scope channelids to the appid, change API and DB schema

This requires a DB delete/recreate when upgrading. It changes the server
protocol, and app IDs, so clients cannot interoperate with each other
across this change, nor with the server. Flag day for everyone!

Now apps do not share channel IDs, so a lot of usage of app1 will not
cause the wormhole codes for app2 to get longer.
This commit is contained in:
Brian Warner 2015-10-06 17:20:12 -07:00
parent 8692bd2cd7
commit 574d5f2314
9 changed files with 548 additions and 249 deletions

View File

@ -1,5 +1,6 @@
from __future__ import print_function from __future__ import print_function
import os, sys, time, re, requests, json, unicodedata import os, sys, time, re, requests, json, unicodedata
from six.moves.urllib_parse import urlencode
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from spake2 import SPAKE2_Symmetric from spake2 import SPAKE2_Symmetric
from nacl.secret import SecretBox from nacl.secret import SecretBox
@ -18,19 +19,21 @@ MINUTE = 60*SECOND
def to_bytes(u): def to_bytes(u):
return unicodedata.normalize("NFC", u).encode("utf-8") return unicodedata.normalize("NFC", u).encode("utf-8")
# relay URLs are: # relay URLs are as follows: (MESSAGES=[{phase:,body:}..])
# GET /list -> {channelids: [INT..]} # GET /list?appid= -> {channelids: [INT..]}
# POST /allocate {side: SIDE} -> {channelid: INT} # POST /allocate {appid:,side:} -> {channelid: INT}
# these return all messages (base64) for CID= : # these return all messages (base64) for appid=/channelid= :
# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]} # POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES}
# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]} # GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES}
# GET /CID (eventsource) -> {phase:, body:}.. # GET /get?appid=&channelid= (eventsource) -> {phase:, body:}..
# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted} # POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted}
# all JSON responses include a "welcome:{..}" key # all JSON responses include a "welcome:{..}" key
class Channel: class Channel:
def __init__(self, relay_url, channelid, side, handle_welcome): def __init__(self, relay_url, appid, channelid, side, handle_welcome):
self._channel_url = u"%s%d" % (relay_url, channelid) self._relay_url = relay_url
self._appid = appid
self._channelid = channelid
self._side = side self._side = side
self._handle_welcome = handle_welcome self._handle_welcome = handle_welcome
self._messages = set() # (phase,body) , body is bytes self._messages = set() # (phase,body) , body is bytes
@ -57,11 +60,13 @@ class Channel:
if not isinstance(phase, type(u"")): raise UsageError(type(phase)) if not isinstance(phase, type(u"")): raise UsageError(type(phase))
if not isinstance(msg, type(b"")): raise UsageError(type(msg)) if not isinstance(msg, type(b"")): raise UsageError(type(msg))
self._sent_messages.add( (phase,msg) ) self._sent_messages.add( (phase,msg) )
payload = {"side": self._side, payload = {"appid": self._appid,
"channelid": self._channelid,
"side": self._side,
"phase": phase, "phase": phase,
"body": hexlify(msg).decode("ascii")} "body": hexlify(msg).decode("ascii")}
data = json.dumps(payload).encode("utf-8") data = json.dumps(payload).encode("utf-8")
r = requests.post(self._channel_url, data=data) r = requests.post(self._relay_url+"add", data=data)
r.raise_for_status() r.raise_for_status()
resp = r.json() resp = r.json()
self._add_inbound_messages(resp["messages"]) self._add_inbound_messages(resp["messages"])
@ -80,7 +85,10 @@ class Channel:
remaining = self._started + self._timeout - time.time() remaining = self._started + self._timeout - time.time()
if remaining < 0: if remaining < 0:
return Timeout return Timeout
f = EventSourceFollower(self._channel_url, remaining) queryargs = urlencode([("appid", self._appid),
("channelid", self._channelid)])
f = EventSourceFollower(self._relay_url+"get?%s" % queryargs,
remaining)
# we loop here until the connection is lost, or we see the # we loop here until the connection is lost, or we see the
# message we want # message we want
for (eventtype, data) in f.iter_events(): for (eventtype, data) in f.iter_events():
@ -98,25 +106,30 @@ class Channel:
def deallocate(self): def deallocate(self):
# only try once, no retries # only try once, no retries
data = json.dumps({"side": self._side}).encode("utf-8") data = json.dumps({"appid": self._appid,
requests.post(self._channel_url+"/deallocate", data=data) "channelid": self._channelid,
"side": self._side}).encode("utf-8")
requests.post(self._relay_url+"deallocate", data=data)
# ignore POST failure, don't call r.raise_for_status() # ignore POST failure, don't call r.raise_for_status()
class ChannelManager: class ChannelManager:
def __init__(self, relay_url, side, handle_welcome): def __init__(self, relay_url, appid, side, handle_welcome):
self._relay_url = relay_url self._relay_url = relay_url
self._appid = appid
self._side = side self._side = side
self._handle_welcome = handle_welcome self._handle_welcome = handle_welcome
def list_channels(self): def list_channels(self):
r = requests.get(self._relay_url + "list") queryargs = urlencode([("appid", self._appid)])
r = requests.get(self._relay_url+"list?%s" % queryargs)
r.raise_for_status() r.raise_for_status()
channelids = r.json()["channelids"] channelids = r.json()["channelids"]
return channelids return channelids
def allocate(self): def allocate(self):
data = json.dumps({"side": self._side}).encode("utf-8") data = json.dumps({"appid": self._appid,
r = requests.post(self._relay_url + "allocate", data=data) "side": self._side}).encode("utf-8")
r = requests.post(self._relay_url+"allocate", data=data)
r.raise_for_status() r.raise_for_status()
data = r.json() data = r.json()
if "welcome" in data: if "welcome" in data:
@ -125,7 +138,7 @@ class ChannelManager:
return channelid return channelid
def connect(self, channelid): def connect(self, channelid):
return Channel(self._relay_url, channelid, self._side, return Channel(self._relay_url, self._appid, channelid, self._side,
self._handle_welcome) self._handle_welcome)
class Wormhole: class Wormhole:
@ -139,7 +152,7 @@ class Wormhole:
self._appid = appid self._appid = appid
self._relay_url = relay_url self._relay_url = relay_url
side = hexlify(os.urandom(5)).decode("ascii") side = hexlify(os.urandom(5)).decode("ascii")
self._channel_manager = ChannelManager(relay_url, side, self._channel_manager = ChannelManager(relay_url, appid, side,
self.handle_welcome) self.handle_welcome)
self.code = None self.code = None
self.key = None self.key = None
@ -152,8 +165,7 @@ class Wormhole:
not self.motd_displayed): not self.motd_displayed):
motd_lines = welcome["motd"].splitlines() motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines) motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" % (self._relay_url, print("Server (at %s) says:\n %s" % (self._relay_url, motd_formatted),
motd_formatted),
file=sys.stderr) file=sys.stderr)
self.motd_displayed = True self.motd_displayed = True

