rewrite server API

This removes "side" and "msgnum" from the URLs, and puts them in a JSON
request body instead. The server now maintains a simple set of messages
for each channel-id, and isn't responsible for removing duplicates.

The client now fetches all messages, and just ignores everything it sent
itself. This removes the "reflection attack".

Deallocate now returns JSON, for consistency. DB and API use "phase" and
"body" instead of msgnum/message.

This changes the DB schema, so delete the DB before upgrading the server.
This commit is contained in:
Brian Warner 2015-10-03 12:36:14 -07:00
parent bc3b0f03b9
commit 617bb03ad5
6 changed files with 445 additions and 179 deletions

View File

@ -8,8 +8,7 @@ from nacl import utils
from .eventsource import EventSourceFollower from .eventsource import EventSourceFollower
from .. import __version__ from .. import __version__
from .. import codes from .. import codes
from ..errors import (ServerError, Timeout, WrongPasswordError, from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
ReflectionAttack, UsageError)
from ..util.hkdf import HKDF from ..util.hkdf import HKDF
SECOND = 1 SECOND = 1
@ -17,12 +16,13 @@ MINUTE = 60*SECOND
# relay URLs are: # relay URLs are:
# GET /list -> {channel-ids: [INT..]} # GET /list -> {channel-ids: [INT..]}
# POST /allocate/SIDE -> {channel-id: INT} # POST /allocate {side: SIDE} -> {channel-id: INT}
# these return all messages for CHANNEL-ID= and MSGNUM= but SIDE!= : # these return all messages (base64) for CID= :
# POST /CHANNEL-ID/SIDE/post/MSGNUM {message: STR} -> {messages: [STR..]} # POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]}
# POST /CHANNEL-ID/SIDE/poll/MSGNUM -> {messages: [STR..]} # GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]}
# GET /CHANNEL-ID/SIDE/poll/MSGNUM (eventsource) -> STR, STR, .. # GET /CID (eventsource) -> {phase:, body:}..
# POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted # POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted}
# all JSON responses include a "welcome:{..}" key
class Wormhole: class Wormhole:
motd_displayed = False motd_displayed = False
@ -40,12 +40,9 @@ class Wormhole:
self.code = None self.code = None
self.key = None self.key = None
self.verifier = None self.verifier = None
self._channel_url = None
def _url(self, verb, msgnum=None): self._messages = set() # (phase,body) , body is bytes
url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb) self._sent_messages = set() # (phase,body)
if msgnum is not None:
url += "/" + msgnum
return url
def handle_welcome(self, welcome): def handle_welcome(self, welcome):
if ("motd" in welcome and if ("motd" in welcome and
@ -69,18 +66,9 @@ class Wormhole:
if "error" in welcome: if "error" in welcome:
raise ServerError(welcome["error"], self.relay) raise ServerError(welcome["error"], self.relay)
def _post_json(self, url, post_json=None):
# POST to a URL, parsing the response as JSON. Optionally include a
# JSON request body.
data = None
if post_json:
data = json.dumps(post_json).encode("utf-8")
r = requests.post(url, data=data)
r.raise_for_status()
return r.json()
def _allocate_channel(self): def _allocate_channel(self):
r = requests.post(self.relay + "allocate/%s" % self.side) data = json.dumps({"side": self.side}).encode("utf-8")
r = requests.post(self.relay + "allocate", data=data)
r.raise_for_status() r.raise_for_status()
data = r.json() data = r.json()
if "welcome" in data: if "welcome" in data:
@ -123,6 +111,7 @@ class Wormhole:
if not mo: if not mo:
raise ValueError("code (%s) must start with NN-" % code) raise ValueError("code (%s) must start with NN-" % code)
self.channel_id = int(mo.group(1)) self.channel_id = int(mo.group(1))
self._channel_url = "%s%d" % (self.relay, self.channel_id)
self.code = code self.code = code
def _start(self): def _start(self):
@ -131,36 +120,62 @@ class Wormhole:
idSymmetric=self.appid) idSymmetric=self.appid)
self.msg1 = self.sp.start() self.msg1 = self.sp.start()
def _post_message(self, url, msg): def _add_inbound_messages(self, messages):
for msg in messages:
phase = msg["phase"]
body = unhexlify(msg["body"].encode("ascii"))
self._messages.add( (phase, body) )
def _find_inbound_message(self, phase):
for (their_phase,body) in self._messages - self._sent_messages:
if their_phase == phase:
return body
return None
def _send_message(self, phase, msg):
# TODO: retry on failure, with exponential backoff. We're guarding # TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline. # against the rendezvous server being temporarily offline.
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
if not isinstance(msg, type(b"")): raise UsageError(type(msg)) if not isinstance(msg, type(b"")): raise UsageError(type(msg))
resp = self._post_json(url, {"message": hexlify(msg).decode("ascii")}) self._sent_messages.add( (phase,msg) )
return resp["messages"] # other_msgs payload = {"side": self.side,
"phase": phase,
"body": hexlify(msg).decode("ascii")}
data = json.dumps(payload).encode("utf-8")
r = requests.post(self._channel_url, data=data)
r.raise_for_status()
resp = r.json()
self._add_inbound_messages(resp["messages"])
def _get_message(self, old_msgs, verb, msgnum): def _get_message(self, phase):
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
# For now, server errors cause the client to fail. TODO: don't. This # For now, server errors cause the client to fail. TODO: don't. This
# will require changing the client to re-post messages when the # will require changing the client to re-post messages when the
# server comes back up. # server comes back up.
# fire with a bytestring of the first message that matches # fire with a bytestring of the first message for 'phase' that wasn't
# verb/msgnum, which either came from old_msgs, or from an # one of ours. It will either come from previously-received messages,
# EventSource that we attached to the corresponding URL # or from an EventSource that we attach to the corresponding URL
msgs = old_msgs body = self._find_inbound_message(phase)
while not msgs: while body is None:
remaining = self.started + self.timeout - time.time() remaining = self.started + self.timeout - time.time()
if remaining < 0: if remaining < 0:
raise Timeout return Timeout
#time.sleep(self.wait) f = EventSourceFollower(self._channel_url, remaining)
f = EventSourceFollower(self._url(verb, msgnum), remaining) # we loop here until the connection is lost, or we see the
# message we want
for (eventtype, data) in f.iter_events(): for (eventtype, data) in f.iter_events():
if eventtype == "welcome": if eventtype == "welcome":
self.handle_welcome(json.loads(data)) self.handle_welcome(json.loads(data))
if eventtype == "message": if eventtype == "message":
msgs = [json.loads(data)["message"]] self._add_inbound_messages([json.loads(data)])
break body = self._find_inbound_message(phase)
if body:
f.close() f.close()
return unhexlify(msgs[0].encode("ascii")) break
if not body:
time.sleep(self.wait)
return body
def derive_key(self, purpose, length=SecretBox.KEY_SIZE): def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(purpose, type(b"")): raise UsageError if not isinstance(purpose, type(b"")): raise UsageError
@ -185,8 +200,8 @@ class Wormhole:
def _get_key(self): def _get_key(self):
if not self.key: if not self.key:
old_msgs = self._post_message(self._url("post", "pake"), self.msg1) self._send_message(u"pake", self.msg1)
pake_msg = self._get_message(old_msgs, "poll", "pake") pake_msg = self._get_message(u"pake")
self.key = self.sp.finish(pake_msg) self.key = self.sp.finish(pake_msg)
self.verifier = self.derive_key(self.appid+b":Verifier") self.verifier = self.derive_key(self.appid+b":Verifier")
@ -214,11 +229,12 @@ class Wormhole:
data_key = self.derive_key(b"data-key") data_key = self.derive_key(b"data-key")
outbound_encrypted = self._encrypt_data(data_key, outbound_data) outbound_encrypted = self._encrypt_data(data_key, outbound_data)
msgs = self._post_message(self._url("post", "data"), outbound_encrypted) self._send_message(u"data", outbound_encrypted)
inbound_encrypted = self._get_message(msgs, "poll", "data") inbound_encrypted = self._get_message(u"data")
if inbound_encrypted == outbound_encrypted: # _find_inbound_message() ignores any inbound message that matches
raise ReflectionAttack # something we previously sent out, so we don't need to explicitly
# check for reflection. A reflection attack will just not progress.
try: try:
inbound_data = self._decrypt_data(data_key, inbound_encrypted) inbound_data = self._decrypt_data(data_key, inbound_encrypted)
return inbound_data return inbound_data
@ -227,5 +243,6 @@ class Wormhole:
def _deallocate(self): def _deallocate(self):
# only try once, no retries # only try once, no retries
requests.post(self._url("deallocate")) data = json.dumps({"side": self.side}).encode("utf-8")
requests.post(self._channel_url+"/deallocate", data=data)
# ignore POST failure, don't call r.raise_for_status() # ignore POST failure, don't call r.raise_for_status()

