split rendezvous server into web, nonweb files

Also rename files/classes from "relay" to "rendezvous".
This commit is contained in:
Brian Warner 2016-04-17 14:41:12 -07:00
parent 7321a2391b
commit c3bd9e936e
5 changed files with 263 additions and 235 deletions

View File

@ -2,7 +2,6 @@ from __future__ import print_function
import json, time, random import json, time, random
from twisted.python import log from twisted.python import log
from twisted.application import service, internet from twisted.application import service, internet
from twisted.web import server, resource
SECONDS = 1.0 SECONDS = 1.0
MINUTE = 60*SECONDS MINUTE = 60*SECONDS
@ -16,200 +15,6 @@ EXPIRATION_CHECK_PERIOD = 2*HOUR
ALLOCATE = u"_allocate" ALLOCATE = u"_allocate"
DEALLOCATE = u"_deallocate" DEALLOCATE = u"_deallocate"
def json_response(request, data):
request.setHeader(b"content-type", b"application/json; charset=utf-8")
return (json.dumps(data)+"\n").encode("utf-8")
class EventsProtocol:
def __init__(self, request):
self.request = request
def sendComment(self, comment):
# this is ignored by clients, but can keep the connection open in the
# face of firewall/NAT timeouts. It also helps unit tests, since
# apparently twisted.web.client.Agent doesn't consider the connection
# to be established until it sees the first byte of the reponse body.
self.request.write(b": " + comment + b"\n\n")
def sendEvent(self, data, name=None, id=None, retry=None):
if name:
self.request.write(b"event: " + name.encode("utf-8") + b"\n")
# e.g. if name=foo, then the client web page should do:
# (new EventSource(url)).addEventListener("foo", handlerfunc)
# Note that this basically defaults to "message".
if id:
self.request.write(b"id: " + id.encode("utf-8") + b"\n")
if retry:
self.request.write(b"retry: " + retry + b"\n") # milliseconds
for line in data.splitlines():
self.request.write(b"data: " + line.encode("utf-8") + b"\n")
self.request.write(b"\n")
def stop(self):
self.request.finish()
# note: no versions of IE (including the current IE11) support EventSource
# relay URLs are as follows: (MESSAGES=[{phase:,body:}..])
# ("-" indicates a deprecated URL)
# GET /list?appid= -> {channelids: [INT..]}
# POST /allocate {appid:,side:} -> {channelid: INT}
# these return all messages (base64) for appid=/channelid= :
# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES}
# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES}
#- GET /get?appid=&channelid= (eventsource) -> {phase:, body:}..
# GET /watch?appid=&channelid= (eventsource) -> {phase:, body:}..
# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted}
# all JSON responses include a "welcome:{..}" key
class RelayResource(resource.Resource):
def __init__(self, relay, welcome, log_requests):
resource.Resource.__init__(self)
self._relay = relay
self._welcome = welcome
self._log_requests = log_requests
class ChannelLister(RelayResource):
def render_GET(self, request):
if b"appid" not in request.args:
e = NeedToUpgradeErrorResource(self._welcome)
return e.get_message()
appid = request.args[b"appid"][0].decode("utf-8")
#print("LIST", appid)
app = self._relay.get_app(appid)
allocated = app.get_allocated()
data = {"welcome": self._welcome, "channelids": sorted(allocated),
"sent": time.time()}
return json_response(request, data)
class Allocator(RelayResource):
def render_POST(self, request):
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
side = data["side"]
if not isinstance(side, type(u"")):
raise TypeError("side must be string, not '%s'" % type(side))
#print("ALLOCATE", appid, side)
app = self._relay.get_app(appid)
channelid = app.find_available_channelid()
app.allocate_channel(channelid, side)
if self._log_requests:
log.msg("allocated #%d, now have %d DB channels" %
(channelid, len(app.get_allocated())))
response = {"welcome": self._welcome, "channelid": channelid,
"sent": time.time()}
return json_response(request, response)
def getChild(self, path, req):
# wormhole-0.4.0 "send" started with "POST /allocate/SIDE".
# wormhole-0.5.0 changed that to "POST /allocate". We catch the old
# URL here to deliver a nicer error message (with upgrade
# instructions) than an ugly 404.
return NeedToUpgradeErrorResource(self._welcome)
class NeedToUpgradeErrorResource(resource.Resource):
def __init__(self, welcome):
resource.Resource.__init__(self)
w = welcome.copy()
w["error"] = "Sorry, you must upgrade your client to use this server."
message = {"welcome": w}
self._message = (json.dumps(message)+"\n").encode("utf-8")
def get_message(self):
return self._message
def render_POST(self, request):
return self._message
def render_GET(self, request):
return self._message
def getChild(self, path, req):
return self
class Adder(RelayResource):
def render_POST(self, request):
#content = json.load(request.content, encoding="utf-8")
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
channelid = int(data["channelid"])
side = data["side"]
phase = data["phase"]
if not isinstance(phase, type(u"")):
raise TypeError("phase must be string, not %s" % type(phase))
body = data["body"]
#print("ADD", appid, channelid, side, phase, body)
app = self._relay.get_app(appid)
channel = app.get_channel(channelid)
messages = channel.add_message(side, phase, body)
response = {"welcome": self._welcome, "messages": messages,
"sent": time.time()}
return json_response(request, response)
class GetterOrWatcher(RelayResource):
def render_GET(self, request):
appid = request.args[b"appid"][0].decode("utf-8")
channelid = int(request.args[b"channelid"][0])
#print("GET", appid, channelid)
app = self._relay.get_app(appid)
channel = app.get_channel(channelid)
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
messages = channel.get_messages()
response = {"welcome": self._welcome, "messages": messages,
"sent": time.time()}
return json_response(request, response)
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self._welcome), name="welcome")
old_events = channel.add_listener(ep)
request.notifyFinish().addErrback(lambda f:
channel.remove_listener(ep))
for old_event in old_events:
ep.sendEvent(old_event)
return server.NOT_DONE_YET
class Watcher(RelayResource):
def render_GET(self, request):
appid = request.args[b"appid"][0].decode("utf-8")
channelid = int(request.args[b"channelid"][0])
app = self._relay.get_app(appid)
channel = app.get_channel(channelid)
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
raise TypeError("/watch is for EventSource only")
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self._welcome), name="welcome")
old_events = channel.add_listener(ep)
request.notifyFinish().addErrback(lambda f:
channel.remove_listener(ep))
for old_event in old_events:
ep.sendEvent(old_event)
return server.NOT_DONE_YET
class Deallocator(RelayResource):
def render_POST(self, request):
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
channelid = int(data["channelid"])
side = data["side"]
if not isinstance(side, type(u"")):
raise TypeError("side must be string, not '%s'" % type(side))
mood = data.get("mood")
#print("DEALLOCATE", appid, channelid, side)
app = self._relay.get_app(appid)
channel = app.get_channel(channelid)
deleted = channel.deallocate(side, mood)
response = {"status": "waiting", "sent": time.time()}
if deleted:
response = {"status": "deleted", "sent": time.time()}
return json_response(request, response)
class Channel: class Channel:
def __init__(self, app, db, welcome, blur_usage, log_requests, def __init__(self, app, db, welcome, blur_usage, log_requests,
appid, channelid): appid, channelid):
@ -469,9 +274,8 @@ class AppNamespace:
for channel in self._channels.values(): for channel in self._channels.values():
channel._shutdown() channel._shutdown()
class Relay(resource.Resource, service.MultiService): class Rendezvous(service.MultiService):
def __init__(self, db, welcome, blur_usage): def __init__(self, db, welcome, blur_usage):
resource.Resource.__init__(self)
service.MultiService.__init__(self) service.MultiService.__init__(self)
self._db = db self._db = db
self._welcome = welcome self._welcome = welcome
@ -481,21 +285,11 @@ class Relay(resource.Resource, service.MultiService):
self._apps = {} self._apps = {}
t = internet.TimerService(EXPIRATION_CHECK_PERIOD, self.prune) t = internet.TimerService(EXPIRATION_CHECK_PERIOD, self.prune)
t.setServiceParent(self) t.setServiceParent(self)
self.putChild(b"list", ChannelLister(self, welcome, log_requests))
self.putChild(b"allocate", Allocator(self, welcome, log_requests))
self.putChild(b"add", Adder(self, welcome, log_requests))
self.putChild(b"get", GetterOrWatcher(self, welcome, log_requests))
self.putChild(b"watch", Watcher(self, welcome, log_requests))
self.putChild(b"deallocate", Deallocator(self, welcome, log_requests))
def getChild(self, path, req): def get_welcome(self):
# 0.4.0 used "POST /CID/SIDE/post/MSGNUM" return self._welcome
# 0.5.0 replaced it with "POST /add (json body)" def get_log_requests(self):
# give a nicer error message to old clients return self._log_requests
if (len(req.postpath) >= 2
and req.postpath[1] in (b"post", b"poll", b"deallocate")):
return NeedToUpgradeErrorResource(self._welcome)
return resource.NoResource("No such child resource.")
def get_app(self, appid): def get_app(self, appid):
assert isinstance(appid, type(u"")) assert isinstance(appid, type(u""))

