From a03fb3900e2552116f0fc10bbca23c2d8737edfd Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 10 Apr 2015 12:00:08 -0500 Subject: [PATCH] relay: track allocations through DB --- src/wormhole/db-schemas/v1.sql | 9 ++++++- src/wormhole/servers/relay.py | 44 +++++++++++++++++++++++++--------- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/src/wormhole/db-schemas/v1.sql b/src/wormhole/db-schemas/v1.sql index 27262ea..43785de 100644 --- a/src/wormhole/db-schemas/v1.sql +++ b/src/wormhole/db-schemas/v1.sql @@ -15,4 +15,11 @@ CREATE TABLE `messages` `message` VARCHAR, `when` INTEGER ); -CREATE INDEX `lookup` ON `messages` (`channel_id`, `side`, `msgnum`); +CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `msgnum`); + +CREATE TABLE `allocations` +( + `channel_id` INTEGER, + `side` VARCHAR +); +CREATE INDEX `allocations_idx` ON `allocations` (`channel_id`); diff --git a/src/wormhole/servers/relay.py b/src/wormhole/servers/relay.py index dafba87..bedac04 100644 --- a/src/wormhole/servers/relay.py +++ b/src/wormhole/servers/relay.py @@ -66,8 +66,6 @@ class Channel(resource.Resource): self.db = db self.welcome = welcome self.expire_at = time.time() + CHANNEL_EXPIRATION_TIME - self.sides = set() - self.messages = [] # (side, msgnum, str) self.event_channels = set() # (side, msgnum, ep) @@ -87,8 +85,12 @@ class Channel(resource.Resource): handle = (their_side, their_msgnum, ep) self.event_channels.add(handle) request.notifyFinish().addErrback(self._shutdown, handle) - for (msg_side, msg_msgnum, msg_str) in self.messages: - self.message_added(msg_side, msg_msgnum, msg_str, channels=[handle]) + for row in self.db.execute("SELECT * FROM `messages`" + " WHERE `channel_id`=?" + " ORDER BY `when` ASC", + (self.channel_id,)).fetchall(): + self.message_added(row["side"], row["msgnum"], row["message"], + channels=[handle]) return server.NOT_DONE_YET def _shutdown(self, _, handle): @@ -107,12 +109,17 @@ class Channel(resource.Resource): def render_POST(self, request): # rest of URL is: SIDE/(MSGNUM|deallocate)/(post|poll) side = request.postpath[0] - self.sides.add(side) verb = request.postpath[1] if verb == "deallocate": - self.sides.remove(side) - if self.sides: + self.db.execute("DELETE FROM `allocations`" + " WHERE `channel_id`=? AND `side`=?", + (self.channel_id, side)) + self.db.commit() + remaining = self.db.execute("SELECT COUNT(*) FROM `allocations`" + " WHERE `channel_id`=?", + (self.channel_id,)).fetchone()[0] + if remaining: return "waiting\n" self.relay.free_child(self.channel_id) return "deleted\n" @@ -124,13 +131,25 @@ class Channel(resource.Resource): msgnum = request.postpath[2] other_messages = [] - for (msg_side, msg_msgnum, msg_str) in self.messages: - if msg_side != side and msg_msgnum == msgnum: - other_messages.append(msg_str) + for row in self.db.execute("SELECT `message` FROM `messages`" + " WHERE `channel_id`=? AND `side`!=?" + " AND `msgnum`=?" + " ORDER BY `when` ASC", + (self.channel_id, side, msgnum)).fetchall(): + other_messages.append(row["message"]) if verb == "post": data = json.load(request.content) - self.messages.append( (side, msgnum, data["message"]) ) + self.db.execute("INSERT INTO `messages`" + " (`channel_id`, `side`, `msgnum`, `message`, `when`)" + " VALUES (?,?,?,?,?)", + (self.channel_id, side, msgnum, data["message"], + time.time())) + self.db.execute("INSERT INTO `allocations`" + " (`channel_id`, `side`)" + " VALUES (?,?)", + (self.channel_id, side)) + self.db.commit() self.message_added(side, msgnum, data["message"]) request.setHeader("content-type", "application/json; charset=utf-8") @@ -213,6 +232,9 @@ class Relay(resource.Resource): return self.channels[channel_id] def free_child(self, channel_id): + self.db.execute("DELETE FROM `messages` WHERE `channel_id`=?", + (channel_id,)) + self.db.commit() self.channels.pop(channel_id) log.msg("freed #%d, now have %d channels" % (channel_id, len(self.channels)))