From 183303e11e5d341e6b6fad85a26ccbc05f0b77f4 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 4 May 2015 18:13:14 -0700 Subject: [PATCH] rework expiration, prune after 3 days, check every 2 hours --- src/wormhole/servers/relay.py | 121 ++++++++++++++++++++-------------- 1 file changed, 70 insertions(+), 51 deletions(-) diff --git a/src/wormhole/servers/relay.py b/src/wormhole/servers/relay.py index 5599e5a..1d43e25 100644 --- a/src/wormhole/servers/relay.py +++ b/src/wormhole/servers/relay.py @@ -10,6 +10,7 @@ from ..database import get_db SECONDS = 1.0 MINUTE = 60*SECONDS HOUR = 60*MINUTE +DAY = 24*HOUR MB = 1000*1000 CHANNEL_EXPIRATION_TIME = 1*HOUR @@ -112,17 +113,10 @@ class Channel(resource.Resource): verb = request.postpath[1] if verb == "deallocate": - self.db.execute("DELETE FROM `allocations`" - " WHERE `channel_id`=? AND `side`=?", - (self.channel_id, side)) - self.db.commit() - remaining = self.db.execute("SELECT COUNT(*) FROM `allocations`" - " WHERE `channel_id`=?", - (self.channel_id,)).fetchone()[0] - if remaining: - return "waiting\n" - self.relay.free_child(self.channel_id) - return "deleted\n" + deleted = self.relay.maybe_free_child(self.channel_id, side) + if deleted: + return "deleted\n" + return "waiting\n" if verb not in ("post", "poll"): request.setResponseCode(http.BAD_REQUEST) @@ -156,19 +150,41 @@ class Channel(resource.Resource): return json.dumps({"welcome": self.welcome, "messages": other_messages})+"\n" +def get_allocated(db): + c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`") + return set([row["channel_id"] for row in c.fetchall()]) + class Allocator(resource.Resource): isLeaf = True - def __init__(self, relay, welcome): + def __init__(self, db, welcome): resource.Resource.__init__(self) - self.relay = relay + self.db = db self.welcome = welcome + + def allocate_channel_id(self): + allocated = get_allocated(self.db) + for size in range(1,4): # stick to 1-999 for now + available = set() + for cid in range(10**(size-1), 10**size): + if cid not in allocated: + available.add(cid) + if available: + return random.choice(list(available)) + # ouch, 999 currently allocated. Try random ones for a while. + for tries in range(1000): + cid = random.randrange(1000, 1000*1000) + if cid not in allocated: + return cid + raise ValueError("unable to find a free channel-id") + def render_POST(self, request): - #side = request.postpath[0] - channel_id = self.relay.allocate_channel_id() - self.relay.channels[channel_id] = Channel(channel_id, self.relay, - self.relay.db, self.welcome) - log.msg("allocated #%d, now have %d channels" % - (channel_id, len(self.relay.channels))) + side = request.postpath[0] + channel_id = self.allocate_channel_id() + self.db.execute("INSERT INTO `allocations` VALUES (?,?)", + (channel_id, side)) + self.db.commit() + log.msg("allocated #%d, now have %d DB channels" % + (channel_id, len(get_allocated(self.db)))) request.setHeader("content-type", "application/json; charset=utf-8") return json.dumps({"welcome": self.welcome, "channel-id": channel_id})+"\n" @@ -186,40 +202,17 @@ class ChannelList(resource.Resource): "channel-ids": allocated})+"\n" class Relay(resource.Resource): + PRUNE_AGE = 3*DAY # old channels expire after 3 days + def __init__(self, db, welcome): resource.Resource.__init__(self) self.db = db self.welcome = welcome self.channels = {} - def prune_old_channels(self): - now = time.time() - for channel_id in list(self.channels): - c = self.channels[channel_id] - if c.expire_at < now: - log.msg("expiring %d" % channel_id) - self.free_child(channel_id) - - def allocate_channel_id(self): - c = self.db.execute("SELECT DISTINCT `channel_id` FROM `allocations`") - allocated = set([row["channel_id"] for row in c.fetchall()]) - for size in range(1,4): # stick to 1-999 for now - available = set() - for cid in range(10**(size-1), 10**size): - if cid not in allocated: - available.add(cid) - if available: - return random.choice(list(available)) - # ouch, 999 currently allocated. Try random ones for a while. - for tries in range(1000): - cid = random.randrange(1000, 1000*1000) - if cid not in allocated: - return cid - raise ValueError("unable to find a free channel-id") - def getChild(self, path, request): if path == "allocate": - return Allocator(self, self.welcome) + return Allocator(self.db, self.welcome) if path == "list": return ChannelList(self.db, self.welcome) if not re.search(r'^\d+$', path): @@ -228,19 +221,45 @@ class Relay(resource.Resource): "invalid channel id") channel_id = int(path) if not channel_id in self.channels: - log.msg("claimed #%d, now have %d channels" % - (channel_id, len(self.channels))) + log.msg("spawning #%d" % channel_id) self.channels[channel_id] = Channel(channel_id, self, self.db, self.welcome) return self.channels[channel_id] + def maybe_free_child(self, channel_id, side): + self.db.execute("DELETE FROM `allocations`" + " WHERE `channel_id`=? AND `side`=?", + (channel_id, side)) + self.db.commit() + remaining = self.db.execute("SELECT COUNT(*) FROM `allocations`" + " WHERE `channel_id`=?", + (channel_id,)).fetchone()[0] + if remaining: + return False + self.free_child(channel_id) + return True + def free_child(self, channel_id): + self.db.execute("DELETE FROM `allocations` WHERE `channel_id`=?", + (channel_id,)) self.db.execute("DELETE FROM `messages` WHERE `channel_id`=?", (channel_id,)) self.db.commit() - self.channels.pop(channel_id) - log.msg("freed #%d, now have %d channels" % - (channel_id, len(self.channels))) + if channel_id in self.channels: + self.channels.pop(channel_id) + log.msg("freed+killed #%d, now have %d DB channels, %d live" % + (channel_id, len(get_allocated(self.db)), len(self.channels))) + + def prune_old_channels(self): + old = time.time() - self.PRUNE_AGE + for channel_id in get_allocated(self.db): + c = self.db.execute("SELECT `when` FROM `messages`" + " WHERE `channel_id`=?" + " ORDER BY `when` DESC LIMIT 1", (channel_id,)) + rows = c.fetchall() + if not rows or (rows[0]["when"] < old): + log.msg("expiring %d" % channel_id) + self.free_child(channel_id) class TransitConnection(protocol.Protocol): def __init__(self): @@ -389,7 +408,7 @@ class RelayServer(service.MultiService): self.relayport_service.setServiceParent(self) self.relay = Relay(self.db, welcome) # accessible from tests self.root.putChild("wormhole-relay", self.relay) - t = internet.TimerService(5*MINUTE, self.relay.prune_old_channels) + t = internet.TimerService(2*HOUR, self.relay.prune_old_channels) t.setServiceParent(self) self.transit = Transit() self.transit.setServiceParent(self) # for the timer