diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index 4d91270..0c140ac 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -20,6 +20,10 @@ def handle_server_error(func): class Timeout(Exception): pass +class WelcomeError(Exception): + """The server told us to signal an error, probably because our version is + too old to possibly work.""" + class WrongPasswordError(Exception): """ Key confirmation failed. Either you or your correspondent typed the code diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index aeaa494..437e8bd 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -5,7 +5,6 @@ from twisted.trial import unittest from twisted.internet import protocol, reactor, defer from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.endpoints import clientFromString, connectProtocol -from twisted.web.client import getPage, Agent, readBody from autobahn.twisted import websocket from .. import __version__ from .common import ServerBase diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index 1a6d865..ad8e5ef 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -1,5 +1,5 @@ from __future__ import print_function -import os, json, re +import os, json, re, gc from binascii import hexlify, unhexlify import mock from twisted.trial import unittest @@ -84,7 +84,14 @@ class Welcome(unittest.TestCase): self.assertEqual(se.mock_calls, []) w.handle_welcome({u"error": u"oops"}) - self.assertEqual(se.mock_calls, [mock.call(u"oops")]) + self.assertEqual(len(se.mock_calls), 1) + self.assertEqual(len(se.mock_calls[0][1]), 1) # posargs + we = se.mock_calls[0][1][0] + self.assertIsInstance(we, wormhole.WelcomeError) + self.assertEqual(we.args, (u"oops",)) + # alas WelcomeError instances don't compare against each other + #self.assertEqual(se.mock_calls, + # [mock.call(wormhole.WelcomeError(u"oops"))]) class InputCode(unittest.TestCase): def test_list(self): @@ -116,10 +123,10 @@ class GetCode(unittest.TestCase): self.assertEqual(len(pieces), 3) # nameplate plus two words self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) - class Basic(unittest.TestCase): - def test_create(self): - wormhole._Wormhole(APPID, u"relay_url", reactor, None, None) + def tearDown(self): + # flush out any errorful Deferreds left dangling in cycles + gc.collect() def check_out(self, out, **kwargs): # Assert that each kwarg is present in the 'out' dict. Ignore other @@ -135,6 +142,17 @@ class Basic(unittest.TestCase): self.assertEqual(out[i][u"type"], t, (i,t,out)) return out + def make_pake(self, code, side, msg1): + sp2 = SPAKE2_Symmetric(wormhole.to_bytes(code), + idSymmetric=wormhole.to_bytes(APPID)) + msg2 = sp2.start() + msg2_hex = hexlify(msg2).decode("ascii") + key = sp2.finish(msg1) + return key, msg2_hex + + def test_create(self): + wormhole._Wormhole(APPID, u"relay_url", reactor, None, None) + def test_basic(self): # We don't call w._start(), so this doesn't create a WebSocket # connection. We provide a mock connection instead. If we wanted to @@ -201,11 +219,7 @@ class Basic(unittest.TestCase): # next we build the simulated peer's PAKE operation side2 = w._side + u"other" msg1 = unhexlify(out[1][u"body"].encode("ascii")) - sp2 = SPAKE2_Symmetric(wormhole.to_bytes(CODE), - idSymmetric=wormhole.to_bytes(APPID)) - msg2 = sp2.start() - msg2_hex = hexlify(msg2).decode("ascii") - key = sp2.finish(msg1) + key, msg2_hex = self.make_pake(CODE, side2, msg1) response(w, type=u"message", phase=u"pake", body=msg2_hex, side=side2) # hearing the peer's PAKE (msg2) makes us release the nameplate, send @@ -272,6 +286,24 @@ class Basic(unittest.TestCase): self.check_out(out[0], type=u"close", mood=u"happy") self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + def test_close_wait_0(self): + # close before even claiming the nameplate + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + + d = w.close(wait=True) + self.check_outbound(ws, [u"bind"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + self.assertNoResult(d) + + w._ws_closed(True, None, None) + self.successResultOf(d) + def test_close_wait_1(self): # close after claiming the nameplate, but before opening the mailbox timing = DebugTiming() @@ -406,6 +438,98 @@ class Basic(unittest.TestCase): self.assertEqual(len(pieces), 3) # nameplate plus two words self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) + def test_api_errors(self): + # doing things you're not supposed to do + pass + + def test_welcome_error(self): + # A welcome message could arrive at any time, with an [error] key + # that should make us halt. In practice, though, this gets sent as + # soon as the connection is established, which limits the possible + # states in which we might see it. + + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + self.check_outbound(ws, [u"bind"]) + + WE = wormhole.WelcomeError + d1 = w.get() + d2 = w.get_verifier() + d3 = w.get_code() + # TODO (tricky): test w.input_code + + self.assertNoResult(d1) + self.assertNoResult(d2) + self.assertNoResult(d3) + + w._signal_error(WE(u"you are not actually welcome")) + self.failureResultOf(d1, WE) + self.failureResultOf(d2, WE) + self.failureResultOf(d3, WE) + + # once the error is signalled, all API calls should fail + self.assertRaises(WE, w.send, u"foo") + self.assertRaises(WE, w.derive_key, u"foo") + self.failureResultOf(w.get(), WE) + self.failureResultOf(w.get_verifier(), WE) + + def test_confirm_error(self): + # we should only receive the "confirm" message after we receive the + # PAKE message, by which point we should know the key. If the + # confirmation message doesn't decrypt, we signal an error. + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + w.set_code(u"123-foo-bar") + response(w, type=u"claimed", mailbox=u"mb456") + + WP = wormhole.WrongPasswordError + d1 = w.get() + d2 = w.get_verifier() + self.assertNoResult(d1) + self.assertNoResult(d2) + + out = ws.outbound() + # [u"bind", u"claim", u"open", u"add"] + self.assertEqual(len(out), 4) + self.assertEqual(out[3][u"type"], u"add") + + sp2 = SPAKE2_Symmetric(b"", idSymmetric=wormhole.to_bytes(APPID)) + msg2 = sp2.start() + msg2_hex = hexlify(msg2).decode("ascii") + response(w, type=u"message", phase=u"pake", body=msg2_hex, side=u"s2") + self.assertNoResult(d1) + self.successResultOf(d2) # early get_verifier is unaffected + # TODO: get_verifier would be a lovely place to signal a confirmation + # error, but that's at odds with delivering the verifier as early as + # possible. The confirmation messages should be hot on the heels of + # the PAKE message that produced the verifier. Maybe get_verifier() + # should explicitly wait for confirm()? + + # sending a random confirm message will cause a confirmation error + confkey = w.derive_key(u"WRONG") + nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH) + badconfirm = wormhole.make_confmsg(confkey, nonce) + badconfirm_hex = hexlify(badconfirm).decode("ascii") + response(w, type=u"message", phase=u"confirm", body=badconfirm_hex, + side=u"s2") + + self.failureResultOf(d1, WP) + + # once the error is signalled, all API calls should fail + self.assertRaises(WP, w.send, u"foo") + self.assertRaises(WP, w.derive_key, u"foo") + self.failureResultOf(w.get(), WP) + self.failureResultOf(w.get_verifier(), WP) + + # event orderings to exercise: # # * normal sender: set_code, send_phase1, connected, claimed, learn_msg2, diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index ccb7f9c..53cefca 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -14,7 +14,7 @@ from spake2 import SPAKE2_Symmetric from . import __version__ from . import codes #from .errors import ServerError, Timeout -from .errors import WrongPasswordError, UsageError +from .errors import WrongPasswordError, UsageError, WelcomeError from .timing import DebugTiming from hkdf import Hkdf @@ -203,7 +203,7 @@ class _WelcomeHandler: self._version_warning_displayed = True if "error" in welcome: - return self._signal_error(welcome["error"]) + return self._signal_error(WelcomeError(welcome["error"])) class _Wormhole: @@ -220,15 +220,16 @@ class _Wormhole: self._connected = None self._connection_waiters = [] self._started_get_code = False + self._get_code = None self._code = None self._nameplate_id = None self._nameplate_claimed = False self._nameplate_released = False - self._release_waiter = defer.Deferred() + self._release_waiter = None self._mailbox_id = None self._mailbox_opened = False self._mailbox_closed = False - self._close_waiter = defer.Deferred() + self._close_waiter = None self._flag_need_nameplate = True self._flag_need_to_see_mailbox_used = True self._flag_need_to_build_msg1 = True @@ -237,6 +238,7 @@ class _Wormhole: self._closed = False self._disconnect_waiter = defer.Deferred() self._mood = u"happy" + self._error = None self._get_verifier_called = False self._verifier_waiter = defer.Deferred() @@ -253,7 +255,24 @@ class _Wormhole: def _signal_error(self, error): # close the mailbox with an "errory" mood, errback all Deferreds, # record the error, fail all subsequent API calls - pass # XXX + if self.DEBUG: print("_signal_error", error) + self._error = error # causes new API calls to fail + for d in self._connection_waiters: + d.errback(error) + if self._get_code: + self._get_code._allocated_d.errback(error) + if not self._verifier_waiter.called: + self._verifier_waiter.errback(error) + for d in self._receive_waiters.values(): + d.errback(error) + + self._maybe_close(mood=u"errory") + if self._release_waiter and not self._release_waiter.called: + self._release_waiter.errback(error) + if self._close_waiter and not self._close_waiter.called: + self._close_waiter.errback(error) + # leave self._disconnect_waiter alone + if self.DEBUG: print("_signal_error done") def _start(self): d = self._connect() # causes stuff to happen @@ -346,8 +365,11 @@ class _Wormhole: with self._timing.add("API get_code"): yield self._when_connected() gc = _GetCode(code_length, self._ws_send_command, self._timing) + self._get_code = gc self._response_handle_allocated = gc._response_handle_allocated + # TODO: signal_error code = yield gc.go() + self._get_code = None self._nameplate_claimed = True # side-effect of allocation self._event_learned_code(code) returnValue(code) @@ -362,6 +384,7 @@ class _Wormhole: yield self._when_connected() ic = _InputCode(prompt, code_length, self._ws_send_command) self._response_handle_nameplates = ic._response_handle_nameplates + # TODO: signal_error code = yield ic.go() self._event_learned_code(code) returnValue(None) @@ -465,9 +488,11 @@ class _Wormhole: self._maybe_send_phase_messages() def get_verifier(self): + if self._error: return defer.fail(self._error) if self._closed: raise UsageError if self._get_verifier_called: raise UsageError self._get_verifier_called = True + # TODO: maybe have this wait on _event_received_confirm too return self._verifier_waiter def _event_computed_verifier(self, verifier): @@ -481,10 +506,12 @@ class _Wormhole: nonce = body[:CONFMSG_NONCE_LENGTH] if body != make_confmsg(confkey, nonce): # this makes all API calls fail + if self.DEBUG: print("CONFIRM FAILED") return self._signal_error(WrongPasswordError()) def send(self, outbound_data): + if self._error: raise self._error if not isinstance(outbound_data, type(b"")): raise TypeError(type(outbound_data)) if self._closed: raise UsageError @@ -540,6 +567,7 @@ class _Wormhole: self._flag_need_to_see_mailbox_used = False def derive_key(self, purpose, length=SecretBox.KEY_SIZE): + if self._error: raise self._error if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if self._key is None: raise UsageError # call derive_key after get_verifier() or get() @@ -589,6 +617,7 @@ class _Wormhole: return data def get(self): + if self._error: return defer.fail(self._error) if self._closed: raise UsageError phase = u"%d" % self._next_receive_phase self._next_receive_phase += 1 @@ -598,21 +627,34 @@ class _Wormhole: d = self._receive_waiters[phase] = defer.Deferred() return d + def _maybe_close(self, mood): + if self._closed: + return + self.close(mood) + @inlineCallbacks def close(self, mood=None, wait=False): # TODO: auto-close on error, mostly for load-from-state + if self.DEBUG: print("close", wait) if self._closed: raise UsageError + self._closed = True if mood: self._mood = mood self._maybe_release_nameplate() self._maybe_close_mailbox() if wait: if self._nameplate_claimed: + if self.DEBUG: print("waiting for released") + self._release_waiter = defer.Deferred() yield self._release_waiter if self._mailbox_opened: + if self.DEBUG: print("waiting for closed") + self._close_waiter = defer.Deferred() yield self._close_waiter + if self.DEBUG: print("dropping connection") self._drop_connection() if wait: + if self.DEBUG: print("waiting for disconnect") yield self._disconnect_waiter def _maybe_release_nameplate(self): @@ -623,15 +665,19 @@ class _Wormhole: self._nameplate_released = True def _response_handle_released(self, msg): - self._release_waiter.callback(None) + if self._release_waiter and not self._release_waiter.called: + self._release_waiter.callback(None) def _maybe_close_mailbox(self): + if self.DEBUG: print("_maybe_close_mailbox", self._mailbox_opened, self._mailbox_closed) if self._mailbox_opened and not self._mailbox_closed: + if self.DEBUG: print(" sending close") self._ws_send_command(u"close", mood=self._mood) self._mailbox_closed = True def _response_handle_closed(self, msg): - self._close_waiter.callback(None) + if self._close_waiter and not self._close_waiter.called: + self._close_waiter.callback(None) def _drop_connection(self): self._ws.transport.loseConnection() # probably flushes