diff --git a/src/wormhole/servers/relay_server.py b/src/wormhole/servers/relay_server.py index 57fbab6..e3601db 100644 --- a/src/wormhole/servers/relay_server.py +++ b/src/wormhole/servers/relay_server.py @@ -57,29 +57,26 @@ class EventsProtocol: # POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted} # all JSON responses include a "welcome:{..}" key -class ChannelLister(resource.Resource): - def __init__(self, relay): +class RelayResource(resource.Resource): + def __init__(self, relay, welcome): resource.Resource.__init__(self) self._relay = relay + self._welcome = welcome +class ChannelLister(RelayResource): def render_GET(self, request): if b"appid" not in request.args: - e = NeedToUpgradeErrorResource(self._relay.welcome) + e = NeedToUpgradeErrorResource(self._welcome) return e.get_message() appid = request.args[b"appid"][0].decode("utf-8") #print("LIST", appid) app = self._relay.get_app(appid) allocated = app.get_allocated() request.setHeader(b"content-type", b"application/json; charset=utf-8") - data = {"welcome": self._relay.welcome, - "channelids": sorted(allocated)} + data = {"welcome": self._welcome, "channelids": sorted(allocated)} return (json.dumps(data)+"\n").encode("utf-8") -class Allocator(resource.Resource): - def __init__(self, relay): - resource.Resource.__init__(self) - self._relay = relay - +class Allocator(RelayResource): def render_POST(self, request): content = request.content.read() data = json.loads(content.decode("utf-8")) @@ -94,8 +91,7 @@ class Allocator(resource.Resource): log.msg("allocated #%d, now have %d DB channels" % (channelid, len(app.get_allocated()))) request.setHeader(b"content-type", b"application/json; charset=utf-8") - data = {"welcome": self._relay.welcome, - "channelid": channelid} + data = {"welcome": self._welcome, "channelid": channelid} return (json.dumps(data)+"\n").encode("utf-8") def getChild(self, path, req): @@ -103,7 +99,7 @@ class Allocator(resource.Resource): # wormhole-0.5.0 changed that to "POST /allocate". We catch the old # URL here to deliver a nicer error message (with upgrade # instructions) than an ugly 404. - return NeedToUpgradeErrorResource(self._relay.welcome) + return NeedToUpgradeErrorResource(self._welcome) class NeedToUpgradeErrorResource(resource.Resource): def __init__(self, welcome): @@ -121,11 +117,7 @@ class NeedToUpgradeErrorResource(resource.Resource): def getChild(self, path, req): return self -class Adder(resource.Resource): - def __init__(self, relay): - resource.Resource.__init__(self) - self._relay = relay - +class Adder(RelayResource): def render_POST(self, request): #content = json.load(request.content, encoding="utf-8") content = request.content.read() @@ -142,12 +134,11 @@ class Adder(resource.Resource): app = self._relay.get_app(appid) channel = app.get_channel(channelid) response = channel.add_message(side, phase, body) + # response is generated with get_messages(), so it includes both + # 'welcome' and 'messages' return json_response(request, response) -class Getter(resource.Resource): - def __init__(self, relay): - self._relay = relay - +class Getter(RelayResource): def render_GET(self, request): appid = request.args[b"appid"][0].decode("utf-8") channelid = int(request.args[b"channelid"][0]) @@ -161,7 +152,7 @@ class Getter(resource.Resource): request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") ep = EventsProtocol(request) - ep.sendEvent(json.dumps(self._relay.welcome), name="welcome") + ep.sendEvent(json.dumps(self._welcome), name="welcome") old_events = channel.add_listener(ep.sendEvent) request.notifyFinish().addErrback(lambda f: channel.remove_listener(ep.sendEvent)) @@ -169,10 +160,7 @@ class Getter(resource.Resource): ep.sendEvent(old_event) return server.NOT_DONE_YET -class Deallocator(resource.Resource): - def __init__(self, relay): - self._relay = relay - +class Deallocator(RelayResource): def render_POST(self, request): content = request.content.read() data = json.loads(content.decode("utf-8")) @@ -331,11 +319,11 @@ class Relay(resource.Resource, service.MultiService): self._apps = {} t = internet.TimerService(EXPIRATION_CHECK_PERIOD, self.prune) t.setServiceParent(self) - self.putChild(b"list", ChannelLister(self)) - self.putChild(b"allocate", Allocator(self)) - self.putChild(b"add", Adder(self)) - self.putChild(b"get", Getter(self)) - self.putChild(b"deallocate", Deallocator(self)) + self.putChild(b"list", ChannelLister(self, welcome)) + self.putChild(b"allocate", Allocator(self, welcome)) + self.putChild(b"add", Adder(self, welcome)) + self.putChild(b"get", Getter(self, welcome)) + self.putChild(b"deallocate", Deallocator(self, welcome)) def getChild(self, path, req): # 0.4.0 used "POST /CID/SIDE/post/MSGNUM"