diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index 5b5ef69..8fd84d3 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -8,21 +8,21 @@ from nacl import utils from .eventsource import EventSourceFollower from .. import __version__ from .. import codes -from ..errors import (ServerError, Timeout, WrongPasswordError, - ReflectionAttack, UsageError) +from ..errors import ServerError, Timeout, WrongPasswordError, UsageError from ..util.hkdf import HKDF SECOND = 1 MINUTE = 60*SECOND # relay URLs are: -# GET /list -> {channel-ids: [INT..]} -# POST /allocate/SIDE -> {channel-id: INT} -# these return all messages for CHANNEL-ID= and MSGNUM= but SIDE!= : -# POST /CHANNEL-ID/SIDE/post/MSGNUM {message: STR} -> {messages: [STR..]} -# POST /CHANNEL-ID/SIDE/poll/MSGNUM -> {messages: [STR..]} -# GET /CHANNEL-ID/SIDE/poll/MSGNUM (eventsource) -> STR, STR, .. -# POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted +# GET /list -> {channel-ids: [INT..]} +# POST /allocate {side: SIDE} -> {channel-id: INT} +# these return all messages (base64) for CID= : +# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]} +# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]} +# GET /CID (eventsource) -> {phase:, body:}.. +# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted} +# all JSON responses include a "welcome:{..}" key class Wormhole: motd_displayed = False @@ -40,12 +40,9 @@ class Wormhole: self.code = None self.key = None self.verifier = None - - def _url(self, verb, msgnum=None): - url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb) - if msgnum is not None: - url += "/" + msgnum - return url + self._channel_url = None + self._messages = set() # (phase,body) , body is bytes + self._sent_messages = set() # (phase,body) def handle_welcome(self, welcome): if ("motd" in welcome and @@ -69,18 +66,9 @@ class Wormhole: if "error" in welcome: raise ServerError(welcome["error"], self.relay) - def _post_json(self, url, post_json=None): - # POST to a URL, parsing the response as JSON. Optionally include a - # JSON request body. - data = None - if post_json: - data = json.dumps(post_json).encode("utf-8") - r = requests.post(url, data=data) - r.raise_for_status() - return r.json() - def _allocate_channel(self): - r = requests.post(self.relay + "allocate/%s" % self.side) + data = json.dumps({"side": self.side}).encode("utf-8") + r = requests.post(self.relay + "allocate", data=data) r.raise_for_status() data = r.json() if "welcome" in data: @@ -123,6 +111,7 @@ class Wormhole: if not mo: raise ValueError("code (%s) must start with NN-" % code) self.channel_id = int(mo.group(1)) + self._channel_url = "%s%d" % (self.relay, self.channel_id) self.code = code def _start(self): @@ -131,36 +120,62 @@ class Wormhole: idSymmetric=self.appid) self.msg1 = self.sp.start() - def _post_message(self, url, msg): + def _add_inbound_messages(self, messages): + for msg in messages: + phase = msg["phase"] + body = unhexlify(msg["body"].encode("ascii")) + self._messages.add( (phase, body) ) + + def _find_inbound_message(self, phase): + for (their_phase,body) in self._messages - self._sent_messages: + if their_phase == phase: + return body + return None + + def _send_message(self, phase, msg): # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. + if not isinstance(phase, type(u"")): raise UsageError(type(phase)) if not isinstance(msg, type(b"")): raise UsageError(type(msg)) - resp = self._post_json(url, {"message": hexlify(msg).decode("ascii")}) - return resp["messages"] # other_msgs + self._sent_messages.add( (phase,msg) ) + payload = {"side": self.side, + "phase": phase, + "body": hexlify(msg).decode("ascii")} + data = json.dumps(payload).encode("utf-8") + r = requests.post(self._channel_url, data=data) + r.raise_for_status() + resp = r.json() + self._add_inbound_messages(resp["messages"]) - def _get_message(self, old_msgs, verb, msgnum): + def _get_message(self, phase): + if not isinstance(phase, type(u"")): raise UsageError(type(phase)) # For now, server errors cause the client to fail. TODO: don't. This # will require changing the client to re-post messages when the # server comes back up. - # fire with a bytestring of the first message that matches - # verb/msgnum, which either came from old_msgs, or from an - # EventSource that we attached to the corresponding URL - msgs = old_msgs - while not msgs: + # fire with a bytestring of the first message for 'phase' that wasn't + # one of ours. It will either come from previously-received messages, + # or from an EventSource that we attach to the corresponding URL + body = self._find_inbound_message(phase) + while body is None: remaining = self.started + self.timeout - time.time() if remaining < 0: - raise Timeout - #time.sleep(self.wait) - f = EventSourceFollower(self._url(verb, msgnum), remaining) + return Timeout + f = EventSourceFollower(self._channel_url, remaining) + # we loop here until the connection is lost, or we see the + # message we want for (eventtype, data) in f.iter_events(): if eventtype == "welcome": self.handle_welcome(json.loads(data)) if eventtype == "message": - msgs = [json.loads(data)["message"]] - break - f.close() - return unhexlify(msgs[0].encode("ascii")) + self._add_inbound_messages([json.loads(data)]) + body = self._find_inbound_message(phase) + if body: + f.close() + break + if not body: + time.sleep(self.wait) + return body def derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(b"")): raise UsageError @@ -185,8 +200,8 @@ class Wormhole: def _get_key(self): if not self.key: - old_msgs = self._post_message(self._url("post", "pake"), self.msg1) - pake_msg = self._get_message(old_msgs, "poll", "pake") + self._send_message(u"pake", self.msg1) + pake_msg = self._get_message(u"pake") self.key = self.sp.finish(pake_msg) self.verifier = self.derive_key(self.appid+b":Verifier") @@ -214,11 +229,12 @@ class Wormhole: data_key = self.derive_key(b"data-key") outbound_encrypted = self._encrypt_data(data_key, outbound_data) - msgs = self._post_message(self._url("post", "data"), outbound_encrypted) + self._send_message(u"data", outbound_encrypted) - inbound_encrypted = self._get_message(msgs, "poll", "data") - if inbound_encrypted == outbound_encrypted: - raise ReflectionAttack + inbound_encrypted = self._get_message(u"data") + # _find_inbound_message() ignores any inbound message that matches + # something we previously sent out, so we don't need to explicitly + # check for reflection. A reflection attack will just not progress. try: inbound_data = self._decrypt_data(data_key, inbound_encrypted) return inbound_data @@ -227,5 +243,6 @@ class Wormhole: def _deallocate(self): # only try once, no retries - requests.post(self._url("deallocate")) + data = json.dumps({"side": self.side}).encode("utf-8") + requests.post(self._channel_url+"/deallocate", data=data) # ignore POST failure, don't call r.raise_for_status() diff --git a/src/wormhole/db-schemas/v1.sql b/src/wormhole/db-schemas/v1.sql index 43785de..106beeb 100644 --- a/src/wormhole/db-schemas/v1.sql +++ b/src/wormhole/db-schemas/v1.sql @@ -11,11 +11,11 @@ CREATE TABLE `messages` ( `channel_id` INTEGER, `side` VARCHAR, - `msgnum` VARCHAR, -- not numeric, more of a PAKE-phase indicator string - `message` VARCHAR, + `phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string + `body` VARCHAR, `when` INTEGER ); -CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `msgnum`); +CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `phase`); CREATE TABLE `allocations` ( diff --git a/src/wormhole/servers/relay.py b/src/wormhole/servers/relay.py index 02ad986..00a6605 100644 --- a/src/wormhole/servers/relay.py +++ b/src/wormhole/servers/relay.py @@ -51,112 +51,101 @@ class EventsProtocol: # note: no versions of IE (including the current IE11) support EventSource # relay URLs are: -# GET /list -> {channel-ids: [INT..]} -# POST /allocate/SIDE -> {channel-id: INT} -# these return all messages for CHANNEL-ID= and MSGNUM= but SIDE!= : -# POST /CHANNEL-ID/SIDE/post/MSGNUM {message: STR} -> {messages: [STR..]} -# POST /CHANNEL-ID/SIDE/poll/MSGNUM -> {messages: [STR..]} -# GET /CHANNEL-ID/SIDE/poll/MSGNUM (eventsource) -> STR, STR, .. -# POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted +# GET /list -> {channel-ids: [INT..]} +# POST /allocate {side: SIDE} -> {channel-id: INT} +# these return all messages (base64) for CID= : +# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]} +# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]} +# GET /CID (eventsource) -> {phase:, body:}.. +# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted} +# all JSON responses include a "welcome:{..}" key class Channel(resource.Resource): - isLeaf = True # I handle /CHANNEL-ID/* - def __init__(self, channel_id, relay, db, welcome): resource.Resource.__init__(self) self.channel_id = channel_id self.relay = relay self.db = db self.welcome = welcome - self.event_channels = set() # (side, msgnum, ep) + self.event_channels = set() # ep + self.putChild(b"deallocate", Deallocator(self.channel_id, self.relay)) - def render_GET(self, request): - # rest of URL is: SIDE/poll/MSGNUM - their_side = request.postpath[0].decode("utf-8") - if request.postpath[1] != b"poll": - request.setResponseCode(http.BAD_REQUEST, b"GET to wrong URL") - return b"GET is only for /SIDE/poll/MSGNUM" - their_msgnum = request.postpath[2].decode("utf-8") - if b"text/event-stream" not in (request.getHeader(b"accept") or b""): - request.setResponseCode(http.BAD_REQUEST, b"Must use EventSource") - return b"Must use EventSource (Content-Type: text/event-stream)" - request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") - ep = EventsProtocol(request) - ep.sendEvent(json.dumps(self.welcome), name="welcome") - handle = (their_side, their_msgnum, ep) - self.event_channels.add(handle) - request.notifyFinish().addErrback(self._shutdown, handle) + def get_messages(self, request): + request.setHeader(b"content-type", b"application/json; charset=utf-8") + messages = [] for row in self.db.execute("SELECT * FROM `messages`" " WHERE `channel_id`=?" " ORDER BY `when` ASC", (self.channel_id,)).fetchall(): - self.message_added(row["side"], row["msgnum"], row["message"], - channels=[handle]) + messages.append({"phase": row["phase"], "body": row["body"]}) + data = {"welcome": self.welcome, "messages": messages} + return (json.dumps(data)+"\n").encode("utf-8") + + def render_GET(self, request): + if b"text/event-stream" not in (request.getHeader(b"accept") or b""): + return self.get_messages(request) + request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") + ep = EventsProtocol(request) + ep.sendEvent(json.dumps(self.welcome), name="welcome") + self.event_channels.add(ep) + request.notifyFinish().addErrback(lambda f: + self.event_channels.discard(ep)) + for row in self.db.execute("SELECT * FROM `messages`" + " WHERE `channel_id`=?" + " ORDER BY `when` ASC", + (self.channel_id,)).fetchall(): + data = json.dumps({"phase": row["phase"], "body": row["body"]}) + ep.sendEvent(data) return server.NOT_DONE_YET - def _shutdown(self, _, handle): - self.event_channels.discard(handle) - - def message_added(self, msg_side, msg_msgnum, msg_str, channels=None): - if channels is None: - channels = self.event_channels - for (their_side, their_msgnum, their_ep) in channels: - if msg_side != their_side and msg_msgnum == their_msgnum: - data = json.dumps({ "side": msg_side, "message": msg_str }) - their_ep.sendEvent(data) + def broadcast_message(self, phase, body): + data = json.dumps({"phase": phase, "body": body}) + for ep in self.event_channels: + ep.sendEvent(data) def render_POST(self, request): - # rest of URL is: SIDE/(MSGNUM|deallocate)/(post|poll) - side = request.postpath[0].decode("utf-8") - verb = request.postpath[1].decode("utf-8") + #data = json.load(request.content, encoding="utf-8") + content = request.content.read() + data = json.loads(content.decode("utf-8")) - if verb == "deallocate": - deleted = self.relay.maybe_free_child(self.channel_id, side) - if deleted: - return b"deleted\n" - return b"waiting\n" + side = data["side"] + phase = data["phase"] + if not isinstance(phase, type(u"")): + raise TypeError("phase must be string, not %s" % type(phase)) + body = data["body"] - if verb not in ("post", "poll"): - request.setResponseCode(http.BAD_REQUEST) - return b"bad verb, want 'post' or 'poll'\n" + self.db.execute("INSERT INTO `messages`" + " (`channel_id`, `side`, `phase`, `body`, `when`)" + " VALUES (?,?,?,?,?)", + (self.channel_id, side, phase, body, time.time())) + self.db.execute("INSERT INTO `allocations`" + " (`channel_id`, `side`)" + " VALUES (?,?)", + (self.channel_id, side)) + self.db.commit() + self.broadcast_message(phase, body) + return self.get_messages(request) - msgnum = request.postpath[2].decode("utf-8") +class Deallocator(resource.Resource): + def __init__(self, channel_id, relay): + self.channel_id = channel_id + self.relay = relay - other_messages = [] - for row in self.db.execute("SELECT `message` FROM `messages`" - " WHERE `channel_id`=? AND `side`!=?" - " AND `msgnum`=?" - " ORDER BY `when` ASC", - (self.channel_id, side, msgnum)).fetchall(): - other_messages.append(row["message"]) - - if verb == "post": - #data = json.load(request.content, encoding="utf-8") - content = request.content.read() - data = json.loads(content.decode("utf-8")) - self.db.execute("INSERT INTO `messages`" - " (`channel_id`, `side`, `msgnum`, `message`, `when`)" - " VALUES (?,?,?,?,?)", - (self.channel_id, side, msgnum, data["message"], - time.time())) - self.db.execute("INSERT INTO `allocations`" - " (`channel_id`, `side`)" - " VALUES (?,?)", - (self.channel_id, side)) - self.db.commit() - self.message_added(side, msgnum, data["message"]) - - request.setHeader(b"content-type", b"application/json; charset=utf-8") - data = {"welcome": self.welcome, - "messages": other_messages} - return (json.dumps(data)+"\n").encode("utf-8") + def render_POST(self, request): + content = request.content.read() + data = json.loads(content.decode("utf-8")) + side = data["side"] + deleted = self.relay.maybe_free_child(self.channel_id, side) + resp = {"status": "waiting"} + if deleted: + resp = {"status": "deleted"} + return json.dumps(resp).encode("utf-8") def get_allocated(db): c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`") return set([row["channel_id"] for row in c.fetchall()]) class Allocator(resource.Resource): - isLeaf = True def __init__(self, db, welcome): resource.Resource.__init__(self) self.db = db @@ -179,7 +168,11 @@ class Allocator(resource.Resource): raise ValueError("unable to find a free channel-id") def render_POST(self, request): - side = request.postpath[0] + content = request.content.read() + data = json.loads(content.decode("utf-8")) + side = data["side"] + if not isinstance(side, type(u"")): + raise TypeError("side must be string, not '%s'" % type(side)) channel_id = self.allocate_channel_id() self.db.execute("INSERT INTO `allocations` VALUES (?,?)", (channel_id, side)) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 9fd1071..d5653eb 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1,10 +1,13 @@ -import sys +from __future__ import print_function +import sys, json import requests from twisted.trial import unittest -from twisted.internet import reactor +from twisted.internet import reactor, defer from twisted.internet.threads import deferToThread from twisted.web.client import getPage, Agent, readBody +from .. import __version__ from .common import ServerBase +from ..twisted.eventsource_twisted import EventSource class Reachable(ServerBase, unittest.TestCase): @@ -47,3 +50,243 @@ class Reachable(ServerBase, unittest.TestCase): self.failUnlessEqual(res, "Wormhole Relay\n") d.addCallback(_got) return d + +def unjson(data): + return json.loads(data.decode("utf-8")) + +class API(ServerBase, unittest.TestCase): + def get(self, path, is_json=True): + url = (self.relayurl+path).encode("ascii") + d = getPage(url) + if is_json: + d.addCallback(unjson) + return d + def post(self, path, data): + url = (self.relayurl+path).encode("ascii") + d = getPage(url, method=b"POST", + postdata=json.dumps(data).encode("utf-8")) + d.addCallback(unjson) + return d + + def check_welcome(self, data): + self.failUnlessIn("welcome", data) + self.failUnlessEqual(data["welcome"], {"current_version": __version__}) + + def test_allocate_1(self): + d = self.get("list") + def _check_list_1(data): + self.check_welcome(data) + self.failUnlessEqual(data["channel-ids"], []) + d.addCallback(_check_list_1) + + d.addCallback(lambda _: self.post("allocate", {"side": "abc"})) + def _allocated(data): + self.failUnlessEqual(set(data.keys()), + set(["welcome", "channel-id"])) + self.failUnlessIsInstance(data["channel-id"], int) + self.cid = data["channel-id"] + d.addCallback(_allocated) + + d.addCallback(lambda _: self.get("list")) + def _check_list_2(data): + self.failUnlessEqual(data["channel-ids"], [self.cid]) + d.addCallback(_check_list_2) + + d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, + {"side": "abc"})) + def _check_deallocate(res): + self.failUnlessEqual(res["status"], "deleted") + d.addCallback(_check_deallocate) + + d.addCallback(lambda _: self.get("list")) + def _check_list_3(data): + self.failUnlessEqual(data["channel-ids"], []) + d.addCallback(_check_list_3) + + return d + + def test_allocate_2(self): + d = self.post("allocate", {"side": "abc"}) + def _allocated(data): + self.cid = data["channel-id"] + d.addCallback(_allocated) + + # second caller increases the number of known sides to 2 + d.addCallback(lambda _: self.post("%d" % self.cid, + {"side": "def", + "phase": "1", + "body": ""})) + + d.addCallback(lambda _: self.get("list")) + d.addCallback(lambda data: + self.failUnlessEqual(data["channel-ids"], [self.cid])) + + d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, + {"side": "abc"})) + d.addCallback(lambda res: + self.failUnlessEqual(res["status"], "waiting")) + + d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, + {"side": "NOT"})) + d.addCallback(lambda res: + self.failUnlessEqual(res["status"], "waiting")) + + d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, + {"side": "def"})) + d.addCallback(lambda res: + self.failUnlessEqual(res["status"], "deleted")) + + d.addCallback(lambda _: self.get("list")) + d.addCallback(lambda data: + self.failUnlessEqual(data["channel-ids"], [])) + + return d + + def add_message(self, message, side="abc", phase="1"): + return self.post(str(self.cid), {"side": side, "phase": phase, + "body": message}) + + def parse_messages(self, messages): + out = set() + for m in messages: + self.failUnlessEqual(sorted(m.keys()), sorted(["phase", "body"])) + self.failUnlessIsInstance(m["phase"], type(u"")) + self.failUnlessIsInstance(m["body"], type(u"")) + out.add( (m["phase"], m["body"]) ) + return out + + def check_messages(self, one, two): + # Comparing lists-of-dicts is non-trivial in python3 because we can + # neither sort them (dicts are uncomparable), nor turn them into sets + # (dicts are unhashable). This is close enough. + self.failUnlessEqual(len(one), len(two), (one,two)) + for d in one: + self.failUnlessIn(d, two) + + def test_messages(self): + d = self.post("allocate", {"side": "abc"}) + def _allocated(data): + self.cid = data["channel-id"] + d.addCallback(_allocated) + + d.addCallback(lambda _: self.add_message("msg1A")) + def _check1(data): + self.check_welcome(data) + self.failUnlessEqual(data["messages"], + [{"phase": "1", "body": "msg1A"}]) + d.addCallback(_check1) + d.addCallback(lambda _: self.add_message("msg1B", side="def")) + def _check2(data): + self.check_welcome(data) + self.failUnlessEqual(self.parse_messages(data["messages"]), + set([("1", "msg1A"), + ("1", "msg1B")])) + d.addCallback(_check2) + + # adding a duplicate message is not an error, is ignored by clients + d.addCallback(lambda _: self.add_message("msg1B", side="def")) + def _check3(data): + self.check_welcome(data) + self.failUnlessEqual(self.parse_messages(data["messages"]), + set([("1", "msg1A"), + ("1", "msg1B")])) + d.addCallback(_check3) + + d.addCallback(lambda _: self.add_message("msg2A", side="abc", + phase="2")) + def _check4(data): + self.check_welcome(data) + self.failUnlessEqual(self.parse_messages(data["messages"]), + set([("1", "msg1A"), + ("1", "msg1B"), + ("2", "msg2A"), + ])) + d.addCallback(_check4) + + return d + + def test_eventsource(self): + if sys.version_info[0] >= 3: + raise unittest.SkipTest("twisted vs py3") + + d = self.post("allocate", {"side": "abc"}) + def _allocated(data): + self.cid = data["channel-id"] + url = (self.relayurl+str(self.cid)).encode("utf-8") + self.o = OneEventAtATime(url, parser=json.loads) + return self.o.wait_for_connection() + d.addCallback(_allocated) + d.addCallback(lambda _: self.o.wait_for_next_event()) + def _check_welcome(ev): + eventtype, data = ev + self.failUnlessEqual(eventtype, "welcome") + self.failUnlessEqual(data, {"current_version": __version__}) + d.addCallback(_check_welcome) + d.addCallback(lambda _: self.add_message("msg1A")) + d.addCallback(lambda _: self.o.wait_for_next_event()) + def _check_msg1(ev): + eventtype, data = ev + self.failUnlessEqual(eventtype, "message") + self.failUnlessEqual(data, {"phase": "1", "body": "msg1A"}) + d.addCallback(_check_msg1) + + d.addCallback(lambda _: self.add_message("msg1B")) + d.addCallback(lambda _: self.add_message("msg2A", phase="2")) + d.addCallback(lambda _: self.o.wait_for_next_event()) + def _check_msg2(ev): + eventtype, data = ev + self.failUnlessEqual(eventtype, "message") + self.failUnlessEqual(data, {"phase": "1", "body": "msg1B"}) + d.addCallback(_check_msg2) + d.addCallback(lambda _: self.o.wait_for_next_event()) + def _check_msg3(ev): + eventtype, data = ev + self.failUnlessEqual(eventtype, "message") + self.failUnlessEqual(data, {"phase": "2", "body": "msg2A"}) + d.addCallback(_check_msg3) + + d.addCallback(lambda _: self.o.close()) + d.addCallback(lambda _: self.o.wait_for_disconnection()) + return d + +class OneEventAtATime: + def __init__(self, url, parser=lambda e: e): + self.parser = parser + self.d = None + self.connected_d = defer.Deferred() + self.disconnected_d = defer.Deferred() + self.events = [] + self.es = EventSource(url, self.handler, when_connected=self.connected) + d = self.es.start() + d.addBoth(self.disconnected) + + def close(self): + self.es.cancel() + + def wait_for_next_event(self): + assert not self.d + if self.events: + event = self.events.pop(0) + return defer.succeed(event) + self.d = defer.Deferred() + return self.d + + def handler(self, eventtype, data): + event = (eventtype, self.parser(data)) + if self.d: + assert not self.events + d,self.d = self.d,None + d.callback(event) + return + self.events.append(event) + + def wait_for_connection(self): + return self.connected_d + def connected(self): + self.connected_d.callback(None) + + def wait_for_disconnection(self): + return self.disconnected_d + def disconnected(self, why): + self.disconnected_d.callback((why,)) + diff --git a/src/wormhole/twisted/eventsource_twisted.py b/src/wormhole/twisted/eventsource_twisted.py index 939b0ba..2d508a3 100644 --- a/src/wormhole/twisted/eventsource_twisted.py +++ b/src/wormhole/twisted/eventsource_twisted.py @@ -93,6 +93,7 @@ class EventSource: # TODO: service.Service def start(self): assert not self.started, "single-use" self.started = True + assert self.url d = self.agent.request("GET", self.url, Headers({"accept": ["text/event-stream"]})) d.addCallback(self._connected) @@ -154,14 +155,13 @@ class Connector: class ReconnectingEventSource(service.MultiService, protocol.ReconnectingClientFactory): - def __init__(self, baseurl, connection_starting, handler, agent=None): + def __init__(self, url, handler, agent=None): service.MultiService.__init__(self) # we don't use any of the basic Factory/ClientFactory methods of # this, just the ReconnectingClientFactory.retry, stopTrying, and # resetDelay methods. - self.baseurl = baseurl - self.connection_starting = connection_starting + self.url = url self.handler = handler self.agent = agent # IService provides self.running, toggled by {start,stop}Service. @@ -201,8 +201,7 @@ class ReconnectingEventSource(service.MultiService, if not (self.active and self.running): return self.continueTrying = True - url = self.connection_starting() - self.es = EventSource(url, self.handler, self.resetDelay, + self.es = EventSource(self.url, self.handler, self.resetDelay, agent=self.agent) d = self.es.start() d.addBoth(self._stopped) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index b56c46b..a681560 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -13,8 +13,7 @@ from spake2 import SPAKE2_Symmetric from .eventsource_twisted import ReconnectingEventSource from .. import __version__ from .. import codes -from ..errors import (ServerError, WrongPasswordError, - ReflectionAttack, UsageError) +from ..errors import ServerError, WrongPasswordError, UsageError from ..util.hkdf import HKDF @implementer(IBodyProducer) @@ -46,12 +45,9 @@ class Wormhole: self.code = None self.key = None self._started_get_code = False - - def _url(self, verb, msgnum=None): - url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb) - if msgnum is not None: - url += "/" + msgnum - return url + self._channel_url = None + self._messages = set() # (phase,body) , body is bytes + self._sent_messages = set() # (phase,body) def handle_welcome(self, welcome): if ("motd" in welcome and @@ -75,14 +71,10 @@ class Wormhole: if "error" in welcome: raise ServerError(welcome["error"], self.relay) - def _post_json(self, url, post_json=None): - # POST to a URL, parsing the response as JSON. Optionally include a - # JSON request body. - p = None - if post_json: - data = json.dumps(post_json).encode("utf-8") - p = DataProducer(data) - d = self.agent.request("POST", url, bodyProducer=p) + def _post_json(self, url, post_json): + # POST a JSON body to a URL, parsing the response as JSON + data = json.dumps(post_json).encode("utf-8") + d = self.agent.request("POST", url, bodyProducer=DataProducer(data)) def _check_error(resp): if resp.code != 200: raise web_error.Error(resp.code, resp.phrase) @@ -93,8 +85,8 @@ class Wormhole: return d def _allocate_channel(self): - url = self.relay + "allocate/%s" % self.side - d = self._post_json(url) + url = self.relay + "allocate" + d = self._post_json(url, {"side": self.side}) def _got_channel(data): if "welcome" in data: self.handle_welcome(data["welcome"]) @@ -131,6 +123,7 @@ class Wormhole: if not mo: raise ValueError("code (%s) must start with NN-" % code) self.channel_id = int(mo.group(1)) + self._channel_url = "%s%d" % (self.relay, self.channel_id) self.code = code def _start(self): @@ -164,36 +157,56 @@ class Wormhole: self.msg1 = d["msg1"].decode("hex") return self - def _post_message(self, url, msg): + def _add_inbound_messages(self, messages): + for msg in messages: + phase = msg["phase"] + body = unhexlify(msg["body"].encode("ascii")) + self._messages.add( (phase, body) ) + + def _find_inbound_message(self, phase): + for (their_phase,body) in self._messages - self._sent_messages: + if their_phase == phase: + return body + return None + + def _send_message(self, phase, msg): # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. + if not isinstance(phase, type(u"")): raise UsageError(type(phase)) if not isinstance(msg, type(b"")): raise UsageError(type(msg)) - d = self._post_json(url, {"message": hexlify(msg).decode("ascii")}) - d.addCallback(lambda resp: resp["messages"]) # other_msgs + self._sent_messages.add( (phase,msg) ) + payload = {"side": self.side, + "phase": phase, + "body": hexlify(msg).decode("ascii")} + d = self._post_json(self._channel_url, payload) + d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"])) return d - def _get_message(self, old_msgs, verb, msgnum): - # fire with a bytestring of the first message that matches - # verb/msgnum, which either came from old_msgs, or from an - # EventSource that we attached to the corresponding URL - if old_msgs: - msg = unhexlify(old_msgs[0].encode("ascii")) - return defer.succeed(msg) + def _get_message(self, phase): + # fire with a bytestring of the first message for 'phase' that wasn't + # one of ours. It will either come from previously-received messages, + # or from an EventSource that we attach to the corresponding URL + body = self._find_inbound_message(phase) + if body is not None: + return defer.succeed(body) d = defer.Deferred() msgs = [] def _handle(name, data): if name == "welcome": self.handle_welcome(json.loads(data)) if name == "message": - msgs.append(json.loads(data)["message"]) - d.callback(None) - es = ReconnectingEventSource(None, lambda: self._url(verb, msgnum), - _handle)#, agent=self.agent) + self._add_inbound_messages([json.loads(data)]) + body = self._find_inbound_message(phase) + if body is not None and not msgs: + msgs.append(body) + d.callback(None) + # TODO: use agent=self.agent + es = ReconnectingEventSource(self._channel_url, _handle) es.startService() # TODO: .setServiceParent(self) es.activate() d.addCallback(lambda _: es.deactivate()) d.addCallback(lambda _: es.stopService()) - d.addCallback(lambda _: unhexlify(msgs[0].encode("ascii"))) + d.addCallback(lambda _: msgs[0]) return d def derive_key(self, purpose, length=SecretBox.KEY_SIZE): @@ -224,8 +237,8 @@ class Wormhole: # TODO: prevent multiple invocation if self.key: return defer.succeed(self.key) - d = self._post_message(self._url("post", "pake"), self.msg1) - d.addCallback(lambda msgs: self._get_message(msgs, "poll", "pake")) + d = self._send_message(u"pake", self.msg1) + d.addCallback(lambda _: self._get_message(u"pake")) def _got_pake(pake_msg): key = self.sp.finish(pake_msg) self.key = key @@ -256,12 +269,12 @@ class Wormhole: data_key = self.derive_key(b"data-key") outbound_encrypted = self._encrypt_data(data_key, outbound_data) - d = self._post_message(self._url("post", "data"), outbound_encrypted) + d = self._send_message(u"data", outbound_encrypted) - d.addCallback(lambda msgs: self._get_message(msgs, "poll", "data")) + d.addCallback(lambda _: self._get_message(u"data")) def _got_data(inbound_encrypted): - if inbound_encrypted == outbound_encrypted: - raise ReflectionAttack + #if inbound_encrypted == outbound_encrypted: + # raise ReflectionAttack try: inbound_data = self._decrypt_data(data_key, inbound_encrypted) return inbound_data @@ -272,6 +285,7 @@ class Wormhole: def _deallocate(self, res): # only try once, no retries - d = self.agent.request("POST", self._url("deallocate")) + d = self._post_json(self._channel_url+"/deallocate", + {"side": self.side}) d.addBoth(lambda _: res) # ignore POST failure, pass-through result return d