rewrite server API

This removes "side" and "msgnum" from the URLs, and puts them in a JSON
request body instead. The server now maintains a simple set of messages
for each channel-id, and isn't responsible for removing duplicates.

The client now fetches all messages, and just ignores everything it sent
itself. This removes the "reflection attack".

Deallocate now returns JSON, for consistency. DB and API use "phase" and
"body" instead of msgnum/message.

This changes the DB schema, so delete the DB before upgrading the server.
This commit is contained in:
Brian Warner 2015-10-03 12:36:14 -07:00
parent bc3b0f03b9
commit 617bb03ad5
6 changed files with 445 additions and 179 deletions

View File

@ -8,21 +8,21 @@ from nacl import utils
from .eventsource import EventSourceFollower
from .. import __version__
from .. import codes
from ..errors import (ServerError, Timeout, WrongPasswordError,
ReflectionAttack, UsageError)
from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
from ..util.hkdf import HKDF
SECOND = 1
MINUTE = 60*SECOND
# relay URLs are:
# GET /list -> {channel-ids: [INT..]}
# POST /allocate/SIDE -> {channel-id: INT}
# these return all messages for CHANNEL-ID= and MSGNUM= but SIDE!= :
# POST /CHANNEL-ID/SIDE/post/MSGNUM {message: STR} -> {messages: [STR..]}
# POST /CHANNEL-ID/SIDE/poll/MSGNUM -> {messages: [STR..]}
# GET /CHANNEL-ID/SIDE/poll/MSGNUM (eventsource) -> STR, STR, ..
# POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted
# GET /list -> {channel-ids: [INT..]}
# POST /allocate {side: SIDE} -> {channel-id: INT}
# these return all messages (base64) for CID= :
# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]}
# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]}
# GET /CID (eventsource) -> {phase:, body:}..
# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted}
# all JSON responses include a "welcome:{..}" key
class Wormhole:
motd_displayed = False
@ -40,12 +40,9 @@ class Wormhole:
self.code = None
self.key = None
self.verifier = None
def _url(self, verb, msgnum=None):
url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb)
if msgnum is not None:
url += "/" + msgnum
return url
self._channel_url = None
self._messages = set() # (phase,body) , body is bytes
self._sent_messages = set() # (phase,body)
def handle_welcome(self, welcome):
if ("motd" in welcome and
@ -69,18 +66,9 @@ class Wormhole:
if "error" in welcome:
raise ServerError(welcome["error"], self.relay)
def _post_json(self, url, post_json=None):
# POST to a URL, parsing the response as JSON. Optionally include a
# JSON request body.
data = None
if post_json:
data = json.dumps(post_json).encode("utf-8")
r = requests.post(url, data=data)
r.raise_for_status()
return r.json()
def _allocate_channel(self):
r = requests.post(self.relay + "allocate/%s" % self.side)
data = json.dumps({"side": self.side}).encode("utf-8")
r = requests.post(self.relay + "allocate", data=data)
r.raise_for_status()
data = r.json()
if "welcome" in data:
@ -123,6 +111,7 @@ class Wormhole:
if not mo:
raise ValueError("code (%s) must start with NN-" % code)
self.channel_id = int(mo.group(1))
self._channel_url = "%s%d" % (self.relay, self.channel_id)
self.code = code
def _start(self):
@ -131,36 +120,62 @@ class Wormhole:
idSymmetric=self.appid)
self.msg1 = self.sp.start()
def _post_message(self, url, msg):
def _add_inbound_messages(self, messages):
for msg in messages:
phase = msg["phase"]
body = unhexlify(msg["body"].encode("ascii"))
self._messages.add( (phase, body) )
def _find_inbound_message(self, phase):
for (their_phase,body) in self._messages - self._sent_messages:
if their_phase == phase:
return body
return None
def _send_message(self, phase, msg):
# TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline.
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
resp = self._post_json(url, {"message": hexlify(msg).decode("ascii")})
return resp["messages"] # other_msgs
self._sent_messages.add( (phase,msg) )
payload = {"side": self.side,
"phase": phase,
"body": hexlify(msg).decode("ascii")}
data = json.dumps(payload).encode("utf-8")
r = requests.post(self._channel_url, data=data)
r.raise_for_status()
resp = r.json()
self._add_inbound_messages(resp["messages"])
def _get_message(self, old_msgs, verb, msgnum):
def _get_message(self, phase):
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
# For now, server errors cause the client to fail. TODO: don't. This
# will require changing the client to re-post messages when the
# server comes back up.
# fire with a bytestring of the first message that matches
# verb/msgnum, which either came from old_msgs, or from an
# EventSource that we attached to the corresponding URL
msgs = old_msgs
while not msgs:
# fire with a bytestring of the first message for 'phase' that wasn't
# one of ours. It will either come from previously-received messages,
# or from an EventSource that we attach to the corresponding URL
body = self._find_inbound_message(phase)
while body is None:
remaining = self.started + self.timeout - time.time()
if remaining < 0:
raise Timeout
#time.sleep(self.wait)
f = EventSourceFollower(self._url(verb, msgnum), remaining)
return Timeout
f = EventSourceFollower(self._channel_url, remaining)
# we loop here until the connection is lost, or we see the
# message we want
for (eventtype, data) in f.iter_events():
if eventtype == "welcome":
self.handle_welcome(json.loads(data))
if eventtype == "message":
msgs = [json.loads(data)["message"]]
break
f.close()
return unhexlify(msgs[0].encode("ascii"))
self._add_inbound_messages([json.loads(data)])
body = self._find_inbound_message(phase)
if body:
f.close()
break
if not body:
time.sleep(self.wait)
return body
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(purpose, type(b"")): raise UsageError
@ -185,8 +200,8 @@ class Wormhole:
def _get_key(self):
if not self.key:
old_msgs = self._post_message(self._url("post", "pake"), self.msg1)
pake_msg = self._get_message(old_msgs, "poll", "pake")
self._send_message(u"pake", self.msg1)
pake_msg = self._get_message(u"pake")
self.key = self.sp.finish(pake_msg)
self.verifier = self.derive_key(self.appid+b":Verifier")
@ -214,11 +229,12 @@ class Wormhole:
data_key = self.derive_key(b"data-key")
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
msgs = self._post_message(self._url("post", "data"), outbound_encrypted)
self._send_message(u"data", outbound_encrypted)
inbound_encrypted = self._get_message(msgs, "poll", "data")
if inbound_encrypted == outbound_encrypted:
raise ReflectionAttack
inbound_encrypted = self._get_message(u"data")
# _find_inbound_message() ignores any inbound message that matches
# something we previously sent out, so we don't need to explicitly
# check for reflection. A reflection attack will just not progress.
try:
inbound_data = self._decrypt_data(data_key, inbound_encrypted)
return inbound_data
@ -227,5 +243,6 @@ class Wormhole:
def _deallocate(self):
# only try once, no retries
requests.post(self._url("deallocate"))
data = json.dumps({"side": self.side}).encode("utf-8")
requests.post(self._channel_url+"/deallocate", data=data)
# ignore POST failure, don't call r.raise_for_status()

