diff --git a/src/wormhole/servers/relay_server.py b/src/wormhole/servers/relay_server.py index 862da9d..c1ebe56 100644 --- a/src/wormhole/servers/relay_server.py +++ b/src/wormhole/servers/relay_server.py @@ -57,6 +57,9 @@ class EventsProtocol: # POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES} # GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES} # GET /get?appid=&channelid= (eventsource) -> {phase:, body:}.. +# POST /add_messages {appid:,channelid:,side:,messages:}-> {messages: MESSAGES} +# GET /get_messages?appid=&channelid= -> {messages: MESSAGES} +# GET /watch_messages?appid=&channelid= (eventsource)-> {[phase:, body:]..}.. # POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted} # all JSON responses include a "welcome:{..}" key @@ -136,7 +139,7 @@ class Adder(RelayResource): app = self._relay.get_app(appid) channel = app.get_channel(channelid) - response = channel.add_message(side, phase, body) + response = channel.add_messages(side, [(phase, body)]) # response is generated with get_messages(), so it includes both # 'welcome' and 'messages' return json_response(request, response) @@ -156,11 +159,74 @@ class Getter(RelayResource): request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") ep = EventsProtocol(request) ep.sendEvent(json.dumps(self._welcome), name="welcome") - old_events = channel.add_listener(ep.sendEvent) + def _send(messages): + for (phase, body) in messages: + data = json.dumps({"phase": phase, "body": body}) + ep.sendEvent(data) + old_messages = list(channel.add_listener(_send)) request.notifyFinish().addErrback(lambda f: - channel.remove_listener(ep.sendEvent)) - for old_event in old_events: - ep.sendEvent(old_event) + channel.remove_listener(_send)) + if old_messages: + _send(old_messages) + return server.NOT_DONE_YET + +class MessageAdder(RelayResource): + def render_POST(self, request): + #content = json.load(request.content, encoding="utf-8") + content = request.content.read() + data = json.loads(content.decode("utf-8")) + appid = data["appid"] + channelid = int(data["channelid"]) + side = data["side"] + messages = [] + for m in data["messages"]: + phase = m.get("phase") + if not isinstance(phase, type(u"")): + raise TypeError("phase must be string, not %s" % type(phase)) + body = m.get("body") + if not isinstance(body, type(u"")): + raise TypeError("body must be string, not %s" % type(body)) + messages.append( (phase, body) ) + + app = self._relay.get_app(appid) + channel = app.get_channel(channelid) + response = channel.add_messages(side, messages) + # response includes both 'welcome' and 'messages' + return json_response(request, response) + +class MessageGetter(RelayResource): + def render_GET(self, request): + if b"text/event-stream" in (request.getHeader(b"accept") or b""): + raise TypeError("/get_messages is not for EventSource") + appid = request.args[b"appid"][0].decode("utf-8") + channelid = int(request.args[b"channelid"][0]) + app = self._relay.get_app(appid) + channel = app.get_channel(channelid) + response = channel.get_messages() + return json_response(request, response) + +class MessageWatcher(RelayResource): + def render_GET(self, request): + if b"text/event-stream" not in (request.getHeader(b"accept") or b""): + raise TypeError("/watch_messages is only for EventSource") + appid = request.args[b"appid"][0].decode("utf-8") + channelid = int(request.args[b"channelid"][0]) + app = self._relay.get_app(appid) + channel = app.get_channel(channelid) + + request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") + ep = EventsProtocol(request) + ep.sendEvent(json.dumps(self._welcome), name="welcome") + def _send(messages): + data = json.dumps([ {"phase": phase, "body": body} + for (phase, body) in messages ]) + ep.sendEvent(data) + + old_messages = list(channel.add_listener(_send)) + request.notifyFinish().addErrback(lambda f: + channel.remove_listener(_send)) + if old_messages: + _send(old_messages) return server.NOT_DONE_YET class Deallocator(RelayResource): @@ -216,14 +282,13 @@ class Channel: (self._appid, self._channelid)).fetchall(): if row["phase"] in (u"_allocate", u"_deallocate"): continue - yield json.dumps({"phase": row["phase"], "body": row["body"]}) + yield (row["phase"], row["body"]) def remove_listener(self, listener): self._listeners.discard(listener) - def broadcast_message(self, phase, body): - data = json.dumps({"phase": phase, "body": body}) + def broadcast_messages(self, messages): for listener in self._listeners: - listener(data) + listener(messages) def _add_message(self, side, phase, body): db = self._db @@ -237,9 +302,10 @@ class Channel: def allocate(self, side): self._add_message(side, ALLOCATE, None) - def add_message(self, side, phase, body): - self._add_message(side, phase, body) - self.broadcast_message(phase, body) + def add_messages(self, side, messages): + for (phase, body) in messages: + self._add_message(side, phase, body) + self.broadcast_messages(messages) return self.get_messages() def deallocate(self, side, mood): @@ -437,6 +503,9 @@ class Relay(resource.Resource, service.MultiService): self.putChild(b"allocate", Allocator(self, welcome)) self.putChild(b"add", Adder(self, welcome)) self.putChild(b"get", Getter(self, welcome)) + self.putChild(b"add_messages", MessageAdder(self, welcome)) + self.putChild(b"get_messages", MessageGetter(self, welcome)) + self.putChild(b"watch_messages", MessageWatcher(self, welcome)) self.putChild(b"deallocate", Deallocator(self, welcome)) def getChild(self, path, req): diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 4223eab..1e526a1 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -203,6 +203,15 @@ class API(ServerBase, unittest.TestCase): "phase": phase, "body": message}) + def add_messages(self, messages, side="abc"): + return self.post("add_messages", + {"appid": "app1", + "channelid": str(self.cid), + "side": side, + "messages": [{"phase": phase, "body": body} + for (phase, body) in messages], + }) + def parse_messages(self, messages): out = set() for m in messages: @@ -271,6 +280,71 @@ class API(ServerBase, unittest.TestCase): return d + def test_messages(self): + # exercise POST /add_messages and GET /get_messages + d = self.post("allocate", {"appid": "app1", "side": "abc"}) + def _allocated(data): + self.cid = data["channelid"] + d.addCallback(_allocated) + + d.addCallback(lambda _: self.add_messages([("1", "msg1A")])) + def _check1(data): + self.check_welcome(data) + self.failUnlessEqual(data["messages"], + [{"phase": "1", "body": "msg1A"}]) + d.addCallback(_check1) + d.addCallback(lambda _: self.get("get_messages", "app1", str(self.cid))) + d.addCallback(_check1) + d.addCallback(lambda _: self.add_messages([("1", "msg1B")], side="def")) + def _check2(data): + self.check_welcome(data) + self.failUnlessEqual(self.parse_messages(data["messages"]), + set([("1", "msg1A"), + ("1", "msg1B")])) + d.addCallback(_check2) + d.addCallback(lambda _: self.get("get_messages", "app1", str(self.cid))) + d.addCallback(_check2) + + # adding a duplicate message is not an error, is ignored by clients + d.addCallback(lambda _: self.add_messages([("1", "msg1B")], side="def")) + def _check3(data): + self.check_welcome(data) + self.failUnlessEqual(self.parse_messages(data["messages"]), + set([("1", "msg1A"), + ("1", "msg1B")])) + d.addCallback(_check3) + d.addCallback(lambda _: self.get("get_messages", "app1", str(self.cid))) + d.addCallback(_check3) + + d.addCallback(lambda _: self.add_messages([("2", "msg2A")], side="abc")) + def _check4(data): + self.check_welcome(data) + self.failUnlessEqual(self.parse_messages(data["messages"]), + set([("1", "msg1A"), + ("1", "msg1B"), + ("2", "msg2A"), + ])) + d.addCallback(_check4) + d.addCallback(lambda _: self.get("get_messages", "app1", str(self.cid))) + d.addCallback(_check4) + + d.addCallback(lambda _: self.add_messages([("3", "msg3A"), + ("4", "msg4A")], side="abc")) + def _check5(data): + self.check_welcome(data) + self.failUnlessEqual(self.parse_messages(data["messages"]), + set([("1", "msg1A"), + ("1", "msg1B"), + ("2", "msg2A"), + ("3", "msg3A"), + ("4", "msg4A"), + ])) + d.addCallback(_check5) + d.addCallback(lambda _: self.get("get_messages", "app1", str(self.cid))) + d.addCallback(_check5) + + return d + def test_watch_message(self): # exercise GET /get (the EventSource version) if sys.version_info[0] >= 3: @@ -316,6 +390,62 @@ class API(ServerBase, unittest.TestCase): d.addCallback(lambda _: self.o.wait_for_disconnection()) return d + def test_watch_messages(self): + # exercise GET /watch_messages (the EventSource version) + if sys.version_info[0] >= 3: + raise unittest.SkipTest("twisted vs py3") + + d = self.post("allocate", {"appid": "app1", "side": "abc"}) + def _allocated(data): + self.cid = data["channelid"] + url = self.build_url("watch_messages", "app1", self.cid) + self.o = OneEventAtATime(url, parser=json.loads) + return self.o.wait_for_connection() + d.addCallback(_allocated) + d.addCallback(lambda _: self.o.wait_for_next_event()) + def _check_welcome(ev): + eventtype, data = ev + self.failUnlessEqual(eventtype, "welcome") + self.failUnlessEqual(data, {"current_version": __version__}) + d.addCallback(_check_welcome) + d.addCallback(lambda _: self.add_message("msg1A")) + d.addCallback(lambda _: self.o.wait_for_next_event()) + def _check_msg1(ev): + eventtype, data = ev + self.failUnlessEqual(eventtype, "message") + self.failUnlessEqual(data, [{"phase": "1", "body": "msg1A"}]) + d.addCallback(_check_msg1) + + d.addCallback(lambda _: self.add_message("msg1B")) + d.addCallback(lambda _: self.add_message("msg2A", phase="2")) + d.addCallback(lambda _: self.o.wait_for_next_event()) + def _check_msg2(ev): + eventtype, data = ev + self.failUnlessEqual(eventtype, "message") + self.failUnlessEqual(data, [{"phase": "1", "body": "msg1B"}]) + d.addCallback(_check_msg2) + d.addCallback(lambda _: self.o.wait_for_next_event()) + def _check_msg3(ev): + eventtype, data = ev + self.failUnlessEqual(eventtype, "message") + self.failUnlessEqual(data, [{"phase": "2", "body": "msg2A"}]) + d.addCallback(_check_msg3) + + d.addCallback(lambda _: self.add_messages([("2", "msg2B"), + ("3", "msg3A")])) + d.addCallback(lambda _: self.o.wait_for_next_event()) + def _check_msg4(ev): + eventtype, data = ev + self.failUnlessEqual(eventtype, "message") + self.failUnlessEqual(data, [{"phase": "2", "body": "msg2B"}, + {"phase": "3", "body": "msg3A"}, + ]) + d.addCallback(_check_msg4) + + d.addCallback(lambda _: self.o.close()) + d.addCallback(lambda _: self.o.wait_for_disconnection()) + return d + class OneEventAtATime: def __init__(self, url, parser=lambda e: e): self.parser = parser