View File

@ -9,16 +9,18 @@ CREATE TABLE `version`
CREATE TABLE `messages` CREATE TABLE `messages`
( (
`appid` VARCHAR,
`channelid` INTEGER, `channelid` INTEGER,
`side` VARCHAR, `side` VARCHAR,
`phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string `phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string
`body` VARCHAR, `body` VARCHAR,
`when` INTEGER `when` INTEGER
); );
CREATE INDEX `messages_idx` ON `messages` (`channelid`, `side`, `phase`); CREATE INDEX `messages_idx` ON `messages` (`appid`, `channelid`);
CREATE TABLE `allocations` CREATE TABLE `allocations`
( (
`appid` VARCHAR,
`channelid` INTEGER, `channelid` INTEGER,
`side` VARCHAR `side` VARCHAR
); );

View File

@ -1,8 +1,8 @@
from __future__ import print_function from __future__ import print_function
import re, json, time, random import json, time, random
from twisted.python import log from twisted.python import log
from twisted.application import service, internet from twisted.application import service, internet
from twisted.web import server, resource, http from twisted.web import server, resource
SECONDS = 1.0 SECONDS = 1.0
MINUTE = 60*SECONDS MINUTE = 60*SECONDS
@ -13,6 +13,10 @@ MB = 1000*1000
CHANNEL_EXPIRATION_TIME = 3*DAY CHANNEL_EXPIRATION_TIME = 3*DAY
EXPIRATION_CHECK_PERIOD = 2*HOUR EXPIRATION_CHECK_PERIOD = 2*HOUR
def json_response(request, data):
request.setHeader(b"content-type", b"application/json; charset=utf-8")
return (json.dumps(data)+"\n").encode("utf-8")
class EventsProtocol: class EventsProtocol:
def __init__(self, request): def __init__(self, request):
self.request = request self.request = request
@ -46,109 +50,185 @@ class EventsProtocol:
# note: no versions of IE (including the current IE11) support EventSource # note: no versions of IE (including the current IE11) support EventSource
# relay URLs are: # relay URLs are as follows: (MESSAGES=[{phase:,body:}..])
# GET /list -> {channelids: [INT..]} # GET /list?appid= -> {channelids: [INT..]}
# POST /allocate {side: SIDE} -> {channelid: INT} # POST /allocate {appid:,side:} -> {channelid: INT}
# these return all messages (base64) for CID= : # these return all messages (base64) for appid=/channelid= :
# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]} # POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES}
# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]} # GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES}
# GET /CID (eventsource) -> {phase:, body:}.. # GET /get?appid=&channelid= (eventsource) -> {phase:, body:}..
# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted} # POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted}
# all JSON responses include a "welcome:{..}" key # all JSON responses include a "welcome:{..}" key
class Channel(resource.Resource): class ChannelLister(resource.Resource):
def __init__(self, channelid, relay, db, welcome): def __init__(self, relay):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.channelid = channelid self._relay = relay
self.relay = relay
self.db = db
self.welcome = welcome
self.event_channels = set() # ep
self.putChild(b"deallocate", Deallocator(self.channelid, self.relay))
def get_messages(self, request):
request.setHeader(b"content-type", b"application/json; charset=utf-8")
messages = []
for row in self.db.execute("SELECT * FROM `messages`"
" WHERE `channelid`=?"
" ORDER BY `when` ASC",
(self.channelid,)).fetchall():
messages.append({"phase": row["phase"], "body": row["body"]})
data = {"welcome": self.welcome, "messages": messages}
return (json.dumps(data)+"\n").encode("utf-8")
def render_GET(self, request): def render_GET(self, request):
if b"text/event-stream" not in (request.getHeader(b"accept") or b""): appid = request.args[b"appid"][0].decode("utf-8")
return self.get_messages(request) #print("LIST", appid)
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") app = self._relay.get_app(appid)
ep = EventsProtocol(request) allocated = app.get_allocated()
ep.sendEvent(json.dumps(self.welcome), name="welcome") request.setHeader(b"content-type", b"application/json; charset=utf-8")
self.event_channels.add(ep) data = {"welcome": self._relay.welcome,
request.notifyFinish().addErrback(lambda f: "channelids": sorted(allocated)}
self.event_channels.discard(ep)) return (json.dumps(data)+"\n").encode("utf-8")
for row in self.db.execute("SELECT * FROM `messages`"
" WHERE `channelid`=?"
" ORDER BY `when` ASC",
(self.channelid,)).fetchall():
data = json.dumps({"phase": row["phase"], "body": row["body"]})
ep.sendEvent(data)
return server.NOT_DONE_YET
def broadcast_message(self, phase, body): class Allocator(resource.Resource):
data = json.dumps({"phase": phase, "body": body}) def __init__(self, relay):
for ep in self.event_channels: resource.Resource.__init__(self)
ep.sendEvent(data) self._relay = relay
def render_POST(self, request): def render_POST(self, request):
#data = json.load(request.content, encoding="utf-8")
content = request.content.read() content = request.content.read()
data = json.loads(content.decode("utf-8")) data = json.loads(content.decode("utf-8"))
appid = data["appid"]
side = data["side"]
if not isinstance(side, type(u"")):
raise TypeError("side must be string, not '%s'" % type(side))
#print("ALLOCATE", appid, side)
app = self._relay.get_app(appid)
channelid = app.find_available_channelid()
app.allocate_channel(channelid, side)
log.msg("allocated #%d, now have %d DB channels" %
(channelid, len(app.get_allocated())))
request.setHeader(b"content-type", b"application/json; charset=utf-8")
data = {"welcome": self._relay.welcome,
"channelid": channelid}
return (json.dumps(data)+"\n").encode("utf-8")
class Adder(resource.Resource):
def __init__(self, relay):
resource.Resource.__init__(self)
self._relay = relay
def render_POST(self, request):
#content = json.load(request.content, encoding="utf-8")
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
channelid = int(data["channelid"])
side = data["side"] side = data["side"]
phase = data["phase"] phase = data["phase"]
if not isinstance(phase, type(u"")): if not isinstance(phase, type(u"")):
raise TypeError("phase must be string, not %s" % type(phase)) raise TypeError("phase must be string, not %s" % type(phase))
body = data["body"] body = data["body"]
#print("ADD", appid, channelid, side, phase, body)
self.db.execute("INSERT INTO `messages`" app = self._relay.get_app(appid)
" (`channelid`, `side`, `phase`, `body`, `when`)" channel = app.get_channel(channelid)
" VALUES (?,?,?,?,?)", response = channel.add_message(side, phase, body)
(self.channelid, side, phase, body, time.time())) return json_response(request, response)
self.db.execute("INSERT INTO `allocations`"
" (`channelid`, `side`)" class Getter(resource.Resource):
" VALUES (?,?)", def __init__(self, relay):
(self.channelid, side)) self._relay = relay
self.db.commit()
self.broadcast_message(phase, body) def render_GET(self, request):
return self.get_messages(request) appid = request.args[b"appid"][0].decode("utf-8")
channelid = int(request.args[b"channelid"][0])
#print("GET", appid, channelid)
app = self._relay.get_app(appid)
channel = app.get_channel(channelid)
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
response = channel.get_messages()
return json_response(request, response)
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self._relay.welcome), name="welcome")
old_events = channel.add_listener(ep.sendEvent)
request.notifyFinish().addErrback(lambda f:
channel.remove_listener(ep.sendEvent))
for old_event in old_events:
ep.sendEvent(old_event)
return server.NOT_DONE_YET
class Deallocator(resource.Resource): class Deallocator(resource.Resource):
def __init__(self, channelid, relay): def __init__(self, relay):
self.channelid = channelid self._relay = relay
self.relay = relay
def render_POST(self, request): def render_POST(self, request):
content = request.content.read() content = request.content.read()
data = json.loads(content.decode("utf-8")) data = json.loads(content.decode("utf-8"))
appid = data["appid"]
channelid = int(data["channelid"])
side = data["side"] side = data["side"]
deleted = self.relay.maybe_free_child(self.channelid, side) #print("DEALLOCATE", appid, channelid, side)
resp = {"status": "waiting"} app = self._relay.get_app(appid)
deleted = app.maybe_free_child(channelid, side)
response = {"status": "waiting"}
if deleted: if deleted:
resp = {"status": "deleted"} response = {"status": "deleted"}
return json.dumps(resp).encode("utf-8") return json_response(request, response)
def get_allocated(db): class Channel(resource.Resource):
c = db.execute("SELECT DISTINCT `channelid` FROM `allocations`") def __init__(self, relay, appid, channelid):
return set([row["channelid"] for row in c.fetchall()])
class Allocator(resource.Resource):
def __init__(self, db, welcome):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.db = db self._relay = relay
self.welcome = welcome self._appid = appid
self._channelid = channelid
self._listeners = set() # callbacks that take JSONable object
def allocate_channelid(self): def get_messages(self):
allocated = get_allocated(self.db) messages = []
db = self._relay.db
for row in db.execute("SELECT * FROM `messages`"
" WHERE `appid`=? AND `channelid`=?"
" ORDER BY `when` ASC",
(self._appid, self._channelid)).fetchall():
messages.append({"phase": row["phase"], "body": row["body"]})
data = {"welcome": self._relay.welcome, "messages": messages}
return data
def add_listener(self, listener):
self._listeners.add(listener)
db = self._relay.db
for row in db.execute("SELECT * FROM `messages`"
" WHERE `appid`=? AND `channelid`=?"
" ORDER BY `when` ASC",
(self._appid, self._channelid)).fetchall():
yield json.dumps({"phase": row["phase"], "body": row["body"]})
def remove_listener(self, listener):
self._listeners.discard(listener)
def broadcast_message(self, phase, body):
data = json.dumps({"phase": phase, "body": body})
for listener in self._listeners:
listener(data)
def add_message(self, side, phase, body):
db = self._relay.db
db.execute("INSERT INTO `messages`"
" (`appid`, `channelid`, `side`, `phase`, `body`, `when`)"
" VALUES (?,?,?,?, ?,?)",
(self._appid, self._channelid, side, phase,
body, time.time()))
db.execute("INSERT INTO `allocations`"
" (`appid`, `channelid`, `side`)"
" VALUES (?,?,?)",
(self._appid, self._channelid, side))
db.commit()
self.broadcast_message(phase, body)
return self.get_messages()
class AppNamespace(resource.Resource):
def __init__(self, relay, appid):
resource.Resource.__init__(self)
self._relay = relay
self._appid = appid
self._channels = {}
def get_allocated(self):
db = self._relay.db
c = db.execute("SELECT DISTINCT `channelid` FROM `allocations`"
" WHERE `appid`=?", (self._appid,))
return set([row["channelid"] for row in c.fetchall()])
def find_available_channelid(self):
allocated = self.get_allocated()
for size in range(1,4): # stick to 1-999 for now for size in range(1,4): # stick to 1-999 for now
available = set() available = set()
for cid in range(10**(size-1), 10**size): for cid in range(10**(size-1), 10**size):
@ -161,37 +241,63 @@ class Allocator(resource.Resource):
cid = random.randrange(1000, 1000*1000) cid = random.randrange(1000, 1000*1000)
if cid not in allocated: if cid not in allocated:
return cid return cid
raise ValueError("unable to find a free channelid") raise ValueError("unable to find a free channel-id")
def render_POST(self, request): def allocate_channel(self, channelid, side):
content = request.content.read() db = self._relay.db
data = json.loads(content.decode("utf-8")) db.execute("INSERT INTO `allocations` VALUES (?,?,?)",
side = data["side"] (self._appid, channelid, side))
if not isinstance(side, type(u"")): db.commit()
raise TypeError("side must be string, not '%s'" % type(side))
channelid = self.allocate_channelid()
self.db.execute("INSERT INTO `allocations` VALUES (?,?)",
(channelid, side))
self.db.commit()
log.msg("allocated #%d, now have %d DB channels" %
(channelid, len(get_allocated(self.db))))
request.setHeader(b"content-type", b"application/json; charset=utf-8")
data = {"welcome": self.welcome,
"channelid": channelid}
return (json.dumps(data)+"\n").encode("utf-8")
class ChannelList(resource.Resource): def get_channel(self, channelid):
def __init__(self, db, welcome): assert isinstance(channelid, int)
resource.Resource.__init__(self) if not channelid in self._channels:
self.db = db log.msg("spawning #%d for appid %s" % (channelid, self._appid))
self.welcome = welcome self._channels[channelid] = Channel(self._relay,
def render_GET(self, request): self._appid, channelid)
c = self.db.execute("SELECT DISTINCT `channelid` FROM `allocations`") return self._channels[channelid]
allocated = sorted(set([row["channelid"] for row in c.fetchall()]))
request.setHeader(b"content-type", b"application/json; charset=utf-8") def maybe_free_child(self, channelid, side):
data = {"welcome": self.welcome, db = self._relay.db
"channelids": allocated} db.execute("DELETE FROM `allocations`"
return (json.dumps(data)+"\n").encode("utf-8") " WHERE `appid`=? AND `channelid`=? AND `side`=?",
(self._appid, channelid, side))
db.commit()
remaining = db.execute("SELECT COUNT(*) FROM `allocations`"
" WHERE `appid`=? AND `channelid`=?",
(self._appid, channelid)).fetchone()[0]
if remaining:
return False
self._free_child(channelid)
return True
def _free_child(self, channelid):
db = self._relay.db
db.execute("DELETE FROM `allocations`"
" WHERE `appid`=? AND `channelid`=?",
(self._appid, channelid))
db.execute("DELETE FROM `messages`"
" WHERE `appid`=? AND `channelid`=?",
(self._appid, channelid))
db.commit()
if channelid in self._channels:
self._channels.pop(channelid)
log.msg("freed+killed #%d, now have %d DB channels, %d live" %
(channelid, len(self.get_allocated()), len(self._channels)))
def prune_old_channels(self):
db = self._relay.db
old = time.time() - CHANNEL_EXPIRATION_TIME
for channelid in self.get_allocated():
c = db.execute("SELECT `when` FROM `messages`"
" WHERE `appid`=? AND `channelid`=?"
" ORDER BY `when` DESC LIMIT 1",
(self._appid, channelid))
rows = c.fetchall()
if not rows or (rows[0]["when"] < old):
log.msg("expiring %d" % channelid)
self._free_child(channelid)
return bool(self._channels)
class Relay(resource.Resource, service.MultiService): class Relay(resource.Resource, service.MultiService):
def __init__(self, db, welcome): def __init__(self, db, welcome):
@ -199,60 +305,24 @@ class Relay(resource.Resource, service.MultiService):
service.MultiService.__init__(self) service.MultiService.__init__(self)
self.db = db self.db = db
self.welcome = welcome self.welcome = welcome
self.channels = {} self._apps = {}
t = internet.TimerService(EXPIRATION_CHECK_PERIOD, t = internet.TimerService(EXPIRATION_CHECK_PERIOD, self.prune)
self.prune_old_channels)
t.setServiceParent(self) t.setServiceParent(self)
self.putChild(b"list", ChannelLister(self))
self.putChild(b"allocate", Allocator(self))
self.putChild(b"add", Adder(self))
self.putChild(b"get", Getter(self))
self.putChild(b"deallocate", Deallocator(self))
def get_app(self, appid):
assert isinstance(appid, type(u""))
if not appid in self._apps:
log.msg("spawning appid %s" % (appid,))
self._apps[appid] = AppNamespace(self, appid)
return self._apps[appid]
def getChild(self, path, request): def prune(self):
if path == b"allocate": for appid in list(self._apps):
return Allocator(self.db, self.welcome) still_active = self._apps[appid].prune_old_channels()
if path == b"list": if not still_active:
return ChannelList(self.db, self.welcome) self._apps.pop(appid)
if not re.search(br'^\d+$', path):
return resource.ErrorPage(http.BAD_REQUEST,
"invalid channel id",
"invalid channel id")
channelid = int(path)
if not channelid in self.channels:
log.msg("spawning #%d" % channelid)
self.channels[channelid] = Channel(channelid, self, self.db,
self.welcome)
return self.channels[channelid]
def maybe_free_child(self, channelid, side):
self.db.execute("DELETE FROM `allocations`"
" WHERE `channelid`=? AND `side`=?",
(channelid, side))
self.db.commit()
remaining = self.db.execute("SELECT COUNT(*) FROM `allocations`"
" WHERE `channelid`=?",
(channelid,)).fetchone()[0]
if remaining:
return False
self.free_child(channelid)
return True
def free_child(self, channelid):
self.db.execute("DELETE FROM `allocations` WHERE `channelid`=?",
(channelid,))
self.db.execute("DELETE FROM `messages` WHERE `channelid`=?",
(channelid,))
self.db.commit()
if channelid in self.channels:
self.channels.pop(channelid)
log.msg("freed+killed #%d, now have %d DB channels, %d live" %
(channelid, len(get_allocated(self.db)), len(self.channels)))
def prune_old_channels(self):
old = time.time() - CHANNEL_EXPIRATION_TIME
for channelid in get_allocated(self.db):
c = self.db.execute("SELECT `when` FROM `messages`"
" WHERE `channelid`=?"
" ORDER BY `when` DESC LIMIT 1", (channelid,))
rows = c.fetchall()
if not rows or (rows[0]["when"] < old):
log.msg("expiring %d" % channelid)
self.free_child(channelid)

