From c10fd9816740fa5bd5f5dfda3c8a44956b09abb8 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 22 May 2016 18:40:44 -0700 Subject: [PATCH] many tests working * add "released" ack-response for "release" command, to sync w.close() * move websocket URL to root * relayurl= should now be a "ws://" URL * many tests pass (except for test_twisted, which will be removed, and test_scripts) * still moving integration tests from test_twisted to test_wormhole.Wormholes --- src/wormhole/server/rendezvous_websocket.py | 2 + src/wormhole/server/server.py | 6 +- src/wormhole/test/common.py | 3 +- src/wormhole/test/test_server.py | 29 +- src/wormhole/test/test_wormhole.py | 578 +++++++++++++++++++- src/wormhole/wormhole.py | 237 +++++--- 6 files changed, 717 insertions(+), 138 deletions(-) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 153b037..05e2700 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -60,6 +60,7 @@ from .rendezvous import CrowdedError, SidedMessage # -> {type: "claim", nameplate: str} -> mailbox # <- {type: "claimed", mailbox: str} # -> {type: "release"} +# <- {type: "released"} # # -> {type: "open", mailbox: str} -> message # sends old messages now, and subscribes to deliver future messages @@ -183,6 +184,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): raise Error("must claim a nameplate before releasing it") self._app.release_nameplate(self._nameplate_id, self._side, server_rx) self._nameplate_id = None + self.send("released") def handle_open(self, msg, server_rx): diff --git a/src/wormhole/server/server.py b/src/wormhole/server/server.py index 694ed19..b847bc7 100644 --- a/src/wormhole/server/server.py +++ b/src/wormhole/server/server.py @@ -45,12 +45,8 @@ class RelayServer(service.MultiService): rendezvous = Rendezvous(db, welcome, blur_usage) rendezvous.setServiceParent(self) # for the pruning timer - root = Root() - wr = resource.Resource() - root.putChild(b"wormhole-relay", wr) - wsrf = WebSocketRendezvousFactory(None, rendezvous) - wr.putChild(b"ws", WebSocketResource(wsrf)) + root = WebSocketResource(wsrf) site = PrivacyEnhancedSite(root) if blur_usage: diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index 382cbfe..48c0685 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -17,8 +17,7 @@ class ServerBase: s.setServiceParent(self.sp) self._rendezvous = s._rendezvous self._transit_server = s._transit - self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport - self.rdv_ws_url = self.relayurl.replace("http:", "ws:") + "ws" + self.relayurl = u"ws://127.0.0.1:%d/" % relayport self.rdv_ws_port = relayport # ws://127.0.0.1:%d/wormhole-relay/ws self.transit = u"tcp:127.0.0.1:%d" % transitport diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 90c03eb..aeaa494 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -12,30 +12,6 @@ from .common import ServerBase from ..server import rendezvous, transit_server from ..server.rendezvous import Usage, SidedMessage -class Reachable(ServerBase, unittest.TestCase): - - def test_getPage(self): - # client.getPage requires bytes URL, returns bytes - url = self.relayurl.replace("wormhole-relay/", "").encode("ascii") - d = getPage(url) - def _got(res): - self.failUnlessEqual(res, b"Wormhole Relay\n") - d.addCallback(_got) - return d - - def test_agent(self): - url = self.relayurl.replace("wormhole-relay/", "").encode("ascii") - agent = Agent(reactor) - d = agent.request(b"GET", url) - def _check(resp): - self.failUnlessEqual(resp.code, 200) - return readBody(resp) - d.addCallback(_check) - def _got(res): - self.failUnlessEqual(res, b"Wormhole Relay\n") - d.addCallback(_got) - return d - class Server(ServerBase, unittest.TestCase): def test_apps(self): app1 = self._rendezvous.get_app(u"appid1") @@ -459,7 +435,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): @inlineCallbacks def make_client(self): - f = WSFactory(self.rdv_ws_url) + f = WSFactory(self.relayurl) f.d = defer.Deferred() reactor.connectTCP("127.0.0.1", self.rdv_ws_port, f) c = yield f.d @@ -644,7 +620,8 @@ class WebSocketAPI(ServerBase, unittest.TestCase): yield c1.next_non_ack() c1.send(u"release") - yield c1.sync() + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"released") row = app._db.execute("SELECT * FROM `nameplates`" " WHERE `app_id`='appid' AND `id`='np1'").fetchone() diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index 1591530..426bc0f 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -1,13 +1,13 @@ from __future__ import print_function -import json +import os, json, re +from binascii import hexlify, unhexlify import mock from twisted.trial import unittest from twisted.internet import reactor -from twisted.internet.defer import gatherResults, inlineCallbacks -#from ..twisted.transcribe import (wormhole, wormhole_from_serialized, -# UsageError, WrongPasswordError) -#from .common import ServerBase -from ..wormhole import _Wormhole, _WelcomeHandler +from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks +from .common import ServerBase +from .. import wormhole +from spake2 import SPAKE2_Symmetric from ..timing import DebugTiming APPID = u"appid" @@ -31,29 +31,131 @@ def response(w, **kwargs): w._ws_dispatch_response(payload) class Welcome(unittest.TestCase): - def test_no_current_version(self): - # WelcomeHandler should tolerate lack of ["current_version"] - w = _WelcomeHandler(u"relay_url", u"current_version") + def test_tolerate_no_current_version(self): + w = wormhole._WelcomeHandler(u"relay_url", u"current_version", None) w.handle_welcome({}) + def test_print_motd(self): + w = wormhole._WelcomeHandler(u"relay_url", u"current_version", None) + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"motd": u"message of\nthe day"}) + self.assertEqual(stderr.method_calls, + [mock.call.write(u"Server (at relay_url) says:\n" + " message of\n the day"), + mock.call.write(u"\n")]) + # motd is only displayed once + with mock.patch("sys.stderr") as stderr2: + w.handle_welcome({u"motd": u"second message"}) + self.assertEqual(stderr2.method_calls, []) + + def test_current_version(self): + w = wormhole._WelcomeHandler(u"relay_url", u"2.0", None) + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"current_version": u"2.0"}) + self.assertEqual(stderr.method_calls, []) + + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"current_version": u"3.0"}) + exp1 = (u"Warning: errors may occur unless both sides are" + " running the same version") + exp2 = (u"Server claims 3.0 is current, but ours is 2.0") + self.assertEqual(stderr.method_calls, + [mock.call.write(exp1), + mock.call.write(u"\n"), + mock.call.write(exp2), + mock.call.write(u"\n"), + ]) + + # warning is only displayed once + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"current_version": u"3.0"}) + self.assertEqual(stderr.method_calls, []) + + def test_non_release_version(self): + w = wormhole._WelcomeHandler(u"relay_url", u"2.0-dirty", None) + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"current_version": u"3.0"}) + self.assertEqual(stderr.method_calls, []) + + def test_signal_error(self): + se = mock.Mock() + w = wormhole._WelcomeHandler(u"relay_url", u"2.0", se) + w.handle_welcome({}) + self.assertEqual(se.mock_calls, []) + + w.handle_welcome({u"error": u"oops"}) + self.assertEqual(se.mock_calls, [mock.call(u"oops")]) + +class InputCode(unittest.TestCase): + def test_list(self): + send_command = mock.Mock() + ic = wormhole._InputCode(None, u"prompt", 2, send_command, + DebugTiming()) + d = ic._list() + self.assertNoResult(d) + self.assertEqual(send_command.mock_calls, [mock.call(u"list")]) + ic._response_handle_nameplates({u"type": u"nameplates", + u"nameplates": [{u"id": u"123"}]}) + res = self.successResultOf(d) + self.assertEqual(res, [u"123"]) + +class GetCode(unittest.TestCase): + def test_get(self): + send_command = mock.Mock() + gc = wormhole._GetCode(2, send_command, DebugTiming()) + d = gc.go() + self.assertNoResult(d) + self.assertEqual(send_command.mock_calls, [mock.call(u"allocate")]) + # TODO: nameplate attributes get added and checked here + gc._response_handle_allocated({u"type": u"allocated", + u"nameplate": u"123"}) + code = self.successResultOf(d) + self.assertIsInstance(code, type(u"")) + self.assert_(code.startswith(u"123-")) + pieces = code.split(u"-") + self.assertEqual(len(pieces), 3) # nameplate plus two words + self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) + class Basic(unittest.TestCase): def test_create(self): - w = _Wormhole(APPID, u"relay_url", reactor, None, None) + wormhole._Wormhole(APPID, u"relay_url", reactor, None, None) + + def check_out(self, out, **kwargs): + # Assert that each kwarg is present in the 'out' dict. Ignore other + # keys ('msgid' in particular) + for key, value in kwargs.items(): + self.assertIn(key, out) + self.assertEqual(out[key], value, (out, key, value)) + + def check_outbound(self, ws, types): + out = ws.outbound() + self.assertEqual(len(out), len(types), (out, types)) + for i,t in enumerate(types): + self.assertEqual(out[i][u"type"], t, (i,t,out)) + return out def test_basic(self): # We don't call w._start(), so this doesn't create a WebSocket - # connection. We provide a mock connection instead. - timing = DebugTiming() - with mock.patch("wormhole.wormhole._WelcomeHandler") as whc: - w = _Wormhole(APPID, u"relay_url", reactor, None, timing) - wh = whc.return_value - #w._welcomer = mock.Mock() + # connection. We provide a mock connection instead. If we wanted to + # exercise _connect, we'd mock out WSFactory. # w._connect = lambda self: None # w._event_connected(mock_ws) # w._event_ws_opened() # w._ws_dispatch_response(payload) + + timing = DebugTiming() + with mock.patch("wormhole.wormhole._WelcomeHandler") as wh_c: + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + wh = wh_c.return_value self.assertEqual(w._ws_url, u"relay_url") + self.assertTrue(w._flag_need_nameplate) + self.assertTrue(w._flag_need_to_build_msg1) + self.assertTrue(w._flag_need_to_send_PAKE) + + v = w.get_verifier() + + w._drop_connection = mock.Mock() ws = MockWebSocket() w._event_connected(ws) out = ws.outbound() @@ -62,13 +164,449 @@ class Basic(unittest.TestCase): w._event_ws_opened(None) out = ws.outbound() self.assertEqual(len(out), 1) - self.assertEqual(out[0]["type"], u"bind") - self.assertEqual(out[0]["appid"], APPID) - self.assertEqual(out[0]["side"], w._side) + self.check_out(out[0], type=u"bind", appid=APPID, side=w._side) self.assertIn(u"id", out[0]) - # WelcomeHandler should get called upon 'welcome' response + # WelcomeHandler should get called upon 'welcome' response. Its full + # behavior is exercised in 'Welcome' above. WELCOME = {u"foo": u"bar"} response(w, type="welcome", welcome=WELCOME) self.assertEqual(wh.mock_calls, [mock.call.handle_welcome(WELCOME)]) + # because we're connected, setting the code also claims the mailbox + CODE = u"123-foo-bar" + w.set_code(CODE) + self.assertFalse(w._flag_need_to_build_msg1) + out = ws.outbound() + self.assertEqual(len(out), 1) + self.check_out(out[0], type=u"claim", nameplate=u"123") + + # the server reveals the linked mailbox + response(w, type=u"claimed", mailbox=u"mb456") + + # that triggers event_learned_mailbox, which should send open() and + # PAKE + self.assertTrue(w._mailbox_opened) + out = ws.outbound() + self.assertEqual(len(out), 2) + self.check_out(out[0], type=u"open", mailbox=u"mb456") + self.check_out(out[1], type=u"add", phase=u"pake") + self.assertNoResult(v) + + # server echoes back all "add" messages + response(w, type=u"message", phase=u"pake", body=out[1][u"body"], + side=w._side) + self.assertNoResult(v) + + # next we build the simulated peer's PAKE operation + side2 = w._side + u"other" + msg1 = unhexlify(out[1][u"body"].encode("ascii")) + sp2 = SPAKE2_Symmetric(wormhole.to_bytes(CODE), + idSymmetric=wormhole.to_bytes(APPID)) + msg2 = sp2.start() + msg2_hex = hexlify(msg2).decode("ascii") + key = sp2.finish(msg1) + response(w, type=u"message", phase=u"pake", body=msg2_hex, side=side2) + + # hearing the peer's PAKE (msg2) makes us release the nameplate, send + # the confirmation message, delivered the verifier, and sends any + # queued phase messages + self.assertFalse(w._flag_need_to_see_mailbox_used) + self.assertEqual(w._key, key) + out = ws.outbound() + self.assertEqual(len(out), 2, out) + self.check_out(out[0], type=u"release") + self.check_out(out[1], type=u"add", phase=u"confirm") + verifier = self.successResultOf(v) + self.assertEqual(verifier, w.derive_key(u"wormhole:verifier")) + + # hearing a valid confirmation message doesn't throw an error + confkey = w.derive_key(u"wormhole:confirmation") + nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH) + confirm2 = wormhole.make_confmsg(confkey, nonce) + confirm2_hex = hexlify(confirm2).decode("ascii") + response(w, type=u"message", phase=u"confirm", body=confirm2_hex, + side=side2) + + # an outbound message can now be sent immediately + w.send(b"phase0-outbound") + out = ws.outbound() + self.assertEqual(len(out), 1) + self.check_out(out[0], type=u"add", phase=u"0") + # decrypt+check the outbound message + p0_outbound = unhexlify(out[0][u"body"].encode("ascii")) + msgkey0 = w.derive_key(u"wormhole:phase:0") + p0_plaintext = w._decrypt_data(msgkey0, p0_outbound) + self.assertEqual(p0_plaintext, b"phase0-outbound") + + # get() waits for the inbound message to arrive + md = w.get() + self.assertNoResult(md) + self.assertIn(u"0", w._receive_waiters) + self.assertNotIn(u"0", w._received_messages) + p0_inbound = w._encrypt_data(msgkey0, b"phase0-inbound") + p0_inbound_hex = hexlify(p0_inbound).decode("ascii") + response(w, type=u"message", phase=u"0", body=p0_inbound_hex, + side=side2) + p0_in = self.successResultOf(md) + self.assertEqual(p0_in, b"phase0-inbound") + self.assertNotIn(u"0", w._receive_waiters) + self.assertIn(u"0", w._received_messages) + + # receiving an inbound message will queue it until get() is called + msgkey1 = w.derive_key(u"wormhole:phase:1") + p1_inbound = w._encrypt_data(msgkey1, b"phase1-inbound") + p1_inbound_hex = hexlify(p1_inbound).decode("ascii") + response(w, type=u"message", phase=u"1", body=p1_inbound_hex, + side=side2) + self.assertIn(u"1", w._received_messages) + self.assertNotIn(u"1", w._receive_waiters) + p1_in = self.successResultOf(w.get()) + self.assertEqual(p1_in, b"phase1-inbound") + self.assertIn(u"1", w._received_messages) + self.assertNotIn(u"1", w._receive_waiters) + + w.close() + out = ws.outbound() + self.assertEqual(len(out), 1) + self.check_out(out[0], type=u"close", mood=u"happy") + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + + def test_close_wait_1(self): + # close after claiming the nameplate, but before opening the mailbox + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + CODE = u"123-foo-bar" + w.set_code(CODE) + self.check_outbound(ws, [u"bind", u"claim"]) + + d = w.close(wait=True) + self.check_outbound(ws, [u"release"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.successResultOf(d) + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + + def test_close_wait_2(self): + # close after both claiming the nameplate and opening the mailbox + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + CODE = u"123-foo-bar" + w.set_code(CODE) + response(w, type=u"claimed", mailbox=u"mb456") + self.check_outbound(ws, [u"bind", u"claim", u"open", u"add"]) + + d = w.close(wait=True) + self.check_outbound(ws, [u"release", u"close"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + response(w, type=u"closed") + self.successResultOf(d) + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + + def test_close_wait_3(self): + # close after claiming the nameplate, opening the mailbox, then + # releasing the nameplate + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + CODE = u"123-foo-bar" + w.set_code(CODE) + response(w, type=u"claimed", mailbox=u"mb456") + + w._key = b"" + msgkey = w.derive_key(u"wormhole:phase:misc") + p1_inbound = w._encrypt_data(msgkey, b"") + p1_inbound_hex = hexlify(p1_inbound).decode("ascii") + response(w, type=u"message", phase=u"misc", side=u"side2", + body=p1_inbound_hex) + self.check_outbound(ws, [u"bind", u"claim", u"open", u"add", + u"release"]) + + d = w.close(wait=True) + self.check_outbound(ws, [u"close"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + response(w, type=u"closed") + self.successResultOf(d) + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + + def test_get_code_mock(self): + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + ws = MockWebSocket() # TODO: mock w._ws_send_command instead + w._event_connected(ws) + w._event_ws_opened(None) + self.check_outbound(ws, [u"bind"]) + + gc_c = mock.Mock() + gc = gc_c.return_value = mock.Mock() + gc_d = gc.go.return_value = Deferred() + with mock.patch("wormhole.wormhole._GetCode", gc_c): + d = w.get_code() + self.assertNoResult(d) + + gc_d.callback(u"123-foo-bar") + code = self.successResultOf(d) + self.assertEqual(code, u"123-foo-bar") + + def test_get_code_real(self): + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + self.check_outbound(ws, [u"bind"]) + + d = w.get_code() + + out = ws.outbound() + self.assertEqual(len(out), 1) + self.check_out(out[0], type=u"allocate") + # TODO: nameplate attributes go here + self.assertNoResult(d) + + response(w, type=u"allocated", nameplate=u"123") + code = self.successResultOf(d) + self.assertIsInstance(code, type(u"")) + self.assert_(code.startswith(u"123-")) + pieces = code.split(u"-") + self.assertEqual(len(pieces), 3) # nameplate plus two words + self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) + +# event orderings to exercise: +# +# * normal sender: set_code, send_phase1, connected, claimed, learn_msg2, +# learn_phase1 +# * normal receiver (argv[2]=code): set_code, connected, learn_msg1, +# learn_phase1, send_phase1, +# * normal receiver (readline): connected, input_code +# * +# * set_code, then connected +# * connected, receive_pake, send_phase, set_code + +class Wormholes(ServerBase, unittest.TestCase): + # integration test, with a real server + + def doBoth(self, d1, d2): + return gatherResults([d1, d2], True) + + @inlineCallbacks + def test_basic(self): + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + code = yield w1.get_code() + w2.set_code(code) + w1.send(b"data1") + w2.send(b"data2") + dataX = yield w1.get() + dataY = yield w2.get() + self.assertEqual(dataX, b"data2") + self.assertEqual(dataY, b"data1") + yield w1.close(wait=True) + yield w2.close(wait=True) + +class Off: + + @inlineCallbacks + def test_same_message(self): + # the two sides use random nonces for their messages, so it's ok for + # both to try and send the same body: they'll result in distinct + # encrypted messages + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + code = yield w1.get_code() + w2.set_code(code) + yield self.doBoth(w1.send(b"data"), w2.send(b"data")) + dl = yield self.doBoth(w1.get(), w2.get()) + (dataX, dataY) = dl + self.assertEqual(dataX, b"data") + self.assertEqual(dataY, b"data") + yield self.doBoth(w1.close(), w2.close()) + + @inlineCallbacks + def test_interleaved(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + code = yield w1.get_code() + w2.set_code(code) + res = yield self.doBoth(w1.send(b"data1"), w2.get()) + (_, dataY) = res + self.assertEqual(dataY, b"data1") + dl = yield self.doBoth(w1.get(), w2.send(b"data2")) + (dataX, _) = dl + self.assertEqual(dataX, b"data2") + yield self.doBoth(w1.close(), w2.close()) + + @inlineCallbacks + def test_fixed_code(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + w1.set_code(u"123-purple-elephant") + w2.set_code(u"123-purple-elephant") + yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) + dl = yield self.doBoth(w1.get(), w2.get()) + (dataX, dataY) = dl + self.assertEqual(dataX, b"data2") + self.assertEqual(dataY, b"data1") + yield self.doBoth(w1.close(), w2.close()) + + + @inlineCallbacks + def test_multiple_messages(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + w1.set_code(u"123-purple-elephant") + w2.set_code(u"123-purple-elephant") + yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) + yield self.doBoth(w1.send(b"data3"), w2.send(b"data4")) + dl = yield self.doBoth(w1.get(), w2.get()) + (dataX, dataY) = dl + self.assertEqual(dataX, b"data2") + self.assertEqual(dataY, b"data1") + dl = yield self.doBoth(w1.get(), w2.get()) + (dataX, dataY) = dl + self.assertEqual(dataX, b"data4") + self.assertEqual(dataY, b"data3") + yield self.doBoth(w1.close(), w2.close()) + + @inlineCallbacks + def test_multiple_messages_2(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + w1.set_code(u"123-purple-elephant") + w2.set_code(u"123-purple-elephant") + # TODO: set_code should be sufficient to kick things off, but for now + # we must also let both sides do at least one send() or get() + yield self.doBoth(w1.send(b"data1"), w2.send(b"ignored")) + yield w1.get() + yield w1.send(b"data2") + yield w1.send(b"data3") + data = yield w2.get() + self.assertEqual(data, b"data1") + data = yield w2.get() + self.assertEqual(data, b"data2") + data = yield w2.get() + self.assertEqual(data, b"data3") + yield self.doBoth(w1.close(), w2.close()) + + @inlineCallbacks + def test_wrong_password(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + code = yield w1.get_code() + w2.set_code(code+"not") + + # w2 can't throw WrongPasswordError until it sees a CONFIRM message, + # and w1 won't send CONFIRM until it sees a PAKE message, which w2 + # won't send until we call get. So we need both sides to be + # running at the same time for this test. + d1 = w1.send(b"data1") + # at this point, w1 should be waiting for w2.PAKE + + yield self.assertFailure(w2.get(), WrongPasswordError) + # * w2 will send w2.PAKE, wait for (and get) w1.PAKE, compute a key, + # send w2.CONFIRM, then wait for w1.DATA. + # * w1 will get w2.PAKE, compute a key, send w1.CONFIRM. + # * w1 might also get w2.CONFIRM, and may notice the error before it + # sends w1.CONFIRM, in which case the wait=True will signal an + # error inside _get_master_key() (inside send), and d1 will + # errback. + # * but w1 might not see w2.CONFIRM yet, in which case it won't + # errback until we do w1.get() + # * w2 gets w1.CONFIRM, notices the error, records it. + # * w2 (waiting for w1.DATA) wakes up, sees the error, throws + # * meanwhile w1 finishes sending its data. w2.CONFIRM may or may not + # have arrived by then + try: + yield d1 + except WrongPasswordError: + pass + + # When we ask w1 to get(), one of two things might happen: + # * if w2.CONFIRM arrived already, it will have recorded the error. + # When w1.get() sleeps (waiting for w2.DATA), we'll notice the + # error before sleeping, and throw WrongPasswordError + # * if w2.CONFIRM hasn't arrived yet, we'll sleep. When w2.CONFIRM + # arrives, we notice and record the error, and wake up, and throw + + # Note that we didn't do w2.send(), so we're hoping that w1 will + # have enough information to detect the error before it sleeps + # (waiting for w2.DATA). Checking for the error both before sleeping + # and after waking up makes this happen. + + # so now w1 should have enough information to throw too + yield self.assertFailure(w1.get(), WrongPasswordError) + + # both sides are closed automatically upon error, but it's still + # legal to call .close(), and should be idempotent + yield self.doBoth(w1.close(), w2.close()) + + @inlineCallbacks + def test_no_confirm(self): + # newer versions (which check confirmations) should will work with + # older versions (that don't send confirmations) + w1 = Wormhole(APPID, self.relayurl) + w1._send_confirm = False + w2 = Wormhole(APPID, self.relayurl) + + code = yield w1.get_code() + w2.set_code(code) + dl = yield self.doBoth(w1.send(b"data1"), w2.get()) + self.assertEqual(dl[1], b"data1") + dl = yield self.doBoth(w1.get(), w2.send(b"data2")) + self.assertEqual(dl[0], b"data2") + yield self.doBoth(w1.close(), w2.close()) + + @inlineCallbacks + def test_verifier(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + code = yield w1.get_code() + w2.set_code(code) + res = yield self.doBoth(w1.get_verifier(), w2.get_verifier()) + v1, v2 = res + self.failUnlessEqual(type(v1), type(b"")) + self.failUnlessEqual(v1, v2) + yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) + dl = yield self.doBoth(w1.get(), w2.get()) + (dataX, dataY) = dl + self.assertEqual(dataX, b"data2") + self.assertEqual(dataY, b"data1") + yield self.doBoth(w1.close(), w2.close()) + + @inlineCallbacks + def test_errors(self): + w1 = Wormhole(APPID, self.relayurl) + yield self.assertFailure(w1.get_verifier(), UsageError) + yield self.assertFailure(w1.send(b"data"), UsageError) + yield self.assertFailure(w1.get(), UsageError) + w1.set_code(u"123-purple-elephant") + yield self.assertRaises(UsageError, w1.set_code, u"123-nope") + yield self.assertFailure(w1.get_code(), UsageError) + w2 = Wormhole(APPID, self.relayurl) + yield w2.get_code() + yield self.assertFailure(w2.get_code(), UsageError) + yield self.doBoth(w1.close(), w2.close()) + diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 2a16bb7..31ffe4c 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -15,7 +15,7 @@ from . import __version__ from . import codes #from .errors import ServerError, Timeout from .errors import WrongPasswordError, UsageError -#from .timing import DebugTiming +from .timing import DebugTiming from hkdf import Hkdf def HKDF(skm, outlen, salt=None, CTXinfo=b""): @@ -67,9 +67,10 @@ class WSFactory(websocket.WebSocketClientFactory): class _GetCode: - def __init__(self, code_length, send_command): + def __init__(self, code_length, send_command, timing): self._code_length = code_length self._send_command = send_command + self._timing = timing self._allocated_d = defer.Deferred() @inlineCallbacks @@ -87,11 +88,12 @@ class _GetCode: self._allocated_d.callback(nid) class _InputCode: - def __init__(self, reactor, prompt, code_length, send_command): + def __init__(self, reactor, prompt, code_length, send_command, timing): self._reactor = reactor self._prompt = prompt self._code_length = code_length self._send_command = send_command + self._timing = timing @inlineCallbacks def _list(self): @@ -172,11 +174,12 @@ class _InputCode: # readline finish. class _WelcomeHandler: - def __init__(self, url, current_version): + def __init__(self, url, current_version, signal_error): self._ws_url = url self._version_warning_displayed = False self._motd_displayed = False self._current_version = current_version + self._signal_error = signal_error def handle_welcome(self, welcome): if ("motd" in welcome and @@ -211,28 +214,45 @@ class _Wormhole: self._tor_manager = tor_manager self._timing = timing - self._welcomer = _WelcomeHandler(self._ws_url, __version__) + self._welcomer = _WelcomeHandler(self._ws_url, __version__, + self._signal_error) self._side = hexlify(os.urandom(5)).decode("ascii") self._connected = None + self._connection_waiters = [] + self._started_get_code = False + self._code = None self._nameplate_id = None + self._nameplate_claimed = False + self._nameplate_released = False + self._release_waiter = defer.Deferred() self._mailbox_id = None self._mailbox_opened = False self._mailbox_closed = False - self._flag_need_mailbox = True + self._close_waiter = defer.Deferred() + self._flag_need_nameplate = True self._flag_need_to_see_mailbox_used = True self._flag_need_to_build_msg1 = True self._flag_need_to_send_PAKE = True - self._flag_need_PAKE = True - self._flag_need_key = True # rename to not self._key + self._key = None + self._closed = False + self._mood = u"happy" + + self._get_verifier_called = False + self._verifier_waiter = defer.Deferred() self._next_send_phase = 0 - self._plaintext_to_send = [] # (phase, plaintext, deferred) - self._phase_messages_to_send = [] # not yet acked by server + # send() queues plaintext here, waiting for a connection and the key + self._plaintext_to_send = [] # (phase, plaintext) + self._sent_phases = set() # to detect double-send self._next_receive_phase = 0 self._receive_waiters = {} # phase -> Deferred - self._phase_messages_received = {} # phase -> message + self._received_messages = {} # phase -> plaintext + def _signal_error(self, error): + # close the mailbox with an "errory" mood, errback all Deferreds, + # record the error, fail all subsequent API calls + pass # XXX def _start(self): d = self._connect() # causes stuff to happen @@ -275,6 +295,16 @@ class _Wormhole: self._ws_send_command(u"bind", appid=self._appid, side=self._side) self._maybe_get_mailbox() self._maybe_send_pake() + waiters, self._connection_waiters = self._connection_waiters, [] + for d in waiters: + d.callback(None) + + def _when_connected(self): + if self._connected: + return defer.succeed(None) + d = defer.Deferred() + self._connection_waiters.append(d) + return d def _ws_send_command(self, mtype, **kwargs): # msgid is used by misc/dump-timing.py to correlate our sends with @@ -287,8 +317,10 @@ class _Wormhole: self._timing.add("ws_send", _side=self._side, **kwargs) self._ws.sendMessage(payload, False) + DEBUG=False def _ws_dispatch_response(self, payload): msg = json.loads(payload.decode("utf-8")) + if self.DEBUG and msg["type"]!="ack": print("DIS", msg["type"], msg) self._timing.add("ws_receive", _side=self._side, message=msg) mtype = msg["type"] meth = getattr(self, "_response_handle_"+mtype, None) @@ -311,9 +343,11 @@ class _Wormhole: if self._started_get_code: raise UsageError self._started_get_code = True with self._timing.add("API get_code"): - gc = _GetCode(code_length, self._ws_send_command) + yield self._when_connected() + gc = _GetCode(code_length, self._ws_send_command, self._timing) self._response_handle_allocated = gc._response_handle_allocated code = yield gc.go() + self._nameplate_claimed = True # side-effect of allocation self._event_learned_code(code) returnValue(code) @@ -324,6 +358,7 @@ class _Wormhole: if self._started_input_code: raise UsageError self._started_input_code = True with self._timing.add("API input_code"): + yield self._when_connected() ic = _InputCode(prompt, code_length, self._ws_send_command) self._response_handle_nameplates = ic._response_handle_nameplates code = yield ic.go() @@ -337,6 +372,17 @@ class _Wormhole: if self._code is not None: raise UsageError self._event_learned_code(code) + # TODO: entry point 4: restore pre-contact saved state (we haven't heard + # from the peer yet, so we still need the nameplate) + + # TODO: entry point 5: restore post-contact saved state (so we don't need + # or use the nameplate, only the mailbox) + def _restore_post_contact_state(self, state): + # ... + self._flag_need_nameplate = False + #self._mailbox_id = X(state) + self._event_learned_mailbox() + def _event_learned_code(self, code): self._timing.add("code established") self._code = code @@ -370,10 +416,10 @@ class _Wormhole: self._maybe_get_mailbox() def _maybe_get_mailbox(self): - if not (self._flag_need_mailbox and self._nameplate_id - and self._connected): + if not (self._nameplate_id and self._connected): return self._ws_send_command(u"claim", nameplate=self._nameplate_id) + self._nameplate_claimed = True def _response_handle_claimed(self, msg): mailbox_id = msg["mailbox"] @@ -382,10 +428,10 @@ class _Wormhole: self._event_learned_mailbox() def _event_learned_mailbox(self): - self._flag_need_mailbox = False if not self._mailbox_id: raise UsageError if self._mailbox_opened: raise UsageError self._ws_send_command(u"open", mailbox=self._mailbox_id) + self._mailbox_opened = True # causes old messages to be sent now, and subscribes to new messages self._maybe_send_pake() self._maybe_send_phase_messages() @@ -395,34 +441,36 @@ class _Wormhole: if not (self._connected and self._mailbox_opened and self._flag_need_to_send_PAKE): return - d = self._msg_send(u"pake", self._msg1) - def _pake_sent(res): - self._flag_need_to_send_PAKE = False - d.addCallback(_pake_sent) - d.addErrback(log.err) + self._msg_send(u"pake", self._msg1) + self._flag_need_to_send_PAKE = False - def _event_learned_PAKE(self, pake_msg): + def _event_received_pake(self, pake_msg): with self._timing.add("pake2", waiting="crypto"): self._key = self._sp.finish(pake_msg) self._event_established_key() def _event_established_key(self): self._timing.add("key established") - if self._send_confirm: - # both sides send different (random) confirmation messages - confkey = self.derive_key(u"wormhole:confirmation") - nonce = os.urandom(CONFMSG_NONCE_LENGTH) - confmsg = make_confmsg(confkey, nonce) - self._msg_send(u"confirm", confmsg, wait=True) + + # both sides send different (random) confirmation messages + confkey = self.derive_key(u"wormhole:confirmation") + nonce = os.urandom(CONFMSG_NONCE_LENGTH) + confmsg = make_confmsg(confkey, nonce) + self._msg_send(u"confirm", confmsg) + verifier = self.derive_key(u"wormhole:verifier") self._event_computed_verifier(verifier) + self._maybe_send_phase_messages() + def get_verifier(self): + if self._closed: raise UsageError + if self._get_verifier_called: raise UsageError + self._get_verifier_called = True + return self._verifier_waiter + def _event_computed_verifier(self, verifier): - self._verifier = verifier - d, self._verifier_waiter = self._verifier_waiter, None - if d: - d.callback(verifier) + self._verifier_waiter.callback(verifier) def _event_received_confirm(self, body): # TODO: we might not have a master key yet, if the caller wasn't @@ -435,19 +483,15 @@ class _Wormhole: return self._signal_error(WrongPasswordError()) - @inlineCallbacks - def send(self, outbound_data, wait=False): + def send(self, outbound_data): if not isinstance(outbound_data, type(b"")): raise TypeError(type(outbound_data)) if self._closed: raise UsageError phase = self._next_send_phase self._next_send_phase += 1 - d = defer.Deferred() - self._plaintext_to_send.append( (phase, outbound_data, d) ) - with self._timing.add("API send", phase=phase, wait=wait): + self._plaintext_to_send.append( (phase, outbound_data) ) + with self._timing.add("API send", phase=phase): self._maybe_send_phase_messages() - if wait: - yield d def _maybe_send_phase_messages(self): # TODO: deal with reentrant call @@ -456,18 +500,19 @@ class _Wormhole: plaintexts = self._plaintext_to_send self._plaintext_to_send = [] for pm in plaintexts: - (phase, plaintext, wait_d) = pm + (phase, plaintext) = pm + assert isinstance(phase, int), type(phase) data_key = self.derive_key(u"wormhole:phase:%d" % phase) encrypted = self._encrypt_data(data_key, plaintext) - d = self._msg_send(phase, encrypted) - d.addBoth(wait_d.callback) - d.addErrback(log.err) + self._msg_send(u"%d" % phase, encrypted) def _encrypt_data(self, key, data): # Without predefined roles, we can't derive predictably unique keys # for each side, so we use the same key for both. We use random # nonces to keep the messages distinct, and we automatically ignore # reflections. + # TODO: HKDF(side, nonce, key) ?? include 'side' to prevent + # reflections, since we no longer compare messages assert isinstance(key, type(b"")), type(key) assert isinstance(data, type(b"")), type(data) assert len(key) == SecretBox.KEY_SIZE, len(key) @@ -475,57 +520,39 @@ class _Wormhole: nonce = utils.random(SecretBox.NONCE_SIZE) return box.encrypt(data, nonce) - @inlineCallbacks - def _msg_send(self, phase, body, wait=False): - if phase in self._sent_messages: raise UsageError + def _msg_send(self, phase, body): + if phase in self._sent_phases: raise UsageError if not self._mailbox_opened: raise UsageError if self._mailbox_closed: raise UsageError - self._sent_messages[phase] = body + self._sent_phases.add(phase) # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. - t = self._timing.add("add", phase=phase, wait=wait) - yield self._ws_send_command(u"add", phase=phase, - body=hexlify(body).decode("ascii")) - if wait: - while phase not in self._delivered_messages: - yield self._sleep() - t.finish() + self._timing.add("add", phase=phase) + self._ws_send_command(u"add", phase=phase, + body=hexlify(body).decode("ascii")) - def _event_received_message(self, msg): - pass def _event_mailbox_used(self): + if self.DEBUG: print("_event_mailbox_used") if self._flag_need_to_see_mailbox_used: - self._ws_send_command(u"release") + self._maybe_release_nameplate() self._flag_need_to_see_mailbox_used = False def derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if self._key is None: - # call after get_verifier() or get() - raise UsageError + raise UsageError # call derive_key after get_verifier() or get() return HKDF(self._key, length, CTXinfo=to_bytes(purpose)) - def _event_received_phase_message(self, phase, message): - self._phase_messages_received[phase] = message - if phase in self._phase_message_waiters: - d = self._phase_message_waiters.pop(phase) - d.callback(message) - def _response_handle_message(self, msg): side = msg["side"] phase = msg["phase"] + assert isinstance(phase, type(u"")), type(phase) body = unhexlify(msg["body"].encode("ascii")) if side == self._side: return self._event_received_peer_message(phase, body) - def XXXackstuff(): - if phase in self._sent_messages and self._sent_messages[phase] == body: - self._delivered_messages.add(phase) # ack by server - self._wakeup() - return # ignore echoes of our outbound messages - def _event_received_peer_message(self, phase, body): # any message in the mailbox means we no longer need the nameplate self._event_mailbox_used() @@ -534,18 +561,23 @@ class _Wormhole: # err = ServerError("got duplicate phase %s" % phase, self._ws_url) # return self._signal_error(err) #self._received_messages[phase] = body + if phase == u"pake": + self._event_received_pake(body) + return if phase == u"confirm": self._event_received_confirm(body) + return + # now notify anyone waiting on it try: data_key = self.derive_key(u"wormhole:phase:%s" % phase) - inbound_data = self._decrypt_data(data_key, body) + plaintext = self._decrypt_data(data_key, body) except CryptoError: - raise WrongPasswordError - self._phase_messages_received[phase] = inbound_data + raise WrongPasswordError # TODO: signal + self._received_messages[phase] = plaintext if phase in self._receive_waiters: d = self._receive_waiters.pop(phase) - d.callback(inbound_data) + d.callback(plaintext) def _decrypt_data(self, key, encrypted): assert isinstance(key, type(b"")), type(key) @@ -555,28 +587,63 @@ class _Wormhole: data = box.decrypt(encrypted) return data - @inlineCallbacks def get(self): if self._closed: raise UsageError - if self._code is None: raise UsageError - phase = self._next_receive_phase + phase = u"%d" % self._next_receive_phase self._next_receive_phase += 1 with self._timing.add("API get", phase=phase): - if phase in self._phase_messages_received: - returnValue(self._phase_messages_received[phase]) + if phase in self._received_messages: + return defer.succeed(self._received_messages[phase]) d = self._receive_waiters[phase] = defer.Deferred() - yield d + return d - def _event_asked_to_close(self): + @inlineCallbacks + def close(self, mood=None, wait=False): + # TODO: auto-close on error, mostly for load-from-state + if self._closed: raise UsageError + if mood: + self._mood = mood + self._maybe_release_nameplate() + self._maybe_close_mailbox() + if wait: + if self._nameplate_claimed: + yield self._release_waiter + if self._mailbox_opened: + yield self._close_waiter + self._drop_connection() + + def _maybe_release_nameplate(self): + if self.DEBUG: print("_maybe_release_nameplate", self._nameplate_claimed, self._nameplate_released) + if self._nameplate_claimed and not self._nameplate_released: + if self.DEBUG: print(" sending release") + self._ws_send_command(u"release") + self._nameplate_released = True + + def _response_handle_released(self, msg): + self._release_waiter.callback(None) + + def _maybe_close_mailbox(self): + if self._mailbox_opened and not self._mailbox_closed: + self._ws_send_command(u"close", mood=self._mood) + self._mailbox_closed = True + + def _response_handle_closed(self, msg): + self._close_waiter.callback(None) + + def _drop_connection(self): + self._ws.transport.loseConnection() # probably flushes + # calls _ws_closed() when done + + def _ws_closed(self, wasClean, code, reason): pass - - def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None): + timing = timing or DebugTiming() w = _Wormhole(appid, relay_url, reactor, tor_manager, timing) w._start() return w -def wormhole_from_serialized(data, reactor): - w = _Wormhole.from_serialized(data, reactor) +def wormhole_from_serialized(data, reactor, timing=None): + timing = timing or DebugTiming() + w = _Wormhole.from_serialized(data, reactor, timing) return w