rename channel-id to channelid. changes DB schema.

This commit is contained in:
Brian Warner 2015-10-06 16:16:41 -07:00
parent fc641622ba
commit fc30fa6cd4
5 changed files with 98 additions and 98 deletions

View File

@ -15,8 +15,8 @@ SECOND = 1
MINUTE = 60*SECOND
# relay URLs are:
# GET /list -> {channel-ids: [INT..]}
# POST /allocate {side: SIDE} -> {channel-id: INT}
# GET /list -> {channelids: [INT..]}
# POST /allocate {side: SIDE} -> {channelid: INT}
# these return all messages (base64) for CID= :
# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]}
# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]}
@ -25,8 +25,8 @@ MINUTE = 60*SECOND
# all JSON responses include a "welcome:{..}" key
class Channel:
def __init__(self, relay, channel_id, side, handle_welcome):
self._channel_url = "%s%d" % (relay, channel_id)
def __init__(self, relay, channelid, side, handle_welcome):
self._channel_url = "%s%d" % (relay, channelid)
self._side = side
self._handle_welcome = handle_welcome
self._messages = set() # (phase,body) , body is bytes
@ -107,8 +107,8 @@ class ChannelManager:
def list_channels(self):
r = requests.get(self._relay + "list")
r.raise_for_status()
channel_ids = r.json()["channel-ids"]
return channel_ids
channelids = r.json()["channelids"]
return channelids
def allocate(self):
data = json.dumps({"side": self._side}).encode("utf-8")
@ -117,11 +117,11 @@ class ChannelManager:
data = r.json()
if "welcome" in data:
self._handle_welcome(data["welcome"])
channel_id = data["channel-id"]
return channel_id
channelid = data["channelid"]
return channelid
def connect(self, channel_id):
return Channel(self._relay, channel_id, self._side,
def connect(self, channelid):
return Channel(self._relay, channelid, self._side,
self._handle_welcome)
class Wormhole:
@ -167,10 +167,10 @@ class Wormhole:
def get_code(self, code_length=2):
if self.code is not None: raise UsageError
channel_id = self._channel_manager.allocate()
code = codes.make_code(channel_id, code_length)
channelid = self._channel_manager.allocate()
code = codes.make_code(channelid, code_length)
assert isinstance(code, str), type(code)
self._set_code_and_channel_id(code)
self._set_code_and_channelid(code)
self._start()
return code
@ -183,17 +183,17 @@ class Wormhole:
def set_code(self, code): # used for human-made pre-generated codes
if not isinstance(code, str): raise UsageError
if self.code is not None: raise UsageError
self._set_code_and_channel_id(code)
self._set_code_and_channelid(code)
self._start()
def _set_code_and_channel_id(self, code):
def _set_code_and_channelid(self, code):
if self.code is not None: raise UsageError
mo = re.search(r'^(\d+)-', code)
if not mo:
raise ValueError("code (%s) must start with NN-" % code)
self.code = code
channel_id = int(mo.group(1))
self.channel = self._channel_manager.connect(channel_id)
channelid = int(mo.group(1))
self.channel = self._channel_manager.connect(channelid)
def _start(self):
# allocate the rest now too, so it can be serialized

View File

@ -9,17 +9,17 @@ CREATE TABLE `version`
CREATE TABLE `messages`
(
`channel_id` INTEGER,
`channelid` INTEGER,
`side` VARCHAR,
`phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string
`body` VARCHAR,
`when` INTEGER
);
CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `phase`);
CREATE INDEX `messages_idx` ON `messages` (`channelid`, `side`, `phase`);
CREATE TABLE `allocations`
(
`channel_id` INTEGER,
`channelid` INTEGER,
`side` VARCHAR
);
CREATE INDEX `allocations_idx` ON `allocations` (`channel_id`);
CREATE INDEX `allocations_idx` ON `allocations` (`channelid`);

View File

@ -47,8 +47,8 @@ class EventsProtocol:
# note: no versions of IE (including the current IE11) support EventSource
# relay URLs are:
# GET /list -> {channel-ids: [INT..]}
# POST /allocate {side: SIDE} -> {channel-id: INT}
# GET /list -> {channelids: [INT..]}
# POST /allocate {side: SIDE} -> {channelid: INT}
# these return all messages (base64) for CID= :
# POST /CID {side:, phase:, body:} -> {messages: [{phase:, body:}..]}
# GET /CID (no-eventsource) -> {messages: [{phase:, body:}..]}
@ -57,22 +57,22 @@ class EventsProtocol:
# all JSON responses include a "welcome:{..}" key
class Channel(resource.Resource):
def __init__(self, channel_id, relay, db, welcome):
def __init__(self, channelid, relay, db, welcome):
resource.Resource.__init__(self)
self.channel_id = channel_id
self.channelid = channelid
self.relay = relay
self.db = db
self.welcome = welcome
self.event_channels = set() # ep
self.putChild(b"deallocate", Deallocator(self.channel_id, self.relay))
self.putChild(b"deallocate", Deallocator(self.channelid, self.relay))
def get_messages(self, request):
request.setHeader(b"content-type", b"application/json; charset=utf-8")
messages = []
for row in self.db.execute("SELECT * FROM `messages`"
" WHERE `channel_id`=?"
" WHERE `channelid`=?"
" ORDER BY `when` ASC",
(self.channel_id,)).fetchall():
(self.channelid,)).fetchall():
messages.append({"phase": row["phase"], "body": row["body"]})
data = {"welcome": self.welcome, "messages": messages}
return (json.dumps(data)+"\n").encode("utf-8")
@ -87,9 +87,9 @@ class Channel(resource.Resource):
request.notifyFinish().addErrback(lambda f:
self.event_channels.discard(ep))
for row in self.db.execute("SELECT * FROM `messages`"
" WHERE `channel_id`=?"
" WHERE `channelid`=?"
" ORDER BY `when` ASC",
(self.channel_id,)).fetchall():
(self.channelid,)).fetchall():
data = json.dumps({"phase": row["phase"], "body": row["body"]})
ep.sendEvent(data)
return server.NOT_DONE_YET
@ -111,35 +111,35 @@ class Channel(resource.Resource):
body = data["body"]
self.db.execute("INSERT INTO `messages`"
" (`channel_id`, `side`, `phase`, `body`, `when`)"
" (`channelid`, `side`, `phase`, `body`, `when`)"
" VALUES (?,?,?,?,?)",
(self.channel_id, side, phase, body, time.time()))
(self.channelid, side, phase, body, time.time()))
self.db.execute("INSERT INTO `allocations`"
" (`channel_id`, `side`)"
" (`channelid`, `side`)"
" VALUES (?,?)",
(self.channel_id, side))
(self.channelid, side))
self.db.commit()
self.broadcast_message(phase, body)
return self.get_messages(request)
class Deallocator(resource.Resource):
def __init__(self, channel_id, relay):
self.channel_id = channel_id
def __init__(self, channelid, relay):
self.channelid = channelid
self.relay = relay
def render_POST(self, request):
content = request.content.read()
data = json.loads(content.decode("utf-8"))
side = data["side"]
deleted = self.relay.maybe_free_child(self.channel_id, side)
deleted = self.relay.maybe_free_child(self.channelid, side)
resp = {"status": "waiting"}
if deleted:
resp = {"status": "deleted"}
return json.dumps(resp).encode("utf-8")
def get_allocated(db):
c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`")
return set([row["channel_id"] for row in c.fetchall()])
c = db.execute("SELECT DISTINCT `channelid` FROM `allocations`")
return set([row["channelid"] for row in c.fetchall()])
class Allocator(resource.Resource):
def __init__(self, db, welcome):
@ -147,7 +147,7 @@ class Allocator(resource.Resource):
self.db = db
self.welcome = welcome
def allocate_channel_id(self):
def allocate_channelid(self):
allocated = get_allocated(self.db)
for size in range(1,4): # stick to 1-999 for now
available = set()
@ -161,7 +161,7 @@ class Allocator(resource.Resource):
cid = random.randrange(1000, 1000*1000)
if cid not in allocated:
return cid
raise ValueError("unable to find a free channel-id")
raise ValueError("unable to find a free channelid")
def render_POST(self, request):
content = request.content.read()
@ -169,15 +169,15 @@ class Allocator(resource.Resource):
side = data["side"]
if not isinstance(side, type(u"")):
raise TypeError("side must be string, not '%s'" % type(side))
channel_id = self.allocate_channel_id()
channelid = self.allocate_channelid()
self.db.execute("INSERT INTO `allocations` VALUES (?,?)",
(channel_id, side))
(channelid, side))
self.db.commit()
log.msg("allocated #%d, now have %d DB channels" %
(channel_id, len(get_allocated(self.db))))
(channelid, len(get_allocated(self.db))))
request.setHeader(b"content-type", b"application/json; charset=utf-8")
data = {"welcome": self.welcome,
"channel-id": channel_id}
"channelid": channelid}
return (json.dumps(data)+"\n").encode("utf-8")
class ChannelList(resource.Resource):
@ -186,11 +186,11 @@ class ChannelList(resource.Resource):
self.db = db
self.welcome = welcome
def render_GET(self, request):
c = self.db.execute("SELECT DISTINCT `channel_id` FROM `allocations`")
allocated = sorted(set([row["channel_id"] for row in c.fetchall()]))
c = self.db.execute("SELECT DISTINCT `channelid` FROM `allocations`")
allocated = sorted(set([row["channelid"] for row in c.fetchall()]))
request.setHeader(b"content-type", b"application/json; charset=utf-8")
data = {"welcome": self.welcome,
"channel-ids": allocated}
"channelids": allocated}
return (json.dumps(data)+"\n").encode("utf-8")
class Relay(resource.Resource, service.MultiService):
@ -214,45 +214,45 @@ class Relay(resource.Resource, service.MultiService):
return resource.ErrorPage(http.BAD_REQUEST,
"invalid channel id",
"invalid channel id")
channel_id = int(path)
if not channel_id in self.channels:
log.msg("spawning #%d" % channel_id)
self.channels[channel_id] = Channel(channel_id, self, self.db,
self.welcome)
return self.channels[channel_id]
channelid = int(path)
if not channelid in self.channels:
log.msg("spawning #%d" % channelid)
self.channels[channelid] = Channel(channelid, self, self.db,
self.welcome)
return self.channels[channelid]
def maybe_free_child(self, channel_id, side):
def maybe_free_child(self, channelid, side):
self.db.execute("DELETE FROM `allocations`"
" WHERE `channel_id`=? AND `side`=?",
(channel_id, side))
" WHERE `channelid`=? AND `side`=?",
(channelid, side))
self.db.commit()
remaining = self.db.execute("SELECT COUNT(*) FROM `allocations`"
" WHERE `channel_id`=?",
(channel_id,)).fetchone()[0]
" WHERE `channelid`=?",
(channelid,)).fetchone()[0]
if remaining:
return False
self.free_child(channel_id)
self.free_child(channelid)
return True
def free_child(self, channel_id):
self.db.execute("DELETE FROM `allocations` WHERE `channel_id`=?",
(channel_id,))
self.db.execute("DELETE FROM `messages` WHERE `channel_id`=?",
(channel_id,))
def free_child(self, channelid):
self.db.execute("DELETE FROM `allocations` WHERE `channelid`=?",
(channelid,))
self.db.execute("DELETE FROM `messages` WHERE `channelid`=?",
(channelid,))
self.db.commit()
if channel_id in self.channels:
self.channels.pop(channel_id)
if channelid in self.channels:
self.channels.pop(channelid)
log.msg("freed+killed #%d, now have %d DB channels, %d live" %
(channel_id, len(get_allocated(self.db)), len(self.channels)))
(channelid, len(get_allocated(self.db)), len(self.channels)))
def prune_old_channels(self):
old = time.time() - CHANNEL_EXPIRATION_TIME
for channel_id in get_allocated(self.db):
for channelid in get_allocated(self.db):
c = self.db.execute("SELECT `when` FROM `messages`"
" WHERE `channel_id`=?"
" ORDER BY `when` DESC LIMIT 1", (channel_id,))
" WHERE `channelid`=?"
" ORDER BY `when` DESC LIMIT 1", (channelid,))
rows = c.fetchall()
if not rows or (rows[0]["when"] < old):
log.msg("expiring %d" % channel_id)
self.free_child(channel_id)
log.msg("expiring %d" % channelid)
self.free_child(channelid)

