rendezvous: allow both allocate and claim (of the same channelid)
This allows the Wormhole setup path to be simpler: consistently doing a claim() just before watch(), regardless of whether we allocated the channelid (with get_code), or dictated it (with set_code or from_serialized).
This commit is contained in:
		
							parent
							
								
									558f01818f
								
							
						
					
					
						commit
						28b3f685d6
					
				| 
						 | 
					@ -28,6 +28,9 @@ class Channel:
 | 
				
			||||||
                                # takes a JSONable object) and
 | 
					                                # takes a JSONable object) and
 | 
				
			||||||
                                # .stop_rendezvous_watcher()
 | 
					                                # .stop_rendezvous_watcher()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_channelid(self):
 | 
				
			||||||
 | 
					        return self._channelid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_messages(self):
 | 
					    def get_messages(self):
 | 
				
			||||||
        messages = []
 | 
					        messages = []
 | 
				
			||||||
        db = self._db
 | 
					        db = self._db
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -127,10 +127,13 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
 | 
				
			||||||
        self.send("allocated", channelid=channelid)
 | 
					        self.send("allocated", channelid=channelid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def handle_claim(self, msg):
 | 
					    def handle_claim(self, msg):
 | 
				
			||||||
        if self._channel:
 | 
					 | 
				
			||||||
            raise Error("Already bound to a channelid")
 | 
					 | 
				
			||||||
        if "channelid" not in msg:
 | 
					        if "channelid" not in msg:
 | 
				
			||||||
            raise Error("claim requires 'channelid'")
 | 
					            raise Error("claim requires 'channelid'")
 | 
				
			||||||
 | 
					        # we allow allocate+claim as long as they match
 | 
				
			||||||
 | 
					        if self._channel is not None:
 | 
				
			||||||
 | 
					            old_cid = self._channel.get_channelid()
 | 
				
			||||||
 | 
					            if msg["channelid"] != old_cid:
 | 
				
			||||||
 | 
					                raise Error("Already bound to channelid %d" % old_cid)
 | 
				
			||||||
        self._channel = self._app.allocate_channel(msg["channelid"], self._side)
 | 
					        self._channel = self._app.allocate_channel(msg["channelid"], self._side)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def handle_watch(self, channel, msg):
 | 
					    def handle_watch(self, channel, msg):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -373,6 +373,7 @@ class WSClient(websocket.WebSocketClientProtocol):
 | 
				
			||||||
    def __init__(self):
 | 
					    def __init__(self):
 | 
				
			||||||
        websocket.WebSocketClientProtocol.__init__(self)
 | 
					        websocket.WebSocketClientProtocol.__init__(self)
 | 
				
			||||||
        self.events = []
 | 
					        self.events = []
 | 
				
			||||||
 | 
					        self.errors = []
 | 
				
			||||||
        self.d = None
 | 
					        self.d = None
 | 
				
			||||||
        self.ping_counter = itertools.count(0)
 | 
					        self.ping_counter = itertools.count(0)
 | 
				
			||||||
    def onOpen(self):
 | 
					    def onOpen(self):
 | 
				
			||||||
| 
						 | 
					@ -380,6 +381,8 @@ class WSClient(websocket.WebSocketClientProtocol):
 | 
				
			||||||
    def onMessage(self, payload, isBinary):
 | 
					    def onMessage(self, payload, isBinary):
 | 
				
			||||||
        assert not isBinary
 | 
					        assert not isBinary
 | 
				
			||||||
        event = json.loads(payload.decode("utf-8"))
 | 
					        event = json.loads(payload.decode("utf-8"))
 | 
				
			||||||
 | 
					        if event["type"] == "error":
 | 
				
			||||||
 | 
					            self.errors.append(event)
 | 
				
			||||||
        if self.d:
 | 
					        if self.d:
 | 
				
			||||||
            assert not self.events
 | 
					            assert not self.events
 | 
				
			||||||
            d,self.d = self.d,None
 | 
					            d,self.d = self.d,None
 | 
				
			||||||
| 
						 | 
					@ -608,6 +611,40 @@ class WebSocketAPI(ServerBase, unittest.TestCase):
 | 
				
			||||||
        self.assertEqual(msg["type"], u"channelids")
 | 
					        self.assertEqual(msg["type"], u"channelids")
 | 
				
			||||||
        self.assertEqual(msg["channelids"], [])
 | 
					        self.assertEqual(msg["channelids"], [])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @inlineCallbacks
 | 
				
			||||||
 | 
					    def test_allocate_and_claim(self):
 | 
				
			||||||
 | 
					        c1 = yield self.make_client()
 | 
				
			||||||
 | 
					        msg = yield c1.next_event()
 | 
				
			||||||
 | 
					        self.check_welcome(msg)
 | 
				
			||||||
 | 
					        c1.send(u"bind", appid=u"appid", side=u"side")
 | 
				
			||||||
 | 
					        c1.send(u"allocate")
 | 
				
			||||||
 | 
					        msg = yield c1.next_event()
 | 
				
			||||||
 | 
					        self.assertEqual(msg["type"], u"allocated")
 | 
				
			||||||
 | 
					        cid = msg["channelid"]
 | 
				
			||||||
 | 
					        c1.send(u"claim", channelid=cid)
 | 
				
			||||||
 | 
					        yield c1.sync()
 | 
				
			||||||
 | 
					        # there should no error
 | 
				
			||||||
 | 
					        self.assertEqual(c1.errors, [])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @inlineCallbacks
 | 
				
			||||||
 | 
					    def test_allocate_and_claim_different(self):
 | 
				
			||||||
 | 
					        c1 = yield self.make_client()
 | 
				
			||||||
 | 
					        msg = yield c1.next_event()
 | 
				
			||||||
 | 
					        self.check_welcome(msg)
 | 
				
			||||||
 | 
					        c1.send(u"bind", appid=u"appid", side=u"side")
 | 
				
			||||||
 | 
					        c1.send(u"allocate")
 | 
				
			||||||
 | 
					        msg = yield c1.next_event()
 | 
				
			||||||
 | 
					        self.assertEqual(msg["type"], u"allocated")
 | 
				
			||||||
 | 
					        cid = msg["channelid"]
 | 
				
			||||||
 | 
					        c1.send(u"claim", channelid=cid+1)
 | 
				
			||||||
 | 
					        yield c1.sync()
 | 
				
			||||||
 | 
					        # that should signal an error
 | 
				
			||||||
 | 
					        self.assertEqual(len(c1.errors), 1, c1.errors)
 | 
				
			||||||
 | 
					        msg = c1.errors[0]
 | 
				
			||||||
 | 
					        self.assertEqual(msg["type"], "error")
 | 
				
			||||||
 | 
					        self.assertEqual(msg["error"], "Already bound to channelid %d" % cid)
 | 
				
			||||||
 | 
					        self.assertEqual(msg["orig"], {"type": "claim", "channelid": cid+1})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @inlineCallbacks
 | 
					    @inlineCallbacks
 | 
				
			||||||
    def test_message(self):
 | 
					    def test_message(self):
 | 
				
			||||||
        c1 = yield self.make_client()
 | 
					        c1 = yield self.make_client()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user