View File

@ -11,11 +11,11 @@ CREATE TABLE `messages`
(
`channel_id` INTEGER,
`side` VARCHAR,
`msgnum` VARCHAR, -- not numeric, more of a PAKE-phase indicator string
`message` VARCHAR,
`phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string
`body` VARCHAR,
`when` INTEGER
);
CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `msgnum`);
CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `phase`);
CREATE TABLE `allocations`
(

View File

@ -51,112 +51,101 @@ class EventsProtocol:
# note: no versions of IE (including the current IE11) support EventSource
# relay URLs are:
# GET /list -> {channel-ids: [INT..]}
# POST /allocate/SIDE -> {channel-id: INT}
# these return all messages for CHANNEL-ID= and MSGNUM= but SIDE!= :
# POST /CHANNEL-ID/SIDE/post/MSGNUM {message: STR} -> {messages: [STR..]}
# POST /CHANNEL-ID/SIDE/poll/MSGNUM -> {messages: [STR..]}
# GET /CHANNEL-ID/SIDE/poll/MSGNUM (eventsource) -> STR, STR, ..
# POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted
# GET /list -> {channel-ids: [INT..]}
# POST /allocate {side: SIDE} -> {channel-id: INT}
# these return all messages (base64) for CID= :
# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]}
# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]}
# GET /CID (eventsource) -> {phase:, body:}..
# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted}
# all JSON responses include a "welcome:{..}" key
class Channel(resource.Resource):
isLeaf = True # I handle /CHANNEL-ID/*
def __init__(self, channel_id, relay, db, welcome):
resource.Resource.__init__(self)
self.channel_id = channel_id
self.relay = relay
self.db = db
self.welcome = welcome
self.event_channels = set() # (side, msgnum, ep)
self.event_channels = set() # ep
self.putChild(b"deallocate", Deallocator(self.channel_id, self.relay))
def render_GET(self, request):
# rest of URL is: SIDE/poll/MSGNUM
their_side = request.postpath[0].decode("utf-8")
if request.postpath[1] != b"poll":
request.setResponseCode(http.BAD_REQUEST, b"GET to wrong URL")
return b"GET is only for /SIDE/poll/MSGNUM"
their_msgnum = request.postpath[2].decode("utf-8")
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
request.setResponseCode(http.BAD_REQUEST, b"Must use EventSource")
return b"Must use EventSource (Content-Type: text/event-stream)"
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self.welcome), name="welcome")
handle = (their_side, their_msgnum, ep)
self.event_channels.add(handle)
request.notifyFinish().addErrback(self._shutdown, handle)
def get_messages(self, request):
request.setHeader(b"content-type", b"application/json; charset=utf-8")
messages = []
for row in self.db.execute("SELECT * FROM `messages`"
" WHERE `channel_id`=?"
" ORDER BY `when` ASC",
(self.channel_id,)).fetchall():
self.message_added(row["side"], row["msgnum"], row["message"],
channels=[handle])
messages.append({"phase": row["phase"], "body": row["body"]})
data = {"welcome": self.welcome, "messages": messages}
return (json.dumps(data)+"\n").encode("utf-8")
def render_GET(self, request):
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
return self.get_messages(request)
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self.welcome), name="welcome")
self.event_channels.add(ep)
request.notifyFinish().addErrback(lambda f:
self.event_channels.discard(ep))
for row in self.db.execute("SELECT * FROM `messages`"
" WHERE `channel_id`=?"
" ORDER BY `when` ASC",
(self.channel_id,)).fetchall():
data = json.dumps({"phase": row["phase"], "body": row["body"]})
ep.sendEvent(data)
return server.NOT_DONE_YET
def _shutdown(self, _, handle):
self.event_channels.discard(handle)
def message_added(self, msg_side, msg_msgnum, msg_str, channels=None):
if channels is None:
channels = self.event_channels
for (their_side, their_msgnum, their_ep) in channels:
if msg_side != their_side and msg_msgnum == their_msgnum:
data = json.dumps({ "side": msg_side, "message": msg_str })
their_ep.sendEvent(data)
def broadcast_message(self, phase, body):
data = json.dumps({"phase": phase, "body": body})
for ep in self.event_channels:
ep.sendEvent(data)
def render_POST(self, request):
# rest of URL is: SIDE/(MSGNUM|deallocate)/(post|poll)
side = request.postpath[0].decode("utf-8")
verb = request.postpath[1].decode("utf-8")
#data = json.load(request.content, encoding="utf-8")
content = request.content.read()
data = json.loads(content.decode("utf-8"))
if verb == "deallocate":
deleted = self.relay.maybe_free_child(self.channel_id, side)
if deleted:
return b"deleted\n"
return b"waiting\n"
side = data["side"]
phase = data["phase"]
if not isinstance(phase, type(u"")):
raise TypeError("phase must be string, not %s" % type(phase))
body = data["body"]
if verb not in ("post", "poll"):
request.setResponseCode(http.BAD_REQUEST)
return b"bad verb, want 'post' or 'poll'\n"
self.db.execute("INSERT INTO `messages`"
" (`channel_id`, `side`, `phase`, `body`, `when`)"
" VALUES (?,?,?,?,?)",
(self.channel_id, side, phase, body, time.time()))
self.db.execute("INSERT INTO `allocations`"
" (`channel_id`, `side`)"
" VALUES (?,?)",
(self.channel_id, side))
self.db.commit()
self.broadcast_message(phase, body)
return self.get_messages(request)
msgnum = request.postpath[2].decode("utf-8")
class Deallocator(resource.Resource):
def __init__(self, channel_id, relay):
self.channel_id = channel_id
self.relay = relay
other_messages = []
for row in self.db.execute("SELECT `message` FROM `messages`"
" WHERE `channel_id`=? AND `side`!=?"
" AND `msgnum`=?"
" ORDER BY `when` ASC",
(self.channel_id, side, msgnum)).fetchall():
other_messages.append(row["message"])
if verb == "post":
#data = json.load(request.content, encoding="utf-8")
content = request.content.read()
data = json.loads(content.decode("utf-8"))
self.db.execute("INSERT INTO `messages`"
" (`channel_id`, `side`, `msgnum`, `message`, `when`)"
" VALUES (?,?,?,?,?)",
(self.channel_id, side, msgnum, data["message"],
time.time()))
self.db.execute("INSERT INTO `allocations`"
" (`channel_id`, `side`)"
" VALUES (?,?)",
(self.channel_id, side))
self.db.commit()
self.message_added(side, msgnum, data["message"])
request.setHeader(b"content-type", b"application/json; charset=utf-8")
data = {"welcome": self.welcome,
"messages": other_messages}
return (json.dumps(data)+"\n").encode("utf-8")
def render_POST(self, request):
content = request.content.read()
data = json.loads(content.decode("utf-8"))
side = data["side"]
deleted = self.relay.maybe_free_child(self.channel_id, side)
resp = {"status": "waiting"}
if deleted:
resp = {"status": "deleted"}
return json.dumps(resp).encode("utf-8")
def get_allocated(db):
c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`")
return set([row["channel_id"] for row in c.fetchall()])
class Allocator(resource.Resource):
isLeaf = True
def __init__(self, db, welcome):
resource.Resource.__init__(self)
self.db = db
@ -179,7 +168,11 @@ class Allocator(resource.Resource):
raise ValueError("unable to find a free channel-id")
def render_POST(self, request):
side = request.postpath[0]
content = request.content.read()
data = json.loads(content.decode("utf-8"))
side = data["side"]
if not isinstance(side, type(u"")):
raise TypeError("side must be string, not '%s'" % type(side))
channel_id = self.allocate_channel_id()
self.db.execute("INSERT INTO `allocations` VALUES (?,?)",
(channel_id, side))

View File

@ -1,10 +1,13 @@
import sys
from __future__ import print_function
import sys, json
import requests
from twisted.trial import unittest
from twisted.internet import reactor
from twisted.internet import reactor, defer
from twisted.internet.threads import deferToThread
from twisted.web.client import getPage, Agent, readBody
from .. import __version__
from .common import ServerBase
from ..twisted.eventsource_twisted import EventSource
class Reachable(ServerBase, unittest.TestCase):
@ -47,3 +50,243 @@ class Reachable(ServerBase, unittest.TestCase):
self.failUnlessEqual(res, "Wormhole Relay\n")
d.addCallback(_got)
return d
def unjson(data):
return json.loads(data.decode("utf-8"))
class API(ServerBase, unittest.TestCase):
def get(self, path, is_json=True):
url = (self.relayurl+path).encode("ascii")
d = getPage(url)
if is_json:
d.addCallback(unjson)
return d
def post(self, path, data):
url = (self.relayurl+path).encode("ascii")
d = getPage(url, method=b"POST",
postdata=json.dumps(data).encode("utf-8"))
d.addCallback(unjson)
return d
def check_welcome(self, data):
self.failUnlessIn("welcome", data)
self.failUnlessEqual(data["welcome"], {"current_version": __version__})
def test_allocate_1(self):
d = self.get("list")
def _check_list_1(data):
self.check_welcome(data)
self.failUnlessEqual(data["channel-ids"], [])
d.addCallback(_check_list_1)
d.addCallback(lambda _: self.post("allocate", {"side": "abc"}))
def _allocated(data):
self.failUnlessEqual(set(data.keys()),
set(["welcome", "channel-id"]))
self.failUnlessIsInstance(data["channel-id"], int)
self.cid = data["channel-id"]
d.addCallback(_allocated)
d.addCallback(lambda _: self.get("list"))
def _check_list_2(data):
self.failUnlessEqual(data["channel-ids"], [self.cid])
d.addCallback(_check_list_2)
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
{"side": "abc"}))
def _check_deallocate(res):
self.failUnlessEqual(res["status"], "deleted")
d.addCallback(_check_deallocate)
d.addCallback(lambda _: self.get("list"))
def _check_list_3(data):
self.failUnlessEqual(data["channel-ids"], [])
d.addCallback(_check_list_3)
return d
def test_allocate_2(self):
d = self.post("allocate", {"side": "abc"})
def _allocated(data):
self.cid = data["channel-id"]
d.addCallback(_allocated)
# second caller increases the number of known sides to 2
d.addCallback(lambda _: self.post("%d" % self.cid,
{"side": "def",
"phase": "1",
"body": ""}))
d.addCallback(lambda _: self.get("list"))
d.addCallback(lambda data:
self.failUnlessEqual(data["channel-ids"], [self.cid]))
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
{"side": "abc"}))
d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "waiting"))
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
{"side": "NOT"}))
d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "waiting"))
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
{"side": "def"}))
d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "deleted"))
d.addCallback(lambda _: self.get("list"))
d.addCallback(lambda data:
self.failUnlessEqual(data["channel-ids"], []))
return d
def add_message(self, message, side="abc", phase="1"):
return self.post(str(self.cid), {"side": side, "phase": phase,
"body": message})
def parse_messages(self, messages):
out = set()
for m in messages:
self.failUnlessEqual(sorted(m.keys()), sorted(["phase", "body"]))
self.failUnlessIsInstance(m["phase"], type(u""))
self.failUnlessIsInstance(m["body"], type(u""))
out.add( (m["phase"], m["body"]) )
return out
def check_messages(self, one, two):
# Comparing lists-of-dicts is non-trivial in python3 because we can
# neither sort them (dicts are uncomparable), nor turn them into sets
# (dicts are unhashable). This is close enough.
self.failUnlessEqual(len(one), len(two), (one,two))
for d in one:
self.failUnlessIn(d, two)
def test_messages(self):
d = self.post("allocate", {"side": "abc"})
def _allocated(data):
self.cid = data["channel-id"]
d.addCallback(_allocated)
d.addCallback(lambda _: self.add_message("msg1A"))
def _check1(data):
self.check_welcome(data)
self.failUnlessEqual(data["messages"],
[{"phase": "1", "body": "msg1A"}])
d.addCallback(_check1)
d.addCallback(lambda _: self.add_message("msg1B", side="def"))
def _check2(data):
self.check_welcome(data)
self.failUnlessEqual(self.parse_messages(data["messages"]),
set([("1", "msg1A"),
("1", "msg1B")]))
d.addCallback(_check2)
# adding a duplicate message is not an error, is ignored by clients
d.addCallback(lambda _: self.add_message("msg1B", side="def"))
def _check3(data):
self.check_welcome(data)
self.failUnlessEqual(self.parse_messages(data["messages"]),
set([("1", "msg1A"),
("1", "msg1B")]))
d.addCallback(_check3)
d.addCallback(lambda _: self.add_message("msg2A", side="abc",
phase="2"))
def _check4(data):
self.check_welcome(data)
self.failUnlessEqual(self.parse_messages(data["messages"]),
set([("1", "msg1A"),
("1", "msg1B"),
("2", "msg2A"),
]))
d.addCallback(_check4)
return d
def test_eventsource(self):
if sys.version_info[0] >= 3:
raise unittest.SkipTest("twisted vs py3")
d = self.post("allocate", {"side": "abc"})
def _allocated(data):
self.cid = data["channel-id"]
url = (self.relayurl+str(self.cid)).encode("utf-8")
self.o = OneEventAtATime(url, parser=json.loads)
return self.o.wait_for_connection()
d.addCallback(_allocated)
d.addCallback(lambda _: self.o.wait_for_next_event())
def _check_welcome(ev):
eventtype, data = ev
self.failUnlessEqual(eventtype, "welcome")
self.failUnlessEqual(data, {"current_version": __version__})
d.addCallback(_check_welcome)
d.addCallback(lambda _: self.add_message("msg1A"))
d.addCallback(lambda _: self.o.wait_for_next_event())
def _check_msg1(ev):
eventtype, data = ev
self.failUnlessEqual(eventtype, "message")
self.failUnlessEqual(data, {"phase": "1", "body": "msg1A"})
d.addCallback(_check_msg1)
d.addCallback(lambda _: self.add_message("msg1B"))
d.addCallback(lambda _: self.add_message("msg2A", phase="2"))
d.addCallback(lambda _: self.o.wait_for_next_event())
def _check_msg2(ev):
eventtype, data = ev
self.failUnlessEqual(eventtype, "message")
self.failUnlessEqual(data, {"phase": "1", "body": "msg1B"})
d.addCallback(_check_msg2)
d.addCallback(lambda _: self.o.wait_for_next_event())
def _check_msg3(ev):
eventtype, data = ev
self.failUnlessEqual(eventtype, "message")
self.failUnlessEqual(data, {"phase": "2", "body": "msg2A"})
d.addCallback(_check_msg3)
d.addCallback(lambda _: self.o.close())
d.addCallback(lambda _: self.o.wait_for_disconnection())
return d
class OneEventAtATime:
def __init__(self, url, parser=lambda e: e):
self.parser = parser
self.d = None
self.connected_d = defer.Deferred()
self.disconnected_d = defer.Deferred()
self.events = []
self.es = EventSource(url, self.handler, when_connected=self.connected)
d = self.es.start()
d.addBoth(self.disconnected)
def close(self):
self.es.cancel()
def wait_for_next_event(self):
assert not self.d
if self.events:
event = self.events.pop(0)
return defer.succeed(event)
self.d = defer.Deferred()
return self.d
def handler(self, eventtype, data):
event = (eventtype, self.parser(data))
if self.d:
assert not self.events
d,self.d = self.d,None
d.callback(event)
return
self.events.append(event)
def wait_for_connection(self):
return self.connected_d
def connected(self):
self.connected_d.callback(None)
def wait_for_disconnection(self):
return self.disconnected_d
def disconnected(self, why):
self.disconnected_d.callback((why,))

