From 13a02df636cdfbcf156e30430e17663b63335e61 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 11 Feb 2015 01:05:11 -0800 Subject: [PATCH] implement relay, fix transcribe.py to use it properly --- src/wormhole/relay.py | 120 +++++++++++++++++++++++++++++++++++++ src/wormhole/transcribe.py | 65 +++++++++----------- 2 files changed, 148 insertions(+), 37 deletions(-) create mode 100644 src/wormhole/relay.py diff --git a/src/wormhole/relay.py b/src/wormhole/relay.py new file mode 100644 index 0000000..8e62e28 --- /dev/null +++ b/src/wormhole/relay.py @@ -0,0 +1,120 @@ +import re, json +from collections import defaultdict +from twisted.python import log +from twisted.application import strports, service +from twisted.web import server, static, resource, http + +class Channel(resource.Resource): + isLeaf = True + + # POST /CHANNEL-ID/SIDE/pake/post {message: STR} -> {messages: [STR..]} + # POST /CHANNEL-ID/SIDE/pake/poll -> {messages: [STR..]} + # POST /CHANNEL-ID/SIDE/data/post {message: STR} -> {messages: [STR..]} + # POST /CHANNEL-ID/SIDE/data/poll -> {messages: [STR..]} + # POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted + + def __init__(self, channel_id, relay): + resource.Resource.__init__(self) + self.channel_id = channel_id + self.relay = relay + self.sides = set() + self.messages = {"pake": defaultdict(list), # side -> [strings] + "data": defaultdict(list), # side -> [strings] + } + + def render_POST(self, request): + side = request.postpath[0] + self.sides.add(side) + which = request.postpath[1] + + if which == "deallocate": + self.sides.remove(side) + if self.sides: + return "waiting\n" + self.relay.free_child(self.channel_id) + return "deleted\n" + elif which in ("pake", "data"): + all_messages = self.messages[which] + messages = all_messages[side] + other_messages = [] + for other_side, other_msgs in all_messages.items(): + if other_side != side: + other_messages.extend(other_msgs) + else: + request.setResponseCode(http.BAD_REQUEST) + return "bad command, want 'pake' or 'data' or 'deallocate'\n" + + verb = request.postpath[2] + if verb not in ("post", "poll"): + request.setResponseCode(http.BAD_REQUEST) + return "bad verb, want 'post' or 'poll'\n" + + if verb == "post": + data = json.load(request.content) + messages.append(data["message"]) + + request.setHeader("content-type", "application/json; charset=utf-8") + return json.dumps({"messages": other_messages})+"\n" + +class Allocated(resource.Resource): + def __init__(self, channel_id): + resource.Resource.__init__(self) + self.channel_id = channel_id + def render_POST(self, request): + request.setHeader("content-type", "application/json; charset=utf-8") + return json.dumps({"channel-id": self.channel_id})+"\n" + + +class Relay(resource.Resource): + def __init__(self): + resource.Resource.__init__(self) + self.channels = {} + self.next_channel = 1 + + def getChild(self, path, request): + if path == "allocate": + # be more clever later. Rotate through 1-99 unless they're all + # full, then rotate through 1-999, etc. + channel_id = self.next_channel + self.next_channel += 1 + self.channels[channel_id] = Channel(channel_id, self) + log.msg("allocated %d, now have %d channels" % + (channel_id, len(self.channels))) + return Allocated(channel_id) + if not re.search(r'^\d+$', path): + return resource.ErrorPage(http.BAD_REQUEST, + "invalid channel id", + "invalid channel id") + channel_id = int(path) + if not channel_id in self.channels: + return resource.ErrorPage(http.NOT_FOUND, + "invalid channel id", + "invalid channel id") + return self.channels[channel_id] + + def free_child(self, channel_id): + self.channels.pop(channel_id) + log.msg("freed %d, now have %d channels" % + (channel_id, len(self.channels))) + +class Root(resource.Resource): + # child_FOO is a nevow thing, not a twisted.web.resource thing + def __init__(self): + resource.Resource.__init__(self) + self.putChild("", static.Data("Wormhole Relay\n", "text/plain")) + +class RelayServer(service.MultiService): + def __init__(self, listenport): + service.MultiService.__init__(self) + self.root = Root() + site = server.Site(self.root) + self.port_service = strports.service(listenport, site) + self.port_service.setServiceParent(self) + self.relay = Relay() # for tests + self.root.putChild("relay", self.relay) + + def get_root(self): + return self.root + +application = service.Application("foo") +RelayServer("tcp:8009").setServiceParent(application) diff --git a/src/wormhole/transcribe.py b/src/wormhole/transcribe.py index 79521b4..6f7215c 100644 --- a/src/wormhole/transcribe.py +++ b/src/wormhole/transcribe.py @@ -10,12 +10,12 @@ MINUTE = 60*SECOND class Timeout(Exception): pass -# POST /allocate -> {channel-id: INT} -# POST /pake/post/CHANNEL-ID {side: STR, message: STR} -> {messages: [STR..]} -# POST /pake/poll/CHANNEL-ID {side: STR} -> {messages: [STR..]} -# POST /data/post/CHANNEL-ID {side: STR, message: STR} -> {messages: [STR..]} -# POST /data/poll/CHANNEL-ID {side: STR} -> {messages: [STR..]} -# POST /deallocate/CHANNEL-ID {side: STR} -> waiting | ok +# POST /allocate -> {channel-id: INT} +# POST /CHANNEL-ID/SIDE/pake/post {message: STR} -> {messages: [STR..]} +# POST /CHANNEL-ID/SIDE/pake/poll -> {messages: [STR..]} +# POST /CHANNEL-ID/SIDE/data/post {message: STR} -> {messages: [STR..]} +# POST /CHANNEL-ID/SIDE/data/poll -> {messages: [STR..]} +# POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted class Initiator: def __init__(self, appid, data, relay=RELAY): @@ -33,24 +33,22 @@ class Initiator: r = requests.post(self.relay + "allocate", data="{}") r.raise_for_status() self.channel_id = r.json()["channel-id"] - self.code = codes.make_code(self.channel_id) + self.code = make_code(self.channel_id) self.sp = SPAKE2_A(self.code.encode("ascii"), idA=self.appid+":Initiator", idB=self.appid+":Receiver") msg = self.sp.start() - post_url = self.relay + "pake/post/%d" % self.channel_id - post_data = {"side": self.side, - "message": hexlify(msg).decode("ascii")} + post_url = self.relay + "pake/post/%d/%s" % (self.channel_id, self.side) + post_data = {"message": hexlify(msg).decode("ascii")} r = requests.post(post_url, data=json.dumps(post_data)) r.raise_for_status() return self.code def get_data(self): # poll for PAKE response - pake_url = self.relay + "pake/poll/%d" % self.channel_id - post_data = json.dumps({"side": self.side}) + pake_url = self.relay + "pake/poll/%d/%s" % (self.channel_id, self.side) while True: - r = requests.post(pake_url, data=post_data) + r = requests.post(pake_url, data="{}") r.raise_for_status() msgs = r.json()["messages"] if msgs: @@ -62,17 +60,15 @@ class Initiator: self.key = self.sp.finish(pake_msg) # post encrypted data - post_url = self.relay + "data/post/%d" % self.channel_id - post_data = json.dumps({"side": self.side, - "message": hexlify(self.data).decode("ascii")}) + post_url = self.relay + "data/post/%d/%s" % (self.channel_id, self.side) + post_data = json.dumps({"message": hexlify(self.data).decode("ascii")}) r = requests.post(post_url, data=post_data) r.raise_for_status() # poll for data message - data_url = self.relay + "data/poll/%d" % self.channel_id - post_data = json.dumps({"side": self.side}) + data_url = self.relay + "data/poll/%d/%s" % (self.channel_id, self.side) while True: - r = requests.post(data_url, data=post_data) + r = requests.post(data_url, data="{}") r.raise_for_status() msgs = r.json()["messages"] if msgs: @@ -83,9 +79,8 @@ class Initiator: data = unhexlify(msgs[0].encode("ascii")) # deallocate channel - deallocate_url = self.relay + "deallocate/%s" % self.channel_id - post_data = json.dumps({"side": self.side}) - r = requests.post(deallocate, data=post_data) + deallocate_url = self.relay + "deallocate/%s/%s" % (self.channel_id, self.side) + r = requests.post(deallocate_url, data="{}") r.raise_for_status() return data @@ -109,17 +104,15 @@ class Receiver: def get_data(self): # post PAKE message msg = self.sp.start() - post_url = self.relay + "pake/post/%d" % self.channel_id - post_data = {"side": self.side, - "message": hexlify(msg).decode("ascii")} + post_url = self.relay + "pake/post/%d/%s" % (self.channel_id, self.side) + post_data = {"message": hexlify(msg).decode("ascii")} r = requests.post(post_url, data=json.dumps(post_data)) r.raise_for_status() # poll for PAKE response - pake_url = self.relay + "pake/poll/%d" % self.channel_id - post_data = json.dumps({"side": self.side}) + pake_url = self.relay + "pake/poll/%d/%s" % (self.channel_id, self.side) while True: - r = requests.post(pake_url, data=post_data) + r = requests.post(pake_url, data="{}") r.raise_for_status() msgs = r.json()["messages"] if msgs: @@ -131,17 +124,15 @@ class Receiver: self.key = self.sp.finish(pake_msg) # post data message - post_url = self.relay + "data/post/%d" % self.channel_id - post_data = json.dumps({"side": self.side, - "message": hexlify(self.data).decode("ascii")}) + post_url = self.relay + "data/post/%d/%s" % (self.channel_id, self.side) + post_data = json.dumps({"message": hexlify(self.data).decode("ascii")}) r = requests.post(post_url, data=post_data) r.raise_for_status() # poll for data message - data_url = self.relay + "data/poll/%d" % self.channel_id - post_data = json.dumps({"side": self.side}) + data_url = self.relay + "data/poll/%d/%s" % (self.channel_id, self.side) while True: - r = requests.post(data_url, data=post_data) + r = requests.post(data_url, data="{}") r.raise_for_status() msgs = r.json()["messages"] if msgs: @@ -152,9 +143,9 @@ class Receiver: data = unhexlify(msgs[0].encode("ascii")) # deallocate channel - deallocate_url = self.relay + "deallocate/%s" % self.channel_id - post_data = json.dumps({"side": self.side}) - r = requests.post(deallocate, data=post_data) + deallocate_url = self.relay + "deallocate/%s/%s" % (self.channel_id, + self.side) + r = requests.post(deallocate_url, data="{}") r.raise_for_status() return data