rewrite server API
This removes "side" and "msgnum" from the URLs, and puts them in a JSON request body instead. The server now maintains a simple set of messages for each channel-id, and isn't responsible for removing duplicates. The client now fetches all messages, and just ignores everything it sent itself. This removes the "reflection attack". Deallocate now returns JSON, for consistency. DB and API use "phase" and "body" instead of msgnum/message. This changes the DB schema, so delete the DB before upgrading the server.
This commit is contained in:
parent
bc3b0f03b9
commit
617bb03ad5
|
@ -8,21 +8,21 @@ from nacl import utils
|
|||
from .eventsource import EventSourceFollower
|
||||
from .. import __version__
|
||||
from .. import codes
|
||||
from ..errors import (ServerError, Timeout, WrongPasswordError,
|
||||
ReflectionAttack, UsageError)
|
||||
from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
|
||||
from ..util.hkdf import HKDF
|
||||
|
||||
SECOND = 1
|
||||
MINUTE = 60*SECOND
|
||||
|
||||
# relay URLs are:
|
||||
# GET /list -> {channel-ids: [INT..]}
|
||||
# POST /allocate/SIDE -> {channel-id: INT}
|
||||
# these return all messages for CHANNEL-ID= and MSGNUM= but SIDE!= :
|
||||
# POST /CHANNEL-ID/SIDE/post/MSGNUM {message: STR} -> {messages: [STR..]}
|
||||
# POST /CHANNEL-ID/SIDE/poll/MSGNUM -> {messages: [STR..]}
|
||||
# GET /CHANNEL-ID/SIDE/poll/MSGNUM (eventsource) -> STR, STR, ..
|
||||
# POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted
|
||||
# GET /list -> {channel-ids: [INT..]}
|
||||
# POST /allocate {side: SIDE} -> {channel-id: INT}
|
||||
# these return all messages (base64) for CID= :
|
||||
# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]}
|
||||
# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]}
|
||||
# GET /CID (eventsource) -> {phase:, body:}..
|
||||
# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted}
|
||||
# all JSON responses include a "welcome:{..}" key
|
||||
|
||||
class Wormhole:
|
||||
motd_displayed = False
|
||||
|
@ -40,12 +40,9 @@ class Wormhole:
|
|||
self.code = None
|
||||
self.key = None
|
||||
self.verifier = None
|
||||
|
||||
def _url(self, verb, msgnum=None):
|
||||
url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb)
|
||||
if msgnum is not None:
|
||||
url += "/" + msgnum
|
||||
return url
|
||||
self._channel_url = None
|
||||
self._messages = set() # (phase,body) , body is bytes
|
||||
self._sent_messages = set() # (phase,body)
|
||||
|
||||
def handle_welcome(self, welcome):
|
||||
if ("motd" in welcome and
|
||||
|
@ -69,18 +66,9 @@ class Wormhole:
|
|||
if "error" in welcome:
|
||||
raise ServerError(welcome["error"], self.relay)
|
||||
|
||||
def _post_json(self, url, post_json=None):
|
||||
# POST to a URL, parsing the response as JSON. Optionally include a
|
||||
# JSON request body.
|
||||
data = None
|
||||
if post_json:
|
||||
data = json.dumps(post_json).encode("utf-8")
|
||||
r = requests.post(url, data=data)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
def _allocate_channel(self):
|
||||
r = requests.post(self.relay + "allocate/%s" % self.side)
|
||||
data = json.dumps({"side": self.side}).encode("utf-8")
|
||||
r = requests.post(self.relay + "allocate", data=data)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "welcome" in data:
|
||||
|
@ -123,6 +111,7 @@ class Wormhole:
|
|||
if not mo:
|
||||
raise ValueError("code (%s) must start with NN-" % code)
|
||||
self.channel_id = int(mo.group(1))
|
||||
self._channel_url = "%s%d" % (self.relay, self.channel_id)
|
||||
self.code = code
|
||||
|
||||
def _start(self):
|
||||
|
@ -131,36 +120,62 @@ class Wormhole:
|
|||
idSymmetric=self.appid)
|
||||
self.msg1 = self.sp.start()
|
||||
|
||||
def _post_message(self, url, msg):
|
||||
def _add_inbound_messages(self, messages):
|
||||
for msg in messages:
|
||||
phase = msg["phase"]
|
||||
body = unhexlify(msg["body"].encode("ascii"))
|
||||
self._messages.add( (phase, body) )
|
||||
|
||||
def _find_inbound_message(self, phase):
|
||||
for (their_phase,body) in self._messages - self._sent_messages:
|
||||
if their_phase == phase:
|
||||
return body
|
||||
return None
|
||||
|
||||
def _send_message(self, phase, msg):
|
||||
# TODO: retry on failure, with exponential backoff. We're guarding
|
||||
# against the rendezvous server being temporarily offline.
|
||||
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
||||
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
|
||||
resp = self._post_json(url, {"message": hexlify(msg).decode("ascii")})
|
||||
return resp["messages"] # other_msgs
|
||||
self._sent_messages.add( (phase,msg) )
|
||||
payload = {"side": self.side,
|
||||
"phase": phase,
|
||||
"body": hexlify(msg).decode("ascii")}
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
r = requests.post(self._channel_url, data=data)
|
||||
r.raise_for_status()
|
||||
resp = r.json()
|
||||
self._add_inbound_messages(resp["messages"])
|
||||
|
||||
def _get_message(self, old_msgs, verb, msgnum):
|
||||
def _get_message(self, phase):
|
||||
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
||||
# For now, server errors cause the client to fail. TODO: don't. This
|
||||
# will require changing the client to re-post messages when the
|
||||
# server comes back up.
|
||||
|
||||
# fire with a bytestring of the first message that matches
|
||||
# verb/msgnum, which either came from old_msgs, or from an
|
||||
# EventSource that we attached to the corresponding URL
|
||||
msgs = old_msgs
|
||||
while not msgs:
|
||||
# fire with a bytestring of the first message for 'phase' that wasn't
|
||||
# one of ours. It will either come from previously-received messages,
|
||||
# or from an EventSource that we attach to the corresponding URL
|
||||
body = self._find_inbound_message(phase)
|
||||
while body is None:
|
||||
remaining = self.started + self.timeout - time.time()
|
||||
if remaining < 0:
|
||||
raise Timeout
|
||||
#time.sleep(self.wait)
|
||||
f = EventSourceFollower(self._url(verb, msgnum), remaining)
|
||||
return Timeout
|
||||
f = EventSourceFollower(self._channel_url, remaining)
|
||||
# we loop here until the connection is lost, or we see the
|
||||
# message we want
|
||||
for (eventtype, data) in f.iter_events():
|
||||
if eventtype == "welcome":
|
||||
self.handle_welcome(json.loads(data))
|
||||
if eventtype == "message":
|
||||
msgs = [json.loads(data)["message"]]
|
||||
break
|
||||
f.close()
|
||||
return unhexlify(msgs[0].encode("ascii"))
|
||||
self._add_inbound_messages([json.loads(data)])
|
||||
body = self._find_inbound_message(phase)
|
||||
if body:
|
||||
f.close()
|
||||
break
|
||||
if not body:
|
||||
time.sleep(self.wait)
|
||||
return body
|
||||
|
||||
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||||
if not isinstance(purpose, type(b"")): raise UsageError
|
||||
|
@ -185,8 +200,8 @@ class Wormhole:
|
|||
|
||||
def _get_key(self):
|
||||
if not self.key:
|
||||
old_msgs = self._post_message(self._url("post", "pake"), self.msg1)
|
||||
pake_msg = self._get_message(old_msgs, "poll", "pake")
|
||||
self._send_message(u"pake", self.msg1)
|
||||
pake_msg = self._get_message(u"pake")
|
||||
self.key = self.sp.finish(pake_msg)
|
||||
self.verifier = self.derive_key(self.appid+b":Verifier")
|
||||
|
||||
|
@ -214,11 +229,12 @@ class Wormhole:
|
|||
data_key = self.derive_key(b"data-key")
|
||||
|
||||
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
||||
msgs = self._post_message(self._url("post", "data"), outbound_encrypted)
|
||||
self._send_message(u"data", outbound_encrypted)
|
||||
|
||||
inbound_encrypted = self._get_message(msgs, "poll", "data")
|
||||
if inbound_encrypted == outbound_encrypted:
|
||||
raise ReflectionAttack
|
||||
inbound_encrypted = self._get_message(u"data")
|
||||
# _find_inbound_message() ignores any inbound message that matches
|
||||
# something we previously sent out, so we don't need to explicitly
|
||||
# check for reflection. A reflection attack will just not progress.
|
||||
try:
|
||||
inbound_data = self._decrypt_data(data_key, inbound_encrypted)
|
||||
return inbound_data
|
||||
|
@ -227,5 +243,6 @@ class Wormhole:
|
|||
|
||||
def _deallocate(self):
|
||||
# only try once, no retries
|
||||
requests.post(self._url("deallocate"))
|
||||
data = json.dumps({"side": self.side}).encode("utf-8")
|
||||
requests.post(self._channel_url+"/deallocate", data=data)
|
||||
# ignore POST failure, don't call r.raise_for_status()
|
||||
|
|
|
@ -11,11 +11,11 @@ CREATE TABLE `messages`
|
|||
(
|
||||
`channel_id` INTEGER,
|
||||
`side` VARCHAR,
|
||||
`msgnum` VARCHAR, -- not numeric, more of a PAKE-phase indicator string
|
||||
`message` VARCHAR,
|
||||
`phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string
|
||||
`body` VARCHAR,
|
||||
`when` INTEGER
|
||||
);
|
||||
CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `msgnum`);
|
||||
CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `phase`);
|
||||
|
||||
CREATE TABLE `allocations`
|
||||
(
|
||||
|
|
|
@ -51,112 +51,101 @@ class EventsProtocol:
|
|||
# note: no versions of IE (including the current IE11) support EventSource
|
||||
|
||||
# relay URLs are:
|
||||
# GET /list -> {channel-ids: [INT..]}
|
||||
# POST /allocate/SIDE -> {channel-id: INT}
|
||||
# these return all messages for CHANNEL-ID= and MSGNUM= but SIDE!= :
|
||||
# POST /CHANNEL-ID/SIDE/post/MSGNUM {message: STR} -> {messages: [STR..]}
|
||||
# POST /CHANNEL-ID/SIDE/poll/MSGNUM -> {messages: [STR..]}
|
||||
# GET /CHANNEL-ID/SIDE/poll/MSGNUM (eventsource) -> STR, STR, ..
|
||||
# POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted
|
||||
# GET /list -> {channel-ids: [INT..]}
|
||||
# POST /allocate {side: SIDE} -> {channel-id: INT}
|
||||
# these return all messages (base64) for CID= :
|
||||
# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]}
|
||||
# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]}
|
||||
# GET /CID (eventsource) -> {phase:, body:}..
|
||||
# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted}
|
||||
# all JSON responses include a "welcome:{..}" key
|
||||
|
||||
class Channel(resource.Resource):
|
||||
isLeaf = True # I handle /CHANNEL-ID/*
|
||||
|
||||
def __init__(self, channel_id, relay, db, welcome):
|
||||
resource.Resource.__init__(self)
|
||||
self.channel_id = channel_id
|
||||
self.relay = relay
|
||||
self.db = db
|
||||
self.welcome = welcome
|
||||
self.event_channels = set() # (side, msgnum, ep)
|
||||
self.event_channels = set() # ep
|
||||
self.putChild(b"deallocate", Deallocator(self.channel_id, self.relay))
|
||||
|
||||
def render_GET(self, request):
|
||||
# rest of URL is: SIDE/poll/MSGNUM
|
||||
their_side = request.postpath[0].decode("utf-8")
|
||||
if request.postpath[1] != b"poll":
|
||||
request.setResponseCode(http.BAD_REQUEST, b"GET to wrong URL")
|
||||
return b"GET is only for /SIDE/poll/MSGNUM"
|
||||
their_msgnum = request.postpath[2].decode("utf-8")
|
||||
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
|
||||
request.setResponseCode(http.BAD_REQUEST, b"Must use EventSource")
|
||||
return b"Must use EventSource (Content-Type: text/event-stream)"
|
||||
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
|
||||
ep = EventsProtocol(request)
|
||||
ep.sendEvent(json.dumps(self.welcome), name="welcome")
|
||||
handle = (their_side, their_msgnum, ep)
|
||||
self.event_channels.add(handle)
|
||||
request.notifyFinish().addErrback(self._shutdown, handle)
|
||||
def get_messages(self, request):
|
||||
request.setHeader(b"content-type", b"application/json; charset=utf-8")
|
||||
messages = []
|
||||
for row in self.db.execute("SELECT * FROM `messages`"
|
||||
" WHERE `channel_id`=?"
|
||||
" ORDER BY `when` ASC",
|
||||
(self.channel_id,)).fetchall():
|
||||
self.message_added(row["side"], row["msgnum"], row["message"],
|
||||
channels=[handle])
|
||||
messages.append({"phase": row["phase"], "body": row["body"]})
|
||||
data = {"welcome": self.welcome, "messages": messages}
|
||||
return (json.dumps(data)+"\n").encode("utf-8")
|
||||
|
||||
def render_GET(self, request):
|
||||
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
|
||||
return self.get_messages(request)
|
||||
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
|
||||
ep = EventsProtocol(request)
|
||||
ep.sendEvent(json.dumps(self.welcome), name="welcome")
|
||||
self.event_channels.add(ep)
|
||||
request.notifyFinish().addErrback(lambda f:
|
||||
self.event_channels.discard(ep))
|
||||
for row in self.db.execute("SELECT * FROM `messages`"
|
||||
" WHERE `channel_id`=?"
|
||||
" ORDER BY `when` ASC",
|
||||
(self.channel_id,)).fetchall():
|
||||
data = json.dumps({"phase": row["phase"], "body": row["body"]})
|
||||
ep.sendEvent(data)
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
def _shutdown(self, _, handle):
|
||||
self.event_channels.discard(handle)
|
||||
|
||||
def message_added(self, msg_side, msg_msgnum, msg_str, channels=None):
|
||||
if channels is None:
|
||||
channels = self.event_channels
|
||||
for (their_side, their_msgnum, their_ep) in channels:
|
||||
if msg_side != their_side and msg_msgnum == their_msgnum:
|
||||
data = json.dumps({ "side": msg_side, "message": msg_str })
|
||||
their_ep.sendEvent(data)
|
||||
def broadcast_message(self, phase, body):
|
||||
data = json.dumps({"phase": phase, "body": body})
|
||||
for ep in self.event_channels:
|
||||
ep.sendEvent(data)
|
||||
|
||||
def render_POST(self, request):
|
||||
# rest of URL is: SIDE/(MSGNUM|deallocate)/(post|poll)
|
||||
side = request.postpath[0].decode("utf-8")
|
||||
verb = request.postpath[1].decode("utf-8")
|
||||
#data = json.load(request.content, encoding="utf-8")
|
||||
content = request.content.read()
|
||||
data = json.loads(content.decode("utf-8"))
|
||||
|
||||
if verb == "deallocate":
|
||||
deleted = self.relay.maybe_free_child(self.channel_id, side)
|
||||
if deleted:
|
||||
return b"deleted\n"
|
||||
return b"waiting\n"
|
||||
side = data["side"]
|
||||
phase = data["phase"]
|
||||
if not isinstance(phase, type(u"")):
|
||||
raise TypeError("phase must be string, not %s" % type(phase))
|
||||
body = data["body"]
|
||||
|
||||
if verb not in ("post", "poll"):
|
||||
request.setResponseCode(http.BAD_REQUEST)
|
||||
return b"bad verb, want 'post' or 'poll'\n"
|
||||
self.db.execute("INSERT INTO `messages`"
|
||||
" (`channel_id`, `side`, `phase`, `body`, `when`)"
|
||||
" VALUES (?,?,?,?,?)",
|
||||
(self.channel_id, side, phase, body, time.time()))
|
||||
self.db.execute("INSERT INTO `allocations`"
|
||||
" (`channel_id`, `side`)"
|
||||
" VALUES (?,?)",
|
||||
(self.channel_id, side))
|
||||
self.db.commit()
|
||||
self.broadcast_message(phase, body)
|
||||
return self.get_messages(request)
|
||||
|
||||
msgnum = request.postpath[2].decode("utf-8")
|
||||
class Deallocator(resource.Resource):
|
||||
def __init__(self, channel_id, relay):
|
||||
self.channel_id = channel_id
|
||||
self.relay = relay
|
||||
|
||||
other_messages = []
|
||||
for row in self.db.execute("SELECT `message` FROM `messages`"
|
||||
" WHERE `channel_id`=? AND `side`!=?"
|
||||
" AND `msgnum`=?"
|
||||
" ORDER BY `when` ASC",
|
||||
(self.channel_id, side, msgnum)).fetchall():
|
||||
other_messages.append(row["message"])
|
||||
|
||||
if verb == "post":
|
||||
#data = json.load(request.content, encoding="utf-8")
|
||||
content = request.content.read()
|
||||
data = json.loads(content.decode("utf-8"))
|
||||
self.db.execute("INSERT INTO `messages`"
|
||||
" (`channel_id`, `side`, `msgnum`, `message`, `when`)"
|
||||
" VALUES (?,?,?,?,?)",
|
||||
(self.channel_id, side, msgnum, data["message"],
|
||||
time.time()))
|
||||
self.db.execute("INSERT INTO `allocations`"
|
||||
" (`channel_id`, `side`)"
|
||||
" VALUES (?,?)",
|
||||
(self.channel_id, side))
|
||||
self.db.commit()
|
||||
self.message_added(side, msgnum, data["message"])
|
||||
|
||||
request.setHeader(b"content-type", b"application/json; charset=utf-8")
|
||||
data = {"welcome": self.welcome,
|
||||
"messages": other_messages}
|
||||
return (json.dumps(data)+"\n").encode("utf-8")
|
||||
def render_POST(self, request):
|
||||
content = request.content.read()
|
||||
data = json.loads(content.decode("utf-8"))
|
||||
side = data["side"]
|
||||
deleted = self.relay.maybe_free_child(self.channel_id, side)
|
||||
resp = {"status": "waiting"}
|
||||
if deleted:
|
||||
resp = {"status": "deleted"}
|
||||
return json.dumps(resp).encode("utf-8")
|
||||
|
||||
def get_allocated(db):
|
||||
c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`")
|
||||
return set([row["channel_id"] for row in c.fetchall()])
|
||||
|
||||
class Allocator(resource.Resource):
|
||||
isLeaf = True
|
||||
def __init__(self, db, welcome):
|
||||
resource.Resource.__init__(self)
|
||||
self.db = db
|
||||
|
@ -179,7 +168,11 @@ class Allocator(resource.Resource):
|
|||
raise ValueError("unable to find a free channel-id")
|
||||
|
||||
def render_POST(self, request):
|
||||
side = request.postpath[0]
|
||||
content = request.content.read()
|
||||
data = json.loads(content.decode("utf-8"))
|
||||
side = data["side"]
|
||||
if not isinstance(side, type(u"")):
|
||||
raise TypeError("side must be string, not '%s'" % type(side))
|
||||
channel_id = self.allocate_channel_id()
|
||||
self.db.execute("INSERT INTO `allocations` VALUES (?,?)",
|
||||
(channel_id, side))
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import sys
|
||||
from __future__ import print_function
|
||||
import sys, json
|
||||
import requests
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.internet.threads import deferToThread
|
||||
from twisted.web.client import getPage, Agent, readBody
|
||||
from .. import __version__
|
||||
from .common import ServerBase
|
||||
from ..twisted.eventsource_twisted import EventSource
|
||||
|
||||
class Reachable(ServerBase, unittest.TestCase):
|
||||
|
||||
|
@ -47,3 +50,243 @@ class Reachable(ServerBase, unittest.TestCase):
|
|||
self.failUnlessEqual(res, "Wormhole Relay\n")
|
||||
d.addCallback(_got)
|
||||
return d
|
||||
|
||||
def unjson(data):
|
||||
return json.loads(data.decode("utf-8"))
|
||||
|
||||
class API(ServerBase, unittest.TestCase):
|
||||
def get(self, path, is_json=True):
|
||||
url = (self.relayurl+path).encode("ascii")
|
||||
d = getPage(url)
|
||||
if is_json:
|
||||
d.addCallback(unjson)
|
||||
return d
|
||||
def post(self, path, data):
|
||||
url = (self.relayurl+path).encode("ascii")
|
||||
d = getPage(url, method=b"POST",
|
||||
postdata=json.dumps(data).encode("utf-8"))
|
||||
d.addCallback(unjson)
|
||||
return d
|
||||
|
||||
def check_welcome(self, data):
|
||||
self.failUnlessIn("welcome", data)
|
||||
self.failUnlessEqual(data["welcome"], {"current_version": __version__})
|
||||
|
||||
def test_allocate_1(self):
|
||||
d = self.get("list")
|
||||
def _check_list_1(data):
|
||||
self.check_welcome(data)
|
||||
self.failUnlessEqual(data["channel-ids"], [])
|
||||
d.addCallback(_check_list_1)
|
||||
|
||||
d.addCallback(lambda _: self.post("allocate", {"side": "abc"}))
|
||||
def _allocated(data):
|
||||
self.failUnlessEqual(set(data.keys()),
|
||||
set(["welcome", "channel-id"]))
|
||||
self.failUnlessIsInstance(data["channel-id"], int)
|
||||
self.cid = data["channel-id"]
|
||||
d.addCallback(_allocated)
|
||||
|
||||
d.addCallback(lambda _: self.get("list"))
|
||||
def _check_list_2(data):
|
||||
self.failUnlessEqual(data["channel-ids"], [self.cid])
|
||||
d.addCallback(_check_list_2)
|
||||
|
||||
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
|
||||
{"side": "abc"}))
|
||||
def _check_deallocate(res):
|
||||
self.failUnlessEqual(res["status"], "deleted")
|
||||
d.addCallback(_check_deallocate)
|
||||
|
||||
d.addCallback(lambda _: self.get("list"))
|
||||
def _check_list_3(data):
|
||||
self.failUnlessEqual(data["channel-ids"], [])
|
||||
d.addCallback(_check_list_3)
|
||||
|
||||
return d
|
||||
|
||||
def test_allocate_2(self):
|
||||
d = self.post("allocate", {"side": "abc"})
|
||||
def _allocated(data):
|
||||
self.cid = data["channel-id"]
|
||||
d.addCallback(_allocated)
|
||||
|
||||
# second caller increases the number of known sides to 2
|
||||
d.addCallback(lambda _: self.post("%d" % self.cid,
|
||||
{"side": "def",
|
||||
"phase": "1",
|
||||
"body": ""}))
|
||||
|
||||
d.addCallback(lambda _: self.get("list"))
|
||||
d.addCallback(lambda data:
|
||||
self.failUnlessEqual(data["channel-ids"], [self.cid]))
|
||||
|
||||
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
|
||||
{"side": "abc"}))
|
||||
d.addCallback(lambda res:
|
||||
self.failUnlessEqual(res["status"], "waiting"))
|
||||
|
||||
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
|
||||
{"side": "NOT"}))
|
||||
d.addCallback(lambda res:
|
||||
self.failUnlessEqual(res["status"], "waiting"))
|
||||
|
||||
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
|
||||
{"side": "def"}))
|
||||
d.addCallback(lambda res:
|
||||
self.failUnlessEqual(res["status"], "deleted"))
|
||||
|
||||
d.addCallback(lambda _: self.get("list"))
|
||||
d.addCallback(lambda data:
|
||||
self.failUnlessEqual(data["channel-ids"], []))
|
||||
|
||||
return d
|
||||
|
||||
def add_message(self, message, side="abc", phase="1"):
|
||||
return self.post(str(self.cid), {"side": side, "phase": phase,
|
||||
"body": message})
|
||||
|
||||
def parse_messages(self, messages):
|
||||
out = set()
|
||||
for m in messages:
|
||||
self.failUnlessEqual(sorted(m.keys()), sorted(["phase", "body"]))
|
||||
self.failUnlessIsInstance(m["phase"], type(u""))
|
||||
self.failUnlessIsInstance(m["body"], type(u""))
|
||||
out.add( (m["phase"], m["body"]) )
|
||||
return out
|
||||
|
||||
def check_messages(self, one, two):
|
||||
# Comparing lists-of-dicts is non-trivial in python3 because we can
|
||||
# neither sort them (dicts are uncomparable), nor turn them into sets
|
||||
# (dicts are unhashable). This is close enough.
|
||||
self.failUnlessEqual(len(one), len(two), (one,two))
|
||||
for d in one:
|
||||
self.failUnlessIn(d, two)
|
||||
|
||||
def test_messages(self):
|
||||
d = self.post("allocate", {"side": "abc"})
|
||||
def _allocated(data):
|
||||
self.cid = data["channel-id"]
|
||||
d.addCallback(_allocated)
|
||||
|
||||
d.addCallback(lambda _: self.add_message("msg1A"))
|
||||
def _check1(data):
|
||||
self.check_welcome(data)
|
||||
self.failUnlessEqual(data["messages"],
|
||||
[{"phase": "1", "body": "msg1A"}])
|
||||
d.addCallback(_check1)
|
||||
d.addCallback(lambda _: self.add_message("msg1B", side="def"))
|
||||
def _check2(data):
|
||||
self.check_welcome(data)
|
||||
self.failUnlessEqual(self.parse_messages(data["messages"]),
|
||||
set([("1", "msg1A"),
|
||||
("1", "msg1B")]))
|
||||
d.addCallback(_check2)
|
||||
|
||||
# adding a duplicate message is not an error, is ignored by clients
|
||||
d.addCallback(lambda _: self.add_message("msg1B", side="def"))
|
||||
def _check3(data):
|
||||
self.check_welcome(data)
|
||||
self.failUnlessEqual(self.parse_messages(data["messages"]),
|
||||
set([("1", "msg1A"),
|
||||
("1", "msg1B")]))
|
||||
d.addCallback(_check3)
|
||||
|
||||
d.addCallback(lambda _: self.add_message("msg2A", side="abc",
|
||||
phase="2"))
|
||||
def _check4(data):
|
||||
self.check_welcome(data)
|
||||
self.failUnlessEqual(self.parse_messages(data["messages"]),
|
||||
set([("1", "msg1A"),
|
||||
("1", "msg1B"),
|
||||
("2", "msg2A"),
|
||||
]))
|
||||
d.addCallback(_check4)
|
||||
|
||||
return d
|
||||
|
||||
def test_eventsource(self):
|
||||
if sys.version_info[0] >= 3:
|
||||
raise unittest.SkipTest("twisted vs py3")
|
||||
|
||||
d = self.post("allocate", {"side": "abc"})
|
||||
def _allocated(data):
|
||||
self.cid = data["channel-id"]
|
||||
url = (self.relayurl+str(self.cid)).encode("utf-8")
|
||||
self.o = OneEventAtATime(url, parser=json.loads)
|
||||
return self.o.wait_for_connection()
|
||||
d.addCallback(_allocated)
|
||||
d.addCallback(lambda _: self.o.wait_for_next_event())
|
||||
def _check_welcome(ev):
|
||||
eventtype, data = ev
|
||||
self.failUnlessEqual(eventtype, "welcome")
|
||||
self.failUnlessEqual(data, {"current_version": __version__})
|
||||
d.addCallback(_check_welcome)
|
||||
d.addCallback(lambda _: self.add_message("msg1A"))
|
||||
d.addCallback(lambda _: self.o.wait_for_next_event())
|
||||
def _check_msg1(ev):
|
||||
eventtype, data = ev
|
||||
self.failUnlessEqual(eventtype, "message")
|
||||
self.failUnlessEqual(data, {"phase": "1", "body": "msg1A"})
|
||||
d.addCallback(_check_msg1)
|
||||
|
||||
d.addCallback(lambda _: self.add_message("msg1B"))
|
||||
d.addCallback(lambda _: self.add_message("msg2A", phase="2"))
|
||||
d.addCallback(lambda _: self.o.wait_for_next_event())
|
||||
def _check_msg2(ev):
|
||||
eventtype, data = ev
|
||||
self.failUnlessEqual(eventtype, "message")
|
||||
self.failUnlessEqual(data, {"phase": "1", "body": "msg1B"})
|
||||
d.addCallback(_check_msg2)
|
||||
d.addCallback(lambda _: self.o.wait_for_next_event())
|
||||
def _check_msg3(ev):
|
||||
eventtype, data = ev
|
||||
self.failUnlessEqual(eventtype, "message")
|
||||
self.failUnlessEqual(data, {"phase": "2", "body": "msg2A"})
|
||||
d.addCallback(_check_msg3)
|
||||
|
||||
d.addCallback(lambda _: self.o.close())
|
||||
d.addCallback(lambda _: self.o.wait_for_disconnection())
|
||||
return d
|
||||
|
||||
class OneEventAtATime:
|
||||
def __init__(self, url, parser=lambda e: e):
|
||||
self.parser = parser
|
||||
self.d = None
|
||||
self.connected_d = defer.Deferred()
|
||||
self.disconnected_d = defer.Deferred()
|
||||
self.events = []
|
||||
self.es = EventSource(url, self.handler, when_connected=self.connected)
|
||||
d = self.es.start()
|
||||
d.addBoth(self.disconnected)
|
||||
|
||||
def close(self):
|
||||
self.es.cancel()
|
||||
|
||||
def wait_for_next_event(self):
|
||||
assert not self.d
|
||||
if self.events:
|
||||
event = self.events.pop(0)
|
||||
return defer.succeed(event)
|
||||
self.d = defer.Deferred()
|
||||
return self.d
|
||||
|
||||
def handler(self, eventtype, data):
|
||||
event = (eventtype, self.parser(data))
|
||||
if self.d:
|
||||
assert not self.events
|
||||
d,self.d = self.d,None
|
||||
d.callback(event)
|
||||
return
|
||||
self.events.append(event)
|
||||
|
||||
def wait_for_connection(self):
|
||||
return self.connected_d
|
||||
def connected(self):
|
||||
self.connected_d.callback(None)
|
||||
|
||||
def wait_for_disconnection(self):
|
||||
return self.disconnected_d
|
||||
def disconnected(self, why):
|
||||
self.disconnected_d.callback((why,))
|
||||
|
||||
|
|
|
@ -93,6 +93,7 @@ class EventSource: # TODO: service.Service
|
|||
def start(self):
|
||||
assert not self.started, "single-use"
|
||||
self.started = True
|
||||
assert self.url
|
||||
d = self.agent.request("GET", self.url,
|
||||
Headers({"accept": ["text/event-stream"]}))
|
||||
d.addCallback(self._connected)
|
||||
|
@ -154,14 +155,13 @@ class Connector:
|
|||
|
||||
class ReconnectingEventSource(service.MultiService,
|
||||
protocol.ReconnectingClientFactory):
|
||||
def __init__(self, baseurl, connection_starting, handler, agent=None):
|
||||
def __init__(self, url, handler, agent=None):
|
||||
service.MultiService.__init__(self)
|
||||
# we don't use any of the basic Factory/ClientFactory methods of
|
||||
# this, just the ReconnectingClientFactory.retry, stopTrying, and
|
||||
# resetDelay methods.
|
||||
|
||||
self.baseurl = baseurl
|
||||
self.connection_starting = connection_starting
|
||||
self.url = url
|
||||
self.handler = handler
|
||||
self.agent = agent
|
||||
# IService provides self.running, toggled by {start,stop}Service.
|
||||
|
@ -201,8 +201,7 @@ class ReconnectingEventSource(service.MultiService,
|
|||
if not (self.active and self.running):
|
||||
return
|
||||
self.continueTrying = True
|
||||
url = self.connection_starting()
|
||||
self.es = EventSource(url, self.handler, self.resetDelay,
|
||||
self.es = EventSource(self.url, self.handler, self.resetDelay,
|
||||
agent=self.agent)
|
||||
d = self.es.start()
|
||||
d.addBoth(self._stopped)
|
||||
|
|
|
@ -13,8 +13,7 @@ from spake2 import SPAKE2_Symmetric
|
|||
from .eventsource_twisted import ReconnectingEventSource
|
||||
from .. import __version__
|
||||
from .. import codes
|
||||
from ..errors import (ServerError, WrongPasswordError,
|
||||
ReflectionAttack, UsageError)
|
||||
from ..errors import ServerError, WrongPasswordError, UsageError
|
||||
from ..util.hkdf import HKDF
|
||||
|
||||
@implementer(IBodyProducer)
|
||||
|
@ -46,12 +45,9 @@ class Wormhole:
|
|||
self.code = None
|
||||
self.key = None
|
||||
self._started_get_code = False
|
||||
|
||||
def _url(self, verb, msgnum=None):
|
||||
url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb)
|
||||
if msgnum is not None:
|
||||
url += "/" + msgnum
|
||||
return url
|
||||
self._channel_url = None
|
||||
self._messages = set() # (phase,body) , body is bytes
|
||||
self._sent_messages = set() # (phase,body)
|
||||
|
||||
def handle_welcome(self, welcome):
|
||||
if ("motd" in welcome and
|
||||
|
@ -75,14 +71,10 @@ class Wormhole:
|
|||
if "error" in welcome:
|
||||
raise ServerError(welcome["error"], self.relay)
|
||||
|
||||
def _post_json(self, url, post_json=None):
|
||||
# POST to a URL, parsing the response as JSON. Optionally include a
|
||||
# JSON request body.
|
||||
p = None
|
||||
if post_json:
|
||||
data = json.dumps(post_json).encode("utf-8")
|
||||
p = DataProducer(data)
|
||||
d = self.agent.request("POST", url, bodyProducer=p)
|
||||
def _post_json(self, url, post_json):
|
||||
# POST a JSON body to a URL, parsing the response as JSON
|
||||
data = json.dumps(post_json).encode("utf-8")
|
||||
d = self.agent.request("POST", url, bodyProducer=DataProducer(data))
|
||||
def _check_error(resp):
|
||||
if resp.code != 200:
|
||||
raise web_error.Error(resp.code, resp.phrase)
|
||||
|
@ -93,8 +85,8 @@ class Wormhole:
|
|||
return d
|
||||
|
||||
def _allocate_channel(self):
|
||||
url = self.relay + "allocate/%s" % self.side
|
||||
d = self._post_json(url)
|
||||
url = self.relay + "allocate"
|
||||
d = self._post_json(url, {"side": self.side})
|
||||
def _got_channel(data):
|
||||
if "welcome" in data:
|
||||
self.handle_welcome(data["welcome"])
|
||||
|
@ -131,6 +123,7 @@ class Wormhole:
|
|||
if not mo:
|
||||
raise ValueError("code (%s) must start with NN-" % code)
|
||||
self.channel_id = int(mo.group(1))
|
||||
self._channel_url = "%s%d" % (self.relay, self.channel_id)
|
||||
self.code = code
|
||||
|
||||
def _start(self):
|
||||
|
@ -164,36 +157,56 @@ class Wormhole:
|
|||
self.msg1 = d["msg1"].decode("hex")
|
||||
return self
|
||||
|
||||
def _post_message(self, url, msg):
|
||||
def _add_inbound_messages(self, messages):
|
||||
for msg in messages:
|
||||
phase = msg["phase"]
|
||||
body = unhexlify(msg["body"].encode("ascii"))
|
||||
self._messages.add( (phase, body) )
|
||||
|
||||
def _find_inbound_message(self, phase):
|
||||
for (their_phase,body) in self._messages - self._sent_messages:
|
||||
if their_phase == phase:
|
||||
return body
|
||||
return None
|
||||
|
||||
def _send_message(self, phase, msg):
|
||||
# TODO: retry on failure, with exponential backoff. We're guarding
|
||||
# against the rendezvous server being temporarily offline.
|
||||
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
||||
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
|
||||
d = self._post_json(url, {"message": hexlify(msg).decode("ascii")})
|
||||
d.addCallback(lambda resp: resp["messages"]) # other_msgs
|
||||
self._sent_messages.add( (phase,msg) )
|
||||
payload = {"side": self.side,
|
||||
"phase": phase,
|
||||
"body": hexlify(msg).decode("ascii")}
|
||||
d = self._post_json(self._channel_url, payload)
|
||||
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
|
||||
return d
|
||||
|
||||
def _get_message(self, old_msgs, verb, msgnum):
|
||||
# fire with a bytestring of the first message that matches
|
||||
# verb/msgnum, which either came from old_msgs, or from an
|
||||
# EventSource that we attached to the corresponding URL
|
||||
if old_msgs:
|
||||
msg = unhexlify(old_msgs[0].encode("ascii"))
|
||||
return defer.succeed(msg)
|
||||
def _get_message(self, phase):
|
||||
# fire with a bytestring of the first message for 'phase' that wasn't
|
||||
# one of ours. It will either come from previously-received messages,
|
||||
# or from an EventSource that we attach to the corresponding URL
|
||||
body = self._find_inbound_message(phase)
|
||||
if body is not None:
|
||||
return defer.succeed(body)
|
||||
d = defer.Deferred()
|
||||
msgs = []
|
||||
def _handle(name, data):
|
||||
if name == "welcome":
|
||||
self.handle_welcome(json.loads(data))
|
||||
if name == "message":
|
||||
msgs.append(json.loads(data)["message"])
|
||||
d.callback(None)
|
||||
es = ReconnectingEventSource(None, lambda: self._url(verb, msgnum),
|
||||
_handle)#, agent=self.agent)
|
||||
self._add_inbound_messages([json.loads(data)])
|
||||
body = self._find_inbound_message(phase)
|
||||
if body is not None and not msgs:
|
||||
msgs.append(body)
|
||||
d.callback(None)
|
||||
# TODO: use agent=self.agent
|
||||
es = ReconnectingEventSource(self._channel_url, _handle)
|
||||
es.startService() # TODO: .setServiceParent(self)
|
||||
es.activate()
|
||||
d.addCallback(lambda _: es.deactivate())
|
||||
d.addCallback(lambda _: es.stopService())
|
||||
d.addCallback(lambda _: unhexlify(msgs[0].encode("ascii")))
|
||||
d.addCallback(lambda _: msgs[0])
|
||||
return d
|
||||
|
||||
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||||
|
@ -224,8 +237,8 @@ class Wormhole:
|
|||
# TODO: prevent multiple invocation
|
||||
if self.key:
|
||||
return defer.succeed(self.key)
|
||||
d = self._post_message(self._url("post", "pake"), self.msg1)
|
||||
d.addCallback(lambda msgs: self._get_message(msgs, "poll", "pake"))
|
||||
d = self._send_message(u"pake", self.msg1)
|
||||
d.addCallback(lambda _: self._get_message(u"pake"))
|
||||
def _got_pake(pake_msg):
|
||||
key = self.sp.finish(pake_msg)
|
||||
self.key = key
|
||||
|
@ -256,12 +269,12 @@ class Wormhole:
|
|||
data_key = self.derive_key(b"data-key")
|
||||
|
||||
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
||||
d = self._post_message(self._url("post", "data"), outbound_encrypted)
|
||||
d = self._send_message(u"data", outbound_encrypted)
|
||||
|
||||
d.addCallback(lambda msgs: self._get_message(msgs, "poll", "data"))
|
||||
d.addCallback(lambda _: self._get_message(u"data"))
|
||||
def _got_data(inbound_encrypted):
|
||||
if inbound_encrypted == outbound_encrypted:
|
||||
raise ReflectionAttack
|
||||
#if inbound_encrypted == outbound_encrypted:
|
||||
# raise ReflectionAttack
|
||||
try:
|
||||
inbound_data = self._decrypt_data(data_key, inbound_encrypted)
|
||||
return inbound_data
|
||||
|
@ -272,6 +285,7 @@ class Wormhole:
|
|||
|
||||
def _deallocate(self, res):
|
||||
# only try once, no retries
|
||||
d = self.agent.request("POST", self._url("deallocate"))
|
||||
d = self._post_json(self._channel_url+"/deallocate",
|
||||
{"side": self.side})
|
||||
d.addBoth(lambda _: res) # ignore POST failure, pass-through result
|
||||
return d
|
||||
|
|
Loading…
Reference in New Issue
Block a user