From e11a6f82431c64b33ab0bad8854e85f8a4cb85fc Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 23 May 2016 22:53:00 -0700 Subject: [PATCH] new connection management, test_wormhole passes --- events.dot | 67 +++- src/wormhole/errors.py | 3 + src/wormhole/server/rendezvous_websocket.py | 2 +- src/wormhole/test/test_wormhole.py | 359 +++++++++++--------- src/wormhole/wormhole.py | 306 +++++++++++------ 5 files changed, 459 insertions(+), 278 deletions(-) diff --git a/events.dot b/events.dot index 66ea93b..f6aa1c3 100644 --- a/events.dot +++ b/events.dot @@ -1,31 +1,39 @@ digraph { + api_get_code [label="get_code" shape="hexagon" color="red"] + api_input_code [label="input_code" shape="hexagon" color="red"] + api_set_code [label="set_code" shape="hexagon" color="red"] + send [label="API\nsend" shape="hexagon" color="red"] + get [label="API\nget" shape="hexagon" color="red"] + close [label="API\nclose" shape="hexagon" color="red"] + + event_connected [label="connected" shape="box"] event_learned_code [label="learned\ncode" shape="box"] event_learned_nameplate [label="learned\nnameplate" shape="box"] - event_learned_mailbox [label="learned\nmailbox" shape="box"] - event_connected [label="connected" shape="box"] + event_received_mailbox [label="received\nmailbox" shape="box"] + event_opened_mailbox [label="opened\nmailbox" shape="box"] event_built_msg1 [label="built\nmsg1" shape="box"] event_mailbox_used [label="mailbox\nused" shape="box"] event_learned_PAKE [label="learned\nmsg2" shape="box"] event_established_key [label="established\nkey" shape="box"] event_computed_verifier [label="computed\nverifier" shape="box"] event_received_confirm [label="received\nconfirm" shape="box"] + event_received_message [label="received\nmessage" shape="box"] + event_received_released [label="ack\nreleased" shape="box"] + event_received_closed [label="ack\nclosed" shape="box"] event_connected -> api_get_code event_connected -> api_input_code - api_get_code [label="get_code" shape="hexagon"] - api_input_code [label="input_code" shape="hexagon"] - api_set_code [label="set_code" shape="hexagon"] api_get_code -> event_learned_code api_input_code -> event_learned_code api_set_code -> event_learned_code maybe_build_msg1 [label="build\nmsg1"] - maybe_get_mailbox [label="get\nmailbox"] + maybe_claim_nameplate [label="claim\nnameplate"] maybe_send_pake [label="send\npake"] maybe_send_phase_messages [label="send\nphase\nmessages"] - event_connected -> maybe_get_mailbox + event_connected -> maybe_claim_nameplate event_connected -> maybe_send_pake event_built_msg1 -> maybe_send_pake @@ -34,22 +42,23 @@ digraph { event_learned_code -> event_learned_nameplate maybe_build_msg1 -> event_built_msg1 - event_learned_nameplate -> maybe_get_mailbox + event_learned_nameplate -> maybe_claim_nameplate + maybe_claim_nameplate -> event_received_mailbox [style="dashed"] - maybe_get_mailbox -> event_learned_mailbox [style="dashed"] - maybe_get_mailbox -> event_mailbox_used [style="dashed"] - maybe_get_mailbox -> event_learned_PAKE [style="dashed"] - maybe_get_mailbox -> event_received_confirm [style="dashed"] + event_received_mailbox -> event_opened_mailbox + maybe_claim_nameplate -> event_learned_PAKE [style="dashed"] + maybe_claim_nameplate -> event_received_confirm [style="dashed"] - event_learned_mailbox -> event_learned_PAKE [style="dashed"] + event_opened_mailbox -> event_learned_PAKE [style="dashed"] event_learned_PAKE -> event_mailbox_used [style="dashed"] - event_mailbox_used -> event_received_confirm [style="dashed"] + event_learned_PAKE -> event_received_confirm [style="dashed"] + event_received_confirm -> event_received_message [style="dashed"] - send [label="API\nsend" shape="hexagon"] send -> maybe_send_phase_messages - event_mailbox_used -> release - event_learned_mailbox -> maybe_send_pake - event_learned_mailbox -> maybe_send_phase_messages + release_nameplate [label="release\nnameplate"] + event_mailbox_used -> release_nameplate + event_opened_mailbox -> maybe_send_pake + event_opened_mailbox -> maybe_send_phase_messages event_learned_PAKE -> event_established_key event_established_key -> event_computed_verifier @@ -59,4 +68,26 @@ digraph { event_computed_verifier -> check_verifier event_received_confirm -> check_verifier + check_verifier -> error + event_received_message -> error + event_received_message -> get + event_established_key -> get + + close -> close_mailbox + close -> release_nameplate + error [label="signal\nerror"] + error -> close_mailbox + error -> release_nameplate + + release_nameplate -> event_received_released [style="dashed"] + close_mailbox [label="close\nmailbox"] + close_mailbox -> event_received_closed [style="dashed"] + + maybe_close_websocket [label="close\nwebsocket"] + event_received_released -> maybe_close_websocket + event_received_closed -> maybe_close_websocket + maybe_close_websocket -> event_websocket_closed [style="dashed"] + event_websocket_closed [label="websocket\nclosed"] + + } diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index 0c140ac..141523d 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -41,5 +41,8 @@ class ReflectionAttack(Exception): class UsageError(Exception): """The programmer did something wrong.""" +class WormholeClosedError(UsageError): + """API calls may not be made after close() is called.""" + class TransferError(Exception): """Something bad happened and the transfer failed.""" diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 05e2700..2a19ab2 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -68,7 +68,7 @@ from .rendezvous import CrowdedError, SidedMessage # -> {type: "add", phase: str, body: hex} # will send echo in a "message" # # -> {type: "close", mood: str} -> closed -# <- {type: "closed", status: waiting|deleted} +# <- {type: "closed"} # # <- {type: "error", error: str, orig: {}} # in response to malformed msgs diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index ad8e5ef..c0f2823 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -7,8 +7,10 @@ from twisted.internet import reactor from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks from .common import ServerBase from .. import wormhole +from ..errors import WrongPasswordError, WelcomeError, UsageError from spake2 import SPAKE2_Symmetric from ..timing import DebugTiming +from nacl.secret import SecretBox APPID = u"appid" @@ -87,11 +89,10 @@ class Welcome(unittest.TestCase): self.assertEqual(len(se.mock_calls), 1) self.assertEqual(len(se.mock_calls[0][1]), 1) # posargs we = se.mock_calls[0][1][0] - self.assertIsInstance(we, wormhole.WelcomeError) + self.assertIsInstance(we, WelcomeError) self.assertEqual(we.args, (u"oops",)) # alas WelcomeError instances don't compare against each other - #self.assertEqual(se.mock_calls, - # [mock.call(wormhole.WelcomeError(u"oops"))]) + #self.assertEqual(se.mock_calls, [mock.call(WelcomeError(u"oops"))]) class InputCode(unittest.TestCase): def test_list(self): @@ -171,7 +172,7 @@ class Basic(unittest.TestCase): self.assertTrue(w._flag_need_to_build_msg1) self.assertTrue(w._flag_need_to_send_PAKE) - v = w.get_verifier() + v = w.verify() w._drop_connection = mock.Mock() ws = MockWebSocket() @@ -204,7 +205,7 @@ class Basic(unittest.TestCase): # that triggers event_learned_mailbox, which should send open() and # PAKE - self.assertTrue(w._mailbox_opened) + self.assertEqual(w._mailbox_state, wormhole.OPEN) out = ws.outbound() self.assertEqual(len(out), 2) self.check_out(out[0], type=u"open", mailbox=u"mb456") @@ -232,10 +233,11 @@ class Basic(unittest.TestCase): self.check_out(out[0], type=u"release") self.check_out(out[1], type=u"add", phase=u"confirm") verifier = self.successResultOf(v) - self.assertEqual(verifier, w.derive_key(u"wormhole:verifier")) + self.assertEqual(verifier, + w.derive_key(u"wormhole:verifier", SecretBox.KEY_SIZE)) # hearing a valid confirmation message doesn't throw an error - confkey = w.derive_key(u"wormhole:confirmation") + confkey = w.derive_key(u"wormhole:confirmation", SecretBox.KEY_SIZE) nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH) confirm2 = wormhole.make_confmsg(confkey, nonce) confirm2_hex = hexlify(confirm2).decode("ascii") @@ -249,7 +251,7 @@ class Basic(unittest.TestCase): self.check_out(out[0], type=u"add", phase=u"0") # decrypt+check the outbound message p0_outbound = unhexlify(out[0][u"body"].encode("ascii")) - msgkey0 = w.derive_key(u"wormhole:phase:0") + msgkey0 = w.derive_key(u"wormhole:phase:0", SecretBox.KEY_SIZE) p0_plaintext = w._decrypt_data(msgkey0, p0_outbound) self.assertEqual(p0_plaintext, b"phase0-outbound") @@ -268,7 +270,7 @@ class Basic(unittest.TestCase): self.assertIn(u"0", w._received_messages) # receiving an inbound message will queue it until get() is called - msgkey1 = w.derive_key(u"wormhole:phase:1") + msgkey1 = w.derive_key(u"wormhole:phase:1", SecretBox.KEY_SIZE) p1_inbound = w._encrypt_data(msgkey1, b"phase1-inbound") p1_inbound_hex = hexlify(p1_inbound).decode("ascii") response(w, type=u"message", phase=u"1", body=p1_inbound_hex, @@ -284,9 +286,34 @@ class Basic(unittest.TestCase): out = ws.outbound() self.assertEqual(len(out), 1) self.check_out(out[0], type=u"close", mood=u"happy") + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.assertEqual(w._drop_connection.mock_calls, []) + response(w, type=u"closed") self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + w._ws_closed(True, None, None) def test_close_wait_0(self): + # Close before the connection is established. The connection still + # gets established, but it is then torn down before sending anything. + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + + d = w.close(wait=True) + self.assertNoResult(d) + + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + self.assertNoResult(d) + + w._ws_closed(True, None, None) + self.successResultOf(d) + + def test_close_wait_1(self): # close before even claiming the nameplate timing = DebugTiming() w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) @@ -304,8 +331,40 @@ class Basic(unittest.TestCase): w._ws_closed(True, None, None) self.successResultOf(d) - def test_close_wait_1(self): + def test_close_wait_2(self): + # Close after claiming the nameplate, but before opening the mailbox. + # The 'claimed' response arrives before we close. + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + CODE = u"123-foo-bar" + w.set_code(CODE) + self.check_outbound(ws, [u"bind", u"claim"]) + + response(w, type=u"claimed", mailbox=u"mb123") + + d = w.close(wait=True) + self.check_outbound(ws, [u"open", u"add", u"release", u"close"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"closed") + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + self.assertNoResult(d) + + w._ws_closed(True, None, None) + self.successResultOf(d) + + def test_close_wait_3(self): # close after claiming the nameplate, but before opening the mailbox + # The 'claimed' response arrives after we start to close. timing = DebugTiming() w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) w._drop_connection = mock.Mock() @@ -317,6 +376,7 @@ class Basic(unittest.TestCase): self.check_outbound(ws, [u"bind", u"claim"]) d = w.close(wait=True) + response(w, type=u"claimed", mailbox=u"mb123") self.check_outbound(ws, [u"release"]) self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, []) @@ -328,7 +388,7 @@ class Basic(unittest.TestCase): w._ws_closed(True, None, None) self.successResultOf(d) - def test_close_wait_2(self): + def test_close_wait_4(self): # close after both claiming the nameplate and opening the mailbox timing = DebugTiming() w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) @@ -357,7 +417,7 @@ class Basic(unittest.TestCase): w._ws_closed(True, None, None) self.successResultOf(d) - def test_close_wait_3(self): + def test_close_wait_5(self): # close after claiming the nameplate, opening the mailbox, then # releasing the nameplate timing = DebugTiming() @@ -371,7 +431,7 @@ class Basic(unittest.TestCase): response(w, type=u"claimed", mailbox=u"mb456") w._key = b"" - msgkey = w.derive_key(u"wormhole:phase:misc") + msgkey = w.derive_key(u"wormhole:phase:misc", SecretBox.KEY_SIZE) p1_inbound = w._encrypt_data(msgkey, b"") p1_inbound_hex = hexlify(p1_inbound).decode("ascii") response(w, type=u"message", phase=u"misc", side=u"side2", @@ -395,6 +455,11 @@ class Basic(unittest.TestCase): w._ws_closed(True, None, None) self.successResultOf(d) + def test_close_errbacks(self): + # make sure the Deferreds returned by verify() and get() are properly + # errbacked upon close + pass + def test_get_code_mock(self): timing = DebugTiming() w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) @@ -438,6 +503,11 @@ class Basic(unittest.TestCase): self.assertEqual(len(pieces), 3) # nameplate plus two words self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) + def test_verifier(self): + # make sure verify() can be called both before and after the verifier + # is computed + pass + def test_api_errors(self): # doing things you're not supposed to do pass @@ -456,9 +526,8 @@ class Basic(unittest.TestCase): w._event_ws_opened(None) self.check_outbound(ws, [u"bind"]) - WE = wormhole.WelcomeError d1 = w.get() - d2 = w.get_verifier() + d2 = w.verify() d3 = w.get_code() # TODO (tricky): test w.input_code @@ -466,16 +535,17 @@ class Basic(unittest.TestCase): self.assertNoResult(d2) self.assertNoResult(d3) - w._signal_error(WE(u"you are not actually welcome")) - self.failureResultOf(d1, WE) - self.failureResultOf(d2, WE) - self.failureResultOf(d3, WE) + w._signal_error(WelcomeError(u"you are not actually welcome"), u"pouty") + self.failureResultOf(d1, WelcomeError) + self.failureResultOf(d2, WelcomeError) + self.failureResultOf(d3, WelcomeError) # once the error is signalled, all API calls should fail - self.assertRaises(WE, w.send, u"foo") - self.assertRaises(WE, w.derive_key, u"foo") - self.failureResultOf(w.get(), WE) - self.failureResultOf(w.get_verifier(), WE) + self.assertRaises(WelcomeError, w.send, u"foo") + self.assertRaises(WelcomeError, + w.derive_key, u"foo", SecretBox.KEY_SIZE) + self.failureResultOf(w.get(), WelcomeError) + self.failureResultOf(w.verify(), WelcomeError) def test_confirm_error(self): # we should only receive the "confirm" message after we receive the @@ -490,9 +560,8 @@ class Basic(unittest.TestCase): w.set_code(u"123-foo-bar") response(w, type=u"claimed", mailbox=u"mb456") - WP = wormhole.WrongPasswordError d1 = w.get() - d2 = w.get_verifier() + d2 = w.verify() self.assertNoResult(d1) self.assertNoResult(d2) @@ -506,28 +575,25 @@ class Basic(unittest.TestCase): msg2_hex = hexlify(msg2).decode("ascii") response(w, type=u"message", phase=u"pake", body=msg2_hex, side=u"s2") self.assertNoResult(d1) - self.successResultOf(d2) # early get_verifier is unaffected - # TODO: get_verifier would be a lovely place to signal a confirmation - # error, but that's at odds with delivering the verifier as early as - # possible. The confirmation messages should be hot on the heels of - # the PAKE message that produced the verifier. Maybe get_verifier() - # should explicitly wait for confirm()? + self.successResultOf(d2) # early verify is unaffected + # TODO: change verify() to wait for "confirm" # sending a random confirm message will cause a confirmation error - confkey = w.derive_key(u"WRONG") + confkey = w.derive_key(u"WRONG", SecretBox.KEY_SIZE) nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH) badconfirm = wormhole.make_confmsg(confkey, nonce) badconfirm_hex = hexlify(badconfirm).decode("ascii") response(w, type=u"message", phase=u"confirm", body=badconfirm_hex, side=u"s2") - self.failureResultOf(d1, WP) + self.failureResultOf(d1, WrongPasswordError) # once the error is signalled, all API calls should fail - self.assertRaises(WP, w.send, u"foo") - self.assertRaises(WP, w.derive_key, u"foo") - self.failureResultOf(w.get(), WP) - self.failureResultOf(w.get_verifier(), WP) + self.assertRaises(WrongPasswordError, w.send, u"foo") + self.assertRaises(WrongPasswordError, + w.derive_key, u"foo", SecretBox.KEY_SIZE) + self.failureResultOf(w.get(), WrongPasswordError) + self.failureResultOf(w.verify(), WrongPasswordError) # event orderings to exercise: @@ -562,60 +628,88 @@ class Wormholes(ServerBase, unittest.TestCase): yield w1.close(wait=True) yield w2.close(wait=True) -class Off: - @inlineCallbacks def test_same_message(self): # the two sides use random nonces for their messages, so it's ok for # both to try and send the same body: they'll result in distinct # encrypted messages - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) code = yield w1.get_code() w2.set_code(code) - yield self.doBoth(w1.send(b"data"), w2.send(b"data")) - dl = yield self.doBoth(w1.get(), w2.get()) - (dataX, dataY) = dl + w1.send(b"data") + w2.send(b"data") + dataX = yield w1.get() + dataY = yield w2.get() self.assertEqual(dataX, b"data") self.assertEqual(dataY, b"data") - yield self.doBoth(w1.close(), w2.close()) + yield w1.close(wait=True) + yield w2.close(wait=True) @inlineCallbacks def test_interleaved(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) code = yield w1.get_code() w2.set_code(code) - res = yield self.doBoth(w1.send(b"data1"), w2.get()) - (_, dataY) = res + w1.send(b"data1") + dataY = yield w2.get() self.assertEqual(dataY, b"data1") - dl = yield self.doBoth(w1.get(), w2.send(b"data2")) - (dataX, _) = dl + d = w1.get() + w2.send(b"data2") + dataX = yield d self.assertEqual(dataX, b"data2") - yield self.doBoth(w1.close(), w2.close()) + yield w1.close(wait=True) + yield w2.close(wait=True) + + @inlineCallbacks + def test_unidirectional(self): + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + code = yield w1.get_code() + w2.set_code(code) + w1.send(b"data1") + dataY = yield w2.get() + self.assertEqual(dataY, b"data1") + yield w1.close(wait=True) + yield w2.close(wait=True) + + @inlineCallbacks + def test_early(self): + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w1.send(b"data1") + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + d = w2.get() + w1.set_code(u"123-abc-def") + w2.set_code(u"123-abc-def") + dataY = yield d + self.assertEqual(dataY, b"data1") + yield w1.close(wait=True) + yield w2.close(wait=True) @inlineCallbacks def test_fixed_code(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) w1.set_code(u"123-purple-elephant") w2.set_code(u"123-purple-elephant") - yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) + w1.send(b"data1"), w2.send(b"data2") dl = yield self.doBoth(w1.get(), w2.get()) (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") - yield self.doBoth(w1.close(), w2.close()) + yield w1.close(wait=True) + yield w2.close(wait=True) @inlineCallbacks def test_multiple_messages(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) w1.set_code(u"123-purple-elephant") w2.set_code(u"123-purple-elephant") - yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) - yield self.doBoth(w1.send(b"data3"), w2.send(b"data4")) + w1.send(b"data1"), w2.send(b"data2") + w1.send(b"data3"), w2.send(b"data4") dl = yield self.doBoth(w1.get(), w2.get()) (dataX, dataY) = dl self.assertEqual(dataX, b"data2") @@ -624,124 +718,69 @@ class Off: (dataX, dataY) = dl self.assertEqual(dataX, b"data4") self.assertEqual(dataY, b"data3") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_multiple_messages_2(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2.set_code(u"123-purple-elephant") - # TODO: set_code should be sufficient to kick things off, but for now - # we must also let both sides do at least one send() or get() - yield self.doBoth(w1.send(b"data1"), w2.send(b"ignored")) - yield w1.get() - yield w1.send(b"data2") - yield w1.send(b"data3") - data = yield w2.get() - self.assertEqual(data, b"data1") - data = yield w2.get() - self.assertEqual(data, b"data2") - data = yield w2.get() - self.assertEqual(data, b"data3") - yield self.doBoth(w1.close(), w2.close()) + yield w1.close(wait=True) + yield w2.close(wait=True) @inlineCallbacks def test_wrong_password(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) code = yield w1.get_code() w2.set_code(code+"not") + # That's enough to allow both sides to discover the mismatch, but + # only after the confirmation message gets through. API calls that + # don't wait will appear to work until the mismatched confirmation + # message arrives. + w1.send(b"should still work") + w2.send(b"should still work") - # w2 can't throw WrongPasswordError until it sees a CONFIRM message, - # and w1 won't send CONFIRM until it sees a PAKE message, which w2 - # won't send until we call get. So we need both sides to be - # running at the same time for this test. - d1 = w1.send(b"data1") - # at this point, w1 should be waiting for w2.PAKE - + # API calls that wait (i.e. get) will errback yield self.assertFailure(w2.get(), WrongPasswordError) - # * w2 will send w2.PAKE, wait for (and get) w1.PAKE, compute a key, - # send w2.CONFIRM, then wait for w1.DATA. - # * w1 will get w2.PAKE, compute a key, send w1.CONFIRM. - # * w1 might also get w2.CONFIRM, and may notice the error before it - # sends w1.CONFIRM, in which case the wait=True will signal an - # error inside _get_master_key() (inside send), and d1 will - # errback. - # * but w1 might not see w2.CONFIRM yet, in which case it won't - # errback until we do w1.get() - # * w2 gets w1.CONFIRM, notices the error, records it. - # * w2 (waiting for w1.DATA) wakes up, sees the error, throws - # * meanwhile w1 finishes sending its data. w2.CONFIRM may or may not - # have arrived by then - try: - yield d1 - except WrongPasswordError: - pass - - # When we ask w1 to get(), one of two things might happen: - # * if w2.CONFIRM arrived already, it will have recorded the error. - # When w1.get() sleeps (waiting for w2.DATA), we'll notice the - # error before sleeping, and throw WrongPasswordError - # * if w2.CONFIRM hasn't arrived yet, we'll sleep. When w2.CONFIRM - # arrives, we notice and record the error, and wake up, and throw - - # Note that we didn't do w2.send(), so we're hoping that w1 will - # have enough information to detect the error before it sleeps - # (waiting for w2.DATA). Checking for the error both before sleeping - # and after waking up makes this happen. - - # so now w1 should have enough information to throw too yield self.assertFailure(w1.get(), WrongPasswordError) - # both sides are closed automatically upon error, but it's still - # legal to call .close(), and should be idempotent - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_no_confirm(self): - # newer versions (which check confirmations) should will work with - # older versions (that don't send confirmations) - w1 = Wormhole(APPID, self.relayurl) - w1._send_confirm = False - w2 = Wormhole(APPID, self.relayurl) - - code = yield w1.get_code() - w2.set_code(code) - dl = yield self.doBoth(w1.send(b"data1"), w2.get()) - self.assertEqual(dl[1], b"data1") - dl = yield self.doBoth(w1.get(), w2.send(b"data2")) - self.assertEqual(dl[0], b"data2") - yield self.doBoth(w1.close(), w2.close()) + yield w1.close(wait=True) + yield w2.close(wait=True) + self.flushLoggedErrors(WrongPasswordError) @inlineCallbacks def test_verifier(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) code = yield w1.get_code() w2.set_code(code) - res = yield self.doBoth(w1.get_verifier(), w2.get_verifier()) - v1, v2 = res + v1 = yield w1.verify() + v2 = yield w2.verify() self.failUnlessEqual(type(v1), type(b"")) self.failUnlessEqual(v1, v2) - yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) - dl = yield self.doBoth(w1.get(), w2.get()) - (dataX, dataY) = dl + w1.send(b"data1") + w2.send(b"data2") + dataX = yield w1.get() + dataY = yield w2.get() self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") - yield self.doBoth(w1.close(), w2.close()) + yield w1.close(wait=True) + yield w2.close(wait=True) + +class Errors(ServerBase, unittest.TestCase): + @inlineCallbacks + def test_codes_1(self): + w = wormhole.wormhole(APPID, self.relayurl, reactor) + # definitely too early + self.assertRaises(UsageError, w.derive_key, u"purpose", 12) + + w.set_code(u"123-purple-elephant") + # code can only be set once + self.assertRaises(UsageError, w.set_code, u"123-nope") + yield self.assertFailure(w.get_code(), UsageError) + yield self.assertFailure(w.input_code(), UsageError) + yield w.close(wait=True) @inlineCallbacks - def test_errors(self): - w1 = Wormhole(APPID, self.relayurl) - yield self.assertFailure(w1.get_verifier(), UsageError) - yield self.assertFailure(w1.send(b"data"), UsageError) - yield self.assertFailure(w1.get(), UsageError) - w1.set_code(u"123-purple-elephant") - yield self.assertRaises(UsageError, w1.set_code, u"123-nope") - yield self.assertFailure(w1.get_code(), UsageError) - w2 = Wormhole(APPID, self.relayurl) - yield w2.get_code() - yield self.assertFailure(w2.get_code(), UsageError) - yield self.doBoth(w1.close(), w2.close()) + def test_codes_2(self): + w = wormhole.wormhole(APPID, self.relayurl, reactor) + yield w.get_code() + self.assertRaises(UsageError, w.set_code, u"123-nope") + yield self.assertFailure(w.get_code(), UsageError) + yield self.assertFailure(w.input_code(), UsageError) + yield w.close(wait=True) diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 53cefca..134aa0b 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -14,7 +14,8 @@ from spake2 import SPAKE2_Symmetric from . import __version__ from . import codes #from .errors import ServerError, Timeout -from .errors import WrongPasswordError, UsageError, WelcomeError +from .errors import (WrongPasswordError, UsageError, WelcomeError, + WormholeClosedError) from .timing import DebugTiming from hkdf import Hkdf @@ -205,6 +206,9 @@ class _WelcomeHandler: if "error" in welcome: return self._signal_error(WelcomeError(welcome["error"])) +# states for nameplates, mailboxes, and the websocket connection +(CLOSED, OPENING, OPEN, CLOSING) = ("closed", "opening", "open", "closing") + class _Wormhole: def __init__(self, appid, relay_url, reactor, tor_manager, timing): @@ -217,31 +221,28 @@ class _Wormhole: self._welcomer = _WelcomeHandler(self._ws_url, __version__, self._signal_error) self._side = hexlify(os.urandom(5)).decode("ascii") - self._connected = None + self._connection_state = CLOSED self._connection_waiters = [] self._started_get_code = False self._get_code = None self._code = None self._nameplate_id = None - self._nameplate_claimed = False - self._nameplate_released = False - self._release_waiter = None + self._nameplate_state = CLOSED self._mailbox_id = None - self._mailbox_opened = False - self._mailbox_closed = False - self._close_waiter = None + self._mailbox_state = CLOSED self._flag_need_nameplate = True self._flag_need_to_see_mailbox_used = True self._flag_need_to_build_msg1 = True self._flag_need_to_send_PAKE = True self._key = None - self._closed = False + self._close_called = False # the close() API has been called + self._closing = False # we've started shutdown self._disconnect_waiter = defer.Deferred() - self._mood = u"happy" self._error = None self._get_verifier_called = False - self._verifier_waiter = defer.Deferred() + self._verifier = None + self._verifier_waiter = None self._next_send_phase = 0 # send() queues plaintext here, waiting for a connection and the key @@ -252,33 +253,62 @@ class _Wormhole: self._receive_waiters = {} # phase -> Deferred self._received_messages = {} # phase -> plaintext - def _signal_error(self, error): - # close the mailbox with an "errory" mood, errback all Deferreds, - # record the error, fail all subsequent API calls - if self.DEBUG: print("_signal_error", error) - self._error = error # causes new API calls to fail - for d in self._connection_waiters: - d.errback(error) - if self._get_code: - self._get_code._allocated_d.errback(error) - if not self._verifier_waiter.called: - self._verifier_waiter.errback(error) - for d in self._receive_waiters.values(): - d.errback(error) + # API METHODS for applications to call - self._maybe_close(mood=u"errory") - if self._release_waiter and not self._release_waiter.called: - self._release_waiter.errback(error) - if self._close_waiter and not self._close_waiter.called: - self._close_waiter.errback(error) - # leave self._disconnect_waiter alone - if self.DEBUG: print("_signal_error done") + # You must use at least one of these entry points, to establish the + # wormhole code. Other APIs will stall or be queued until we have one. + + # entry point 1: generate a new code. returns a Deferred + def get_code(self, code_length=2): # XX rename to allocate_code()? create_? + return self._API_get_code(code_length) + + # entry point 2: interactively type in a code, with completion. returns + # Deferred + def input_code(self, prompt="Enter wormhole code: ", code_length=2): + return self._API_input_code(prompt, code_length) + + # entry point 3: paste in a fully-formed code. No return value. + def set_code(self, code): + self._API_set_code(code) + + # todo: restore-saved-state entry points + + def verify(self): + """Returns a Deferred that fires when we've heard back from the other + side, and have confirmed that they used the right wormhole code. When + successful, the Deferred fires with a "verifier" (a bytestring) which + can be compared out-of-band before making additional API calls. If + they used the wrong wormhole code, the Deferred errbacks with + WrongPasswordError. + """ + return self._API_verify() + + def send(self, outbound_data): + return self._API_send(outbound_data) + + def get(self): + return self._API_get() + + def derive_key(self, purpose, length): + """Derive a new key from the established wormhole channel for some + other purpose. This is a deterministic randomized function of the + session key and the 'purpose' string (unicode/py3-string). This + cannot be called until verify() or get() has fired. + """ + return self._API_derive_key(purpose, length) + + def close(self, wait=False): + return self._API_close(wait) + + # INTERNAL METHODS beyond here def _start(self): d = self._connect() # causes stuff to happen d.addErrback(log.err) return d # fires when connection is established, if you care + + def _make_endpoint(self, hostname, port): if self._tor_manager: return self._tor_manager.get_endpoint_for(hostname, port) @@ -289,6 +319,7 @@ class _Wormhole: # TODO: if we lose the connection, make a new one, re-establish the # state assert self._side + self._connection_state = OPENING p = urlparse(self._ws_url) f = WSFactory(self._ws_url) f.wormhole = self @@ -311,16 +342,18 @@ class _Wormhole: self._ws_t = self._timing.add("websocket") def _event_ws_opened(self, _): - self._connected = True + self._connection_state = OPEN + if self._closing: + return self._maybe_finished_closing() self._ws_send_command(u"bind", appid=self._appid, side=self._side) - self._maybe_get_mailbox() + self._maybe_claim_nameplate() self._maybe_send_pake() waiters, self._connection_waiters = self._connection_waiters, [] for d in waiters: d.callback(None) def _when_connected(self): - if self._connected: + if self._connection_state == OPEN: return defer.succeed(None) d = defer.Deferred() self._connection_waiters.append(d) @@ -331,6 +364,7 @@ class _Wormhole: # their receives, and vice versa. They are also correlated with the # ACKs we get back from the server (which we otherwise ignore). There # are so few messages, 16 bits is enough to be mostly-unique. + if self.DEBUG: print("SEND", mtype) kwargs["id"] = hexlify(os.urandom(2)).decode("ascii") kwargs["type"] = mtype payload = json.dumps(kwargs).encode("utf-8") @@ -358,7 +392,7 @@ class _Wormhole: # entry point 1: generate a new code @inlineCallbacks - def get_code(self, code_length=2): # XX rename to allocate_code()? create_? + def _API_get_code(self, code_length): if self._code is not None: raise UsageError if self._started_get_code: raise UsageError self._started_get_code = True @@ -370,13 +404,13 @@ class _Wormhole: # TODO: signal_error code = yield gc.go() self._get_code = None - self._nameplate_claimed = True # side-effect of allocation + self._nameplate_state = OPEN self._event_learned_code(code) returnValue(code) # entry point 2: interactively type in a code, with completion @inlineCallbacks - def input_code(self, prompt="Enter wormhole code: ", code_length=2): + def _API_input_code(self, prompt, code_length): if self._code is not None: raise UsageError if self._started_input_code: raise UsageError self._started_input_code = True @@ -390,7 +424,7 @@ class _Wormhole: returnValue(None) # entry point 3: paste in a fully-formed code - def set_code(self, code): + def _API_set_code(self, code): self._timing.add("API set_code") if not isinstance(code, type(u"")): raise TypeError(type(code)) if self._code is not None: raise UsageError @@ -437,13 +471,13 @@ class _Wormhole: # for each such condition Y, every _event_Y must call _maybe_X def _event_learned_nameplate(self): - self._maybe_get_mailbox() + self._maybe_claim_nameplate() - def _maybe_get_mailbox(self): - if not (self._nameplate_id and self._connected): + def _maybe_claim_nameplate(self): + if not (self._nameplate_id and self._connection_state == OPEN): return self._ws_send_command(u"claim", nameplate=self._nameplate_id) - self._nameplate_claimed = True + self._nameplate_state = OPEN def _response_handle_claimed(self, msg): mailbox_id = msg["mailbox"] @@ -453,16 +487,19 @@ class _Wormhole: def _event_learned_mailbox(self): if not self._mailbox_id: raise UsageError - if self._mailbox_opened: raise UsageError + assert self._mailbox_state == CLOSED, self._mailbox_state + if self._closing: + return self._ws_send_command(u"open", mailbox=self._mailbox_id) - self._mailbox_opened = True + self._mailbox_state = OPEN # causes old messages to be sent now, and subscribes to new messages self._maybe_send_pake() self._maybe_send_phase_messages() def _maybe_send_pake(self): # TODO: deal with reentrant call - if not (self._connected and self._mailbox_opened + if not (self._connection_state == OPEN + and self._mailbox_state == OPEN and self._flag_need_to_send_PAKE): return self._msg_send(u"pake", self._msg1) @@ -477,44 +514,52 @@ class _Wormhole: self._timing.add("key established") # both sides send different (random) confirmation messages - confkey = self.derive_key(u"wormhole:confirmation") + confkey = self._derive_key(u"wormhole:confirmation") nonce = os.urandom(CONFMSG_NONCE_LENGTH) confmsg = make_confmsg(confkey, nonce) self._msg_send(u"confirm", confmsg) - verifier = self.derive_key(u"wormhole:verifier") + verifier = self._derive_key(u"wormhole:verifier") self._event_computed_verifier(verifier) self._maybe_send_phase_messages() - def get_verifier(self): + def _API_verify(self): + # TODO: rename "verify()", make it stall until confirm received. If + # you want to discover WrongPasswordError before doing send(), call + # verify() first. If you also want to deny a successful MitM (and + # have some other way to check a long verifier), use the return value + # of verify(). if self._error: return defer.fail(self._error) - if self._closed: raise UsageError if self._get_verifier_called: raise UsageError self._get_verifier_called = True + if self._verifier: + return defer.succeed(self._verifier) # TODO: maybe have this wait on _event_received_confirm too + self._verifier_waiter = defer.Deferred() return self._verifier_waiter def _event_computed_verifier(self, verifier): - self._verifier_waiter.callback(verifier) + self._verifier = verifier + if self._verifier_waiter: + self._verifier_waiter.callback(verifier) def _event_received_confirm(self, body): # TODO: we might not have a master key yet, if the caller wasn't # waiting in _get_master_key() when a back-to-back pake+_confirm # message pair arrived. - confkey = self.derive_key(u"wormhole:confirmation") + confkey = self._derive_key(u"wormhole:confirmation") nonce = body[:CONFMSG_NONCE_LENGTH] if body != make_confmsg(confkey, nonce): # this makes all API calls fail if self.DEBUG: print("CONFIRM FAILED") - return self._signal_error(WrongPasswordError()) + return self._signal_error(WrongPasswordError(), u"scary") - def send(self, outbound_data): + def _API_send(self, outbound_data): if self._error: raise self._error if not isinstance(outbound_data, type(b"")): raise TypeError(type(outbound_data)) - if self._closed: raise UsageError phase = self._next_send_phase self._next_send_phase += 1 self._plaintext_to_send.append( (phase, outbound_data) ) @@ -523,14 +568,16 @@ class _Wormhole: def _maybe_send_phase_messages(self): # TODO: deal with reentrant call - if not (self._connected and self._mailbox_opened and self._key): + if not (self._connection_state == OPEN + and self._mailbox_state == OPEN + and self._key): return plaintexts = self._plaintext_to_send self._plaintext_to_send = [] for pm in plaintexts: (phase, plaintext) = pm assert isinstance(phase, int), type(phase) - data_key = self.derive_key(u"wormhole:phase:%d" % phase) + data_key = self._derive_key(u"wormhole:phase:%d" % phase) encrypted = self._encrypt_data(data_key, plaintext) self._msg_send(u"%d" % phase, encrypted) @@ -550,8 +597,7 @@ class _Wormhole: def _msg_send(self, phase, body): if phase in self._sent_phases: raise UsageError - if not self._mailbox_opened: raise UsageError - if self._mailbox_closed: raise UsageError + assert self._mailbox_state == OPEN, self._mailbox_state self._sent_phases.add(phase) # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. @@ -566,8 +612,11 @@ class _Wormhole: self._maybe_release_nameplate() self._flag_need_to_see_mailbox_used = False - def derive_key(self, purpose, length=SecretBox.KEY_SIZE): + def _API_derive_key(self, purpose, length): if self._error: raise self._error + return self._derive_key(purpose, length) + + def _derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if self._key is None: raise UsageError # call derive_key after get_verifier() or get() @@ -597,12 +646,19 @@ class _Wormhole: self._event_received_confirm(body) return - # now notify anyone waiting on it + # It's a phase message, aimed at the application above us. Decrypt + # and deliver upstairs, notifying anyone waiting on it try: - data_key = self.derive_key(u"wormhole:phase:%s" % phase) + data_key = self._derive_key(u"wormhole:phase:%s" % phase) plaintext = self._decrypt_data(data_key, body) except CryptoError: - raise WrongPasswordError # TODO: signal + e = WrongPasswordError() + self._signal_error(e, u"scary") # flunk all other API calls + # make tests fail, if they aren't explicitly catching it + if self.DEBUG: print("CryptoError in msg received") + log.err(e) + if self.DEBUG: print(" did log.err", e) + return # ignore this message self._received_messages[phase] = plaintext if phase in self._receive_waiters: d = self._receive_waiters.pop(phase) @@ -616,9 +672,8 @@ class _Wormhole: data = box.decrypt(encrypted) return data - def get(self): + def _API_get(self): if self._error: return defer.fail(self._error) - if self._closed: raise UsageError phase = u"%d" % self._next_receive_phase self._next_receive_phase += 1 with self._timing.add("API get", phase=phase): @@ -627,64 +682,117 @@ class _Wormhole: d = self._receive_waiters[phase] = defer.Deferred() return d - def _maybe_close(self, mood): - if self._closed: + def _signal_error(self, error, mood): + if self.DEBUG: print("_signal_error", error, mood) + if self._error: return - self.close(mood) + self._maybe_close(error, mood) + if self.DEBUG: print("_signal_error done") @inlineCallbacks - def close(self, mood=None, wait=False): - # TODO: auto-close on error, mostly for load-from-state + def _API_close(self, wait=False, mood=u"happy"): if self.DEBUG: print("close", wait) - if self._closed: raise UsageError - self._closed = True - if mood: - self._mood = mood - self._maybe_release_nameplate() - self._maybe_close_mailbox() - if wait: - if self._nameplate_claimed: - if self.DEBUG: print("waiting for released") - self._release_waiter = defer.Deferred() - yield self._release_waiter - if self._mailbox_opened: - if self.DEBUG: print("waiting for closed") - self._close_waiter = defer.Deferred() - yield self._close_waiter - if self.DEBUG: print("dropping connection") - self._drop_connection() + if self._close_called: raise UsageError + self._close_called = True + self._maybe_close(WormholeClosedError(), mood) if wait: if self.DEBUG: print("waiting for disconnect") yield self._disconnect_waiter + def _maybe_close(self, error, mood): + if self._closing: + return + + # ordering constraints: + # * must wait for nameplate/mailbox acks before closing the websocket + # * must mark APIs for failure before errbacking Deferreds + # * since we give up control + # * must mark self._closing before errbacking Deferreds + # * since caller may call close() when we give up control + # * and close() will reenter _maybe_close + + self._error = error # causes new API calls to fail + + # since we're about to give up control by errbacking any API + # Deferreds, set self._closing, to make sure that a new call to + # close() isn't going to confuse anything + self._closing = True + + # now errback all API deferreds except close(): get_code, + # input_code, verify, get + for d in self._connection_waiters: # input_code, get_code (early) + if self.DEBUG: print("EB cw") + d.errback(error) + if self._get_code: # get_code (late) + if self.DEBUG: print("EB gc") + self._get_code._allocated_d.errback(error) + if self._verifier_waiter and not self._verifier_waiter.called: + if self.DEBUG: print("EB VW") + self._verifier_waiter.errback(error) + for d in self._receive_waiters.values(): + if self.DEBUG: print("EB RW") + d.errback(error) + # Release nameplate and close mailbox, if either was claimed/open. + # Since _closing is True when both ACKs come back, the handlers will + # close the websocket. When *that* finishes, _disconnect_waiter() + # will fire. + self._maybe_release_nameplate() + self._maybe_close_mailbox(mood) + # In the off chance we got closed before we even claimed the + # nameplate, give _maybe_finished_closing a chance to run now. + self._maybe_finished_closing() + def _maybe_release_nameplate(self): - if self.DEBUG: print("_maybe_release_nameplate", self._nameplate_claimed, self._nameplate_released) - if self._nameplate_claimed and not self._nameplate_released: + if self.DEBUG: print("_maybe_release_nameplate", self._nameplate_state) + if self._nameplate_state == OPEN: if self.DEBUG: print(" sending release") self._ws_send_command(u"release") - self._nameplate_released = True + self._nameplate_state = CLOSING def _response_handle_released(self, msg): - if self._release_waiter and not self._release_waiter.called: - self._release_waiter.callback(None) + self._nameplate_state = CLOSED + self._maybe_finished_closing() - def _maybe_close_mailbox(self): - if self.DEBUG: print("_maybe_close_mailbox", self._mailbox_opened, self._mailbox_closed) - if self._mailbox_opened and not self._mailbox_closed: + def _maybe_close_mailbox(self, mood): + if self.DEBUG: print("_maybe_close_mailbox", self._mailbox_state) + if self._mailbox_state == OPEN: if self.DEBUG: print(" sending close") - self._ws_send_command(u"close", mood=self._mood) - self._mailbox_closed = True + self._ws_send_command(u"close", mood=mood) + self._mailbox_state = CLOSING def _response_handle_closed(self, msg): - if self._close_waiter and not self._close_waiter.called: - self._close_waiter.callback(None) + self._mailbox_state = CLOSED + self._maybe_finished_closing() + + def _maybe_finished_closing(self): + if self.DEBUG: print("_maybe_finished_closing", self._closing, self._nameplate_state, self._mailbox_state, self._connection_state) + if not self._closing: + return + if (self._nameplate_state == CLOSED + and self._mailbox_state == CLOSED + and self._connection_state == OPEN): + self._connection_state = CLOSING + self._drop_connection() def _drop_connection(self): - self._ws.transport.loseConnection() # probably flushes + # separate method so it can be overridden by tests + self._ws.transport.loseConnection() # probably flushes output # calls _ws_closed() when done def _ws_closed(self, wasClean, code, reason): + # For now (until we add reconnection), losing the websocket means + # losing everything. Make all API callers fail. Help someone waiting + # in close() to finish + self._connection_state = CLOSED self._disconnect_waiter.callback(None) + self._maybe_finished_closing() + + # what needs to happen when _ws_closed() happens unexpectedly + # * errback all API deferreds + # * maybe: cause new API calls to fail + # * obviously can't release nameplate or close mailbox + # * can't re-close websocket + # * close(wait=True) callers should fire right away def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None): timing = timing or DebugTiming()