From 3da52b0a3ec22f69ce279346d1fb7e4dc56fc927 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 22 May 2016 11:31:00 -0700 Subject: [PATCH] add 'mock', building out test_wormhole --- src/wormhole/test/test_wormhole.py | 74 ++++++++ src/wormhole/wormhole.py | 273 ++++++++++++++++++++--------- tox.ini | 1 + 3 files changed, 270 insertions(+), 78 deletions(-) create mode 100644 src/wormhole/test/test_wormhole.py diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py new file mode 100644 index 0000000..1591530 --- /dev/null +++ b/src/wormhole/test/test_wormhole.py @@ -0,0 +1,74 @@ +from __future__ import print_function +import json +import mock +from twisted.trial import unittest +from twisted.internet import reactor +from twisted.internet.defer import gatherResults, inlineCallbacks +#from ..twisted.transcribe import (wormhole, wormhole_from_serialized, +# UsageError, WrongPasswordError) +#from .common import ServerBase +from ..wormhole import _Wormhole, _WelcomeHandler +from ..timing import DebugTiming + +APPID = u"appid" + +class MockWebSocket: + def __init__(self): + self._payloads = [] + def sendMessage(self, payload, is_binary): + assert not is_binary + self._payloads.append(payload) + + def outbound(self): + out = [] + while self._payloads: + p = self._payloads.pop(0) + out.append(json.loads(p.decode("utf-8"))) + return out + +def response(w, **kwargs): + payload = json.dumps(kwargs).encode("utf-8") + w._ws_dispatch_response(payload) + +class Welcome(unittest.TestCase): + def test_no_current_version(self): + # WelcomeHandler should tolerate lack of ["current_version"] + w = _WelcomeHandler(u"relay_url", u"current_version") + w.handle_welcome({}) + + +class Basic(unittest.TestCase): + def test_create(self): + w = _Wormhole(APPID, u"relay_url", reactor, None, None) + + def test_basic(self): + # We don't call w._start(), so this doesn't create a WebSocket + # connection. We provide a mock connection instead. + timing = DebugTiming() + with mock.patch("wormhole.wormhole._WelcomeHandler") as whc: + w = _Wormhole(APPID, u"relay_url", reactor, None, timing) + wh = whc.return_value + #w._welcomer = mock.Mock() + # w._connect = lambda self: None + # w._event_connected(mock_ws) + # w._event_ws_opened() + # w._ws_dispatch_response(payload) + self.assertEqual(w._ws_url, u"relay_url") + ws = MockWebSocket() + w._event_connected(ws) + out = ws.outbound() + self.assertEqual(len(out), 0) + + w._event_ws_opened(None) + out = ws.outbound() + self.assertEqual(len(out), 1) + self.assertEqual(out[0]["type"], u"bind") + self.assertEqual(out[0]["appid"], APPID) + self.assertEqual(out[0]["side"], w._side) + self.assertIn(u"id", out[0]) + + # WelcomeHandler should get called upon 'welcome' response + WELCOME = {u"foo": u"bar"} + response(w, type="welcome", welcome=WELCOME) + self.assertEqual(wh.mock_calls, [mock.call.handle_welcome(WELCOME)]) + diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index ca3d7c7..2a16bb7 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -1,4 +1,4 @@ -from __future__ import print_function +from __future__ import print_function, absolute_import import os, sys, json, re, unicodedata from six.moves.urllib_parse import urlparse from binascii import hexlify, unhexlify @@ -11,10 +11,11 @@ from nacl.secret import SecretBox from nacl.exceptions import CryptoError from nacl import utils from spake2 import SPAKE2_Symmetric -from .. import __version__ -from .. import codes -from ..errors import ServerError, Timeout, WrongPasswordError, UsageError -from ..timing import DebugTiming +from . import __version__ +from . import codes +#from .errors import ServerError, Timeout +from .errors import WrongPasswordError, UsageError +#from .timing import DebugTiming from hkdf import Hkdf def HKDF(skm, outlen, salt=None, CTXinfo=b""): @@ -80,7 +81,7 @@ class _GetCode: assert isinstance(code, type(u"")), type(code) returnValue(code) - def _ws_handle_allocated(self, msg): + def _response_handle_allocated(self, msg): nid = msg["nameplate"] assert isinstance(nid, type(u"")), type(nid) self._allocated_d.callback(nid) @@ -125,12 +126,16 @@ class _InputCode: self._reactor.removeSystemEventTrigger(t) returnValue(code) - def _ws_handle_nameplates(self, msg): + def _response_handle_nameplates(self, msg): nameplates = msg["nameplates"] assert isinstance(nameplates, list), type(nameplates) - for nameplate_id in nameplates: + nids = [] + for n in nameplates: + assert isinstance(n, dict), type(n) + nameplate_id = n[u"id"] assert isinstance(nameplate_id, type(u"")), type(nameplate_id) - self._lister_d.callback(nameplates) + nids.append(nameplate_id) + self._lister_d.callback(nids) def _warn_readline(self): # When our process receives a SIGINT, Twisted's SIGINT handler will @@ -166,11 +171,53 @@ class _InputCode: # doesn't see the signal, and we must still wait for stdin to make # readline finish. +class _WelcomeHandler: + def __init__(self, url, current_version): + self._ws_url = url + self._version_warning_displayed = False + self._motd_displayed = False + self._current_version = current_version + + def handle_welcome(self, welcome): + if ("motd" in welcome and + not self._motd_displayed): + motd_lines = welcome["motd"].splitlines() + motd_formatted = "\n ".join(motd_lines) + print("Server (at %s) says:\n %s" % + (self._ws_url, motd_formatted), file=sys.stderr) + self._motd_displayed = True + + # Only warn if we're running a release version (e.g. 0.0.6, not + # 0.0.6-DISTANCE-gHASH). Only warn once. + if ("current_version" in welcome + and "-" not in self._current_version + and not self._version_warning_displayed + and welcome["current_version"] != self._current_version): + print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) + print("Server claims %s is current, but ours is %s" + % (welcome["current_version"], self._current_version), + file=sys.stderr) + self._version_warning_displayed = True + + if "error" in welcome: + return self._signal_error(welcome["error"]) class _Wormhole: - def __init__(self): + def __init__(self, appid, relay_url, reactor, tor_manager, timing): + self._appid = appid + self._ws_url = relay_url + self._reactor = reactor + self._tor_manager = tor_manager + self._timing = timing + + self._welcomer = _WelcomeHandler(self._ws_url, __version__) + self._side = hexlify(os.urandom(5)).decode("ascii") self._connected = None + self._nameplate_id = None + self._mailbox_id = None + self._mailbox_opened = False + self._mailbox_closed = False self._flag_need_mailbox = True self._flag_need_to_see_mailbox_used = True self._flag_need_to_build_msg1 = True @@ -179,9 +226,11 @@ class _Wormhole: self._flag_need_key = True # rename to not self._key self._next_send_phase = 0 + self._plaintext_to_send = [] # (phase, plaintext, deferred) self._phase_messages_to_send = [] # not yet acked by server self._next_receive_phase = 0 + self._receive_waiters = {} # phase -> Deferred self._phase_messages_received = {} # phase -> message @@ -213,10 +262,11 @@ class _Wormhole: d.addCallback(self._event_connected) # f.d is errbacked if WebSocket negotiation fails, and the WebSocket # drops any data sent before onOpen() fires, so we must wait for it + d.addCallback(lambda _: f.d) d.addCallback(self._event_ws_opened) return d - def _event_connected(self, ws, f): + def _event_connected(self, ws): self._ws = ws self._ws_t = self._timing.add("websocket") @@ -224,30 +274,35 @@ class _Wormhole: self._connected = True self._ws_send_command(u"bind", appid=self._appid, side=self._side) self._maybe_get_mailbox() + self._maybe_send_pake() - def _ws_handle_welcome(self, msg): - welcome = msg["welcome"] - if ("motd" in welcome and - not self.motd_displayed): - motd_lines = welcome["motd"].splitlines() - motd_formatted = "\n ".join(motd_lines) - print("Server (at %s) says:\n %s" % - (self._ws_url, motd_formatted), file=sys.stderr) - self.motd_displayed = True + def _ws_send_command(self, mtype, **kwargs): + # msgid is used by misc/dump-timing.py to correlate our sends with + # their receives, and vice versa. They are also correlated with the + # ACKs we get back from the server (which we otherwise ignore). There + # are so few messages, 16 bits is enough to be mostly-unique. + kwargs["id"] = hexlify(os.urandom(2)).decode("ascii") + kwargs["type"] = mtype + payload = json.dumps(kwargs).encode("utf-8") + self._timing.add("ws_send", _side=self._side, **kwargs) + self._ws.sendMessage(payload, False) - # Only warn if we're running a release version (e.g. 0.0.6, not - # 0.0.6-DISTANCE-gHASH). Only warn once. - if ("-" not in __version__ and - not self.version_warning_displayed and - welcome["current_version"] != __version__): - print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) - print("Server claims %s is current, but ours is %s" - % (welcome["current_version"], __version__), file=sys.stderr) - self.version_warning_displayed = True + def _ws_dispatch_response(self, payload): + msg = json.loads(payload.decode("utf-8")) + self._timing.add("ws_receive", _side=self._side, message=msg) + mtype = msg["type"] + meth = getattr(self, "_response_handle_"+mtype, None) + if not meth: + # make tests fail, but real application will ignore it + log.err(ValueError("Unknown inbound message type %r" % (msg,))) + return + return meth(msg) - if "error" in welcome: - return self._signal_error(welcome["error"]) + def _response_handle_ack(self, msg): + pass + def _response_handle_welcome(self, msg): + self._welcomer.handle_welcome(msg["welcome"]) # entry point 1: generate a new code @inlineCallbacks @@ -257,7 +312,7 @@ class _Wormhole: self._started_get_code = True with self._timing.add("API get_code"): gc = _GetCode(code_length, self._ws_send_command) - self._ws_handle_allocated = gc._ws_handle_allocated + self._response_handle_allocated = gc._response_handle_allocated code = yield gc.go() self._event_learned_code(code) returnValue(code) @@ -269,9 +324,9 @@ class _Wormhole: if self._started_input_code: raise UsageError self._started_input_code = True with self._timing.add("API input_code"): - gc = _InputCode(prompt, code_length, self._ws_send_command) - self._ws_handle_nameplates = gc._ws_handle_nameplates - code = yield gc.go() + ic = _InputCode(prompt, code_length, self._ws_send_command) + self._response_handle_nameplates = ic._response_handle_nameplates + code = yield ic.go() self._event_learned_code(code) returnValue(None) @@ -320,15 +375,12 @@ class _Wormhole: return self._ws_send_command(u"claim", nameplate=self._nameplate_id) - def _ws_handle_claimed(self, msg): + def _response_handle_claimed(self, msg): mailbox_id = msg["mailbox"] assert isinstance(mailbox_id, type(u"")), type(mailbox_id) self._mailbox_id = mailbox_id self._event_learned_mailbox() - def _event_welcome(self): - pass - def _event_learned_mailbox(self): self._flag_need_mailbox = False if not self._mailbox_id: raise UsageError @@ -340,7 +392,7 @@ class _Wormhole: def _maybe_send_pake(self): # TODO: deal with reentrant call - if not (self._connected and self._mailbox + if not (self._connected and self._mailbox_opened and self._flag_need_to_send_PAKE): return d = self._msg_send(u"pake", self._msg1) @@ -349,42 +401,11 @@ class _Wormhole: d.addCallback(_pake_sent) d.addErrback(log.err) - def _maybe_send_phase_messages(self): - # TODO: deal with reentrant call - if not (self._connected and self._mailbox and self._key): - return - for pm in self._phase_messages_to_send: - (phase, message) = pm - d = self._msg_send(phase, message) - def _phase_message_sent(res, pm=pm): - try: - self._phase_messages_to_send.remove(pm) - except ValueError: - pass - d.addCallback(_phase_message_sent) - d.addErrback(log.err) - - - - def _event_received_message(self, msg): - pass - def _event_mailbox_used(self): - if self._flag_need_to_see_mailbox_used: - self._ws_send_command(u"release") - self._flag_need_to_see_mailbox_used = False - def _event_learned_PAKE(self, pake_msg): with self._timing.add("pake2", waiting="crypto"): self._key = self._sp.finish(pake_msg) self._event_established_key() - def derive_key(self, purpose, length=SecretBox.KEY_SIZE): - if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) - if self._key is None: - # call after get_verifier() or get() - raise UsageError - return HKDF(self._key, length, CTXinfo=to_bytes(purpose)) - def _event_established_key(self): self._timing.add("key established") if self._send_confirm: @@ -395,7 +416,8 @@ class _Wormhole: self._msg_send(u"confirm", confmsg, wait=True) verifier = self.derive_key(u"wormhole:verifier") self._event_computed_verifier(verifier) - pass + self._maybe_send_phase_messages() + def _event_computed_verifier(self, verifier): self._verifier = verifier d, self._verifier_waiter = self._verifier_waiter, None @@ -412,13 +434,85 @@ class _Wormhole: # this makes all API calls fail return self._signal_error(WrongPasswordError()) + + @inlineCallbacks + def send(self, outbound_data, wait=False): + if not isinstance(outbound_data, type(b"")): + raise TypeError(type(outbound_data)) + if self._closed: raise UsageError + phase = self._next_send_phase + self._next_send_phase += 1 + d = defer.Deferred() + self._plaintext_to_send.append( (phase, outbound_data, d) ) + with self._timing.add("API send", phase=phase, wait=wait): + self._maybe_send_phase_messages() + if wait: + yield d + + def _maybe_send_phase_messages(self): + # TODO: deal with reentrant call + if not (self._connected and self._mailbox_opened and self._key): + return + plaintexts = self._plaintext_to_send + self._plaintext_to_send = [] + for pm in plaintexts: + (phase, plaintext, wait_d) = pm + data_key = self.derive_key(u"wormhole:phase:%d" % phase) + encrypted = self._encrypt_data(data_key, plaintext) + d = self._msg_send(phase, encrypted) + d.addBoth(wait_d.callback) + d.addErrback(log.err) + + def _encrypt_data(self, key, data): + # Without predefined roles, we can't derive predictably unique keys + # for each side, so we use the same key for both. We use random + # nonces to keep the messages distinct, and we automatically ignore + # reflections. + assert isinstance(key, type(b"")), type(key) + assert isinstance(data, type(b"")), type(data) + assert len(key) == SecretBox.KEY_SIZE, len(key) + box = SecretBox(key) + nonce = utils.random(SecretBox.NONCE_SIZE) + return box.encrypt(data, nonce) + + @inlineCallbacks + def _msg_send(self, phase, body, wait=False): + if phase in self._sent_messages: raise UsageError + if not self._mailbox_opened: raise UsageError + if self._mailbox_closed: raise UsageError + self._sent_messages[phase] = body + # TODO: retry on failure, with exponential backoff. We're guarding + # against the rendezvous server being temporarily offline. + t = self._timing.add("add", phase=phase, wait=wait) + yield self._ws_send_command(u"add", phase=phase, + body=hexlify(body).decode("ascii")) + if wait: + while phase not in self._delivered_messages: + yield self._sleep() + t.finish() + + + def _event_received_message(self, msg): + pass + def _event_mailbox_used(self): + if self._flag_need_to_see_mailbox_used: + self._ws_send_command(u"release") + self._flag_need_to_see_mailbox_used = False + + def derive_key(self, purpose, length=SecretBox.KEY_SIZE): + if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) + if self._key is None: + # call after get_verifier() or get() + raise UsageError + return HKDF(self._key, length, CTXinfo=to_bytes(purpose)) + def _event_received_phase_message(self, phase, message): self._phase_messages_received[phase] = message if phase in self._phase_message_waiters: d = self._phase_message_waiters.pop(phase) d.callback(message) - def _ws_handle_message(self, msg): + def _response_handle_message(self, msg): side = msg["side"] phase = msg["phase"] body = unhexlify(msg["body"].encode("ascii")) @@ -443,12 +537,35 @@ class _Wormhole: if phase == u"confirm": self._event_received_confirm(body) # now notify anyone waiting on it - self._wakeup() + try: + data_key = self.derive_key(u"wormhole:phase:%s" % phase) + inbound_data = self._decrypt_data(data_key, body) + except CryptoError: + raise WrongPasswordError + self._phase_messages_received[phase] = inbound_data + if phase in self._receive_waiters: + d = self._receive_waiters.pop(phase) + d.callback(inbound_data) - def _event_asked_to_send_phase_message(self, phase, message): - pm = (phase, message) - self._phase_messages_to_send.append(pm) - self._maybe_send_phase_messages() + def _decrypt_data(self, key, encrypted): + assert isinstance(key, type(b"")), type(key) + assert isinstance(encrypted, type(b"")), type(encrypted) + assert len(key) == SecretBox.KEY_SIZE, len(key) + box = SecretBox(key) + data = box.decrypt(encrypted) + return data + + @inlineCallbacks + def get(self): + if self._closed: raise UsageError + if self._code is None: raise UsageError + phase = self._next_receive_phase + self._next_receive_phase += 1 + with self._timing.add("API get", phase=phase): + if phase in self._phase_messages_received: + returnValue(self._phase_messages_received[phase]) + d = self._receive_waiters[phase] = defer.Deferred() + yield d def _event_asked_to_close(self): pass diff --git a/tox.ini b/tox.ini index 7ee4127..326afa5 100644 --- a/tox.ini +++ b/tox.ini @@ -19,6 +19,7 @@ skip_missing_interpreters = True [testenv] deps = pyflakes >= 1.2.3 + mock {env:EXTRA_DEPENDENCY:} commands = pyflakes setup.py src