rewrite server API
This removes "side" and "msgnum" from the URLs, and puts them in a JSON request body instead. The server now maintains a simple set of messages for each channel-id, and isn't responsible for removing duplicates. The client now fetches all messages, and just ignores everything it sent itself. This removes the "reflection attack". Deallocate now returns JSON, for consistency. DB and API use "phase" and "body" instead of msgnum/message. This changes the DB schema, so delete the DB before upgrading the server.
This commit is contained in:
parent
bc3b0f03b9
commit
617bb03ad5
|
@ -8,8 +8,7 @@ from nacl import utils
|
||||||
from .eventsource import EventSourceFollower
|
from .eventsource import EventSourceFollower
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .. import codes
|
from .. import codes
|
||||||
from ..errors import (ServerError, Timeout, WrongPasswordError,
|
from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
|
||||||
ReflectionAttack, UsageError)
|
|
||||||
from ..util.hkdf import HKDF
|
from ..util.hkdf import HKDF
|
||||||
|
|
||||||
SECOND = 1
|
SECOND = 1
|
||||||
|
@ -17,12 +16,13 @@ MINUTE = 60*SECOND
|
||||||
|
|
||||||
# relay URLs are:
|
# relay URLs are:
|
||||||
# GET /list -> {channel-ids: [INT..]}
|
# GET /list -> {channel-ids: [INT..]}
|
||||||
# POST /allocate/SIDE -> {channel-id: INT}
|
# POST /allocate {side: SIDE} -> {channel-id: INT}
|
||||||
# these return all messages for CHANNEL-ID= and MSGNUM= but SIDE!= :
|
# these return all messages (base64) for CID= :
|
||||||
# POST /CHANNEL-ID/SIDE/post/MSGNUM {message: STR} -> {messages: [STR..]}
|
# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]}
|
||||||
# POST /CHANNEL-ID/SIDE/poll/MSGNUM -> {messages: [STR..]}
|
# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]}
|
||||||
# GET /CHANNEL-ID/SIDE/poll/MSGNUM (eventsource) -> STR, STR, ..
|
# GET /CID (eventsource) -> {phase:, body:}..
|
||||||
# POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted
|
# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted}
|
||||||
|
# all JSON responses include a "welcome:{..}" key
|
||||||
|
|
||||||
class Wormhole:
|
class Wormhole:
|
||||||
motd_displayed = False
|
motd_displayed = False
|
||||||
|
@ -40,12 +40,9 @@ class Wormhole:
|
||||||
self.code = None
|
self.code = None
|
||||||
self.key = None
|
self.key = None
|
||||||
self.verifier = None
|
self.verifier = None
|
||||||
|
self._channel_url = None
|
||||||
def _url(self, verb, msgnum=None):
|
self._messages = set() # (phase,body) , body is bytes
|
||||||
url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb)
|
self._sent_messages = set() # (phase,body)
|
||||||
if msgnum is not None:
|
|
||||||
url += "/" + msgnum
|
|
||||||
return url
|
|
||||||
|
|
||||||
def handle_welcome(self, welcome):
|
def handle_welcome(self, welcome):
|
||||||
if ("motd" in welcome and
|
if ("motd" in welcome and
|
||||||
|
@ -69,18 +66,9 @@ class Wormhole:
|
||||||
if "error" in welcome:
|
if "error" in welcome:
|
||||||
raise ServerError(welcome["error"], self.relay)
|
raise ServerError(welcome["error"], self.relay)
|
||||||
|
|
||||||
def _post_json(self, url, post_json=None):
|
|
||||||
# POST to a URL, parsing the response as JSON. Optionally include a
|
|
||||||
# JSON request body.
|
|
||||||
data = None
|
|
||||||
if post_json:
|
|
||||||
data = json.dumps(post_json).encode("utf-8")
|
|
||||||
r = requests.post(url, data=data)
|
|
||||||
r.raise_for_status()
|
|
||||||
return r.json()
|
|
||||||
|
|
||||||
def _allocate_channel(self):
|
def _allocate_channel(self):
|
||||||
r = requests.post(self.relay + "allocate/%s" % self.side)
|
data = json.dumps({"side": self.side}).encode("utf-8")
|
||||||
|
r = requests.post(self.relay + "allocate", data=data)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
data = r.json()
|
data = r.json()
|
||||||
if "welcome" in data:
|
if "welcome" in data:
|
||||||
|
@ -123,6 +111,7 @@ class Wormhole:
|
||||||
if not mo:
|
if not mo:
|
||||||
raise ValueError("code (%s) must start with NN-" % code)
|
raise ValueError("code (%s) must start with NN-" % code)
|
||||||
self.channel_id = int(mo.group(1))
|
self.channel_id = int(mo.group(1))
|
||||||
|
self._channel_url = "%s%d" % (self.relay, self.channel_id)
|
||||||
self.code = code
|
self.code = code
|
||||||
|
|
||||||
def _start(self):
|
def _start(self):
|
||||||
|
@ -131,36 +120,62 @@ class Wormhole:
|
||||||
idSymmetric=self.appid)
|
idSymmetric=self.appid)
|
||||||
self.msg1 = self.sp.start()
|
self.msg1 = self.sp.start()
|
||||||
|
|
||||||
def _post_message(self, url, msg):
|
def _add_inbound_messages(self, messages):
|
||||||
|
for msg in messages:
|
||||||
|
phase = msg["phase"]
|
||||||
|
body = unhexlify(msg["body"].encode("ascii"))
|
||||||
|
self._messages.add( (phase, body) )
|
||||||
|
|
||||||
|
def _find_inbound_message(self, phase):
|
||||||
|
for (their_phase,body) in self._messages - self._sent_messages:
|
||||||
|
if their_phase == phase:
|
||||||
|
return body
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _send_message(self, phase, msg):
|
||||||
# TODO: retry on failure, with exponential backoff. We're guarding
|
# TODO: retry on failure, with exponential backoff. We're guarding
|
||||||
# against the rendezvous server being temporarily offline.
|
# against the rendezvous server being temporarily offline.
|
||||||
|
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
||||||
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
|
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
|
||||||
resp = self._post_json(url, {"message": hexlify(msg).decode("ascii")})
|
self._sent_messages.add( (phase,msg) )
|
||||||
return resp["messages"] # other_msgs
|
payload = {"side": self.side,
|
||||||
|
"phase": phase,
|
||||||
|
"body": hexlify(msg).decode("ascii")}
|
||||||
|
data = json.dumps(payload).encode("utf-8")
|
||||||
|
r = requests.post(self._channel_url, data=data)
|
||||||
|
r.raise_for_status()
|
||||||
|
resp = r.json()
|
||||||
|
self._add_inbound_messages(resp["messages"])
|
||||||
|
|
||||||
def _get_message(self, old_msgs, verb, msgnum):
|
def _get_message(self, phase):
|
||||||
|
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
||||||
# For now, server errors cause the client to fail. TODO: don't. This
|
# For now, server errors cause the client to fail. TODO: don't. This
|
||||||
# will require changing the client to re-post messages when the
|
# will require changing the client to re-post messages when the
|
||||||
# server comes back up.
|
# server comes back up.
|
||||||
|
|
||||||
# fire with a bytestring of the first message that matches
|
# fire with a bytestring of the first message for 'phase' that wasn't
|
||||||
# verb/msgnum, which either came from old_msgs, or from an
|
# one of ours. It will either come from previously-received messages,
|
||||||
# EventSource that we attached to the corresponding URL
|
# or from an EventSource that we attach to the corresponding URL
|
||||||
msgs = old_msgs
|
body = self._find_inbound_message(phase)
|
||||||
while not msgs:
|
while body is None:
|
||||||
remaining = self.started + self.timeout - time.time()
|
remaining = self.started + self.timeout - time.time()
|
||||||
if remaining < 0:
|
if remaining < 0:
|
||||||
raise Timeout
|
return Timeout
|
||||||
#time.sleep(self.wait)
|
f = EventSourceFollower(self._channel_url, remaining)
|
||||||
f = EventSourceFollower(self._url(verb, msgnum), remaining)
|
# we loop here until the connection is lost, or we see the
|
||||||
|
# message we want
|
||||||
for (eventtype, data) in f.iter_events():
|
for (eventtype, data) in f.iter_events():
|
||||||
if eventtype == "welcome":
|
if eventtype == "welcome":
|
||||||
self.handle_welcome(json.loads(data))
|
self.handle_welcome(json.loads(data))
|
||||||
if eventtype == "message":
|
if eventtype == "message":
|
||||||
msgs = [json.loads(data)["message"]]
|
self._add_inbound_messages([json.loads(data)])
|
||||||
break
|
body = self._find_inbound_message(phase)
|
||||||
|
if body:
|
||||||
f.close()
|
f.close()
|
||||||
return unhexlify(msgs[0].encode("ascii"))
|
break
|
||||||
|
if not body:
|
||||||
|
time.sleep(self.wait)
|
||||||
|
return body
|
||||||
|
|
||||||
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||||||
if not isinstance(purpose, type(b"")): raise UsageError
|
if not isinstance(purpose, type(b"")): raise UsageError
|
||||||
|
@ -185,8 +200,8 @@ class Wormhole:
|
||||||
|
|
||||||
def _get_key(self):
|
def _get_key(self):
|
||||||
if not self.key:
|
if not self.key:
|
||||||
old_msgs = self._post_message(self._url("post", "pake"), self.msg1)
|
self._send_message(u"pake", self.msg1)
|
||||||
pake_msg = self._get_message(old_msgs, "poll", "pake")
|
pake_msg = self._get_message(u"pake")
|
||||||
self.key = self.sp.finish(pake_msg)
|
self.key = self.sp.finish(pake_msg)
|
||||||
self.verifier = self.derive_key(self.appid+b":Verifier")
|
self.verifier = self.derive_key(self.appid+b":Verifier")
|
||||||
|
|
||||||
|
@ -214,11 +229,12 @@ class Wormhole:
|
||||||
data_key = self.derive_key(b"data-key")
|
data_key = self.derive_key(b"data-key")
|
||||||
|
|
||||||
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
||||||
msgs = self._post_message(self._url("post", "data"), outbound_encrypted)
|
self._send_message(u"data", outbound_encrypted)
|
||||||
|
|
||||||
inbound_encrypted = self._get_message(msgs, "poll", "data")
|
inbound_encrypted = self._get_message(u"data")
|
||||||
if inbound_encrypted == outbound_encrypted:
|
# _find_inbound_message() ignores any inbound message that matches
|
||||||
raise ReflectionAttack
|
# something we previously sent out, so we don't need to explicitly
|
||||||
|
# check for reflection. A reflection attack will just not progress.
|
||||||
try:
|
try:
|
||||||
inbound_data = self._decrypt_data(data_key, inbound_encrypted)
|
inbound_data = self._decrypt_data(data_key, inbound_encrypted)
|
||||||
return inbound_data
|
return inbound_data
|
||||||
|
@ -227,5 +243,6 @@ class Wormhole:
|
||||||
|
|
||||||
def _deallocate(self):
|
def _deallocate(self):
|
||||||
# only try once, no retries
|
# only try once, no retries
|
||||||
requests.post(self._url("deallocate"))
|
data = json.dumps({"side": self.side}).encode("utf-8")
|
||||||
|
requests.post(self._channel_url+"/deallocate", data=data)
|
||||||
# ignore POST failure, don't call r.raise_for_status()
|
# ignore POST failure, don't call r.raise_for_status()
|
||||||
|
|
|
@ -11,11 +11,11 @@ CREATE TABLE `messages`
|
||||||
(
|
(
|
||||||
`channel_id` INTEGER,
|
`channel_id` INTEGER,
|
||||||
`side` VARCHAR,
|
`side` VARCHAR,
|
||||||
`msgnum` VARCHAR, -- not numeric, more of a PAKE-phase indicator string
|
`phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string
|
||||||
`message` VARCHAR,
|
`body` VARCHAR,
|
||||||
`when` INTEGER
|
`when` INTEGER
|
||||||
);
|
);
|
||||||
CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `msgnum`);
|
CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `phase`);
|
||||||
|
|
||||||
CREATE TABLE `allocations`
|
CREATE TABLE `allocations`
|
||||||
(
|
(
|
||||||
|
|
|
@ -52,111 +52,100 @@ class EventsProtocol:
|
||||||
|
|
||||||
# relay URLs are:
|
# relay URLs are:
|
||||||
# GET /list -> {channel-ids: [INT..]}
|
# GET /list -> {channel-ids: [INT..]}
|
||||||
# POST /allocate/SIDE -> {channel-id: INT}
|
# POST /allocate {side: SIDE} -> {channel-id: INT}
|
||||||
# these return all messages for CHANNEL-ID= and MSGNUM= but SIDE!= :
|
# these return all messages (base64) for CID= :
|
||||||
# POST /CHANNEL-ID/SIDE/post/MSGNUM {message: STR} -> {messages: [STR..]}
|
# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]}
|
||||||
# POST /CHANNEL-ID/SIDE/poll/MSGNUM -> {messages: [STR..]}
|
# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]}
|
||||||
# GET /CHANNEL-ID/SIDE/poll/MSGNUM (eventsource) -> STR, STR, ..
|
# GET /CID (eventsource) -> {phase:, body:}..
|
||||||
# POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted
|
# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted}
|
||||||
|
# all JSON responses include a "welcome:{..}" key
|
||||||
|
|
||||||
class Channel(resource.Resource):
|
class Channel(resource.Resource):
|
||||||
isLeaf = True # I handle /CHANNEL-ID/*
|
|
||||||
|
|
||||||
def __init__(self, channel_id, relay, db, welcome):
|
def __init__(self, channel_id, relay, db, welcome):
|
||||||
resource.Resource.__init__(self)
|
resource.Resource.__init__(self)
|
||||||
self.channel_id = channel_id
|
self.channel_id = channel_id
|
||||||
self.relay = relay
|
self.relay = relay
|
||||||
self.db = db
|
self.db = db
|
||||||
self.welcome = welcome
|
self.welcome = welcome
|
||||||
self.event_channels = set() # (side, msgnum, ep)
|
self.event_channels = set() # ep
|
||||||
|
self.putChild(b"deallocate", Deallocator(self.channel_id, self.relay))
|
||||||
|
|
||||||
def render_GET(self, request):
|
def get_messages(self, request):
|
||||||
# rest of URL is: SIDE/poll/MSGNUM
|
request.setHeader(b"content-type", b"application/json; charset=utf-8")
|
||||||
their_side = request.postpath[0].decode("utf-8")
|
messages = []
|
||||||
if request.postpath[1] != b"poll":
|
|
||||||
request.setResponseCode(http.BAD_REQUEST, b"GET to wrong URL")
|
|
||||||
return b"GET is only for /SIDE/poll/MSGNUM"
|
|
||||||
their_msgnum = request.postpath[2].decode("utf-8")
|
|
||||||
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
|
|
||||||
request.setResponseCode(http.BAD_REQUEST, b"Must use EventSource")
|
|
||||||
return b"Must use EventSource (Content-Type: text/event-stream)"
|
|
||||||
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
|
|
||||||
ep = EventsProtocol(request)
|
|
||||||
ep.sendEvent(json.dumps(self.welcome), name="welcome")
|
|
||||||
handle = (their_side, their_msgnum, ep)
|
|
||||||
self.event_channels.add(handle)
|
|
||||||
request.notifyFinish().addErrback(self._shutdown, handle)
|
|
||||||
for row in self.db.execute("SELECT * FROM `messages`"
|
for row in self.db.execute("SELECT * FROM `messages`"
|
||||||
" WHERE `channel_id`=?"
|
" WHERE `channel_id`=?"
|
||||||
" ORDER BY `when` ASC",
|
" ORDER BY `when` ASC",
|
||||||
(self.channel_id,)).fetchall():
|
(self.channel_id,)).fetchall():
|
||||||
self.message_added(row["side"], row["msgnum"], row["message"],
|
messages.append({"phase": row["phase"], "body": row["body"]})
|
||||||
channels=[handle])
|
data = {"welcome": self.welcome, "messages": messages}
|
||||||
|
return (json.dumps(data)+"\n").encode("utf-8")
|
||||||
|
|
||||||
|
def render_GET(self, request):
|
||||||
|
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
|
||||||
|
return self.get_messages(request)
|
||||||
|
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
|
||||||
|
ep = EventsProtocol(request)
|
||||||
|
ep.sendEvent(json.dumps(self.welcome), name="welcome")
|
||||||
|
self.event_channels.add(ep)
|
||||||
|
request.notifyFinish().addErrback(lambda f:
|
||||||
|
self.event_channels.discard(ep))
|
||||||
|
for row in self.db.execute("SELECT * FROM `messages`"
|
||||||
|
" WHERE `channel_id`=?"
|
||||||
|
" ORDER BY `when` ASC",
|
||||||
|
(self.channel_id,)).fetchall():
|
||||||
|
data = json.dumps({"phase": row["phase"], "body": row["body"]})
|
||||||
|
ep.sendEvent(data)
|
||||||
return server.NOT_DONE_YET
|
return server.NOT_DONE_YET
|
||||||
|
|
||||||
def _shutdown(self, _, handle):
|
def broadcast_message(self, phase, body):
|
||||||
self.event_channels.discard(handle)
|
data = json.dumps({"phase": phase, "body": body})
|
||||||
|
for ep in self.event_channels:
|
||||||
def message_added(self, msg_side, msg_msgnum, msg_str, channels=None):
|
ep.sendEvent(data)
|
||||||
if channels is None:
|
|
||||||
channels = self.event_channels
|
|
||||||
for (their_side, their_msgnum, their_ep) in channels:
|
|
||||||
if msg_side != their_side and msg_msgnum == their_msgnum:
|
|
||||||
data = json.dumps({ "side": msg_side, "message": msg_str })
|
|
||||||
their_ep.sendEvent(data)
|
|
||||||
|
|
||||||
def render_POST(self, request):
|
def render_POST(self, request):
|
||||||
# rest of URL is: SIDE/(MSGNUM|deallocate)/(post|poll)
|
|
||||||
side = request.postpath[0].decode("utf-8")
|
|
||||||
verb = request.postpath[1].decode("utf-8")
|
|
||||||
|
|
||||||
if verb == "deallocate":
|
|
||||||
deleted = self.relay.maybe_free_child(self.channel_id, side)
|
|
||||||
if deleted:
|
|
||||||
return b"deleted\n"
|
|
||||||
return b"waiting\n"
|
|
||||||
|
|
||||||
if verb not in ("post", "poll"):
|
|
||||||
request.setResponseCode(http.BAD_REQUEST)
|
|
||||||
return b"bad verb, want 'post' or 'poll'\n"
|
|
||||||
|
|
||||||
msgnum = request.postpath[2].decode("utf-8")
|
|
||||||
|
|
||||||
other_messages = []
|
|
||||||
for row in self.db.execute("SELECT `message` FROM `messages`"
|
|
||||||
" WHERE `channel_id`=? AND `side`!=?"
|
|
||||||
" AND `msgnum`=?"
|
|
||||||
" ORDER BY `when` ASC",
|
|
||||||
(self.channel_id, side, msgnum)).fetchall():
|
|
||||||
other_messages.append(row["message"])
|
|
||||||
|
|
||||||
if verb == "post":
|
|
||||||
#data = json.load(request.content, encoding="utf-8")
|
#data = json.load(request.content, encoding="utf-8")
|
||||||
content = request.content.read()
|
content = request.content.read()
|
||||||
data = json.loads(content.decode("utf-8"))
|
data = json.loads(content.decode("utf-8"))
|
||||||
|
|
||||||
|
side = data["side"]
|
||||||
|
phase = data["phase"]
|
||||||
|
if not isinstance(phase, type(u"")):
|
||||||
|
raise TypeError("phase must be string, not %s" % type(phase))
|
||||||
|
body = data["body"]
|
||||||
|
|
||||||
self.db.execute("INSERT INTO `messages`"
|
self.db.execute("INSERT INTO `messages`"
|
||||||
" (`channel_id`, `side`, `msgnum`, `message`, `when`)"
|
" (`channel_id`, `side`, `phase`, `body`, `when`)"
|
||||||
" VALUES (?,?,?,?,?)",
|
" VALUES (?,?,?,?,?)",
|
||||||
(self.channel_id, side, msgnum, data["message"],
|
(self.channel_id, side, phase, body, time.time()))
|
||||||
time.time()))
|
|
||||||
self.db.execute("INSERT INTO `allocations`"
|
self.db.execute("INSERT INTO `allocations`"
|
||||||
" (`channel_id`, `side`)"
|
" (`channel_id`, `side`)"
|
||||||
" VALUES (?,?)",
|
" VALUES (?,?)",
|
||||||
(self.channel_id, side))
|
(self.channel_id, side))
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
self.message_added(side, msgnum, data["message"])
|
self.broadcast_message(phase, body)
|
||||||
|
return self.get_messages(request)
|
||||||
|
|
||||||
request.setHeader(b"content-type", b"application/json; charset=utf-8")
|
class Deallocator(resource.Resource):
|
||||||
data = {"welcome": self.welcome,
|
def __init__(self, channel_id, relay):
|
||||||
"messages": other_messages}
|
self.channel_id = channel_id
|
||||||
return (json.dumps(data)+"\n").encode("utf-8")
|
self.relay = relay
|
||||||
|
|
||||||
|
def render_POST(self, request):
|
||||||
|
content = request.content.read()
|
||||||
|
data = json.loads(content.decode("utf-8"))
|
||||||
|
side = data["side"]
|
||||||
|
deleted = self.relay.maybe_free_child(self.channel_id, side)
|
||||||
|
resp = {"status": "waiting"}
|
||||||
|
if deleted:
|
||||||
|
resp = {"status": "deleted"}
|
||||||
|
return json.dumps(resp).encode("utf-8")
|
||||||
|
|
||||||
def get_allocated(db):
|
def get_allocated(db):
|
||||||
c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`")
|
c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`")
|
||||||
return set([row["channel_id"] for row in c.fetchall()])
|
return set([row["channel_id"] for row in c.fetchall()])
|
||||||
|
|
||||||
class Allocator(resource.Resource):
|
class Allocator(resource.Resource):
|
||||||
isLeaf = True
|
|
||||||
def __init__(self, db, welcome):
|
def __init__(self, db, welcome):
|
||||||
resource.Resource.__init__(self)
|
resource.Resource.__init__(self)
|
||||||
self.db = db
|
self.db = db
|
||||||
|
@ -179,7 +168,11 @@ class Allocator(resource.Resource):
|
||||||
raise ValueError("unable to find a free channel-id")
|
raise ValueError("unable to find a free channel-id")
|
||||||
|
|
||||||
def render_POST(self, request):
|
def render_POST(self, request):
|
||||||
side = request.postpath[0]
|
content = request.content.read()
|
||||||
|
data = json.loads(content.decode("utf-8"))
|
||||||
|
side = data["side"]
|
||||||
|
if not isinstance(side, type(u"")):
|
||||||
|
raise TypeError("side must be string, not '%s'" % type(side))
|
||||||
channel_id = self.allocate_channel_id()
|
channel_id = self.allocate_channel_id()
|
||||||
self.db.execute("INSERT INTO `allocations` VALUES (?,?)",
|
self.db.execute("INSERT INTO `allocations` VALUES (?,?)",
|
||||||
(channel_id, side))
|
(channel_id, side))
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
import sys
|
from __future__ import print_function
|
||||||
|
import sys, json
|
||||||
import requests
|
import requests
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor, defer
|
||||||
from twisted.internet.threads import deferToThread
|
from twisted.internet.threads import deferToThread
|
||||||
from twisted.web.client import getPage, Agent, readBody
|
from twisted.web.client import getPage, Agent, readBody
|
||||||
|
from .. import __version__
|
||||||
from .common import ServerBase
|
from .common import ServerBase
|
||||||
|
from ..twisted.eventsource_twisted import EventSource
|
||||||
|
|
||||||
class Reachable(ServerBase, unittest.TestCase):
|
class Reachable(ServerBase, unittest.TestCase):
|
||||||
|
|
||||||
|
@ -47,3 +50,243 @@ class Reachable(ServerBase, unittest.TestCase):
|
||||||
self.failUnlessEqual(res, "Wormhole Relay\n")
|
self.failUnlessEqual(res, "Wormhole Relay\n")
|
||||||
d.addCallback(_got)
|
d.addCallback(_got)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
def unjson(data):
|
||||||
|
return json.loads(data.decode("utf-8"))
|
||||||
|
|
||||||
|
class API(ServerBase, unittest.TestCase):
|
||||||
|
def get(self, path, is_json=True):
|
||||||
|
url = (self.relayurl+path).encode("ascii")
|
||||||
|
d = getPage(url)
|
||||||
|
if is_json:
|
||||||
|
d.addCallback(unjson)
|
||||||
|
return d
|
||||||
|
def post(self, path, data):
|
||||||
|
url = (self.relayurl+path).encode("ascii")
|
||||||
|
d = getPage(url, method=b"POST",
|
||||||
|
postdata=json.dumps(data).encode("utf-8"))
|
||||||
|
d.addCallback(unjson)
|
||||||
|
return d
|
||||||
|
|
||||||
|
def check_welcome(self, data):
|
||||||
|
self.failUnlessIn("welcome", data)
|
||||||
|
self.failUnlessEqual(data["welcome"], {"current_version": __version__})
|
||||||
|
|
||||||
|
def test_allocate_1(self):
|
||||||
|
d = self.get("list")
|
||||||
|
def _check_list_1(data):
|
||||||
|
self.check_welcome(data)
|
||||||
|
self.failUnlessEqual(data["channel-ids"], [])
|
||||||
|
d.addCallback(_check_list_1)
|
||||||
|
|
||||||
|
d.addCallback(lambda _: self.post("allocate", {"side": "abc"}))
|
||||||
|
def _allocated(data):
|
||||||
|
self.failUnlessEqual(set(data.keys()),
|
||||||
|
set(["welcome", "channel-id"]))
|
||||||
|
self.failUnlessIsInstance(data["channel-id"], int)
|
||||||
|
self.cid = data["channel-id"]
|
||||||
|
d.addCallback(_allocated)
|
||||||
|
|
||||||
|
d.addCallback(lambda _: self.get("list"))
|
||||||
|
def _check_list_2(data):
|
||||||
|
self.failUnlessEqual(data["channel-ids"], [self.cid])
|
||||||
|
d.addCallback(_check_list_2)
|
||||||
|
|
||||||
|
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
|
||||||
|
{"side": "abc"}))
|
||||||
|
def _check_deallocate(res):
|
||||||
|
self.failUnlessEqual(res["status"], "deleted")
|
||||||
|
d.addCallback(_check_deallocate)
|
||||||
|
|
||||||
|
d.addCallback(lambda _: self.get("list"))
|
||||||
|
def _check_list_3(data):
|
||||||
|
self.failUnlessEqual(data["channel-ids"], [])
|
||||||
|
d.addCallback(_check_list_3)
|
||||||
|
|
||||||
|
return d
|
||||||
|
|
||||||
|
def test_allocate_2(self):
|
||||||
|
d = self.post("allocate", {"side": "abc"})
|
||||||
|
def _allocated(data):
|
||||||
|
self.cid = data["channel-id"]
|
||||||
|
d.addCallback(_allocated)
|
||||||
|
|
||||||
|
# second caller increases the number of known sides to 2
|
||||||
|
d.addCallback(lambda _: self.post("%d" % self.cid,
|
||||||
|
{"side": "def",
|
||||||
|
"phase": "1",
|
||||||
|
"body": ""}))
|
||||||
|
|
||||||
|
d.addCallback(lambda _: self.get("list"))
|
||||||
|
d.addCallback(lambda data:
|
||||||
|
self.failUnlessEqual(data["channel-ids"], [self.cid]))
|
||||||
|
|
||||||
|
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
|
||||||
|
{"side": "abc"}))
|
||||||
|
d.addCallback(lambda res:
|
||||||
|
self.failUnlessEqual(res["status"], "waiting"))
|
||||||
|
|
||||||
|
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
|
||||||
|
{"side": "NOT"}))
|
||||||
|
d.addCallback(lambda res:
|
||||||
|
self.failUnlessEqual(res["status"], "waiting"))
|
||||||
|
|
||||||
|
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
|
||||||
|
{"side": "def"}))
|
||||||
|
d.addCallback(lambda res:
|
||||||
|
self.failUnlessEqual(res["status"], "deleted"))
|
||||||
|
|
||||||
|
d.addCallback(lambda _: self.get("list"))
|
||||||
|
d.addCallback(lambda data:
|
||||||
|
self.failUnlessEqual(data["channel-ids"], []))
|
||||||
|
|
||||||
|
return d
|
||||||
|
|
||||||
|
def add_message(self, message, side="abc", phase="1"):
|
||||||
|
return self.post(str(self.cid), {"side": side, "phase": phase,
|
||||||
|
"body": message})
|
||||||
|
|
||||||
|
def parse_messages(self, messages):
|
||||||
|
out = set()
|
||||||
|
for m in messages:
|
||||||
|
self.failUnlessEqual(sorted(m.keys()), sorted(["phase", "body"]))
|
||||||
|
self.failUnlessIsInstance(m["phase"], type(u""))
|
||||||
|
self.failUnlessIsInstance(m["body"], type(u""))
|
||||||
|
out.add( (m["phase"], m["body"]) )
|
||||||
|
return out
|
||||||
|
|
||||||
|
def check_messages(self, one, two):
|
||||||
|
# Comparing lists-of-dicts is non-trivial in python3 because we can
|
||||||
|
# neither sort them (dicts are uncomparable), nor turn them into sets
|
||||||
|
# (dicts are unhashable). This is close enough.
|
||||||
|
self.failUnlessEqual(len(one), len(two), (one,two))
|
||||||
|
for d in one:
|
||||||
|
self.failUnlessIn(d, two)
|
||||||
|
|
||||||
|
def test_messages(self):
|
||||||
|
d = self.post("allocate", {"side": "abc"})
|
||||||
|
def _allocated(data):
|
||||||
|
self.cid = data["channel-id"]
|
||||||
|
d.addCallback(_allocated)
|
||||||
|
|
||||||
|
d.addCallback(lambda _: self.add_message("msg1A"))
|
||||||
|
def _check1(data):
|
||||||
|
self.check_welcome(data)
|
||||||
|
self.failUnlessEqual(data["messages"],
|
||||||
|
[{"phase": "1", "body": "msg1A"}])
|
||||||
|
d.addCallback(_check1)
|
||||||
|
d.addCallback(lambda _: self.add_message("msg1B", side="def"))
|
||||||
|
def _check2(data):
|
||||||
|
self.check_welcome(data)
|
||||||
|
self.failUnlessEqual(self.parse_messages(data["messages"]),
|
||||||
|
set([("1", "msg1A"),
|
||||||
|
("1", "msg1B")]))
|
||||||
|
d.addCallback(_check2)
|
||||||
|
|
||||||
|
# adding a duplicate message is not an error, is ignored by clients
|
||||||
|
d.addCallback(lambda _: self.add_message("msg1B", side="def"))
|
||||||
|
def _check3(data):
|
||||||
|
self.check_welcome(data)
|
||||||
|
self.failUnlessEqual(self.parse_messages(data["messages"]),
|
||||||
|
set([("1", "msg1A"),
|
||||||
|
("1", "msg1B")]))
|
||||||
|
d.addCallback(_check3)
|
||||||
|
|
||||||
|
d.addCallback(lambda _: self.add_message("msg2A", side="abc",
|
||||||
|
phase="2"))
|
||||||
|
def _check4(data):
|
||||||
|
self.check_welcome(data)
|
||||||
|
self.failUnlessEqual(self.parse_messages(data["messages"]),
|
||||||
|
set([("1", "msg1A"),
|
||||||
|
("1", "msg1B"),
|
||||||
|
("2", "msg2A"),
|
||||||
|
]))
|
||||||
|
d.addCallback(_check4)
|
||||||
|
|
||||||
|
return d
|
||||||
|
|
||||||
|
def test_eventsource(self):
|
||||||
|
if sys.version_info[0] >= 3:
|
||||||
|
raise unittest.SkipTest("twisted vs py3")
|
||||||
|
|
||||||
|
d = self.post("allocate", {"side": "abc"})
|
||||||
|
def _allocated(data):
|
||||||
|
self.cid = data["channel-id"]
|
||||||
|
url = (self.relayurl+str(self.cid)).encode("utf-8")
|
||||||
|
self.o = OneEventAtATime(url, parser=json.loads)
|
||||||
|
return self.o.wait_for_connection()
|
||||||
|
d.addCallback(_allocated)
|
||||||
|
d.addCallback(lambda _: self.o.wait_for_next_event())
|
||||||
|
def _check_welcome(ev):
|
||||||
|
eventtype, data = ev
|
||||||
|
self.failUnlessEqual(eventtype, "welcome")
|
||||||
|
self.failUnlessEqual(data, {"current_version": __version__})
|
||||||
|
d.addCallback(_check_welcome)
|
||||||
|
d.addCallback(lambda _: self.add_message("msg1A"))
|
||||||
|
d.addCallback(lambda _: self.o.wait_for_next_event())
|
||||||
|
def _check_msg1(ev):
|
||||||
|
eventtype, data = ev
|
||||||
|
self.failUnlessEqual(eventtype, "message")
|
||||||
|
self.failUnlessEqual(data, {"phase": "1", "body": "msg1A"})
|
||||||
|
d.addCallback(_check_msg1)
|
||||||
|
|
||||||
|
d.addCallback(lambda _: self.add_message("msg1B"))
|
||||||
|
d.addCallback(lambda _: self.add_message("msg2A", phase="2"))
|
||||||
|
d.addCallback(lambda _: self.o.wait_for_next_event())
|
||||||
|
def _check_msg2(ev):
|
||||||
|
eventtype, data = ev
|
||||||
|
self.failUnlessEqual(eventtype, "message")
|
||||||
|
self.failUnlessEqual(data, {"phase": "1", "body": "msg1B"})
|
||||||
|
d.addCallback(_check_msg2)
|
||||||
|
d.addCallback(lambda _: self.o.wait_for_next_event())
|
||||||
|
def _check_msg3(ev):
|
||||||
|
eventtype, data = ev
|
||||||
|
self.failUnlessEqual(eventtype, "message")
|
||||||
|
self.failUnlessEqual(data, {"phase": "2", "body": "msg2A"})
|
||||||
|
d.addCallback(_check_msg3)
|
||||||
|
|
||||||
|
d.addCallback(lambda _: self.o.close())
|
||||||
|
d.addCallback(lambda _: self.o.wait_for_disconnection())
|
||||||
|
return d
|
||||||
|
|
||||||
|
class OneEventAtATime:
|
||||||
|
def __init__(self, url, parser=lambda e: e):
|
||||||
|
self.parser = parser
|
||||||
|
self.d = None
|
||||||
|
self.connected_d = defer.Deferred()
|
||||||
|
self.disconnected_d = defer.Deferred()
|
||||||
|
self.events = []
|
||||||
|
self.es = EventSource(url, self.handler, when_connected=self.connected)
|
||||||
|
d = self.es.start()
|
||||||
|
d.addBoth(self.disconnected)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.es.cancel()
|
||||||
|
|
||||||
|
def wait_for_next_event(self):
|
||||||
|
assert not self.d
|
||||||
|
if self.events:
|
||||||
|
event = self.events.pop(0)
|
||||||
|
return defer.succeed(event)
|
||||||
|
self.d = defer.Deferred()
|
||||||
|
return self.d
|
||||||
|
|
||||||
|
def handler(self, eventtype, data):
|
||||||
|
event = (eventtype, self.parser(data))
|
||||||
|
if self.d:
|
||||||
|
assert not self.events
|
||||||
|
d,self.d = self.d,None
|
||||||
|
d.callback(event)
|
||||||
|
return
|
||||||
|
self.events.append(event)
|
||||||
|
|
||||||
|
def wait_for_connection(self):
|
||||||
|
return self.connected_d
|
||||||
|
def connected(self):
|
||||||
|
self.connected_d.callback(None)
|
||||||
|
|
||||||
|
def wait_for_disconnection(self):
|
||||||
|
return self.disconnected_d
|
||||||
|
def disconnected(self, why):
|
||||||
|
self.disconnected_d.callback((why,))
|
||||||
|
|
||||||
|
|
|
@ -93,6 +93,7 @@ class EventSource: # TODO: service.Service
|
||||||
def start(self):
|
def start(self):
|
||||||
assert not self.started, "single-use"
|
assert not self.started, "single-use"
|
||||||
self.started = True
|
self.started = True
|
||||||
|
assert self.url
|
||||||
d = self.agent.request("GET", self.url,
|
d = self.agent.request("GET", self.url,
|
||||||
Headers({"accept": ["text/event-stream"]}))
|
Headers({"accept": ["text/event-stream"]}))
|
||||||
d.addCallback(self._connected)
|
d.addCallback(self._connected)
|
||||||
|
@ -154,14 +155,13 @@ class Connector:
|
||||||
|
|
||||||
class ReconnectingEventSource(service.MultiService,
|
class ReconnectingEventSource(service.MultiService,
|
||||||
protocol.ReconnectingClientFactory):
|
protocol.ReconnectingClientFactory):
|
||||||
def __init__(self, baseurl, connection_starting, handler, agent=None):
|
def __init__(self, url, handler, agent=None):
|
||||||
service.MultiService.__init__(self)
|
service.MultiService.__init__(self)
|
||||||
# we don't use any of the basic Factory/ClientFactory methods of
|
# we don't use any of the basic Factory/ClientFactory methods of
|
||||||
# this, just the ReconnectingClientFactory.retry, stopTrying, and
|
# this, just the ReconnectingClientFactory.retry, stopTrying, and
|
||||||
# resetDelay methods.
|
# resetDelay methods.
|
||||||
|
|
||||||
self.baseurl = baseurl
|
self.url = url
|
||||||
self.connection_starting = connection_starting
|
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
# IService provides self.running, toggled by {start,stop}Service.
|
# IService provides self.running, toggled by {start,stop}Service.
|
||||||
|
@ -201,8 +201,7 @@ class ReconnectingEventSource(service.MultiService,
|
||||||
if not (self.active and self.running):
|
if not (self.active and self.running):
|
||||||
return
|
return
|
||||||
self.continueTrying = True
|
self.continueTrying = True
|
||||||
url = self.connection_starting()
|
self.es = EventSource(self.url, self.handler, self.resetDelay,
|
||||||
self.es = EventSource(url, self.handler, self.resetDelay,
|
|
||||||
agent=self.agent)
|
agent=self.agent)
|
||||||
d = self.es.start()
|
d = self.es.start()
|
||||||
d.addBoth(self._stopped)
|
d.addBoth(self._stopped)
|
||||||
|
|
|
@ -13,8 +13,7 @@ from spake2 import SPAKE2_Symmetric
|
||||||
from .eventsource_twisted import ReconnectingEventSource
|
from .eventsource_twisted import ReconnectingEventSource
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .. import codes
|
from .. import codes
|
||||||
from ..errors import (ServerError, WrongPasswordError,
|
from ..errors import ServerError, WrongPasswordError, UsageError
|
||||||
ReflectionAttack, UsageError)
|
|
||||||
from ..util.hkdf import HKDF
|
from ..util.hkdf import HKDF
|
||||||
|
|
||||||
@implementer(IBodyProducer)
|
@implementer(IBodyProducer)
|
||||||
|
@ -46,12 +45,9 @@ class Wormhole:
|
||||||
self.code = None
|
self.code = None
|
||||||
self.key = None
|
self.key = None
|
||||||
self._started_get_code = False
|
self._started_get_code = False
|
||||||
|
self._channel_url = None
|
||||||
def _url(self, verb, msgnum=None):
|
self._messages = set() # (phase,body) , body is bytes
|
||||||
url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb)
|
self._sent_messages = set() # (phase,body)
|
||||||
if msgnum is not None:
|
|
||||||
url += "/" + msgnum
|
|
||||||
return url
|
|
||||||
|
|
||||||
def handle_welcome(self, welcome):
|
def handle_welcome(self, welcome):
|
||||||
if ("motd" in welcome and
|
if ("motd" in welcome and
|
||||||
|
@ -75,14 +71,10 @@ class Wormhole:
|
||||||
if "error" in welcome:
|
if "error" in welcome:
|
||||||
raise ServerError(welcome["error"], self.relay)
|
raise ServerError(welcome["error"], self.relay)
|
||||||
|
|
||||||
def _post_json(self, url, post_json=None):
|
def _post_json(self, url, post_json):
|
||||||
# POST to a URL, parsing the response as JSON. Optionally include a
|
# POST a JSON body to a URL, parsing the response as JSON
|
||||||
# JSON request body.
|
|
||||||
p = None
|
|
||||||
if post_json:
|
|
||||||
data = json.dumps(post_json).encode("utf-8")
|
data = json.dumps(post_json).encode("utf-8")
|
||||||
p = DataProducer(data)
|
d = self.agent.request("POST", url, bodyProducer=DataProducer(data))
|
||||||
d = self.agent.request("POST", url, bodyProducer=p)
|
|
||||||
def _check_error(resp):
|
def _check_error(resp):
|
||||||
if resp.code != 200:
|
if resp.code != 200:
|
||||||
raise web_error.Error(resp.code, resp.phrase)
|
raise web_error.Error(resp.code, resp.phrase)
|
||||||
|
@ -93,8 +85,8 @@ class Wormhole:
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def _allocate_channel(self):
|
def _allocate_channel(self):
|
||||||
url = self.relay + "allocate/%s" % self.side
|
url = self.relay + "allocate"
|
||||||
d = self._post_json(url)
|
d = self._post_json(url, {"side": self.side})
|
||||||
def _got_channel(data):
|
def _got_channel(data):
|
||||||
if "welcome" in data:
|
if "welcome" in data:
|
||||||
self.handle_welcome(data["welcome"])
|
self.handle_welcome(data["welcome"])
|
||||||
|
@ -131,6 +123,7 @@ class Wormhole:
|
||||||
if not mo:
|
if not mo:
|
||||||
raise ValueError("code (%s) must start with NN-" % code)
|
raise ValueError("code (%s) must start with NN-" % code)
|
||||||
self.channel_id = int(mo.group(1))
|
self.channel_id = int(mo.group(1))
|
||||||
|
self._channel_url = "%s%d" % (self.relay, self.channel_id)
|
||||||
self.code = code
|
self.code = code
|
||||||
|
|
||||||
def _start(self):
|
def _start(self):
|
||||||
|
@ -164,36 +157,56 @@ class Wormhole:
|
||||||
self.msg1 = d["msg1"].decode("hex")
|
self.msg1 = d["msg1"].decode("hex")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _post_message(self, url, msg):
|
def _add_inbound_messages(self, messages):
|
||||||
|
for msg in messages:
|
||||||
|
phase = msg["phase"]
|
||||||
|
body = unhexlify(msg["body"].encode("ascii"))
|
||||||
|
self._messages.add( (phase, body) )
|
||||||
|
|
||||||
|
def _find_inbound_message(self, phase):
|
||||||
|
for (their_phase,body) in self._messages - self._sent_messages:
|
||||||
|
if their_phase == phase:
|
||||||
|
return body
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _send_message(self, phase, msg):
|
||||||
# TODO: retry on failure, with exponential backoff. We're guarding
|
# TODO: retry on failure, with exponential backoff. We're guarding
|
||||||
# against the rendezvous server being temporarily offline.
|
# against the rendezvous server being temporarily offline.
|
||||||
|
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
||||||
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
|
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
|
||||||
d = self._post_json(url, {"message": hexlify(msg).decode("ascii")})
|
self._sent_messages.add( (phase,msg) )
|
||||||
d.addCallback(lambda resp: resp["messages"]) # other_msgs
|
payload = {"side": self.side,
|
||||||
|
"phase": phase,
|
||||||
|
"body": hexlify(msg).decode("ascii")}
|
||||||
|
d = self._post_json(self._channel_url, payload)
|
||||||
|
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def _get_message(self, old_msgs, verb, msgnum):
|
def _get_message(self, phase):
|
||||||
# fire with a bytestring of the first message that matches
|
# fire with a bytestring of the first message for 'phase' that wasn't
|
||||||
# verb/msgnum, which either came from old_msgs, or from an
|
# one of ours. It will either come from previously-received messages,
|
||||||
# EventSource that we attached to the corresponding URL
|
# or from an EventSource that we attach to the corresponding URL
|
||||||
if old_msgs:
|
body = self._find_inbound_message(phase)
|
||||||
msg = unhexlify(old_msgs[0].encode("ascii"))
|
if body is not None:
|
||||||
return defer.succeed(msg)
|
return defer.succeed(body)
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
msgs = []
|
msgs = []
|
||||||
def _handle(name, data):
|
def _handle(name, data):
|
||||||
if name == "welcome":
|
if name == "welcome":
|
||||||
self.handle_welcome(json.loads(data))
|
self.handle_welcome(json.loads(data))
|
||||||
if name == "message":
|
if name == "message":
|
||||||
msgs.append(json.loads(data)["message"])
|
self._add_inbound_messages([json.loads(data)])
|
||||||
|
body = self._find_inbound_message(phase)
|
||||||
|
if body is not None and not msgs:
|
||||||
|
msgs.append(body)
|
||||||
d.callback(None)
|
d.callback(None)
|
||||||
es = ReconnectingEventSource(None, lambda: self._url(verb, msgnum),
|
# TODO: use agent=self.agent
|
||||||
_handle)#, agent=self.agent)
|
es = ReconnectingEventSource(self._channel_url, _handle)
|
||||||
es.startService() # TODO: .setServiceParent(self)
|
es.startService() # TODO: .setServiceParent(self)
|
||||||
es.activate()
|
es.activate()
|
||||||
d.addCallback(lambda _: es.deactivate())
|
d.addCallback(lambda _: es.deactivate())
|
||||||
d.addCallback(lambda _: es.stopService())
|
d.addCallback(lambda _: es.stopService())
|
||||||
d.addCallback(lambda _: unhexlify(msgs[0].encode("ascii")))
|
d.addCallback(lambda _: msgs[0])
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||||||
|
@ -224,8 +237,8 @@ class Wormhole:
|
||||||
# TODO: prevent multiple invocation
|
# TODO: prevent multiple invocation
|
||||||
if self.key:
|
if self.key:
|
||||||
return defer.succeed(self.key)
|
return defer.succeed(self.key)
|
||||||
d = self._post_message(self._url("post", "pake"), self.msg1)
|
d = self._send_message(u"pake", self.msg1)
|
||||||
d.addCallback(lambda msgs: self._get_message(msgs, "poll", "pake"))
|
d.addCallback(lambda _: self._get_message(u"pake"))
|
||||||
def _got_pake(pake_msg):
|
def _got_pake(pake_msg):
|
||||||
key = self.sp.finish(pake_msg)
|
key = self.sp.finish(pake_msg)
|
||||||
self.key = key
|
self.key = key
|
||||||
|
@ -256,12 +269,12 @@ class Wormhole:
|
||||||
data_key = self.derive_key(b"data-key")
|
data_key = self.derive_key(b"data-key")
|
||||||
|
|
||||||
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
||||||
d = self._post_message(self._url("post", "data"), outbound_encrypted)
|
d = self._send_message(u"data", outbound_encrypted)
|
||||||
|
|
||||||
d.addCallback(lambda msgs: self._get_message(msgs, "poll", "data"))
|
d.addCallback(lambda _: self._get_message(u"data"))
|
||||||
def _got_data(inbound_encrypted):
|
def _got_data(inbound_encrypted):
|
||||||
if inbound_encrypted == outbound_encrypted:
|
#if inbound_encrypted == outbound_encrypted:
|
||||||
raise ReflectionAttack
|
# raise ReflectionAttack
|
||||||
try:
|
try:
|
||||||
inbound_data = self._decrypt_data(data_key, inbound_encrypted)
|
inbound_data = self._decrypt_data(data_key, inbound_encrypted)
|
||||||
return inbound_data
|
return inbound_data
|
||||||
|
@ -272,6 +285,7 @@ class Wormhole:
|
||||||
|
|
||||||
def _deallocate(self, res):
|
def _deallocate(self, res):
|
||||||
# only try once, no retries
|
# only try once, no retries
|
||||||
d = self.agent.request("POST", self._url("deallocate"))
|
d = self._post_json(self._channel_url+"/deallocate",
|
||||||
|
{"side": self.side})
|
||||||
d.addBoth(lambda _: res) # ignore POST failure, pass-through result
|
d.addBoth(lambda _: res) # ignore POST failure, pass-through result
|
||||||
return d
|
return d
|
||||||
|
|
Loading…
Reference in New Issue
Block a user