View File

@ -14,6 +14,7 @@ class ServerBase:
"tcp:%s:interface=127.0.0.1" % transitport, "tcp:%s:interface=127.0.0.1" % transitport,
__version__) __version__)
s.setServiceParent(self.sp) s.setServiceParent(self.sp)
self._relay_server = s.relay
self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport
self.transit = "tcp:127.0.0.1:%d" % transitport self.transit = "tcp:127.0.0.1:%d" % transitport
d.addCallback(_got_ports) d.addCallback(_got_ports)

View File

@ -1,12 +1,89 @@
from __future__ import print_function
import json import json
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet.defer import gatherResults from twisted.internet.defer import gatherResults, succeed
from twisted.internet.threads import deferToThread from twisted.internet.threads import deferToThread
from ..blocking.transcribe import Wormhole as BlockingWormhole, UsageError from ..blocking.transcribe import (Wormhole as BlockingWormhole, UsageError,
ChannelManager)
from .common import ServerBase from .common import ServerBase
APPID = u"appid" APPID = u"appid"
class Channel(ServerBase, unittest.TestCase):
def ignore(self, welcome):
pass
def test_allocate(self):
cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore)
d = deferToThread(cm.list_channels)
def _got_channels(channels):
self.failUnlessEqual(channels, [])
d.addCallback(_got_channels)
d.addCallback(lambda _: deferToThread(cm.allocate))
def _allocated(channelid):
self.failUnlessEqual(type(channelid), int)
self._channelid = channelid
d.addCallback(_allocated)
d.addCallback(lambda _: deferToThread(cm.connect, self._channelid))
def _connected(c):
self._channel = c
d.addCallback(_connected)
d.addCallback(lambda _: deferToThread(self._channel.deallocate))
return d
def test_messages(self):
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
c1 = cm1.connect(1)
c2 = cm2.connect(1)
d = succeed(None)
d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1"))
d.addCallback(lambda _: deferToThread(c2.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1"))
d.addCallback(lambda _: deferToThread(c2.send, u"phase1", b"msg2"))
d.addCallback(lambda _: deferToThread(c1.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
# it's legal to fetch a phase multiple times, should be idempotent
d.addCallback(lambda _: deferToThread(c1.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
# deallocating one side is not enough to destroy the channel
d.addCallback(lambda _: deferToThread(c2.deallocate))
def _not_yet(_):
self._relay_server.prune()
self.failUnlessEqual(len(self._relay_server._apps), 1)
d.addCallback(_not_yet)
# but deallocating both will make the messages go away
d.addCallback(lambda _: deferToThread(c1.deallocate))
def _gone(_):
self._relay_server.prune()
self.failUnlessEqual(len(self._relay_server._apps), 0)
d.addCallback(_gone)
return d
def test_appid_independence(self):
APPID_A = u"appid_A"
APPID_B = u"appid_B"
cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore)
cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore)
c1a = cm1a.connect(1)
c2a = cm2a.connect(1)
cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore)
cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore)
c1b = cm1b.connect(1)
c2b = cm2b.connect(1)
d = succeed(None)
d.addCallback(lambda _: deferToThread(c1a.send, u"phase1", b"msg1a"))
d.addCallback(lambda _: deferToThread(c1b.send, u"phase1", b"msg1b"))
d.addCallback(lambda _: deferToThread(c2a.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a"))
d.addCallback(lambda _: deferToThread(c2b.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b"))
return d
class Blocking(ServerBase, unittest.TestCase): class Blocking(ServerBase, unittest.TestCase):
# we need Twisted to run the server, but we run the sender and receiver # we need Twisted to run the server, but we run the sender and receiver
# with deferToThread() # with deferToThread()

View File

@ -99,28 +99,27 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
out, err, rc = res out, err, rc = res
out = out.decode("utf-8") out = out.decode("utf-8")
err = err.decode("utf-8") err = err.decode("utf-8")
self.failUnlessEqual(out, self.maxDiff = None
"Sending text message (%d bytes)\n" expected = ("Sending text message (%d bytes)\n"
"On the other computer, please run: " "On the other computer, please run: "
"wormhole receive\n" "wormhole receive\n"
"Wormhole code is: %s\n\n" "Wormhole code is: %s\n\n"
"text message sent\n" % (len(message), code) "text message sent\n" % (len(message), code))
) self.failUnlessEqual( (expected, "", 0),
self.failUnlessEqual(err, "") (out, err, rc) )
self.failUnlessEqual(rc, 0)
return d2 return d2
d1.addCallback(_check_sender) d1.addCallback(_check_sender)
def _check_receiver(res): def _check_receiver(res):
out, err, rc = res out, err, rc = res
out = out.decode("utf-8") out = out.decode("utf-8")
err = err.decode("utf-8") err = err.decode("utf-8")
self.failUnlessEqual(out, message+"\n") self.failUnlessEqual( (message+"\n", "", 0),
self.failUnlessEqual(err, "") (out, err, rc) )
self.failUnlessEqual(rc, 0)
d1.addCallback(_check_receiver) d1.addCallback(_check_receiver)
return d1 return d1
def test_send_file_pre_generated_code(self): def test_send_file_pre_generated_code(self):
self.maxDiff=None
code = "1-abc" code = "1-abc"
filename = "testfile" filename = "testfile"
message = "test message" message = "test message"
@ -150,6 +149,7 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
out, err, rc = res out, err, rc = res
out = out.decode("utf-8") out = out.decode("utf-8")
err = err.decode("utf-8") err = err.decode("utf-8")
self.failUnlessEqual(err, "")
self.failUnlessIn("Sending %d byte file named '%s'\n" % self.failUnlessIn("Sending %d byte file named '%s'\n" %
(len(message), filename), out) (len(message), filename), out)
self.failUnlessIn("On the other computer, please run: " self.failUnlessIn("On the other computer, please run: "
@ -159,7 +159,6 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
self.failUnlessIn("File sent.. waiting for confirmation\n" self.failUnlessIn("File sent.. waiting for confirmation\n"
"Confirmation received. Transfer complete.\n", "Confirmation received. Transfer complete.\n",
out) out)
self.failUnlessEqual(err, "")
self.failUnlessEqual(rc, 0) self.failUnlessEqual(rc, 0)
return d2 return d2
d1.addCallback(_check_sender) d1.addCallback(_check_sender)

View File

@ -1,6 +1,7 @@
from __future__ import print_function from __future__ import print_function
import sys, json import sys, json
import requests import requests
from six.moves.urllib_parse import urlencode
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from twisted.internet.threads import deferToThread from twisted.internet.threads import deferToThread
@ -55,15 +56,26 @@ def unjson(data):
return json.loads(data.decode("utf-8")) return json.loads(data.decode("utf-8"))
class API(ServerBase, unittest.TestCase): class API(ServerBase, unittest.TestCase):
def get(self, path, is_json=True): def build_url(self, path, appid, channelid):
url = (self.relayurl+path).encode("ascii") url = self.relayurl+path
d = getPage(url) queryargs = []
if is_json: if appid:
d.addCallback(unjson) queryargs.append(("appid", appid))
if channelid:
queryargs.append(("channelid", channelid))
if queryargs:
url += "?" + urlencode(queryargs)
return url
def get(self, path, appid=None, channelid=None):
url = self.build_url(path, appid, channelid)
d = getPage(url.encode("ascii"))
d.addCallback(unjson)
return d return d
def post(self, path, data): def post(self, path, data):
url = (self.relayurl+path).encode("ascii") url = self.relayurl+path
d = getPage(url, method=b"POST", d = getPage(url.encode("ascii"), method=b"POST",
postdata=json.dumps(data).encode("utf-8")) postdata=json.dumps(data).encode("utf-8"))
d.addCallback(unjson) d.addCallback(unjson)
return d return d
@ -73,13 +85,14 @@ class API(ServerBase, unittest.TestCase):
self.failUnlessEqual(data["welcome"], {"current_version": __version__}) self.failUnlessEqual(data["welcome"], {"current_version": __version__})
def test_allocate_1(self): def test_allocate_1(self):
d = self.get("list") d = self.get("list", "app1")
def _check_list_1(data): def _check_list_1(data):
self.check_welcome(data) self.check_welcome(data)
self.failUnlessEqual(data["channelids"], []) self.failUnlessEqual(data["channelids"], [])
d.addCallback(_check_list_1) d.addCallback(_check_list_1)
d.addCallback(lambda _: self.post("allocate", {"side": "abc"})) d.addCallback(lambda _: self.post("allocate", {"appid": "app1",
"side": "abc"}))
def _allocated(data): def _allocated(data):
self.failUnlessEqual(set(data.keys()), self.failUnlessEqual(set(data.keys()),
set(["welcome", "channelid"])) set(["welcome", "channelid"]))
@ -87,18 +100,20 @@ class API(ServerBase, unittest.TestCase):
self.cid = data["channelid"] self.cid = data["channelid"]
d.addCallback(_allocated) d.addCallback(_allocated)
d.addCallback(lambda _: self.get("list")) d.addCallback(lambda _: self.get("list", "app1"))
def _check_list_2(data): def _check_list_2(data):
self.failUnlessEqual(data["channelids"], [self.cid]) self.failUnlessEqual(data["channelids"], [self.cid])
d.addCallback(_check_list_2) d.addCallback(_check_list_2)
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, d.addCallback(lambda _: self.post("deallocate",
{"side": "abc"})) {"appid": "app1",
"channelid": str(self.cid),
"side": "abc"}))
def _check_deallocate(res): def _check_deallocate(res):
self.failUnlessEqual(res["status"], "deleted") self.failUnlessEqual(res["status"], "deleted")
d.addCallback(_check_deallocate) d.addCallback(_check_deallocate)
d.addCallback(lambda _: self.get("list")) d.addCallback(lambda _: self.get("list", "app1"))
def _check_list_3(data): def _check_list_3(data):
self.failUnlessEqual(data["channelids"], []) self.failUnlessEqual(data["channelids"], [])
d.addCallback(_check_list_3) d.addCallback(_check_list_3)
@ -106,45 +121,57 @@ class API(ServerBase, unittest.TestCase):
return d return d
def test_allocate_2(self): def test_allocate_2(self):
d = self.post("allocate", {"side": "abc"}) d = self.post("allocate", {"appid": "app1", "side": "abc"})
def _allocated(data): def _allocated(data):
self.cid = data["channelid"] self.cid = data["channelid"]
d.addCallback(_allocated) d.addCallback(_allocated)
# second caller increases the number of known sides to 2 # second caller increases the number of known sides to 2
d.addCallback(lambda _: self.post("%d" % self.cid, d.addCallback(lambda _: self.post("add",
{"side": "def", {"appid": "app1",
"channelid": str(self.cid),
"side": "def",
"phase": "1", "phase": "1",
"body": ""})) "body": ""}))
d.addCallback(lambda _: self.get("list")) d.addCallback(lambda _: self.get("list", "app1"))
d.addCallback(lambda data: d.addCallback(lambda data:
self.failUnlessEqual(data["channelids"], [self.cid])) self.failUnlessEqual(data["channelids"], [self.cid]))
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, d.addCallback(lambda _: self.post("deallocate",
{"side": "abc"})) {"appid": "app1",
"channelid": str(self.cid),
"side": "abc"}))
d.addCallback(lambda res: d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "waiting")) self.failUnlessEqual(res["status"], "waiting"))
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, d.addCallback(lambda _: self.post("deallocate",
{"side": "NOT"})) {"appid": "app1",
"channelid": str(self.cid),
"side": "NOT"}))
d.addCallback(lambda res: d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "waiting")) self.failUnlessEqual(res["status"], "waiting"))
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid, d.addCallback(lambda _: self.post("deallocate",
{"side": "def"})) {"appid": "app1",
"channelid": str(self.cid),
"side": "def"}))
d.addCallback(lambda res: d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "deleted")) self.failUnlessEqual(res["status"], "deleted"))
d.addCallback(lambda _: self.get("list")) d.addCallback(lambda _: self.get("list", "app1"))
d.addCallback(lambda data: d.addCallback(lambda data:
self.failUnlessEqual(data["channelids"], [])) self.failUnlessEqual(data["channelids"], []))
return d return d
def add_message(self, message, side="abc", phase="1"): def add_message(self, message, side="abc", phase="1"):
return self.post(str(self.cid), {"side": side, "phase": phase, return self.post("add",
"body": message}) {"appid": "app1",
"channelid": str(self.cid),
"side": side,
"phase": phase,
"body": message})
def parse_messages(self, messages): def parse_messages(self, messages):
out = set() out = set()
@ -164,7 +191,7 @@ class API(ServerBase, unittest.TestCase):
self.failUnlessIn(d, two) self.failUnlessIn(d, two)
def test_messages(self): def test_messages(self):
d = self.post("allocate", {"side": "abc"}) d = self.post("allocate", {"appid": "app1", "side": "abc"})
def _allocated(data): def _allocated(data):
self.cid = data["channelid"] self.cid = data["channelid"]
d.addCallback(_allocated) d.addCallback(_allocated)
@ -175,6 +202,8 @@ class API(ServerBase, unittest.TestCase):
self.failUnlessEqual(data["messages"], self.failUnlessEqual(data["messages"],
[{"phase": "1", "body": "msg1A"}]) [{"phase": "1", "body": "msg1A"}])
d.addCallback(_check1) d.addCallback(_check1)
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
d.addCallback(_check1)
d.addCallback(lambda _: self.add_message("msg1B", side="def")) d.addCallback(lambda _: self.add_message("msg1B", side="def"))
def _check2(data): def _check2(data):
self.check_welcome(data) self.check_welcome(data)
@ -182,6 +211,8 @@ class API(ServerBase, unittest.TestCase):
set([("1", "msg1A"), set([("1", "msg1A"),
("1", "msg1B")])) ("1", "msg1B")]))
d.addCallback(_check2) d.addCallback(_check2)
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
d.addCallback(_check2)
# adding a duplicate message is not an error, is ignored by clients # adding a duplicate message is not an error, is ignored by clients
d.addCallback(lambda _: self.add_message("msg1B", side="def")) d.addCallback(lambda _: self.add_message("msg1B", side="def"))
@ -191,6 +222,8 @@ class API(ServerBase, unittest.TestCase):
set([("1", "msg1A"), set([("1", "msg1A"),
("1", "msg1B")])) ("1", "msg1B")]))
d.addCallback(_check3) d.addCallback(_check3)
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
d.addCallback(_check3)
d.addCallback(lambda _: self.add_message("msg2A", side="abc", d.addCallback(lambda _: self.add_message("msg2A", side="abc",
phase="2")) phase="2"))
@ -202,6 +235,8 @@ class API(ServerBase, unittest.TestCase):
("2", "msg2A"), ("2", "msg2A"),
])) ]))
d.addCallback(_check4) d.addCallback(_check4)
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
d.addCallback(_check4)
return d return d
@ -209,10 +244,10 @@ class API(ServerBase, unittest.TestCase):
if sys.version_info[0] >= 3: if sys.version_info[0] >= 3:
raise unittest.SkipTest("twisted vs py3") raise unittest.SkipTest("twisted vs py3")
d = self.post("allocate", {"side": "abc"}) d = self.post("allocate", {"appid": "app1", "side": "abc"})
def _allocated(data): def _allocated(data):
self.cid = data["channelid"] self.cid = data["channelid"]
url = self.relayurl+str(self.cid) url = self.build_url("get", "app1", self.cid)
self.o = OneEventAtATime(url, parser=json.loads) self.o = OneEventAtATime(url, parser=json.loads)
return self.o.wait_for_connection() return self.o.wait_for_connection()
d.addCallback(_allocated) d.addCallback(_allocated)