View File

@ -11,11 +11,11 @@ CREATE TABLE `messages`
( (
`channel_id` INTEGER, `channel_id` INTEGER,
`side` VARCHAR, `side` VARCHAR,
`msgnum` VARCHAR, -- not numeric, more of a PAKE-phase indicator string `phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string
`message` VARCHAR, `body` VARCHAR,
`when` INTEGER `when` INTEGER
); );
CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `msgnum`); CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `phase`);
CREATE TABLE `allocations` CREATE TABLE `allocations`
( (

View File

@ -52,111 +52,100 @@ class EventsProtocol:
# relay URLs are: # relay URLs are:
# GET /list -> {channel-ids: [INT..]} # GET /list -> {channel-ids: [INT..]}
# POST /allocate/SIDE -> {channel-id: INT} # POST /allocate {side: SIDE} -> {channel-id: INT}
# these return all messages for CHANNEL-ID= and MSGNUM= but SIDE!= : # these return all messages (base64) for CID= :
# POST /CHANNEL-ID/SIDE/post/MSGNUM {message: STR} -> {messages: [STR..]} # POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]}
# POST /CHANNEL-ID/SIDE/poll/MSGNUM -> {messages: [STR..]} # GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]}
# GET /CHANNEL-ID/SIDE/poll/MSGNUM (eventsource) -> STR, STR, .. # GET /CID (eventsource) -> {phase:, body:}..
# POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted # POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted}
# all JSON responses include a "welcome:{..}" key
class Channel(resource.Resource): class Channel(resource.Resource):
isLeaf = True # I handle /CHANNEL-ID/*
def __init__(self, channel_id, relay, db, welcome): def __init__(self, channel_id, relay, db, welcome):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.channel_id = channel_id self.channel_id = channel_id
self.relay = relay self.relay = relay
self.db = db self.db = db
self.welcome = welcome self.welcome = welcome
self.event_channels = set() # (side, msgnum, ep) self.event_channels = set() # ep
self.putChild(b"deallocate", Deallocator(self.channel_id, self.relay))
def render_GET(self, request): def get_messages(self, request):
# rest of URL is: SIDE/poll/MSGNUM request.setHeader(b"content-type", b"application/json; charset=utf-8")
their_side = request.postpath[0].decode("utf-8") messages = []
if request.postpath[1] != b"poll":
request.setResponseCode(http.BAD_REQUEST, b"GET to wrong URL")
return b"GET is only for /SIDE/poll/MSGNUM"
their_msgnum = request.postpath[2].decode("utf-8")
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
request.setResponseCode(http.BAD_REQUEST, b"Must use EventSource")
return b"Must use EventSource (Content-Type: text/event-stream)"
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self.welcome), name="welcome")
handle = (their_side, their_msgnum, ep)
self.event_channels.add(handle)
request.notifyFinish().addErrback(self._shutdown, handle)
for row in self.db.execute("SELECT * FROM `messages`" for row in self.db.execute("SELECT * FROM `messages`"
" WHERE `channel_id`=?" " WHERE `channel_id`=?"
" ORDER BY `when` ASC", " ORDER BY `when` ASC",
(self.channel_id,)).fetchall(): (self.channel_id,)).fetchall():
self.message_added(row["side"], row["msgnum"], row["message"], messages.append({"phase": row["phase"], "body": row["body"]})
channels=[handle]) data = {"welcome": self.welcome, "messages": messages}
return (json.dumps(data)+"\n").encode("utf-8")
def render_GET(self, request):
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
return self.get_messages(request)
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self.welcome), name="welcome")
self.event_channels.add(ep)
request.notifyFinish().addErrback(lambda f:
self.event_channels.discard(ep))
for row in self.db.execute("SELECT * FROM `messages`"
" WHERE `channel_id`=?"
" ORDER BY `when` ASC",
(self.channel_id,)).fetchall():
data = json.dumps({"phase": row["phase"], "body": row["body"]})
ep.sendEvent(data)
return server.NOT_DONE_YET return server.NOT_DONE_YET
def _shutdown(self, _, handle): def broadcast_message(self, phase, body):
self.event_channels.discard(handle) data = json.dumps({"phase": phase, "body": body})
for ep in self.event_channels:
def message_added(self, msg_side, msg_msgnum, msg_str, channels=None): ep.sendEvent(data)
if channels is None:
channels = self.event_channels
for (their_side, their_msgnum, their_ep) in channels:
if msg_side != their_side and msg_msgnum == their_msgnum:
data = json.dumps({ "side": msg_side, "message": msg_str })
their_ep.sendEvent(data)
def render_POST(self, request): def render_POST(self, request):
# rest of URL is: SIDE/(MSGNUM|deallocate)/(post|poll)
side = request.postpath[0].decode("utf-8")
verb = request.postpath[1].decode("utf-8")
if verb == "deallocate":
deleted = self.relay.maybe_free_child(self.channel_id, side)
if deleted:
return b"deleted\n"
return b"waiting\n"
if verb not in ("post", "poll"):
request.setResponseCode(http.BAD_REQUEST)
return b"bad verb, want 'post' or 'poll'\n"
msgnum = request.postpath[2].decode("utf-8")
other_messages = []
for row in self.db.execute("SELECT `message` FROM `messages`"
" WHERE `channel_id`=? AND `side`!=?"
" AND `msgnum`=?"
" ORDER BY `when` ASC",
(self.channel_id, side, msgnum)).fetchall():
other_messages.append(row["message"])
if verb == "post":
#data = json.load(request.content, encoding="utf-8") #data = json.load(request.content, encoding="utf-8")
content = request.content.read() content = request.content.read()
data = json.loads(content.decode("utf-8")) data = json.loads(content.decode("utf-8"))
side = data["side"]
phase = data["phase"]
if not isinstance(phase, type(u"")):
raise TypeError("phase must be string, not %s" % type(phase))
body = data["body"]
self.db.execute("INSERT INTO `messages`" self.db.execute("INSERT INTO `messages`"
" (`channel_id`, `side`, `msgnum`, `message`, `when`)" " (`channel_id`, `side`, `phase`, `body`, `when`)"
" VALUES (?,?,?,?,?)", " VALUES (?,?,?,?,?)",
(self.channel_id, side, msgnum, data["message"], (self.channel_id, side, phase, body, time.time()))
time.time()))
self.db.execute("INSERT INTO `allocations`" self.db.execute("INSERT INTO `allocations`"
" (`channel_id`, `side`)" " (`channel_id`, `side`)"
" VALUES (?,?)", " VALUES (?,?)",
(self.channel_id, side)) (self.channel_id, side))
self.db.commit() self.db.commit()
self.message_added(side, msgnum, data["message"]) self.broadcast_message(phase, body)
return self.get_messages(request)
request.setHeader(b"content-type", b"application/json; charset=utf-8") class Deallocator(resource.Resource):
data = {"welcome": self.welcome, def __init__(self, channel_id, relay):
"messages": other_messages} self.channel_id = channel_id
return (json.dumps(data)+"\n").encode("utf-8") self.relay = relay
def render_POST(self, request):
content = request.content.read()
data = json.loads(content.decode("utf-8"))
side = data["side"]
deleted = self.relay.maybe_free_child(self.channel_id, side)
resp = {"status": "waiting"}
if deleted:
resp = {"status": "deleted"}
return json.dumps(resp).encode("utf-8")
def get_allocated(db): def get_allocated(db):
c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`") c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`")
return set([row["channel_id"] for row in c.fetchall()]) return set([row["channel_id"] for row in c.fetchall()])
class Allocator(resource.Resource): class Allocator(resource.Resource):
isLeaf = True
def __init__(self, db, welcome): def __init__(self, db, welcome):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.db = db self.db = db
@ -179,7 +168,11 @@ class Allocator(resource.Resource):
raise ValueError("unable to find a free channel-id") raise ValueError("unable to find a free channel-id")
def render_POST(self, request): def render_POST(self, request):
side = request.postpath[0] content = request.content.read()
data = json.loads(content.decode("utf-8"))
side = data["side"]
if not isinstance(side, type(u"")):
raise TypeError("side must be string, not '%s'" % type(side))
channel_id = self.allocate_channel_id() channel_id = self.allocate_channel_id()
self.db.execute("INSERT INTO `allocations` VALUES (?,?)", self.db.execute("INSERT INTO `allocations` VALUES (?,?)",
(channel_id, side)) (channel_id, side))

View File

@ -1,10 +1,13 @@
import sys from __future__ import print_function
import sys, json
import requests import requests
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import reactor from twisted.internet import reactor, defer
from twisted.internet.threads import deferToThread from twisted.internet.threads import deferToThread
from twisted.web.client import getPage, Agent, readBody from twisted.web.client import getPage, Agent, readBody
from .. import __version__
from .common import ServerBase from .common import ServerBase
from ..twisted.eventsource_twisted import EventSource
class Reachable(ServerBase, unittest.TestCase): class Reachable(ServerBase, unittest.TestCase):
@ -47,3 +50,243 @@ class Reachable(ServerBase, unittest.TestCase):
self.failUnlessEqual(res, "Wormhole Relay\n") self.failUnlessEqual(res, "Wormhole Relay\n")
d.addCallback(_got) d.addCallback(_got)
return d return d
def unjson(data):
return json.loads(data.decode("utf-8"))
class API(ServerBase, unittest.TestCase):
def get(self, path, is_json=True):
url = (self.relayurl+path).encode("ascii")
d = getPage(url)
if is_json:
d.addCallback(unjson)
return d
def post(self, path, data):
url = (self.relayurl+path).encode("ascii")
d = getPage(url, method=b"POST",
postdata=json.dumps(data).encode("utf-8"))
d.addCallback(unjson)
return d
def check_welcome(self, data):
self.failUnlessIn("welcome", data)
self.failUnlessEqual(data["welcome"], {"current_version": __version__})
def test_allocate_1(self):
d = self.get("list")
def _check_list_1(data):
self.check_welcome(data)
self.failUnlessEqual(data["channel-ids"], [])
d.addCallback(_check_list_1)
d.addCallback(lambda _: self.post("allocate", {"side": "abc"}))
def _allocated(data):
self.failUnlessEqual(set(data.keys()),
set(["welcome", "channel-id"]))
self.failUnlessIsInstance(data["channel-id"], int)
self.cid = data["channel-id"]
d.addCallback(_allocated)
d.addCallback(lambda _: self.get("list"))
def _check_list_2(data):
self.failUnlessEqual(data["channel-ids"], [self.cid])
d.addCallback(_check_list_2)
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
{"side": "abc"}))
def _check_deallocate(res):
self.failUnlessEqual(res["status"], "deleted")
d.addCallback(_check_deallocate)
d.addCallback(lambda _: self.get("list"))
def _check_list_3(data):
self.failUnlessEqual(data["channel-ids"], [])
d.addCallback(_check_list_3)
return d
def test_allocate_2(self):
d = self.post("allocate", {"side": "abc"})
def _allocated(data):
self.cid = data["channel-id"]
d.addCallback(_allocated)
# second caller increases the number of known sides to 2
d.addCallback(lambda _: self.post("%d" % self.cid,
{"side": "def",
"phase": "1",
"body": ""}))
d.addCallback(lambda _: self.get("list"))
d.addCallback(lambda data:
self.failUnlessEqual(data["channel-ids"], [self.cid]))
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
{"side": "abc"}))
d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "waiting"))
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
{"side": "NOT"}))
d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "waiting"))
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
{"side": "def"}))
d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "deleted"))
d.addCallback(lambda _: self.get("list"))
d.addCallback(lambda data:
self.failUnlessEqual(data["channel-ids"], []))
return d
def add_message(self, message, side="abc", phase="1"):
return self.post(str(self.cid), {"side": side, "phase": phase,
"body": message})
def parse_messages(self, messages):
out = set()
for m in messages:
self.failUnlessEqual(sorted(m.keys()), sorted(["phase", "body"]))
self.failUnlessIsInstance(m["phase"], type(u""))
self.failUnlessIsInstance(m["body"], type(u""))
out.add( (m["phase"], m["body"]) )
return out
def check_messages(self, one, two):
# Comparing lists-of-dicts is non-trivial in python3 because we can
# neither sort them (dicts are uncomparable), nor turn them into sets
# (dicts are unhashable). This is close enough.
self.failUnlessEqual(len(one), len(two), (one,two))
for d in one:
self.failUnlessIn(d, two)
def test_messages(self):
d = self.post("allocate", {"side": "abc"})
def _allocated(data):
self.cid = data["channel-id"]
d.addCallback(_allocated)
d.addCallback(lambda _: self.add_message("msg1A"))
def _check1(data):
self.check_welcome(data)
self.failUnlessEqual(data["messages"],
[{"phase": "1", "body": "msg1A"}])
d.addCallback(_check1)
d.addCallback(lambda _: self.add_message("msg1B", side="def"))
def _check2(data):
self.check_welcome(data)
self.failUnlessEqual(self.parse_messages(data["messages"]),
set([("1", "msg1A"),
("1", "msg1B")]))
d.addCallback(_check2)
# adding a duplicate message is not an error, is ignored by clients
d.addCallback(lambda _: self.add_message("msg1B", side="def"))
def _check3(data):
self.check_welcome(data)
self.failUnlessEqual(self.parse_messages(data["messages"]),
set([("1", "msg1A"),
("1", "msg1B")]))
d.addCallback(_check3)
d.addCallback(lambda _: self.add_message("msg2A", side="abc",
phase="2"))
def _check4(data):
self.check_welcome(data)
self.failUnlessEqual(self.parse_messages(data["messages"]),
set([("1", "msg1A"),
("1", "msg1B"),
("2", "msg2A"),
]))
d.addCallback(_check4)
return d
def test_eventsource(self):
if sys.version_info[0] >= 3:
raise unittest.SkipTest("twisted vs py3")
d = self.post("allocate", {"side": "abc"})
def _allocated(data):
self.cid = data["channel-id"]
url = (self.relayurl+str(self.cid)).encode("utf-8")
self.o = OneEventAtATime(url, parser=json.loads)
return self.o.wait_for_connection()
d.addCallback(_allocated)
d.addCallback(lambda _: self.o.wait_for_next_event())
def _check_welcome(ev):
eventtype, data = ev
self.failUnlessEqual(eventtype, "welcome")
self.failUnlessEqual(data, {"current_version": __version__})
d.addCallback(_check_welcome)
d.addCallback(lambda _: self.add_message("msg1A"))
d.addCallback(lambda _: self.o.wait_for_next_event())
def _check_msg1(ev):
eventtype, data = ev
self.failUnlessEqual(eventtype, "message")
self.failUnlessEqual(data, {"phase": "1", "body": "msg1A"})
d.addCallback(_check_msg1)
d.addCallback(lambda _: self.add_message("msg1B"))
d.addCallback(lambda _: self.add_message("msg2A", phase="2"))
d.addCallback(lambda _: self.o.wait_for_next_event())
def _check_msg2(ev):
eventtype, data = ev
self.failUnlessEqual(eventtype, "message")
self.failUnlessEqual(data, {"phase": "1", "body": "msg1B"})
d.addCallback(_check_msg2)
d.addCallback(lambda _: self.o.wait_for_next_event())
def _check_msg3(ev):
eventtype, data = ev
self.failUnlessEqual(eventtype, "message")
self.failUnlessEqual(data, {"phase": "2", "body": "msg2A"})
d.addCallback(_check_msg3)
d.addCallback(lambda _: self.o.close())
d.addCallback(lambda _: self.o.wait_for_disconnection())
return d
class OneEventAtATime:
def __init__(self, url, parser=lambda e: e):
self.parser = parser
self.d = None
self.connected_d = defer.Deferred()
self.disconnected_d = defer.Deferred()
self.events = []
self.es = EventSource(url, self.handler, when_connected=self.connected)
d = self.es.start()
d.addBoth(self.disconnected)
def close(self):
self.es.cancel()
def wait_for_next_event(self):
assert not self.d
if self.events:
event = self.events.pop(0)
return defer.succeed(event)
self.d = defer.Deferred()
return self.d
def handler(self, eventtype, data):
event = (eventtype, self.parser(data))
if self.d:
assert not self.events
d,self.d = self.d,None
d.callback(event)
return
self.events.append(event)
def wait_for_connection(self):
return self.connected_d
def connected(self):
self.connected_d.callback(None)
def wait_for_disconnection(self):
return self.disconnected_d
def disconnected(self, why):
self.disconnected_d.callback((why,))

