scope channelids to the appid, change API and DB schema

This requires a DB delete/recreate when upgrading. It changes the server
protocol, and app IDs, so clients cannot interoperate with each other
across this change, nor with the server. Flag day for everyone!

Now apps do not share channel IDs, so a lot of usage of app1 will not
cause the wormhole codes for app2 to get longer.
This commit is contained in:
Brian Warner 2015-10-06 17:20:12 -07:00
parent 8692bd2cd7
commit 574d5f2314
9 changed files with 548 additions and 249 deletions

View File

@ -1,5 +1,6 @@
from __future__ import print_function
import os, sys, time, re, requests, json, unicodedata
from six.moves.urllib_parse import urlencode
from binascii import hexlify, unhexlify
from spake2 import SPAKE2_Symmetric
from nacl.secret import SecretBox
@ -18,19 +19,21 @@ MINUTE = 60*SECOND
def to_bytes(u):
return unicodedata.normalize("NFC", u).encode("utf-8")
# relay URLs are:
# GET /list -> {channelids: [INT..]}
# POST /allocate {side: SIDE} -> {channelid: INT}
# these return all messages (base64) for CID= :
# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]}
# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]}
# GET /CID (eventsource) -> {phase:, body:}..
# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted}
# relay URLs are as follows: (MESSAGES=[{phase:,body:}..])
# GET /list?appid= -> {channelids: [INT..]}
# POST /allocate {appid:,side:} -> {channelid: INT}
# these return all messages (base64) for appid=/channelid= :
# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES}
# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES}
# GET /get?appid=&channelid= (eventsource) -> {phase:, body:}..
# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted}
# all JSON responses include a "welcome:{..}" key
class Channel:
def __init__(self, relay_url, channelid, side, handle_welcome):
self._channel_url = u"%s%d" % (relay_url, channelid)
def __init__(self, relay_url, appid, channelid, side, handle_welcome):
self._relay_url = relay_url
self._appid = appid
self._channelid = channelid
self._side = side
self._handle_welcome = handle_welcome
self._messages = set() # (phase,body) , body is bytes
@ -57,11 +60,13 @@ class Channel:
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
self._sent_messages.add( (phase,msg) )
payload = {"side": self._side,
payload = {"appid": self._appid,
"channelid": self._channelid,
"side": self._side,
"phase": phase,
"body": hexlify(msg).decode("ascii")}
data = json.dumps(payload).encode("utf-8")
r = requests.post(self._channel_url, data=data)
r = requests.post(self._relay_url+"add", data=data)
r.raise_for_status()
resp = r.json()
self._add_inbound_messages(resp["messages"])
@ -80,7 +85,10 @@ class Channel:
remaining = self._started + self._timeout - time.time()
if remaining < 0:
return Timeout
f = EventSourceFollower(self._channel_url, remaining)
queryargs = urlencode([("appid", self._appid),
("channelid", self._channelid)])
f = EventSourceFollower(self._relay_url+"get?%s" % queryargs,
remaining)
# we loop here until the connection is lost, or we see the
# message we want
for (eventtype, data) in f.iter_events():
@ -98,25 +106,30 @@ class Channel:
def deallocate(self):
# only try once, no retries
data = json.dumps({"side": self._side}).encode("utf-8")
requests.post(self._channel_url+"/deallocate", data=data)
data = json.dumps({"appid": self._appid,
"channelid": self._channelid,
"side": self._side}).encode("utf-8")
requests.post(self._relay_url+"deallocate", data=data)
# ignore POST failure, don't call r.raise_for_status()
class ChannelManager:
def __init__(self, relay_url, side, handle_welcome):
def __init__(self, relay_url, appid, side, handle_welcome):
self._relay_url = relay_url
self._appid = appid
self._side = side
self._handle_welcome = handle_welcome
def list_channels(self):
r = requests.get(self._relay_url + "list")
queryargs = urlencode([("appid", self._appid)])
r = requests.get(self._relay_url+"list?%s" % queryargs)
r.raise_for_status()
channelids = r.json()["channelids"]
return channelids
def allocate(self):
data = json.dumps({"side": self._side}).encode("utf-8")
r = requests.post(self._relay_url + "allocate", data=data)
data = json.dumps({"appid": self._appid,
"side": self._side}).encode("utf-8")
r = requests.post(self._relay_url+"allocate", data=data)
r.raise_for_status()
data = r.json()
if "welcome" in data:
@ -125,7 +138,7 @@ class ChannelManager:
return channelid
def connect(self, channelid):
return Channel(self._relay_url, channelid, self._side,
return Channel(self._relay_url, self._appid, channelid, self._side,
self._handle_welcome)
class Wormhole:
@ -139,7 +152,7 @@ class Wormhole:
self._appid = appid
self._relay_url = relay_url
side = hexlify(os.urandom(5)).decode("ascii")
self._channel_manager = ChannelManager(relay_url, side,
self._channel_manager = ChannelManager(relay_url, appid, side,
self.handle_welcome)
self.code = None
self.key = None
@ -152,8 +165,7 @@ class Wormhole:
not self.motd_displayed):
motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" % (self._relay_url,
motd_formatted),
print("Server (at %s) says:\n %s" % (self._relay_url, motd_formatted),
file=sys.stderr)
self.motd_displayed = True

View File

@ -9,16 +9,18 @@ CREATE TABLE `version`
CREATE TABLE `messages`
(
`appid` VARCHAR,
`channelid` INTEGER,
`side` VARCHAR,
`phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string
`body` VARCHAR,
`when` INTEGER
);
CREATE INDEX `messages_idx` ON `messages` (`channelid`, `side`, `phase`);
CREATE INDEX `messages_idx` ON `messages` (`appid`, `channelid`);
CREATE TABLE `allocations`
(
`appid` VARCHAR,
`channelid` INTEGER,
`side` VARCHAR
);

View File

@ -1,8 +1,8 @@
from __future__ import print_function
import re, json, time, random
import json, time, random
from twisted.python import log
from twisted.application import service, internet
from twisted.web import server, resource, http
from twisted.web import server, resource
SECONDS = 1.0
MINUTE = 60*SECONDS
@ -13,6 +13,10 @@ MB = 1000*1000
CHANNEL_EXPIRATION_TIME = 3*DAY
EXPIRATION_CHECK_PERIOD = 2*HOUR
def json_response(request, data):
request.setHeader(b"content-type", b"application/json; charset=utf-8")
return (json.dumps(data)+"\n").encode("utf-8")
class EventsProtocol:
def __init__(self, request):
self.request = request
@ -46,109 +50,185 @@ class EventsProtocol:
# note: no versions of IE (including the current IE11) support EventSource
# relay URLs are:
# GET /list -> {channelids: [INT..]}
# POST /allocate {side: SIDE} -> {channelid: INT}
# these return all messages (base64) for CID= :
# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]}
# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]}
# GET /CID (eventsource) -> {phase:, body:}..
# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted}
# relay URLs are as follows: (MESSAGES=[{phase:,body:}..])
# GET /list?appid= -> {channelids: [INT..]}
# POST /allocate {appid:,side:} -> {channelid: INT}
# these return all messages (base64) for appid=/channelid= :
# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES}
# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES}
# GET /get?appid=&channelid= (eventsource) -> {phase:, body:}..
# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted}
# all JSON responses include a "welcome:{..}" key
class Channel(resource.Resource):
def __init__(self, channelid, relay, db, welcome):
class ChannelLister(resource.Resource):
def __init__(self, relay):
resource.Resource.__init__(self)
self.channelid = channelid
self.relay = relay
self.db = db
self.welcome = welcome
self.event_channels = set() # ep
self.putChild(b"deallocate", Deallocator(self.channelid, self.relay))
def get_messages(self, request):
request.setHeader(b"content-type", b"application/json; charset=utf-8")
messages = []
for row in self.db.execute("SELECT * FROM `messages`"
" WHERE `channelid`=?"
" ORDER BY `when` ASC",
(self.channelid,)).fetchall():
messages.append({"phase": row["phase"], "body": row["body"]})
data = {"welcome": self.welcome, "messages": messages}
return (json.dumps(data)+"\n").encode("utf-8")
self._relay = relay
def render_GET(self, request):
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
return self.get_messages(request)
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self.welcome), name="welcome")
self.event_channels.add(ep)
request.notifyFinish().addErrback(lambda f:
self.event_channels.discard(ep))
for row in self.db.execute("SELECT * FROM `messages`"
" WHERE `channelid`=?"
" ORDER BY `when` ASC",
(self.channelid,)).fetchall():
data = json.dumps({"phase": row["phase"], "body": row["body"]})
ep.sendEvent(data)
return server.NOT_DONE_YET
appid = request.args[b"appid"][0].decode("utf-8")
#print("LIST", appid)
app = self._relay.get_app(appid)
allocated = app.get_allocated()
request.setHeader(b"content-type", b"application/json; charset=utf-8")
data = {"welcome": self._relay.welcome,
"channelids": sorted(allocated)}
return (json.dumps(data)+"\n").encode("utf-8")
def broadcast_message(self, phase, body):
data = json.dumps({"phase": phase, "body": body})
for ep in self.event_channels:
ep.sendEvent(data)
class Allocator(resource.Resource):
def __init__(self, relay):
resource.Resource.__init__(self)
self._relay = relay
def render_POST(self, request):
#data = json.load(request.content, encoding="utf-8")
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
side = data["side"]
if not isinstance(side, type(u"")):
raise TypeError("side must be string, not '%s'" % type(side))
#print("ALLOCATE", appid, side)
app = self._relay.get_app(appid)
channelid = app.find_available_channelid()
app.allocate_channel(channelid, side)
log.msg("allocated #%d, now have %d DB channels" %
(channelid, len(app.get_allocated())))
request.setHeader(b"content-type", b"application/json; charset=utf-8")
data = {"welcome": self._relay.welcome,
"channelid": channelid}
return (json.dumps(data)+"\n").encode("utf-8")
class Adder(resource.Resource):
def __init__(self, relay):
resource.Resource.__init__(self)
self._relay = relay
def render_POST(self, request):
#content = json.load(request.content, encoding="utf-8")
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
channelid = int(data["channelid"])
side = data["side"]
phase = data["phase"]
if not isinstance(phase, type(u"")):
raise TypeError("phase must be string, not %s" % type(phase))
body = data["body"]
#print("ADD", appid, channelid, side, phase, body)
self.db.execute("INSERT INTO `messages`"
" (`channelid`, `side`, `phase`, `body`, `when`)"
" VALUES (?,?,?,?,?)",
(self.channelid, side, phase, body, time.time()))
self.db.execute("INSERT INTO `allocations`"
" (`channelid`, `side`)"
" VALUES (?,?)",
(self.channelid, side))
self.db.commit()
self.broadcast_message(phase, body)
return self.get_messages(request)
app = self._relay.get_app(appid)
channel = app.get_channel(channelid)
response = channel.add_message(side, phase, body)
return json_response(request, response)
class Getter(resource.Resource):
def __init__(self, relay):
self._relay = relay
def render_GET(self, request):
appid = request.args[b"appid"][0].decode("utf-8")
channelid = int(request.args[b"channelid"][0])
#print("GET", appid, channelid)
app = self._relay.get_app(appid)
channel = app.get_channel(channelid)
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
response = channel.get_messages()
return json_response(request, response)
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self._relay.welcome), name="welcome")
old_events = channel.add_listener(ep.sendEvent)
request.notifyFinish().addErrback(lambda f:
channel.remove_listener(ep.sendEvent))
for old_event in old_events:
ep.sendEvent(old_event)
return server.NOT_DONE_YET
class Deallocator(resource.Resource):
def __init__(self, channelid, relay):
self.channelid = channelid
self.relay = relay
def __init__(self, relay):
self._relay = relay
def render_POST(self, request):
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
channelid = int(data["channelid"])
side = data["side"]
deleted = self.relay.maybe_free_child(self.channelid, side)
resp = {"status": "waiting"}
#print("DEALLOCATE", appid, channelid, side)
app = self._relay.get_app(appid)
deleted = app.maybe_free_child(channelid, side)
response = {"status": "waiting"}
if deleted:
resp = {"status": "deleted"}
return json.dumps(resp).encode("utf-8")
response = {"status": "deleted"}
return json_response(request, response)
def get_allocated(db):
c = db.execute("SELECT DISTINCT `channelid` FROM `allocations`")
return set([row["channelid"] for row in c.fetchall()])
class Allocator(resource.Resource):
def __init__(self, db, welcome):
class Channel(resource.Resource):
def __init__(self, relay, appid, channelid):
resource.Resource.__init__(self)
self.db = db
self.welcome = welcome
self._relay = relay
self._appid = appid
self._channelid = channelid
self._listeners = set() # callbacks that take JSONable object
def allocate_channelid(self):
allocated = get_allocated(self.db)
def get_messages(self):
messages = []
db = self._relay.db
for row in db.execute("SELECT * FROM `messages`"
" WHERE `appid`=? AND `channelid`=?"
" ORDER BY `when` ASC",
(self._appid, self._channelid)).fetchall():
messages.append({"phase": row["phase"], "body": row["body"]})
data = {"welcome": self._relay.welcome, "messages": messages}
return data
def add_listener(self, listener):
self._listeners.add(listener)
db = self._relay.db
for row in db.execute("SELECT * FROM `messages`"
" WHERE `appid`=? AND `channelid`=?"
" ORDER BY `when` ASC",
(self._appid, self._channelid)).fetchall():
yield json.dumps({"phase": row["phase"], "body": row["body"]})
def remove_listener(self, listener):
self._listeners.discard(listener)
def broadcast_message(self, phase, body):
data = json.dumps({"phase": phase, "body": body})
for listener in self._listeners:
listener(data)
def add_message(self, side, phase, body):
db = self._relay.db
db.execute("INSERT INTO `messages`"
" (`appid`, `channelid`, `side`, `phase`, `body`, `when`)"
" VALUES (?,?,?,?, ?,?)",
(self._appid, self._channelid, side, phase,
body, time.time()))
db.execute("INSERT INTO `allocations`"
" (`appid`, `channelid`, `side`)"
" VALUES (?,?,?)",
(self._appid, self._channelid, side))
db.commit()
self.broadcast_message(phase, body)
return self.get_messages()
class AppNamespace(resource.Resource):
def __init__(self, relay, appid):
resource.Resource.__init__(self)
self._relay = relay
self._appid = appid
self._channels = {}
def get_allocated(self):
db = self._relay.db
c = db.execute("SELECT DISTINCT `channelid` FROM `allocations`"
" WHERE `appid`=?", (self._appid,))
return set([row["channelid"] for row in c.fetchall()])
def find_available_channelid(self):
allocated = self.get_allocated()
for size in range(1,4): # stick to 1-999 for now
available = set()
for cid in range(10**(size-1), 10**size):
@ -161,37 +241,63 @@ class Allocator(resource.Resource):
cid = random.randrange(1000, 1000*1000)
if cid not in allocated:
return cid
raise ValueError("unable to find a free channelid")
raise ValueError("unable to find a free channel-id")
def render_POST(self, request):
content = request.content.read()
data = json.loads(content.decode("utf-8"))
side = data["side"]
if not isinstance(side, type(u"")):
raise TypeError("side must be string, not '%s'" % type(side))
channelid = self.allocate_channelid()
self.db.execute("INSERT INTO `allocations` VALUES (?,?)",
(channelid, side))
self.db.commit()
log.msg("allocated #%d, now have %d DB channels" %
(channelid, len(get_allocated(self.db))))
request.setHeader(b"content-type", b"application/json; charset=utf-8")
data = {"welcome": self.welcome,
"channelid": channelid}
return (json.dumps(data)+"\n").encode("utf-8")
def allocate_channel(self, channelid, side):
db = self._relay.db
db.execute("INSERT INTO `allocations` VALUES (?,?,?)",
(self._appid, channelid, side))
db.commit()
class ChannelList(resource.Resource):
def __init__(self, db, welcome):
resource.Resource.__init__(self)
self.db = db
self.welcome = welcome
def render_GET(self, request):
c = self.db.execute("SELECT DISTINCT `channelid` FROM `allocations`")
allocated = sorted(set([row["channelid"] for row in c.fetchall()]))
request.setHeader(b"content-type", b"application/json; charset=utf-8")
data = {"welcome": self.welcome,
"channelids": allocated}
return (json.dumps(data)+"\n").encode("utf-8")
def get_channel(self, channelid):
assert isinstance(channelid, int)
if not channelid in self._channels:
log.msg("spawning #%d for appid %s" % (channelid, self._appid))
self._channels[channelid] = Channel(self._relay,
self._appid, channelid)
return self._channels[channelid]
def maybe_free_child(self, channelid, side):
db = self._relay.db
db.execute("DELETE FROM `allocations`"
" WHERE `appid`=? AND `channelid`=? AND `side`=?",
(self._appid, channelid, side))
db.commit()
remaining = db.execute("SELECT COUNT(*) FROM `allocations`"
" WHERE `appid`=? AND `channelid`=?",
(self._appid, channelid)).fetchone()[0]
if remaining:
return False
self._free_child(channelid)
return True
def _free_child(self, channelid):
db = self._relay.db
db.execute("DELETE FROM `allocations`"
" WHERE `appid`=? AND `channelid`=?",
(self._appid, channelid))
db.execute("DELETE FROM `messages`"
" WHERE `appid`=? AND `channelid`=?",
(self._appid, channelid))
db.commit()
if channelid in self._channels:
self._channels.pop(channelid)
log.msg("freed+killed #%d, now have %d DB channels, %d live" %
(channelid, len(self.get_allocated()), len(self._channels)))
def prune_old_channels(self):
db = self._relay.db
old = time.time() - CHANNEL_EXPIRATION_TIME
for channelid in self.get_allocated():
c = db.execute("SELECT `when` FROM `messages`"
" WHERE `appid`=? AND `channelid`=?"
" ORDER BY `when` DESC LIMIT 1",
(self._appid, channelid))
rows = c.fetchall()
if not rows or (rows[0]["when"] < old):
log.msg("expiring %d" % channelid)
self._free_child(channelid)
return bool(self._channels)
class Relay(resource.Resource, service.MultiService):
def __init__(self, db, welcome):
@ -199,60 +305,24 @@ class Relay(resource.Resource, service.MultiService):
service.MultiService.__init__(self)
self.db = db
self.welcome = welcome
self.channels = {}
t = internet.TimerService(EXPIRATION_CHECK_PERIOD,
self.prune_old_channels)
self._apps = {}
t = internet.TimerService(EXPIRATION_CHECK_PERIOD, self.prune)
t.setServiceParent(self)
self.putChild(b"list", ChannelLister(self))
self.putChild(b"allocate", Allocator(self))
self.putChild(b"add", Adder(self))
self.putChild(b"get", Getter(self))
self.putChild(b"deallocate", Deallocator(self))
def get_app(self, appid):
assert isinstance(appid, type(u""))
if not appid in self._apps:
log.msg("spawning appid %s" % (appid,))
self._apps[appid] = AppNamespace(self, appid)
return self._apps[appid]
def getChild(self, path, request):
if path == b"allocate":
return Allocator(self.db, self.welcome)
if path == b"list":
return ChannelList(self.db, self.welcome)
if not re.search(br'^\d+$', path):
return resource.ErrorPage(http.BAD_REQUEST,
"invalid channel id",
"invalid channel id")
channelid = int(path)
if not channelid in self.channels:
log.msg("spawning #%d" % channelid)
self.channels[channelid] = Channel(channelid, self, self.db,
self.welcome)
return self.channels[channelid]
def maybe_free_child(self, channelid, side):
self.db.execute("DELETE FROM `allocations`"
" WHERE `channelid`=? AND `side`=?",
(channelid, side))
self.db.commit()
remaining = self.db.execute("SELECT COUNT(*) FROM `allocations`"
" WHERE `channelid`=?",
(channelid,)).fetchone()[0]
if remaining:
return False
self.free_child(channelid)
return True
def free_child(self, channelid):
self.db.execute("DELETE FROM `allocations` WHERE `channelid`=?",
(channelid,))
self.db.execute("DELETE FROM `messages` WHERE `channelid`=?",
(channelid,))
self.db.commit()
if channelid in self.channels:
self.channels.pop(channelid)
log.msg("freed+killed #%d, now have %d DB channels, %d live" %
(channelid, len(get_allocated(self.db)), len(self.channels)))
def prune_old_channels(self):
old = time.time() - CHANNEL_EXPIRATION_TIME
for channelid in get_allocated(self.db):
c = self.db.execute("SELECT `when` FROM `messages`"
" WHERE `channelid`=?"
" ORDER BY `when` DESC LIMIT 1", (channelid,))
rows = c.fetchall()
if not rows or (rows[0]["when"] < old):
log.msg("expiring %d" % channelid)
self.free_child(channelid)
def prune(self):
for appid in list(self._apps):
still_active = self._apps[appid].prune_old_channels()
if not still_active:
self._apps.pop(appid)

View File

@ -14,6 +14,7 @@ class ServerBase:
"tcp:%s:interface=127.0.0.1" % transitport,
__version__)
s.setServiceParent(self.sp)
self._relay_server = s.relay
self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport
self.transit = "tcp:127.0.0.1:%d" % transitport
d.addCallback(_got_ports)

View File

@ -1,12 +1,89 @@
from __future__ import print_function
import json
from twisted.trial import unittest
from twisted.internet.defer import gatherResults
from twisted.internet.defer import gatherResults, succeed
from twisted.internet.threads import deferToThread
from ..blocking.transcribe import Wormhole as BlockingWormhole, UsageError
from ..blocking.transcribe import (Wormhole as BlockingWormhole, UsageError,
ChannelManager)
from .common import ServerBase
APPID = u"appid"
class Channel(ServerBase, unittest.TestCase):
def ignore(self, welcome):
pass
def test_allocate(self):
cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore)
d = deferToThread(cm.list_channels)
def _got_channels(channels):
self.failUnlessEqual(channels, [])
d.addCallback(_got_channels)
d.addCallback(lambda _: deferToThread(cm.allocate))
def _allocated(channelid):
self.failUnlessEqual(type(channelid), int)
self._channelid = channelid
d.addCallback(_allocated)
d.addCallback(lambda _: deferToThread(cm.connect, self._channelid))
def _connected(c):
self._channel = c
d.addCallback(_connected)
d.addCallback(lambda _: deferToThread(self._channel.deallocate))
return d
def test_messages(self):
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
c1 = cm1.connect(1)
c2 = cm2.connect(1)
d = succeed(None)
d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1"))
d.addCallback(lambda _: deferToThread(c2.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1"))
d.addCallback(lambda _: deferToThread(c2.send, u"phase1", b"msg2"))
d.addCallback(lambda _: deferToThread(c1.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
# it's legal to fetch a phase multiple times, should be idempotent
d.addCallback(lambda _: deferToThread(c1.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
# deallocating one side is not enough to destroy the channel
d.addCallback(lambda _: deferToThread(c2.deallocate))
def _not_yet(_):
self._relay_server.prune()
self.failUnlessEqual(len(self._relay_server._apps), 1)
d.addCallback(_not_yet)
# but deallocating both will make the messages go away
d.addCallback(lambda _: deferToThread(c1.deallocate))
def _gone(_):
self._relay_server.prune()
self.failUnlessEqual(len(self._relay_server._apps), 0)
d.addCallback(_gone)
return d
def test_appid_independence(self):
APPID_A = u"appid_A"
APPID_B = u"appid_B"
cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore)
cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore)
c1a = cm1a.connect(1)
c2a = cm2a.connect(1)
cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore)
cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore)
c1b = cm1b.connect(1)
c2b = cm2b.connect(1)
d = succeed(None)
d.addCallback(lambda _: deferToThread(c1a.send, u"phase1", b"msg1a"))
d.addCallback(lambda _: deferToThread(c1b.send, u"phase1", b"msg1b"))
d.addCallback(lambda _: deferToThread(c2a.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a"))
d.addCallback(lambda _: deferToThread(c2b.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b"))
return d
class Blocking(ServerBase, unittest.TestCase):
# we need Twisted to run the server, but we run the sender and receiver
# with deferToThread()

View File

@ -99,28 +99,27 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessEqual(out,
"Sending text message (%d bytes)\n"
"On the other computer, please run: "
"wormhole receive\n"
"Wormhole code is: %s\n\n"
"text message sent\n" % (len(message), code)
)
self.failUnlessEqual(err, "")
self.failUnlessEqual(rc, 0)
self.maxDiff = None
expected = ("Sending text message (%d bytes)\n"
"On the other computer, please run: "
"wormhole receive\n"
"Wormhole code is: %s\n\n"
"text message sent\n" % (len(message), code))
self.failUnlessEqual( (expected, "", 0),
(out, err, rc) )
return d2
d1.addCallback(_check_sender)
def _check_receiver(res):
out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessEqual(out, message+"\n")
self.failUnlessEqual(err, "")
self.failUnlessEqual(rc, 0)
self.failUnlessEqual( (message+"\n", "", 0),
(out, err, rc) )
d1.addCallback(_check_receiver)
return d1
def test_send_file_pre_generated_code(self):
self.maxDiff=None
code = "1-abc"
filename = "testfile"
message = "test message"
@ -150,6 +149,7 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessEqual(err, "")
self.failUnlessIn("Sending %d byte file named '%s'\n" %
(len(message), filename), out)
self.failUnlessIn("On the other computer, please run: "
@ -159,7 +159,6 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
self.failUnlessIn("File sent.. waiting for confirmation\n"
"Confirmation received. Transfer complete.\n",
out)
self.failUnlessEqual(err, "")
self.failUnlessEqual(rc, 0)
return d2
d1.addCallback(_check_sender)

View File

@ -1,6 +1,7 @@
from __future__ import print_function
import sys, json
import requests
from six.moves.urllib_parse import urlencode
from twisted.trial import unittest
from twisted.internet import reactor, defer
from twisted.internet.threads import deferToThread
@ -55,15 +56,26 @@ def unjson(data):
return json.loads(data.decode("utf-8"))
class API(ServerBase, unittest.TestCase):
def get(self, path, is_json=True):
url = (self.relayurl+path).encode("ascii")
d = getPage(url)
if is_json:
d.addCallback(unjson)
def build_url(self, path, appid, channelid):
url = self.relayurl+path
queryargs = []
if appid:
queryargs.append(("appid", appid))
if channelid:
queryargs.append(("channelid", channelid))
if queryargs:
url += "?" + urlencode(queryargs)
return url
def get(self, path, appid=None, channelid=None):
url = self.build_url(path, appid, channelid)
d = getPage(url.encode("ascii"))
d.addCallback(unjson)
return d
def post(self, path, data):
url = (self.relayurl+path).encode("ascii")
d = getPage(url, method=b"POST",
url = self.relayurl+path
d = getPage(url.encode("ascii"), method=b"POST",
postdata=json.dumps(data).encode("utf-8"))
d.addCallback(unjson)
return d
@ -73,13 +85,14 @@ class API(ServerBase, unittest.TestCase):
self.failUnlessEqual(data["welcome"], {"current_version": __version__})
def test_allocate_1(self):
d = self.get("list")
d = self.get("list", "app1")
def _check_list_1(data):
self.check_welcome(data)
self.failUnlessEqual(data["channelids"], [])
d.addCallback(_check_list_1)
d.addCallback(lambda _: self.post("allocate", {"side": "abc"}))
d.addCallback(lambda _: self.post("allocate", {"appid": "app1",
"side": "abc"}))
def _allocated(data):
self.failUnlessEqual(set(data.keys()),
set(["welcome", "channelid"]))
@ -87,18 +100,20 @@ class API(ServerBase, unittest.TestCase):
self.cid = data["channelid"]
d.addCallback(_allocated)
d.addCallback(lambda _: self.get("list"))
d.addCallback(lambda _: self.get("list", "app1"))
def _check_list_2(data):
self.failUnlessEqual(data["channelids"], [self.cid])
d.addCallback(_check_list_2)
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
{"side": "abc"}))
d.addCallback(lambda _: self.post("deallocate",
{"appid": "app1",
"channelid": str(self.cid),
"side": "abc"}))
def _check_deallocate(res):
self.failUnlessEqual(res["status"], "deleted")
d.addCallback(_check_deallocate)
d.addCallback(lambda _: self.get("list"))
d.addCallback(lambda _: self.get("list", "app1"))
def _check_list_3(data):
self.failUnlessEqual(data["channelids"], [])
d.addCallback(_check_list_3)
@ -106,45 +121,57 @@ class API(ServerBase, unittest.TestCase):
return d
def test_allocate_2(self):
d = self.post("allocate", {"side": "abc"})
d = self.post("allocate", {"appid": "app1", "side": "abc"})
def _allocated(data):
self.cid = data["channelid"]
d.addCallback(_allocated)
# second caller increases the number of known sides to 2
d.addCallback(lambda _: self.post("%d" % self.cid,
{"side": "def",
d.addCallback(lambda _: self.post("add",
{"appid": "app1",
"channelid": str(self.cid),
"side": "def",
"phase": "1",
"body": ""}))
d.addCallback(lambda _: self.get("list"))
d.addCallback(lambda _: self.get("list", "app1"))
d.addCallback(lambda data:
self.failUnlessEqual(data["channelids"], [self.cid]))
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
{"side": "abc"}))
d.addCallback(lambda _: self.post("deallocate",
{"appid": "app1",
"channelid": str(self.cid),
"side": "abc"}))
d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "waiting"))
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
{"side": "NOT"}))
d.addCallback(lambda _: self.post("deallocate",
{"appid": "app1",
"channelid": str(self.cid),
"side": "NOT"}))
d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "waiting"))
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
{"side": "def"}))
d.addCallback(lambda _: self.post("deallocate",
{"appid": "app1",
"channelid": str(self.cid),
"side": "def"}))
d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "deleted"))
d.addCallback(lambda _: self.get("list"))
d.addCallback(lambda _: self.get("list", "app1"))
d.addCallback(lambda data:
self.failUnlessEqual(data["channelids"], []))
return d
def add_message(self, message, side="abc", phase="1"):
return self.post(str(self.cid), {"side": side, "phase": phase,
"body": message})
return self.post("add",
{"appid": "app1",
"channelid": str(self.cid),
"side": side,
"phase": phase,
"body": message})
def parse_messages(self, messages):
out = set()
@ -164,7 +191,7 @@ class API(ServerBase, unittest.TestCase):
self.failUnlessIn(d, two)
def test_messages(self):
d = self.post("allocate", {"side": "abc"})
d = self.post("allocate", {"appid": "app1", "side": "abc"})
def _allocated(data):
self.cid = data["channelid"]
d.addCallback(_allocated)
@ -175,6 +202,8 @@ class API(ServerBase, unittest.TestCase):
self.failUnlessEqual(data["messages"],
[{"phase": "1", "body": "msg1A"}])
d.addCallback(_check1)
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
d.addCallback(_check1)
d.addCallback(lambda _: self.add_message("msg1B", side="def"))
def _check2(data):
self.check_welcome(data)
@ -182,6 +211,8 @@ class API(ServerBase, unittest.TestCase):
set([("1", "msg1A"),
("1", "msg1B")]))
d.addCallback(_check2)
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
d.addCallback(_check2)
# adding a duplicate message is not an error, is ignored by clients
d.addCallback(lambda _: self.add_message("msg1B", side="def"))
@ -191,6 +222,8 @@ class API(ServerBase, unittest.TestCase):
set([("1", "msg1A"),
("1", "msg1B")]))
d.addCallback(_check3)
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
d.addCallback(_check3)
d.addCallback(lambda _: self.add_message("msg2A", side="abc",
phase="2"))
@ -202,6 +235,8 @@ class API(ServerBase, unittest.TestCase):
("2", "msg2A"),
]))
d.addCallback(_check4)
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
d.addCallback(_check4)
return d
@ -209,10 +244,10 @@ class API(ServerBase, unittest.TestCase):
if sys.version_info[0] >= 3:
raise unittest.SkipTest("twisted vs py3")
d = self.post("allocate", {"side": "abc"})
d = self.post("allocate", {"appid": "app1", "side": "abc"})
def _allocated(data):
self.cid = data["channelid"]
url = self.relayurl+str(self.cid)
url = self.build_url("get", "app1", self.cid)
self.o = OneEventAtATime(url, parser=json.loads)
return self.o.wait_for_connection()
d.addCallback(_allocated)

View File

@ -1,11 +1,86 @@
from __future__ import print_function
import sys, json
from twisted.trial import unittest
from twisted.internet.defer import gatherResults
from ..twisted.transcribe import Wormhole, UsageError
from twisted.internet.defer import gatherResults, succeed
from ..twisted.transcribe import Wormhole, UsageError, ChannelManager
from .common import ServerBase
APPID = u"appid"
class Channel(ServerBase, unittest.TestCase):
def ignore(self, welcome):
pass
def test_allocate(self):
cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore)
d = cm.list_channels()
def _got_channels(channels):
self.failUnlessEqual(channels, [])
d.addCallback(_got_channels)
d.addCallback(lambda _: cm.allocate())
def _allocated(channelid):
self.failUnlessEqual(type(channelid), int)
self._channelid = channelid
d.addCallback(_allocated)
d.addCallback(lambda _: cm.connect(self._channelid))
def _connected(c):
self._channel = c
d.addCallback(_connected)
d.addCallback(lambda _: self._channel.deallocate())
return d
def test_messages(self):
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
c1 = cm1.connect(1)
c2 = cm2.connect(1)
d = succeed(None)
d.addCallback(lambda _: c1.send(u"phase1", b"msg1"))
d.addCallback(lambda _: c2.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1"))
d.addCallback(lambda _: c2.send(u"phase1", b"msg2"))
d.addCallback(lambda _: c1.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
# it's legal to fetch a phase multiple times, should be idempotent
d.addCallback(lambda _: c1.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
# deallocating one side is not enough to destroy the channel
d.addCallback(lambda _: c2.deallocate())
def _not_yet(_):
self._relay_server.prune()
self.failUnlessEqual(len(self._relay_server._apps), 1)
d.addCallback(_not_yet)
# but deallocating both will make the messages go away
d.addCallback(lambda _: c1.deallocate())
def _gone(_):
self._relay_server.prune()
self.failUnlessEqual(len(self._relay_server._apps), 0)
d.addCallback(_gone)
return d
def test_appid_independence(self):
APPID_A = u"appid_A"
APPID_B = u"appid_B"
cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore)
cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore)
c1a = cm1a.connect(1)
c2a = cm2a.connect(1)
cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore)
cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore)
c1b = cm1b.connect(1)
c2b = cm2b.connect(1)
d = succeed(None)
d.addCallback(lambda _: c1a.send(u"phase1", b"msg1a"))
d.addCallback(lambda _: c1b.send(u"phase1", b"msg1b"))
d.addCallback(lambda _: c2a.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a"))
d.addCallback(lambda _: c2b.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b"))
return d
class Basic(ServerBase, unittest.TestCase):
def doBoth(self, d1, d2):
@ -154,6 +229,7 @@ class Basic(ServerBase, unittest.TestCase):
return d
if sys.version_info[0] >= 3:
Channel.skip = "twisted is not yet sufficiently ported to py3"
Basic.skip = "twisted is not yet sufficiently ported to py3"
# as of 15.4.0, Twisted is still missing:
# * web.client.Agent (for all non-EventSource POSTs in transcribe.py)

View File

@ -1,5 +1,6 @@
from __future__ import print_function
import os, sys, json, re, unicodedata
from six.moves.urllib_parse import urlencode
from binascii import hexlify, unhexlify
from zope.interface import implementer
from twisted.internet import reactor, defer
@ -50,10 +51,24 @@ def post_json(agent, url, request_body):
d.addCallback(lambda data: json.loads(data))
return d
def get_json(agent, url):
# GET from a URL, parsing the response as JSON
d = agent.request("GET", url.encode("utf-8"))
def _check_error(resp):
if resp.code != 200:
raise web_error.Error(resp.code, resp.phrase)
return resp
d.addCallback(_check_error)
d.addCallback(web_client.readBody)
d.addCallback(lambda data: json.loads(data))
return d
class Channel:
def __init__(self, relay_url, channelid, side, handle_welcome,
def __init__(self, relay_url, appid, channelid, side, handle_welcome,
agent):
self._channel_url = u"%s%d" % (relay_url, channelid)
self._relay_url = relay_url
self._appid = appid
self._channelid = channelid
self._side = side
self._handle_welcome = handle_welcome
self._agent = agent
@ -78,10 +93,12 @@ class Channel:
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
self._sent_messages.add( (phase,msg) )
payload = {"side": self._side,
payload = {"appid": self._appid,
"channelid": self._channelid,
"side": self._side,
"phase": phase,
"body": hexlify(msg).decode("ascii")}
d = post_json(self._agent, self._channel_url, payload)
d = post_json(self._agent, self._relay_url+"add", payload)
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
return d
@ -104,7 +121,10 @@ class Channel:
msgs.append(body)
d.callback(None)
# TODO: use agent=self._agent
es = ReconnectingEventSource(self._channel_url, _handle)
queryargs = urlencode([("appid", self._appid),
("channelid", self._channelid)])
es = ReconnectingEventSource(self._relay_url+"get?%s" % queryargs,
_handle)
es.startService() # TODO: .setServiceParent(self)
es.activate()
d.addCallback(lambda _: es.deactivate())
@ -114,22 +134,26 @@ class Channel:
def deallocate(self):
# only try once, no retries
d = post_json(self._agent, self._channel_url+"/deallocate",
{"side": self._side})
d = post_json(self._agent, self._relay_url+"deallocate",
{"appid": self._appid,
"channelid": self._channelid,
"side": self._side})
d.addBoth(lambda _: None) # ignore POST failure
return d
class ChannelManager:
def __init__(self, relay_url, side, handle_welcome):
assert isinstance(relay_url, type(u""))
self._relay_url = relay_url
def __init__(self, relay, appid, side, handle_welcome):
assert isinstance(relay, type(u""))
self._relay = relay
self._appid = appid
self._side = side
self._handle_welcome = handle_welcome
self._agent = web_client.Agent(reactor)
def allocate(self):
url = self._relay_url + "allocate"
d = post_json(self._agent, url, {"side": self._side})
url = self._relay + "allocate"
d = post_json(self._agent, url, {"appid": self._appid,
"side": self._side})
def _got_channel(data):
if "welcome" in data:
self._handle_welcome(data["welcome"])
@ -138,10 +162,14 @@ class ChannelManager:
return d
def list_channels(self):
raise NotImplementedError
queryargs = urlencode([("appid", self._appid)])
url = self._relay + u"list?%s" % queryargs
d = get_json(self._agent, url)
d.addCallback(lambda r: r["channelids"])
return d
def connect(self, channelid):
return Channel(self._relay_url, channelid, self._side,
return Channel(self._relay, self._appid, channelid, self._side,
self._handle_welcome, self._agent)
class Wormhole:
@ -163,17 +191,16 @@ class Wormhole:
def _set_side(self, side):
self._side = side
self._channel_manager = ChannelManager(self._relay_url, self._side,
self.handle_welcome)
self._channel_manager = ChannelManager(self._relay_url, self._appid,
self._side, self.handle_welcome)
def handle_welcome(self, welcome):
if ("motd" in welcome and
not self.motd_displayed):
motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" % (self._relay_url,
motd_formatted),
file=sys.stderr)
print("Server (at %s) says:\n %s" %
(self._relay_url, motd_formatted), file=sys.stderr)
self.motd_displayed = True
# Only warn if we're running a release version (e.g. 0.0.6, not