View File

@ -76,20 +76,20 @@ class API(ServerBase, unittest.TestCase):
d = self.get("list")
def _check_list_1(data):
self.check_welcome(data)
self.failUnlessEqual(data["channel-ids"], [])
self.failUnlessEqual(data["channelids"], [])
d.addCallback(_check_list_1)
d.addCallback(lambda _: self.post("allocate", {"side": "abc"}))
def _allocated(data):
self.failUnlessEqual(set(data.keys()),
set(["welcome", "channel-id"]))
self.failUnlessIsInstance(data["channel-id"], int)
self.cid = data["channel-id"]
set(["welcome", "channelid"]))
self.failUnlessIsInstance(data["channelid"], int)
self.cid = data["channelid"]
d.addCallback(_allocated)
d.addCallback(lambda _: self.get("list"))
def _check_list_2(data):
self.failUnlessEqual(data["channel-ids"], [self.cid])
self.failUnlessEqual(data["channelids"], [self.cid])
d.addCallback(_check_list_2)
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
@ -100,7 +100,7 @@ class API(ServerBase, unittest.TestCase):
d.addCallback(lambda _: self.get("list"))
def _check_list_3(data):
self.failUnlessEqual(data["channel-ids"], [])
self.failUnlessEqual(data["channelids"], [])
d.addCallback(_check_list_3)
return d
@ -108,7 +108,7 @@ class API(ServerBase, unittest.TestCase):
def test_allocate_2(self):
d = self.post("allocate", {"side": "abc"})
def _allocated(data):
self.cid = data["channel-id"]
self.cid = data["channelid"]
d.addCallback(_allocated)
# second caller increases the number of known sides to 2
@ -119,7 +119,7 @@ class API(ServerBase, unittest.TestCase):
d.addCallback(lambda _: self.get("list"))
d.addCallback(lambda data:
self.failUnlessEqual(data["channel-ids"], [self.cid]))
self.failUnlessEqual(data["channelids"], [self.cid]))
d.addCallback(lambda _: self.post("%d/deallocate" % self.cid,
{"side": "abc"}))
@ -138,7 +138,7 @@ class API(ServerBase, unittest.TestCase):
d.addCallback(lambda _: self.get("list"))
d.addCallback(lambda data:
self.failUnlessEqual(data["channel-ids"], []))
self.failUnlessEqual(data["channelids"], []))
return d
@ -166,7 +166,7 @@ class API(ServerBase, unittest.TestCase):
def test_messages(self):
d = self.post("allocate", {"side": "abc"})
def _allocated(data):
self.cid = data["channel-id"]
self.cid = data["channelid"]
d.addCallback(_allocated)
d.addCallback(lambda _: self.add_message("msg1A"))
@ -211,7 +211,7 @@ class API(ServerBase, unittest.TestCase):
d = self.post("allocate", {"side": "abc"})
def _allocated(data):
self.cid = data["channel-id"]
self.cid = data["channelid"]
url = (self.relayurl+str(self.cid)).encode("utf-8")
self.o = OneEventAtATime(url, parser=json.loads)
return self.o.wait_for_connection()

