From 1198977e069bcf0ea01e66f538269972dc06060a Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 13 May 2016 00:37:53 -0700 Subject: [PATCH] SCHEMA CHANGE: channelids are now strs, not ints This will enable the use of large randomly-generated hex or base32 channelids, for post-startup or resumed-connection channels. --- src/wormhole/codes.py | 3 ++- src/wormhole/server/db-schemas/v1.sql | 4 ++-- src/wormhole/server/rendezvous.py | 13 ++++++++----- src/wormhole/server/rendezvous_websocket.py | 5 +++++ src/wormhole/test/test_server.py | 4 ++-- src/wormhole/twisted/transcribe.py | 4 +++- 6 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/wormhole/codes.py b/src/wormhole/codes.py index fd08694..44f2665 100644 --- a/src/wormhole/codes.py +++ b/src/wormhole/codes.py @@ -4,6 +4,7 @@ from .wordlist import (byte_to_even_word, byte_to_odd_word, even_words_lowercase, odd_words_lowercase) def make_code(channel_id, code_length): + assert isinstance(channel_id, type(u"")), type(channel_id) words = [] for i in range(code_length): # we start with an "odd word" @@ -11,7 +12,7 @@ def make_code(channel_id, code_length): words.append(byte_to_odd_word[os.urandom(1)].lower()) else: words.append(byte_to_even_word[os.urandom(1)].lower()) - return u"%d-%s" % (channel_id, u"-".join(words)) + return u"%s-%s" % (channel_id, u"-".join(words)) def extract_channel_id(code): channel_id = int(code.split("-")[0]) diff --git a/src/wormhole/server/db-schemas/v1.sql b/src/wormhole/server/db-schemas/v1.sql index 7d7115a..2740661 100644 --- a/src/wormhole/server/db-schemas/v1.sql +++ b/src/wormhole/server/db-schemas/v1.sql @@ -10,9 +10,9 @@ CREATE TABLE `version` CREATE TABLE `messages` ( `appid` VARCHAR, - `channelid` INTEGER, + `channelid` VARCHAR, `side` VARCHAR, - `phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string + `phase` VARCHAR, -- numeric or string -- phase="_allocate" and "_deallocate" are used internally `body` VARCHAR, `server_rx` INTEGER, diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 9f20b61..691d06e 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -212,28 +212,31 @@ class AppNamespace: claimed = self.get_claimed() for size in range(1,4): # stick to 1-999 for now available = set() - for cid in range(10**(size-1), 10**size): + for cid_int in range(10**(size-1), 10**size): + cid = u"%d" % cid_int if cid not in claimed: available.add(cid) if available: return random.choice(list(available)) # ouch, 999 currently claimed. Try random ones for a while. for tries in range(1000): - cid = random.randrange(1000, 1000*1000) + cid_int = random.randrange(1000, 1000*1000) + cid = u"%d" % cid_int if cid not in claimed: return cid raise ValueError("unable to find a free channel-id") def claim_channel(self, channelid, side): + assert isinstance(channelid, type(u"")), type(channelid) channel = self.get_channel(channelid) channel.claim(side) return channel def get_channel(self, channelid): - assert isinstance(channelid, int) + assert isinstance(channelid, type(u"")) if not channelid in self._channels: if self._log_requests: - log.msg("spawning #%d for appid %s" % (channelid, self._appid)) + log.msg("spawning #%s for appid %s" % (channelid, self._appid)) self._channels[channelid] = Channel(self, self._db, self._welcome, self._blur_usage, self._log_requests, @@ -247,7 +250,7 @@ class AppNamespace: if channelid in self._channels: self._channels.pop(channelid) if self._log_requests: - log.msg("freed+killed #%d, now have %d DB channels, %d live" % + log.msg("freed+killed #%s, now have %d DB channels, %d live" % (channelid, len(self.get_claimed()), len(self._channels))) def prune_old_channels(self): diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 9ec113d..358063d 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -132,6 +132,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): if self._did_allocate: raise Error("You already allocated one channel, don't be greedy") channelid = self._app.find_available_channelid() + assert isinstance(channelid, type(u"")) self._did_allocate = True channel = self._app.claim_channel(channelid, self._side) self._channels[channelid] = channel @@ -141,6 +142,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): if "channelid" not in msg: raise Error("claim requires 'channelid'") channelid = msg["channelid"] + assert isinstance(channelid, type(u"")), type(channelid) if channelid not in self._channels: channel = self._app.claim_channel(channelid, self._side) self._channels[channelid] = channel @@ -149,6 +151,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): channelid = msg["channelid"] if channelid not in self._channels: raise Error("must claim channel before watching") + assert isinstance(channelid, type(u"")) channel = self._channels[channelid] def _send(event): self.send("message", channelid=channelid, message=event) @@ -161,6 +164,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): channelid = msg["channelid"] if channelid not in self._channels: raise Error("must claim channel before adding") + assert isinstance(channelid, type(u"")) channel = self._channels[channelid] if "phase" not in msg: raise Error("missing 'phase'") @@ -174,6 +178,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): channelid = msg["channelid"] if channelid not in self._channels: raise Error("must claim channel before releasing") + assert isinstance(channelid, type(u"")) channel = self._channels[channelid] deleted = channel.release(self._side, msg.get("mood")) del self._channels[channelid] diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 6deef6c..ad3e227 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -220,7 +220,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"allocated") cid = msg["channelid"] - self.failUnlessIsInstance(cid, int) + self.failUnlessIsInstance(cid, type(u"")) self.assertEqual(app.get_claimed(), set([cid])) channel = app.get_channel(cid) self.assertEqual(channel.get_messages(), []) @@ -254,7 +254,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"allocated") cid = msg["channelid"] - self.failUnlessIsInstance(cid, int) + self.failUnlessIsInstance(cid, type(u"")) self.assertEqual(app.get_claimed(), set([cid])) channel = app.get_channel(cid) self.assertEqual(channel.get_messages(), []) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 5f362a3..d8f39d6 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -234,6 +234,7 @@ class _Wormhole: if self._channelid is not None: return self._signal_error("got duplicate channelid") self._channelid = msg["channelid"] + assert isinstance(self._channelid, type(u"")), type(self._channelid) self._wakeup() def _start(self): @@ -322,7 +323,8 @@ class _Wormhole: if not mo: raise ValueError("code (%s) must start with NN-" % code) with self._timing.add("API set_code"): - self._channelid = int(mo.group(1)) + self._channelid = mo.group(1) + assert isinstance(self._channelid, type(u"")), type(self._channelid) self._set_code(code) self._start()