finish Boss tests

This commit is contained in:
Brian Warner 2017-03-19 15:09:26 -07:00
parent d8d305407b
commit 53a911cc80
5 changed files with 212 additions and 29 deletions

View File

@ -21,7 +21,7 @@ from ._code import Code
from ._terminator import Terminator
from ._wordlist import PGPWordList
from .errors import (ServerError, LonelyError, WrongPasswordError,
KeyFormatError, OnlyOneCodeError)
KeyFormatError, OnlyOneCodeError, _UnknownPhaseError)
from .util import bytes_to_dict
@attrs
@ -126,11 +126,11 @@ class Boss(object):
# would require the Wormhole to be aware of Code (whereas right now
# Wormhole only knows about this Boss instance, and everything else is
# hidden away).
def input_code(self, helper):
def input_code(self):
if self._did_start_code:
raise OnlyOneCodeError()
self._did_start_code = True
self._C.input_code(helper)
return self._C.input_code()
def allocate_code(self, code_length):
if self._did_start_code:
raise OnlyOneCodeError()
@ -175,17 +175,17 @@ class Boss(object):
assert isinstance(phase, type("")), type(phase)
assert isinstance(plaintext, type(b"")), type(plaintext)
if phase == "version":
self.got_version(plaintext)
self._got_version(plaintext)
elif re.search(r'^\d+$', phase):
self.got_phase(int(phase), plaintext)
self._got_phase(int(phase), plaintext)
else:
# Ignore unrecognized phases, for forwards-compatibility. Use
# log.err so tests will catch surprises.
log.err("received unknown phase '%s'" % phase)
log.err(_UnknownPhaseError("received unknown phase '%s'" % phase))
@m.input()
def got_version(self, plaintext): pass
def _got_version(self, plaintext): pass
@m.input()
def got_phase(self, phase, plaintext): pass
def _got_phase(self, phase, plaintext): pass
@m.input()
def got_key(self, key): pass
@m.input()
@ -210,7 +210,7 @@ class Boss(object):
self._their_versions = bytes_to_dict(plaintext)
# but this part is app-to-app
app_versions = self._their_versions.get("app_versions", {})
self._W.got_versions(app_versions)
self._W.got_version(app_versions)
@m.output()
def S_send(self, plaintext):
@ -279,8 +279,8 @@ class Boss(object):
S1_lonely.upon(error, enter=S4_closed, outputs=[W_close_with_error])
S2_happy.upon(rx_welcome, enter=S2_happy, outputs=[process_welcome])
S2_happy.upon(got_phase, enter=S2_happy, outputs=[W_received])
S2_happy.upon(got_version, enter=S2_happy, outputs=[process_version])
S2_happy.upon(_got_phase, enter=S2_happy, outputs=[W_received])
S2_happy.upon(_got_version, enter=S2_happy, outputs=[process_version])
S2_happy.upon(scared, enter=S3_closing, outputs=[close_scared])
S2_happy.upon(close, enter=S3_closing, outputs=[close_happy])
S2_happy.upon(send, enter=S2_happy, outputs=[S_send])
@ -289,8 +289,8 @@ class Boss(object):
S3_closing.upon(rx_welcome, enter=S3_closing, outputs=[])
S3_closing.upon(rx_error, enter=S3_closing, outputs=[])
S3_closing.upon(got_phase, enter=S3_closing, outputs=[])
S3_closing.upon(got_version, enter=S3_closing, outputs=[])
S3_closing.upon(_got_phase, enter=S3_closing, outputs=[])
S3_closing.upon(_got_version, enter=S3_closing, outputs=[])
S3_closing.upon(happy, enter=S3_closing, outputs=[])
S3_closing.upon(scared, enter=S3_closing, outputs=[])
S3_closing.upon(close, enter=S3_closing, outputs=[])
@ -299,8 +299,8 @@ class Boss(object):
S3_closing.upon(error, enter=S4_closed, outputs=[W_close_with_error])
S4_closed.upon(rx_welcome, enter=S4_closed, outputs=[])
S4_closed.upon(got_phase, enter=S4_closed, outputs=[])
S4_closed.upon(got_version, enter=S4_closed, outputs=[])
S4_closed.upon(_got_phase, enter=S4_closed, outputs=[])
S4_closed.upon(_got_version, enter=S4_closed, outputs=[])
S4_closed.upon(happy, enter=S4_closed, outputs=[])
S4_closed.upon(scared, enter=S4_closed, outputs=[])
S4_closed.upon(close, enter=S4_closed, outputs=[])

View File

