diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index 0a97053..420d513 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -1,5 +1,6 @@ from __future__ import print_function import os, sys, time, re, requests, json, unicodedata +from six.moves.urllib_parse import urlencode from binascii import hexlify, unhexlify from spake2 import SPAKE2_Symmetric from nacl.secret import SecretBox @@ -18,19 +19,21 @@ MINUTE = 60*SECOND def to_bytes(u): return unicodedata.normalize("NFC", u).encode("utf-8") -# relay URLs are: -# GET /list -> {channelids: [INT..]} -# POST /allocate {side: SIDE} -> {channelid: INT} -# these return all messages (base64) for CID= : -# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]} -# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]} -# GET /CID (eventsource) -> {phase:, body:}.. -# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted} +# relay URLs are as follows: (MESSAGES=[{phase:,body:}..]) +# GET /list?appid= -> {channelids: [INT..]} +# POST /allocate {appid:,side:} -> {channelid: INT} +# these return all messages (base64) for appid=/channelid= : +# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES} +# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES} +# GET /get?appid=&channelid= (eventsource) -> {phase:, body:}.. +# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted} # all JSON responses include a "welcome:{..}" key class Channel: - def __init__(self, relay_url, channelid, side, handle_welcome): - self._channel_url = u"%s%d" % (relay_url, channelid) + def __init__(self, relay_url, appid, channelid, side, handle_welcome): + self._relay_url = relay_url + self._appid = appid + self._channelid = channelid self._side = side self._handle_welcome = handle_welcome self._messages = set() # (phase,body) , body is bytes @@ -57,11 +60,13 @@ class Channel: if not isinstance(phase, type(u"")): raise UsageError(type(phase)) if not isinstance(msg, type(b"")): raise UsageError(type(msg)) self._sent_messages.add( (phase,msg) ) - payload = {"side": self._side, + payload = {"appid": self._appid, + "channelid": self._channelid, + "side": self._side, "phase": phase, "body": hexlify(msg).decode("ascii")} data = json.dumps(payload).encode("utf-8") - r = requests.post(self._channel_url, data=data) + r = requests.post(self._relay_url+"add", data=data) r.raise_for_status() resp = r.json() self._add_inbound_messages(resp["messages"]) @@ -80,7 +85,10 @@ class Channel: remaining = self._started + self._timeout - time.time() if remaining < 0: return Timeout - f = EventSourceFollower(self._channel_url, remaining) + queryargs = urlencode([("appid", self._appid), + ("channelid", self._channelid)]) + f = EventSourceFollower(self._relay_url+"get?%s" % queryargs, + remaining) # we loop here until the connection is lost, or we see the # message we want for (eventtype, data) in f.iter_events(): @@ -98,25 +106,30 @@ class Channel: def deallocate(self): # only try once, no retries - data = json.dumps({"side": self._side}).encode("utf-8") - requests.post(self._channel_url+"/deallocate", data=data) + data = json.dumps({"appid": self._appid, + "channelid": self._channelid, + "side": self._side}).encode("utf-8") + requests.post(self._relay_url+"deallocate", data=data) # ignore POST failure, don't call r.raise_for_status() class ChannelManager: - def __init__(self, relay_url, side, handle_welcome): + def __init__(self, relay_url, appid, side, handle_welcome): self._relay_url = relay_url + self._appid = appid self._side = side self._handle_welcome = handle_welcome def list_channels(self): - r = requests.get(self._relay_url + "list") + queryargs = urlencode([("appid", self._appid)]) + r = requests.get(self._relay_url+"list?%s" % queryargs) r.raise_for_status() channelids = r.json()["channelids"] return channelids def allocate(self): - data = json.dumps({"side": self._side}).encode("utf-8") - r = requests.post(self._relay_url + "allocate", data=data) + data = json.dumps({"appid": self._appid, + "side": self._side}).encode("utf-8") + r = requests.post(self._relay_url+"allocate", data=data) r.raise_for_status() data = r.json() if "welcome" in data: @@ -125,7 +138,7 @@ class ChannelManager: return channelid def connect(self, channelid): - return Channel(self._relay_url, channelid, self._side, + return Channel(self._relay_url, self._appid, channelid, self._side, self._handle_welcome) class Wormhole: @@ -139,7 +152,7 @@ class Wormhole: self._appid = appid self._relay_url = relay_url side = hexlify(os.urandom(5)).decode("ascii") - self._channel_manager = ChannelManager(relay_url, side, + self._channel_manager = ChannelManager(relay_url, appid, side, self.handle_welcome) self.code = None self.key = None @@ -152,8 +165,7 @@ class Wormhole: not self.motd_displayed): motd_lines = welcome["motd"].splitlines() motd_formatted = "\n ".join(motd_lines) - print("Server (at %s) says:\n %s" % (self._relay_url, - motd_formatted), + print("Server (at %s) says:\n %s" % (self._relay_url, motd_formatted), file=sys.stderr) self.motd_displayed = True diff --git a/src/wormhole/db-schemas/v1.sql b/src/wormhole/db-schemas/v1.sql index 3858140..9f7872d 100644 --- a/src/wormhole/db-schemas/v1.sql +++ b/src/wormhole/db-schemas/v1.sql @@ -9,16 +9,18 @@ CREATE TABLE `version` CREATE TABLE `messages` ( + `appid` VARCHAR, `channelid` INTEGER, `side` VARCHAR, `phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string `body` VARCHAR, `when` INTEGER ); -CREATE INDEX `messages_idx` ON `messages` (`channelid`, `side`, `phase`); +CREATE INDEX `messages_idx` ON `messages` (`appid`, `channelid`); CREATE TABLE `allocations` ( + `appid` VARCHAR, `channelid` INTEGER, `side` VARCHAR ); diff --git a/src/wormhole/servers/relay_server.py b/src/wormhole/servers/relay_server.py index 56f2ed1..e9884ce 100644 --- a/src/wormhole/servers/relay_server.py +++ b/src/wormhole/servers/relay_server.py @@ -1,8 +1,8 @@ from __future__ import print_function -import re, json, time, random +import json, time, random from twisted.python import log from twisted.application import service, internet -from twisted.web import server, resource, http +from twisted.web import server, resource SECONDS = 1.0 MINUTE = 60*SECONDS @@ -13,6 +13,10 @@ MB = 1000*1000 CHANNEL_EXPIRATION_TIME = 3*DAY EXPIRATION_CHECK_PERIOD = 2*HOUR +def json_response(request, data): + request.setHeader(b"content-type", b"application/json; charset=utf-8") + return (json.dumps(data)+"\n").encode("utf-8") + class EventsProtocol: def __init__(self, request): self.request = request @@ -46,109 +50,185 @@ class EventsProtocol: # note: no versions of IE (including the current IE11) support EventSource -# relay URLs are: -# GET /list -> {channelids: [INT..]} -# POST /allocate {side: SIDE} -> {channelid: INT} -# these return all messages (base64) for CID= : -# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]} -# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]} -# GET /CID (eventsource) -> {phase:, body:}.. -# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted} +# relay URLs are as follows: (MESSAGES=[{phase:,body:}..]) +# GET /list?appid= -> {channelids: [INT..]} +# POST /allocate {appid:,side:} -> {channelid: INT} +# these return all messages (base64) for appid=/channelid= : +# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES} +# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES} +# GET /get?appid=&channelid= (eventsource) -> {phase:, body:}.. +# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted} # all JSON responses include a "welcome:{..}" key -class Channel(resource.Resource): - def __init__(self, channelid, relay, db, welcome): +class ChannelLister(resource.Resource): + def __init__(self, relay): resource.Resource.__init__(self) - self.channelid = channelid - self.relay = relay - self.db = db - self.welcome = welcome - self.event_channels = set() # ep - self.putChild(b"deallocate", Deallocator(self.channelid, self.relay)) - - def get_messages(self, request): - request.setHeader(b"content-type", b"application/json; charset=utf-8") - messages = [] - for row in self.db.execute("SELECT * FROM `messages`" - " WHERE `channelid`=?" - " ORDER BY `when` ASC", - (self.channelid,)).fetchall(): - messages.append({"phase": row["phase"], "body": row["body"]}) - data = {"welcome": self.welcome, "messages": messages} - return (json.dumps(data)+"\n").encode("utf-8") + self._relay = relay def render_GET(self, request): - if b"text/event-stream" not in (request.getHeader(b"accept") or b""): - return self.get_messages(request) - request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") - ep = EventsProtocol(request) - ep.sendEvent(json.dumps(self.welcome), name="welcome") - self.event_channels.add(ep) - request.notifyFinish().addErrback(lambda f: - self.event_channels.discard(ep)) - for row in self.db.execute("SELECT * FROM `messages`" - " WHERE `channelid`=?" - " ORDER BY `when` ASC", - (self.channelid,)).fetchall(): - data = json.dumps({"phase": row["phase"], "body": row["body"]}) - ep.sendEvent(data) - return server.NOT_DONE_YET + appid = request.args[b"appid"][0].decode("utf-8") + #print("LIST", appid) + app = self._relay.get_app(appid) + allocated = app.get_allocated() + request.setHeader(b"content-type", b"application/json; charset=utf-8") + data = {"welcome": self._relay.welcome, + "channelids": sorted(allocated)} + return (json.dumps(data)+"\n").encode("utf-8") - def broadcast_message(self, phase, body): - data = json.dumps({"phase": phase, "body": body}) - for ep in self.event_channels: - ep.sendEvent(data) +class Allocator(resource.Resource): + def __init__(self, relay): + resource.Resource.__init__(self) + self._relay = relay def render_POST(self, request): - #data = json.load(request.content, encoding="utf-8") content = request.content.read() data = json.loads(content.decode("utf-8")) + appid = data["appid"] + side = data["side"] + if not isinstance(side, type(u"")): + raise TypeError("side must be string, not '%s'" % type(side)) + #print("ALLOCATE", appid, side) + app = self._relay.get_app(appid) + channelid = app.find_available_channelid() + app.allocate_channel(channelid, side) + log.msg("allocated #%d, now have %d DB channels" % + (channelid, len(app.get_allocated()))) + request.setHeader(b"content-type", b"application/json; charset=utf-8") + data = {"welcome": self._relay.welcome, + "channelid": channelid} + return (json.dumps(data)+"\n").encode("utf-8") +class Adder(resource.Resource): + def __init__(self, relay): + resource.Resource.__init__(self) + self._relay = relay + + def render_POST(self, request): + #content = json.load(request.content, encoding="utf-8") + content = request.content.read() + data = json.loads(content.decode("utf-8")) + appid = data["appid"] + channelid = int(data["channelid"]) side = data["side"] phase = data["phase"] if not isinstance(phase, type(u"")): raise TypeError("phase must be string, not %s" % type(phase)) body = data["body"] + #print("ADD", appid, channelid, side, phase, body) - self.db.execute("INSERT INTO `messages`" - " (`channelid`, `side`, `phase`, `body`, `when`)" - " VALUES (?,?,?,?,?)", - (self.channelid, side, phase, body, time.time())) - self.db.execute("INSERT INTO `allocations`" - " (`channelid`, `side`)" - " VALUES (?,?)", - (self.channelid, side)) - self.db.commit() - self.broadcast_message(phase, body) - return self.get_messages(request) + app = self._relay.get_app(appid) + channel = app.get_channel(channelid) + response = channel.add_message(side, phase, body) + return json_response(request, response) + +class Getter(resource.Resource): + def __init__(self, relay): + self._relay = relay + + def render_GET(self, request): + appid = request.args[b"appid"][0].decode("utf-8") + channelid = int(request.args[b"channelid"][0]) + #print("GET", appid, channelid) + app = self._relay.get_app(appid) + channel = app.get_channel(channelid) + + if b"text/event-stream" not in (request.getHeader(b"accept") or b""): + response = channel.get_messages() + return json_response(request, response) + + request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") + ep = EventsProtocol(request) + ep.sendEvent(json.dumps(self._relay.welcome), name="welcome") + old_events = channel.add_listener(ep.sendEvent) + request.notifyFinish().addErrback(lambda f: + channel.remove_listener(ep.sendEvent)) + for old_event in old_events: + ep.sendEvent(old_event) + return server.NOT_DONE_YET class Deallocator(resource.Resource): - def __init__(self, channelid, relay): - self.channelid = channelid - self.relay = relay + def __init__(self, relay): + self._relay = relay def render_POST(self, request): content = request.content.read() data = json.loads(content.decode("utf-8")) + appid = data["appid"] + channelid = int(data["channelid"]) side = data["side"] - deleted = self.relay.maybe_free_child(self.channelid, side) - resp = {"status": "waiting"} + #print("DEALLOCATE", appid, channelid, side) + app = self._relay.get_app(appid) + deleted = app.maybe_free_child(channelid, side) + response = {"status": "waiting"} if deleted: - resp = {"status": "deleted"} - return json.dumps(resp).encode("utf-8") + response = {"status": "deleted"} + return json_response(request, response) -def get_allocated(db): - c = db.execute("SELECT DISTINCT `channelid` FROM `allocations`") - return set([row["channelid"] for row in c.fetchall()]) - -class Allocator(resource.Resource): - def __init__(self, db, welcome): +class Channel(resource.Resource): + def __init__(self, relay, appid, channelid): resource.Resource.__init__(self) - self.db = db - self.welcome = welcome + self._relay = relay + self._appid = appid + self._channelid = channelid + self._listeners = set() # callbacks that take JSONable object - def allocate_channelid(self): - allocated = get_allocated(self.db) + def get_messages(self): + messages = [] + db = self._relay.db + for row in db.execute("SELECT * FROM `messages`" + " WHERE `appid`=? AND `channelid`=?" + " ORDER BY `when` ASC", + (self._appid, self._channelid)).fetchall(): + messages.append({"phase": row["phase"], "body": row["body"]}) + data = {"welcome": self._relay.welcome, "messages": messages} + return data + + def add_listener(self, listener): + self._listeners.add(listener) + db = self._relay.db + for row in db.execute("SELECT * FROM `messages`" + " WHERE `appid`=? AND `channelid`=?" + " ORDER BY `when` ASC", + (self._appid, self._channelid)).fetchall(): + yield json.dumps({"phase": row["phase"], "body": row["body"]}) + def remove_listener(self, listener): + self._listeners.discard(listener) + + def broadcast_message(self, phase, body): + data = json.dumps({"phase": phase, "body": body}) + for listener in self._listeners: + listener(data) + + def add_message(self, side, phase, body): + db = self._relay.db + db.execute("INSERT INTO `messages`" + " (`appid`, `channelid`, `side`, `phase`, `body`, `when`)" + " VALUES (?,?,?,?, ?,?)", + (self._appid, self._channelid, side, phase, + body, time.time())) + db.execute("INSERT INTO `allocations`" + " (`appid`, `channelid`, `side`)" + " VALUES (?,?,?)", + (self._appid, self._channelid, side)) + db.commit() + self.broadcast_message(phase, body) + return self.get_messages() + +class AppNamespace(resource.Resource): + def __init__(self, relay, appid): + resource.Resource.__init__(self) + self._relay = relay + self._appid = appid + self._channels = {} + + def get_allocated(self): + db = self._relay.db + c = db.execute("SELECT DISTINCT `channelid` FROM `allocations`" + " WHERE `appid`=?", (self._appid,)) + return set([row["channelid"] for row in c.fetchall()]) + + def find_available_channelid(self): + allocated = self.get_allocated() for size in range(1,4): # stick to 1-999 for now available = set() for cid in range(10**(size-1), 10**size): @@ -161,37 +241,63 @@ class Allocator(resource.Resource): cid = random.randrange(1000, 1000*1000) if cid not in allocated: return cid - raise ValueError("unable to find a free channelid") + raise ValueError("unable to find a free channel-id") - def render_POST(self, request): - content = request.content.read() - data = json.loads(content.decode("utf-8")) - side = data["side"] - if not isinstance(side, type(u"")): - raise TypeError("side must be string, not '%s'" % type(side)) - channelid = self.allocate_channelid() - self.db.execute("INSERT INTO `allocations` VALUES (?,?)", - (channelid, side)) - self.db.commit() - log.msg("allocated #%d, now have %d DB channels" % - (channelid, len(get_allocated(self.db)))) - request.setHeader(b"content-type", b"application/json; charset=utf-8") - data = {"welcome": self.welcome, - "channelid": channelid} - return (json.dumps(data)+"\n").encode("utf-8") + def allocate_channel(self, channelid, side): + db = self._relay.db + db.execute("INSERT INTO `allocations` VALUES (?,?,?)", + (self._appid, channelid, side)) + db.commit() -class ChannelList(resource.Resource): - def __init__(self, db, welcome): - resource.Resource.__init__(self) - self.db = db - self.welcome = welcome - def render_GET(self, request): - c = self.db.execute("SELECT DISTINCT `channelid` FROM `allocations`") - allocated = sorted(set([row["channelid"] for row in c.fetchall()])) - request.setHeader(b"content-type", b"application/json; charset=utf-8") - data = {"welcome": self.welcome, - "channelids": allocated} - return (json.dumps(data)+"\n").encode("utf-8") + def get_channel(self, channelid): + assert isinstance(channelid, int) + if not channelid in self._channels: + log.msg("spawning #%d for appid %s" % (channelid, self._appid)) + self._channels[channelid] = Channel(self._relay, + self._appid, channelid) + return self._channels[channelid] + + def maybe_free_child(self, channelid, side): + db = self._relay.db + db.execute("DELETE FROM `allocations`" + " WHERE `appid`=? AND `channelid`=? AND `side`=?", + (self._appid, channelid, side)) + db.commit() + remaining = db.execute("SELECT COUNT(*) FROM `allocations`" + " WHERE `appid`=? AND `channelid`=?", + (self._appid, channelid)).fetchone()[0] + if remaining: + return False + self._free_child(channelid) + return True + + def _free_child(self, channelid): + db = self._relay.db + db.execute("DELETE FROM `allocations`" + " WHERE `appid`=? AND `channelid`=?", + (self._appid, channelid)) + db.execute("DELETE FROM `messages`" + " WHERE `appid`=? AND `channelid`=?", + (self._appid, channelid)) + db.commit() + if channelid in self._channels: + self._channels.pop(channelid) + log.msg("freed+killed #%d, now have %d DB channels, %d live" % + (channelid, len(self.get_allocated()), len(self._channels))) + + def prune_old_channels(self): + db = self._relay.db + old = time.time() - CHANNEL_EXPIRATION_TIME + for channelid in self.get_allocated(): + c = db.execute("SELECT `when` FROM `messages`" + " WHERE `appid`=? AND `channelid`=?" + " ORDER BY `when` DESC LIMIT 1", + (self._appid, channelid)) + rows = c.fetchall() + if not rows or (rows[0]["when"] < old): + log.msg("expiring %d" % channelid) + self._free_child(channelid) + return bool(self._channels) class Relay(resource.Resource, service.MultiService): def __init__(self, db, welcome): @@ -199,60 +305,24 @@ class Relay(resource.Resource, service.MultiService): service.MultiService.__init__(self) self.db = db self.welcome = welcome - self.channels = {} - t = internet.TimerService(EXPIRATION_CHECK_PERIOD, - self.prune_old_channels) + self._apps = {} + t = internet.TimerService(EXPIRATION_CHECK_PERIOD, self.prune) t.setServiceParent(self) + self.putChild(b"list", ChannelLister(self)) + self.putChild(b"allocate", Allocator(self)) + self.putChild(b"add", Adder(self)) + self.putChild(b"get", Getter(self)) + self.putChild(b"deallocate", Deallocator(self)) + def get_app(self, appid): + assert isinstance(appid, type(u"")) + if not appid in self._apps: + log.msg("spawning appid %s" % (appid,)) + self._apps[appid] = AppNamespace(self, appid) + return self._apps[appid] - def getChild(self, path, request): - if path == b"allocate": - return Allocator(self.db, self.welcome) - if path == b"list": - return ChannelList(self.db, self.welcome) - if not re.search(br'^\d+$', path): - return resource.ErrorPage(http.BAD_REQUEST, - "invalid channel id", - "invalid channel id") - channelid = int(path) - if not channelid in self.channels: - log.msg("spawning #%d" % channelid) - self.channels[channelid] = Channel(channelid, self, self.db, - self.welcome) - return self.channels[channelid] - - def maybe_free_child(self, channelid, side): - self.db.execute("DELETE FROM `allocations`" - " WHERE `channelid`=? AND `side`=?", - (channelid, side)) - self.db.commit() - remaining = self.db.execute("SELECT COUNT(*) FROM `allocations`" - " WHERE `channelid`=?", - (channelid,)).fetchone()[0] - if remaining: - return False - self.free_child(channelid) - return True - - def free_child(self, channelid): - self.db.execute("DELETE FROM `allocations` WHERE `channelid`=?", - (channelid,)) - self.db.execute("DELETE FROM `messages` WHERE `channelid`=?", - (channelid,)) - self.db.commit() - if channelid in self.channels: - self.channels.pop(channelid) - log.msg("freed+killed #%d, now have %d DB channels, %d live" % - (channelid, len(get_allocated(self.db)), len(self.channels))) - - def prune_old_channels(self): - old = time.time() - CHANNEL_EXPIRATION_TIME - for channelid in get_allocated(self.db): - c = self.db.execute("SELECT `when` FROM `messages`" - " WHERE `channelid`=?" - " ORDER BY `when` DESC LIMIT 1", (channelid,)) - rows = c.fetchall() - if not rows or (rows[0]["when"] < old): - log.msg("expiring %d" % channelid) - self.free_child(channelid) - + def prune(self): + for appid in list(self._apps): + still_active = self._apps[appid].prune_old_channels() + if not still_active: + self._apps.pop(appid) diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index 068249f..b27a08a 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -14,6 +14,7 @@ class ServerBase: "tcp:%s:interface=127.0.0.1" % transitport, __version__) s.setServiceParent(self.sp) + self._relay_server = s.relay self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport self.transit = "tcp:127.0.0.1:%d" % transitport d.addCallback(_got_ports) diff --git a/src/wormhole/test/test_blocking.py b/src/wormhole/test/test_blocking.py index 977e47d..65e5be0 100644 --- a/src/wormhole/test/test_blocking.py +++ b/src/wormhole/test/test_blocking.py @@ -1,12 +1,89 @@ +from __future__ import print_function import json from twisted.trial import unittest -from twisted.internet.defer import gatherResults +from twisted.internet.defer import gatherResults, succeed from twisted.internet.threads import deferToThread -from ..blocking.transcribe import Wormhole as BlockingWormhole, UsageError +from ..blocking.transcribe import (Wormhole as BlockingWormhole, UsageError, + ChannelManager) from .common import ServerBase APPID = u"appid" +class Channel(ServerBase, unittest.TestCase): + def ignore(self, welcome): + pass + + def test_allocate(self): + cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore) + d = deferToThread(cm.list_channels) + def _got_channels(channels): + self.failUnlessEqual(channels, []) + d.addCallback(_got_channels) + d.addCallback(lambda _: deferToThread(cm.allocate)) + def _allocated(channelid): + self.failUnlessEqual(type(channelid), int) + self._channelid = channelid + d.addCallback(_allocated) + d.addCallback(lambda _: deferToThread(cm.connect, self._channelid)) + def _connected(c): + self._channel = c + d.addCallback(_connected) + d.addCallback(lambda _: deferToThread(self._channel.deallocate)) + return d + + def test_messages(self): + cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore) + cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore) + c1 = cm1.connect(1) + c2 = cm2.connect(1) + + d = succeed(None) + d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1")) + d.addCallback(lambda _: deferToThread(c2.get, u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1")) + d.addCallback(lambda _: deferToThread(c2.send, u"phase1", b"msg2")) + d.addCallback(lambda _: deferToThread(c1.get, u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2")) + # it's legal to fetch a phase multiple times, should be idempotent + d.addCallback(lambda _: deferToThread(c1.get, u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2")) + # deallocating one side is not enough to destroy the channel + d.addCallback(lambda _: deferToThread(c2.deallocate)) + def _not_yet(_): + self._relay_server.prune() + self.failUnlessEqual(len(self._relay_server._apps), 1) + d.addCallback(_not_yet) + # but deallocating both will make the messages go away + d.addCallback(lambda _: deferToThread(c1.deallocate)) + def _gone(_): + self._relay_server.prune() + self.failUnlessEqual(len(self._relay_server._apps), 0) + d.addCallback(_gone) + + return d + + def test_appid_independence(self): + APPID_A = u"appid_A" + APPID_B = u"appid_B" + cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore) + cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore) + c1a = cm1a.connect(1) + c2a = cm2a.connect(1) + cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore) + cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore) + c1b = cm1b.connect(1) + c2b = cm2b.connect(1) + + d = succeed(None) + d.addCallback(lambda _: deferToThread(c1a.send, u"phase1", b"msg1a")) + d.addCallback(lambda _: deferToThread(c1b.send, u"phase1", b"msg1b")) + d.addCallback(lambda _: deferToThread(c2a.get, u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a")) + d.addCallback(lambda _: deferToThread(c2b.get, u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b")) + return d + + class Blocking(ServerBase, unittest.TestCase): # we need Twisted to run the server, but we run the sender and receiver # with deferToThread() diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index 91c72d8..6fe3c93 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -99,28 +99,27 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase): out, err, rc = res out = out.decode("utf-8") err = err.decode("utf-8") - self.failUnlessEqual(out, - "Sending text message (%d bytes)\n" - "On the other computer, please run: " - "wormhole receive\n" - "Wormhole code is: %s\n\n" - "text message sent\n" % (len(message), code) - ) - self.failUnlessEqual(err, "") - self.failUnlessEqual(rc, 0) + self.maxDiff = None + expected = ("Sending text message (%d bytes)\n" + "On the other computer, please run: " + "wormhole receive\n" + "Wormhole code is: %s\n\n" + "text message sent\n" % (len(message), code)) + self.failUnlessEqual( (expected, "", 0), + (out, err, rc) ) return d2 d1.addCallback(_check_sender) def _check_receiver(res): out, err, rc = res out = out.decode("utf-8") err = err.decode("utf-8") - self.failUnlessEqual(out, message+"\n") - self.failUnlessEqual(err, "") - self.failUnlessEqual(rc, 0) + self.failUnlessEqual( (message+"\n", "", 0), + (out, err, rc) ) d1.addCallback(_check_receiver) return d1 def test_send_file_pre_generated_code(self): + self.maxDiff=None code = "1-abc" filename = "testfile" message = "test message" @@ -150,6 +149,7 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase): out, err, rc = res out = out.decode("utf-8") err = err.decode("utf-8") + self.failUnlessEqual(err, "") self.failUnlessIn("Sending %d byte file named '%s'\n" % (len(message), filename), out) self.failUnlessIn("On the other computer, please run: " @@ -159,7 +159,6 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase): self.failUnlessIn("File sent.. waiting for confirmation\n" "Confirmation received. Transfer complete.\n", out) - self.failUnlessEqual(err, "") self.failUnlessEqual(rc, 0) return d2 d1.addCallback(_check_sender) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 32a2f2c..9e3f813 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1,6 +1,7 @@ from __future__ import print_function import sys, json import requests +from six.moves.urllib_parse import urlencode from twisted.trial import unittest from twisted.internet import reactor, defer from twisted.internet.threads import deferToThread @@ -55,15 +56,26 @@ def unjson(data): return json.loads(data.decode("utf-8")) class API(ServerBase, unittest.TestCase): - def get(self, path, is_json=True): - url = (self.relayurl+path).encode("ascii") - d = getPage(url) - if is_json: - d.addCallback(unjson) + def build_url(self, path, appid, channelid): + url = self.relayurl+path + queryargs = [] + if appid: + queryargs.append(("appid", appid)) + if channelid: + queryargs.append(("channelid", channelid)) + if queryargs: + url += "?" + urlencode(queryargs) + return url + + def get(self, path, appid=None, channelid=None): + url = self.build_url(path, appid, channelid) + d = getPage(url.encode("ascii")) + d.addCallback(unjson) return d + def post(self, path, data): - url = (self.relayurl+path).encode("ascii") - d = getPage(url, method=b"POST", + url = self.relayurl+path + d = getPage(url.encode("ascii"), method=b"POST", postdata=json.dumps(data).encode("utf-8")) d.addCallback(unjson) return d @@ -73,13 +85,14 @@ class API(ServerBase, unittest.TestCase): self.failUnlessEqual(data["welcome"], {"current_version": __version__}) def test_allocate_1(self): - d = self.get("list") + d = self.get("list", "app1") def _check_list_1(data): self.check_welcome(data) self.failUnlessEqual(data["channelids"], []) d.addCallback(_check_list_1) - d.addCallback(lambda _: self.post("allocate", {"side": "abc"})) + d.addCallback(lambda _: self.post("allocate", {"appid": "app1", + "side": "abc"})) def _allocated(data): self.failUnlessEqual(set(data.keys()), set(["welcome", "channelid"])) @@ -87,18 +100,20 @@ class API(ServerBase, unittest.TestCase): self.cid = data["channelid"] d.addCallback(_allocated) - d.addCallback(lambda _: self.get("list")) + d.addCallback(lambda _: self.get("list", "app1")) def _check_list_2(data): self.failUnlessEqual(data["channelids"], [self.cid]) d.addCallback(_check_list_2) - d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, - {"side": "abc"})) + d.addCallback(lambda _: self.post("deallocate", + {"appid": "app1", + "channelid": str(self.cid), + "side": "abc"})) def _check_deallocate(res): self.failUnlessEqual(res["status"], "deleted") d.addCallback(_check_deallocate) - d.addCallback(lambda _: self.get("list")) + d.addCallback(lambda _: self.get("list", "app1")) def _check_list_3(data): self.failUnlessEqual(data["channelids"], []) d.addCallback(_check_list_3) @@ -106,45 +121,57 @@ class API(ServerBase, unittest.TestCase): return d def test_allocate_2(self): - d = self.post("allocate", {"side": "abc"}) + d = self.post("allocate", {"appid": "app1", "side": "abc"}) def _allocated(data): self.cid = data["channelid"] d.addCallback(_allocated) # second caller increases the number of known sides to 2 - d.addCallback(lambda _: self.post("%d" % self.cid, - {"side": "def", + d.addCallback(lambda _: self.post("add", + {"appid": "app1", + "channelid": str(self.cid), + "side": "def", "phase": "1", "body": ""})) - d.addCallback(lambda _: self.get("list")) + d.addCallback(lambda _: self.get("list", "app1")) d.addCallback(lambda data: self.failUnlessEqual(data["channelids"], [self.cid])) - d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, - {"side": "abc"})) + d.addCallback(lambda _: self.post("deallocate", + {"appid": "app1", + "channelid": str(self.cid), + "side": "abc"})) d.addCallback(lambda res: self.failUnlessEqual(res["status"], "waiting")) - d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, - {"side": "NOT"})) + d.addCallback(lambda _: self.post("deallocate", + {"appid": "app1", + "channelid": str(self.cid), + "side": "NOT"})) d.addCallback(lambda res: self.failUnlessEqual(res["status"], "waiting")) - d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, - {"side": "def"})) + d.addCallback(lambda _: self.post("deallocate", + {"appid": "app1", + "channelid": str(self.cid), + "side": "def"})) d.addCallback(lambda res: self.failUnlessEqual(res["status"], "deleted")) - d.addCallback(lambda _: self.get("list")) + d.addCallback(lambda _: self.get("list", "app1")) d.addCallback(lambda data: self.failUnlessEqual(data["channelids"], [])) return d def add_message(self, message, side="abc", phase="1"): - return self.post(str(self.cid), {"side": side, "phase": phase, - "body": message}) + return self.post("add", + {"appid": "app1", + "channelid": str(self.cid), + "side": side, + "phase": phase, + "body": message}) def parse_messages(self, messages): out = set() @@ -164,7 +191,7 @@ class API(ServerBase, unittest.TestCase): self.failUnlessIn(d, two) def test_messages(self): - d = self.post("allocate", {"side": "abc"}) + d = self.post("allocate", {"appid": "app1", "side": "abc"}) def _allocated(data): self.cid = data["channelid"] d.addCallback(_allocated) @@ -175,6 +202,8 @@ class API(ServerBase, unittest.TestCase): self.failUnlessEqual(data["messages"], [{"phase": "1", "body": "msg1A"}]) d.addCallback(_check1) + d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) + d.addCallback(_check1) d.addCallback(lambda _: self.add_message("msg1B", side="def")) def _check2(data): self.check_welcome(data) @@ -182,6 +211,8 @@ class API(ServerBase, unittest.TestCase): set([("1", "msg1A"), ("1", "msg1B")])) d.addCallback(_check2) + d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) + d.addCallback(_check2) # adding a duplicate message is not an error, is ignored by clients d.addCallback(lambda _: self.add_message("msg1B", side="def")) @@ -191,6 +222,8 @@ class API(ServerBase, unittest.TestCase): set([("1", "msg1A"), ("1", "msg1B")])) d.addCallback(_check3) + d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) + d.addCallback(_check3) d.addCallback(lambda _: self.add_message("msg2A", side="abc", phase="2")) @@ -202,6 +235,8 @@ class API(ServerBase, unittest.TestCase): ("2", "msg2A"), ])) d.addCallback(_check4) + d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) + d.addCallback(_check4) return d @@ -209,10 +244,10 @@ class API(ServerBase, unittest.TestCase): if sys.version_info[0] >= 3: raise unittest.SkipTest("twisted vs py3") - d = self.post("allocate", {"side": "abc"}) + d = self.post("allocate", {"appid": "app1", "side": "abc"}) def _allocated(data): self.cid = data["channelid"] - url = self.relayurl+str(self.cid) + url = self.build_url("get", "app1", self.cid) self.o = OneEventAtATime(url, parser=json.loads) return self.o.wait_for_connection() d.addCallback(_allocated) diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py index 4aa345b..2d20057 100644 --- a/src/wormhole/test/test_twisted.py +++ b/src/wormhole/test/test_twisted.py @@ -1,11 +1,86 @@ +from __future__ import print_function import sys, json from twisted.trial import unittest -from twisted.internet.defer import gatherResults -from ..twisted.transcribe import Wormhole, UsageError +from twisted.internet.defer import gatherResults, succeed +from ..twisted.transcribe import Wormhole, UsageError, ChannelManager from .common import ServerBase APPID = u"appid" +class Channel(ServerBase, unittest.TestCase): + def ignore(self, welcome): + pass + + def test_allocate(self): + cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore) + d = cm.list_channels() + def _got_channels(channels): + self.failUnlessEqual(channels, []) + d.addCallback(_got_channels) + d.addCallback(lambda _: cm.allocate()) + def _allocated(channelid): + self.failUnlessEqual(type(channelid), int) + self._channelid = channelid + d.addCallback(_allocated) + d.addCallback(lambda _: cm.connect(self._channelid)) + def _connected(c): + self._channel = c + d.addCallback(_connected) + d.addCallback(lambda _: self._channel.deallocate()) + return d + + def test_messages(self): + cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore) + cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore) + c1 = cm1.connect(1) + c2 = cm2.connect(1) + + d = succeed(None) + d.addCallback(lambda _: c1.send(u"phase1", b"msg1")) + d.addCallback(lambda _: c2.get(u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1")) + d.addCallback(lambda _: c2.send(u"phase1", b"msg2")) + d.addCallback(lambda _: c1.get(u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2")) + # it's legal to fetch a phase multiple times, should be idempotent + d.addCallback(lambda _: c1.get(u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2")) + # deallocating one side is not enough to destroy the channel + d.addCallback(lambda _: c2.deallocate()) + def _not_yet(_): + self._relay_server.prune() + self.failUnlessEqual(len(self._relay_server._apps), 1) + d.addCallback(_not_yet) + # but deallocating both will make the messages go away + d.addCallback(lambda _: c1.deallocate()) + def _gone(_): + self._relay_server.prune() + self.failUnlessEqual(len(self._relay_server._apps), 0) + d.addCallback(_gone) + + return d + + def test_appid_independence(self): + APPID_A = u"appid_A" + APPID_B = u"appid_B" + cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore) + cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore) + c1a = cm1a.connect(1) + c2a = cm2a.connect(1) + cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore) + cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore) + c1b = cm1b.connect(1) + c2b = cm2b.connect(1) + + d = succeed(None) + d.addCallback(lambda _: c1a.send(u"phase1", b"msg1a")) + d.addCallback(lambda _: c1b.send(u"phase1", b"msg1b")) + d.addCallback(lambda _: c2a.get(u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a")) + d.addCallback(lambda _: c2b.get(u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b")) + return d + class Basic(ServerBase, unittest.TestCase): def doBoth(self, d1, d2): @@ -154,6 +229,7 @@ class Basic(ServerBase, unittest.TestCase): return d if sys.version_info[0] >= 3: + Channel.skip = "twisted is not yet sufficiently ported to py3" Basic.skip = "twisted is not yet sufficiently ported to py3" # as of 15.4.0, Twisted is still missing: # * web.client.Agent (for all non-EventSource POSTs in transcribe.py) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 2357867..a475223 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -1,5 +1,6 @@ from __future__ import print_function import os, sys, json, re, unicodedata +from six.moves.urllib_parse import urlencode from binascii import hexlify, unhexlify from zope.interface import implementer from twisted.internet import reactor, defer @@ -50,10 +51,24 @@ def post_json(agent, url, request_body): d.addCallback(lambda data: json.loads(data)) return d +def get_json(agent, url): + # GET from a URL, parsing the response as JSON + d = agent.request("GET", url.encode("utf-8")) + def _check_error(resp): + if resp.code != 200: + raise web_error.Error(resp.code, resp.phrase) + return resp + d.addCallback(_check_error) + d.addCallback(web_client.readBody) + d.addCallback(lambda data: json.loads(data)) + return d + class Channel: - def __init__(self, relay_url, channelid, side, handle_welcome, + def __init__(self, relay_url, appid, channelid, side, handle_welcome, agent): - self._channel_url = u"%s%d" % (relay_url, channelid) + self._relay_url = relay_url + self._appid = appid + self._channelid = channelid self._side = side self._handle_welcome = handle_welcome self._agent = agent @@ -78,10 +93,12 @@ class Channel: if not isinstance(phase, type(u"")): raise UsageError(type(phase)) if not isinstance(msg, type(b"")): raise UsageError(type(msg)) self._sent_messages.add( (phase,msg) ) - payload = {"side": self._side, + payload = {"appid": self._appid, + "channelid": self._channelid, + "side": self._side, "phase": phase, "body": hexlify(msg).decode("ascii")} - d = post_json(self._agent, self._channel_url, payload) + d = post_json(self._agent, self._relay_url+"add", payload) d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"])) return d @@ -104,7 +121,10 @@ class Channel: msgs.append(body) d.callback(None) # TODO: use agent=self._agent - es = ReconnectingEventSource(self._channel_url, _handle) + queryargs = urlencode([("appid", self._appid), + ("channelid", self._channelid)]) + es = ReconnectingEventSource(self._relay_url+"get?%s" % queryargs, + _handle) es.startService() # TODO: .setServiceParent(self) es.activate() d.addCallback(lambda _: es.deactivate()) @@ -114,22 +134,26 @@ class Channel: def deallocate(self): # only try once, no retries - d = post_json(self._agent, self._channel_url+"/deallocate", - {"side": self._side}) + d = post_json(self._agent, self._relay_url+"deallocate", + {"appid": self._appid, + "channelid": self._channelid, + "side": self._side}) d.addBoth(lambda _: None) # ignore POST failure return d class ChannelManager: - def __init__(self, relay_url, side, handle_welcome): - assert isinstance(relay_url, type(u"")) - self._relay_url = relay_url + def __init__(self, relay, appid, side, handle_welcome): + assert isinstance(relay, type(u"")) + self._relay = relay + self._appid = appid self._side = side self._handle_welcome = handle_welcome self._agent = web_client.Agent(reactor) def allocate(self): - url = self._relay_url + "allocate" - d = post_json(self._agent, url, {"side": self._side}) + url = self._relay + "allocate" + d = post_json(self._agent, url, {"appid": self._appid, + "side": self._side}) def _got_channel(data): if "welcome" in data: self._handle_welcome(data["welcome"]) @@ -138,10 +162,14 @@ class ChannelManager: return d def list_channels(self): - raise NotImplementedError + queryargs = urlencode([("appid", self._appid)]) + url = self._relay + u"list?%s" % queryargs + d = get_json(self._agent, url) + d.addCallback(lambda r: r["channelids"]) + return d def connect(self, channelid): - return Channel(self._relay_url, channelid, self._side, + return Channel(self._relay, self._appid, channelid, self._side, self._handle_welcome, self._agent) class Wormhole: @@ -163,17 +191,16 @@ class Wormhole: def _set_side(self, side): self._side = side - self._channel_manager = ChannelManager(self._relay_url, self._side, - self.handle_welcome) + self._channel_manager = ChannelManager(self._relay_url, self._appid, + self._side, self.handle_welcome) def handle_welcome(self, welcome): if ("motd" in welcome and not self.motd_displayed): motd_lines = welcome["motd"].splitlines() motd_formatted = "\n ".join(motd_lines) - print("Server (at %s) says:\n %s" % (self._relay_url, - motd_formatted), - file=sys.stderr) + print("Server (at %s) says:\n %s" % + (self._relay_url, motd_formatted), file=sys.stderr) self.motd_displayed = True # Only warn if we're running a release version (e.g. 0.0.6, not