From 28b3f685d6c7db3a3f2857ca5a2b341c2160320c Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 20 Apr 2016 01:51:03 -0700 Subject: [PATCH] rendezvous: allow both allocate and claim (of the same channelid) This allows the Wormhole setup path to be simpler: consistently doing a claim() just before watch(), regardless of whether we allocated the channelid (with get_code), or dictated it (with set_code or from_serialized). --- src/wormhole_server/rendezvous.py | 3 ++ src/wormhole_server/rendezvous_websocket.py | 7 ++-- tests/test_server.py | 37 +++++++++++++++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/wormhole_server/rendezvous.py b/src/wormhole_server/rendezvous.py index 73be8dd..23d5750 100644 --- a/src/wormhole_server/rendezvous.py +++ b/src/wormhole_server/rendezvous.py @@ -28,6 +28,9 @@ class Channel: # takes a JSONable object) and # .stop_rendezvous_watcher() + def get_channelid(self): + return self._channelid + def get_messages(self): messages = [] db = self._db diff --git a/src/wormhole_server/rendezvous_websocket.py b/src/wormhole_server/rendezvous_websocket.py index 53ec2fa..a4876f4 100644 --- a/src/wormhole_server/rendezvous_websocket.py +++ b/src/wormhole_server/rendezvous_websocket.py @@ -127,10 +127,13 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): self.send("allocated", channelid=channelid) def handle_claim(self, msg): - if self._channel: - raise Error("Already bound to a channelid") if "channelid" not in msg: raise Error("claim requires 'channelid'") + # we allow allocate+claim as long as they match + if self._channel is not None: + old_cid = self._channel.get_channelid() + if msg["channelid"] != old_cid: + raise Error("Already bound to channelid %d" % old_cid) self._channel = self._app.allocate_channel(msg["channelid"], self._side) def handle_watch(self, channel, msg): diff --git a/tests/test_server.py b/tests/test_server.py index cb2ae85..76e4128 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -373,6 +373,7 @@ class WSClient(websocket.WebSocketClientProtocol): def __init__(self): websocket.WebSocketClientProtocol.__init__(self) self.events = [] + self.errors = [] self.d = None self.ping_counter = itertools.count(0) def onOpen(self): @@ -380,6 +381,8 @@ class WSClient(websocket.WebSocketClientProtocol): def onMessage(self, payload, isBinary): assert not isBinary event = json.loads(payload.decode("utf-8")) + if event["type"] == "error": + self.errors.append(event) if self.d: assert not self.events d,self.d = self.d,None @@ -608,6 +611,40 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(msg["type"], u"channelids") self.assertEqual(msg["channelids"], []) + @inlineCallbacks + def test_allocate_and_claim(self): + c1 = yield self.make_client() + msg = yield c1.next_event() + self.check_welcome(msg) + c1.send(u"bind", appid=u"appid", side=u"side") + c1.send(u"allocate") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"allocated") + cid = msg["channelid"] + c1.send(u"claim", channelid=cid) + yield c1.sync() + # there should no error + self.assertEqual(c1.errors, []) + + @inlineCallbacks + def test_allocate_and_claim_different(self): + c1 = yield self.make_client() + msg = yield c1.next_event() + self.check_welcome(msg) + c1.send(u"bind", appid=u"appid", side=u"side") + c1.send(u"allocate") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"allocated") + cid = msg["channelid"] + c1.send(u"claim", channelid=cid+1) + yield c1.sync() + # that should signal an error + self.assertEqual(len(c1.errors), 1, c1.errors) + msg = c1.errors[0] + self.assertEqual(msg["type"], "error") + self.assertEqual(msg["error"], "Already bound to channelid %d" % cid) + self.assertEqual(msg["orig"], {"type": "claim", "channelid": cid+1}) + @inlineCallbacks def test_message(self): c1 = yield self.make_client()