diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 90364ed..c94649a 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -40,7 +40,8 @@ class Channel: (self._appid, self._channelid)).fetchall(): if row["phase"] in (u"_allocate", u"_deallocate"): continue - messages.append({"phase": row["phase"], "body": row["body"]}) + messages.append({"phase": row["phase"], "body": row["body"], + "server_rx": row["server_rx"]}) return messages def add_listener(self, ep): @@ -50,9 +51,10 @@ class Channel: def remove_listener(self, ep): self._listeners.discard(ep) - def broadcast_message(self, phase, body): + def broadcast_message(self, phase, body, server_rx): for ep in self._listeners: - ep.send_rendezvous_event({"phase": phase, "body": body}) + ep.send_rendezvous_event({"phase": phase, "body": body, + "server_rx": server_rx}) def _add_message(self, side, phase, body, server_rx): db = self._db @@ -68,7 +70,7 @@ class Channel: def add_message(self, side, phase, body, server_rx): self._add_message(side, phase, body, server_rx) - self.broadcast_message(phase, body) + self.broadcast_message(phase, body, server_rx) return self.get_messages() # for rendezvous_web.py POST /add def deallocate(self, side, mood): diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index d0a4d0d..71a1239 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -55,6 +55,14 @@ class Reachable(ServerBase, unittest.TestCase): def unjson(data): return json.loads(data.decode("utf-8")) +def strip_message(msg): + m2 = msg.copy() + m2.pop("server_rx", None) + return m2 + +def strip_messages(messages): + return [strip_message(m) for m in messages] + class WebAPI(ServerBase, unittest.TestCase): def build_url(self, path, appid, channelid): url = self.relayurl+path @@ -230,7 +238,7 @@ class WebAPI(ServerBase, unittest.TestCase): d.addCallback(lambda _: self.add_message("msg1A")) def _check1(data): self.check_welcome(data) - self.failUnlessEqual(data["messages"], + self.failUnlessEqual(strip_messages(data["messages"]), [{"phase": "1", "body": "msg1A"}]) d.addCallback(_check1) d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) @@ -238,7 +246,7 @@ class WebAPI(ServerBase, unittest.TestCase): d.addCallback(lambda _: self.add_message("msg1B", side="def")) def _check2(data): self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(data["messages"]), + self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])), set([("1", "msg1A"), ("1", "msg1B")])) d.addCallback(_check2) @@ -249,7 +257,7 @@ class WebAPI(ServerBase, unittest.TestCase): d.addCallback(lambda _: self.add_message("msg1B", side="def")) def _check3(data): self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(data["messages"]), + self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])), set([("1", "msg1A"), ("1", "msg1B")])) d.addCallback(_check3) @@ -260,7 +268,7 @@ class WebAPI(ServerBase, unittest.TestCase): phase="2")) def _check4(data): self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(data["messages"]), + self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])), set([("1", "msg1A"), ("1", "msg1B"), ("2", "msg2A"), @@ -300,7 +308,8 @@ class WebAPI(ServerBase, unittest.TestCase): eventtype, data = ev self.failUnlessEqual(eventtype, "message") data.pop("sent", None) - self.failUnlessEqual(data, {"phase": "1", "body": "msg1A"}) + self.failUnlessEqual(strip_message(data), + {"phase": "1", "body": "msg1A"}) d.addCallback(_check_msg1) d.addCallback(lambda _: self.add_message("msg1B")) @@ -310,14 +319,16 @@ class WebAPI(ServerBase, unittest.TestCase): eventtype, data = ev self.failUnlessEqual(eventtype, "message") data.pop("sent", None) - self.failUnlessEqual(data, {"phase": "1", "body": "msg1B"}) + self.failUnlessEqual(strip_message(data), + {"phase": "1", "body": "msg1B"}) d.addCallback(_check_msg2) d.addCallback(lambda _: self.o.wait_for_next_event()) def _check_msg3(ev): eventtype, data = ev self.failUnlessEqual(eventtype, "message") data.pop("sent", None) - self.failUnlessEqual(data, {"phase": "2", "body": "msg2A"}) + self.failUnlessEqual(strip_message(data), + {"phase": "2", "body": "msg2A"}) d.addCallback(_check_msg3) d.addCallback(lambda _: self.o.close()) @@ -584,7 +595,8 @@ class WebSocketAPI(ServerBase, unittest.TestCase): yield c2.sync() self.assertEqual(app.get_allocated(), set([cid])) - self.assertEqual(channel.get_messages(), [{"phase": "1", "body": ""}]) + self.assertEqual(strip_messages(channel.get_messages()), + [{"phase": "1", "body": ""}]) c1.send(u"list") msg = yield c1.next_event() @@ -666,12 +678,13 @@ class WebSocketAPI(ServerBase, unittest.TestCase): c1.send(u"add", phase="1", body="msg1A") yield c1.sync() - self.assertEqual(channel.get_messages(), + self.assertEqual(strip_messages(channel.get_messages()), [{"phase": "1", "body": "msg1A"}]) self.assertEqual(len(c1.events), 1) # echo should be sent right away msg = yield c1.next_event() self.assertEqual(msg["type"], "message") - self.assertEqual(msg["message"], {"phase": "1", "body": "msg1A"}) + self.assertEqual(strip_message(msg["message"]), + {"phase": "1", "body": "msg1A"}) self.assertIn("sent", msg) self.assertIsInstance(msg["sent"], float) @@ -680,13 +693,15 @@ class WebSocketAPI(ServerBase, unittest.TestCase): msg = yield c1.next_event() self.assertEqual(msg["type"], "message") - self.assertEqual(msg["message"], {"phase": "1", "body": "msg1B"}) + self.assertEqual(strip_message(msg["message"]), + {"phase": "1", "body": "msg1B"}) msg = yield c1.next_event() self.assertEqual(msg["type"], "message") - self.assertEqual(msg["message"], {"phase": "2", "body": "msg2A"}) + self.assertEqual(strip_message(msg["message"]), + {"phase": "2", "body": "msg2A"}) - self.assertEqual(channel.get_messages(), [ + self.assertEqual(strip_messages(channel.get_messages()), [ {"phase": "1", "body": "msg1A"}, {"phase": "1", "body": "msg1B"}, {"phase": "2", "body": "msg2A"}, @@ -703,15 +718,18 @@ class WebSocketAPI(ServerBase, unittest.TestCase): msg = yield c2.next_event() self.assertEqual(msg["type"], "message") - self.assertEqual(msg["message"], {"phase": "1", "body": "msg1A"}) + self.assertEqual(strip_message(msg["message"]), + {"phase": "1", "body": "msg1A"}) msg = yield c2.next_event() self.assertEqual(msg["type"], "message") - self.assertEqual(msg["message"], {"phase": "1", "body": "msg1B"}) + self.assertEqual(strip_message(msg["message"]), + {"phase": "1", "body": "msg1B"}) msg = yield c2.next_event() self.assertEqual(msg["type"], "message") - self.assertEqual(msg["message"], {"phase": "2", "body": "msg2A"}) + self.assertEqual(strip_message(msg["message"]), + {"phase": "2", "body": "msg2A"}) # adding a duplicate is not an error, and clients will ignore it c1.send(u"add", phase="2", body="msg2A") @@ -719,7 +737,8 @@ class WebSocketAPI(ServerBase, unittest.TestCase): # the duplicate message *does* get stored, and delivered msg = yield c2.next_event() self.assertEqual(msg["type"], "message") - self.assertEqual(msg["message"], {"phase": "2", "body": "msg2A"}) + self.assertEqual(strip_message(msg["message"]), + {"phase": "2", "body": "msg2A"}) class Summary(unittest.TestCase):