View File

@ -0,0 +1,216 @@
import json, time
from twisted.web import server, resource
from twisted.python import log
def json_response(request, data):
request.setHeader(b"content-type", b"application/json; charset=utf-8")
return (json.dumps(data)+"\n").encode("utf-8")
class EventsProtocol:
def __init__(self, request):
self.request = request
def sendComment(self, comment):
# this is ignored by clients, but can keep the connection open in the
# face of firewall/NAT timeouts. It also helps unit tests, since
# apparently twisted.web.client.Agent doesn't consider the connection
# to be established until it sees the first byte of the reponse body.
self.request.write(b": " + comment + b"\n\n")
def sendEvent(self, data, name=None, id=None, retry=None):
if name:
self.request.write(b"event: " + name.encode("utf-8") + b"\n")
# e.g. if name=foo, then the client web page should do:
# (new EventSource(url)).addEventListener("foo", handlerfunc)
# Note that this basically defaults to "message".
if id:
self.request.write(b"id: " + id.encode("utf-8") + b"\n")
if retry:
self.request.write(b"retry: " + retry + b"\n") # milliseconds
for line in data.splitlines():
self.request.write(b"data: " + line.encode("utf-8") + b"\n")
self.request.write(b"\n")
def stop(self):
self.request.finish()
# note: no versions of IE (including the current IE11) support EventSource
# relay URLs are as follows: (MESSAGES=[{phase:,body:}..])
# ("-" indicates a deprecated URL)
# GET /list?appid= -> {channelids: [INT..]}
# POST /allocate {appid:,side:} -> {channelid: INT}
# these return all messages (base64) for appid=/channelid= :
# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES}
# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES}
#- GET /get?appid=&channelid= (eventsource) -> {phase:, body:}..
# GET /watch?appid=&channelid= (eventsource) -> {phase:, body:}..
# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted}
# all JSON responses include a "welcome:{..}" key
class RelayResource(resource.Resource):
def __init__(self, rendezvous):
resource.Resource.__init__(self)
self._rendezvous = rendezvous
self._welcome = rendezvous.get_welcome()
class ChannelLister(RelayResource):
def render_GET(self, request):
if b"appid" not in request.args:
e = NeedToUpgradeErrorResource(self._welcome)
return e.get_message()
appid = request.args[b"appid"][0].decode("utf-8")
#print("LIST", appid)
app = self._rendezvous.get_app(appid)
allocated = app.get_allocated()
data = {"welcome": self._welcome, "channelids": sorted(allocated),
"sent": time.time()}
return json_response(request, data)
class Allocator(RelayResource):
def render_POST(self, request):
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
side = data["side"]
if not isinstance(side, type(u"")):
raise TypeError("side must be string, not '%s'" % type(side))
#print("ALLOCATE", appid, side)
app = self._rendezvous.get_app(appid)
channelid = app.find_available_channelid()
app.allocate_channel(channelid, side)
if self._rendezvous.get_log_requests():
log.msg("allocated #%d, now have %d DB channels" %
(channelid, len(app.get_allocated())))
response = {"welcome": self._welcome, "channelid": channelid,
"sent": time.time()}
return json_response(request, response)
def getChild(self, path, req):
# wormhole-0.4.0 "send" started with "POST /allocate/SIDE".
# wormhole-0.5.0 changed that to "POST /allocate". We catch the old
# URL here to deliver a nicer error message (with upgrade
# instructions) than an ugly 404.
return NeedToUpgradeErrorResource(self._welcome)
class NeedToUpgradeErrorResource(resource.Resource):
def __init__(self, welcome):
resource.Resource.__init__(self)
w = welcome.copy()
w["error"] = "Sorry, you must upgrade your client to use this server."
message = {"welcome": w}
self._message = (json.dumps(message)+"\n").encode("utf-8")
def get_message(self):
return self._message
def render_POST(self, request):
return self._message
def render_GET(self, request):
return self._message
def getChild(self, path, req):
return self
class Adder(RelayResource):
def render_POST(self, request):
#content = json.load(request.content, encoding="utf-8")
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
channelid = int(data["channelid"])
side = data["side"]
phase = data["phase"]
if not isinstance(phase, type(u"")):
raise TypeError("phase must be string, not %s" % type(phase))
body = data["body"]
#print("ADD", appid, channelid, side, phase, body)
app = self._rendezvous.get_app(appid)
channel = app.get_channel(channelid)
messages = channel.add_message(side, phase, body)
response = {"welcome": self._welcome, "messages": messages,
"sent": time.time()}
return json_response(request, response)
class GetterOrWatcher(RelayResource):
def render_GET(self, request):
appid = request.args[b"appid"][0].decode("utf-8")
channelid = int(request.args[b"channelid"][0])
#print("GET", appid, channelid)
app = self._rendezvous.get_app(appid)
channel = app.get_channel(channelid)
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
messages = channel.get_messages()
response = {"welcome": self._welcome, "messages": messages,
"sent": time.time()}
return json_response(request, response)
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self._welcome), name="welcome")
old_events = channel.add_listener(ep)
request.notifyFinish().addErrback(lambda f:
channel.remove_listener(ep))
for old_event in old_events:
ep.sendEvent(old_event)
return server.NOT_DONE_YET
class Watcher(RelayResource):
def render_GET(self, request):
appid = request.args[b"appid"][0].decode("utf-8")
channelid = int(request.args[b"channelid"][0])
app = self._rendezvous.get_app(appid)
channel = app.get_channel(channelid)
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
raise TypeError("/watch is for EventSource only")
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self._welcome), name="welcome")
old_events = channel.add_listener(ep)
request.notifyFinish().addErrback(lambda f:
channel.remove_listener(ep))
for old_event in old_events:
ep.sendEvent(old_event)
return server.NOT_DONE_YET
class Deallocator(RelayResource):
def render_POST(self, request):
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
channelid = int(data["channelid"])
side = data["side"]
if not isinstance(side, type(u"")):
raise TypeError("side must be string, not '%s'" % type(side))
mood = data.get("mood")
#print("DEALLOCATE", appid, channelid, side)
app = self._rendezvous.get_app(appid)
channel = app.get_channel(channelid)
deleted = channel.deallocate(side, mood)
response = {"status": "waiting", "sent": time.time()}
if deleted:
response = {"status": "deleted", "sent": time.time()}
return json_response(request, response)
class WebRendezvous(resource.Resource):
def __init__(self, rendezvous):
resource.Resource.__init__(self)
self._rendezvous = rendezvous
self.putChild(b"list", ChannelLister(rendezvous))
self.putChild(b"allocate", Allocator(rendezvous))
self.putChild(b"add", Adder(rendezvous))
self.putChild(b"get", GetterOrWatcher(rendezvous))
self.putChild(b"watch", Watcher(rendezvous))
self.putChild(b"deallocate", Deallocator(rendezvous))
def getChild(self, path, req):
# 0.4.0 used "POST /CID/SIDE/post/MSGNUM"
# 0.5.0 replaced it with "POST /add (json body)"
# give a nicer error message to old clients
if (len(req.postpath) >= 2
and req.postpath[1] in (b"post", b"poll", b"deallocate")):
welcome = self._rendezvous.get_welcome()
return NeedToUpgradeErrorResource(welcome)
return resource.NoResource("No such child resource.")

