From 617bb03ad542aa661b081951ba4708b92d4cb4e6 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 3 Oct 2015 12:36:14 -0700 Subject: [PATCH] rewrite server API This removes "side" and "msgnum" from the URLs, and puts them in a JSON request body instead. The server now maintains a simple set of messages for each channel-id, and isn't responsible for removing duplicates. The client now fetches all messages, and just ignores everything it sent itself. This removes the "reflection attack". Deallocate now returns JSON, for consistency. DB and API use "phase" and "body" instead of msgnum/message. This changes the DB schema, so delete the DB before upgrading the server. --- src/wormhole/blocking/transcribe.py | 115 +++++---- src/wormhole/db-schemas/v1.sql | 6 +- src/wormhole/servers/relay.py | 153 ++++++------ src/wormhole/test/test_server.py | 247 +++++++++++++++++++- src/wormhole/twisted/eventsource_twisted.py | 9 +- src/wormhole/twisted/transcribe.py | 94 ++++---- 6 files changed, 445 insertions(+), 179 deletions(-) diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index 5b5ef69..8fd84d3 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -8,21 +8,21 @@ from nacl import utils from .eventsource import EventSourceFollower from .. import __version__ from .. import codes -from ..errors import (ServerError, Timeout, WrongPasswordError, - ReflectionAttack, UsageError) +from ..errors import ServerError, Timeout, WrongPasswordError, UsageError from ..util.hkdf import HKDF SECOND = 1 MINUTE = 60*SECOND # relay URLs are: -# GET /list -> {channel-ids: [INT..]} -# POST /allocate/SIDE -> {channel-id: INT} -# these return all messages for CHANNEL-ID= and MSGNUM= but SIDE!= : -# POST /CHANNEL-ID/SIDE/post/MSGNUM {message: STR} -> {messages: [STR..]} -# POST /CHANNEL-ID/SIDE/poll/MSGNUM -> {messages: [STR..]} -# GET /CHANNEL-ID/SIDE/poll/MSGNUM (eventsource) -> STR, STR, .. -# POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted +# GET /list -> {channel-ids: [INT..]} +# POST /allocate {side: SIDE} -> {channel-id: INT} +# these return all messages (base64) for CID= : +# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]} +# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]} +# GET /CID (eventsource) -> {phase:, body:}.. +# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted} +# all JSON responses include a "welcome:{..}" key class Wormhole: motd_displayed = False @@ -40,12 +40,9 @@ class Wormhole: self.code = None self.key = None self.verifier = None - - def _url(self, verb, msgnum=None): - url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb) - if msgnum is not None: - url += "/" + msgnum - return url + self._channel_url = None + self._messages = set() # (phase,body) , body is bytes + self._sent_messages = set() # (phase,body) def handle_welcome(self, welcome): if ("motd" in welcome and @@ -69,18 +66,9 @@ class Wormhole: if "error" in welcome: raise ServerError(welcome["error"], self.relay) - def _post_json(self, url, post_json=None): - # POST to a URL, parsing the response as JSON. Optionally include a - # JSON request body. - data = None - if post_json: - data = json.dumps(post_json).encode("utf-8") - r = requests.post(url, data=data) - r.raise_for_status() - return r.json() - def _allocate_channel(self): - r = requests.post(self.relay + "allocate/%s" % self.side) + data = json.dumps({"side": self.side}).encode("utf-8") + r = requests.post(self.relay + "allocate", data=data) r.raise_for_status() data = r.json() if "welcome" in data: @@ -123,6 +111,7 @@ class Wormhole: if not mo: raise ValueError("code (%s) must start with NN-" % code) self.channel_id = int(mo.group(1)) + self._channel_url = "%s%d" % (self.relay, self.channel_id) self.code = code def _start(self): @@ -131,36 +120,62 @@ class Wormhole: idSymmetric=self.appid) self.msg1 = self.sp.start() - def _post_message(self, url, msg): + def _add_inbound_messages(self, messages): + for msg in messages: + phase = msg["phase"] + body = unhexlify(msg["body"].encode("ascii")) + self._messages.add( (phase, body) ) + + def _find_inbound_message(self, phase): + for (their_phase,body) in self._messages - self._sent_messages: + if their_phase == phase: + return body + return None + + def _send_message(self, phase, msg): # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. + if not isinstance(phase, type(u"")): raise UsageError(type(phase)) if not isinstance(msg, type(b"")): raise UsageError(type(msg)) - resp = self._post_json(url, {"message": hexlify(msg).decode("ascii")}) - return resp["messages"] # other_msgs + self._sent_messages.add( (phase,msg) ) + payload = {"side": self.side, + "phase": phase, + "body": hexlify(msg).decode("ascii")} + data = json.dumps(payload).encode("utf-8") + r = requests.post(self._channel_url, data=data) + r.raise_for_status() + resp = r.json() + self._add_inbound_messages(resp["messages"]) - def _get_message(self, old_msgs, verb, msgnum): + def _get_message(self, phase): + if not isinstance(phase, type(u"")): raise UsageError(type(phase)) # For now, server errors cause the client to fail. TODO: don't. This # will require changing the client to re-post messages when the # server comes back up. - # fire with a bytestring of the first message that matches - # verb/msgnum, which either came from old_msgs, or from an - # EventSource that we attached to the corresponding URL - msgs = old_msgs - while not msgs: + # fire with a bytestring of the first message for 'phase' that wasn't + # one of ours. It will either come from previously-received messages, + # or from an EventSource that we attach to the corresponding URL + body = self._find_inbound_message(phase) + while body is None: remaining = self.started + self.timeout - time.time() if remaining < 0: - raise Timeout - #time.sleep(self.wait) - f = EventSourceFollower(self._url(verb, msgnum), remaining) + return Timeout + f = EventSourceFollower(self._channel_url, remaining) + # we loop here until the connection is lost, or we see the + # message we want for (eventtype, data) in f.iter_events(): if eventtype == "welcome": self.handle_welcome(json.loads(data)) if eventtype == "message": - msgs = [json.loads(data)["message"]] - break - f.close() - return unhexlify(msgs[0].encode("ascii")) + self._add_inbound_messages([json.loads(data)]) + body = self._find_inbound_message(phase) + if body: + f.close() + break + if not body: + time.sleep(self.wait) + return body def derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(b"")): raise UsageError @@ -185,8 +200,8 @@ class Wormhole: def _get_key(self): if not self.key: - old_msgs = self._post_message(self._url("post", "pake"), self.msg1) - pake_msg = self._get_message(old_msgs, "poll", "pake") + self._send_message(u"pake", self.msg1) + pake_msg = self._get_message(u"pake") self.key = self.sp.finish(pake_msg) self.verifier = self.derive_key(self.appid+b":Verifier") @@ -214,11 +229,12 @@ class Wormhole: data_key = self.derive_key(b"data-key") outbound_encrypted = self._encrypt_data(data_key, outbound_data) - msgs = self._post_message(self._url("post", "data"), outbound_encrypted) + self._send_message(u"data", outbound_encrypted) - inbound_encrypted = self._get_message(msgs, "poll", "data") - if inbound_encrypted == outbound_encrypted: - raise ReflectionAttack + inbound_encrypted = self._get_message(u"data") + # _find_inbound_message() ignores any inbound message that matches + # something we previously sent out, so we don't need to explicitly + # check for reflection. A reflection attack will just not progress. try: inbound_data = self._decrypt_data(data_key, inbound_encrypted) return inbound_data @@ -227,5 +243,6 @@ class Wormhole: def _deallocate(self): # only try once, no retries - requests.post(self._url("deallocate")) + data = json.dumps({"side": self.side}).encode("utf-8") + requests.post(self._channel_url+"/deallocate", data=data) # ignore POST failure, don't call r.raise_for_status() diff --git a/src/wormhole/db-schemas/v1.sql b/src/wormhole/db-schemas/v1.sql index 43785de..106beeb 100644 --- a/src/wormhole/db-schemas/v1.sql +++ b/src/wormhole/db-schemas/v1.sql @@ -11,11 +11,11 @@ CREATE TABLE `messages` ( `channel_id` INTEGER, `side` VARCHAR, - `msgnum` VARCHAR, -- not numeric, more of a PAKE-phase indicator string - `message` VARCHAR, + `phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string + `body` VARCHAR, `when` INTEGER ); -CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `msgnum`); +CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `phase`); CREATE TABLE `allocations` ( diff --git a/src/wormhole/servers/relay.py b/src/wormhole/servers/relay.py index 02ad986..00a6605 100644 --- a/src/wormhole/servers/relay.py +++ b/src/wormhole/servers/relay.py @@ -51,112 +51,101 @@ class EventsProtocol: # note: no versions of IE (including the current IE11) support EventSource # relay URLs are: -# GET /list -> {channel-ids: [INT..]} -# POST /allocate/SIDE -> {channel-id: INT} -# these return all messages for CHANNEL-ID= and MSGNUM= but SIDE!= : -# POST /CHANNEL-ID/SIDE/post/MSGNUM {message: STR} -> {messages: [STR..]} -# POST /CHANNEL-ID/SIDE/poll/MSGNUM -> {messages: [STR..]} -# GET /CHANNEL-ID/SIDE/poll/MSGNUM (eventsource) -> STR, STR, .. -# POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted +# GET /list -> {channel-ids: [INT..]} +# POST /allocate {side: SIDE} -> {channel-id: INT} +# these return all messages (base64) for CID= : +# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]} +# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]} +# GET /CID (eventsource) -> {phase:, body:}.. +# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted} +# all JSON responses include a "welcome:{..}" key class Channel(resource.Resource): - isLeaf = True # I handle /CHANNEL-ID/* - def __init__(self, channel_id, relay, db, welcome): resource.Resource.__init__(self) self.channel_id = channel_id self.relay = relay self.db = db self.welcome = welcome - self.event_channels = set() # (side, msgnum, ep) + self.event_channels = set() # ep + self.putChild(b"deallocate", Deallocator(self.channel_id, self.relay)) - def render_GET(self, request): - # rest of URL is: SIDE/poll/MSGNUM - their_side = request.postpath[0].decode("utf-8") - if request.postpath[1] != b"poll": - request.setResponseCode(http.BAD_REQUEST, b"GET to wrong URL") - return b"GET is only for /SIDE/poll/MSGNUM" - their_msgnum = request.postpath[2].decode("utf-8") - if b"text/event-stream" not in (request.getHeader(b"accept") or b""): - request.setResponseCode(http.BAD_REQUEST, b"Must use EventSource") - return b"Must use EventSource (Content-Type: text/event-stream)" - request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") - ep = EventsProtocol(request) - ep.sendEvent(json.dumps(self.welcome), name="welcome") - handle = (their_side, their_msgnum, ep) - self.event_channels.add(handle) - request.notifyFinish().addErrback(self._shutdown, handle) + def get_messages(self, request): + request.setHeader(b"content-type", b"application/json; charset=utf-8") + messages = [] for row in self.db.execute("SELECT * FROM `messages`" " WHERE `channel_id`=?" " ORDER BY `when` ASC", (self.channel_id,)).fetchall(): - self.message_added(row["side"], row["msgnum"], row["message"], - channels=[handle]) + messages.append({"phase": row["phase"], "body": row["body"]}) + data = {"welcome": self.welcome, "messages": messages} + return (json.dumps(data)+"\n").encode("utf-8") + + def render_GET(self, request): + if b"text/event-stream" not in (request.getHeader(b"accept") or b""): + return self.get_messages(request) + request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") + ep = EventsProtocol(request) + ep.sendEvent(json.dumps(self.welcome), name="welcome") + self.event_channels.add(ep) + request.notifyFinish().addErrback(lambda f: + self.event_channels.discard(ep)) + for row in self.db.execute("SELECT * FROM `messages`" + " WHERE `channel_id`=?" + " ORDER BY `when` ASC", + (self.channel_id,)).fetchall(): + data = json.dumps({"phase": row["phase"], "body": row["body"]}) + ep.sendEvent(data) return server.NOT_DONE_YET - def _shutdown(self, _, handle): - self.event_channels.discard(handle) - - def message_added(self, msg_side, msg_msgnum, msg_str, channels=None): - if channels is None: - channels = self.event_channels - for (their_side, their_msgnum, their_ep) in channels: - if msg_side != their_side and msg_msgnum == their_msgnum: - data = json.dumps({ "side": msg_side, "message": msg_str }) - their_ep.sendEvent(data) + def broadcast_message(self, phase, body): + data = json.dumps({"phase": phase, "body": body}) + for ep in self.event_channels: + ep.sendEvent(data) def render_POST(self, request): - # rest of URL is: SIDE/(MSGNUM|deallocate)/(post|poll) - side = request.postpath[0].decode("utf-8") - verb = request.postpath[1].decode("utf-8") + #data = json.load(request.content, encoding="utf-8") + content = request.content.read() + data = json.loads(content.decode("utf-8")) - if verb == "deallocate": - deleted = self.relay.maybe_free_child(self.channel_id, side) - if deleted: - return b"deleted\n" - return b"waiting\n" + side = data["side"] + phase = data["phase"] + if not isinstance(phase, type(u"")): + raise TypeError("phase must be string, not %s" % type(phase)) + body = data["body"] - if verb not in ("post", "poll"): - request.setResponseCode(http.BAD_REQUEST) - return b"bad verb, want 'post' or 'poll'\n" + self.db.execute("INSERT INTO `messages`" + " (`channel_id`, `side`, `phase`, `body`, `when`)" + " VALUES (?,?,?,?,?)", + (self.channel_id, side, phase, body, time.time())) + self.db.execute("INSERT INTO `allocations`" + " (`channel_id`, `side`)" + " VALUES (?,?)", + (self.channel_id, side)) + self.db.commit() + self.broadcast_message(phase, body) + return self.get_messages(request) - msgnum = request.postpath[2].decode("utf-8") +class Deallocator(resource.Resource): + def __init__(self, channel_id, relay): + self.channel_id = channel_id + self.relay = relay - other_messages = [] - for row in self.db.execute("SELECT `message` FROM `messages`" - " WHERE `channel_id`=? AND `side`!=?" - " AND `msgnum`=?" - " ORDER BY `when` ASC", - (self.channel_id, side, msgnum)).fetchall(): - other_messages.append(row["message"]) - - if verb == "post": - #data = json.load(request.content, encoding="utf-8") - content = request.content.read() - data = json.loads(content.decode("utf-8")) - self.db.execute("INSERT INTO `messages`" - " (`channel_id`, `side`, `msgnum`, `message`, `when`)" - " VALUES (?,?,?,?,?)", - (self.channel_id, side, msgnum, data["message"], - time.time())) - self.db.execute("INSERT INTO `allocations`" - " (`channel_id`, `side`)" - " VALUES (?,?)", - (self.channel_id, side)) - self.db.commit() - self.message_added(side, msgnum, data["message"]) - - request.setHeader(b"content-type", b"application/json; charset=utf-8") - data = {"welcome": self.welcome, - "messages": other_messages} - return (json.dumps(data)+"\n").encode("utf-8") + def render_POST(self, request): + content = request.content.read() + data = json.loads(content.decode("utf-8")) + side = data["side"] + deleted = self.relay.maybe_free_child(self.channel_id, side) + resp = {"status": "waiting"} + if deleted: + resp = {"status": "deleted"} + return json.dumps(resp).encode("utf-8") def get_allocated(db): c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`") return set([row["channel_id"] for row in c.fetchall()]) class Allocator(resource.Resource): - isLeaf = True def __init__(self, db, welcome): resource.Resource.__init__(self) self.db = db @@ -179,7 +168,11 @@ class Allocator(resource.Resource): raise ValueError("unable to find a free channel-id") def render_POST(self, request): - side = request.postpath[0] + content = request.content.read() + data = json.loads(content.decode("utf-8")) + side = data["side"] + if not isinstance(side, type(u"")): + raise TypeError("side must be string, not '%s'" % type(side)) channel_id = self.allocate_channel_id() self.db.execute("INSERT INTO `allocations` VALUES (?,?)", (channel_id, side)) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 9fd1071..d5653eb 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1,10 +1,13 @@ -import sys +from __future__ import print_function +import sys, json import requests from twisted.trial import unittest -from twisted.internet import reactor +from twisted.internet import reactor, defer from twisted.internet.threads import deferToThread from twisted.web.client import getPage, Agent, readBody +from .. import __version__ from .common import ServerBase +from ..twisted.eventsource_twisted import EventSource class Reachable(ServerBase, unittest.TestCase): @@ -47,3 +50,243 @@ class Reachable(ServerBase, unittest.TestCase): self.failUnlessEqual(res, "Wormhole Relay\n") d.addCallback(_got) return d + +def unjson(data): + return json.loads(data.decode("utf-8")) + +class API(ServerBase, unittest.TestCase): + def get(self, path, is_json=True): + url = (self.relayurl+path).encode("ascii") + d = getPage(url) + if is_json: + d.addCallback(unjson) + return d + def post(self, path, data): + url = (self.relayurl+path).encode("ascii") + d = getPage(url, method=b"POST", + postdata=json.dumps(data).encode("utf-8")) + d.addCallback(unjson) + return d + + def check_welcome(self, data): + self.failUnlessIn("welcome", data) + self.failUnlessEqual(data["welcome"], {"current_version": __version__}) + + def test_allocate_1(self): + d = self.get("list") + def _check_list_1(data): + self.check_welcome(data) + self.failUnlessEqual(data["channel-ids"], []) + d.addCallback(_check_list_1) + + d.addCallback(lambda _: self.post("allocate", {"side": "abc"})) + def _allocated(data): + self.failUnlessEqual(set(data.keys()), + set(["welcome", "channel-id"])) + self.failUnlessIsInstance(data["channel-id"], int) + self.cid = data["channel-id"] + d.addCallback(_allocated) + + d.addCallback(lambda _: self.get("list")) + def _check_list_2(data): + self.failUnlessEqual(data["channel-ids"], [self.cid]) + d.addCallback(_check_list_2) + + d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, + {"side": "abc"})) + def _check_deallocate(res): + self.failUnlessEqual(res["status"], "deleted") + d.addCallback(_check_deallocate) + + d.addCallback(lambda _: self.get("list")) + def _check_list_3(data): + self.failUnlessEqual(data["channel-ids"], []) + d.addCallback(_check_list_3) + + return d + + def test_allocate_2(self): + d = self.post("allocate", {"side": "abc"}) + def _allocated(data): + self.cid = data["channel-id"] + d.addCallback(_allocated) + + # second caller increases the number of known sides to 2 + d.addCallback(lambda _: self.post("%d" % self.cid, + {"side": "def", + "phase": "1", + "body": ""})) + + d.addCallback(lambda _: self.get("list")) + d.addCallback(lambda data: + self.failUnlessEqual(data["channel-ids"], [self.cid])) + + d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, + {"side": "abc"})) + d.addCallback(lambda res: + self.failUnlessEqual(res["status"], "waiting")) + + d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, + {"side": "NOT"})) + d.addCallback(lambda res: + self.failUnlessEqual(res["status"], "waiting")) + + d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, + {"side": "def"})) + d.addCallback(lambda res: + self.failUnlessEqual(res["status"], "deleted")) + + d.addCallback(lambda _: self.get("list")) + d.addCallback(lambda data: + self.failUnlessEqual(data["channel-ids"], [])) + + return d + + def add_message(self, message, side="abc", phase="1"): + return self.post(str(self.cid), {"side": side, "phase": phase, + "body": message}) + + def parse_messages(self, messages): + out = set() + for m in messages: + self.failUnlessEqual(sorted(m.keys()), sorted(["phase", "body"])) + self.failUnlessIsInstance(m["phase"], type(u"")) + self.failUnlessIsInstance(m["body"], type(u"")) + out.add( (m["phase"], m["body"]) ) + return out + + def check_messages(self, one, two): + # Comparing lists-of-dicts is non-trivial in python3 because we can + # neither sort them (dicts are uncomparable), nor turn them into sets + # (dicts are unhashable). This is close enough. + self.failUnlessEqual(len(one), len(two), (one,two)) + for d in one: + self.failUnlessIn(d, two) + + def test_messages(self): + d = self.post("allocate", {"side": "abc"}) + def _allocated(data): + self.cid = data["channel-id"] + d.addCallback(_allocated) + + d.addCallback(lambda _: self.add_message("msg1A")) + def _check1(data): + self.check_welcome(data) + self.failUnlessEqual(data["messages"], + [{"phase": "1", "body": "msg1A"}]) + d.addCallback(_check1) + d.addCallback(lambda _: self.add_message("msg1B", side="def")) + def _check2(data): + self.check_welcome(data) + self.failUnlessEqual(self.parse_messages(data["messages"]), + set([("1", "msg1A"), + ("1", "msg1B")])) + d.addCallback(_check2) + + # adding a duplicate message is not an error, is ignored by clients + d.addCallback(lambda _: self.add_message("msg1B", side="def")) + def _check3(data): + self.check_welcome(data) + self.failUnlessEqual(self.parse_messages(data["messages"]), + set([("1", "msg1A"), + ("1", "msg1B")])) + d.addCallback(_check3) + + d.addCallback(lambda _: self.add_message("msg2A", side="abc", + phase="2")) + def _check4(data): + self.check_welcome(data) + self.failUnlessEqual(self.parse_messages(data["messages"]), + set([("1", "msg1A"), + ("1", "msg1B"), + ("2", "msg2A"), + ])) + d.addCallback(_check4) + + return d + + def test_eventsource(self): + if sys.version_info[0] >= 3: + raise unittest.SkipTest("twisted vs py3") + + d = self.post("allocate", {"side": "abc"}) + def _allocated(data): + self.cid = data["channel-id"] + url = (self.relayurl+str(self.cid)).encode("utf-8") + self.o = OneEventAtATime(url, parser=json.loads) + return self.o.wait_for_connection() + d.addCallback(_allocated) + d.addCallback(lambda _: self.o.wait_for_next_event()) + def _check_welcome(ev): + eventtype, data = ev + self.failUnlessEqual(eventtype, "welcome") + self.failUnlessEqual(data, {"current_version": __version__}) + d.addCallback(_check_welcome) + d.addCallback(lambda _: self.add_message("msg1A")) + d.addCallback(lambda _: self.o.wait_for_next_event()) + def _check_msg1(ev): + eventtype, data = ev + self.failUnlessEqual(eventtype, "message") + self.failUnlessEqual(data, {"phase": "1", "body": "msg1A"}) + d.addCallback(_check_msg1) + + d.addCallback(lambda _: self.add_message("msg1B")) + d.addCallback(lambda _: self.add_message("msg2A", phase="2")) + d.addCallback(lambda _: self.o.wait_for_next_event()) + def _check_msg2(ev): + eventtype, data = ev + self.failUnlessEqual(eventtype, "message") + self.failUnlessEqual(data, {"phase": "1", "body": "msg1B"}) + d.addCallback(_check_msg2) + d.addCallback(lambda _: self.o.wait_for_next_event()) + def _check_msg3(ev): + eventtype, data = ev + self.failUnlessEqual(eventtype, "message") + self.failUnlessEqual(data, {"phase": "2", "body": "msg2A"}) + d.addCallback(_check_msg3) + + d.addCallback(lambda _: self.o.close()) + d.addCallback(lambda _: self.o.wait_for_disconnection()) + return d + +class OneEventAtATime: + def __init__(self, url, parser=lambda e: e): + self.parser = parser + self.d = None + self.connected_d = defer.Deferred() + self.disconnected_d = defer.Deferred() + self.events = [] + self.es = EventSource(url, self.handler, when_connected=self.connected) + d = self.es.start() + d.addBoth(self.disconnected) + + def close(self): + self.es.cancel() + + def wait_for_next_event(self): + assert not self.d + if self.events: + event = self.events.pop(0) + return defer.succeed(event) + self.d = defer.Deferred() + return self.d + + def handler(self, eventtype, data): + event = (eventtype, self.parser(data)) + if self.d: + assert not self.events + d,self.d = self.d,None + d.callback(event) + return + self.events.append(event) + + def wait_for_connection(self): + return self.connected_d + def connected(self): + self.connected_d.callback(None) + + def wait_for_disconnection(self): + return self.disconnected_d + def disconnected(self, why): + self.disconnected_d.callback((why,)) + diff --git a/src/wormhole/twisted/eventsource_twisted.py b/src/wormhole/twisted/eventsource_twisted.py index 939b0ba..2d508a3 100644 --- a/src/wormhole/twisted/eventsource_twisted.py +++ b/src/wormhole/twisted/eventsource_twisted.py @@ -93,6 +93,7 @@ class EventSource: # TODO: service.Service def start(self): assert not self.started, "single-use" self.started = True + assert self.url d = self.agent.request("GET", self.url, Headers({"accept": ["text/event-stream"]})) d.addCallback(self._connected) @@ -154,14 +155,13 @@ class Connector: class ReconnectingEventSource(service.MultiService, protocol.ReconnectingClientFactory): - def __init__(self, baseurl, connection_starting, handler, agent=None): + def __init__(self, url, handler, agent=None): service.MultiService.__init__(self) # we don't use any of the basic Factory/ClientFactory methods of # this, just the ReconnectingClientFactory.retry, stopTrying, and # resetDelay methods. - self.baseurl = baseurl - self.connection_starting = connection_starting + self.url = url self.handler = handler self.agent = agent # IService provides self.running, toggled by {start,stop}Service. @@ -201,8 +201,7 @@ class ReconnectingEventSource(service.MultiService, if not (self.active and self.running): return self.continueTrying = True - url = self.connection_starting() - self.es = EventSource(url, self.handler, self.resetDelay, + self.es = EventSource(self.url, self.handler, self.resetDelay, agent=self.agent) d = self.es.start() d.addBoth(self._stopped) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index b56c46b..a681560 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -13,8 +13,7 @@ from spake2 import SPAKE2_Symmetric from .eventsource_twisted import ReconnectingEventSource from .. import __version__ from .. import codes -from ..errors import (ServerError, WrongPasswordError, - ReflectionAttack, UsageError) +from ..errors import ServerError, WrongPasswordError, UsageError from ..util.hkdf import HKDF @implementer(IBodyProducer) @@ -46,12 +45,9 @@ class Wormhole: self.code = None self.key = None self._started_get_code = False - - def _url(self, verb, msgnum=None): - url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb) - if msgnum is not None: - url += "/" + msgnum - return url + self._channel_url = None + self._messages = set() # (phase,body) , body is bytes + self._sent_messages = set() # (phase,body) def handle_welcome(self, welcome): if ("motd" in welcome and @@ -75,14 +71,10 @@ class Wormhole: if "error" in welcome: raise ServerError(welcome["error"], self.relay) - def _post_json(self, url, post_json=None): - # POST to a URL, parsing the response as JSON. Optionally include a - # JSON request body. - p = None - if post_json: - data = json.dumps(post_json).encode("utf-8") - p = DataProducer(data) - d = self.agent.request("POST", url, bodyProducer=p) + def _post_json(self, url, post_json): + # POST a JSON body to a URL, parsing the response as JSON + data = json.dumps(post_json).encode("utf-8") + d = self.agent.request("POST", url, bodyProducer=DataProducer(data)) def _check_error(resp): if resp.code != 200: raise web_error.Error(resp.code, resp.phrase) @@ -93,8 +85,8 @@ class Wormhole: return d def _allocate_channel(self): - url = self.relay + "allocate/%s" % self.side - d = self._post_json(url) + url = self.relay + "allocate" + d = self._post_json(url, {"side": self.side}) def _got_channel(data): if "welcome" in data: self.handle_welcome(data["welcome"]) @@ -131,6 +123,7 @@ class Wormhole: if not mo: raise ValueError("code (%s) must start with NN-" % code) self.channel_id = int(mo.group(1)) + self._channel_url = "%s%d" % (self.relay, self.channel_id) self.code = code def _start(self): @@ -164,36 +157,56 @@ class Wormhole: self.msg1 = d["msg1"].decode("hex") return self - def _post_message(self, url, msg): + def _add_inbound_messages(self, messages): + for msg in messages: + phase = msg["phase"] + body = unhexlify(msg["body"].encode("ascii")) + self._messages.add( (phase, body) ) + + def _find_inbound_message(self, phase): + for (their_phase,body) in self._messages - self._sent_messages: + if their_phase == phase: + return body + return None + + def _send_message(self, phase, msg): # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. + if not isinstance(phase, type(u"")): raise UsageError(type(phase)) if not isinstance(msg, type(b"")): raise UsageError(type(msg)) - d = self._post_json(url, {"message": hexlify(msg).decode("ascii")}) - d.addCallback(lambda resp: resp["messages"]) # other_msgs + self._sent_messages.add( (phase,msg) ) + payload = {"side": self.side, + "phase": phase, + "body": hexlify(msg).decode("ascii")} + d = self._post_json(self._channel_url, payload) + d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"])) return d - def _get_message(self, old_msgs, verb, msgnum): - # fire with a bytestring of the first message that matches - # verb/msgnum, which either came from old_msgs, or from an - # EventSource that we attached to the corresponding URL - if old_msgs: - msg = unhexlify(old_msgs[0].encode("ascii")) - return defer.succeed(msg) + def _get_message(self, phase): + # fire with a bytestring of the first message for 'phase' that wasn't + # one of ours. It will either come from previously-received messages, + # or from an EventSource that we attach to the corresponding URL + body = self._find_inbound_message(phase) + if body is not None: + return defer.succeed(body) d = defer.Deferred() msgs = [] def _handle(name, data): if name == "welcome": self.handle_welcome(json.loads(data)) if name == "message": - msgs.append(json.loads(data)["message"]) - d.callback(None) - es = ReconnectingEventSource(None, lambda: self._url(verb, msgnum), - _handle)#, agent=self.agent) + self._add_inbound_messages([json.loads(data)]) + body = self._find_inbound_message(phase) + if body is not None and not msgs: + msgs.append(body) + d.callback(None) + # TODO: use agent=self.agent + es = ReconnectingEventSource(self._channel_url, _handle) es.startService() # TODO: .setServiceParent(self) es.activate() d.addCallback(lambda _: es.deactivate()) d.addCallback(lambda _: es.stopService()) - d.addCallback(lambda _: unhexlify(msgs[0].encode("ascii"))) + d.addCallback(lambda _: msgs[0]) return d def derive_key(self, purpose, length=SecretBox.KEY_SIZE): @@ -224,8 +237,8 @@ class Wormhole: # TODO: prevent multiple invocation if self.key: return defer.succeed(self.key) - d = self._post_message(self._url("post", "pake"), self.msg1) - d.addCallback(lambda msgs: self._get_message(msgs, "poll", "pake")) + d = self._send_message(u"pake", self.msg1) + d.addCallback(lambda _: self._get_message(u"pake")) def _got_pake(pake_msg): key = self.sp.finish(pake_msg) self.key = key @@ -256,12 +269,12 @@ class Wormhole: data_key = self.derive_key(b"data-key") outbound_encrypted = self._encrypt_data(data_key, outbound_data) - d = self._post_message(self._url("post", "data"), outbound_encrypted) + d = self._send_message(u"data", outbound_encrypted) - d.addCallback(lambda msgs: self._get_message(msgs, "poll", "data")) + d.addCallback(lambda _: self._get_message(u"data")) def _got_data(inbound_encrypted): - if inbound_encrypted == outbound_encrypted: - raise ReflectionAttack + #if inbound_encrypted == outbound_encrypted: + # raise ReflectionAttack try: inbound_data = self._decrypt_data(data_key, inbound_encrypted) return inbound_data @@ -272,6 +285,7 @@ class Wormhole: def _deallocate(self, res): # only try once, no retries - d = self.agent.request("POST", self._url("deallocate")) + d = self._post_json(self._channel_url+"/deallocate", + {"side": self.side}) d.addBoth(lambda _: res) # ignore POST failure, pass-through result return d