@ -8,7 +8,7 @@ from twisted.python import log
from twisted.internet import defer, endpoints
from twisted.application import internet
from autobahn.twisted import websocket
from . import _interfaces
from . import _interfaces, errors
from .util import (bytes_to_hexstr, hexstr_to_bytes,
bytes_to_dict, dict_to_bytes)
@ -171,7 +171,7 @@ class RendezvousConnector(object):
meth = getattr(self, "_response_handle_"+mtype, None)
if not meth:
# make tests fail, but real application will ignore it
log.err(ValueError("Unknown inbound message type %r" % (msg,)))
log.err(errors._UnknownMessageTypeError("Unknown inbound message type %r" % (msg,)))
return
try:
return meth(msg)

View File

@ -72,3 +72,7 @@ class WormholeClosed(Exception):
wormhole was already closed, or if it closes before a real result can be
obtained."""
class _UnknownPhaseError(Exception):
"""internal exception type, for tests."""
class _UnknownMessageTypeError(Exception):
"""internal exception type, for tests."""

View File

@ -30,9 +30,11 @@ class Dummy:
directlyProvides(self, iface)
for meth in meths:
self.mock(meth)
self.retval = None
def mock(self, meth):
def log(*args):
self.events.append(("%s.%s" % (self.name, meth),) + args)
return self.retval
setattr(self, meth, log)
class Send(unittest.TestCase):
@ -1109,13 +1111,13 @@ class Boss(unittest.TestCase):
wormhole = Dummy("w", events, None,
"got_code", "got_key", "got_verifier", "got_version",
"received", "closed")
welcome_handler = mock.Mock()
self._welcome_handler = mock.Mock()
versions = {"app": "version1"}
reactor = None
journal = ImmediateJournal()
tor_manager = None
b = MockBoss(wormhole, "side", "url", "appid", versions,
welcome_handler, reactor, journal, tor_manager,
self._welcome_handler, reactor, journal, tor_manager,
timing.DebugTiming())
t = b._T = Dummy("t", events, ITerminator, "close")
s = b._S = Dummy("s", events, ISend, "send")
@ -1134,22 +1136,199 @@ class Boss(unittest.TestCase):
self.assertEqual(events, [("w.got_code", "1-code")])
events[:] = []
b.rx_welcome("welcome")
self.assertEqual(self._welcome_handler.mock_calls, [mock.call("welcome")])
# pretend a peer message was correctly decrypted
b.got_key(b"key")
b.got_verifier(b"verifier")
b.happy()
b.got_version({})
b.got_phase("phase1", b"msg1")
self.assertEqual(events, [("w.got_version", {}),
b.got_message("version", b"{}")
b.got_message("0", b"msg1")
self.assertEqual(events, [("w.got_key", b"key"),
("w.got_verifier", b"verifier"),
("w.got_version", {}),
("w.received", b"msg1"),
])
events[:] = []
b.send(b"msg2")
self.assertEqual(events, [("s.send", "0", b"msg2")])
events[:] = []
b.close()
self.assertEqual(events, [("t.close", "happy")])
events[:] = []
b.closed()
self.assertEqual(events, [("w.closed", "reasons")])
self.assertEqual(events, [("w.closed", "happy")])
def test_lonely(self):
b, events = self.build()
b.set_code("1-code")
self.assertEqual(events, [("c.set_code", "1-code")])
events[:] = []
b.got_code("1-code")
self.assertEqual(events, [("w.got_code", "1-code")])
events[:] = []
b.close()
self.assertEqual(events, [("t.close", "lonely")])
events[:] = []
b.closed()
self.assertEqual(len(events), 1, events)
self.assertEqual(events[0][0], "w.closed")
self.assertIsInstance(events[0][1], errors.LonelyError)
def test_server_error(self):
b, events = self.build()
b.set_code("1-code")
self.assertEqual(events, [("c.set_code", "1-code")])
events[:] = []
orig = {}
b.rx_error("server-error-msg", orig)
self.assertEqual(events, [("t.close", "errory")])
events[:] = []
b.closed()
self.assertEqual(len(events), 1, events)
self.assertEqual(events[0][0], "w.closed")
self.assertIsInstance(events[0][1], errors.ServerError)
self.assertEqual(events[0][1].args[0], "server-error-msg")
def test_internal_error(self):
b, events = self.build()
b.set_code("1-code")
self.assertEqual(events, [("c.set_code", "1-code")])
events[:] = []
b.error(ValueError("catch me"))
self.assertEqual(len(events), 1, events)
self.assertEqual(events[0][0], "w.closed")
self.assertIsInstance(events[0][1], ValueError)
self.assertEqual(events[0][1].args[0], "catch me")
def test_close_early(self):
b, events = self.build()
b.set_code("1-code")
self.assertEqual(events, [("c.set_code", "1-code")])
events[:] = []
b.close() # before even w.got_code
self.assertEqual(events, [("t.close", "lonely")])
events[:] = []
b.closed()
self.assertEqual(len(events), 1, events)
self.assertEqual(events[0][0], "w.closed")
self.assertIsInstance(events[0][1], errors.LonelyError)
def test_error_while_closing(self):
b, events = self.build()
b.set_code("1-code")
self.assertEqual(events, [("c.set_code", "1-code")])
events[:] = []
b.close()
self.assertEqual(events, [("t.close", "lonely")])
events[:] = []
b.error(ValueError("oops"))
self.assertEqual(len(events), 1, events)
self.assertEqual(events[0][0], "w.closed")
self.assertIsInstance(events[0][1], ValueError)
def test_scary_version(self):
b, events = self.build()
b.set_code("1-code")
self.assertEqual(events, [("c.set_code", "1-code")])
events[:] = []
b.got_code("1-code")
self.assertEqual(events, [("w.got_code", "1-code")])
events[:] = []
b.scared()
self.assertEqual(events, [("t.close", "scary")])
events[:] = []
b.closed()
self.assertEqual(len(events), 1, events)
self.assertEqual(events[0][0], "w.closed")
self.assertIsInstance(events[0][1], errors.WrongPasswordError)
def test_scary_phase(self):
b, events = self.build()
b.set_code("1-code")
self.assertEqual(events, [("c.set_code", "1-code")])
events[:] = []
b.got_code("1-code")
self.assertEqual(events, [("w.got_code", "1-code")])
events[:] = []
b.happy() # phase=version
b.scared() # phase=0
self.assertEqual(events, [("t.close", "scary")])
events[:] = []
b.closed()
self.assertEqual(len(events), 1, events)
self.assertEqual(events[0][0], "w.closed")
self.assertIsInstance(events[0][1], errors.WrongPasswordError)
def test_unknown_phase(self):
b, events = self.build()
b.set_code("1-code")
self.assertEqual(events, [("c.set_code", "1-code")])
events[:] = []
b.got_code("1-code")
self.assertEqual(events, [("w.got_code", "1-code")])
events[:] = []
b.happy() # phase=version
b.got_message("unknown-phase", b"spooky")
self.assertEqual(events, [])
self.flushLoggedErrors(errors._UnknownPhaseError)
def test_set_code_bad_format(self):
b, events = self.build()
with self.assertRaises(errors.KeyFormatError):
b.set_code("1 code")
def test_set_code_bad_twice(self):
b, events = self.build()
b.set_code("1-code")
with self.assertRaises(errors.OnlyOneCodeError):
b.set_code("1-code")
def test_input_code(self):
b, events = self.build()
b._C.retval = "helper"
helper = b.input_code()
self.assertEqual(events, [("c.input_code",)])
self.assertEqual(helper, "helper")
with self.assertRaises(errors.OnlyOneCodeError):
b.input_code()
def test_allocate_code(self):
b, events = self.build()
wl = object()
with mock.patch("wormhole._boss.PGPWordList", return_value=wl):
b.allocate_code(3)
self.assertEqual(events, [("c.allocate_code", 3, wl)])
with self.assertRaises(errors.OnlyOneCodeError):
b.allocate_code(3)
# TODO
# #Send

View File

@ -124,7 +124,7 @@ class _DelegatedWormhole(object):
self._key = key # for derive_key()
def got_verifier(self, verifier):
self._delegate.wormhole_verified(verifier)
def got_versions(self, versions):
def got_version(self, versions):
self._delegate.wormhole_version(versions)
def received(self, plaintext):
self._delegate.wormhole_received(plaintext)
@ -191,8 +191,8 @@ class _DeferredWormhole(object):
def allocate_code(self, code_length=2):
self._boss.allocate_code(code_length)
def input_code(self, stdio): # TODO
self._boss.input_code(stdio)
def input_code(self):
return self._boss.input_code()
def set_code(self, code):
self._boss.set_code(code)
@ -241,7 +241,7 @@ class _DeferredWormhole(object):
for d in self._verifier_observers:
d.callback(verifier)
self._verifier_observers[:] = []
def got_versions(self, versions):
def got_version(self, versions):
self._versions = versions
for d in self._version_observers:
d.callback(versions)