View File

@ -6,7 +6,8 @@ from twisted.web import server, static, resource
from .endpoint_service import ServerEndpointService from .endpoint_service import ServerEndpointService
from wormhole import __version__ from wormhole import __version__
from .database import get_db from .database import get_db
from .relay_server import Relay from .rendezvous import Rendezvous
from .rendezvous_web import WebRendezvous
from .transit_server import Transit from .transit_server import Transit
class Root(resource.Resource): class Root(resource.Resource):
@ -22,11 +23,12 @@ class PrivacyEnhancedSite(server.Site):
return server.Site.log(self, request) return server.Site.log(self, request)
class RelayServer(service.MultiService): class RelayServer(service.MultiService):
def __init__(self, relayport, transitport, advertise_version, def __init__(self, rendezvous_web_port, transit_port,
db_url=":memory:", blur_usage=None): advertise_version, db_url=":memory:", blur_usage=None):
service.MultiService.__init__(self) service.MultiService.__init__(self)
self._blur_usage = blur_usage self._blur_usage = blur_usage
self.db = get_db(db_url)
db = get_db(db_url)
welcome = { welcome = {
"current_version": __version__, "current_version": __version__,
# adding .motd will cause all clients to display the message, # adding .motd will cause all clients to display the message,
@ -38,22 +40,38 @@ class RelayServer(service.MultiService):
} }
if advertise_version: if advertise_version:
welcome["current_version"] = advertise_version welcome["current_version"] = advertise_version
self.root = Root()
site = PrivacyEnhancedSite(self.root) rendezvous = Rendezvous(db, welcome, blur_usage)
rendezvous.setServiceParent(self) # for the pruning timer
root = Root()
wr = WebRendezvous(rendezvous)
root.putChild(b"wormhole-relay", wr)
site = PrivacyEnhancedSite(root)
if blur_usage: if blur_usage:
site.logRequests = False site.logRequests = False
r = endpoints.serverFromString(reactor, relayport)
self.relayport_service = ServerEndpointService(r, site) r = endpoints.serverFromString(reactor, rendezvous_web_port)
self.relayport_service.setServiceParent(self) rendezvous_web_service = ServerEndpointService(r, site)
self.relay = Relay(self.db, welcome, blur_usage) # accessible from tests rendezvous_web_service.setServiceParent(self)
self.relay.setServiceParent(self) # for the pruning timer
self.root.putChild(b"wormhole-relay", self.relay) if transit_port:
if transitport: transit = Transit(db, blur_usage)
self.transit = Transit(self.db, blur_usage) transit.setServiceParent(self) # for the timer
self.transit.setServiceParent(self) # for the timer t = endpoints.serverFromString(reactor, transit_port)
t = endpoints.serverFromString(reactor, transitport) transit_service = ServerEndpointService(t, transit)
self.transport_service = ServerEndpointService(t, self.transit) transit_service.setServiceParent(self)
self.transport_service.setServiceParent(self)
# make some things accessible for tests
self._db = db
self._rendezvous = rendezvous
self._root = root
self._rendezvous_web = wr
self._rendezvous_web_service = rendezvous_web_service
if transit_port:
self._transit = transit
self._transit_service = transit_service
def startService(self): def startService(self):
service.MultiService.startService(self) service.MultiService.startService(self)

