scope channelids to the appid, change API and DB schema
This requires a DB delete/recreate when upgrading. It changes the server protocol, and app IDs, so clients cannot interoperate with each other across this change, nor with the server. Flag day for everyone! Now apps do not share channel IDs, so a lot of usage of app1 will not cause the wormhole codes for app2 to get longer.
This commit is contained in:
parent
8692bd2cd7
commit
574d5f2314
|
@ -1,5 +1,6 @@
|
|||
from __future__ import print_function
|
||||
import os, sys, time, re, requests, json, unicodedata
|
||||
from six.moves.urllib_parse import urlencode
|
||||
from binascii import hexlify, unhexlify
|
||||
from spake2 import SPAKE2_Symmetric
|
||||
from nacl.secret import SecretBox
|
||||
|
@ -18,19 +19,21 @@ MINUTE = 60*SECOND
|
|||
def to_bytes(u):
|
||||
return unicodedata.normalize("NFC", u).encode("utf-8")
|
||||
|
||||
# relay URLs are:
|
||||
# GET /list -> {channelids: [INT..]}
|
||||
# POST /allocate {side: SIDE} -> {channelid: INT}
|
||||
# these return all messages (base64) for CID= :
|
||||
# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]}
|
||||
# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]}
|
||||
# GET /CID (eventsource) -> {phase:, body:}..
|
||||
# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted}
|
||||
# relay URLs are as follows: (MESSAGES=[{phase:,body:}..])
|
||||
# GET /list?appid= -> {channelids: [INT..]}
|
||||
# POST /allocate {appid:,side:} -> {channelid: INT}
|
||||
# these return all messages (base64) for appid=/channelid= :
|
||||
# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES}
|
||||
# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES}
|
||||
# GET /get?appid=&channelid= (eventsource) -> {phase:, body:}..
|
||||
# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted}
|
||||
# all JSON responses include a "welcome:{..}" key
|
||||
|
||||
class Channel:
|
||||
def __init__(self, relay_url, channelid, side, handle_welcome):
|
||||
self._channel_url = u"%s%d" % (relay_url, channelid)
|
||||
def __init__(self, relay_url, appid, channelid, side, handle_welcome):
|
||||
self._relay_url = relay_url
|
||||
self._appid = appid
|
||||
self._channelid = channelid
|
||||
self._side = side
|
||||
self._handle_welcome = handle_welcome
|
||||
self._messages = set() # (phase,body) , body is bytes
|
||||
|
@ -57,11 +60,13 @@ class Channel:
|
|||
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
||||
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
|
||||
self._sent_messages.add( (phase,msg) )
|
||||
payload = {"side": self._side,
|
||||
payload = {"appid": self._appid,
|
||||
"channelid": self._channelid,
|
||||
"side": self._side,
|
||||
"phase": phase,
|
||||
"body": hexlify(msg).decode("ascii")}
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
r = requests.post(self._channel_url, data=data)
|
||||
r = requests.post(self._relay_url+"add", data=data)
|
||||
r.raise_for_status()
|
||||
resp = r.json()
|
||||
self._add_inbound_messages(resp["messages"])
|
||||
|
@ -80,7 +85,10 @@ class Channel:
|
|||
remaining = self._started + self._timeout - time.time()
|
||||
if remaining < 0:
|
||||
return Timeout
|
||||
f = EventSourceFollower(self._channel_url, remaining)
|
||||
queryargs = urlencode([("appid", self._appid),
|
||||
("channelid", self._channelid)])
|
||||
f = EventSourceFollower(self._relay_url+"get?%s" % queryargs,
|
||||
remaining)
|
||||
# we loop here until the connection is lost, or we see the
|
||||
# message we want
|
||||
for (eventtype, data) in f.iter_events():
|
||||
|
@ -98,25 +106,30 @@ class Channel:
|
|||
|
||||
def deallocate(self):
|
||||
# only try once, no retries
|
||||
data = json.dumps({"side": self._side}).encode("utf-8")
|
||||
requests.post(self._channel_url+"/deallocate", data=data)
|
||||
data = json.dumps({"appid": self._appid,
|
||||
"channelid": self._channelid,
|
||||
"side": self._side}).encode("utf-8")
|
||||
requests.post(self._relay_url+"deallocate", data=data)
|
||||
# ignore POST failure, don't call r.raise_for_status()
|
||||
|
||||
class ChannelManager:
|
||||
def __init__(self, relay_url, side, handle_welcome):
|
||||
def __init__(self, relay_url, appid, side, handle_welcome):
|
||||
self._relay_url = relay_url
|
||||
self._appid = appid
|
||||
self._side = side
|
||||
self._handle_welcome = handle_welcome
|
||||
|
||||
def list_channels(self):
|
||||
r = requests.get(self._relay_url + "list")
|
||||
queryargs = urlencode([("appid", self._appid)])
|
||||
r = requests.get(self._relay_url+"list?%s" % queryargs)
|
||||
r.raise_for_status()
|
||||
channelids = r.json()["channelids"]
|
||||
return channelids
|
||||
|
||||
def allocate(self):
|
||||
data = json.dumps({"side": self._side}).encode("utf-8")
|
||||
r = requests.post(self._relay_url + "allocate", data=data)
|
||||
data = json.dumps({"appid": self._appid,
|
||||
"side": self._side}).encode("utf-8")
|
||||
r = requests.post(self._relay_url+"allocate", data=data)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "welcome" in data:
|
||||
|
@ -125,7 +138,7 @@ class ChannelManager:
|
|||
return channelid
|
||||
|
||||
def connect(self, channelid):
|
||||
return Channel(self._relay_url, channelid, self._side,
|
||||
return Channel(self._relay_url, self._appid, channelid, self._side,
|
||||
self._handle_welcome)
|
||||
|
||||
class Wormhole:
|
||||
|
@ -139,7 +152,7 @@ class Wormhole:
|
|||
self._appid = appid
|
||||
self._relay_url = relay_url
|
||||
side = hexlify(os.urandom(5)).decode("ascii")
|
||||
self._channel_manager = ChannelManager(relay_url, side,
|
||||
self._channel_manager = ChannelManager(relay_url, appid, side,
|
||||
self.handle_welcome)
|
||||
self.code = None
|
||||
self.key = None
|
||||
|
@ -152,8 +165,7 @@ class Wormhole:
|
|||
not self.motd_displayed):
|
||||
motd_lines = welcome["motd"].splitlines()
|
||||
motd_formatted = "\n ".join(motd_lines)
|
||||
print("Server (at %s) says:\n %s" % (self._relay_url,
|
||||
motd_formatted),
|
||||
print("Server (at %s) says:\n %s" % (self._relay_url, motd_formatted),
|
||||
file=sys.stderr)
|
||||
self.motd_displayed = True
|
||||
|
||||
|
|
|
@ -9,16 +9,18 @@ CREATE TABLE `version`
|
|||
|
||||
CREATE TABLE `messages`
|
||||
(
|
||||
`appid` VARCHAR,
|
||||
`channelid` INTEGER,
|
||||
`side` VARCHAR,
|
||||
`phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string
|
||||
`body` VARCHAR,
|
||||
`when` INTEGER
|
||||
);
|
||||
CREATE INDEX `messages_idx` ON `messages` (`channelid`, `side`, `phase`);
|
||||
CREATE INDEX `messages_idx` ON `messages` (`appid`, `channelid`);
|
||||
|
||||
CREATE TABLE `allocations`
|
||||
(
|
||||
`appid` VARCHAR,
|
||||
`channelid` INTEGER,
|
||||
`side` VARCHAR
|
||||
);
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from __future__ import print_function
|
||||
import re, json, time, random
|
||||
import json, time, random
|
||||
from twisted.python import log
|
||||
from twisted.application import service, internet
|
||||
from twisted.web import server, resource, http
|
||||
from twisted.web import server, resource
|
||||
|
||||
SECONDS = 1.0
|
||||
MINUTE = 60*SECONDS
|
||||
|
@ -13,6 +13,10 @@ MB = 1000*1000
|
|||
CHANNEL_EXPIRATION_TIME = 3*DAY
|
||||
EXPIRATION_CHECK_PERIOD = 2*HOUR
|
||||
|
||||
def json_response(request, data):
|
||||
request.setHeader(b"content-type", b"application/json; charset=utf-8")
|
||||
return (json.dumps(data)+"\n").encode("utf-8")
|
||||
|
||||
class EventsProtocol:
|
||||
def __init__(self, request):
|
||||
self.request = request
|
||||
|
@ -46,109 +50,185 @@ class EventsProtocol:
|
|||
|
||||
# note: no versions of IE (including the current IE11) support EventSource
|
||||
|
||||
# relay URLs are:
|
||||
# GET /list -> {channelids: [INT..]}
|
||||
# POST /allocate {side: SIDE} -> {channelid: INT}
|
||||
# these return all messages (base64) for CID= :
|
||||
# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]}
|
||||
# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]}
|
||||
# GET /CID (eventsource) -> {phase:, body:}..
|
||||
# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted}
|
||||
# relay URLs are as follows: (MESSAGES=[{phase:,body:}..])
|
||||
# GET /list?appid= -> {channelids: [INT..]}
|
||||
# POST /allocate {appid:,side:} -> {channelid: INT}
|
||||
# these return all messages (base64) for appid=/channelid= :
|
||||
# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES}
|
||||
# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES}
|
||||
# GET /get?appid=&channelid= (eventsource) -> {phase:, body:}..
|
||||
# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted}
|
||||
# all JSON responses include a "welcome:{..}" key
|
||||
|
||||
class Channel(resource.Resource):
|
||||
def __init__(self, channelid, relay, db, welcome):
|
||||
class ChannelLister(resource.Resource):
|
||||
def __init__(self, relay):
|
||||
resource.Resource.__init__(self)
|
||||
self.channelid = channelid
|
||||
self.relay = relay
|
||||
self.db = db
|
||||
self.welcome = welcome
|
||||
self.event_channels = set() # ep
|
||||
self.putChild(b"deallocate", Deallocator(self.channelid, self.relay))
|
||||
|
||||
def get_messages(self, request):
|
||||
request.setHeader(b"content-type", b"application/json; charset=utf-8")
|
||||
messages = []
|
||||
for row in self.db.execute("SELECT * FROM `messages`"
|
||||
" WHERE `channelid`=?"
|
||||
" ORDER BY `when` ASC",
|
||||
(self.channelid,)).fetchall():
|
||||
messages.append({"phase": row["phase"], "body": row["body"]})
|
||||
data = {"welcome": self.welcome, "messages": messages}
|
||||
return (json.dumps(data)+"\n").encode("utf-8")
|
||||
self._relay = relay
|
||||
|
||||
def render_GET(self, request):
|
||||
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
|
||||
return self.get_messages(request)
|
||||
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
|
||||
ep = EventsProtocol(request)
|
||||
ep.sendEvent(json.dumps(self.welcome), name="welcome")
|
||||
self.event_channels.add(ep)
|
||||
request.notifyFinish().addErrback(lambda f:
|
||||
self.event_channels.discard(ep))
|
||||
for row in self.db.execute("SELECT * FROM `messages`"
|
||||
" WHERE `channelid`=?"
|
||||
" ORDER BY `when` ASC",
|
||||
(self.channelid,)).fetchall():
|
||||
data = json.dumps({"phase": row["phase"], "body": row["body"]})
|
||||
ep.sendEvent(data)
|
||||
return server.NOT_DONE_YET
|
||||
appid = request.args[b"appid"][0].decode("utf-8")
|
||||
#print("LIST", appid)
|
||||
app = self._relay.get_app(appid)
|
||||
allocated = app.get_allocated()
|
||||
request.setHeader(b"content-type", b"application/json; charset=utf-8")
|
||||
data = {"welcome": self._relay.welcome,
|
||||
"channelids": sorted(allocated)}
|
||||
return (json.dumps(data)+"\n").encode("utf-8")
|
||||
|
||||
def broadcast_message(self, phase, body):
|
||||
data = json.dumps({"phase": phase, "body": body})
|
||||
for ep in self.event_channels:
|
||||
ep.sendEvent(data)
|
||||
class Allocator(resource.Resource):
|
||||
def __init__(self, relay):
|
||||
resource.Resource.__init__(self)
|
||||
self._relay = relay
|
||||
|
||||
def render_POST(self, request):
|
||||
#data = json.load(request.content, encoding="utf-8")
|
||||
content = request.content.read()
|
||||
data = json.loads(content.decode("utf-8"))
|
||||
appid = data["appid"]
|
||||
side = data["side"]
|
||||
if not isinstance(side, type(u"")):
|
||||
raise TypeError("side must be string, not '%s'" % type(side))
|
||||
#print("ALLOCATE", appid, side)
|
||||
app = self._relay.get_app(appid)
|
||||
channelid = app.find_available_channelid()
|
||||
app.allocate_channel(channelid, side)
|
||||
log.msg("allocated #%d, now have %d DB channels" %
|
||||
(channelid, len(app.get_allocated())))
|
||||
request.setHeader(b"content-type", b"application/json; charset=utf-8")
|
||||
data = {"welcome": self._relay.welcome,
|
||||
"channelid": channelid}
|
||||
return (json.dumps(data)+"\n").encode("utf-8")
|
||||
|
||||
class Adder(resource.Resource):
|
||||
def __init__(self, relay):
|
||||
resource.Resource.__init__(self)
|
||||
self._relay = relay
|
||||
|
||||
def render_POST(self, request):
|
||||
#content = json.load(request.content, encoding="utf-8")
|
||||
content = request.content.read()
|
||||
data = json.loads(content.decode("utf-8"))
|
||||
appid = data["appid"]
|
||||
channelid = int(data["channelid"])
|
||||
side = data["side"]
|
||||
phase = data["phase"]
|
||||
if not isinstance(phase, type(u"")):
|
||||
raise TypeError("phase must be string, not %s" % type(phase))
|
||||
body = data["body"]
|
||||
#print("ADD", appid, channelid, side, phase, body)
|
||||
|
||||
self.db.execute("INSERT INTO `messages`"
|
||||
" (`channelid`, `side`, `phase`, `body`, `when`)"
|
||||
" VALUES (?,?,?,?,?)",
|
||||
(self.channelid, side, phase, body, time.time()))
|
||||
self.db.execute("INSERT INTO `allocations`"
|
||||
" (`channelid`, `side`)"
|
||||
" VALUES (?,?)",
|
||||
(self.channelid, side))
|
||||
self.db.commit()
|
||||
self.broadcast_message(phase, body)
|
||||
return self.get_messages(request)
|
||||
app = self._relay.get_app(appid)
|
||||
channel = app.get_channel(channelid)
|
||||
response = channel.add_message(side, phase, body)
|
||||
return json_response(request, response)
|
||||
|
||||
class Getter(resource.Resource):
|
||||
def __init__(self, relay):
|
||||
self._relay = relay
|
||||
|
||||
def render_GET(self, request):
|
||||
appid = request.args[b"appid"][0].decode("utf-8")
|
||||
channelid = int(request.args[b"channelid"][0])
|
||||
#print("GET", appid, channelid)
|
||||
app = self._relay.get_app(appid)
|
||||
channel = app.get_channel(channelid)
|
||||
|
||||
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
|
||||
response = channel.get_messages()
|
||||
return json_response(request, response)
|
||||
|
||||
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
|
||||
ep = EventsProtocol(request)
|
||||
ep.sendEvent(json.dumps(self._relay.welcome), name="welcome")
|
||||
old_events = channel.add_listener(ep.sendEvent)
|
||||
request.notifyFinish().addErrback(lambda f:
|
||||
channel.remove_listener(ep.sendEvent))
|
||||
for old_event in old_events:
|
||||
ep.sendEvent(old_event)
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
class Deallocator(resource.Resource):
|
||||
def __init__(self, channelid, relay):
|
||||
self.channelid = channelid
|
||||
self.relay = relay
|
||||
def __init__(self, relay):
|
||||
self._relay = relay
|
||||
|
||||
def render_POST(self, request):
|
||||
content = request.content.read()
|
||||
data = json.loads(content.decode("utf-8"))
|
||||
appid = data["appid"]
|
||||
channelid = int(data["channelid"])
|
||||
side = data["side"]
|
||||
deleted = self.relay.maybe_free_child(self.channelid, side)
|
||||
resp = {"status": "waiting"}
|
||||
#print("DEALLOCATE", appid, channelid, side)
|
||||
app = self._relay.get_app(appid)
|
||||
deleted = app.maybe_free_child(channelid, side)
|
||||
response = {"status": "waiting"}
|
||||
if deleted:
|
||||
resp = {"status": "deleted"}
|
||||
return json.dumps(resp).encode("utf-8")
|
||||
response = {"status": "deleted"}
|
||||
return json_response(request, response)
|
||||
|
||||
def get_allocated(db):
|
||||
c = db.execute("SELECT DISTINCT `channelid` FROM `allocations`")
|
||||
return set([row["channelid"] for row in c.fetchall()])
|
||||
|
||||
class Allocator(resource.Resource):
|
||||
def __init__(self, db, welcome):
|
||||
class Channel(resource.Resource):
|
||||
def __init__(self, relay, appid, channelid):
|
||||
resource.Resource.__init__(self)
|
||||
self.db = db
|
||||
self.welcome = welcome
|
||||
self._relay = relay
|
||||
self._appid = appid
|
||||
self._channelid = channelid
|
||||
self._listeners = set() # callbacks that take JSONable object
|
||||
|
||||
def allocate_channelid(self):
|
||||
allocated = get_allocated(self.db)
|
||||
def get_messages(self):
|
||||
messages = []
|
||||
db = self._relay.db
|
||||
for row in db.execute("SELECT * FROM `messages`"
|
||||
" WHERE `appid`=? AND `channelid`=?"
|
||||
" ORDER BY `when` ASC",
|
||||
(self._appid, self._channelid)).fetchall():
|
||||
messages.append({"phase": row["phase"], "body": row["body"]})
|
||||
data = {"welcome": self._relay.welcome, "messages": messages}
|
||||
return data
|
||||
|
||||
def add_listener(self, listener):
|
||||
self._listeners.add(listener)
|
||||
db = self._relay.db
|
||||
for row in db.execute("SELECT * FROM `messages`"
|
||||
" WHERE `appid`=? AND `channelid`=?"
|
||||
" ORDER BY `when` ASC",
|
||||
(self._appid, self._channelid)).fetchall():
|
||||
yield json.dumps({"phase": row["phase"], "body": row["body"]})
|
||||
def remove_listener(self, listener):
|
||||
self._listeners.discard(listener)
|
||||
|
||||
def broadcast_message(self, phase, body):
|
||||
data = json.dumps({"phase": phase, "body": body})
|
||||
for listener in self._listeners:
|
||||
listener(data)
|
||||
|
||||
def add_message(self, side, phase, body):
|
||||
db = self._relay.db
|
||||
db.execute("INSERT INTO `messages`"
|
||||
" (`appid`, `channelid`, `side`, `phase`, `body`, `when`)"
|
||||
" VALUES (?,?,?,?, ?,?)",
|
||||
(self._appid, self._channelid, side, phase,
|
||||
body, time.time()))
|
||||
db.execute("INSERT INTO `allocations`"
|
||||
" (`appid`, `channelid`, `side`)"
|
||||
" VALUES (?,?,?)",
|
||||
(self._appid, self._channelid, side))
|
||||
db.commit()
|
||||
self.broadcast_message(phase, body)
|
||||
return self.get_messages()
|
||||
|
||||
class AppNamespace(resource.Resource):
|
||||
def __init__(self, relay, appid):
|
||||
resource.Resource.__init__(self)
|
||||
self._relay = relay
|
||||
self._appid = appid
|
||||
self._channels = {}
|
||||
|
||||
def get_allocated(self):
|
||||
db = self._relay.db
|
||||
c = db.execute("SELECT DISTINCT `channelid` FROM `allocations`"
|
||||
" WHERE `appid`=?", (self._appid,))
|
||||
return set([row["channelid"] for row in c.fetchall()])
|
||||
|
||||
def find_available_channelid(self):
|
||||
allocated = self.get_allocated()
|
||||
for size in range(1,4): # stick to 1-999 for now
|
||||
available = set()
|
||||
for cid in range(10**(size-1), 10**size):
|
||||
|
@ -161,37 +241,63 @@ class Allocator(resource.Resource):
|
|||
cid = random.randrange(1000, 1000*1000)
|
||||
if cid not in allocated:
|
||||
return cid
|
||||
raise ValueError("unable to find a free channelid")
|
||||
raise ValueError("unable to find a free channel-id")
|
||||
|
||||
def render_POST(self, request):
|
||||
content = request.content.read()
|
||||
data = json.loads(content.decode("utf-8"))
|
||||
side = data["side"]
|
||||
if not isinstance(side, type(u"")):
|
||||
raise TypeError("side must be string, not '%s'" % type(side))
|
||||
channelid = self.allocate_channelid()
|
||||
self.db.execute("INSERT INTO `allocations` VALUES (?,?)",
|
||||
(channelid, side))
|
||||
self.db.commit()
|
||||
log.msg("allocated #%d, now have %d DB channels" %
|
||||
(channelid, len(get_allocated(self.db))))
|
||||
request.setHeader(b"content-type", b"application/json; charset=utf-8")
|
||||
data = {"welcome": self.welcome,
|
||||
"channelid": channelid}
|
||||
return (json.dumps(data)+"\n").encode("utf-8")
|
||||
def allocate_channel(self, channelid, side):
|
||||
db = self._relay.db
|
||||
db.execute("INSERT INTO `allocations` VALUES (?,?,?)",
|
||||
(self._appid, channelid, side))
|
||||
db.commit()
|
||||
|
||||
class ChannelList(resource.Resource):
|
||||
def __init__(self, db, welcome):
|
||||
resource.Resource.__init__(self)
|
||||
self.db = db
|
||||
self.welcome = welcome
|
||||
def render_GET(self, request):
|
||||
c = self.db.execute("SELECT DISTINCT `channelid` FROM `allocations`")
|
||||
allocated = sorted(set([row["channelid"] for row in c.fetchall()]))
|
||||
request.setHeader(b"content-type", b"application/json; charset=utf-8")
|
||||
data = {"welcome": self.welcome,
|
||||
"channelids": allocated}
|
||||
return (json.dumps(data)+"\n").encode("utf-8")
|
||||
def get_channel(self, channelid):
|
||||
assert isinstance(channelid, int)
|
||||
if not channelid in self._channels:
|
||||
log.msg("spawning #%d for appid %s" % (channelid, self._appid))
|
||||
self._channels[channelid] = Channel(self._relay,
|
||||
self._appid, channelid)
|
||||
return self._channels[channelid]
|
||||
|
||||
def maybe_free_child(self, channelid, side):
|
||||
db = self._relay.db
|
||||
db.execute("DELETE FROM `allocations`"
|
||||
" WHERE `appid`=? AND `channelid`=? AND `side`=?",
|
||||
(self._appid, channelid, side))
|
||||
db.commit()
|
||||
remaining = db.execute("SELECT COUNT(*) FROM `allocations`"
|
||||
" WHERE `appid`=? AND `channelid`=?",
|
||||
(self._appid, channelid)).fetchone()[0]
|
||||
if remaining:
|
||||
return False
|
||||
self._free_child(channelid)
|
||||
return True
|
||||
|
||||
def _free_child(self, channelid):
|
||||
db = self._relay.db
|
||||
db.execute("DELETE FROM `allocations`"
|
||||
" WHERE `appid`=? AND `channelid`=?",
|
||||
(self._appid, channelid))
|
||||
db.execute("DELETE FROM `messages`"
|
||||
" WHERE `appid`=? AND `channelid`=?",
|
||||
(self._appid, channelid))
|
||||
db.commit()
|
||||
if channelid in self._channels:
|
||||
self._channels.pop(channelid)
|
||||
log.msg("freed+killed #%d, now have %d DB channels, %d live" %
|
||||
(channelid, len(self.get_allocated()), len(self._channels)))
|
||||
|
||||
def prune_old_channels(self):
|
||||
db = self._relay.db
|
||||
old = time.time() - CHANNEL_EXPIRATION_TIME
|
||||
for channelid in self.get_allocated():
|
||||
c = db.execute("SELECT `when` FROM `messages`"
|
||||
" WHERE `appid`=? AND `channelid`=?"
|
||||
" ORDER BY `when` DESC LIMIT 1",
|
||||
(self._appid, channelid))
|
||||
rows = c.fetchall()
|
||||
if not rows or (rows[0]["when"] < old):
|
||||
log.msg("expiring %d" % channelid)
|
||||
self._free_child(channelid)
|
||||
return bool(self._channels)
|
||||
|
||||
class Relay(resource.Resource, service.MultiService):
|
||||
def __init__(self, db, welcome):
|
||||
|
@ -199,60 +305,24 @@ class Relay(resource.Resource, service.MultiService):
|
|||
service.MultiService.__init__(self)
|
||||
self.db = db
|
||||
self.welcome = welcome
|
||||
self.channels = {}
|
||||
t = internet.TimerService(EXPIRATION_CHECK_PERIOD,
|
||||
self.prune_old_channels)
|
||||
self._apps = {}
|
||||
t = internet.TimerService(EXPIRATION_CHECK_PERIOD, self.prune)
|
||||
t.setServiceParent(self)
|
||||
self.putChild(b"list", ChannelLister(self))
|
||||
self.putChild(b"allocate", Allocator(self))
|
||||
self.putChild(b"add", Adder(self))
|
||||
self.putChild(b"get", Getter(self))
|
||||
self.putChild(b"deallocate", Deallocator(self))
|
||||
|
||||
def get_app(self, appid):
|
||||
assert isinstance(appid, type(u""))
|
||||
if not appid in self._apps:
|
||||
log.msg("spawning appid %s" % (appid,))
|
||||
self._apps[appid] = AppNamespace(self, appid)
|
||||
return self._apps[appid]
|
||||
|
||||
def getChild(self, path, request):
|
||||
if path == b"allocate":
|
||||
return Allocator(self.db, self.welcome)
|
||||
if path == b"list":
|
||||
return ChannelList(self.db, self.welcome)
|
||||
if not re.search(br'^\d+$', path):
|
||||
return resource.ErrorPage(http.BAD_REQUEST,
|
||||
"invalid channel id",
|
||||
"invalid channel id")
|
||||
channelid = int(path)
|
||||
if not channelid in self.channels:
|
||||
log.msg("spawning #%d" % channelid)
|
||||
self.channels[channelid] = Channel(channelid, self, self.db,
|
||||
self.welcome)
|
||||
return self.channels[channelid]
|
||||
|
||||
def maybe_free_child(self, channelid, side):
|
||||
self.db.execute("DELETE FROM `allocations`"
|
||||
" WHERE `channelid`=? AND `side`=?",
|
||||
(channelid, side))
|
||||
self.db.commit()
|
||||
remaining = self.db.execute("SELECT COUNT(*) FROM `allocations`"
|
||||
" WHERE `channelid`=?",
|
||||
(channelid,)).fetchone()[0]
|
||||
if remaining:
|
||||
return False
|
||||
self.free_child(channelid)
|
||||
return True
|
||||
|
||||
def free_child(self, channelid):
|
||||
self.db.execute("DELETE FROM `allocations` WHERE `channelid`=?",
|
||||
(channelid,))
|
||||
self.db.execute("DELETE FROM `messages` WHERE `channelid`=?",
|
||||
(channelid,))
|
||||
self.db.commit()
|
||||
if channelid in self.channels:
|
||||
self.channels.pop(channelid)
|
||||
log.msg("freed+killed #%d, now have %d DB channels, %d live" %
|
||||
(channelid, len(get_allocated(self.db)), len(self.channels)))
|
||||
|
||||
def prune_old_channels(self):
|
||||
old = time.time() - CHANNEL_EXPIRATION_TIME
|
||||
for channelid in get_allocated(self.db):
|
||||
c = self.db.execute("SELECT `when` FROM `messages`"
|
||||
" WHERE `channelid`=?"
|
||||
" ORDER BY `when` DESC LIMIT 1", (channelid,))
|
||||
rows = c.fetchall()
|
||||
if not rows or (rows[0]["when"] < old):
|
||||
log.msg("expiring %d" % channelid)
|
||||
self.free_child(channelid)
|
||||
|
||||
def prune(self):
|
||||
for appid in list(self._apps):
|
||||
still_active = self._apps[appid].prune_old_channels()
|
||||
if not still_active:
|
||||
self._apps.pop(appid)
|
||||
|
|
|
@ -14,6 +14,7 @@ class ServerBase:
|
|||
"tcp:%s:interface=127.0.0.1" % transitport,
|
||||
__version__)
|
||||
s.setServiceParent(self.sp)
|
||||
self._relay_server = s.relay
|
||||
self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport
|
||||
self.transit = "tcp:127.0.0.1:%d" % transitport
|
||||
d.addCallback(_got_ports)
|
||||
|
|
|
@ -1,12 +1,89 @@
|
|||
from __future__ import print_function
|
||||
import json
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.defer import gatherResults
|
||||
from twisted.internet.defer import gatherResults, succeed
|
||||
from twisted.internet.threads import deferToThread
|
||||
from ..blocking.transcribe import Wormhole as BlockingWormhole, UsageError
|
||||
from ..blocking.transcribe import (Wormhole as BlockingWormhole, UsageError,
|
||||
ChannelManager)
|
||||
from .common import ServerBase
|
||||
|
||||
APPID = u"appid"
|
||||
|
||||
class Channel(ServerBase, unittest.TestCase):
|
||||
def ignore(self, welcome):
|
||||
pass
|
||||
|
||||
def test_allocate(self):
|
||||
cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore)
|
||||
d = deferToThread(cm.list_channels)
|
||||
def _got_channels(channels):
|
||||
self.failUnlessEqual(channels, [])
|
||||
d.addCallback(_got_channels)
|
||||
d.addCallback(lambda _: deferToThread(cm.allocate))
|
||||
def _allocated(channelid):
|
||||
self.failUnlessEqual(type(channelid), int)
|
||||
self._channelid = channelid
|
||||
d.addCallback(_allocated)
|
||||
d.addCallback(lambda _: deferToThread(cm.connect, self._channelid))
|
||||
def _connected(c):
|
||||
self._channel = c
|
||||
d.addCallback(_connected)
|
||||
d.addCallback(lambda _: deferToThread(self._channel.deallocate))
|
||||
return d
|
||||
|
||||
def test_messages(self):
|
||||
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
|
||||
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
|
||||
c1 = cm1.connect(1)
|
||||
c2 = cm2.connect(1)
|
||||
|
||||
d = succeed(None)
|
||||
d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1"))
|
||||
d.addCallback(lambda _: deferToThread(c2.get, u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1"))
|
||||
d.addCallback(lambda _: deferToThread(c2.send, u"phase1", b"msg2"))
|
||||
d.addCallback(lambda _: deferToThread(c1.get, u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
|
||||
# it's legal to fetch a phase multiple times, should be idempotent
|
||||
d.addCallback(lambda _: deferToThread(c1.get, u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
|
||||
# deallocating one side is not enough to destroy the channel
|
||||
d.addCallback(lambda _: deferToThread(c2.deallocate))
|
||||
def _not_yet(_):
|
||||
self._relay_server.prune()
|
||||
self.failUnlessEqual(len(self._relay_server._apps), 1)
|
||||
d.addCallback(_not_yet)
|
||||
# but deallocating both will make the messages go away
|
||||
d.addCallback(lambda _: deferToThread(c1.deallocate))
|
||||
def _gone(_):
|
||||
self._relay_server.prune()
|
||||
self.failUnlessEqual(len(self._relay_server._apps), 0)
|
||||
d.addCallback(_gone)
|
||||
|
||||
return d
|
||||
|
||||
def test_appid_independence(self):
|
||||
APPID_A = u"appid_A"
|
||||
APPID_B = u"appid_B"
|
||||
cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore)
|
||||
cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore)
|
||||
c1a = cm1a.connect(1)
|
||||
c2a = cm2a.connect(1)
|
||||
cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore)
|
||||
cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore)
|
||||
c1b = cm1b.connect(1)
|
||||
c2b = cm2b.connect(1)
|
||||
|
||||
d = succeed(None)
|
||||
d.addCallback(lambda _: deferToThread(c1a.send, u"phase1", b"msg1a"))
|
||||
d.addCallback(lambda _: deferToThread(c1b.send, u"phase1", b"msg1b"))
|
||||
d.addCallback(lambda _: deferToThread(c2a.get, u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a"))
|
||||
d.addCallback(lambda _: deferToThread(c2b.get, u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b"))
|
||||
return d
|
||||
|
||||
|
||||
class Blocking(ServerBase, unittest.TestCase):
|
||||
# we need Twisted to run the server, but we run the sender and receiver
|
||||
# with deferToThread()
|
||||
|
|
|
@ -99,28 +99,27 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
out, err, rc = res
|
||||
out = out.decode("utf-8")
|
||||
err = err.decode("utf-8")
|
||||
self.failUnlessEqual(out,
|
||||
"Sending text message (%d bytes)\n"
|
||||
"On the other computer, please run: "
|
||||
"wormhole receive\n"
|
||||
"Wormhole code is: %s\n\n"
|
||||
"text message sent\n" % (len(message), code)
|
||||
)
|
||||
self.failUnlessEqual(err, "")
|
||||
self.failUnlessEqual(rc, 0)
|
||||
self.maxDiff = None
|
||||
expected = ("Sending text message (%d bytes)\n"
|
||||
"On the other computer, please run: "
|
||||
"wormhole receive\n"
|
||||
"Wormhole code is: %s\n\n"
|
||||
"text message sent\n" % (len(message), code))
|
||||
self.failUnlessEqual( (expected, "", 0),
|
||||
(out, err, rc) )
|
||||
return d2
|
||||
d1.addCallback(_check_sender)
|
||||
def _check_receiver(res):
|
||||
out, err, rc = res
|
||||
out = out.decode("utf-8")
|
||||
err = err.decode("utf-8")
|
||||
self.failUnlessEqual(out, message+"\n")
|
||||
self.failUnlessEqual(err, "")
|
||||
self.failUnlessEqual(rc, 0)
|
||||
self.failUnlessEqual( (message+"\n", "", 0),
|
||||
(out, err, rc) )
|
||||
d1.addCallback(_check_receiver)
|
||||
return d1
|
||||
|
||||
def test_send_file_pre_generated_code(self):
|
||||
self.maxDiff=None
|
||||
code = "1-abc"
|
||||
filename = "testfile"
|
||||
message = "test message"
|
||||
|
@ -150,6 +149,7 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
out, err, rc = res
|
||||
out = out.decode("utf-8")
|
||||
err = err.decode("utf-8")
|
||||
self.failUnlessEqual(err, "")
|
||||
self.failUnlessIn("Sending %d byte file named '%s'\n" %
|
||||
(len(message), filename), out)
|
||||
self.failUnlessIn("On the other computer, please run: "
|
||||
|
@ -159,7 +159,6 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
self.failUnlessIn("File sent.. waiting for confirmation\n"
|
||||
"Confirmation received. Transfer complete.\n",
|
||||
out)
|
||||
self.failUnlessEqual(err, "")
|
||||
self.failUnlessEqual(rc, 0)
|
||||
return d2
|
||||
d1.addCallback(_check_sender)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import print_function
|
||||
import sys, json
|
||||
import requests
|
||||
from six.moves.urllib_parse import urlencode
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.internet.threads import deferToThread
|
||||
|
@ -55,15 +56,26 @@ def unjson(data):
|
|||
return json.loads(data.decode("utf-8"))
|
||||
|
||||
class API(ServerBase, unittest.TestCase):
|
||||
def get(self, path, is_json=True):
|
||||
url = (self.relayurl+path).encode("ascii")
|
||||
d = getPage(url)
|
||||
if is_json:
|
||||
d.addCallback(unjson)
|
||||
def build_url(self, path, appid, channelid):
|
||||
url = self.relayurl+path
|
||||
queryargs = []
|
||||
if appid:
|
||||
queryargs.append(("appid", appid))
|
||||
if channelid:
|
||||
queryargs.append(("channelid", channelid))
|
||||
if queryargs:
|
||||
url += "?" + urlencode(queryargs)
|
||||
return url
|
||||
|
||||
def get(self, path, appid=None, channelid=None):
|
||||
url = self.build_url(path, appid, channelid)
|
||||
d = getPage(url.encode("ascii"))
|
||||
d.addCallback(unjson)
|
||||
return d
|
||||
|
||||
def post(self, path, data):
|
||||
url = (self.relayurl+path).encode("ascii")
|
||||
d = getPage(url, method=b"POST",
|
||||
url = self.relayurl+path
|
||||
d = getPage(url.encode("ascii"), method=b"POST",
|
||||
postdata=json.dumps(data).encode("utf-8"))
|
||||
d.addCallback(unjson)
|
||||
return d
|
||||
|
@ -73,13 +85,14 @@ class API(ServerBase, unittest.TestCase):
|
|||
self.failUnlessEqual(data["welcome"], {"current_version": __version__})
|
||||
|
||||
def test_allocate_1(self):
|
||||
d = self.get("list")
|
||||
d = self.get("list", "app1")
|
||||
def _check_list_1(data):
|
||||
self.check_welcome(data)
|
||||
self.failUnlessEqual(data["channelids"], [])
|
||||
d.addCallback(_check_list_1)
|
||||
|
||||
d.addCallback(lambda _: self.post("allocate", {"side": "abc"}))
|
||||
d.addCallback(lambda _: self.post("allocate", {"appid": "app1",
|
||||
"side": "abc"}))
|
||||
def _allocated(data):
|
||||
self.failUnlessEqual(set(data.keys()),
|
||||
set(["welcome", "channelid"]))
|
||||
|
@ -87,18 +100,20 @@ class API(ServerBase, unittest.TestCase):
|
|||
self.cid = data["channelid"]
|
||||
d.addCallback(_allocated)
|
||||
|
||||
d.addCallback(lambda _: self.get("list"))
|
||||
d.addCallback(lambda _: self.get("list", "app1"))
|
||||
def _check_list_2(data):
|
||||
self.failUnlessEqual(data["channelids"], [self.cid])
|
||||
d.addCallback(_check_list_2)
|
||||
|
||||
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
|
||||
{"side": "abc"}))
|
||||
d.addCallback(lambda _: self.post("deallocate",
|
||||
{"appid": "app1",
|
||||
"channelid": str(self.cid),
|
||||
"side": "abc"}))
|
||||
def _check_deallocate(res):
|
||||
self.failUnlessEqual(res["status"], "deleted")
|
||||
d.addCallback(_check_deallocate)
|
||||
|
||||
d.addCallback(lambda _: self.get("list"))
|
||||
d.addCallback(lambda _: self.get("list", "app1"))
|
||||
def _check_list_3(data):
|
||||
self.failUnlessEqual(data["channelids"], [])
|
||||
d.addCallback(_check_list_3)
|
||||
|
@ -106,45 +121,57 @@ class API(ServerBase, unittest.TestCase):
|
|||
return d
|
||||
|
||||
def test_allocate_2(self):
|
||||
d = self.post("allocate", {"side": "abc"})
|
||||
d = self.post("allocate", {"appid": "app1", "side": "abc"})
|
||||
def _allocated(data):
|
||||
self.cid = data["channelid"]
|
||||
d.addCallback(_allocated)
|
||||
|
||||
# second caller increases the number of known sides to 2
|
||||
d.addCallback(lambda _: self.post("%d" % self.cid,
|
||||
{"side": "def",
|
||||
d.addCallback(lambda _: self.post("add",
|
||||
{"appid": "app1",
|
||||
"channelid": str(self.cid),
|
||||
"side": "def",
|
||||
"phase": "1",
|
||||
"body": ""}))
|
||||
|
||||
d.addCallback(lambda _: self.get("list"))
|
||||
d.addCallback(lambda _: self.get("list", "app1"))
|
||||
d.addCallback(lambda data:
|
||||
self.failUnlessEqual(data["channelids"], [self.cid]))
|
||||
|
||||
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
|
||||
{"side": "abc"}))
|
||||
d.addCallback(lambda _: self.post("deallocate",
|
||||
{"appid": "app1",
|
||||
"channelid": str(self.cid),
|
||||
"side": "abc"}))
|
||||
d.addCallback(lambda res:
|
||||
self.failUnlessEqual(res["status"], "waiting"))
|
||||
|
||||
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
|
||||
{"side": "NOT"}))
|
||||
d.addCallback(lambda _: self.post("deallocate",
|
||||
{"appid": "app1",
|
||||
"channelid": str(self.cid),
|
||||
"side": "NOT"}))
|
||||
d.addCallback(lambda res:
|
||||
self.failUnlessEqual(res["status"], "waiting"))
|
||||
|
||||
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
|
||||
{"side": "def"}))
|
||||
d.addCallback(lambda _: self.post("deallocate",
|
||||
{"appid": "app1",
|
||||
"channelid": str(self.cid),
|
||||
"side": "def"}))
|
||||
d.addCallback(lambda res:
|
||||
self.failUnlessEqual(res["status"], "deleted"))
|
||||
|
||||
d.addCallback(lambda _: self.get("list"))
|
||||
d.addCallback(lambda _: self.get("list", "app1"))
|
||||
d.addCallback(lambda data:
|
||||
self.failUnlessEqual(data["channelids"], []))
|
||||
|
||||
return d
|
||||
|
||||
def add_message(self, message, side="abc", phase="1"):
|
||||
return self.post(str(self.cid), {"side": side, "phase": phase,
|
||||
"body": message})
|
||||
return self.post("add",
|
||||
{"appid": "app1",
|
||||
"channelid": str(self.cid),
|
||||
"side": side,
|
||||
"phase": phase,
|
||||
"body": message})
|
||||
|
||||
def parse_messages(self, messages):
|
||||
out = set()
|
||||
|
@ -164,7 +191,7 @@ class API(ServerBase, unittest.TestCase):
|
|||
self.failUnlessIn(d, two)
|
||||
|
||||
def test_messages(self):
|
||||
d = self.post("allocate", {"side": "abc"})
|
||||
d = self.post("allocate", {"appid": "app1", "side": "abc"})
|
||||
def _allocated(data):
|
||||
self.cid = data["channelid"]
|
||||
d.addCallback(_allocated)
|
||||
|
@ -175,6 +202,8 @@ class API(ServerBase, unittest.TestCase):
|
|||
self.failUnlessEqual(data["messages"],
|
||||
[{"phase": "1", "body": "msg1A"}])
|
||||
d.addCallback(_check1)
|
||||
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
|
||||
d.addCallback(_check1)
|
||||
d.addCallback(lambda _: self.add_message("msg1B", side="def"))
|
||||
def _check2(data):
|
||||
self.check_welcome(data)
|
||||
|
@ -182,6 +211,8 @@ class API(ServerBase, unittest.TestCase):
|
|||
set([("1", "msg1A"),
|
||||
("1", "msg1B")]))
|
||||
d.addCallback(_check2)
|
||||
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
|
||||
d.addCallback(_check2)
|
||||
|
||||
# adding a duplicate message is not an error, is ignored by clients
|
||||
d.addCallback(lambda _: self.add_message("msg1B", side="def"))
|
||||
|
@ -191,6 +222,8 @@ class API(ServerBase, unittest.TestCase):
|
|||
set([("1", "msg1A"),
|
||||
("1", "msg1B")]))
|
||||
d.addCallback(_check3)
|
||||
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
|
||||
d.addCallback(_check3)
|
||||
|
||||
d.addCallback(lambda _: self.add_message("msg2A", side="abc",
|
||||
phase="2"))
|
||||
|
@ -202,6 +235,8 @@ class API(ServerBase, unittest.TestCase):
|
|||
("2", "msg2A"),
|
||||
]))
|
||||
d.addCallback(_check4)
|
||||
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
|
||||
d.addCallback(_check4)
|
||||
|
||||
return d
|
||||
|
||||
|
@ -209,10 +244,10 @@ class API(ServerBase, unittest.TestCase):
|
|||
if sys.version_info[0] >= 3:
|
||||
raise unittest.SkipTest("twisted vs py3")
|
||||
|
||||
d = self.post("allocate", {"side": "abc"})
|
||||
d = self.post("allocate", {"appid": "app1", "side": "abc"})
|
||||
def _allocated(data):
|
||||
self.cid = data["channelid"]
|
||||
url = self.relayurl+str(self.cid)
|
||||
url = self.build_url("get", "app1", self.cid)
|
||||
self.o = OneEventAtATime(url, parser=json.loads)
|
||||
return self.o.wait_for_connection()
|
||||
d.addCallback(_allocated)
|
||||
|
|
|
@ -1,11 +1,86 @@
|
|||
from __future__ import print_function
|
||||
import sys, json
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.defer import gatherResults
|
||||
from ..twisted.transcribe import Wormhole, UsageError
|
||||
from twisted.internet.defer import gatherResults, succeed
|
||||
from ..twisted.transcribe import Wormhole, UsageError, ChannelManager
|
||||
from .common import ServerBase
|
||||
|
||||
APPID = u"appid"
|
||||
|
||||
class Channel(ServerBase, unittest.TestCase):
|
||||
def ignore(self, welcome):
|
||||
pass
|
||||
|
||||
def test_allocate(self):
|
||||
cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore)
|
||||
d = cm.list_channels()
|
||||
def _got_channels(channels):
|
||||
self.failUnlessEqual(channels, [])
|
||||
d.addCallback(_got_channels)
|
||||
d.addCallback(lambda _: cm.allocate())
|
||||
def _allocated(channelid):
|
||||
self.failUnlessEqual(type(channelid), int)
|
||||
self._channelid = channelid
|
||||
d.addCallback(_allocated)
|
||||
d.addCallback(lambda _: cm.connect(self._channelid))
|
||||
def _connected(c):
|
||||
self._channel = c
|
||||
d.addCallback(_connected)
|
||||
d.addCallback(lambda _: self._channel.deallocate())
|
||||
return d
|
||||
|
||||
def test_messages(self):
|
||||
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
|
||||
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
|
||||
c1 = cm1.connect(1)
|
||||
c2 = cm2.connect(1)
|
||||
|
||||
d = succeed(None)
|
||||
d.addCallback(lambda _: c1.send(u"phase1", b"msg1"))
|
||||
d.addCallback(lambda _: c2.get(u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1"))
|
||||
d.addCallback(lambda _: c2.send(u"phase1", b"msg2"))
|
||||
d.addCallback(lambda _: c1.get(u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
|
||||
# it's legal to fetch a phase multiple times, should be idempotent
|
||||
d.addCallback(lambda _: c1.get(u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
|
||||
# deallocating one side is not enough to destroy the channel
|
||||
d.addCallback(lambda _: c2.deallocate())
|
||||
def _not_yet(_):
|
||||
self._relay_server.prune()
|
||||
self.failUnlessEqual(len(self._relay_server._apps), 1)
|
||||
d.addCallback(_not_yet)
|
||||
# but deallocating both will make the messages go away
|
||||
d.addCallback(lambda _: c1.deallocate())
|
||||
def _gone(_):
|
||||
self._relay_server.prune()
|
||||
self.failUnlessEqual(len(self._relay_server._apps), 0)
|
||||
d.addCallback(_gone)
|
||||
|
||||
return d
|
||||
|
||||
def test_appid_independence(self):
|
||||
APPID_A = u"appid_A"
|
||||
APPID_B = u"appid_B"
|
||||
cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore)
|
||||
cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore)
|
||||
c1a = cm1a.connect(1)
|
||||
c2a = cm2a.connect(1)
|
||||
cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore)
|
||||
cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore)
|
||||
c1b = cm1b.connect(1)
|
||||
c2b = cm2b.connect(1)
|
||||
|
||||
d = succeed(None)
|
||||
d.addCallback(lambda _: c1a.send(u"phase1", b"msg1a"))
|
||||
d.addCallback(lambda _: c1b.send(u"phase1", b"msg1b"))
|
||||
d.addCallback(lambda _: c2a.get(u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a"))
|
||||
d.addCallback(lambda _: c2b.get(u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b"))
|
||||
return d
|
||||
|
||||
class Basic(ServerBase, unittest.TestCase):
|
||||
|
||||
def doBoth(self, d1, d2):
|
||||
|
@ -154,6 +229,7 @@ class Basic(ServerBase, unittest.TestCase):
|
|||
return d
|
||||
|
||||
if sys.version_info[0] >= 3:
|
||||
Channel.skip = "twisted is not yet sufficiently ported to py3"
|
||||
Basic.skip = "twisted is not yet sufficiently ported to py3"
|
||||
# as of 15.4.0, Twisted is still missing:
|
||||
# * web.client.Agent (for all non-EventSource POSTs in transcribe.py)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import print_function
|
||||
import os, sys, json, re, unicodedata
|
||||
from six.moves.urllib_parse import urlencode
|
||||
from binascii import hexlify, unhexlify
|
||||
from zope.interface import implementer
|
||||
from twisted.internet import reactor, defer
|
||||
|
@ -50,10 +51,24 @@ def post_json(agent, url, request_body):
|
|||
d.addCallback(lambda data: json.loads(data))
|
||||
return d
|
||||
|
||||
def get_json(agent, url):
|
||||
# GET from a URL, parsing the response as JSON
|
||||
d = agent.request("GET", url.encode("utf-8"))
|
||||
def _check_error(resp):
|
||||
if resp.code != 200:
|
||||
raise web_error.Error(resp.code, resp.phrase)
|
||||
return resp
|
||||
d.addCallback(_check_error)
|
||||
d.addCallback(web_client.readBody)
|
||||
d.addCallback(lambda data: json.loads(data))
|
||||
return d
|
||||
|
||||
class Channel:
|
||||
def __init__(self, relay_url, channelid, side, handle_welcome,
|
||||
def __init__(self, relay_url, appid, channelid, side, handle_welcome,
|
||||
agent):
|
||||
self._channel_url = u"%s%d" % (relay_url, channelid)
|
||||
self._relay_url = relay_url
|
||||
self._appid = appid
|
||||
self._channelid = channelid
|
||||
self._side = side
|
||||
self._handle_welcome = handle_welcome
|
||||
self._agent = agent
|
||||
|
@ -78,10 +93,12 @@ class Channel:
|
|||
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
||||
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
|
||||
self._sent_messages.add( (phase,msg) )
|
||||
payload = {"side": self._side,
|
||||
payload = {"appid": self._appid,
|
||||
"channelid": self._channelid,
|
||||
"side": self._side,
|
||||
"phase": phase,
|
||||
"body": hexlify(msg).decode("ascii")}
|
||||
d = post_json(self._agent, self._channel_url, payload)
|
||||
d = post_json(self._agent, self._relay_url+"add", payload)
|
||||
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
|
||||
return d
|
||||
|
||||
|
@ -104,7 +121,10 @@ class Channel:
|
|||
msgs.append(body)
|
||||
d.callback(None)
|
||||
# TODO: use agent=self._agent
|
||||
es = ReconnectingEventSource(self._channel_url, _handle)
|
||||
queryargs = urlencode([("appid", self._appid),
|
||||
("channelid", self._channelid)])
|
||||
es = ReconnectingEventSource(self._relay_url+"get?%s" % queryargs,
|
||||
_handle)
|
||||
es.startService() # TODO: .setServiceParent(self)
|
||||
es.activate()
|
||||
d.addCallback(lambda _: es.deactivate())
|
||||
|
@ -114,22 +134,26 @@ class Channel:
|
|||
|
||||
def deallocate(self):
|
||||
# only try once, no retries
|
||||
d = post_json(self._agent, self._channel_url+"/deallocate",
|
||||
{"side": self._side})
|
||||
d = post_json(self._agent, self._relay_url+"deallocate",
|
||||
{"appid": self._appid,
|
||||
"channelid": self._channelid,
|
||||
"side": self._side})
|
||||
d.addBoth(lambda _: None) # ignore POST failure
|
||||
return d
|
||||
|
||||
class ChannelManager:
|
||||
def __init__(self, relay_url, side, handle_welcome):
|
||||
assert isinstance(relay_url, type(u""))
|
||||
self._relay_url = relay_url
|
||||
def __init__(self, relay, appid, side, handle_welcome):
|
||||
assert isinstance(relay, type(u""))
|
||||
self._relay = relay
|
||||
self._appid = appid
|
||||
self._side = side
|
||||
self._handle_welcome = handle_welcome
|
||||
self._agent = web_client.Agent(reactor)
|
||||
|
||||
def allocate(self):
|
||||
url = self._relay_url + "allocate"
|
||||
d = post_json(self._agent, url, {"side": self._side})
|
||||
url = self._relay + "allocate"
|
||||
d = post_json(self._agent, url, {"appid": self._appid,
|
||||
"side": self._side})
|
||||
def _got_channel(data):
|
||||
if "welcome" in data:
|
||||
self._handle_welcome(data["welcome"])
|
||||
|
@ -138,10 +162,14 @@ class ChannelManager:
|
|||
return d
|
||||
|
||||
def list_channels(self):
|
||||
raise NotImplementedError
|
||||
queryargs = urlencode([("appid", self._appid)])
|
||||
url = self._relay + u"list?%s" % queryargs
|
||||
d = get_json(self._agent, url)
|
||||
d.addCallback(lambda r: r["channelids"])
|
||||
return d
|
||||
|
||||
def connect(self, channelid):
|
||||
return Channel(self._relay_url, channelid, self._side,
|
||||
return Channel(self._relay, self._appid, channelid, self._side,
|
||||
self._handle_welcome, self._agent)
|
||||
|
||||
class Wormhole:
|
||||
|
@ -163,17 +191,16 @@ class Wormhole:
|
|||
|
||||
def _set_side(self, side):
|
||||
self._side = side
|
||||
self._channel_manager = ChannelManager(self._relay_url, self._side,
|
||||
self.handle_welcome)
|
||||
self._channel_manager = ChannelManager(self._relay_url, self._appid,
|
||||
self._side, self.handle_welcome)
|
||||
|
||||
def handle_welcome(self, welcome):
|
||||
if ("motd" in welcome and
|
||||
not self.motd_displayed):
|
||||
motd_lines = welcome["motd"].splitlines()
|
||||
motd_formatted = "\n ".join(motd_lines)
|
||||
print("Server (at %s) says:\n %s" % (self._relay_url,
|
||||
motd_formatted),
|
||||
file=sys.stderr)
|
||||
print("Server (at %s) says:\n %s" %
|
||||
(self._relay_url, motd_formatted), file=sys.stderr)
|
||||
self.motd_displayed = True
|
||||
|
||||
# Only warn if we're running a release version (e.g. 0.0.6, not
|
||||
|
|
Loading…
Reference in New Issue
Block a user