diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index afd144a..2e88c5b 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -52,31 +52,22 @@ class Channel: return (phase, body) return None - def send(self, phase, body): - return self.send_many([(phase, body)]) - - def send_many(self, messages): + def send(self, phase, msg): # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. - payload_messages = [] - for (phase, body) in messages: - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - if not isinstance(body, type(b"")): raise TypeError(type(body)) - self._sent_messages.add( (phase,body) ) - payload_messages.append({"phase": phase, - "body": hexlify(body).decode("ascii")}) + if not isinstance(phase, type(u"")): raise TypeError(type(phase)) + if not isinstance(msg, type(b"")): raise TypeError(type(msg)) + self._sent_messages.add( (phase,msg) ) payload = {"appid": self._appid, "channelid": self._channelid, "side": self._side, - "messages": payload_messages, - } + "phase": phase, + "body": hexlify(msg).decode("ascii")} data = json.dumps(payload).encode("utf-8") - r = requests.post(self._relay_url+"add_messages", data=data, + r = requests.post(self._relay_url+"add", data=data, timeout=self._timeout) r.raise_for_status() resp = r.json() - if "welcome" in resp: - self._handle_welcome(resp["welcome"]) self._add_inbound_messages(resp["messages"]) def get_first_of(self, phases): @@ -99,15 +90,15 @@ class Channel: raise Timeout queryargs = urlencode([("appid", self._appid), ("channelid", self._channelid)]) - url = self._relay_url + "watch_messages?%s" % queryargs - f = EventSourceFollower(url, remaining) + f = EventSourceFollower(self._relay_url+"get?%s" % queryargs, + remaining) # we loop here until the connection is lost, or we see the # message we want for (eventtype, data) in f.iter_events(): if eventtype == "welcome": self._handle_welcome(json.loads(data)) if eventtype == "message": - self._add_inbound_messages(json.loads(data)) + self._add_inbound_messages([json.loads(data)]) phase_and_body = self._find_inbound_message(phases) if phase_and_body: f.close() diff --git a/src/wormhole/servers/relay_server.py b/src/wormhole/servers/relay_server.py index c1ebe56..862da9d 100644 --- a/src/wormhole/servers/relay_server.py +++ b/src/wormhole/servers/relay_server.py @@ -57,9 +57,6 @@ class EventsProtocol: # POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES} # GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES} # GET /get?appid=&channelid= (eventsource) -> {phase:, body:}.. -# POST /add_messages {appid:,channelid:,side:,messages:}-> {messages: MESSAGES} -# GET /get_messages?appid=&channelid= -> {messages: MESSAGES} -# GET /watch_messages?appid=&channelid= (eventsource)-> {[phase:, body:]..}.. # POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted} # all JSON responses include a "welcome:{..}" key @@ -139,7 +136,7 @@ class Adder(RelayResource): app = self._relay.get_app(appid) channel = app.get_channel(channelid) - response = channel.add_messages(side, [(phase, body)]) + response = channel.add_message(side, phase, body) # response is generated with get_messages(), so it includes both # 'welcome' and 'messages' return json_response(request, response) @@ -159,74 +156,11 @@ class Getter(RelayResource): request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") ep = EventsProtocol(request) ep.sendEvent(json.dumps(self._welcome), name="welcome") - def _send(messages): - for (phase, body) in messages: - data = json.dumps({"phase": phase, "body": body}) - ep.sendEvent(data) - old_messages = list(channel.add_listener(_send)) + old_events = channel.add_listener(ep.sendEvent) request.notifyFinish().addErrback(lambda f: - channel.remove_listener(_send)) - if old_messages: - _send(old_messages) - return server.NOT_DONE_YET - -class MessageAdder(RelayResource): - def render_POST(self, request): - #content = json.load(request.content, encoding="utf-8") - content = request.content.read() - data = json.loads(content.decode("utf-8")) - appid = data["appid"] - channelid = int(data["channelid"]) - side = data["side"] - messages = [] - for m in data["messages"]: - phase = m.get("phase") - if not isinstance(phase, type(u"")): - raise TypeError("phase must be string, not %s" % type(phase)) - body = m.get("body") - if not isinstance(body, type(u"")): - raise TypeError("body must be string, not %s" % type(body)) - messages.append( (phase, body) ) - - app = self._relay.get_app(appid) - channel = app.get_channel(channelid) - response = channel.add_messages(side, messages) - # response includes both 'welcome' and 'messages' - return json_response(request, response) - -class MessageGetter(RelayResource): - def render_GET(self, request): - if b"text/event-stream" in (request.getHeader(b"accept") or b""): - raise TypeError("/get_messages is not for EventSource") - appid = request.args[b"appid"][0].decode("utf-8") - channelid = int(request.args[b"channelid"][0]) - app = self._relay.get_app(appid) - channel = app.get_channel(channelid) - response = channel.get_messages() - return json_response(request, response) - -class MessageWatcher(RelayResource): - def render_GET(self, request): - if b"text/event-stream" not in (request.getHeader(b"accept") or b""): - raise TypeError("/watch_messages is only for EventSource") - appid = request.args[b"appid"][0].decode("utf-8") - channelid = int(request.args[b"channelid"][0]) - app = self._relay.get_app(appid) - channel = app.get_channel(channelid) - - request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") - ep = EventsProtocol(request) - ep.sendEvent(json.dumps(self._welcome), name="welcome") - def _send(messages): - data = json.dumps([ {"phase": phase, "body": body} - for (phase, body) in messages ]) - ep.sendEvent(data) - - old_messages = list(channel.add_listener(_send)) - request.notifyFinish().addErrback(lambda f: - channel.remove_listener(_send)) - if old_messages: - _send(old_messages) + channel.remove_listener(ep.sendEvent)) + for old_event in old_events: + ep.sendEvent(old_event) return server.NOT_DONE_YET class Deallocator(RelayResource): @@ -282,13 +216,14 @@ class Channel: (self._appid, self._channelid)).fetchall(): if row["phase"] in (u"_allocate", u"_deallocate"): continue - yield (row["phase"], row["body"]) + yield json.dumps({"phase": row["phase"], "body": row["body"]}) def remove_listener(self, listener): self._listeners.discard(listener) - def broadcast_messages(self, messages): + def broadcast_message(self, phase, body): + data = json.dumps({"phase": phase, "body": body}) for listener in self._listeners: - listener(messages) + listener(data) def _add_message(self, side, phase, body): db = self._db @@ -302,10 +237,9 @@ class Channel: def allocate(self, side): self._add_message(side, ALLOCATE, None) - def add_messages(self, side, messages): - for (phase, body) in messages: - self._add_message(side, phase, body) - self.broadcast_messages(messages) + def add_message(self, side, phase, body): + self._add_message(side, phase, body) + self.broadcast_message(phase, body) return self.get_messages() def deallocate(self, side, mood): @@ -503,9 +437,6 @@ class Relay(resource.Resource, service.MultiService): self.putChild(b"allocate", Allocator(self, welcome)) self.putChild(b"add", Adder(self, welcome)) self.putChild(b"get", Getter(self, welcome)) - self.putChild(b"add_messages", MessageAdder(self, welcome)) - self.putChild(b"get_messages", MessageGetter(self, welcome)) - self.putChild(b"watch_messages", MessageWatcher(self, welcome)) self.putChild(b"deallocate", Deallocator(self, welcome)) def getChild(self, path, req): diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 1e526a1..4223eab 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -203,15 +203,6 @@ class API(ServerBase, unittest.TestCase): "phase": phase, "body": message}) - def add_messages(self, messages, side="abc"): - return self.post("add_messages", - {"appid": "app1", - "channelid": str(self.cid), - "side": side, - "messages": [{"phase": phase, "body": body} - for (phase, body) in messages], - }) - def parse_messages(self, messages): out = set() for m in messages: @@ -280,71 +271,6 @@ class API(ServerBase, unittest.TestCase): return d - def test_messages(self): - # exercise POST /add_messages and GET /get_messages - d = self.post("allocate", {"appid": "app1", "side": "abc"}) - def _allocated(data): - self.cid = data["channelid"] - d.addCallback(_allocated) - - d.addCallback(lambda _: self.add_messages([("1", "msg1A")])) - def _check1(data): - self.check_welcome(data) - self.failUnlessEqual(data["messages"], - [{"phase": "1", "body": "msg1A"}]) - d.addCallback(_check1) - d.addCallback(lambda _: self.get("get_messages", "app1", str(self.cid))) - d.addCallback(_check1) - d.addCallback(lambda _: self.add_messages([("1", "msg1B")], side="def")) - def _check2(data): - self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(data["messages"]), - set([("1", "msg1A"), - ("1", "msg1B")])) - d.addCallback(_check2) - d.addCallback(lambda _: self.get("get_messages", "app1", str(self.cid))) - d.addCallback(_check2) - - # adding a duplicate message is not an error, is ignored by clients - d.addCallback(lambda _: self.add_messages([("1", "msg1B")], side="def")) - def _check3(data): - self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(data["messages"]), - set([("1", "msg1A"), - ("1", "msg1B")])) - d.addCallback(_check3) - d.addCallback(lambda _: self.get("get_messages", "app1", str(self.cid))) - d.addCallback(_check3) - - d.addCallback(lambda _: self.add_messages([("2", "msg2A")], side="abc")) - def _check4(data): - self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(data["messages"]), - set([("1", "msg1A"), - ("1", "msg1B"), - ("2", "msg2A"), - ])) - d.addCallback(_check4) - d.addCallback(lambda _: self.get("get_messages", "app1", str(self.cid))) - d.addCallback(_check4) - - d.addCallback(lambda _: self.add_messages([("3", "msg3A"), - ("4", "msg4A")], side="abc")) - def _check5(data): - self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(data["messages"]), - set([("1", "msg1A"), - ("1", "msg1B"), - ("2", "msg2A"), - ("3", "msg3A"), - ("4", "msg4A"), - ])) - d.addCallback(_check5) - d.addCallback(lambda _: self.get("get_messages", "app1", str(self.cid))) - d.addCallback(_check5) - - return d - def test_watch_message(self): # exercise GET /get (the EventSource version) if sys.version_info[0] >= 3: @@ -390,62 +316,6 @@ class API(ServerBase, unittest.TestCase): d.addCallback(lambda _: self.o.wait_for_disconnection()) return d - def test_watch_messages(self): - # exercise GET /watch_messages (the EventSource version) - if sys.version_info[0] >= 3: - raise unittest.SkipTest("twisted vs py3") - - d = self.post("allocate", {"appid": "app1", "side": "abc"}) - def _allocated(data): - self.cid = data["channelid"] - url = self.build_url("watch_messages", "app1", self.cid) - self.o = OneEventAtATime(url, parser=json.loads) - return self.o.wait_for_connection() - d.addCallback(_allocated) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_welcome(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "welcome") - self.failUnlessEqual(data, {"current_version": __version__}) - d.addCallback(_check_welcome) - d.addCallback(lambda _: self.add_message("msg1A")) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_msg1(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "message") - self.failUnlessEqual(data, [{"phase": "1", "body": "msg1A"}]) - d.addCallback(_check_msg1) - - d.addCallback(lambda _: self.add_message("msg1B")) - d.addCallback(lambda _: self.add_message("msg2A", phase="2")) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_msg2(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "message") - self.failUnlessEqual(data, [{"phase": "1", "body": "msg1B"}]) - d.addCallback(_check_msg2) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_msg3(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "message") - self.failUnlessEqual(data, [{"phase": "2", "body": "msg2A"}]) - d.addCallback(_check_msg3) - - d.addCallback(lambda _: self.add_messages([("2", "msg2B"), - ("3", "msg3A")])) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_msg4(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "message") - self.failUnlessEqual(data, [{"phase": "2", "body": "msg2B"}, - {"phase": "3", "body": "msg3A"}, - ]) - d.addCallback(_check_msg4) - - d.addCallback(lambda _: self.o.close()) - d.addCallback(lambda _: self.o.wait_for_disconnection()) - return d - class OneEventAtATime: def __init__(self, url, parser=lambda e: e): self.parser = parser diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 4d8e902..6ea9e41 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -94,30 +94,18 @@ class Channel: return (phase, body) return None - def send(self, phase, body): - return self.send_many([(phase, body)]) - - def send_many(self, messages): + def send(self, phase, msg): # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. - payload_messages = [] - for (phase, body) in messages: - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - if not isinstance(body, type(b"")): raise TypeError(type(body)) - self._sent_messages.add( (phase,body) ) - payload_messages.append({"phase": phase, - "body": hexlify(body).decode("ascii")}) + if not isinstance(phase, type(u"")): raise TypeError(type(phase)) + if not isinstance(msg, type(b"")): raise TypeError(type(msg)) + self._sent_messages.add( (phase,msg) ) payload = {"appid": self._appid, "channelid": self._channelid, "side": self._side, - "messages": payload_messages, - } - d = post_json(self._agent, self._relay_url+"add_messages", payload) - def _maybe_handle_welcome(resp): - if "welcome" in resp: - self._handle_welcome(resp["welcome"]) - return resp - d.addCallback(_maybe_handle_welcome) + "phase": phase, + "body": hexlify(msg).decode("ascii")} + d = post_json(self._agent, self._relay_url+"add", payload) d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"])) return d @@ -139,7 +127,7 @@ class Channel: if name == "welcome": self._handle_welcome(json.loads(data)) if name == "message": - self._add_inbound_messages(json.loads(data)) + self._add_inbound_messages([json.loads(data)]) phase_and_body = self._find_inbound_message(phases) if phase_and_body is not None and not msgs: msgs.append(phase_and_body) @@ -147,8 +135,8 @@ class Channel: # TODO: use agent=self._agent queryargs = urlencode([("appid", self._appid), ("channelid", self._channelid)]) - url = self._relay_url + "watch_messages?%s" % queryargs - es = ReconnectingEventSource(url, _handle) + es = ReconnectingEventSource(self._relay_url+"get?%s" % queryargs, + _handle) es.startService() # TODO: .setServiceParent(self) es.activate() d.addCallback(lambda _: es.deactivate())