View File

@ -93,6 +93,7 @@ class EventSource: # TODO: service.Service
def start(self): def start(self):
assert not self.started, "single-use" assert not self.started, "single-use"
self.started = True self.started = True
assert self.url
d = self.agent.request("GET", self.url, d = self.agent.request("GET", self.url,
Headers({"accept": ["text/event-stream"]})) Headers({"accept": ["text/event-stream"]}))
d.addCallback(self._connected) d.addCallback(self._connected)
@ -154,14 +155,13 @@ class Connector:
class ReconnectingEventSource(service.MultiService, class ReconnectingEventSource(service.MultiService,
protocol.ReconnectingClientFactory): protocol.ReconnectingClientFactory):
def __init__(self, baseurl, connection_starting, handler, agent=None): def __init__(self, url, handler, agent=None):
service.MultiService.__init__(self) service.MultiService.__init__(self)
# we don't use any of the basic Factory/ClientFactory methods of # we don't use any of the basic Factory/ClientFactory methods of
# this, just the ReconnectingClientFactory.retry, stopTrying, and # this, just the ReconnectingClientFactory.retry, stopTrying, and
# resetDelay methods. # resetDelay methods.
self.baseurl = baseurl self.url = url
self.connection_starting = connection_starting
self.handler = handler self.handler = handler
self.agent = agent self.agent = agent
# IService provides self.running, toggled by {start,stop}Service. # IService provides self.running, toggled by {start,stop}Service.
@ -201,8 +201,7 @@ class ReconnectingEventSource(service.MultiService,
if not (self.active and self.running): if not (self.active and self.running):
return return
self.continueTrying = True self.continueTrying = True
url = self.connection_starting() self.es = EventSource(self.url, self.handler, self.resetDelay,
self.es = EventSource(url, self.handler, self.resetDelay,
agent=self.agent) agent=self.agent)
d = self.es.start() d = self.es.start()
d.addBoth(self._stopped) d.addBoth(self._stopped)

View File

@ -13,8 +13,7 @@ from spake2 import SPAKE2_Symmetric
from .eventsource_twisted import ReconnectingEventSource from .eventsource_twisted import ReconnectingEventSource
from .. import __version__ from .. import __version__
from .. import codes from .. import codes
from ..errors import (ServerError, WrongPasswordError, from ..errors import ServerError, WrongPasswordError, UsageError
ReflectionAttack, UsageError)
from ..util.hkdf import HKDF from ..util.hkdf import HKDF
@implementer(IBodyProducer) @implementer(IBodyProducer)
@ -46,12 +45,9 @@ class Wormhole:
self.code = None self.code = None
self.key = None self.key = None
self._started_get_code = False self._started_get_code = False
self._channel_url = None
def _url(self, verb, msgnum=None): self._messages = set() # (phase,body) , body is bytes
url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb) self._sent_messages = set() # (phase,body)
if msgnum is not None:
url += "/" + msgnum
return url
def handle_welcome(self, welcome): def handle_welcome(self, welcome):
if ("motd" in welcome and if ("motd" in welcome and
@ -75,14 +71,10 @@ class Wormhole:
if "error" in welcome: if "error" in welcome:
raise ServerError(welcome["error"], self.relay) raise ServerError(welcome["error"], self.relay)
def _post_json(self, url, post_json=None): def _post_json(self, url, post_json):
# POST to a URL, parsing the response as JSON. Optionally include a # POST a JSON body to a URL, parsing the response as JSON
# JSON request body.
p = None
if post_json:
data = json.dumps(post_json).encode("utf-8") data = json.dumps(post_json).encode("utf-8")
p = DataProducer(data) d = self.agent.request("POST", url, bodyProducer=DataProducer(data))
d = self.agent.request("POST", url, bodyProducer=p)
def _check_error(resp): def _check_error(resp):
if resp.code != 200: if resp.code != 200:
raise web_error.Error(resp.code, resp.phrase) raise web_error.Error(resp.code, resp.phrase)
@ -93,8 +85,8 @@ class Wormhole:
return d return d
def _allocate_channel(self): def _allocate_channel(self):
url = self.relay + "allocate/%s" % self.side url = self.relay + "allocate"
d = self._post_json(url) d = self._post_json(url, {"side": self.side})
def _got_channel(data): def _got_channel(data):
if "welcome" in data: if "welcome" in data:
self.handle_welcome(data["welcome"]) self.handle_welcome(data["welcome"])
@ -131,6 +123,7 @@ class Wormhole:
if not mo: if not mo:
raise ValueError("code (%s) must start with NN-" % code) raise ValueError("code (%s) must start with NN-" % code)
self.channel_id = int(mo.group(1)) self.channel_id = int(mo.group(1))
self._channel_url = "%s%d" % (self.relay, self.channel_id)
self.code = code self.code = code
def _start(self): def _start(self):
@ -164,36 +157,56 @@ class Wormhole:
self.msg1 = d["msg1"].decode("hex") self.msg1 = d["msg1"].decode("hex")
return self return self
def _post_message(self, url, msg): def _add_inbound_messages(self, messages):
for msg in messages:
phase = msg["phase"]
body = unhexlify(msg["body"].encode("ascii"))
self._messages.add( (phase, body) )
def _find_inbound_message(self, phase):
for (their_phase,body) in self._messages - self._sent_messages:
if their_phase == phase:
return body
return None
def _send_message(self, phase, msg):
# TODO: retry on failure, with exponential backoff. We're guarding # TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline. # against the rendezvous server being temporarily offline.
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
if not isinstance(msg, type(b"")): raise UsageError(type(msg)) if not isinstance(msg, type(b"")): raise UsageError(type(msg))
d = self._post_json(url, {"message": hexlify(msg).decode("ascii")}) self._sent_messages.add( (phase,msg) )
d.addCallback(lambda resp: resp["messages"]) # other_msgs payload = {"side": self.side,
"phase": phase,
"body": hexlify(msg).decode("ascii")}
d = self._post_json(self._channel_url, payload)
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
return d return d
def _get_message(self, old_msgs, verb, msgnum): def _get_message(self, phase):
# fire with a bytestring of the first message that matches # fire with a bytestring of the first message for 'phase' that wasn't
# verb/msgnum, which either came from old_msgs, or from an # one of ours. It will either come from previously-received messages,
# EventSource that we attached to the corresponding URL # or from an EventSource that we attach to the corresponding URL
if old_msgs: body = self._find_inbound_message(phase)
msg = unhexlify(old_msgs[0].encode("ascii")) if body is not None:
return defer.succeed(msg) return defer.succeed(body)
d = defer.Deferred() d = defer.Deferred()
msgs = [] msgs = []
def _handle(name, data): def _handle(name, data):
if name == "welcome": if name == "welcome":
self.handle_welcome(json.loads(data)) self.handle_welcome(json.loads(data))
if name == "message": if name == "message":
msgs.append(json.loads(data)["message"]) self._add_inbound_messages([json.loads(data)])
body = self._find_inbound_message(phase)
if body is not None and not msgs:
msgs.append(body)
d.callback(None) d.callback(None)
es = ReconnectingEventSource(None, lambda: self._url(verb, msgnum), # TODO: use agent=self.agent
_handle)#, agent=self.agent) es = ReconnectingEventSource(self._channel_url, _handle)
es.startService() # TODO: .setServiceParent(self) es.startService() # TODO: .setServiceParent(self)
es.activate() es.activate()
d.addCallback(lambda _: es.deactivate()) d.addCallback(lambda _: es.deactivate())
d.addCallback(lambda _: es.stopService()) d.addCallback(lambda _: es.stopService())
d.addCallback(lambda _: unhexlify(msgs[0].encode("ascii"))) d.addCallback(lambda _: msgs[0])
return d return d
def derive_key(self, purpose, length=SecretBox.KEY_SIZE): def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
@ -224,8 +237,8 @@ class Wormhole:
# TODO: prevent multiple invocation # TODO: prevent multiple invocation
if self.key: if self.key:
return defer.succeed(self.key) return defer.succeed(self.key)
d = self._post_message(self._url("post", "pake"), self.msg1) d = self._send_message(u"pake", self.msg1)
d.addCallback(lambda msgs: self._get_message(msgs, "poll", "pake")) d.addCallback(lambda _: self._get_message(u"pake"))
def _got_pake(pake_msg): def _got_pake(pake_msg):
key = self.sp.finish(pake_msg) key = self.sp.finish(pake_msg)
self.key = key self.key = key
@ -256,12 +269,12 @@ class Wormhole:
data_key = self.derive_key(b"data-key") data_key = self.derive_key(b"data-key")
outbound_encrypted = self._encrypt_data(data_key, outbound_data) outbound_encrypted = self._encrypt_data(data_key, outbound_data)
d = self._post_message(self._url("post", "data"), outbound_encrypted) d = self._send_message(u"data", outbound_encrypted)
d.addCallback(lambda msgs: self._get_message(msgs, "poll", "data")) d.addCallback(lambda _: self._get_message(u"data"))
def _got_data(inbound_encrypted): def _got_data(inbound_encrypted):
if inbound_encrypted == outbound_encrypted: #if inbound_encrypted == outbound_encrypted:
raise ReflectionAttack # raise ReflectionAttack
try: try:
inbound_data = self._decrypt_data(data_key, inbound_encrypted) inbound_data = self._decrypt_data(data_key, inbound_encrypted)
return inbound_data return inbound_data
@ -272,6 +285,7 @@ class Wormhole:
def _deallocate(self, res): def _deallocate(self, res):
# only try once, no retries # only try once, no retries
d = self.agent.request("POST", self._url("deallocate")) d = self._post_json(self._channel_url+"/deallocate",
{"side": self.side})
d.addBoth(lambda _: res) # ignore POST failure, pass-through result d.addBoth(lambda _: res) # ignore POST failure, pass-through result
return d return d