View File

@ -46,9 +46,9 @@ def post_json(agent, url, request_body):
return d
class Channel:
def __init__(self, relay, channel_id, side, handle_welcome,
def __init__(self, relay, channelid, side, handle_welcome,
agent):
self._channel_url = "%s%d" % (relay, channel_id)
self._channel_url = "%s%d" % (relay, channelid)
self._side = side
self._handle_welcome = handle_welcome
self._agent = agent
@ -127,15 +127,15 @@ class ChannelManager:
def _got_channel(data):
if "welcome" in data:
self._handle_welcome(data["welcome"])
return data["channel-id"]
return data["channelid"]
d.addCallback(_got_channel)
return d
def list_channels(self):
raise NotImplementedError
def connect(self, channel_id):
return Channel(self._relay, channel_id, self._side,
def connect(self, channelid):
return Channel(self._relay, channelid, self._side,
self._handle_welcome, self._agent)
class Wormhole:
@ -186,29 +186,29 @@ class Wormhole:
if self._started_get_code: raise UsageError
self._started_get_code = True
d = self._channel_manager.allocate()
def _got_channel_id(channel_id):
code = codes.make_code(channel_id, code_length)
def _got_channelid(channelid):
code = codes.make_code(channelid, code_length)
assert isinstance(code, str), type(code)
self._set_code_and_channel_id(code)
self._set_code_and_channelid(code)
self._start()
return code
d.addCallback(_got_channel_id)
d.addCallback(_got_channelid)
return d
def set_code(self, code):
if not isinstance(code, str): raise UsageError
if self.code is not None: raise UsageError
self._set_code_and_channel_id(code)
self._set_code_and_channelid(code)
self._start()
def _set_code_and_channel_id(self, code):
def _set_code_and_channelid(self, code):
if self.code is not None: raise UsageError
mo = re.search(r'^(\d+)-', code)
if not mo:
raise ValueError("code (%s) must start with NN-" % code)
self.code = code
channel_id = int(mo.group(1))
self.channel = self._channel_manager.connect(channel_id)
channelid = int(mo.group(1))
self.channel = self._channel_manager.connect(channelid)
def _start(self):
# allocate the rest now too, so it can be serialized
@ -238,7 +238,7 @@ class Wormhole:
d = json.loads(data)
self = klass(d["appid"].encode("ascii"), d["relay"].encode("ascii"))
self._set_side(d["side"].encode("ascii"))
self._set_code_and_channel_id(d["code"].encode("ascii"))
self._set_code_and_channelid(d["code"].encode("ascii"))
self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["spake2"]))
self.msg1 = d["msg1"].decode("hex")
return self