diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 32519b0..cac4f53 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -62,6 +62,7 @@ from ..util import dict_to_bytes, bytes_to_dict # -> {type: "claim", nameplate: str} -> mailbox # <- {type: "claimed", mailbox: str} # -> {type: "release"} +# .nameplate is optional, but must match previous claim() # <- {type: "released"} # # -> {type: "open", mailbox: str} -> message @@ -89,7 +90,9 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): self._side = None self._did_allocate = False # only one allocate() per websocket self._listening = False + self._did_claim = False self._nameplate_id = None + self._did_release = False self._mailbox = None def onConnect(self, request): @@ -125,7 +128,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): if mtype == "claim": return self.handle_claim(msg, server_rx) if mtype == "release": - return self.handle_release(server_rx) + return self.handle_release(msg, server_rx) if mtype == "open": return self.handle_open(msg, server_rx) @@ -172,6 +175,9 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): def handle_claim(self, msg, server_rx): if "nameplate" not in msg: raise Error("claim requires 'nameplate'") + if self._did_claim: + raise Error("only one claim per connection") + self._did_claim = True nameplate_id = msg["nameplate"] assert isinstance(nameplate_id, type("")), type(nameplate_id) self._nameplate_id = nameplate_id @@ -182,11 +188,21 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): raise Error("crowded") self.send("claimed", mailbox=mailbox_id) - def handle_release(self, server_rx): - if not self._nameplate_id: - raise Error("must claim a nameplate before releasing it") - self._app.release_nameplate(self._nameplate_id, self._side, server_rx) - self._nameplate_id = None + def handle_release(self, msg, server_rx): + if self._did_release: + raise Error("only one release per connection") + if "nameplate" in msg: + if self._nameplate_id is not None: + if msg["nameplate"] != self._nameplate_id: + raise Error("release and claim must use same nameplate") + nameplate_id = msg["nameplate"] + else: + if self._nameplate_id is None: + raise Error("release without nameplate must follow claim") + nameplate_id = self._nameplate_id + assert nameplate_id is not None + self._did_release = True + self._app.release_nameplate(nameplate_id, self._side, server_rx) self.send("released") diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index d2aaa83..40b34f7 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -754,6 +754,11 @@ class WebSocketAPI(_Util, ServerBase, unittest.TestCase): mailbox_id = m["mailbox"] self.assertEqual(type(mailbox_id), type("")) + c1.send("claim", nameplate="np1") + err = yield c1.next_non_ack() + self.assertEqual(err["type"], "error", err) + self.assertEqual(err["error"], "only one claim per connection") + nids = app.get_nameplate_ids() self.assertEqual(len(nids), 1) self.assertEqual("np1", list(nids)[0]) @@ -796,14 +801,14 @@ class WebSocketAPI(_Util, ServerBase, unittest.TestCase): err = yield c1.next_non_ack() self.assertEqual(err["type"], "error") self.assertEqual(err["error"], - "must claim a nameplate before releasing it") + "release without nameplate must follow claim") c1.send("claim", nameplate="np1") yield c1.next_non_ack() c1.send("release") m = yield c1.next_non_ack() - self.assertEqual(m["type"], "released") + self.assertEqual(m["type"], "released", m) np_row, side_rows = self._nameplate(app, "np1") claims = [(row["side"], row["claimed"]) for row in side_rows] @@ -813,8 +818,45 @@ class WebSocketAPI(_Util, ServerBase, unittest.TestCase): c1.send("release") # no longer claimed err = yield c1.next_non_ack() self.assertEqual(err["type"], "error") + self.assertEqual(err["error"], "only one release per connection") + + @inlineCallbacks + def test_release_named(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send("bind", appid="appid", side="side") + + c1.send("claim", nameplate="np1") + yield c1.next_non_ack() + + c1.send("release", nameplate="np1") + m = yield c1.next_non_ack() + self.assertEqual(m["type"], "released", m) + + @inlineCallbacks + def test_release_named_ignored(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send("bind", appid="appid", side="side") + + c1.send("release", nameplate="np1") # didn't do claim first, ignored + m = yield c1.next_non_ack() + self.assertEqual(m["type"], "released", m) + + @inlineCallbacks + def test_release_named_mismatch(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send("bind", appid="appid", side="side") + + c1.send("claim", nameplate="np1") + yield c1.next_non_ack() + + c1.send("release", nameplate="np2") # mismatching nameplate + err = yield c1.next_non_ack() + self.assertEqual(err["type"], "error") self.assertEqual(err["error"], - "must claim a nameplate before releasing it") + "release and claim must use same nameplate") @inlineCallbacks def test_open(self): @@ -934,6 +976,64 @@ class WebSocketAPI(_Util, ServerBase, unittest.TestCase): yield d self.assertFalse(mb1.has_listeners()) + @inlineCallbacks + def test_interrupted_client_nameplate(self): + # a client's interactions with the server might be split over + # multiple sequential WebSocket connections, e.g. when the server is + # bounced and the client reconnects, or vice versa + c = yield self.make_client() + yield c.next_non_ack() + c.send("bind", appid="appid", side="side") + app = self._rendezvous.get_app("appid") + + c.send("claim", nameplate="np1") + m = yield c.next_non_ack() + self.assertEqual(m["type"], "claimed") + mailbox_id = m["mailbox"] + self.assertEqual(type(mailbox_id), type("")) + np_row, side_rows = self._nameplate(app, "np1") + claims = [(row["side"], row["claimed"]) for row in side_rows] + self.assertEqual(claims, [("side", True)]) + c.close() + yield c.d + + c = yield self.make_client() + yield c.next_non_ack() + c.send("bind", appid="appid", side="side") + c.send("claim", nameplate="np1") # idempotent + m = yield c.next_non_ack() + self.assertEqual(m["type"], "claimed") + self.assertEqual(m["mailbox"], mailbox_id) # mailbox id is stable + np_row, side_rows = self._nameplate(app, "np1") + claims = [(row["side"], row["claimed"]) for row in side_rows] + self.assertEqual(claims, [("side", True)]) + c.close() + yield c.d + + c = yield self.make_client() + yield c.next_non_ack() + c.send("bind", appid="appid", side="side") + # we haven't done a claim with this particular connection, but we can + # still send a release as long as we include the nameplate + c.send("release", nameplate="np1") # release-without-claim + m = yield c.next_non_ack() + self.assertEqual(m["type"], "released") + np_row, side_rows = self._nameplate(app, "np1") + self.assertEqual(np_row, None) + c.close() + yield c.d + + c = yield self.make_client() + yield c.next_non_ack() + c.send("bind", appid="appid", side="side") + # and the release is idempotent, when done on separate connections + c.send("release", nameplate="np1") + m = yield c.next_non_ack() + self.assertEqual(m["type"], "released") + np_row, side_rows = self._nameplate(app, "np1") + self.assertEqual(np_row, None) + c.close() + yield c.d class Summary(unittest.TestCase): def test_mailbox(self):