From c3bd9e936ecb5d068a47245de3ecda395554b432 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 17 Apr 2016 14:41:12 -0700 Subject: [PATCH] split rendezvous server into web, nonweb files Also rename files/classes from "relay" to "rendezvous". --- .../{relay_server.py => rendezvous.py} | 216 +----------------- src/wormhole_server/rendezvous_web.py | 216 ++++++++++++++++++ src/wormhole_server/server.py | 54 +++-- tests/common.py | 4 +- tests/test_server.py | 8 +- 5 files changed, 263 insertions(+), 235 deletions(-) rename src/wormhole_server/{relay_server.py => rendezvous.py} (58%) create mode 100644 src/wormhole_server/rendezvous_web.py diff --git a/src/wormhole_server/relay_server.py b/src/wormhole_server/rendezvous.py similarity index 58% rename from src/wormhole_server/relay_server.py rename to src/wormhole_server/rendezvous.py index 6212c96..8e25f4a 100644 --- a/src/wormhole_server/relay_server.py +++ b/src/wormhole_server/rendezvous.py @@ -2,7 +2,6 @@ from __future__ import print_function import json, time, random from twisted.python import log from twisted.application import service, internet -from twisted.web import server, resource SECONDS = 1.0 MINUTE = 60*SECONDS @@ -16,200 +15,6 @@ EXPIRATION_CHECK_PERIOD = 2*HOUR ALLOCATE = u"_allocate" DEALLOCATE = u"_deallocate" -def json_response(request, data): - request.setHeader(b"content-type", b"application/json; charset=utf-8") - return (json.dumps(data)+"\n").encode("utf-8") - -class EventsProtocol: - def __init__(self, request): - self.request = request - - def sendComment(self, comment): - # this is ignored by clients, but can keep the connection open in the - # face of firewall/NAT timeouts. It also helps unit tests, since - # apparently twisted.web.client.Agent doesn't consider the connection - # to be established until it sees the first byte of the reponse body. - self.request.write(b": " + comment + b"\n\n") - - def sendEvent(self, data, name=None, id=None, retry=None): - if name: - self.request.write(b"event: " + name.encode("utf-8") + b"\n") - # e.g. if name=foo, then the client web page should do: - # (new EventSource(url)).addEventListener("foo", handlerfunc) - # Note that this basically defaults to "message". - if id: - self.request.write(b"id: " + id.encode("utf-8") + b"\n") - if retry: - self.request.write(b"retry: " + retry + b"\n") # milliseconds - for line in data.splitlines(): - self.request.write(b"data: " + line.encode("utf-8") + b"\n") - self.request.write(b"\n") - - def stop(self): - self.request.finish() - -# note: no versions of IE (including the current IE11) support EventSource - -# relay URLs are as follows: (MESSAGES=[{phase:,body:}..]) -# ("-" indicates a deprecated URL) -# GET /list?appid= -> {channelids: [INT..]} -# POST /allocate {appid:,side:} -> {channelid: INT} -# these return all messages (base64) for appid=/channelid= : -# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES} -# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES} -#- GET /get?appid=&channelid= (eventsource) -> {phase:, body:}.. -# GET /watch?appid=&channelid= (eventsource) -> {phase:, body:}.. -# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted} -# all JSON responses include a "welcome:{..}" key - -class RelayResource(resource.Resource): - def __init__(self, relay, welcome, log_requests): - resource.Resource.__init__(self) - self._relay = relay - self._welcome = welcome - self._log_requests = log_requests - -class ChannelLister(RelayResource): - def render_GET(self, request): - if b"appid" not in request.args: - e = NeedToUpgradeErrorResource(self._welcome) - return e.get_message() - appid = request.args[b"appid"][0].decode("utf-8") - #print("LIST", appid) - app = self._relay.get_app(appid) - allocated = app.get_allocated() - data = {"welcome": self._welcome, "channelids": sorted(allocated), - "sent": time.time()} - return json_response(request, data) - -class Allocator(RelayResource): - def render_POST(self, request): - content = request.content.read() - data = json.loads(content.decode("utf-8")) - appid = data["appid"] - side = data["side"] - if not isinstance(side, type(u"")): - raise TypeError("side must be string, not '%s'" % type(side)) - #print("ALLOCATE", appid, side) - app = self._relay.get_app(appid) - channelid = app.find_available_channelid() - app.allocate_channel(channelid, side) - if self._log_requests: - log.msg("allocated #%d, now have %d DB channels" % - (channelid, len(app.get_allocated()))) - response = {"welcome": self._welcome, "channelid": channelid, - "sent": time.time()} - return json_response(request, response) - - def getChild(self, path, req): - # wormhole-0.4.0 "send" started with "POST /allocate/SIDE". - # wormhole-0.5.0 changed that to "POST /allocate". We catch the old - # URL here to deliver a nicer error message (with upgrade - # instructions) than an ugly 404. - return NeedToUpgradeErrorResource(self._welcome) - -class NeedToUpgradeErrorResource(resource.Resource): - def __init__(self, welcome): - resource.Resource.__init__(self) - w = welcome.copy() - w["error"] = "Sorry, you must upgrade your client to use this server." - message = {"welcome": w} - self._message = (json.dumps(message)+"\n").encode("utf-8") - def get_message(self): - return self._message - def render_POST(self, request): - return self._message - def render_GET(self, request): - return self._message - def getChild(self, path, req): - return self - -class Adder(RelayResource): - def render_POST(self, request): - #content = json.load(request.content, encoding="utf-8") - content = request.content.read() - data = json.loads(content.decode("utf-8")) - appid = data["appid"] - channelid = int(data["channelid"]) - side = data["side"] - phase = data["phase"] - if not isinstance(phase, type(u"")): - raise TypeError("phase must be string, not %s" % type(phase)) - body = data["body"] - #print("ADD", appid, channelid, side, phase, body) - - app = self._relay.get_app(appid) - channel = app.get_channel(channelid) - messages = channel.add_message(side, phase, body) - response = {"welcome": self._welcome, "messages": messages, - "sent": time.time()} - return json_response(request, response) - -class GetterOrWatcher(RelayResource): - def render_GET(self, request): - appid = request.args[b"appid"][0].decode("utf-8") - channelid = int(request.args[b"channelid"][0]) - #print("GET", appid, channelid) - app = self._relay.get_app(appid) - channel = app.get_channel(channelid) - - if b"text/event-stream" not in (request.getHeader(b"accept") or b""): - messages = channel.get_messages() - response = {"welcome": self._welcome, "messages": messages, - "sent": time.time()} - return json_response(request, response) - - request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") - ep = EventsProtocol(request) - ep.sendEvent(json.dumps(self._welcome), name="welcome") - old_events = channel.add_listener(ep) - request.notifyFinish().addErrback(lambda f: - channel.remove_listener(ep)) - for old_event in old_events: - ep.sendEvent(old_event) - return server.NOT_DONE_YET - -class Watcher(RelayResource): - def render_GET(self, request): - appid = request.args[b"appid"][0].decode("utf-8") - channelid = int(request.args[b"channelid"][0]) - app = self._relay.get_app(appid) - channel = app.get_channel(channelid) - if b"text/event-stream" not in (request.getHeader(b"accept") or b""): - raise TypeError("/watch is for EventSource only") - - request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") - ep = EventsProtocol(request) - ep.sendEvent(json.dumps(self._welcome), name="welcome") - old_events = channel.add_listener(ep) - request.notifyFinish().addErrback(lambda f: - channel.remove_listener(ep)) - for old_event in old_events: - ep.sendEvent(old_event) - return server.NOT_DONE_YET - -class Deallocator(RelayResource): - def render_POST(self, request): - content = request.content.read() - data = json.loads(content.decode("utf-8")) - appid = data["appid"] - channelid = int(data["channelid"]) - side = data["side"] - if not isinstance(side, type(u"")): - raise TypeError("side must be string, not '%s'" % type(side)) - mood = data.get("mood") - #print("DEALLOCATE", appid, channelid, side) - - app = self._relay.get_app(appid) - channel = app.get_channel(channelid) - deleted = channel.deallocate(side, mood) - response = {"status": "waiting", "sent": time.time()} - if deleted: - response = {"status": "deleted", "sent": time.time()} - return json_response(request, response) - - - class Channel: def __init__(self, app, db, welcome, blur_usage, log_requests, appid, channelid): @@ -469,9 +274,8 @@ class AppNamespace: for channel in self._channels.values(): channel._shutdown() -class Relay(resource.Resource, service.MultiService): +class Rendezvous(service.MultiService): def __init__(self, db, welcome, blur_usage): - resource.Resource.__init__(self) service.MultiService.__init__(self) self._db = db self._welcome = welcome @@ -481,21 +285,11 @@ class Relay(resource.Resource, service.MultiService): self._apps = {} t = internet.TimerService(EXPIRATION_CHECK_PERIOD, self.prune) t.setServiceParent(self) - self.putChild(b"list", ChannelLister(self, welcome, log_requests)) - self.putChild(b"allocate", Allocator(self, welcome, log_requests)) - self.putChild(b"add", Adder(self, welcome, log_requests)) - self.putChild(b"get", GetterOrWatcher(self, welcome, log_requests)) - self.putChild(b"watch", Watcher(self, welcome, log_requests)) - self.putChild(b"deallocate", Deallocator(self, welcome, log_requests)) - def getChild(self, path, req): - # 0.4.0 used "POST /CID/SIDE/post/MSGNUM" - # 0.5.0 replaced it with "POST /add (json body)" - # give a nicer error message to old clients - if (len(req.postpath) >= 2 - and req.postpath[1] in (b"post", b"poll", b"deallocate")): - return NeedToUpgradeErrorResource(self._welcome) - return resource.NoResource("No such child resource.") + def get_welcome(self): + return self._welcome + def get_log_requests(self): + return self._log_requests def get_app(self, appid): assert isinstance(appid, type(u"")) diff --git a/src/wormhole_server/rendezvous_web.py b/src/wormhole_server/rendezvous_web.py new file mode 100644 index 0000000..63de23f --- /dev/null +++ b/src/wormhole_server/rendezvous_web.py @@ -0,0 +1,216 @@ +import json, time +from twisted.web import server, resource +from twisted.python import log + +def json_response(request, data): + request.setHeader(b"content-type", b"application/json; charset=utf-8") + return (json.dumps(data)+"\n").encode("utf-8") + +class EventsProtocol: + def __init__(self, request): + self.request = request + + def sendComment(self, comment): + # this is ignored by clients, but can keep the connection open in the + # face of firewall/NAT timeouts. It also helps unit tests, since + # apparently twisted.web.client.Agent doesn't consider the connection + # to be established until it sees the first byte of the reponse body. + self.request.write(b": " + comment + b"\n\n") + + def sendEvent(self, data, name=None, id=None, retry=None): + if name: + self.request.write(b"event: " + name.encode("utf-8") + b"\n") + # e.g. if name=foo, then the client web page should do: + # (new EventSource(url)).addEventListener("foo", handlerfunc) + # Note that this basically defaults to "message". + if id: + self.request.write(b"id: " + id.encode("utf-8") + b"\n") + if retry: + self.request.write(b"retry: " + retry + b"\n") # milliseconds + for line in data.splitlines(): + self.request.write(b"data: " + line.encode("utf-8") + b"\n") + self.request.write(b"\n") + + def stop(self): + self.request.finish() + +# note: no versions of IE (including the current IE11) support EventSource + +# relay URLs are as follows: (MESSAGES=[{phase:,body:}..]) +# ("-" indicates a deprecated URL) +# GET /list?appid= -> {channelids: [INT..]} +# POST /allocate {appid:,side:} -> {channelid: INT} +# these return all messages (base64) for appid=/channelid= : +# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES} +# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES} +#- GET /get?appid=&channelid= (eventsource) -> {phase:, body:}.. +# GET /watch?appid=&channelid= (eventsource) -> {phase:, body:}.. +# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted} +# all JSON responses include a "welcome:{..}" key + +class RelayResource(resource.Resource): + def __init__(self, rendezvous): + resource.Resource.__init__(self) + self._rendezvous = rendezvous + self._welcome = rendezvous.get_welcome() + +class ChannelLister(RelayResource): + def render_GET(self, request): + if b"appid" not in request.args: + e = NeedToUpgradeErrorResource(self._welcome) + return e.get_message() + appid = request.args[b"appid"][0].decode("utf-8") + #print("LIST", appid) + app = self._rendezvous.get_app(appid) + allocated = app.get_allocated() + data = {"welcome": self._welcome, "channelids": sorted(allocated), + "sent": time.time()} + return json_response(request, data) + +class Allocator(RelayResource): + def render_POST(self, request): + content = request.content.read() + data = json.loads(content.decode("utf-8")) + appid = data["appid"] + side = data["side"] + if not isinstance(side, type(u"")): + raise TypeError("side must be string, not '%s'" % type(side)) + #print("ALLOCATE", appid, side) + app = self._rendezvous.get_app(appid) + channelid = app.find_available_channelid() + app.allocate_channel(channelid, side) + if self._rendezvous.get_log_requests(): + log.msg("allocated #%d, now have %d DB channels" % + (channelid, len(app.get_allocated()))) + response = {"welcome": self._welcome, "channelid": channelid, + "sent": time.time()} + return json_response(request, response) + + def getChild(self, path, req): + # wormhole-0.4.0 "send" started with "POST /allocate/SIDE". + # wormhole-0.5.0 changed that to "POST /allocate". We catch the old + # URL here to deliver a nicer error message (with upgrade + # instructions) than an ugly 404. + return NeedToUpgradeErrorResource(self._welcome) + +class NeedToUpgradeErrorResource(resource.Resource): + def __init__(self, welcome): + resource.Resource.__init__(self) + w = welcome.copy() + w["error"] = "Sorry, you must upgrade your client to use this server." + message = {"welcome": w} + self._message = (json.dumps(message)+"\n").encode("utf-8") + def get_message(self): + return self._message + def render_POST(self, request): + return self._message + def render_GET(self, request): + return self._message + def getChild(self, path, req): + return self + +class Adder(RelayResource): + def render_POST(self, request): + #content = json.load(request.content, encoding="utf-8") + content = request.content.read() + data = json.loads(content.decode("utf-8")) + appid = data["appid"] + channelid = int(data["channelid"]) + side = data["side"] + phase = data["phase"] + if not isinstance(phase, type(u"")): + raise TypeError("phase must be string, not %s" % type(phase)) + body = data["body"] + #print("ADD", appid, channelid, side, phase, body) + + app = self._rendezvous.get_app(appid) + channel = app.get_channel(channelid) + messages = channel.add_message(side, phase, body) + response = {"welcome": self._welcome, "messages": messages, + "sent": time.time()} + return json_response(request, response) + +class GetterOrWatcher(RelayResource): + def render_GET(self, request): + appid = request.args[b"appid"][0].decode("utf-8") + channelid = int(request.args[b"channelid"][0]) + #print("GET", appid, channelid) + app = self._rendezvous.get_app(appid) + channel = app.get_channel(channelid) + + if b"text/event-stream" not in (request.getHeader(b"accept") or b""): + messages = channel.get_messages() + response = {"welcome": self._welcome, "messages": messages, + "sent": time.time()} + return json_response(request, response) + + request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") + ep = EventsProtocol(request) + ep.sendEvent(json.dumps(self._welcome), name="welcome") + old_events = channel.add_listener(ep) + request.notifyFinish().addErrback(lambda f: + channel.remove_listener(ep)) + for old_event in old_events: + ep.sendEvent(old_event) + return server.NOT_DONE_YET + +class Watcher(RelayResource): + def render_GET(self, request): + appid = request.args[b"appid"][0].decode("utf-8") + channelid = int(request.args[b"channelid"][0]) + app = self._rendezvous.get_app(appid) + channel = app.get_channel(channelid) + if b"text/event-stream" not in (request.getHeader(b"accept") or b""): + raise TypeError("/watch is for EventSource only") + + request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") + ep = EventsProtocol(request) + ep.sendEvent(json.dumps(self._welcome), name="welcome") + old_events = channel.add_listener(ep) + request.notifyFinish().addErrback(lambda f: + channel.remove_listener(ep)) + for old_event in old_events: + ep.sendEvent(old_event) + return server.NOT_DONE_YET + +class Deallocator(RelayResource): + def render_POST(self, request): + content = request.content.read() + data = json.loads(content.decode("utf-8")) + appid = data["appid"] + channelid = int(data["channelid"]) + side = data["side"] + if not isinstance(side, type(u"")): + raise TypeError("side must be string, not '%s'" % type(side)) + mood = data.get("mood") + #print("DEALLOCATE", appid, channelid, side) + + app = self._rendezvous.get_app(appid) + channel = app.get_channel(channelid) + deleted = channel.deallocate(side, mood) + response = {"status": "waiting", "sent": time.time()} + if deleted: + response = {"status": "deleted", "sent": time.time()} + return json_response(request, response) + + +class WebRendezvous(resource.Resource): + def __init__(self, rendezvous): + resource.Resource.__init__(self) + self._rendezvous = rendezvous + self.putChild(b"list", ChannelLister(rendezvous)) + self.putChild(b"allocate", Allocator(rendezvous)) + self.putChild(b"add", Adder(rendezvous)) + self.putChild(b"get", GetterOrWatcher(rendezvous)) + self.putChild(b"watch", Watcher(rendezvous)) + self.putChild(b"deallocate", Deallocator(rendezvous)) + + def getChild(self, path, req): + # 0.4.0 used "POST /CID/SIDE/post/MSGNUM" + # 0.5.0 replaced it with "POST /add (json body)" + # give a nicer error message to old clients + if (len(req.postpath) >= 2 + and req.postpath[1] in (b"post", b"poll", b"deallocate")): + welcome = self._rendezvous.get_welcome() + return NeedToUpgradeErrorResource(welcome) + return resource.NoResource("No such child resource.") diff --git a/src/wormhole_server/server.py b/src/wormhole_server/server.py index 321ee27..90971b9 100644 --- a/src/wormhole_server/server.py +++ b/src/wormhole_server/server.py @@ -6,7 +6,8 @@ from twisted.web import server, static, resource from .endpoint_service import ServerEndpointService from wormhole import __version__ from .database import get_db -from .relay_server import Relay +from .rendezvous import Rendezvous +from .rendezvous_web import WebRendezvous from .transit_server import Transit class Root(resource.Resource): @@ -22,11 +23,12 @@ class PrivacyEnhancedSite(server.Site): return server.Site.log(self, request) class RelayServer(service.MultiService): - def __init__(self, relayport, transitport, advertise_version, - db_url=":memory:", blur_usage=None): + def __init__(self, rendezvous_web_port, transit_port, + advertise_version, db_url=":memory:", blur_usage=None): service.MultiService.__init__(self) self._blur_usage = blur_usage - self.db = get_db(db_url) + + db = get_db(db_url) welcome = { "current_version": __version__, # adding .motd will cause all clients to display the message, @@ -38,22 +40,38 @@ class RelayServer(service.MultiService): } if advertise_version: welcome["current_version"] = advertise_version - self.root = Root() - site = PrivacyEnhancedSite(self.root) + + rendezvous = Rendezvous(db, welcome, blur_usage) + rendezvous.setServiceParent(self) # for the pruning timer + + root = Root() + wr = WebRendezvous(rendezvous) + root.putChild(b"wormhole-relay", wr) + + site = PrivacyEnhancedSite(root) if blur_usage: site.logRequests = False - r = endpoints.serverFromString(reactor, relayport) - self.relayport_service = ServerEndpointService(r, site) - self.relayport_service.setServiceParent(self) - self.relay = Relay(self.db, welcome, blur_usage) # accessible from tests - self.relay.setServiceParent(self) # for the pruning timer - self.root.putChild(b"wormhole-relay", self.relay) - if transitport: - self.transit = Transit(self.db, blur_usage) - self.transit.setServiceParent(self) # for the timer - t = endpoints.serverFromString(reactor, transitport) - self.transport_service = ServerEndpointService(t, self.transit) - self.transport_service.setServiceParent(self) + + r = endpoints.serverFromString(reactor, rendezvous_web_port) + rendezvous_web_service = ServerEndpointService(r, site) + rendezvous_web_service.setServiceParent(self) + + if transit_port: + transit = Transit(db, blur_usage) + transit.setServiceParent(self) # for the timer + t = endpoints.serverFromString(reactor, transit_port) + transit_service = ServerEndpointService(t, transit) + transit_service.setServiceParent(self) + + # make some things accessible for tests + self._db = db + self._rendezvous = rendezvous + self._root = root + self._rendezvous_web = wr + self._rendezvous_web_service = rendezvous_web_service + if transit_port: + self._transit = transit + self._transit_service = transit_service def startService(self): service.MultiService.startService(self) diff --git a/tests/common.py b/tests/common.py index 4a46b37..a14a43e 100644 --- a/tests/common.py +++ b/tests/common.py @@ -15,8 +15,8 @@ class ServerBase: "tcp:%s:interface=127.0.0.1" % transitport, __version__) s.setServiceParent(self.sp) - self._relay_server = s.relay - self._transit_server = s.transit + self._relay_server = s._rendezvous + self._transit_server = s._transit self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport self.transit = u"tcp:127.0.0.1:%d" % transitport diff --git a/tests/test_server.py b/tests/test_server.py index 87a8c8d..1d76755 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -10,7 +10,7 @@ from twisted.internet.endpoints import clientFromString, connectProtocol from twisted.web.client import getPage, Agent, readBody from wormhole import __version__ from .common import ServerBase -from wormhole_server import relay_server, transit_server +from wormhole_server import rendezvous, transit_server from txwormhole.eventsource import EventSource class Reachable(ServerBase, unittest.TestCase): @@ -369,9 +369,9 @@ class OneEventAtATime: class Summary(unittest.TestCase): def test_summarize(self): - c = relay_server.Channel(None, None, None, None, False, None, None) - A = relay_server.ALLOCATE - D = relay_server.DEALLOCATE + c = rendezvous.Channel(None, None, None, None, False, None, None) + A = rendezvous.ALLOCATE + D = rendezvous.DEALLOCATE messages = [{"when": 1, "side": "a", "phase": A}] self.failUnlessEqual(c._summarize(messages, 2),