View File

@ -15,8 +15,8 @@ class ServerBase:
"tcp:%s:interface=127.0.0.1" % transitport, "tcp:%s:interface=127.0.0.1" % transitport,
__version__) __version__)
s.setServiceParent(self.sp) s.setServiceParent(self.sp)
self._relay_server = s.relay self._relay_server = s._rendezvous
self._transit_server = s.transit self._transit_server = s._transit
self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport
self.transit = u"tcp:127.0.0.1:%d" % transitport self.transit = u"tcp:127.0.0.1:%d" % transitport

View File

@ -10,7 +10,7 @@ from twisted.internet.endpoints import clientFromString, connectProtocol
from twisted.web.client import getPage, Agent, readBody from twisted.web.client import getPage, Agent, readBody
from wormhole import __version__ from wormhole import __version__
from .common import ServerBase from .common import ServerBase
from wormhole_server import relay_server, transit_server from wormhole_server import rendezvous, transit_server
from txwormhole.eventsource import EventSource from txwormhole.eventsource import EventSource
class Reachable(ServerBase, unittest.TestCase): class Reachable(ServerBase, unittest.TestCase):
@ -369,9 +369,9 @@ class OneEventAtATime:
class Summary(unittest.TestCase): class Summary(unittest.TestCase):
def test_summarize(self): def test_summarize(self):
c = relay_server.Channel(None, None, None, None, False, None, None) c = rendezvous.Channel(None, None, None, None, False, None, None)
A = relay_server.ALLOCATE A = rendezvous.ALLOCATE
D = relay_server.DEALLOCATE D = rendezvous.DEALLOCATE
messages = [{"when": 1, "side": "a", "phase": A}] messages = [{"when": 1, "side": "a", "phase": A}]
self.failUnlessEqual(c._summarize(messages, 2), self.failUnlessEqual(c._summarize(messages, 2),