View File

@ -1,11 +1,86 @@
from __future__ import print_function
import sys, json import sys, json
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet.defer import gatherResults from twisted.internet.defer import gatherResults, succeed
from ..twisted.transcribe import Wormhole, UsageError from ..twisted.transcribe import Wormhole, UsageError, ChannelManager
from .common import ServerBase from .common import ServerBase
APPID = u"appid" APPID = u"appid"
class Channel(ServerBase, unittest.TestCase):
def ignore(self, welcome):
pass
def test_allocate(self):
cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore)
d = cm.list_channels()
def _got_channels(channels):
self.failUnlessEqual(channels, [])
d.addCallback(_got_channels)
d.addCallback(lambda _: cm.allocate())
def _allocated(channelid):
self.failUnlessEqual(type(channelid), int)
self._channelid = channelid
d.addCallback(_allocated)
d.addCallback(lambda _: cm.connect(self._channelid))
def _connected(c):
self._channel = c
d.addCallback(_connected)
d.addCallback(lambda _: self._channel.deallocate())
return d
def test_messages(self):
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
c1 = cm1.connect(1)
c2 = cm2.connect(1)
d = succeed(None)
d.addCallback(lambda _: c1.send(u"phase1", b"msg1"))
d.addCallback(lambda _: c2.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1"))
d.addCallback(lambda _: c2.send(u"phase1", b"msg2"))
d.addCallback(lambda _: c1.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
# it's legal to fetch a phase multiple times, should be idempotent
d.addCallback(lambda _: c1.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
# deallocating one side is not enough to destroy the channel
d.addCallback(lambda _: c2.deallocate())
def _not_yet(_):
self._relay_server.prune()
self.failUnlessEqual(len(self._relay_server._apps), 1)
d.addCallback(_not_yet)
# but deallocating both will make the messages go away
d.addCallback(lambda _: c1.deallocate())
def _gone(_):
self._relay_server.prune()
self.failUnlessEqual(len(self._relay_server._apps), 0)
d.addCallback(_gone)
return d
def test_appid_independence(self):
APPID_A = u"appid_A"
APPID_B = u"appid_B"
cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore)
cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore)
c1a = cm1a.connect(1)
c2a = cm2a.connect(1)
cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore)
cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore)
c1b = cm1b.connect(1)
c2b = cm2b.connect(1)
d = succeed(None)
d.addCallback(lambda _: c1a.send(u"phase1", b"msg1a"))
d.addCallback(lambda _: c1b.send(u"phase1", b"msg1b"))
d.addCallback(lambda _: c2a.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a"))
d.addCallback(lambda _: c2b.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b"))
return d
class Basic(ServerBase, unittest.TestCase): class Basic(ServerBase, unittest.TestCase):
def doBoth(self, d1, d2): def doBoth(self, d1, d2):
@ -154,6 +229,7 @@ class Basic(ServerBase, unittest.TestCase):
return d return d
if sys.version_info[0] >= 3: if sys.version_info[0] >= 3:
Channel.skip = "twisted is not yet sufficiently ported to py3"
Basic.skip = "twisted is not yet sufficiently ported to py3" Basic.skip = "twisted is not yet sufficiently ported to py3"
# as of 15.4.0, Twisted is still missing: # as of 15.4.0, Twisted is still missing:
# * web.client.Agent (for all non-EventSource POSTs in transcribe.py) # * web.client.Agent (for all non-EventSource POSTs in transcribe.py)

View File

@ -1,5 +1,6 @@
from __future__ import print_function from __future__ import print_function
import os, sys, json, re, unicodedata import os, sys, json, re, unicodedata
from six.moves.urllib_parse import urlencode
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
@ -50,10 +51,24 @@ def post_json(agent, url, request_body):
d.addCallback(lambda data: json.loads(data)) d.addCallback(lambda data: json.loads(data))
return d return d
def get_json(agent, url):
# GET from a URL, parsing the response as JSON
d = agent.request("GET", url.encode("utf-8"))
def _check_error(resp):
if resp.code != 200:
raise web_error.Error(resp.code, resp.phrase)
return resp
d.addCallback(_check_error)
d.addCallback(web_client.readBody)
d.addCallback(lambda data: json.loads(data))
return d
class Channel: class Channel:
def __init__(self, relay_url, channelid, side, handle_welcome, def __init__(self, relay_url, appid, channelid, side, handle_welcome,
agent): agent):
self._channel_url = u"%s%d" % (relay_url, channelid) self._relay_url = relay_url
self._appid = appid
self._channelid = channelid
self._side = side self._side = side
self._handle_welcome = handle_welcome self._handle_welcome = handle_welcome
self._agent = agent self._agent = agent
@ -78,10 +93,12 @@ class Channel:
if not isinstance(phase, type(u"")): raise UsageError(type(phase)) if not isinstance(phase, type(u"")): raise UsageError(type(phase))
if not isinstance(msg, type(b"")): raise UsageError(type(msg)) if not isinstance(msg, type(b"")): raise UsageError(type(msg))
self._sent_messages.add( (phase,msg) ) self._sent_messages.add( (phase,msg) )
payload = {"side": self._side, payload = {"appid": self._appid,
"channelid": self._channelid,
"side": self._side,
"phase": phase, "phase": phase,
"body": hexlify(msg).decode("ascii")} "body": hexlify(msg).decode("ascii")}
d = post_json(self._agent, self._channel_url, payload) d = post_json(self._agent, self._relay_url+"add", payload)
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"])) d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
return d return d
@ -104,7 +121,10 @@ class Channel:
msgs.append(body) msgs.append(body)
d.callback(None) d.callback(None)
# TODO: use agent=self._agent # TODO: use agent=self._agent
es = ReconnectingEventSource(self._channel_url, _handle) queryargs = urlencode([("appid", self._appid),
("channelid", self._channelid)])
es = ReconnectingEventSource(self._relay_url+"get?%s" % queryargs,
_handle)
es.startService() # TODO: .setServiceParent(self) es.startService() # TODO: .setServiceParent(self)
es.activate() es.activate()
d.addCallback(lambda _: es.deactivate()) d.addCallback(lambda _: es.deactivate())
@ -114,22 +134,26 @@ class Channel:
def deallocate(self): def deallocate(self):
# only try once, no retries # only try once, no retries
d = post_json(self._agent, self._channel_url+"/deallocate", d = post_json(self._agent, self._relay_url+"deallocate",
{"side": self._side}) {"appid": self._appid,
"channelid": self._channelid,
"side": self._side})
d.addBoth(lambda _: None) # ignore POST failure d.addBoth(lambda _: None) # ignore POST failure
return d return d
class ChannelManager: class ChannelManager:
def __init__(self, relay_url, side, handle_welcome): def __init__(self, relay, appid, side, handle_welcome):
assert isinstance(relay_url, type(u"")) assert isinstance(relay, type(u""))
self._relay_url = relay_url self._relay = relay
self._appid = appid
self._side = side self._side = side
self._handle_welcome = handle_welcome self._handle_welcome = handle_welcome
self._agent = web_client.Agent(reactor) self._agent = web_client.Agent(reactor)
def allocate(self): def allocate(self):
url = self._relay_url + "allocate" url = self._relay + "allocate"
d = post_json(self._agent, url, {"side": self._side}) d = post_json(self._agent, url, {"appid": self._appid,
"side": self._side})
def _got_channel(data): def _got_channel(data):
if "welcome" in data: if "welcome" in data:
self._handle_welcome(data["welcome"]) self._handle_welcome(data["welcome"])
@ -138,10 +162,14 @@ class ChannelManager:
return d return d
def list_channels(self): def list_channels(self):
raise NotImplementedError queryargs = urlencode([("appid", self._appid)])
url = self._relay + u"list?%s" % queryargs
d = get_json(self._agent, url)
d.addCallback(lambda r: r["channelids"])
return d
def connect(self, channelid): def connect(self, channelid):
return Channel(self._relay_url, channelid, self._side, return Channel(self._relay, self._appid, channelid, self._side,
self._handle_welcome, self._agent) self._handle_welcome, self._agent)
class Wormhole: class Wormhole:
@ -163,17 +191,16 @@ class Wormhole:
def _set_side(self, side): def _set_side(self, side):
self._side = side self._side = side
self._channel_manager = ChannelManager(self._relay_url, self._side, self._channel_manager = ChannelManager(self._relay_url, self._appid,
self.handle_welcome) self._side, self.handle_welcome)
def handle_welcome(self, welcome): def handle_welcome(self, welcome):
if ("motd" in welcome and if ("motd" in welcome and
not self.motd_displayed): not self.motd_displayed):
motd_lines = welcome["motd"].splitlines() motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines) motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" % (self._relay_url, print("Server (at %s) says:\n %s" %
motd_formatted), (self._relay_url, motd_formatted), file=sys.stderr)
file=sys.stderr)
self.motd_displayed = True self.motd_displayed = True
# Only warn if we're running a release version (e.g. 0.0.6, not # Only warn if we're running a release version (e.g. 0.0.6, not