From 25472423c6be64ece64d5f01d013b0179432d4ea Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 20 Jun 2015 19:18:21 -0700 Subject: [PATCH] make twisted work, get serialization into shape, add proper tests --- src/wormhole/test/test_twisted.py | 132 ++++++++++++++++ src/wormhole/twisted/transcribe.py | 236 ++++++++++++++--------------- 2 files changed, 244 insertions(+), 124 deletions(-) create mode 100644 src/wormhole/test/test_twisted.py diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py new file mode 100644 index 0000000..51be63d --- /dev/null +++ b/src/wormhole/test/test_twisted.py @@ -0,0 +1,132 @@ +import json +from twisted.trial import unittest +from twisted.internet import defer, protocol, endpoints, reactor +from twisted.application import service +from ..servers.relay import RelayServer +from ..twisted.transcribe import SymmetricWormhole, UsageError +from .. import __version__ +#from twisted.python import log +#import sys +#log.startLogging(sys.stdout) + +def allocate_port(): + ep = endpoints.serverFromString(reactor, "tcp:0:interface=127.0.0.1") + d = ep.listen(protocol.Factory()) + def _listening(lp): + port = lp.getHost().port + d2 = lp.stopListening() + d2.addCallback(lambda _: port) + return d2 + d.addCallback(_listening) + return d + +def allocate_ports(): + d = defer.DeferredList([allocate_port(), allocate_port()]) + def _done(results): + port1 = results[0][1] + port2 = results[1][1] + return (port1, port2) + d.addCallback(_done) + return d + +class Basic(unittest.TestCase): + def setUp(self): + self.sp = service.MultiService() + self.sp.startService() + d = allocate_ports() + def _got_ports(ports): + relayport, transitport = ports + s = RelayServer("tcp:%d:interface=127.0.0.1" % relayport, + "tcp:%s:interface=127.0.0.1" % transitport, + __version__) + s.setServiceParent(self.sp) + self.relayurl = "http://127.0.0.1:%d/wormhole-relay/" % relayport + self.transit = "tcp:127.0.0.1:%d" % transitport + d.addCallback(_got_ports) + return d + + def tearDown(self): + return self.sp.stopService() + + def test_basic(self): + appid = "appid" + w1 = SymmetricWormhole(appid, self.relayurl) + w2 = SymmetricWormhole(appid, self.relayurl) + d = w1.get_code() + def _got_code(code): + w2.set_code(code) + d1 = w1.get_data("data1") + d2 = w2.get_data("data2") + return defer.DeferredList([d1,d2], fireOnOneErrback=False) + d.addCallback(_got_code) + def _done(dl): + ((success1, dataX), (success2, dataY)) = dl + r1,r2 = dl + self.assertTrue(success1) + self.assertTrue(success2) + self.assertEqual(dataX, "data2") + self.assertEqual(dataY, "data1") + d.addCallback(_done) + return d + + def test_fixed_code(self): + appid = "appid" + w1 = SymmetricWormhole(appid, self.relayurl) + w2 = SymmetricWormhole(appid, self.relayurl) + w1.set_code("123-purple-elephant") + w2.set_code("123-purple-elephant") + d1 = w1.get_data("data1") + d2 = w2.get_data("data2") + d = defer.DeferredList([d1,d2], fireOnOneErrback=False) + def _done(dl): + ((success1, dataX), (success2, dataY)) = dl + r1,r2 = dl + self.assertTrue(success1) + self.assertTrue(success2) + self.assertEqual(dataX, "data2") + self.assertEqual(dataY, "data1") + d.addCallback(_done) + return d + + def test_errors(self): + appid = "appid" + w1 = SymmetricWormhole(appid, self.relayurl) + self.assertRaises(UsageError, w1.get_verifier) + self.assertRaises(UsageError, w1.get_data, "data") + w1.set_code("123-purple-elephant") + self.assertRaises(UsageError, w1.set_code, "123-nope") + self.assertRaises(UsageError, w1.get_code) + w2 = SymmetricWormhole(appid, self.relayurl) + d = w2.get_code() + self.assertRaises(UsageError, w2.get_code) + return d + + def test_serialize(self): + appid = "appid" + w1 = SymmetricWormhole(appid, self.relayurl) + self.assertRaises(UsageError, w1.serialize) # too early + w2 = SymmetricWormhole(appid, self.relayurl) + d = w1.get_code() + def _got_code(code): + self.assertRaises(UsageError, w2.serialize) # too early + w2.set_code(code) + w2.serialize() # ok + s = w1.serialize() + self.assertEqual(type(s), type("")) + unpacked = json.loads(s) # this is supposed to be JSON + self.assertEqual(type(unpacked), dict) + new_w1 = SymmetricWormhole.from_serialized(s) + d1 = new_w1.get_data("data1") + d2 = w2.get_data("data2") + return defer.DeferredList([d1,d2], fireOnOneErrback=False) + d.addCallback(_got_code) + def _done(dl): + ((success1, dataX), (success2, dataY)) = dl + r1,r2 = dl + self.assertTrue(success1) + self.assertTrue(success2) + self.assertEqual(dataX, "data2") + self.assertEqual(dataY, "data1") + self.assertRaises(UsageError, w2.serialize) # too late + d.addCallback(_done) + return d diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 090cc97..775dd67 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -1,14 +1,14 @@ from __future__ import print_function -import sys, json +import os, sys, json, re from binascii import hexlify, unhexlify from zope.interface import implementer -#from twisted.application import service from twisted.internet import reactor, defer from twisted.web import client as web_client from twisted.web import error as web_error from twisted.web.iweb import IBodyProducer from nacl.secret import SecretBox from nacl.exceptions import CryptoError +from nacl import utils from spake2 import SPAKE2_Symmetric from .eventsource import ReconnectingEventSource from .. import __version__ @@ -21,10 +21,17 @@ class WrongPasswordError(Exception): Key confirmation failed. """ +class ReflectionAttack(Exception): + """An attacker (or bug) reflected our outgoing message back to us.""" + +class UsageError(Exception): + """The programmer did something wrong.""" + @implementer(IBodyProducer) class DataProducer: def __init__(self, data): self.data = data + self.length = len(data) def startProducing(self, consumer): consumer.write(self.data) return defer.succeed(None) @@ -35,112 +42,87 @@ class DataProducer: def resumeProducing(self): pass -''' -class TwistedInitiator(service.MultiService): - """I am a service, and I must be running to function. Either call my - .startService() method, or .setServiceParent() me to some other running - service. You can use i.when_done().addCallback(i.disownServiceParent) to - make me go away when everything is done. - """ - def __init__(self, appid, data, reactor, relay): - self.appid = appid - self.data = data - self.reactor = reactor - self.relay = relay - self.code = None - - def set_code(self, code): # used for human-made pre-generated codes - assert self.code is None - mo = re.search(r'^(\d+)-', code) - if not mo: - raise ValueError("code (%s) must start with NN-" % code) - self.channel_id = int(mo.group(1)) - self.code = code - self.sp = SPAKE2_A(self.code.encode("ascii"), - idA=self.appid+":Initiator", - idB=self.appid+":Receiver") - - def get_code(self, length=2): - assert self.code is None - d = self._allocate_channel() - def _got_channel_id(channel_id): - code = codes.make_code(channel_id, code_length) - self.set_code(code) - return code - d.addCallback(_got_channel_id) - return d - - def serialize(self): - if not self.code: - raise ValueEror - - def get_data(self, outbound_data): - msg = self.sp.start() - # change SPAKE2 to choose random_scalar earlier, to make getting the - # first message idempotent. - ... - - def when_get_code(self): - pass # return Deferred - - def when_get_data(self): - pass # return Deferred - -class TwistedReceiver(service.MultiService): - def __init__(self, appid, data, code, reactor, relay): - self.appid = appid - self.data = data - self.code = code - self.reactor = reactor - self.relay = relay - - def when_get_data(self): - pass # return Deferred -''' - class SymmetricWormhole: def __init__(self, appid, relay): self.appid = appid self.relay = relay self.agent = web_client.Agent(reactor) + self.side = None + self.code = None self.key = None + self._started_get_code = False - def set_code(self, code): - assert self.code is None - self.code = code - # allocate the rest now too, so it can be serialized - self.sp = SPAKE2_Symmetric(self.code.encode("ascii"), - idA=self.appid+":SymmetricA", - idB=self.appid+":SymmetricB") - self.msg1 = self.sp.start() + def get_code(self, code_length=2): + if self.code is not None: raise UsageError + if self._started_get_code: raise UsageError + self._started_get_code = True + self.side = hexlify(os.urandom(5)) + d = self._allocate_channel() + def _got_channel_id(channel_id): + code = codes.make_code(channel_id, code_length) + self._set_code_and_channel_id(code) + self._start() + return code + d.addCallback(_got_channel_id) + return d def _allocate_channel(self): url = self.relay + "allocate/%s" % self.side d = self.post(url) - def _got_channel(data_json): - data = json.loads(data_json) + def _got_channel(data): if "welcome" in data: self.handle_welcome(data["welcome"]) return data["channel-id"] d.addCallback(_got_channel) return d - def _deallocate(self, res): - d = self.agent.request("POST", self.url("deallocate")) - d.addBoth(lambda _: res) # ignore POST failure, pass-through result - return d + def set_code(self, code): + if self.code is not None: raise UsageError + if self.side is not None: raise UsageError + self._set_code_and_channel_id(code) + self.side = hexlify(os.urandom(5)) + self._start() - def get_code(self, code_length=2): - if self.code is not None: - return defer.succeed(self.code) - d = self._allocate_channel() - def _got_channel_id(channel_id): - code = codes.make_code(channel_id, code_length) - self.set_code(code) - return code - d.addCallback(_got_channel_id) - return d + def _set_code_and_channel_id(self, code): + if self.code is not None: raise UsageError + mo = re.search(r'^(\d+)-', code) + if not mo: + raise ValueError("code (%s) must start with NN-" % code) + self.channel_id = int(mo.group(1)) + self.code = code + + def _start(self): + # allocate the rest now too, so it can be serialized + self.sp = SPAKE2_Symmetric(self.code.encode("ascii"), + idA=self.appid+":SymmetricA", + idB=self.appid+":SymmetricB") + self.msg1 = self.sp.start() + + def serialize(self): + # I can only be serialized after get_code/set_code and before + # get_verifier/get_data + if self.code is None: raise UsageError + if self.key is not None: raise UsageError + data = { + "appid": self.appid, + "relay": self.relay, + "code": self.code, + "side": self.side, + "spake2": json.loads(self.sp.serialize()), + "msg1": self.msg1.encode("hex"), + } + return json.dumps(data) + + @classmethod + def from_serialized(klass, data): + d = json.loads(data) + self = klass(d["appid"].encode("ascii"), d["relay"].encode("ascii")) + self._set_code_and_channel_id(d["code"].encode("ascii")) + self.side = d["side"].encode("ascii") + self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["spake2"])) + self.msg1 = d["msg1"].decode("hex") + return self motd_displayed = False version_warning_displayed = False @@ -174,6 +156,8 @@ class SymmetricWormhole: return url def post(self, url, post_json=None): + # TODO: retry on failure, with exponential backoff. We're guarding + # against the rendezvous server being temporarily offline. p = None if post_json: data = json.dumps(post_json).encode("utf-8") @@ -188,7 +172,7 @@ class SymmetricWormhole: d.addCallback(lambda data: json.loads(data)) return d - def get_msgs(self, old_msgs, verb, msgnum): + def _get_msgs(self, old_msgs, verb, msgnum): # fire with a list of messages that match verb/msgnum, which either # came from old_msgs, or from an EventSource that we attached to the # corresponding URL @@ -200,9 +184,9 @@ class SymmetricWormhole: if name == "welcome": self.handle_welcome(json.loads(data)) if name == "message": - msgs.extend(json.loads(data)["message"]) + msgs.append(json.loads(data)["message"]) d.callback(None) - es = ReconnectingEventSource(None, lambda: self.url("post", "pake"), + es = ReconnectingEventSource(None, lambda: self.url(verb, msgnum), _handle)#, agent=self.agent) es.startService() # TODO: .setServiceParent(self) es.activate() @@ -212,9 +196,22 @@ class SymmetricWormhole: return d def derive_key(self, purpose, length=SecretBox.KEY_SIZE): + assert self.key is not None # call after get_verifier() or get_data() assert type(purpose) == type(b"") return HKDF(self.key, length, CTXinfo=purpose) + def _encrypt_data(self, key, data): + assert len(key) == SecretBox.KEY_SIZE + box = SecretBox(key) + nonce = utils.random(SecretBox.NONCE_SIZE) + return box.encrypt(data, nonce) + + def _decrypt_data(self, key, encrypted): + assert len(key) == SecretBox.KEY_SIZE + box = SecretBox(key) + data = box.decrypt(encrypted) + return data + def _get_key(self): # TODO: prevent multiple invocation @@ -222,7 +219,7 @@ class SymmetricWormhole: return defer.succeed(self.key) data = {"message": hexlify(self.msg1).decode("ascii")} d = self.post(self.url("post", "pake"), data) - d.addCallback(lambda j: self.get_msgs(j["messages"], "poll", "pake")) + d.addCallback(lambda j: self._get_msgs(j["messages"], "poll", "pake")) def _got_pake(msgs): pake_msg = unhexlify(msgs[0].encode("ascii")) key = self.sp.finish(pake_msg) @@ -233,25 +230,33 @@ class SymmetricWormhole: return d def get_verifier(self): + if self.code is None: raise UsageError d = self._get_key() d.addCallback(lambda _: self.verifier) return d def get_data(self, outbound_data): # only call this once + if self.code is None: raise UsageError d = self._get_key() - def _got_key(_): - outbound_key = self.derive_key(b"sender") - outbound_encrypted = self._encrypt_data(outbound_key, outbound_data) - data = {"message": hexlify(outbound_encrypted).decode("ascii")} - return self.post(self.url("post", "data"), data) - d.addCallback(lambda j: self.get_msgs(j["messages"], "poll", "data")) + d.addCallback(self._get_data2, outbound_data) + return d + + def _get_data2(self, key, outbound_data): + # Without predefined roles, we can't derive predictably unique keys + # for each side, so we use the same key for both. We use random + # nonces to keep the messages distinct, and check for reflection. + data_key = self.derive_key(b"data-key") + outbound_encrypted = self._encrypt_data(data_key, outbound_data) + data = {"message": hexlify(outbound_encrypted).decode("ascii")} + d = self.post(self.url("post", "data"), data) + d.addCallback(lambda j: self._get_msgs(j["messages"], "poll", "data")) def _got_data(msgs): inbound_encrypted = unhexlify(msgs[0].encode("ascii")) - inbound_key = self.derive_key(b"receiver") + if inbound_encrypted == outbound_encrypted: + raise ReflectionAttack try: - inbound_data = self._decrypt_data(inbound_key, - inbound_encrypted) + inbound_data = self._decrypt_data(data_key, inbound_encrypted) return inbound_data except CryptoError: raise WrongPasswordError @@ -259,25 +264,8 @@ class SymmetricWormhole: d.addBoth(self._deallocate) return d - - def serialize(self): - assert self.code is not None - data = { - "appid": self.appid, - "payload_for_them": self.payload_for_them.encode("hex"), - "relay": self.relay, - "code": self.code, - "wormhole": json.loads(self.sp.serialize()), - "msg1": self.msg1.encode("hex"), - } - return json.dumps(data) - - @classmethod - def from_serialized(klass, data): - d = json.loads(data) - self = klass(str(d["appid"]), d["payload_for_them"].decode("hex"), - str(d["relay"])) - self.code = str(d["code"]) - self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["wormhole"])) - self.msg1 = d["msg1"].decode("hex") - return self + def _deallocate(self, res): + # only try once, no retries + d = self.agent.request("POST", self.url("deallocate")) + d.addBoth(lambda _: res) # ignore POST failure, pass-through result + return d