diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index cac4f53..1349849 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -71,6 +71,7 @@ from ..util import dict_to_bytes, bytes_to_dict # -> {type: "add", phase: str, body: hex} # will send echo in a "message" # # -> {type: "close", mood: str} -> closed +# .mailbox is optional, but must match previous open() # <- {type: "closed"} # # <- {type: "error", error: str, orig: {}} # in response to malformed msgs @@ -93,7 +94,10 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): self._did_claim = False self._nameplate_id = None self._did_release = False + self._did_open = False self._mailbox = None + self._mailbox_id = None + self._did_close = False def onConnect(self, request): rv = self.factory.rendezvous @@ -208,11 +212,12 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): def handle_open(self, msg, server_rx): if self._mailbox: - raise Error("you already have a mailbox open") + raise Error("only one open per connection") if "mailbox" not in msg: raise Error("open requires 'mailbox'") mailbox_id = msg["mailbox"] assert isinstance(mailbox_id, type("")) + self._mailbox_id = mailbox_id self._mailbox = self._app.open_mailbox(mailbox_id, self._side, server_rx) def _send(sm): @@ -238,11 +243,24 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): self._mailbox.add_message(sm) def handle_close(self, msg, server_rx): + if self._did_close: + raise Error("only one close per connection") + if "mailbox" in msg: + if self._mailbox_id is not None: + if msg["mailbox"] != self._mailbox_id: + raise Error("open and close must use same mailbox") + mailbox_id = msg["mailbox"] + else: + if self._mailbox_id is None: + raise Error("close without mailbox must follow open") + mailbox_id = self._mailbox_id if not self._mailbox: - raise Error("must open mailbox before closing") + self._mailbox = self._app.open_mailbox(mailbox_id, self._side, + server_rx) if self._listening: self._mailbox.remove_listener(self) self._listening = False + self._did_close = True self._mailbox.close(self._side, msg.get("mood"), server_rx) self._mailbox = None self.send("closed") diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 40b34f7..3fc1274 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -891,7 +891,7 @@ class WebSocketAPI(_Util, ServerBase, unittest.TestCase): c1.send("open", mailbox="mb1") err = yield c1.next_non_ack() self.assertEqual(err["type"], "error") - self.assertEqual(err["error"], "you already have a mailbox open") + self.assertEqual(err["error"], "only one open per connection") @inlineCallbacks def test_add(self): @@ -938,7 +938,7 @@ class WebSocketAPI(_Util, ServerBase, unittest.TestCase): c1.send("close", mood="mood") # must open first err = yield c1.next_non_ack() self.assertEqual(err["type"], "error") - self.assertEqual(err["error"], "must open mailbox before closing") + self.assertEqual(err["error"], "close without mailbox must follow open") c1.send("open", mailbox="mb1") yield c1.sync() @@ -952,8 +952,46 @@ class WebSocketAPI(_Util, ServerBase, unittest.TestCase): c1.send("close", mood="mood") # already closed err = yield c1.next_non_ack() + self.assertEqual(err["type"], "error", m) + self.assertEqual(err["error"], "only one close per connection") + + @inlineCallbacks + def test_close_named(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send("bind", appid="appid", side="side") + + c1.send("open", mailbox="mb1") + yield c1.sync() + + c1.send("close", mailbox="mb1", mood="mood") + m = yield c1.next_non_ack() + self.assertEqual(m["type"], "closed") + + @inlineCallbacks + def test_close_named_ignored(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send("bind", appid="appid", side="side") + + c1.send("close", mailbox="mb1", mood="mood") # no open first, ignored + m = yield c1.next_non_ack() + self.assertEqual(m["type"], "closed") + + @inlineCallbacks + def test_close_named_mismatch(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send("bind", appid="appid", side="side") + + c1.send("open", mailbox="mb1") + yield c1.sync() + + c1.send("close", mailbox="mb2", mood="mood") + err = yield c1.next_non_ack() self.assertEqual(err["type"], "error") - self.assertEqual(err["error"], "must open mailbox before closing") + self.assertEqual(err["error"], "open and close must use same mailbox") + @inlineCallbacks def test_disconnect(self): @@ -1035,6 +1073,64 @@ class WebSocketAPI(_Util, ServerBase, unittest.TestCase): c.close() yield c.d + @inlineCallbacks + def test_interrupted_client_mailbox(self): + # a client's interactions with the server might be split over + # multiple sequential WebSocket connections, e.g. when the server is + # bounced and the client reconnects, or vice versa + c = yield self.make_client() + yield c.next_non_ack() + c.send("bind", appid="appid", side="side") + app = self._rendezvous.get_app("appid") + mb1 = app.open_mailbox("mb1", "side2", 0) + mb1.add_message(SidedMessage(side="side2", phase="phase", + body="body", server_rx=0, + msg_id="msgid")) + + c.send("open", mailbox="mb1") + m = yield c.next_non_ack() + self.assertEqual(m["type"], "message") + self.assertEqual(m["body"], "body") + self.assertTrue(mb1.has_listeners()) + c.close() + yield c.d + + c = yield self.make_client() + yield c.next_non_ack() + c.send("bind", appid="appid", side="side") + # open should be idempotent + c.send("open", mailbox="mb1") + m = yield c.next_non_ack() + self.assertEqual(m["type"], "message") + self.assertEqual(m["body"], "body") + mb_row, side_rows = self._mailbox(app, "mb1") + openeds = [(row["side"], row["opened"]) for row in side_rows] + self.assertIn(("side", 1), openeds) # TODO: why 1, and not True? + + # close on the same connection as open is ok + c.send("close", mailbox="mb1", mood="mood") + m = yield c.next_non_ack() + self.assertEqual(m["type"], "closed", m) + mb_row, side_rows = self._mailbox(app, "mb1") + openeds = [(row["side"], row["opened"]) for row in side_rows] + self.assertIn(("side", 0), openeds) + c.close() + yield c.d + + # close (on a separate connection) is idempotent + c = yield self.make_client() + yield c.next_non_ack() + c.send("bind", appid="appid", side="side") + c.send("close", mailbox="mb1", mood="mood") + m = yield c.next_non_ack() + self.assertEqual(m["type"], "closed", m) + mb_row, side_rows = self._mailbox(app, "mb1") + openeds = [(row["side"], row["opened"]) for row in side_rows] + self.assertIn(("side", 0), openeds) + c.close() + yield c.d + + class Summary(unittest.TestCase): def test_mailbox(self): app = rendezvous.AppNamespace(None, None, False, None)