View File

@ -93,6 +93,7 @@ class EventSource: # TODO: service.Service
def start(self):
assert not self.started, "single-use"
self.started = True
assert self.url
d = self.agent.request("GET", self.url,
Headers({"accept": ["text/event-stream"]}))
d.addCallback(self._connected)
@ -154,14 +155,13 @@ class Connector:
class ReconnectingEventSource(service.MultiService,
protocol.ReconnectingClientFactory):
def __init__(self, baseurl, connection_starting, handler, agent=None):
def __init__(self, url, handler, agent=None):
service.MultiService.__init__(self)
# we don't use any of the basic Factory/ClientFactory methods of
# this, just the ReconnectingClientFactory.retry, stopTrying, and
# resetDelay methods.
self.baseurl = baseurl
self.connection_starting = connection_starting
self.url = url
self.handler = handler
self.agent = agent
# IService provides self.running, toggled by {start,stop}Service.
@ -201,8 +201,7 @@ class ReconnectingEventSource(service.MultiService,
if not (self.active and self.running):
return
self.continueTrying = True
url = self.connection_starting()
self.es = EventSource(url, self.handler, self.resetDelay,
self.es = EventSource(self.url, self.handler, self.resetDelay,
agent=self.agent)
d = self.es.start()
d.addBoth(self._stopped)

View File

@ -13,8 +13,7 @@ from spake2 import SPAKE2_Symmetric
from .eventsource_twisted import ReconnectingEventSource
from .. import __version__
from .. import codes
from ..errors import (ServerError, WrongPasswordError,
ReflectionAttack, UsageError)
from ..errors import ServerError, WrongPasswordError, UsageError
from ..util.hkdf import HKDF
@implementer(IBodyProducer)
@ -46,12 +45,9 @@ class Wormhole:
self.code = None
self.key = None
self._started_get_code = False
def _url(self, verb, msgnum=None):
url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb)
if msgnum is not None:
url += "/" + msgnum
return url
self._channel_url = None
self._messages = set() # (phase,body) , body is bytes
self._sent_messages = set() # (phase,body)
def handle_welcome(self, welcome):
if ("motd" in welcome and
@ -75,14 +71,10 @@ class Wormhole:
if "error" in welcome:
raise ServerError(welcome["error"], self.relay)
def _post_json(self, url, post_json=None):
# POST to a URL, parsing the response as JSON. Optionally include a
# JSON request body.
p = None
if post_json:
data = json.dumps(post_json).encode("utf-8")
p = DataProducer(data)
d = self.agent.request("POST", url, bodyProducer=p)
def _post_json(self, url, post_json):
# POST a JSON body to a URL, parsing the response as JSON
data = json.dumps(post_json).encode("utf-8")
d = self.agent.request("POST", url, bodyProducer=DataProducer(data))
def _check_error(resp):
if resp.code != 200:
raise web_error.Error(resp.code, resp.phrase)
@ -93,8 +85,8 @@ class Wormhole:
return d
def _allocate_channel(self):
url = self.relay + "allocate/%s" % self.side
d = self._post_json(url)
url = self.relay + "allocate"
d = self._post_json(url, {"side": self.side})
def _got_channel(data):
if "welcome" in data:
self.handle_welcome(data["welcome"])
@ -131,6 +123,7 @@ class Wormhole:
if not mo:
raise ValueError("code (%s) must start with NN-" % code)
self.channel_id = int(mo.group(1))
self._channel_url = "%s%d" % (self.relay, self.channel_id)
self.code = code
def _start(self):
@ -164,36 +157,56 @@ class Wormhole:
self.msg1 = d["msg1"].decode("hex")
return self
def _post_message(self, url, msg):
def _add_inbound_messages(self, messages):
for msg in messages:
phase = msg["phase"]
body = unhexlify(msg["body"].encode("ascii"))
self._messages.add( (phase, body) )
def _find_inbound_message(self, phase):
for (their_phase,body) in self._messages - self._sent_messages:
if their_phase == phase:
return body
return None
def _send_message(self, phase, msg):
# TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline.
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
d = self._post_json(url, {"message": hexlify(msg).decode("ascii")})
d.addCallback(lambda resp: resp["messages"]) # other_msgs
self._sent_messages.add( (phase,msg) )
payload = {"side": self.side,
"phase": phase,
"body": hexlify(msg).decode("ascii")}
d = self._post_json(self._channel_url, payload)
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
return d
def _get_message(self, old_msgs, verb, msgnum):
# fire with a bytestring of the first message that matches
# verb/msgnum, which either came from old_msgs, or from an
# EventSource that we attached to the corresponding URL
if old_msgs:
msg = unhexlify(old_msgs[0].encode("ascii"))
return defer.succeed(msg)
def _get_message(self, phase):
# fire with a bytestring of the first message for 'phase' that wasn't
# one of ours. It will either come from previously-received messages,
# or from an EventSource that we attach to the corresponding URL
body = self._find_inbound_message(phase)
if body is not None:
return defer.succeed(body)
d = defer.Deferred()
msgs = []
def _handle(name, data):
if name == "welcome":
self.handle_welcome(json.loads(data))
if name == "message":
msgs.append(json.loads(data)["message"])
d.callback(None)
es = ReconnectingEventSource(None, lambda: self._url(verb, msgnum),
_handle)#, agent=self.agent)
self._add_inbound_messages([json.loads(data)])
body = self._find_inbound_message(phase)
if body is not None and not msgs:
msgs.append(body)
d.callback(None)
# TODO: use agent=self.agent
es = ReconnectingEventSource(self._channel_url, _handle)
es.startService() # TODO: .setServiceParent(self)
es.activate()
d.addCallback(lambda _: es.deactivate())
d.addCallback(lambda _: es.stopService())
d.addCallback(lambda _: unhexlify(msgs[0].encode("ascii")))
d.addCallback(lambda _: msgs[0])
return d
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
@ -224,8 +237,8 @@ class Wormhole:
# TODO: prevent multiple invocation
if self.key:
return defer.succeed(self.key)
d = self._post_message(self._url("post", "pake"), self.msg1)
d.addCallback(lambda msgs: self._get_message(msgs, "poll", "pake"))
d = self._send_message(u"pake", self.msg1)
d.addCallback(lambda _: self._get_message(u"pake"))
def _got_pake(pake_msg):
key = self.sp.finish(pake_msg)
self.key = key
@ -256,12 +269,12 @@ class Wormhole:
data_key = self.derive_key(b"data-key")
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
d = self._post_message(self._url("post", "data"), outbound_encrypted)
d = self._send_message(u"data", outbound_encrypted)
d.addCallback(lambda msgs: self._get_message(msgs, "poll", "data"))
d.addCallback(lambda _: self._get_message(u"data"))
def _got_data(inbound_encrypted):
if inbound_encrypted == outbound_encrypted:
raise ReflectionAttack
#if inbound_encrypted == outbound_encrypted:
# raise ReflectionAttack
try:
inbound_data = self._decrypt_data(data_key, inbound_encrypted)
return inbound_data
@ -272,6 +285,7 @@ class Wormhole:
def _deallocate(self, res):
# only try once, no retries
d = self.agent.request("POST", self._url("deallocate"))
d = self._post_json(self._channel_url+"/deallocate",
{"side": self.side})
d.addBoth(lambda _: res) # ignore POST failure, pass-through result
return d