SCHEMA CHANGE: channelids are now strs, not ints

This will enable the use of large randomly-generated hex or base32
channelids, for post-startup or resumed-connection channels.
This commit is contained in:
Brian Warner 2016-05-13 00:37:53 -07:00
parent c14e982ae7
commit 1198977e06
6 changed files with 22 additions and 11 deletions

View File

@ -4,6 +4,7 @@ from .wordlist import (byte_to_even_word, byte_to_odd_word,
even_words_lowercase, odd_words_lowercase) even_words_lowercase, odd_words_lowercase)
def make_code(channel_id, code_length): def make_code(channel_id, code_length):
assert isinstance(channel_id, type(u"")), type(channel_id)
words = [] words = []
for i in range(code_length): for i in range(code_length):
# we start with an "odd word" # we start with an "odd word"
@ -11,7 +12,7 @@ def make_code(channel_id, code_length):
words.append(byte_to_odd_word[os.urandom(1)].lower()) words.append(byte_to_odd_word[os.urandom(1)].lower())
else: else:
words.append(byte_to_even_word[os.urandom(1)].lower()) words.append(byte_to_even_word[os.urandom(1)].lower())
return u"%d-%s" % (channel_id, u"-".join(words)) return u"%s-%s" % (channel_id, u"-".join(words))
def extract_channel_id(code): def extract_channel_id(code):
channel_id = int(code.split("-")[0]) channel_id = int(code.split("-")[0])

View File

