diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 40c0b50..9ec113d 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -3,26 +3,26 @@ from twisted.internet import reactor from twisted.python import log from autobahn.twisted import websocket -# Each WebSocket connection is bound to one "appid", one "side", and one -# "channelid". The connection's appid and side are set by the "bind" message -# (which must be the first message on the connection). The channelid is set -# by either a "allocate" message (where the server picks the channelid), or -# by a "claim" message (where the client picks it). All three values must be -# set before any other message (watch, add, deallocate) can be sent. Channels -# are maintained (saved from deletion) by a "claim" message (and also -# incidentally by "allocate"). Channels are deleted when the last claim is -# released with "release". +# Each WebSocket connection is bound to one "appid", one "side", and zero or +# more "channelids". The connection's appid and side are set by the "bind" +# message (which must be the first message on the connection). Both must be +# set before any other message (allocate, claim, watch, add, deallocate) will +# be accepted. Short channel IDs can be obtained from the server with an +# "allocate" message. Longer ones can be selected independently by the +# client. Channels are maintained (saved from deletion) by a "claim" message +# (and also incidentally by "allocate"). Channels are deleted when the last +# claim is released with "release". # All websocket messages are JSON-encoded. The client can send us "inbound" # messages (marked as "->" below), which may (or may not) provoke immediate # (or delayed) "outbound" messages (marked as "<-"). There is no guaranteed # correlation between requests and responses. In this list, "A -> B" means # that some time after A is received, at least one message of type B will be -# sent out. +# sent out (probably). # All outbound messages include a "server_tx" key, which is a float (seconds # since epoch) with the server clock just before the outbound message was -# written to the socket. +# written to the socket. Unrecognized keys will be ignored. # connection -> welcome # <- {type: "welcome", welcome: {}} # .welcome keys are all optional: @@ -30,17 +30,20 @@ from autobahn.twisted import websocket # motd: all clients display message, then continue normally # error: all clients display mesage, then terminate with error # -> {type: "bind", appid:, side:} +# # -> {type: "list"} -> channelids # <- {type: "channelids", channelids: [int..]} # -> {type: "allocate"} -> allocated # <- {type: "allocated", channelid: int} # -> {type: "claim", channelid: int} -# -> {type: "watch"} -> message # sends old messages and more in future -# <- {type: "message", message: {phase:, body:}} # body is hex -# -> {type: "add", phase: str, body: hex} # may send echo # -# -> {type: "release", mood: str} -> deallocated -# <- {type: "released", status: waiting|deleted} +# -> {type: "watch", channelid: int} -> message +# sends old messages and more in future +# <- {type: "message", channelid: int, message: {phase:, body:}} # body is hex +# -> {type: "add", channelid: int, phase: str, body: hex} # will send echo +# +# -> {type: "release", channelid: int, mood: str} -> deallocated +# <- {type: "released", channelid: int, status: waiting|deleted} # # <- {type: "error", error: str, orig: {}} # in response to malformed msgs @@ -57,8 +60,8 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): websocket.WebSocketServerProtocol.__init__(self) self._app = None self._side = None - self._channel = None - self._watching = False + self._did_allocate = False # only one allocate() per websocket + self._channels = {} # channel-id -> Channel (claimed) def onConnect(self, request): rv = self.factory.rendezvous @@ -95,26 +98,17 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): return self.handle_allocate() if mtype == "claim": return self.handle_claim(msg) - - if not self._channel: - raise Error("Must set channel first") if mtype == "watch": - return self.handle_watch(self._channel, msg) + return self.handle_watch(msg) if mtype == "add": - return self.handle_add(self._channel, msg, server_rx) + return self.handle_add(msg, server_rx) if mtype == "release": - return self.handle_release(self._channel, msg) + return self.handle_release(msg) raise Error("Unknown type") except Error as e: self.send("error", error=e._explain, orig=msg) - def send_rendezvous_event(self, event): - self.send("message", message=event) - - def stop_rendezvous_watcher(self): - self._reactor.callLater(0, self.transport.loseConnection) - def handle_ping(self, msg): if "ping" not in msg: raise Error("ping requires 'ping'") @@ -135,34 +129,39 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): self.send("channelids", channelids=channelids) def handle_allocate(self): - if self._channel: - raise Error("Already bound to a channelid") + if self._did_allocate: + raise Error("You already allocated one channel, don't be greedy") channelid = self._app.find_available_channelid() - self._channel = self._app.claim_channel(channelid, self._side) + self._did_allocate = True + channel = self._app.claim_channel(channelid, self._side) + self._channels[channelid] = channel self.send("allocated", channelid=channelid) def handle_claim(self, msg): if "channelid" not in msg: raise Error("claim requires 'channelid'") - # we allow allocate+claim as long as they match - if self._channel is not None: - old_cid = self._channel.get_channelid() - if msg["channelid"] != old_cid: - raise Error("Already bound to channelid %d" % old_cid) - self._channel = self._app.claim_channel(msg["channelid"], self._side) + channelid = msg["channelid"] + if channelid not in self._channels: + channel = self._app.claim_channel(channelid, self._side) + self._channels[channelid] = channel - def handle_watch(self, channel, msg): - if self._watching: - raise Error("already watching") - self._watching = True + def handle_watch(self, msg): + channelid = msg["channelid"] + if channelid not in self._channels: + raise Error("must claim channel before watching") + channel = self._channels[channelid] def _send(event): - self.send_rendezvous_event(event) + self.send("message", channelid=channelid, message=event) def _stop(): - self.stop_rendezvous_watcher() + self._reactor.callLater(0, self.transport.loseConnection) for old_message in channel.add_listener(self, _send, _stop): _send(old_message) - def handle_add(self, channel, msg, server_rx): + def handle_add(self, msg, server_rx): + channelid = msg["channelid"] + if channelid not in self._channels: + raise Error("must claim channel before adding") + channel = self._channels[channelid] if "phase" not in msg: raise Error("missing 'phase'") if "body" not in msg: @@ -171,8 +170,13 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): channel.add_message(self._side, msg["phase"], msg["body"], server_rx, msgid) - def handle_release(self, channel, msg): + def handle_release(self, msg): + channelid = msg["channelid"] + if channelid not in self._channels: + raise Error("must claim channel before releasing") + channel = self._channels[channelid] deleted = channel.release(self._side, msg.get("mood")) + del self._channels[channelid] self.send("released", status="deleted" if deleted else "waiting") def send(self, mtype, **kwargs): diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 8708dd1..6deef6c 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -230,7 +230,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(msg["type"], u"channelids") self.assertEqual(msg["channelids"], [cid]) - c1.send(u"release") + c1.send(u"release", channelid=cid) msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"released") self.assertEqual(msg["status"], u"deleted") @@ -265,7 +265,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.check_welcome(msg) c2.send(u"bind", appid=u"appid", side=u"side-2") c2.send(u"claim", channelid=cid) - c2.send(u"add", phase="1", body="") + c2.send(u"add", channelid=cid, phase="1", body="") yield c2.sync() self.assertEqual(app.get_claimed(), set([cid])) @@ -282,12 +282,12 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(msg["type"], u"channelids") self.assertEqual(msg["channelids"], [cid]) - c1.send(u"release") + c1.send(u"release", channelid=cid) msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"released") self.assertEqual(msg["status"], u"waiting") - c2.send(u"release") + c2.send(u"release", channelid=cid) msg = yield c2.next_non_ack() self.assertEqual(msg["type"], u"released") self.assertEqual(msg["status"], u"deleted") @@ -312,25 +312,6 @@ class WebSocketAPI(ServerBase, unittest.TestCase): # there should no error self.assertEqual(c1.errors, []) - @inlineCallbacks - def test_allocate_and_claim_different(self): - c1 = yield self.make_client() - msg = yield c1.next_non_ack() - self.check_welcome(msg) - c1.send(u"bind", appid=u"appid", side=u"side") - c1.send(u"allocate") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"allocated") - cid = msg["channelid"] - c1.send(u"claim", channelid=cid+1) - yield c1.sync() - # that should signal an error - self.assertEqual(len(c1.errors), 1, c1.errors) - msg = c1.errors[0] - self.assertEqual(msg["type"], "error") - self.assertEqual(msg["error"], "Already bound to channelid %d" % cid) - self.assertEqual(msg["orig"], {"type": "claim", "channelid": cid+1}) - @inlineCallbacks def test_message(self): c1 = yield self.make_client() @@ -345,13 +326,13 @@ class WebSocketAPI(ServerBase, unittest.TestCase): channel = app.get_channel(cid) self.assertEqual(channel.get_messages(), []) - c1.send(u"watch") + c1.send(u"watch", channelid=cid) yield c1.sync() self.assertEqual(len(channel._listeners), 1) c1.strip_acks() self.assertEqual(c1.events, []) - c1.send(u"add", phase="1", body="msg1A") + c1.send(u"add", channelid=cid, phase="1", body="msg1A") yield c1.sync() c1.strip_acks() self.assertEqual(strip_messages(channel.get_messages()), @@ -364,8 +345,8 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertIn("server_tx", msg) self.assertIsInstance(msg["server_tx"], float) - c1.send(u"add", phase="1", body="msg1B") - c1.send(u"add", phase="2", body="msg2A") + c1.send(u"add", channelid=cid, phase="1", body="msg1B") + c1.send(u"add", channelid=cid, phase="2", body="msg2A") msg = yield c1.next_non_ack() self.assertEqual(msg["type"], "message") @@ -390,7 +371,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): c2.send(u"bind", appid=u"appid", side=u"side") c2.send(u"claim", channelid=cid) # 'watch' triggers delivery of old messages, in temporal order - c2.send(u"watch") + c2.send(u"watch", channelid=cid) msg = yield c2.next_non_ack() self.assertEqual(msg["type"], "message") @@ -408,7 +389,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): {"phase": "2", "body": "msg2A"}) # adding a duplicate is not an error, and clients will ignore it - c1.send(u"add", phase="2", body="msg2A") + c1.send(u"add", channelid=cid, phase="2", body="msg2A") # the duplicate message *does* get stored, and delivered msg = yield c2.next_non_ack() diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 64d83c3..5f362a3 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -211,7 +211,7 @@ class _Wormhole: if not self._ws_channel_claimed: yield self._ws_send(u"claim", channelid=self._channelid) self._ws_channel_claimed = True - yield self._ws_send(u"watch") + yield self._ws_send(u"watch", channelid=self._channelid) # entry point 1: generate a new code @inlineCallbacks @@ -406,7 +406,7 @@ class _Wormhole: # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. t = self._timing.add("add", phase=phase, wait=wait) - yield self._ws_send(u"add", phase=phase, + yield self._ws_send(u"add", channelid=self._channelid, phase=phase, body=hexlify(body).decode("ascii")) if wait: while phase not in self._delivered_messages: @@ -546,7 +546,8 @@ class _Wormhole: @inlineCallbacks def _release(self, mood): with self._timing.add("release"): - yield self._ws_send(u"release", mood=mood) + yield self._ws_send(u"release", channelid=self._channelid, + mood=mood) while self._released_status is None: yield self._sleep(wake_on_error=False) # TODO: set a timeout, don't wait forever for an ack