From 5530c33185c3a62f222b390be01835bee9eab3e4 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 5 May 2016 19:02:52 -0700 Subject: [PATCH] rdv_ws: send acks for each message but only if the client is modern enough to include "id" in the message, which lets us avoid sending acks to an 0.7.5 client (which would cause them to abort, they don't like unrecognized server messages). The acks let the client learn the server_rx time of messages that terminate on the server, like "allocate" and "claim". --- src/wormhole/server/rendezvous_websocket.py | 5 ++ src/wormhole/test/test_server.py | 82 ++++++++++++--------- src/wormhole/twisted/transcribe.py | 3 + 3 files changed, 55 insertions(+), 35 deletions(-) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 60e381a..93abf2c 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -71,6 +71,11 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): try: if "type" not in msg: raise Error("missing 'type'") + if "id" in msg: + # Only ack clients modern enough to include [id]. Older ones + # won't recognize the message, then they'll abort. + self.send("ack", id=msg["id"]) + mtype = msg["type"] if mtype == "ping": return self.handle_ping(msg) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index a33e7b8..d9010f1 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -410,6 +410,16 @@ class WSClient(websocket.WebSocketClientProtocol): self.d = defer.Deferred() return self.d + @inlineCallbacks + def next_non_ack(self): + while True: + m = yield self.next_event() + if m["type"] != "ack": + returnValue(m) + + def strip_acks(self): + self.events = [e for e in self.events if e["type"] != u"ack"] + def send(self, mtype, **kwargs): kwargs["type"] = mtype payload = json.dumps(kwargs).encode("utf-8") @@ -462,9 +472,9 @@ class WSClientSync(unittest.TestCase): d.addBoth(sunc.append) add("pong", pong=1) yield d - m = yield c.next_event() + m = yield c.next_non_ack() self.assertEqual(m["type"], "one") - m = yield c.next_event() + m = yield c.next_non_ack() self.assertEqual(m["type"], "two") self.assertEqual(c.events, []) @@ -476,9 +486,9 @@ class WSClientSync(unittest.TestCase): add("two", two=2) add("pong", pong=2) yield d - m = yield c.next_event() + m = yield c.next_non_ack() self.assertEqual(m["type"], "one") - m = yield c.next_event() + m = yield c.next_non_ack() self.assertEqual(m["type"], "two") self.assertEqual(c.events, []) @@ -490,9 +500,9 @@ class WSClientSync(unittest.TestCase): add("two", two=2) add("pong", pong=3) yield d - m = yield c.next_event() + m = yield c.next_non_ack() self.assertEqual(m["type"], "one") - m = yield c.next_event() + m = yield c.next_non_ack() self.assertEqual(m["type"], "two") self.assertEqual(c.events, []) @@ -524,14 +534,14 @@ class WebSocketAPI(ServerBase, unittest.TestCase): @inlineCallbacks def test_welcome(self): c1 = yield self.make_client() - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.check_welcome(msg) self.assertEqual(self._rendezvous._apps, {}) @inlineCallbacks def test_allocate_1(self): c1 = yield self.make_client() - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.check_welcome(msg) c1.send(u"bind", appid=u"appid", side=u"side") yield c1.sync() @@ -539,12 +549,12 @@ class WebSocketAPI(ServerBase, unittest.TestCase): app = self._rendezvous.get_app(u"appid") self.assertEqual(app.get_allocated(), set()) c1.send(u"list") - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"channelids") self.assertEqual(msg["channelids"], []) c1.send(u"allocate") - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"allocated") cid = msg["channelid"] self.failUnlessIsInstance(cid, int) @@ -553,32 +563,32 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(channel.get_messages(), []) c1.send(u"list") - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"channelids") self.assertEqual(msg["channelids"], [cid]) c1.send(u"deallocate") - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"deallocated") self.assertEqual(msg["status"], u"deleted") self.assertEqual(app.get_allocated(), set()) c1.send(u"list") - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"channelids") self.assertEqual(msg["channelids"], []) @inlineCallbacks def test_allocate_2(self): c1 = yield self.make_client() - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.check_welcome(msg) c1.send(u"bind", appid=u"appid", side=u"side") yield c1.sync() app = self._rendezvous.get_app(u"appid") self.assertEqual(app.get_allocated(), set()) c1.send(u"allocate") - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"allocated") cid = msg["channelid"] self.failUnlessIsInstance(cid, int) @@ -588,7 +598,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): # second caller increases the number of known sides to 2 c2 = yield self.make_client() - msg = yield c2.next_event() + msg = yield c2.next_non_ack() self.check_welcome(msg) c2.send(u"bind", appid=u"appid", side=u"side-2") c2.send(u"claim", channelid=cid) @@ -600,38 +610,38 @@ class WebSocketAPI(ServerBase, unittest.TestCase): [{"phase": "1", "body": ""}]) c1.send(u"list") - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"channelids") self.assertEqual(msg["channelids"], [cid]) c2.send(u"list") - msg = yield c2.next_event() + msg = yield c2.next_non_ack() self.assertEqual(msg["type"], u"channelids") self.assertEqual(msg["channelids"], [cid]) c1.send(u"deallocate") - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"deallocated") self.assertEqual(msg["status"], u"waiting") c2.send(u"deallocate") - msg = yield c2.next_event() + msg = yield c2.next_non_ack() self.assertEqual(msg["type"], u"deallocated") self.assertEqual(msg["status"], u"deleted") c2.send(u"list") - msg = yield c2.next_event() + msg = yield c2.next_non_ack() self.assertEqual(msg["type"], u"channelids") self.assertEqual(msg["channelids"], []) @inlineCallbacks def test_allocate_and_claim(self): c1 = yield self.make_client() - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.check_welcome(msg) c1.send(u"bind", appid=u"appid", side=u"side") c1.send(u"allocate") - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"allocated") cid = msg["channelid"] c1.send(u"claim", channelid=cid) @@ -642,11 +652,11 @@ class WebSocketAPI(ServerBase, unittest.TestCase): @inlineCallbacks def test_allocate_and_claim_different(self): c1 = yield self.make_client() - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.check_welcome(msg) c1.send(u"bind", appid=u"appid", side=u"side") c1.send(u"allocate") - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"allocated") cid = msg["channelid"] c1.send(u"claim", channelid=cid+1) @@ -661,11 +671,11 @@ class WebSocketAPI(ServerBase, unittest.TestCase): @inlineCallbacks def test_message(self): c1 = yield self.make_client() - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.check_welcome(msg) c1.send(u"bind", appid=u"appid", side=u"side") c1.send(u"allocate") - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"allocated") cid = msg["channelid"] app = self._rendezvous.get_app(u"appid") @@ -675,14 +685,16 @@ class WebSocketAPI(ServerBase, unittest.TestCase): c1.send(u"watch") yield c1.sync() self.assertEqual(len(channel._listeners), 1) + c1.strip_acks() self.assertEqual(c1.events, []) c1.send(u"add", phase="1", body="msg1A") yield c1.sync() + c1.strip_acks() self.assertEqual(strip_messages(channel.get_messages()), [{"phase": "1", "body": "msg1A"}]) self.assertEqual(len(c1.events), 1) # echo should be sent right away - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.assertEqual(msg["type"], "message") self.assertEqual(strip_message(msg["message"]), {"phase": "1", "body": "msg1A"}) @@ -692,12 +704,12 @@ class WebSocketAPI(ServerBase, unittest.TestCase): c1.send(u"add", phase="1", body="msg1B") c1.send(u"add", phase="2", body="msg2A") - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.assertEqual(msg["type"], "message") self.assertEqual(strip_message(msg["message"]), {"phase": "1", "body": "msg1B"}) - msg = yield c1.next_event() + msg = yield c1.next_non_ack() self.assertEqual(msg["type"], "message") self.assertEqual(strip_message(msg["message"]), {"phase": "2", "body": "msg2A"}) @@ -710,24 +722,24 @@ class WebSocketAPI(ServerBase, unittest.TestCase): # second client should see everything c2 = yield self.make_client() - msg = yield c2.next_event() + msg = yield c2.next_non_ack() self.check_welcome(msg) c2.send(u"bind", appid=u"appid", side=u"side") c2.send(u"claim", channelid=cid) # 'watch' triggers delivery of old messages, in temporal order c2.send(u"watch") - msg = yield c2.next_event() + msg = yield c2.next_non_ack() self.assertEqual(msg["type"], "message") self.assertEqual(strip_message(msg["message"]), {"phase": "1", "body": "msg1A"}) - msg = yield c2.next_event() + msg = yield c2.next_non_ack() self.assertEqual(msg["type"], "message") self.assertEqual(strip_message(msg["message"]), {"phase": "1", "body": "msg1B"}) - msg = yield c2.next_event() + msg = yield c2.next_non_ack() self.assertEqual(msg["type"], "message") self.assertEqual(strip_message(msg["message"]), {"phase": "2", "body": "msg2A"}) @@ -736,7 +748,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): c1.send(u"add", phase="2", body="msg2A") # the duplicate message *does* get stored, and delivered - msg = yield c2.next_event() + msg = yield c2.next_non_ack() self.assertEqual(msg["type"], "message") self.assertEqual(strip_message(msg["message"]), {"phase": "2", "body": "msg2A"}) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index d6a7c5d..678c142 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -139,6 +139,9 @@ class Wormhole: return return meth(msg) + def _ws_handle_ack(self, msg): + pass + def _ws_handle_welcome(self, msg): self._timing.add("welcome").server_sent(msg["server_tx"]) welcome = msg["welcome"]