From 574d5f2314abd0ff12f9664e7ac5d5eaee38a7f0 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 6 Oct 2015 17:20:12 -0700 Subject: [PATCH] scope channelids to the appid, change API and DB schema This requires a DB delete/recreate when upgrading. It changes the server protocol, and app IDs, so clients cannot interoperate with each other across this change, nor with the server. Flag day for everyone! Now apps do not share channel IDs, so a lot of usage of app1 will not cause the wormhole codes for app2 to get longer. --- src/wormhole/blocking/transcribe.py | 58 ++-- src/wormhole/db-schemas/v1.sql | 4 +- src/wormhole/servers/relay_server.py | 390 ++++++++++++++++----------- src/wormhole/test/common.py | 1 + src/wormhole/test/test_blocking.py | 81 +++++- src/wormhole/test/test_scripts.py | 25 +- src/wormhole/test/test_server.py | 93 +++++-- src/wormhole/test/test_twisted.py | 80 +++++- src/wormhole/twisted/transcribe.py | 65 +++-- 9 files changed, 548 insertions(+), 249 deletions(-) diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index 0a97053..420d513 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -1,5 +1,6 @@ from __future__ import print_function import os, sys, time, re, requests, json, unicodedata +from six.moves.urllib_parse import urlencode from binascii import hexlify, unhexlify from spake2 import SPAKE2_Symmetric from nacl.secret import SecretBox @@ -18,19 +19,21 @@ MINUTE = 60*SECOND def to_bytes(u): return unicodedata.normalize("NFC", u).encode("utf-8") -# relay URLs are: -# GET /list -> {channelids: [INT..]} -# POST /allocate {side: SIDE} -> {channelid: INT} -# these return all messages (base64) for CID= : -# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]} -# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]} -# GET /CID (eventsource) -> {phase:, body:}.. -# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted} +# relay URLs are as follows: (MESSAGES=[{phase:,body:}..]) +# GET /list?appid= -> {channelids: [INT..]} +# POST /allocate {appid:,side:} -> {channelid: INT} +# these return all messages (base64) for appid=/channelid= : +# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES} +# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES} +# GET /get?appid=&channelid= (eventsource) -> {phase:, body:}.. +# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted} # all JSON responses include a "welcome:{..}" key class Channel: - def __init__(self, relay_url, channelid, side, handle_welcome): - self._channel_url = u"%s%d" % (relay_url, channelid) + def __init__(self, relay_url, appid, channelid, side, handle_welcome): + self._relay_url = relay_url + self._appid = appid + self._channelid = channelid self._side = side self._handle_welcome = handle_welcome self._messages = set() # (phase,body) , body is bytes @@ -57,11 +60,13 @@ class Channel: if not isinstance(phase, type(u"")): raise UsageError(type(phase)) if not isinstance(msg, type(b"")): raise UsageError(type(msg)) self._sent_messages.add( (phase,msg) ) - payload = {"side": self._side, + payload = {"appid": self._appid, + "channelid": self._channelid, + "side": self._side, "phase": phase, "body": hexlify(msg).decode("ascii")} data = json.dumps(payload).encode("utf-8") - r = requests.post(self._channel_url, data=data) + r = requests.post(self._relay_url+"add", data=data) r.raise_for_status() resp = r.json() self._add_inbound_messages(resp["messages"]) @@ -80,7 +85,10 @@ class Channel: remaining = self._started + self._timeout - time.time() if remaining < 0: return Timeout - f = EventSourceFollower(self._channel_url, remaining) + queryargs = urlencode([("appid", self._appid), + ("channelid", self._channelid)]) + f = EventSourceFollower(self._relay_url+"get?%s" % queryargs, + remaining) # we loop here until the connection is lost, or we see the # message we want for (eventtype, data) in f.iter_events(): @@ -98,25 +106,30 @@ class Channel: def deallocate(self): # only try once, no retries - data = json.dumps({"side": self._side}).encode("utf-8") - requests.post(self._channel_url+"/deallocate", data=data) + data = json.dumps({"appid": self._appid, + "channelid": self._channelid, + "side": self._side}).encode("utf-8") + requests.post(self._relay_url+"deallocate", data=data) # ignore POST failure, don't call r.raise_for_status() class ChannelManager: - def __init__(self, relay_url, side, handle_welcome): + def __init__(self, relay_url, appid, side, handle_welcome): self._relay_url = relay_url + self._appid = appid self._side = side self._handle_welcome = handle_welcome def list_channels(self): - r = requests.get(self._relay_url + "list") + queryargs = urlencode([("appid", self._appid)]) + r = requests.get(self._relay_url+"list?%s" % queryargs) r.raise_for_status() channelids = r.json()["channelids"] return channelids def allocate(self): - data = json.dumps({"side": self._side}).encode("utf-8") - r = requests.post(self._relay_url + "allocate", data=data) + data = json.dumps({"appid": self._appid, + "side": self._side}).encode("utf-8") + r = requests.post(self._relay_url+"allocate", data=data) r.raise_for_status() data = r.json() if "welcome" in data: @@ -125,7 +138,7 @@ class ChannelManager: return channelid def connect(self, channelid): - return Channel(self._relay_url, channelid, self._side, + return Channel(self._relay_url, self._appid, channelid, self._side, self._handle_welcome) class Wormhole: @@ -139,7 +152,7 @@ class Wormhole: self._appid = appid self._relay_url = relay_url side = hexlify(os.urandom(5)).decode("ascii") - self._channel_manager = ChannelManager(relay_url, side, + self._channel_manager = ChannelManager(relay_url, appid, side, self.handle_welcome) self.code = None self.key = None @@ -152,8 +165,7 @@ class Wormhole: not self.motd_displayed): motd_lines = welcome["motd"].splitlines() motd_formatted = "\n ".join(motd_lines) - print("Server (at %s) says:\n %s" % (self._relay_url, - motd_formatted), + print("Server (at %s) says:\n %s" % (self._relay_url, motd_formatted), file=sys.stderr) self.motd_displayed = True diff --git a/src/wormhole/db-schemas/v1.sql b/src/wormhole/db-schemas/v1.sql index 3858140..9f7872d 100644 --- a/src/wormhole/db-schemas/v1.sql +++ b/src/wormhole/db-schemas/v1.sql @@ -9,16 +9,18 @@ CREATE TABLE `version` CREATE TABLE `messages` ( + `appid` VARCHAR, `channelid` INTEGER, `side` VARCHAR, `phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string `body` VARCHAR, `when` INTEGER ); -CREATE INDEX `messages_idx` ON `messages` (`channelid`, `side`, `phase`); +CREATE INDEX `messages_idx` ON `messages` (`appid`, `channelid`); CREATE TABLE `allocations` ( + `appid` VARCHAR, `channelid` INTEGER, `side` VARCHAR ); diff --git a/src/wormhole/servers/relay_server.py b/src/wormhole/servers/relay_server.py index 56f2ed1..e9884ce 100644 --- a/src/wormhole/servers/relay_server.py +++ b/src/wormhole/servers/relay_server.py @@ -1,8 +1,8 @@ from __future__ import print_function -import re, json, time, random +import json, time, random from twisted.python import log from twisted.application import service, internet -from twisted.web import server, resource, http +from twisted.web import server, resource SECONDS = 1.0 MINUTE = 60*SECONDS @@ -13,6 +13,10 @@ MB = 1000*1000 CHANNEL_EXPIRATION_TIME = 3*DAY EXPIRATION_CHECK_PERIOD = 2*HOUR +def json_response(request, data): + request.setHeader(b"content-type", b"application/json; charset=utf-8") + return (json.dumps(data)+"\n").encode("utf-8") + class EventsProtocol: def __init__(self, request): self.request = request @@ -46,109 +50,185 @@ class EventsProtocol: # note: no versions of IE (including the current IE11) support EventSource -# relay URLs are: -# GET /list -> {channelids: [INT..]} -# POST /allocate {side: SIDE} -> {channelid: INT} -# these return all messages (base64) for CID= : -# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]} -# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]} -# GET /CID (eventsource) -> {phase:, body:}.. -# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted} +# relay URLs are as follows: (MESSAGES=[{phase:,body:}..]) +# GET /list?appid= -> {channelids: [INT..]} +# POST /allocate {appid:,side:} -> {channelid: INT} +# these return all messages (base64) for appid=/channelid= : +# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES} +# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES} +# GET /get?appid=&channelid= (eventsource) -> {phase:, body:}.. +# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted} # all JSON responses include a "welcome:{..}" key -class Channel(resource.Resource): - def __init__(self, channelid, relay, db, welcome): +class ChannelLister(resource.Resource): + def __init__(self, relay): resource.Resource.__init__(self) - self.channelid = channelid - self.relay = relay - self.db = db - self.welcome = welcome - self.event_channels = set() # ep - self.putChild(b"deallocate", Deallocator(self.channelid, self.relay)) - - def get_messages(self, request): - request.setHeader(b"content-type", b"application/json; charset=utf-8") - messages = [] - for row in self.db.execute("SELECT * FROM `messages`" - " WHERE `channelid`=?" - " ORDER BY `when` ASC", - (self.channelid,)).fetchall(): - messages.append({"phase": row["phase"], "body": row["body"]}) - data = {"welcome": self.welcome, "messages": messages} - return (json.dumps(data)+"\n").encode("utf-8") + self._relay = relay def render_GET(self, request): - if b"text/event-stream" not in (request.getHeader(b"accept") or b""): - return self.get_messages(request) - request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") - ep = EventsProtocol(request) - ep.sendEvent(json.dumps(self.welcome), name="welcome") - self.event_channels.add(ep) - request.notifyFinish().addErrback(lambda f: - self.event_channels.discard(ep)) - for row in self.db.execute("SELECT * FROM `messages`" - " WHERE `channelid`=?" - " ORDER BY `when` ASC", - (self.channelid,)).fetchall(): - data = json.dumps({"phase": row["phase"], "body": row["body"]}) - ep.sendEvent(data) - return server.NOT_DONE_YET + appid = request.args[b"appid"][0].decode("utf-8") + #print("LIST", appid) + app = self._relay.get_app(appid) + allocated = app.get_allocated() + request.setHeader(b"content-type", b"application/json; charset=utf-8") + data = {"welcome": self._relay.welcome, + "channelids": sorted(allocated)} + return (json.dumps(data)+"\n").encode("utf-8") - def broadcast_message(self, phase, body): - data = json.dumps({"phase": phase, "body": body}) - for ep in self.event_channels: - ep.sendEvent(data) +class Allocator(resource.Resource): + def __init__(self, relay): + resource.Resource.__init__(self) + self._relay = relay def render_POST(self, request): - #data = json.load(request.content, encoding="utf-8") content = request.content.read() data = json.loads(content.decode("utf-8")) + appid = data["appid"] + side = data["side"] + if not isinstance(side, type(u"")): + raise TypeError("side must be string, not '%s'" % type(side)) + #print("ALLOCATE", appid, side) + app = self._relay.get_app(appid) + channelid = app.find_available_channelid() + app.allocate_channel(channelid, side) + log.msg("allocated #%d, now have %d DB channels" % + (channelid, len(app.get_allocated()))) + request.setHeader(b"content-type", b"application/json; charset=utf-8") + data = {"welcome": self._relay.welcome, + "channelid": channelid} + return (json.dumps(data)+"\n").encode("utf-8") +class Adder(resource.Resource): + def __init__(self, relay): + resource.Resource.__init__(self) + self._relay = relay + + def render_POST(self, request): + #content = json.load(request.content, encoding="utf-8") + content = request.content.read() + data = json.loads(content.decode("utf-8")) + appid = data["appid"] + channelid = int(data["channelid"]) side = data["side"] phase = data["phase"] if not isinstance(phase, type(u"")): raise TypeError("phase must be string, not %s" % type(phase)) body = data["body"] + #print("ADD", appid, channelid, side, phase, body) - self.db.execute("INSERT INTO `messages`" - " (`channelid`, `side`, `phase`, `body`, `when`)" - " VALUES (?,?,?,?,?)", - (self.channelid, side, phase, body, time.time())) - self.db.execute("INSERT INTO `allocations`" - " (`channelid`, `side`)" - " VALUES (?,?)", - (self.channelid, side)) - self.db.commit() - self.broadcast_message(phase, body) - return self.get_messages(request) + app = self._relay.get_app(appid) + channel = app.get_channel(channelid) + response = channel.add_message(side, phase, body) + return json_response(request, response) + +class Getter(resource.Resource): + def __init__(self, relay): + self._relay = relay + + def render_GET(self, request): + appid = request.args[b"appid"][0].decode("utf-8") + channelid = int(request.args[b"channelid"][0]) + #print("GET", appid, channelid) + app = self._relay.get_app(appid) + channel = app.get_channel(channelid) + + if b"text/event-stream" not in (request.getHeader(b"accept") or b""): + response = channel.get_messages() + return json_response(request, response) + + request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") + ep = EventsProtocol(request) + ep.sendEvent(json.dumps(self._relay.welcome), name="welcome") + old_events = channel.add_listener(ep.sendEvent) + request.notifyFinish().addErrback(lambda f: + channel.remove_listener(ep.sendEvent)) + for old_event in old_events: + ep.sendEvent(old_event) + return server.NOT_DONE_YET class Deallocator(resource.Resource): - def __init__(self, channelid, relay): - self.channelid = channelid - self.relay = relay + def __init__(self, relay): + self._relay = relay def render_POST(self, request): content = request.content.read() data = json.loads(content.decode("utf-8")) + appid = data["appid"] + channelid = int(data["channelid"]) side = data["side"] - deleted = self.relay.maybe_free_child(self.channelid, side) - resp = {"status": "waiting"} + #print("DEALLOCATE", appid, channelid, side) + app = self._relay.get_app(appid) + deleted = app.maybe_free_child(channelid, side) + response = {"status": "waiting"} if deleted: - resp = {"status": "deleted"} - return json.dumps(resp).encode("utf-8") + response = {"status": "deleted"} + return json_response(request, response) -def get_allocated(db): - c = db.execute("SELECT DISTINCT `channelid` FROM `allocations`") - return set([row["channelid"] for row in c.fetchall()]) - -class Allocator(resource.Resource): - def __init__(self, db, welcome): +class Channel(resource.Resource): + def __init__(self, relay, appid, channelid): resource.Resource.__init__(self) - self.db = db - self.welcome = welcome + self._relay = relay + self._appid = appid + self._channelid = channelid + self._listeners = set() # callbacks that take JSONable object - def allocate_channelid(self): - allocated = get_allocated(self.db) + def get_messages(self): + messages = [] + db = self._relay.db + for row in db.execute("SELECT * FROM `messages`" + " WHERE `appid`=? AND `channelid`=?" + " ORDER BY `when` ASC", + (self._appid, self._channelid)).fetchall(): + messages.append({"phase": row["phase"], "body": row["body"]}) + data = {"welcome": self._relay.welcome, "messages": messages} + return data + + def add_listener(self, listener): + self._listeners.add(listener) + db = self._relay.db + for row in db.execute("SELECT * FROM `messages`" + " WHERE `appid`=? AND `channelid`=?" + " ORDER BY `when` ASC", + (self._appid, self._channelid)).fetchall(): + yield json.dumps({"phase": row["phase"], "body": row["body"]}) + def remove_listener(self, listener): + self._listeners.discard(listener) + + def broadcast_message(self, phase, body): + data = json.dumps({"phase": phase, "body": body}) + for listener in self._listeners: + listener(data) + + def add_message(self, side, phase, body): + db = self._relay.db + db.execute("INSERT INTO `messages`" + " (`appid`, `channelid`, `side`, `phase`, `body`, `when`)" + " VALUES (?,?,?,?, ?,?)", + (self._appid, self._channelid, side, phase, + body, time.time())) + db.execute("INSERT INTO `allocations`" + " (`appid`, `channelid`, `side`)" + " VALUES (?,?,?)", + (self._appid, self._channelid, side)) + db.commit() + self.broadcast_message(phase, body) + return self.get_messages() + +class AppNamespace(resource.Resource): + def __init__(self, relay, appid): + resource.Resource.__init__(self) + self._relay = relay + self._appid = appid + self._channels = {} + + def get_allocated(self): + db = self._relay.db + c = db.execute("SELECT DISTINCT `channelid` FROM `allocations`" + " WHERE `appid`=?", (self._appid,)) + return set([row["channelid"] for row in c.fetchall()]) + + def find_available_channelid(self): + allocated = self.get_allocated() for size in range(1,4): # stick to 1-999 for now available = set() for cid in range(10**(size-1), 10**size): @@ -161,37 +241,63 @@ class Allocator(resource.Resource): cid = random.randrange(1000, 1000*1000) if cid not in allocated: return cid - raise ValueError("unable to find a free channelid") + raise ValueError("unable to find a free channel-id") - def render_POST(self, request): - content = request.content.read() - data = json.loads(content.decode("utf-8")) - side = data["side"] - if not isinstance(side, type(u"")): - raise TypeError("side must be string, not '%s'" % type(side)) - channelid = self.allocate_channelid() - self.db.execute("INSERT INTO `allocations` VALUES (?,?)", - (channelid, side)) - self.db.commit() - log.msg("allocated #%d, now have %d DB channels" % - (channelid, len(get_allocated(self.db)))) - request.setHeader(b"content-type", b"application/json; charset=utf-8") - data = {"welcome": self.welcome, - "channelid": channelid} - return (json.dumps(data)+"\n").encode("utf-8") + def allocate_channel(self, channelid, side): + db = self._relay.db + db.execute("INSERT INTO `allocations` VALUES (?,?,?)", + (self._appid, channelid, side)) + db.commit() -class ChannelList(resource.Resource): - def __init__(self, db, welcome): - resource.Resource.__init__(self) - self.db = db - self.welcome = welcome - def render_GET(self, request): - c = self.db.execute("SELECT DISTINCT `channelid` FROM `allocations`") - allocated = sorted(set([row["channelid"] for row in c.fetchall()])) - request.setHeader(b"content-type", b"application/json; charset=utf-8") - data = {"welcome": self.welcome, - "channelids": allocated} - return (json.dumps(data)+"\n").encode("utf-8") + def get_channel(self, channelid): + assert isinstance(channelid, int) + if not channelid in self._channels: + log.msg("spawning #%d for appid %s" % (channelid, self._appid)) + self._channels[channelid] = Channel(self._relay, + self._appid, channelid) + return self._channels[channelid] + + def maybe_free_child(self, channelid, side): + db = self._relay.db + db.execute("DELETE FROM `allocations`" + " WHERE `appid`=? AND `channelid`=? AND `side`=?", + (self._appid, channelid, side)) + db.commit() + remaining = db.execute("SELECT COUNT(*) FROM `allocations`" + " WHERE `appid`=? AND `channelid`=?", + (self._appid, channelid)).fetchone()[0] + if remaining: + return False + self._free_child(channelid) + return True + + def _free_child(self, channelid): + db = self._relay.db + db.execute("DELETE FROM `allocations`" + " WHERE `appid`=? AND `channelid`=?", + (self._appid, channelid)) + db.execute("DELETE FROM `messages`" + " WHERE `appid`=? AND `channelid`=?", + (self._appid, channelid)) + db.commit() + if channelid in self._channels: + self._channels.pop(channelid) + log.msg("freed+killed #%d, now have %d DB channels, %d live" % + (channelid, len(self.get_allocated()), len(self._channels))) + + def prune_old_channels(self): + db = self._relay.db + old = time.time() - CHANNEL_EXPIRATION_TIME + for channelid in self.get_allocated(): + c = db.execute("SELECT `when` FROM `messages`" + " WHERE `appid`=? AND `channelid`=?" + " ORDER BY `when` DESC LIMIT 1", + (self._appid, channelid)) + rows = c.fetchall() + if not rows or (rows[0]["when"] < old): + log.msg("expiring %d" % channelid) + self._free_child(channelid) + return bool(self._channels) class Relay(resource.Resource, service.MultiService): def __init__(self, db, welcome): @@ -199,60 +305,24 @@ class Relay(resource.Resource, service.MultiService): service.MultiService.__init__(self) self.db = db self.welcome = welcome - self.channels = {} - t = internet.TimerService(EXPIRATION_CHECK_PERIOD, - self.prune_old_channels) + self._apps = {} + t = internet.TimerService(EXPIRATION_CHECK_PERIOD, self.prune) t.setServiceParent(self) + self.putChild(b"list", ChannelLister(self)) + self.putChild(b"allocate", Allocator(self)) + self.putChild(b"add", Adder(self)) + self.putChild(b"get", Getter(self)) + self.putChild(b"deallocate", Deallocator(self)) + def get_app(self, appid): + assert isinstance(appid, type(u"")) + if not appid in self._apps: + log.msg("spawning appid %s" % (appid,)) + self._apps[appid] = AppNamespace(self, appid) + return self._apps[appid] - def getChild(self, path, request): - if path == b"allocate": - return Allocator(self.db, self.welcome) - if path == b"list": - return ChannelList(self.db, self.welcome) - if not re.search(br'^\d+$', path): - return resource.ErrorPage(http.BAD_REQUEST, - "invalid channel id", - "invalid channel id") - channelid = int(path) - if not channelid in self.channels: - log.msg("spawning #%d" % channelid) - self.channels[channelid] = Channel(channelid, self, self.db, - self.welcome) - return self.channels[channelid] - - def maybe_free_child(self, channelid, side): - self.db.execute("DELETE FROM `allocations`" - " WHERE `channelid`=? AND `side`=?", - (channelid, side)) - self.db.commit() - remaining = self.db.execute("SELECT COUNT(*) FROM `allocations`" - " WHERE `channelid`=?", - (channelid,)).fetchone()[0] - if remaining: - return False - self.free_child(channelid) - return True - - def free_child(self, channelid): - self.db.execute("DELETE FROM `allocations` WHERE `channelid`=?", - (channelid,)) - self.db.execute("DELETE FROM `messages` WHERE `channelid`=?", - (channelid,)) - self.db.commit() - if channelid in self.channels: - self.channels.pop(channelid) - log.msg("freed+killed #%d, now have %d DB channels, %d live" % - (channelid, len(get_allocated(self.db)), len(self.channels))) - - def prune_old_channels(self): - old = time.time() - CHANNEL_EXPIRATION_TIME - for channelid in get_allocated(self.db): - c = self.db.execute("SELECT `when` FROM `messages`" - " WHERE `channelid`=?" - " ORDER BY `when` DESC LIMIT 1", (channelid,)) - rows = c.fetchall() - if not rows or (rows[0]["when"] < old): - log.msg("expiring %d" % channelid) - self.free_child(channelid) - + def prune(self): + for appid in list(self._apps): + still_active = self._apps[appid].prune_old_channels() + if not still_active: + self._apps.pop(appid) diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index 068249f..b27a08a 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -14,6 +14,7 @@ class ServerBase: "tcp:%s:interface=127.0.0.1" % transitport, __version__) s.setServiceParent(self.sp) + self._relay_server = s.relay self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport self.transit = "tcp:127.0.0.1:%d" % transitport d.addCallback(_got_ports) diff --git a/src/wormhole/test/test_blocking.py b/src/wormhole/test/test_blocking.py index 977e47d..65e5be0 100644 --- a/src/wormhole/test/test_blocking.py +++ b/src/wormhole/test/test_blocking.py @@ -1,12 +1,89 @@ +from __future__ import print_function import json from twisted.trial import unittest -from twisted.internet.defer import gatherResults +from twisted.internet.defer import gatherResults, succeed from twisted.internet.threads import deferToThread -from ..blocking.transcribe import Wormhole as BlockingWormhole, UsageError +from ..blocking.transcribe import (Wormhole as BlockingWormhole, UsageError, + ChannelManager) from .common import ServerBase APPID = u"appid" +class Channel(ServerBase, unittest.TestCase): + def ignore(self, welcome): + pass + + def test_allocate(self): + cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore) + d = deferToThread(cm.list_channels) + def _got_channels(channels): + self.failUnlessEqual(channels, []) + d.addCallback(_got_channels) + d.addCallback(lambda _: deferToThread(cm.allocate)) + def _allocated(channelid): + self.failUnlessEqual(type(channelid), int) + self._channelid = channelid + d.addCallback(_allocated) + d.addCallback(lambda _: deferToThread(cm.connect, self._channelid)) + def _connected(c): + self._channel = c + d.addCallback(_connected) + d.addCallback(lambda _: deferToThread(self._channel.deallocate)) + return d + + def test_messages(self): + cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore) + cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore) + c1 = cm1.connect(1) + c2 = cm2.connect(1) + + d = succeed(None) + d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1")) + d.addCallback(lambda _: deferToThread(c2.get, u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1")) + d.addCallback(lambda _: deferToThread(c2.send, u"phase1", b"msg2")) + d.addCallback(lambda _: deferToThread(c1.get, u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2")) + # it's legal to fetch a phase multiple times, should be idempotent + d.addCallback(lambda _: deferToThread(c1.get, u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2")) + # deallocating one side is not enough to destroy the channel + d.addCallback(lambda _: deferToThread(c2.deallocate)) + def _not_yet(_): + self._relay_server.prune() + self.failUnlessEqual(len(self._relay_server._apps), 1) + d.addCallback(_not_yet) + # but deallocating both will make the messages go away + d.addCallback(lambda _: deferToThread(c1.deallocate)) + def _gone(_): + self._relay_server.prune() + self.failUnlessEqual(len(self._relay_server._apps), 0) + d.addCallback(_gone) + + return d + + def test_appid_independence(self): + APPID_A = u"appid_A" + APPID_B = u"appid_B" + cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore) + cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore) + c1a = cm1a.connect(1) + c2a = cm2a.connect(1) + cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore) + cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore) + c1b = cm1b.connect(1) + c2b = cm2b.connect(1) + + d = succeed(None) + d.addCallback(lambda _: deferToThread(c1a.send, u"phase1", b"msg1a")) + d.addCallback(lambda _: deferToThread(c1b.send, u"phase1", b"msg1b")) + d.addCallback(lambda _: deferToThread(c2a.get, u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a")) + d.addCallback(lambda _: deferToThread(c2b.get, u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b")) + return d + + class Blocking(ServerBase, unittest.TestCase): # we need Twisted to run the server, but we run the sender and receiver # with deferToThread() diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index 91c72d8..6fe3c93 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -99,28 +99,27 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase): out, err, rc = res out = out.decode("utf-8") err = err.decode("utf-8") - self.failUnlessEqual(out, - "Sending text message (%d bytes)\n" - "On the other computer, please run: " - "wormhole receive\n" - "Wormhole code is: %s\n\n" - "text message sent\n" % (len(message), code) - ) - self.failUnlessEqual(err, "") - self.failUnlessEqual(rc, 0) + self.maxDiff = None + expected = ("Sending text message (%d bytes)\n" + "On the other computer, please run: " + "wormhole receive\n" + "Wormhole code is: %s\n\n" + "text message sent\n" % (len(message), code)) + self.failUnlessEqual( (expected, "", 0), + (out, err, rc) ) return d2 d1.addCallback(_check_sender) def _check_receiver(res): out, err, rc = res out = out.decode("utf-8") err = err.decode("utf-8") - self.failUnlessEqual(out, message+"\n") - self.failUnlessEqual(err, "") - self.failUnlessEqual(rc, 0) + self.failUnlessEqual( (message+"\n", "", 0), + (out, err, rc) ) d1.addCallback(_check_receiver) return d1 def test_send_file_pre_generated_code(self): + self.maxDiff=None code = "1-abc" filename = "testfile" message = "test message" @@ -150,6 +149,7 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase): out, err, rc = res out = out.decode("utf-8") err = err.decode("utf-8") + self.failUnlessEqual(err, "") self.failUnlessIn("Sending %d byte file named '%s'\n" % (len(message), filename), out) self.failUnlessIn("On the other computer, please run: " @@ -159,7 +159,6 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase): self.failUnlessIn("File sent.. waiting for confirmation\n" "Confirmation received. Transfer complete.\n", out) - self.failUnlessEqual(err, "") self.failUnlessEqual(rc, 0) return d2 d1.addCallback(_check_sender) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 32a2f2c..9e3f813 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1,6 +1,7 @@ from __future__ import print_function import sys, json import requests +from six.moves.urllib_parse import urlencode from twisted.trial import unittest from twisted.internet import reactor, defer from twisted.internet.threads import deferToThread @@ -55,15 +56,26 @@ def unjson(data): return json.loads(data.decode("utf-8")) class API(ServerBase, unittest.TestCase): - def get(self, path, is_json=True): - url = (self.relayurl+path).encode("ascii") - d = getPage(url) - if is_json: - d.addCallback(unjson) + def build_url(self, path, appid, channelid): + url = self.relayurl+path + queryargs = [] + if appid: + queryargs.append(("appid", appid)) + if channelid: + queryargs.append(("channelid", channelid)) + if queryargs: + url += "?" + urlencode(queryargs) + return url + + def get(self, path, appid=None, channelid=None): + url = self.build_url(path, appid, channelid) + d = getPage(url.encode("ascii")) + d.addCallback(unjson) return d + def post(self, path, data): - url = (self.relayurl+path).encode("ascii") - d = getPage(url, method=b"POST", + url = self.relayurl+path + d = getPage(url.encode("ascii"), method=b"POST", postdata=json.dumps(data).encode("utf-8")) d.addCallback(unjson) return d @@ -73,13 +85,14 @@ class API(ServerBase, unittest.TestCase): self.failUnlessEqual(data["welcome"], {"current_version": __version__}) def test_allocate_1(self): - d = self.get("list") + d = self.get("list", "app1") def _check_list_1(data): self.check_welcome(data) self.failUnlessEqual(data["channelids"], []) d.addCallback(_check_list_1) - d.addCallback(lambda _: self.post("allocate", {"side": "abc"})) + d.addCallback(lambda _: self.post("allocate", {"appid": "app1", + "side": "abc"})) def _allocated(data): self.failUnlessEqual(set(data.keys()), set(["welcome", "channelid"])) @@ -87,18 +100,20 @@ class API(ServerBase, unittest.TestCase): self.cid = data["channelid"] d.addCallback(_allocated) - d.addCallback(lambda _: self.get("list")) + d.addCallback(lambda _: self.get("list", "app1")) def _check_list_2(data): self.failUnlessEqual(data["channelids"], [self.cid]) d.addCallback(_check_list_2) - d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, - {"side": "abc"})) + d.addCallback(lambda _: self.post("deallocate", + {"appid": "app1", + "channelid": str(self.cid), + "side": "abc"})) def _check_deallocate(res): self.failUnlessEqual(res["status"], "deleted") d.addCallback(_check_deallocate) - d.addCallback(lambda _: self.get("list")) + d.addCallback(lambda _: self.get("list", "app1")) def _check_list_3(data): self.failUnlessEqual(data["channelids"], []) d.addCallback(_check_list_3) @@ -106,45 +121,57 @@ class API(ServerBase, unittest.TestCase): return d def test_allocate_2(self): - d = self.post("allocate", {"side": "abc"}) + d = self.post("allocate", {"appid": "app1", "side": "abc"}) def _allocated(data): self.cid = data["channelid"] d.addCallback(_allocated) # second caller increases the number of known sides to 2 - d.addCallback(lambda _: self.post("%d" % self.cid, - {"side": "def", + d.addCallback(lambda _: self.post("add", + {"appid": "app1", + "channelid": str(self.cid), + "side": "def", "phase": "1", "body": ""})) - d.addCallback(lambda _: self.get("list")) + d.addCallback(lambda _: self.get("list", "app1")) d.addCallback(lambda data: self.failUnlessEqual(data["channelids"], [self.cid])) - d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, - {"side": "abc"})) + d.addCallback(lambda _: self.post("deallocate", + {"appid": "app1", + "channelid": str(self.cid), + "side": "abc"})) d.addCallback(lambda res: self.failUnlessEqual(res["status"], "waiting")) - d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, - {"side": "NOT"})) + d.addCallback(lambda _: self.post("deallocate", + {"appid": "app1", + "channelid": str(self.cid), + "side": "NOT"})) d.addCallback(lambda res: self.failUnlessEqual(res["status"], "waiting")) - d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, - {"side": "def"})) + d.addCallback(lambda _: self.post("deallocate", + {"appid": "app1", + "channelid": str(self.cid), + "side": "def"})) d.addCallback(lambda res: self.failUnlessEqual(res["status"], "deleted")) - d.addCallback(lambda _: self.get("list")) + d.addCallback(lambda _: self.get("list", "app1")) d.addCallback(lambda data: self.failUnlessEqual(data["channelids"], [])) return d def add_message(self, message, side="abc", phase="1"): - return self.post(str(self.cid), {"side": side, "phase": phase, - "body": message}) + return self.post("add", + {"appid": "app1", + "channelid": str(self.cid), + "side": side, + "phase": phase, + "body": message}) def parse_messages(self, messages): out = set() @@ -164,7 +191,7 @@ class API(ServerBase, unittest.TestCase): self.failUnlessIn(d, two) def test_messages(self): - d = self.post("allocate", {"side": "abc"}) + d = self.post("allocate", {"appid": "app1", "side": "abc"}) def _allocated(data): self.cid = data["channelid"] d.addCallback(_allocated) @@ -175,6 +202,8 @@ class API(ServerBase, unittest.TestCase): self.failUnlessEqual(data["messages"], [{"phase": "1", "body": "msg1A"}]) d.addCallback(_check1) + d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) + d.addCallback(_check1) d.addCallback(lambda _: self.add_message("msg1B", side="def")) def _check2(data): self.check_welcome(data) @@ -182,6 +211,8 @@ class API(ServerBase, unittest.TestCase): set([("1", "msg1A"), ("1", "msg1B")])) d.addCallback(_check2) + d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) + d.addCallback(_check2) # adding a duplicate message is not an error, is ignored by clients d.addCallback(lambda _: self.add_message("msg1B", side="def")) @@ -191,6 +222,8 @@ class API(ServerBase, unittest.TestCase): set([("1", "msg1A"), ("1", "msg1B")])) d.addCallback(_check3) + d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) + d.addCallback(_check3) d.addCallback(lambda _: self.add_message("msg2A", side="abc", phase="2")) @@ -202,6 +235,8 @@ class API(ServerBase, unittest.TestCase): ("2", "msg2A"), ])) d.addCallback(_check4) + d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) + d.addCallback(_check4) return d @@ -209,10 +244,10 @@ class API(ServerBase, unittest.TestCase): if sys.version_info[0] >= 3: raise unittest.SkipTest("twisted vs py3") - d = self.post("allocate", {"side": "abc"}) + d = self.post("allocate", {"appid": "app1", "side": "abc"}) def _allocated(data): self.cid = data["channelid"] - url = self.relayurl+str(self.cid) + url = self.build_url("get", "app1", self.cid) self.o = OneEventAtATime(url, parser=json.loads) return self.o.wait_for_connection() d.addCallback(_allocated) diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py index 4aa345b..2d20057 100644 --- a/src/wormhole/test/test_twisted.py +++ b/src/wormhole/test/test_twisted.py @@ -1,11 +1,86 @@ +from __future__ import print_function import sys, json from twisted.trial import unittest -from twisted.internet.defer import gatherResults -from ..twisted.transcribe import Wormhole, UsageError +from twisted.internet.defer import gatherResults, succeed +from ..twisted.transcribe import Wormhole, UsageError, ChannelManager from .common import ServerBase APPID = u"appid" +class Channel(ServerBase, unittest.TestCase): + def ignore(self, welcome): + pass + + def test_allocate(self): + cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore) + d = cm.list_channels() + def _got_channels(channels): + self.failUnlessEqual(channels, []) + d.addCallback(_got_channels) + d.addCallback(lambda _: cm.allocate()) + def _allocated(channelid): + self.failUnlessEqual(type(channelid), int) + self._channelid = channelid + d.addCallback(_allocated) + d.addCallback(lambda _: cm.connect(self._channelid)) + def _connected(c): + self._channel = c + d.addCallback(_connected) + d.addCallback(lambda _: self._channel.deallocate()) + return d + + def test_messages(self): + cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore) + cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore) + c1 = cm1.connect(1) + c2 = cm2.connect(1) + + d = succeed(None) + d.addCallback(lambda _: c1.send(u"phase1", b"msg1")) + d.addCallback(lambda _: c2.get(u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1")) + d.addCallback(lambda _: c2.send(u"phase1", b"msg2")) + d.addCallback(lambda _: c1.get(u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2")) + # it's legal to fetch a phase multiple times, should be idempotent + d.addCallback(lambda _: c1.get(u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2")) + # deallocating one side is not enough to destroy the channel + d.addCallback(lambda _: c2.deallocate()) + def _not_yet(_): + self._relay_server.prune() + self.failUnlessEqual(len(self._relay_server._apps), 1) + d.addCallback(_not_yet) + # but deallocating both will make the messages go away + d.addCallback(lambda _: c1.deallocate()) + def _gone(_): + self._relay_server.prune() + self.failUnlessEqual(len(self._relay_server._apps), 0) + d.addCallback(_gone) + + return d + + def test_appid_independence(self): + APPID_A = u"appid_A" + APPID_B = u"appid_B" + cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore) + cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore) + c1a = cm1a.connect(1) + c2a = cm2a.connect(1) + cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore) + cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore) + c1b = cm1b.connect(1) + c2b = cm2b.connect(1) + + d = succeed(None) + d.addCallback(lambda _: c1a.send(u"phase1", b"msg1a")) + d.addCallback(lambda _: c1b.send(u"phase1", b"msg1b")) + d.addCallback(lambda _: c2a.get(u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a")) + d.addCallback(lambda _: c2b.get(u"phase1")) + d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b")) + return d + class Basic(ServerBase, unittest.TestCase): def doBoth(self, d1, d2): @@ -154,6 +229,7 @@ class Basic(ServerBase, unittest.TestCase): return d if sys.version_info[0] >= 3: + Channel.skip = "twisted is not yet sufficiently ported to py3" Basic.skip = "twisted is not yet sufficiently ported to py3" # as of 15.4.0, Twisted is still missing: # * web.client.Agent (for all non-EventSource POSTs in transcribe.py) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 2357867..a475223 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -1,5 +1,6 @@ from __future__ import print_function import os, sys, json, re, unicodedata +from six.moves.urllib_parse import urlencode from binascii import hexlify, unhexlify from zope.interface import implementer from twisted.internet import reactor, defer @@ -50,10 +51,24 @@ def post_json(agent, url, request_body): d.addCallback(lambda data: json.loads(data)) return d +def get_json(agent, url): + # GET from a URL, parsing the response as JSON + d = agent.request("GET", url.encode("utf-8")) + def _check_error(resp): + if resp.code != 200: + raise web_error.Error(resp.code, resp.phrase) + return resp + d.addCallback(_check_error) + d.addCallback(web_client.readBody) + d.addCallback(lambda data: json.loads(data)) + return d + class Channel: - def __init__(self, relay_url, channelid, side, handle_welcome, + def __init__(self, relay_url, appid, channelid, side, handle_welcome, agent): - self._channel_url = u"%s%d" % (relay_url, channelid) + self._relay_url = relay_url + self._appid = appid + self._channelid = channelid self._side = side self._handle_welcome = handle_welcome self._agent = agent @@ -78,10 +93,12 @@ class Channel: if not isinstance(phase, type(u"")): raise UsageError(type(phase)) if not isinstance(msg, type(b"")): raise UsageError(type(msg)) self._sent_messages.add( (phase,msg) ) - payload = {"side": self._side, + payload = {"appid": self._appid, + "channelid": self._channelid, + "side": self._side, "phase": phase, "body": hexlify(msg).decode("ascii")} - d = post_json(self._agent, self._channel_url, payload) + d = post_json(self._agent, self._relay_url+"add", payload) d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"])) return d @@ -104,7 +121,10 @@ class Channel: msgs.append(body) d.callback(None) # TODO: use agent=self._agent - es = ReconnectingEventSource(self._channel_url, _handle) + queryargs = urlencode([("appid", self._appid), + ("channelid", self._channelid)]) + es = ReconnectingEventSource(self._relay_url+"get?%s" % queryargs, + _handle) es.startService() # TODO: .setServiceParent(self) es.activate() d.addCallback(lambda _: es.deactivate()) @@ -114,22 +134,26 @@ class Channel: def deallocate(self): # only try once, no retries - d = post_json(self._agent, self._channel_url+"/deallocate", - {"side": self._side}) + d = post_json(self._agent, self._relay_url+"deallocate", + {"appid": self._appid, + "channelid": self._channelid, + "side": self._side}) d.addBoth(lambda _: None) # ignore POST failure return d class ChannelManager: - def __init__(self, relay_url, side, handle_welcome): - assert isinstance(relay_url, type(u"")) - self._relay_url = relay_url + def __init__(self, relay, appid, side, handle_welcome): + assert isinstance(relay, type(u"")) + self._relay = relay + self._appid = appid self._side = side self._handle_welcome = handle_welcome self._agent = web_client.Agent(reactor) def allocate(self): - url = self._relay_url + "allocate" - d = post_json(self._agent, url, {"side": self._side}) + url = self._relay + "allocate" + d = post_json(self._agent, url, {"appid": self._appid, + "side": self._side}) def _got_channel(data): if "welcome" in data: self._handle_welcome(data["welcome"]) @@ -138,10 +162,14 @@ class ChannelManager: return d def list_channels(self): - raise NotImplementedError + queryargs = urlencode([("appid", self._appid)]) + url = self._relay + u"list?%s" % queryargs + d = get_json(self._agent, url) + d.addCallback(lambda r: r["channelids"]) + return d def connect(self, channelid): - return Channel(self._relay_url, channelid, self._side, + return Channel(self._relay, self._appid, channelid, self._side, self._handle_welcome, self._agent) class Wormhole: @@ -163,17 +191,16 @@ class Wormhole: def _set_side(self, side): self._side = side - self._channel_manager = ChannelManager(self._relay_url, self._side, - self.handle_welcome) + self._channel_manager = ChannelManager(self._relay_url, self._appid, + self._side, self.handle_welcome) def handle_welcome(self, welcome): if ("motd" in welcome and not self.motd_displayed): motd_lines = welcome["motd"].splitlines() motd_formatted = "\n ".join(motd_lines) - print("Server (at %s) says:\n %s" % (self._relay_url, - motd_formatted), - file=sys.stderr) + print("Server (at %s) says:\n %s" % + (self._relay_url, motd_formatted), file=sys.stderr) self.motd_displayed = True # Only warn if we're running a release version (e.g. 0.0.6, not