diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index 37868d7..a101cb3 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -21,7 +21,7 @@ from ._code import Code from ._terminator import Terminator from ._wordlist import PGPWordList from .errors import (ServerError, LonelyError, WrongPasswordError, - KeyFormatError, OnlyOneCodeError) + KeyFormatError, OnlyOneCodeError, _UnknownPhaseError) from .util import bytes_to_dict @attrs @@ -126,11 +126,11 @@ class Boss(object): # would require the Wormhole to be aware of Code (whereas right now # Wormhole only knows about this Boss instance, and everything else is # hidden away). - def input_code(self, helper): + def input_code(self): if self._did_start_code: raise OnlyOneCodeError() self._did_start_code = True - self._C.input_code(helper) + return self._C.input_code() def allocate_code(self, code_length): if self._did_start_code: raise OnlyOneCodeError() @@ -175,17 +175,17 @@ class Boss(object): assert isinstance(phase, type("")), type(phase) assert isinstance(plaintext, type(b"")), type(plaintext) if phase == "version": - self.got_version(plaintext) + self._got_version(plaintext) elif re.search(r'^\d+$', phase): - self.got_phase(int(phase), plaintext) + self._got_phase(int(phase), plaintext) else: # Ignore unrecognized phases, for forwards-compatibility. Use # log.err so tests will catch surprises. - log.err("received unknown phase '%s'" % phase) + log.err(_UnknownPhaseError("received unknown phase '%s'" % phase)) @m.input() - def got_version(self, plaintext): pass + def _got_version(self, plaintext): pass @m.input() - def got_phase(self, phase, plaintext): pass + def _got_phase(self, phase, plaintext): pass @m.input() def got_key(self, key): pass @m.input() @@ -210,7 +210,7 @@ class Boss(object): self._their_versions = bytes_to_dict(plaintext) # but this part is app-to-app app_versions = self._their_versions.get("app_versions", {}) - self._W.got_versions(app_versions) + self._W.got_version(app_versions) @m.output() def S_send(self, plaintext): @@ -279,8 +279,8 @@ class Boss(object): S1_lonely.upon(error, enter=S4_closed, outputs=[W_close_with_error]) S2_happy.upon(rx_welcome, enter=S2_happy, outputs=[process_welcome]) - S2_happy.upon(got_phase, enter=S2_happy, outputs=[W_received]) - S2_happy.upon(got_version, enter=S2_happy, outputs=[process_version]) + S2_happy.upon(_got_phase, enter=S2_happy, outputs=[W_received]) + S2_happy.upon(_got_version, enter=S2_happy, outputs=[process_version]) S2_happy.upon(scared, enter=S3_closing, outputs=[close_scared]) S2_happy.upon(close, enter=S3_closing, outputs=[close_happy]) S2_happy.upon(send, enter=S2_happy, outputs=[S_send]) @@ -289,8 +289,8 @@ class Boss(object): S3_closing.upon(rx_welcome, enter=S3_closing, outputs=[]) S3_closing.upon(rx_error, enter=S3_closing, outputs=[]) - S3_closing.upon(got_phase, enter=S3_closing, outputs=[]) - S3_closing.upon(got_version, enter=S3_closing, outputs=[]) + S3_closing.upon(_got_phase, enter=S3_closing, outputs=[]) + S3_closing.upon(_got_version, enter=S3_closing, outputs=[]) S3_closing.upon(happy, enter=S3_closing, outputs=[]) S3_closing.upon(scared, enter=S3_closing, outputs=[]) S3_closing.upon(close, enter=S3_closing, outputs=[]) @@ -299,8 +299,8 @@ class Boss(object): S3_closing.upon(error, enter=S4_closed, outputs=[W_close_with_error]) S4_closed.upon(rx_welcome, enter=S4_closed, outputs=[]) - S4_closed.upon(got_phase, enter=S4_closed, outputs=[]) - S4_closed.upon(got_version, enter=S4_closed, outputs=[]) + S4_closed.upon(_got_phase, enter=S4_closed, outputs=[]) + S4_closed.upon(_got_version, enter=S4_closed, outputs=[]) S4_closed.upon(happy, enter=S4_closed, outputs=[]) S4_closed.upon(scared, enter=S4_closed, outputs=[]) S4_closed.upon(close, enter=S4_closed, outputs=[]) diff --git a/src/wormhole/_rendezvous.py b/src/wormhole/_rendezvous.py index 2ed6192..a257b78 100644 --- a/src/wormhole/_rendezvous.py +++ b/src/wormhole/_rendezvous.py @@ -8,7 +8,7 @@ from twisted.python import log from twisted.internet import defer, endpoints from twisted.application import internet from autobahn.twisted import websocket -from . import _interfaces +from . import _interfaces, errors from .util import (bytes_to_hexstr, hexstr_to_bytes, bytes_to_dict, dict_to_bytes) @@ -171,7 +171,7 @@ class RendezvousConnector(object): meth = getattr(self, "_response_handle_"+mtype, None) if not meth: # make tests fail, but real application will ignore it - log.err(ValueError("Unknown inbound message type %r" % (msg,))) + log.err(errors._UnknownMessageTypeError("Unknown inbound message type %r" % (msg,))) return try: return meth(msg) diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index 419605f..1197c26 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -72,3 +72,7 @@ class WormholeClosed(Exception): wormhole was already closed, or if it closes before a real result can be obtained.""" +class _UnknownPhaseError(Exception): + """internal exception type, for tests.""" +class _UnknownMessageTypeError(Exception): + """internal exception type, for tests.""" diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index 3359958..c77bae6 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -30,9 +30,11 @@ class Dummy: directlyProvides(self, iface) for meth in meths: self.mock(meth) + self.retval = None def mock(self, meth): def log(*args): self.events.append(("%s.%s" % (self.name, meth),) + args) + return self.retval setattr(self, meth, log) class Send(unittest.TestCase): @@ -1109,13 +1111,13 @@ class Boss(unittest.TestCase): wormhole = Dummy("w", events, None, "got_code", "got_key", "got_verifier", "got_version", "received", "closed") - welcome_handler = mock.Mock() + self._welcome_handler = mock.Mock() versions = {"app": "version1"} reactor = None journal = ImmediateJournal() tor_manager = None b = MockBoss(wormhole, "side", "url", "appid", versions, - welcome_handler, reactor, journal, tor_manager, + self._welcome_handler, reactor, journal, tor_manager, timing.DebugTiming()) t = b._T = Dummy("t", events, ITerminator, "close") s = b._S = Dummy("s", events, ISend, "send") @@ -1134,22 +1136,199 @@ class Boss(unittest.TestCase): self.assertEqual(events, [("w.got_code", "1-code")]) events[:] = [] + b.rx_welcome("welcome") + self.assertEqual(self._welcome_handler.mock_calls, [mock.call("welcome")]) + # pretend a peer message was correctly decrypted + b.got_key(b"key") + b.got_verifier(b"verifier") b.happy() - b.got_version({}) - b.got_phase("phase1", b"msg1") - self.assertEqual(events, [("w.got_version", {}), + b.got_message("version", b"{}") + b.got_message("0", b"msg1") + self.assertEqual(events, [("w.got_key", b"key"), + ("w.got_verifier", b"verifier"), + ("w.got_version", {}), ("w.received", b"msg1"), ]) events[:] = [] + + b.send(b"msg2") + self.assertEqual(events, [("s.send", "0", b"msg2")]) + events[:] = [] + b.close() self.assertEqual(events, [("t.close", "happy")]) events[:] = [] b.closed() - self.assertEqual(events, [("w.closed", "reasons")]) - - + self.assertEqual(events, [("w.closed", "happy")]) + + def test_lonely(self): + b, events = self.build() + b.set_code("1-code") + self.assertEqual(events, [("c.set_code", "1-code")]) + events[:] = [] + + b.got_code("1-code") + self.assertEqual(events, [("w.got_code", "1-code")]) + events[:] = [] + + b.close() + self.assertEqual(events, [("t.close", "lonely")]) + events[:] = [] + + b.closed() + self.assertEqual(len(events), 1, events) + self.assertEqual(events[0][0], "w.closed") + self.assertIsInstance(events[0][1], errors.LonelyError) + + def test_server_error(self): + b, events = self.build() + b.set_code("1-code") + self.assertEqual(events, [("c.set_code", "1-code")]) + events[:] = [] + + orig = {} + b.rx_error("server-error-msg", orig) + self.assertEqual(events, [("t.close", "errory")]) + events[:] = [] + + b.closed() + self.assertEqual(len(events), 1, events) + self.assertEqual(events[0][0], "w.closed") + self.assertIsInstance(events[0][1], errors.ServerError) + self.assertEqual(events[0][1].args[0], "server-error-msg") + + def test_internal_error(self): + b, events = self.build() + b.set_code("1-code") + self.assertEqual(events, [("c.set_code", "1-code")]) + events[:] = [] + + b.error(ValueError("catch me")) + self.assertEqual(len(events), 1, events) + self.assertEqual(events[0][0], "w.closed") + self.assertIsInstance(events[0][1], ValueError) + self.assertEqual(events[0][1].args[0], "catch me") + + def test_close_early(self): + b, events = self.build() + b.set_code("1-code") + self.assertEqual(events, [("c.set_code", "1-code")]) + events[:] = [] + + b.close() # before even w.got_code + self.assertEqual(events, [("t.close", "lonely")]) + events[:] = [] + + b.closed() + self.assertEqual(len(events), 1, events) + self.assertEqual(events[0][0], "w.closed") + self.assertIsInstance(events[0][1], errors.LonelyError) + + def test_error_while_closing(self): + b, events = self.build() + b.set_code("1-code") + self.assertEqual(events, [("c.set_code", "1-code")]) + events[:] = [] + + b.close() + self.assertEqual(events, [("t.close", "lonely")]) + events[:] = [] + + b.error(ValueError("oops")) + self.assertEqual(len(events), 1, events) + self.assertEqual(events[0][0], "w.closed") + self.assertIsInstance(events[0][1], ValueError) + + def test_scary_version(self): + b, events = self.build() + b.set_code("1-code") + self.assertEqual(events, [("c.set_code", "1-code")]) + events[:] = [] + + b.got_code("1-code") + self.assertEqual(events, [("w.got_code", "1-code")]) + events[:] = [] + + b.scared() + self.assertEqual(events, [("t.close", "scary")]) + events[:] = [] + + b.closed() + self.assertEqual(len(events), 1, events) + self.assertEqual(events[0][0], "w.closed") + self.assertIsInstance(events[0][1], errors.WrongPasswordError) + + def test_scary_phase(self): + b, events = self.build() + b.set_code("1-code") + self.assertEqual(events, [("c.set_code", "1-code")]) + events[:] = [] + + b.got_code("1-code") + self.assertEqual(events, [("w.got_code", "1-code")]) + events[:] = [] + + b.happy() # phase=version + + b.scared() # phase=0 + self.assertEqual(events, [("t.close", "scary")]) + events[:] = [] + + b.closed() + self.assertEqual(len(events), 1, events) + self.assertEqual(events[0][0], "w.closed") + self.assertIsInstance(events[0][1], errors.WrongPasswordError) + + def test_unknown_phase(self): + b, events = self.build() + b.set_code("1-code") + self.assertEqual(events, [("c.set_code", "1-code")]) + events[:] = [] + + b.got_code("1-code") + self.assertEqual(events, [("w.got_code", "1-code")]) + events[:] = [] + + b.happy() # phase=version + + b.got_message("unknown-phase", b"spooky") + self.assertEqual(events, []) + + self.flushLoggedErrors(errors._UnknownPhaseError) + + def test_set_code_bad_format(self): + b, events = self.build() + with self.assertRaises(errors.KeyFormatError): + b.set_code("1 code") + + def test_set_code_bad_twice(self): + b, events = self.build() + b.set_code("1-code") + with self.assertRaises(errors.OnlyOneCodeError): + b.set_code("1-code") + + def test_input_code(self): + b, events = self.build() + b._C.retval = "helper" + helper = b.input_code() + self.assertEqual(events, [("c.input_code",)]) + self.assertEqual(helper, "helper") + with self.assertRaises(errors.OnlyOneCodeError): + b.input_code() + + def test_allocate_code(self): + b, events = self.build() + wl = object() + with mock.patch("wormhole._boss.PGPWordList", return_value=wl): + b.allocate_code(3) + self.assertEqual(events, [("c.allocate_code", 3, wl)]) + with self.assertRaises(errors.OnlyOneCodeError): + b.allocate_code(3) + + + # TODO # #Send diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index f730195..dc3daf9 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -124,7 +124,7 @@ class _DelegatedWormhole(object): self._key = key # for derive_key() def got_verifier(self, verifier): self._delegate.wormhole_verified(verifier) - def got_versions(self, versions): + def got_version(self, versions): self._delegate.wormhole_version(versions) def received(self, plaintext): self._delegate.wormhole_received(plaintext) @@ -191,8 +191,8 @@ class _DeferredWormhole(object): def allocate_code(self, code_length=2): self._boss.allocate_code(code_length) - def input_code(self, stdio): # TODO - self._boss.input_code(stdio) + def input_code(self): + return self._boss.input_code() def set_code(self, code): self._boss.set_code(code) @@ -241,7 +241,7 @@ class _DeferredWormhole(object): for d in self._verifier_observers: d.callback(verifier) self._verifier_observers[:] = [] - def got_versions(self, versions): + def got_version(self, versions): self._versions = versions for d in self._version_observers: d.callback(versions)