From 5241c07b8c40dd9cabc2830169490a1537145ffe Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 11 Jun 2015 17:34:48 -0700 Subject: [PATCH 1/6] copy eventsource.py from petmail c98d5a0 --- src/wormhole/twisted/eventsource.py | 205 ++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 src/wormhole/twisted/eventsource.py diff --git a/src/wormhole/twisted/eventsource.py b/src/wormhole/twisted/eventsource.py new file mode 100644 index 0000000..0630fe5 --- /dev/null +++ b/src/wormhole/twisted/eventsource.py @@ -0,0 +1,205 @@ +from twisted.python import log, failure +from twisted.internet import reactor, defer, protocol +from twisted.application import service +from twisted.protocols import basic +from twisted.web.client import Agent, ResponseDone +from twisted.web.http_headers import Headers +from .eventual import eventually + +class EventSourceParser(basic.LineOnlyReceiver): + delimiter = "\n" + + def __init__(self, handler): + self.current_field = None + self.current_lines = [] + self.handler = handler + self.done_deferred = defer.Deferred() + + def connectionLost(self, why): + if why.check(ResponseDone): + why = None + self.done_deferred.callback(why) + + def dataReceived(self, data): + # exceptions here aren't being logged properly, and tests will hang + # rather than halt. I suspect twisted.web._newclient's + # HTTP11ClientProtocol.dataReceived(), which catches everything and + # responds with self._giveUp() but doesn't log.err. + try: + basic.LineOnlyReceiver.dataReceived(self, data) + except: + log.err() + raise + + def lineReceived(self, line): + if not line: + # blank line ends the field + self.fieldReceived(self.current_field, + "\n".join(self.current_lines)) + self.current_field = None + self.current_lines[:] = [] + return + if self.current_field is None: + self.current_field, data = line.split(": ", 1) + self.current_lines.append(data) + else: + self.current_lines.append(line) + + def fieldReceived(self, name, data): + self.handler(name, data) + +class EventSourceError(Exception): + pass + +# es = EventSource(url, handler) +# d = es.start() +# es.cancel() + +class EventSource: # TODO: service.Service + def __init__(self, url, handler, when_connected=None): + self.url = url + self.handler = handler + self.when_connected = when_connected + self.started = False + self.cancelled = False + self.proto = EventSourceParser(self.handler) + + def start(self): + assert not self.started, "single-use" + self.started = True + a = Agent(reactor) + d = a.request("GET", self.url, + Headers({"accept": ["text/event-stream"]})) + d.addCallback(self._connected) + return d + + def _connected(self, resp): + if resp.code != 200: + raise EventSourceError("%d: %s" % (resp.code, resp.phrase)) + if self.when_connected: + self.when_connected() + #if resp.headers.getRawHeaders("content-type") == ["text/event-stream"]: + resp.deliverBody(self.proto) + if self.cancelled: + self.kill_connection() + return self.proto.done_deferred + + def cancel(self): + self.cancelled = True + if not self.proto.transport: + # _connected hasn't been called yet, but that self.cancelled + # should take care of it when the connection is established + def kill(data): + # this should kill it as soon as any data is delivered + raise ValueError("dead") + self.proto.dataReceived = kill # just in case + return + self.kill_connection() + + def kill_connection(self): + if (hasattr(self.proto.transport, "_producer") + and self.proto.transport._producer): + # This is gross and fragile. We need a clean way to stop the + # client connection. p.transport is a + # twisted.web._newclient.TransportProxyProducer , and its + # ._producer is the tcp.Port. + self.proto.transport._producer.loseConnection() + else: + log.err("get_events: unable to stop connection") + # oh well + #err = EventSourceError("unable to cancel") + try: + self.proto.done_deferred.callback(None) + except defer.AlreadyCalledError: + pass + + +class Connector: + # behave enough like an IConnector to appease ReconnectingClientFactory + def __init__(self, res): + self.res = res + def connect(self): + self.res._maybeStart() + def stopConnecting(self): + self.res._stop_eventsource() + +class ReconnectingEventSource(service.MultiService, + protocol.ReconnectingClientFactory): + def __init__(self, baseurl, connection_starting, handler): + service.MultiService.__init__(self) + # we don't use any of the basic Factory/ClientFactory methods of + # this, just the ReconnectingClientFactory.retry, stopTrying, and + # resetDelay methods. + + self.baseurl = baseurl + self.connection_starting = connection_starting + self.handler = handler + # IService provides self.running, toggled by {start,stop}Service. + # self.active is toggled by {,de}activate. If both .running and + # .active are True, then we want to have an outstanding EventSource + # and will start one if necessary. If either is False, then we don't + # want one to be outstanding, and will initiate shutdown. + self.active = False + self.connector = Connector(self) + self.es = None # set we have an outstanding EventSource + self.when_stopped = [] # list of Deferreds + + def isStopped(self): + return not self.es + + def startService(self): + service.MultiService.startService(self) # sets self.running + self._maybeStart() + + def stopService(self): + # clears self.running + d = defer.maybeDeferred(service.MultiService.stopService, self) + d.addCallback(self._maybeStop) + return d + + def activate(self): + assert not self.active + self.active = True + self._maybeStart() + + def deactivate(self): + assert self.active # XXX + self.active = False + return self._maybeStop() + + def _maybeStart(self): + if not (self.active and self.running): + return + self.continueTrying = True + url = self.connection_starting() + self.es = EventSource(url, self.handler, self.resetDelay) + d = self.es.start() + d.addBoth(self._stopped) + + def _stopped(self, res): + self.es = None + # we might have stopped because of a connection error, or because of + # an intentional shutdown. + if self.active and self.running: + # we still want to be connected, so schedule a reconnection + if isinstance(res, failure.Failure): + log.err(res) + self.retry() # will eventually call _maybeStart + return + # intentional shutdown + self.stopTrying() + for d in self.when_stopped: + eventually(d.callback, None) + self.when_stopped = [] + + def _stop_eventsource(self): + if self.es: + eventually(self.es.cancel) + + def _maybeStop(self, _=None): + self.stopTrying() # cancels timer, calls _stop_eventsource() + if not self.es: + return defer.succeed(None) + d = defer.Deferred() + self.when_stopped.append(d) + return d From 951da1a59b883d56a2956dc41704ea1982ed5d75 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 20 Jun 2015 18:54:37 -0700 Subject: [PATCH 2/6] eventsource: add Agent, deliver eventtype correctly import eventual.py from the right place --- src/wormhole/twisted/eventsource.py | 32 +++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/src/wormhole/twisted/eventsource.py b/src/wormhole/twisted/eventsource.py index 0630fe5..521ac76 100644 --- a/src/wormhole/twisted/eventsource.py +++ b/src/wormhole/twisted/eventsource.py @@ -4,7 +4,7 @@ from twisted.application import service from twisted.protocols import basic from twisted.web.client import Agent, ResponseDone from twisted.web.http_headers import Headers -from .eventual import eventually +from ..util.eventual import eventually class EventSourceParser(basic.LineOnlyReceiver): delimiter = "\n" @@ -14,6 +14,7 @@ class EventSourceParser(basic.LineOnlyReceiver): self.current_lines = [] self.handler = handler self.done_deferred = defer.Deferred() + self.eventtype = "message" def connectionLost(self, why): if why.check(ResponseDone): @@ -45,8 +46,17 @@ class EventSourceParser(basic.LineOnlyReceiver): else: self.current_lines.append(line) - def fieldReceived(self, name, data): - self.handler(name, data) + def fieldReceived(self, fieldname, data): + if fieldname == "event": + self.eventtype = data + elif fieldname == "data": + self.eventReceived(self.eventtype, data) + self.eventtype = "message" + else: + log.msg("weird fieldname", fieldname, data) + + def eventReceived(self, eventtype, data): + self.handler(eventtype, data) class EventSourceError(Exception): pass @@ -56,20 +66,22 @@ class EventSourceError(Exception): # es.cancel() class EventSource: # TODO: service.Service - def __init__(self, url, handler, when_connected=None): + def __init__(self, url, handler, when_connected=None, agent=None): self.url = url self.handler = handler self.when_connected = when_connected self.started = False self.cancelled = False self.proto = EventSourceParser(self.handler) + if not agent: + agent = Agent(reactor) + self.agent = agent def start(self): assert not self.started, "single-use" self.started = True - a = Agent(reactor) - d = a.request("GET", self.url, - Headers({"accept": ["text/event-stream"]})) + d = self.agent.request("GET", self.url, + Headers({"accept": ["text/event-stream"]})) d.addCallback(self._connected) return d @@ -125,7 +137,7 @@ class Connector: class ReconnectingEventSource(service.MultiService, protocol.ReconnectingClientFactory): - def __init__(self, baseurl, connection_starting, handler): + def __init__(self, baseurl, connection_starting, handler, agent=None): service.MultiService.__init__(self) # we don't use any of the basic Factory/ClientFactory methods of # this, just the ReconnectingClientFactory.retry, stopTrying, and @@ -134,6 +146,7 @@ class ReconnectingEventSource(service.MultiService, self.baseurl = baseurl self.connection_starting = connection_starting self.handler = handler + self.agent = agent # IService provides self.running, toggled by {start,stop}Service. # self.active is toggled by {,de}activate. If both .running and # .active are True, then we want to have an outstanding EventSource @@ -172,7 +185,8 @@ class ReconnectingEventSource(service.MultiService, return self.continueTrying = True url = self.connection_starting() - self.es = EventSource(url, self.handler, self.resetDelay) + self.es = EventSource(url, self.handler, self.resetDelay, + agent=self.agent) d = self.es.start() d.addBoth(self._stopped) From 85dd3ba9485013dc3e09067cd3074a14361f4bb7 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 20 Jun 2015 18:53:57 -0700 Subject: [PATCH 3/6] make twisted/ a real package --- src/wormhole/twisted/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/wormhole/twisted/__init__.py diff --git a/src/wormhole/twisted/__init__.py b/src/wormhole/twisted/__init__.py new file mode 100644 index 0000000..e69de29 From 0f58f3906dc726bca3b46b02281d2510eecf48f2 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 20 Jun 2015 18:36:22 -0700 Subject: [PATCH 4/6] rough out twisted.SymmetricWormhole --- src/wormhole/twisted/transcribe.py | 264 ++++++++++++++++++++++++++++- 1 file changed, 260 insertions(+), 4 deletions(-) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index dfdf821..090cc97 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -1,12 +1,84 @@ -from twisted.application import service -from ..const import RELAY +from __future__ import print_function +import sys, json +from binascii import hexlify, unhexlify +from zope.interface import implementer +#from twisted.application import service +from twisted.internet import reactor, defer +from twisted.web import client as web_client +from twisted.web import error as web_error +from twisted.web.iweb import IBodyProducer +from nacl.secret import SecretBox +from nacl.exceptions import CryptoError +from spake2 import SPAKE2_Symmetric +from .eventsource import ReconnectingEventSource +from .. import __version__ +from .. import codes +from ..errors import ServerError +from ..util.hkdf import HKDF +class WrongPasswordError(Exception): + """ + Key confirmation failed. + """ + +@implementer(IBodyProducer) +class DataProducer: + def __init__(self, data): + self.data = data + def startProducing(self, consumer): + consumer.write(self.data) + return defer.succeed(None) + def stopProducing(self): + pass + def pauseProducing(self): + pass + def resumeProducing(self): + pass + +''' class TwistedInitiator(service.MultiService): - def __init__(self, appid, data, reactor, relay=RELAY): + """I am a service, and I must be running to function. Either call my + .startService() method, or .setServiceParent() me to some other running + service. You can use i.when_done().addCallback(i.disownServiceParent) to + make me go away when everything is done. + """ + def __init__(self, appid, data, reactor, relay): self.appid = appid self.data = data self.reactor = reactor self.relay = relay + self.code = None + + def set_code(self, code): # used for human-made pre-generated codes + assert self.code is None + mo = re.search(r'^(\d+)-', code) + if not mo: + raise ValueError("code (%s) must start with NN-" % code) + self.channel_id = int(mo.group(1)) + self.code = code + self.sp = SPAKE2_A(self.code.encode("ascii"), + idA=self.appid+":Initiator", + idB=self.appid+":Receiver") + + def get_code(self, length=2): + assert self.code is None + d = self._allocate_channel() + def _got_channel_id(channel_id): + code = codes.make_code(channel_id, code_length) + self.set_code(code) + return code + d.addCallback(_got_channel_id) + return d + + def serialize(self): + if not self.code: + raise ValueEror + + def get_data(self, outbound_data): + msg = self.sp.start() + # change SPAKE2 to choose random_scalar earlier, to make getting the + # first message idempotent. + ... def when_get_code(self): pass # return Deferred @@ -15,7 +87,7 @@ class TwistedInitiator(service.MultiService): pass # return Deferred class TwistedReceiver(service.MultiService): - def __init__(self, appid, data, code, reactor, relay=RELAY): + def __init__(self, appid, data, code, reactor, relay): self.appid = appid self.data = data self.code = code @@ -24,4 +96,188 @@ class TwistedReceiver(service.MultiService): def when_get_data(self): pass # return Deferred +''' + +class SymmetricWormhole: + def __init__(self, appid, relay): + self.appid = appid + self.relay = relay + self.agent = web_client.Agent(reactor) + self.key = None + + def set_code(self, code): + assert self.code is None + self.code = code + # allocate the rest now too, so it can be serialized + self.sp = SPAKE2_Symmetric(self.code.encode("ascii"), + idA=self.appid+":SymmetricA", + idB=self.appid+":SymmetricB") + self.msg1 = self.sp.start() + + def _allocate_channel(self): + url = self.relay + "allocate/%s" % self.side + d = self.post(url) + def _got_channel(data_json): + data = json.loads(data_json) + if "welcome" in data: + self.handle_welcome(data["welcome"]) + return data["channel-id"] + d.addCallback(_got_channel) + return d + + def _deallocate(self, res): + d = self.agent.request("POST", self.url("deallocate")) + d.addBoth(lambda _: res) # ignore POST failure, pass-through result + return d + + def get_code(self, code_length=2): + if self.code is not None: + return defer.succeed(self.code) + d = self._allocate_channel() + def _got_channel_id(channel_id): + code = codes.make_code(channel_id, code_length) + self.set_code(code) + return code + d.addCallback(_got_channel_id) + return d + + motd_displayed = False + version_warning_displayed = False + + def handle_welcome(self, welcome): + if ("motd" in welcome and + not self.motd_displayed): + motd_lines = welcome["motd"].splitlines() + motd_formatted = "\n ".join(motd_lines) + print("Server (at %s) says:\n %s" % (self.relay, motd_formatted), + file=sys.stderr) + self.motd_displayed = True + + # Only warn if we're running a release version (e.g. 0.0.6, not + # 0.0.6-DISTANCE-gHASH). Only warn once. + if ("-" not in __version__ and + not self.version_warning_displayed and + welcome["current_version"] != __version__): + print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) + print("Server claims %s is current, but ours is %s" + % (welcome["current_version"], __version__), file=sys.stderr) + self.version_warning_displayed = True + + if "error" in welcome: + raise ServerError(welcome["error"], self.relay) + + def url(self, verb, msgnum=None): + url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb) + if msgnum is not None: + url += "/" + msgnum + return url + + def post(self, url, post_json=None): + p = None + if post_json: + data = json.dumps(post_json).encode("utf-8") + p = DataProducer(data) + d = self.agent.request("POST", url, bodyProducer=p) + def _check_error(resp): + if resp.code != 200: + raise web_error.Error(resp.code, resp.phrase) + return resp + d.addCallback(_check_error) + d.addCallback(web_client.readBody) + d.addCallback(lambda data: json.loads(data)) + return d + + def get_msgs(self, old_msgs, verb, msgnum): + # fire with a list of messages that match verb/msgnum, which either + # came from old_msgs, or from an EventSource that we attached to the + # corresponding URL + if old_msgs: + return defer.succeed(old_msgs) + d = defer.Deferred() + msgs = [] + def _handle(name, data): + if name == "welcome": + self.handle_welcome(json.loads(data)) + if name == "message": + msgs.extend(json.loads(data)["message"]) + d.callback(None) + es = ReconnectingEventSource(None, lambda: self.url("post", "pake"), + _handle)#, agent=self.agent) + es.startService() # TODO: .setServiceParent(self) + es.activate() + d.addCallback(lambda _: es.deactivate()) + d.addCallback(lambda _: es.stopService()) + d.addCallback(lambda _: msgs) + return d + + def derive_key(self, purpose, length=SecretBox.KEY_SIZE): + assert type(purpose) == type(b"") + return HKDF(self.key, length, CTXinfo=purpose) + + + def _get_key(self): + # TODO: prevent multiple invocation + if self.key: + return defer.succeed(self.key) + data = {"message": hexlify(self.msg1).decode("ascii")} + d = self.post(self.url("post", "pake"), data) + d.addCallback(lambda j: self.get_msgs(j["messages"], "poll", "pake")) + def _got_pake(msgs): + pake_msg = unhexlify(msgs[0].encode("ascii")) + key = self.sp.finish(pake_msg) + self.key = key + self.verifier = self.derive_key(self.appid+b":Verifier") + return key + d.addCallback(_got_pake) + return d + + def get_verifier(self): + d = self._get_key() + d.addCallback(lambda _: self.verifier) + return d + + def get_data(self, outbound_data): + # only call this once + d = self._get_key() + def _got_key(_): + outbound_key = self.derive_key(b"sender") + outbound_encrypted = self._encrypt_data(outbound_key, outbound_data) + data = {"message": hexlify(outbound_encrypted).decode("ascii")} + return self.post(self.url("post", "data"), data) + d.addCallback(lambda j: self.get_msgs(j["messages"], "poll", "data")) + def _got_data(msgs): + inbound_encrypted = unhexlify(msgs[0].encode("ascii")) + inbound_key = self.derive_key(b"receiver") + try: + inbound_data = self._decrypt_data(inbound_key, + inbound_encrypted) + return inbound_data + except CryptoError: + raise WrongPasswordError + d.addCallback(_got_data) + d.addBoth(self._deallocate) + return d + + + def serialize(self): + assert self.code is not None + data = { + "appid": self.appid, + "payload_for_them": self.payload_for_them.encode("hex"), + "relay": self.relay, + "code": self.code, + "wormhole": json.loads(self.sp.serialize()), + "msg1": self.msg1.encode("hex"), + } + return json.dumps(data) + + @classmethod + def from_serialized(klass, data): + d = json.loads(data) + self = klass(str(d["appid"]), d["payload_for_them"].decode("hex"), + str(d["relay"])) + self.code = str(d["code"]) + self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["wormhole"])) + self.msg1 = d["msg1"].decode("hex") + return self From 25472423c6be64ece64d5f01d013b0179432d4ea Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 20 Jun 2015 19:18:21 -0700 Subject: [PATCH 5/6] make twisted work, get serialization into shape, add proper tests --- src/wormhole/test/test_twisted.py | 132 ++++++++++++++++ src/wormhole/twisted/transcribe.py | 236 ++++++++++++++--------------- 2 files changed, 244 insertions(+), 124 deletions(-) create mode 100644 src/wormhole/test/test_twisted.py diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py new file mode 100644 index 0000000..51be63d --- /dev/null +++ b/src/wormhole/test/test_twisted.py @@ -0,0 +1,132 @@ +import json +from twisted.trial import unittest +from twisted.internet import defer, protocol, endpoints, reactor +from twisted.application import service +from ..servers.relay import RelayServer +from ..twisted.transcribe import SymmetricWormhole, UsageError +from .. import __version__ +#from twisted.python import log +#import sys +#log.startLogging(sys.stdout) + +def allocate_port(): + ep = endpoints.serverFromString(reactor, "tcp:0:interface=127.0.0.1") + d = ep.listen(protocol.Factory()) + def _listening(lp): + port = lp.getHost().port + d2 = lp.stopListening() + d2.addCallback(lambda _: port) + return d2 + d.addCallback(_listening) + return d + +def allocate_ports(): + d = defer.DeferredList([allocate_port(), allocate_port()]) + def _done(results): + port1 = results[0][1] + port2 = results[1][1] + return (port1, port2) + d.addCallback(_done) + return d + +class Basic(unittest.TestCase): + def setUp(self): + self.sp = service.MultiService() + self.sp.startService() + d = allocate_ports() + def _got_ports(ports): + relayport, transitport = ports + s = RelayServer("tcp:%d:interface=127.0.0.1" % relayport, + "tcp:%s:interface=127.0.0.1" % transitport, + __version__) + s.setServiceParent(self.sp) + self.relayurl = "http://127.0.0.1:%d/wormhole-relay/" % relayport + self.transit = "tcp:127.0.0.1:%d" % transitport + d.addCallback(_got_ports) + return d + + def tearDown(self): + return self.sp.stopService() + + def test_basic(self): + appid = "appid" + w1 = SymmetricWormhole(appid, self.relayurl) + w2 = SymmetricWormhole(appid, self.relayurl) + d = w1.get_code() + def _got_code(code): + w2.set_code(code) + d1 = w1.get_data("data1") + d2 = w2.get_data("data2") + return defer.DeferredList([d1,d2], fireOnOneErrback=False) + d.addCallback(_got_code) + def _done(dl): + ((success1, dataX), (success2, dataY)) = dl + r1,r2 = dl + self.assertTrue(success1) + self.assertTrue(success2) + self.assertEqual(dataX, "data2") + self.assertEqual(dataY, "data1") + d.addCallback(_done) + return d + + def test_fixed_code(self): + appid = "appid" + w1 = SymmetricWormhole(appid, self.relayurl) + w2 = SymmetricWormhole(appid, self.relayurl) + w1.set_code("123-purple-elephant") + w2.set_code("123-purple-elephant") + d1 = w1.get_data("data1") + d2 = w2.get_data("data2") + d = defer.DeferredList([d1,d2], fireOnOneErrback=False) + def _done(dl): + ((success1, dataX), (success2, dataY)) = dl + r1,r2 = dl + self.assertTrue(success1) + self.assertTrue(success2) + self.assertEqual(dataX, "data2") + self.assertEqual(dataY, "data1") + d.addCallback(_done) + return d + + def test_errors(self): + appid = "appid" + w1 = SymmetricWormhole(appid, self.relayurl) + self.assertRaises(UsageError, w1.get_verifier) + self.assertRaises(UsageError, w1.get_data, "data") + w1.set_code("123-purple-elephant") + self.assertRaises(UsageError, w1.set_code, "123-nope") + self.assertRaises(UsageError, w1.get_code) + w2 = SymmetricWormhole(appid, self.relayurl) + d = w2.get_code() + self.assertRaises(UsageError, w2.get_code) + return d + + def test_serialize(self): + appid = "appid" + w1 = SymmetricWormhole(appid, self.relayurl) + self.assertRaises(UsageError, w1.serialize) # too early + w2 = SymmetricWormhole(appid, self.relayurl) + d = w1.get_code() + def _got_code(code): + self.assertRaises(UsageError, w2.serialize) # too early + w2.set_code(code) + w2.serialize() # ok + s = w1.serialize() + self.assertEqual(type(s), type("")) + unpacked = json.loads(s) # this is supposed to be JSON + self.assertEqual(type(unpacked), dict) + new_w1 = SymmetricWormhole.from_serialized(s) + d1 = new_w1.get_data("data1") + d2 = w2.get_data("data2") + return defer.DeferredList([d1,d2], fireOnOneErrback=False) + d.addCallback(_got_code) + def _done(dl): + ((success1, dataX), (success2, dataY)) = dl + r1,r2 = dl + self.assertTrue(success1) + self.assertTrue(success2) + self.assertEqual(dataX, "data2") + self.assertEqual(dataY, "data1") + self.assertRaises(UsageError, w2.serialize) # too late + d.addCallback(_done) + return d diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 090cc97..775dd67 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -1,14 +1,14 @@ from __future__ import print_function -import sys, json +import os, sys, json, re from binascii import hexlify, unhexlify from zope.interface import implementer -#from twisted.application import service from twisted.internet import reactor, defer from twisted.web import client as web_client from twisted.web import error as web_error from twisted.web.iweb import IBodyProducer from nacl.secret import SecretBox from nacl.exceptions import CryptoError +from nacl import utils from spake2 import SPAKE2_Symmetric from .eventsource import ReconnectingEventSource from .. import __version__ @@ -21,10 +21,17 @@ class WrongPasswordError(Exception): Key confirmation failed. """ +class ReflectionAttack(Exception): + """An attacker (or bug) reflected our outgoing message back to us.""" + +class UsageError(Exception): + """The programmer did something wrong.""" + @implementer(IBodyProducer) class DataProducer: def __init__(self, data): self.data = data + self.length = len(data) def startProducing(self, consumer): consumer.write(self.data) return defer.succeed(None) @@ -35,112 +42,87 @@ class DataProducer: def resumeProducing(self): pass -''' -class TwistedInitiator(service.MultiService): - """I am a service, and I must be running to function. Either call my - .startService() method, or .setServiceParent() me to some other running - service. You can use i.when_done().addCallback(i.disownServiceParent) to - make me go away when everything is done. - """ - def __init__(self, appid, data, reactor, relay): - self.appid = appid - self.data = data - self.reactor = reactor - self.relay = relay - self.code = None - - def set_code(self, code): # used for human-made pre-generated codes - assert self.code is None - mo = re.search(r'^(\d+)-', code) - if not mo: - raise ValueError("code (%s) must start with NN-" % code) - self.channel_id = int(mo.group(1)) - self.code = code - self.sp = SPAKE2_A(self.code.encode("ascii"), - idA=self.appid+":Initiator", - idB=self.appid+":Receiver") - - def get_code(self, length=2): - assert self.code is None - d = self._allocate_channel() - def _got_channel_id(channel_id): - code = codes.make_code(channel_id, code_length) - self.set_code(code) - return code - d.addCallback(_got_channel_id) - return d - - def serialize(self): - if not self.code: - raise ValueEror - - def get_data(self, outbound_data): - msg = self.sp.start() - # change SPAKE2 to choose random_scalar earlier, to make getting the - # first message idempotent. - ... - - def when_get_code(self): - pass # return Deferred - - def when_get_data(self): - pass # return Deferred - -class TwistedReceiver(service.MultiService): - def __init__(self, appid, data, code, reactor, relay): - self.appid = appid - self.data = data - self.code = code - self.reactor = reactor - self.relay = relay - - def when_get_data(self): - pass # return Deferred -''' - class SymmetricWormhole: def __init__(self, appid, relay): self.appid = appid self.relay = relay self.agent = web_client.Agent(reactor) + self.side = None + self.code = None self.key = None + self._started_get_code = False - def set_code(self, code): - assert self.code is None - self.code = code - # allocate the rest now too, so it can be serialized - self.sp = SPAKE2_Symmetric(self.code.encode("ascii"), - idA=self.appid+":SymmetricA", - idB=self.appid+":SymmetricB") - self.msg1 = self.sp.start() + def get_code(self, code_length=2): + if self.code is not None: raise UsageError + if self._started_get_code: raise UsageError + self._started_get_code = True + self.side = hexlify(os.urandom(5)) + d = self._allocate_channel() + def _got_channel_id(channel_id): + code = codes.make_code(channel_id, code_length) + self._set_code_and_channel_id(code) + self._start() + return code + d.addCallback(_got_channel_id) + return d def _allocate_channel(self): url = self.relay + "allocate/%s" % self.side d = self.post(url) - def _got_channel(data_json): - data = json.loads(data_json) + def _got_channel(data): if "welcome" in data: self.handle_welcome(data["welcome"]) return data["channel-id"] d.addCallback(_got_channel) return d - def _deallocate(self, res): - d = self.agent.request("POST", self.url("deallocate")) - d.addBoth(lambda _: res) # ignore POST failure, pass-through result - return d + def set_code(self, code): + if self.code is not None: raise UsageError + if self.side is not None: raise UsageError + self._set_code_and_channel_id(code) + self.side = hexlify(os.urandom(5)) + self._start() - def get_code(self, code_length=2): - if self.code is not None: - return defer.succeed(self.code) - d = self._allocate_channel() - def _got_channel_id(channel_id): - code = codes.make_code(channel_id, code_length) - self.set_code(code) - return code - d.addCallback(_got_channel_id) - return d + def _set_code_and_channel_id(self, code): + if self.code is not None: raise UsageError + mo = re.search(r'^(\d+)-', code) + if not mo: + raise ValueError("code (%s) must start with NN-" % code) + self.channel_id = int(mo.group(1)) + self.code = code + + def _start(self): + # allocate the rest now too, so it can be serialized + self.sp = SPAKE2_Symmetric(self.code.encode("ascii"), + idA=self.appid+":SymmetricA", + idB=self.appid+":SymmetricB") + self.msg1 = self.sp.start() + + def serialize(self): + # I can only be serialized after get_code/set_code and before + # get_verifier/get_data + if self.code is None: raise UsageError + if self.key is not None: raise UsageError + data = { + "appid": self.appid, + "relay": self.relay, + "code": self.code, + "side": self.side, + "spake2": json.loads(self.sp.serialize()), + "msg1": self.msg1.encode("hex"), + } + return json.dumps(data) + + @classmethod + def from_serialized(klass, data): + d = json.loads(data) + self = klass(d["appid"].encode("ascii"), d["relay"].encode("ascii")) + self._set_code_and_channel_id(d["code"].encode("ascii")) + self.side = d["side"].encode("ascii") + self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["spake2"])) + self.msg1 = d["msg1"].decode("hex") + return self motd_displayed = False version_warning_displayed = False @@ -174,6 +156,8 @@ class SymmetricWormhole: return url def post(self, url, post_json=None): + # TODO: retry on failure, with exponential backoff. We're guarding + # against the rendezvous server being temporarily offline. p = None if post_json: data = json.dumps(post_json).encode("utf-8") @@ -188,7 +172,7 @@ class SymmetricWormhole: d.addCallback(lambda data: json.loads(data)) return d - def get_msgs(self, old_msgs, verb, msgnum): + def _get_msgs(self, old_msgs, verb, msgnum): # fire with a list of messages that match verb/msgnum, which either # came from old_msgs, or from an EventSource that we attached to the # corresponding URL @@ -200,9 +184,9 @@ class SymmetricWormhole: if name == "welcome": self.handle_welcome(json.loads(data)) if name == "message": - msgs.extend(json.loads(data)["message"]) + msgs.append(json.loads(data)["message"]) d.callback(None) - es = ReconnectingEventSource(None, lambda: self.url("post", "pake"), + es = ReconnectingEventSource(None, lambda: self.url(verb, msgnum), _handle)#, agent=self.agent) es.startService() # TODO: .setServiceParent(self) es.activate() @@ -212,9 +196,22 @@ class SymmetricWormhole: return d def derive_key(self, purpose, length=SecretBox.KEY_SIZE): + assert self.key is not None # call after get_verifier() or get_data() assert type(purpose) == type(b"") return HKDF(self.key, length, CTXinfo=purpose) + def _encrypt_data(self, key, data): + assert len(key) == SecretBox.KEY_SIZE + box = SecretBox(key) + nonce = utils.random(SecretBox.NONCE_SIZE) + return box.encrypt(data, nonce) + + def _decrypt_data(self, key, encrypted): + assert len(key) == SecretBox.KEY_SIZE + box = SecretBox(key) + data = box.decrypt(encrypted) + return data + def _get_key(self): # TODO: prevent multiple invocation @@ -222,7 +219,7 @@ class SymmetricWormhole: return defer.succeed(self.key) data = {"message": hexlify(self.msg1).decode("ascii")} d = self.post(self.url("post", "pake"), data) - d.addCallback(lambda j: self.get_msgs(j["messages"], "poll", "pake")) + d.addCallback(lambda j: self._get_msgs(j["messages"], "poll", "pake")) def _got_pake(msgs): pake_msg = unhexlify(msgs[0].encode("ascii")) key = self.sp.finish(pake_msg) @@ -233,25 +230,33 @@ class SymmetricWormhole: return d def get_verifier(self): + if self.code is None: raise UsageError d = self._get_key() d.addCallback(lambda _: self.verifier) return d def get_data(self, outbound_data): # only call this once + if self.code is None: raise UsageError d = self._get_key() - def _got_key(_): - outbound_key = self.derive_key(b"sender") - outbound_encrypted = self._encrypt_data(outbound_key, outbound_data) - data = {"message": hexlify(outbound_encrypted).decode("ascii")} - return self.post(self.url("post", "data"), data) - d.addCallback(lambda j: self.get_msgs(j["messages"], "poll", "data")) + d.addCallback(self._get_data2, outbound_data) + return d + + def _get_data2(self, key, outbound_data): + # Without predefined roles, we can't derive predictably unique keys + # for each side, so we use the same key for both. We use random + # nonces to keep the messages distinct, and check for reflection. + data_key = self.derive_key(b"data-key") + outbound_encrypted = self._encrypt_data(data_key, outbound_data) + data = {"message": hexlify(outbound_encrypted).decode("ascii")} + d = self.post(self.url("post", "data"), data) + d.addCallback(lambda j: self._get_msgs(j["messages"], "poll", "data")) def _got_data(msgs): inbound_encrypted = unhexlify(msgs[0].encode("ascii")) - inbound_key = self.derive_key(b"receiver") + if inbound_encrypted == outbound_encrypted: + raise ReflectionAttack try: - inbound_data = self._decrypt_data(inbound_key, - inbound_encrypted) + inbound_data = self._decrypt_data(data_key, inbound_encrypted) return inbound_data except CryptoError: raise WrongPasswordError @@ -259,25 +264,8 @@ class SymmetricWormhole: d.addBoth(self._deallocate) return d - - def serialize(self): - assert self.code is not None - data = { - "appid": self.appid, - "payload_for_them": self.payload_for_them.encode("hex"), - "relay": self.relay, - "code": self.code, - "wormhole": json.loads(self.sp.serialize()), - "msg1": self.msg1.encode("hex"), - } - return json.dumps(data) - - @classmethod - def from_serialized(klass, data): - d = json.loads(data) - self = klass(str(d["appid"]), d["payload_for_them"].decode("hex"), - str(d["relay"])) - self.code = str(d["code"]) - self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["wormhole"])) - self.msg1 = d["msg1"].decode("hex") - return self + def _deallocate(self, res): + # only try once, no retries + d = self.agent.request("POST", self.url("deallocate")) + d.addBoth(lambda _: res) # ignore POST failure, pass-through result + return d From 6ee09f5316d6bfce517b62ff993c69cd371e05b4 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 20 Jun 2015 19:03:10 -0700 Subject: [PATCH 6/6] add demo of twisted flow, update docs python -m wormhole.twisted.demo send-text TEXT -> CODE python -m wormhole.twisted.demo receive-text CODE -> TEXT --- docs/api.md | 72 ++++++++++++++++++++++-------------- src/wormhole/twisted/demo.py | 30 +++++++++++++++ 2 files changed, 74 insertions(+), 28 deletions(-) create mode 100644 src/wormhole/twisted/demo.py diff --git a/docs/api.md b/docs/api.md index 705aa7e..9adefce 100644 --- a/docs/api.md +++ b/docs/api.md @@ -50,39 +50,45 @@ theirdata = r.get_data(mydata) print("Their data: %s" % theirdata.decode("ascii")) ``` -## Twisted (TODO) +## Twisted -The Twisted-friendly flow, which is not yet implemented, may look like this: +The Twisted-friendly flow looks like this: ```python from twisted.internet import reactor -from wormhole.transcribe import TwistedInitiator -data = b"initiator's data" -ti = TwistedInitiator("appid", data, reactor) -ti.startService() -d1 = ti.when_get_code() -d1.addCallback(lambda code: print("Invitation Code: %s" % code)) -d2 = ti.when_get_data() -d2.addCallback(lambda theirdata: - print("Their data: %s" % theirdata.decode("ascii"))) -d2.addCallback(labmda _: reactor.stop()) +from wormhole.public_relay import RENDEZVOUS_RELAY +from wormhole.twisted.transcribe import SymmetricWormhole +outbound_message = b"outbound data" +w1 = SymmetricWormhole("appid", RENDEZVOUS_RELAY) +d = w1.get_code() +def _got_code(code): + print "Invitation Code:", code + return w1.get_data(outbound_message) +d.addCallback(_got_code) +def _got_data(inbound_message): + print "Inbound message:", inbound_message +d.addCallback(_got_data) +d.addBoth(lambda _: reactor.stop()) reactor.run() ``` +On the other side, you call `set_code()` instead of waiting for `get_code()`: + ```python -from twisted.internet import reactor -from wormhole.transcribe import TwistedReceiver -data = b"receiver's data" -code = sys.argv[1] -tr = TwistedReceiver("appid", code, data, reactor) -tr.startService() -d = tr.when_get_data() -d.addCallback(lambda theirdata: - print("Their data: %s" % theirdata.decode("ascii"))) -d.addCallback(lambda _: reactor.stop()) -reactor.run() +w2 = SymmetricWormhole("appid", RENDEZVOUS_RELAY) +w2.set_code(code) +d = w2.get_data(my_message) ``` +You can call `d=w.get_verifier()` before `get_data()`: this will perform the +first half of the PAKE negotiation, then fire the Deferred with a verifier +object (bytes) which can be converted into a printable representation and +manually compared. When the users are convinced that `get_verifier()` from +both sides are the same, call `d=get_data()` to continue the transfer. If you +call `get_data()` first, it will perform the complete transfer without +pausing. + + ## Application Identifier Applications using this library must provide an "application identifier", a @@ -135,17 +141,27 @@ sync will be abandoned, and all callbacks will errback with a TimeoutError. Both have defaults suitable for face-to-face realtime setup environments. -## Serialization (TODO) +## Serialization + +TODO: only the Twisted form supports serialization so far You may not be able to hold the Initiator/Receiver object in memory for the whole sync process: maybe you allow it to wait for several days, but the program will be restarted during that time. To support this, you can persist -the state of the object by calling `data = i.serialize()`, which will return +the state of the object by calling `data = w.serialize()`, which will return a printable bytestring (the JSON-encoding of a small dictionary). To restore, -call `Initiator.from_serialized(data)`. +use the `from_serialized(data)` classmethod (e.g. `w = +SymmetricWormhole.from_serialized(data)`). -Note that callbacks are not serialized: they must be restored after -deserialization. +There is exactly one point at which you can serialize the wormhole: *after* +establishing the invitation code, but before waiting for `get_verifier()` or +`get_data()`. If you are creating a new code, the correct time is during the +callback fired by `get_code()`. If you are accepting a pre-generated code, +the time is just after calling `set_code()`. + +To properly checkpoint the process, you should store the first message +(returned by `start()`) next to the serialized wormhole instance, so you can +re-send it if necessary. ## Detailed Example diff --git a/src/wormhole/twisted/demo.py b/src/wormhole/twisted/demo.py new file mode 100644 index 0000000..adc3b74 --- /dev/null +++ b/src/wormhole/twisted/demo.py @@ -0,0 +1,30 @@ +import sys +from twisted.internet import reactor +from .transcribe import SymmetricWormhole +from .. import public_relay + +APPID = "lothar.com/wormhole/text-xfer" + +w = SymmetricWormhole(APPID, public_relay.RENDEZVOUS_RELAY) + +if sys.argv[1] == "send-text": + message = sys.argv[2] + d = w.get_code() + def _got_code(code): + print "code is:", code + return w.get_data(message) + d.addCallback(_got_code) + def _got_data(their_data): + print "ack:", their_data + d.addCallback(_got_data) +elif sys.argv[1] == "receive-text": + code = sys.argv[2] + w.set_code(code) + d = w.get_data("ok") + def _got_data(their_data): + print their_data + d.addCallback(_got_data) +else: + raise ValueError("bad command") +d.addCallback(lambda _: reactor.stop()) +reactor.run()