@ -10,9 +10,9 @@ CREATE TABLE `version`
CREATE TABLE `messages` CREATE TABLE `messages`
( (
`appid` VARCHAR, `appid` VARCHAR,
`channelid` INTEGER, `channelid` VARCHAR,
`side` VARCHAR, `side` VARCHAR,
`phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string `phase` VARCHAR, -- numeric or string
-- phase="_allocate" and "_deallocate" are used internally -- phase="_allocate" and "_deallocate" are used internally
`body` VARCHAR, `body` VARCHAR,
`server_rx` INTEGER, `server_rx` INTEGER,

View File

@ -212,28 +212,31 @@ class AppNamespace:
claimed = self.get_claimed() claimed = self.get_claimed()
for size in range(1,4): # stick to 1-999 for now for size in range(1,4): # stick to 1-999 for now
available = set() available = set()
for cid in range(10**(size-1), 10**size): for cid_int in range(10**(size-1), 10**size):
cid = u"%d" % cid_int
if cid not in claimed: if cid not in claimed:
available.add(cid) available.add(cid)
if available: if available:
return random.choice(list(available)) return random.choice(list(available))
# ouch, 999 currently claimed. Try random ones for a while. # ouch, 999 currently claimed. Try random ones for a while.
for tries in range(1000): for tries in range(1000):
cid = random.randrange(1000, 1000*1000) cid_int = random.randrange(1000, 1000*1000)
cid = u"%d" % cid_int
if cid not in claimed: if cid not in claimed:
return cid return cid
raise ValueError("unable to find a free channel-id") raise ValueError("unable to find a free channel-id")
def claim_channel(self, channelid, side): def claim_channel(self, channelid, side):
assert isinstance(channelid, type(u"")), type(channelid)
channel = self.get_channel(channelid) channel = self.get_channel(channelid)
channel.claim(side) channel.claim(side)
return channel return channel
def get_channel(self, channelid): def get_channel(self, channelid):
assert isinstance(channelid, int) assert isinstance(channelid, type(u""))
if not channelid in self._channels: if not channelid in self._channels:
if self._log_requests: if self._log_requests:
log.msg("spawning #%d for appid %s" % (channelid, self._appid)) log.msg("spawning #%s for appid %s" % (channelid, self._appid))
self._channels[channelid] = Channel(self, self._db, self._welcome, self._channels[channelid] = Channel(self, self._db, self._welcome,
self._blur_usage, self._blur_usage,
self._log_requests, self._log_requests,
@ -247,7 +250,7 @@ class AppNamespace:
if channelid in self._channels: if channelid in self._channels:
self._channels.pop(channelid) self._channels.pop(channelid)
if self._log_requests: if self._log_requests:
log.msg("freed+killed #%d, now have %d DB channels, %d live" % log.msg("freed+killed #%s, now have %d DB channels, %d live" %
(channelid, len(self.get_claimed()), len(self._channels))) (channelid, len(self.get_claimed()), len(self._channels)))
def prune_old_channels(self): def prune_old_channels(self):

View File

@ -132,6 +132,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
if self._did_allocate: if self._did_allocate:
raise Error("You already allocated one channel, don't be greedy") raise Error("You already allocated one channel, don't be greedy")
channelid = self._app.find_available_channelid() channelid = self._app.find_available_channelid()
assert isinstance(channelid, type(u""))
self._did_allocate = True self._did_allocate = True
channel = self._app.claim_channel(channelid, self._side) channel = self._app.claim_channel(channelid, self._side)
self._channels[channelid] = channel self._channels[channelid] = channel
@ -141,6 +142,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
if "channelid" not in msg: if "channelid" not in msg:
raise Error("claim requires 'channelid'") raise Error("claim requires 'channelid'")
channelid = msg["channelid"] channelid = msg["channelid"]
assert isinstance(channelid, type(u"")), type(channelid)
if channelid not in self._channels: if channelid not in self._channels:
channel = self._app.claim_channel(channelid, self._side) channel = self._app.claim_channel(channelid, self._side)
self._channels[channelid] = channel self._channels[channelid] = channel
@ -149,6 +151,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
channelid = msg["channelid"] channelid = msg["channelid"]
if channelid not in self._channels: if channelid not in self._channels:
raise Error("must claim channel before watching") raise Error("must claim channel before watching")
assert isinstance(channelid, type(u""))
channel = self._channels[channelid] channel = self._channels[channelid]
def _send(event): def _send(event):
self.send("message", channelid=channelid, message=event) self.send("message", channelid=channelid, message=event)
@ -161,6 +164,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
channelid = msg["channelid"] channelid = msg["channelid"]
if channelid not in self._channels: if channelid not in self._channels:
raise Error("must claim channel before adding") raise Error("must claim channel before adding")
assert isinstance(channelid, type(u""))
channel = self._channels[channelid] channel = self._channels[channelid]
if "phase" not in msg: if "phase" not in msg:
raise Error("missing 'phase'") raise Error("missing 'phase'")
@ -174,6 +178,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
channelid = msg["channelid"] channelid = msg["channelid"]
if channelid not in self._channels: if channelid not in self._channels:
raise Error("must claim channel before releasing") raise Error("must claim channel before releasing")
assert isinstance(channelid, type(u""))
channel = self._channels[channelid] channel = self._channels[channelid]
deleted = channel.release(self._side, msg.get("mood")) deleted = channel.release(self._side, msg.get("mood"))
del self._channels[channelid] del self._channels[channelid]

View File

@ -220,7 +220,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase):
msg = yield c1.next_non_ack() msg = yield c1.next_non_ack()
self.assertEqual(msg["type"], u"allocated") self.assertEqual(msg["type"], u"allocated")
cid = msg["channelid"] cid = msg["channelid"]
self.failUnlessIsInstance(cid, int) self.failUnlessIsInstance(cid, type(u""))
self.assertEqual(app.get_claimed(), set([cid])) self.assertEqual(app.get_claimed(), set([cid]))
channel = app.get_channel(cid) channel = app.get_channel(cid)
self.assertEqual(channel.get_messages(), []) self.assertEqual(channel.get_messages(), [])
@ -254,7 +254,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase):
msg = yield c1.next_non_ack() msg = yield c1.next_non_ack()
self.assertEqual(msg["type"], u"allocated") self.assertEqual(msg["type"], u"allocated")
cid = msg["channelid"] cid = msg["channelid"]
self.failUnlessIsInstance(cid, int) self.failUnlessIsInstance(cid, type(u""))
self.assertEqual(app.get_claimed(), set([cid])) self.assertEqual(app.get_claimed(), set([cid]))
channel = app.get_channel(cid) channel = app.get_channel(cid)
self.assertEqual(channel.get_messages(), []) self.assertEqual(channel.get_messages(), [])

View File

@ -234,6 +234,7 @@ class _Wormhole:
if self._channelid is not None: if self._channelid is not None:
return self._signal_error("got duplicate channelid") return self._signal_error("got duplicate channelid")
self._channelid = msg["channelid"] self._channelid = msg["channelid"]
assert isinstance(self._channelid, type(u"")), type(self._channelid)
self._wakeup() self._wakeup()
def _start(self): def _start(self):
@ -322,7 +323,8 @@ class _Wormhole:
if not mo: if not mo:
raise ValueError("code (%s) must start with NN-" % code) raise ValueError("code (%s) must start with NN-" % code)
with self._timing.add("API set_code"): with self._timing.add("API set_code"):
self._channelid = int(mo.group(1)) self._channelid = mo.group(1)
assert isinstance(self._channelid, type(u"")), type(self._channelid)
self._set_code(code) self._set_code(code)
self._start() self._start()