diff --git a/.gitignore b/.gitignore index 1f8a294..f198bc7 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,4 @@ docs/_build/ # PyBuilder target/ /twistd.pid +/relay.sqlite diff --git a/src/wormhole/database.py b/src/wormhole/database.py new file mode 100644 index 0000000..b764a24 --- /dev/null +++ b/src/wormhole/database.py @@ -0,0 +1,40 @@ +import os, sys +import sqlite3 +from pkg_resources import resource_string + +class DBError(Exception): + pass + +def get_schema(version): + return resource_string("wormhole", "db-schemas/v%d.sql" % version) + +def get_db(dbfile, stderr=sys.stderr): + """Open or create the given db file. The parent directory must exist. + Returns the db connection object, or raises DBError. + """ + + must_create = not os.path.exists(dbfile) + try: + db = sqlite3.connect(dbfile) + except (EnvironmentError, sqlite3.OperationalError), e: + raise DBError("Unable to create/open db file %s: %s" % (dbfile, e)) + db.row_factory = sqlite3.Row + + VERSION = 1 + if must_create: + schema = get_schema(VERSION) + db.executescript(schema) + db.execute("INSERT INTO version (version) VALUES (?)", (VERSION,)) + db.commit() + + try: + version = db.execute("SELECT version FROM version").fetchone()[0] + except sqlite3.DatabaseError, e: + # this indicates that the file is not a compatible database format. + # Perhaps it was created with an old version, or it might be junk. + raise DBError("db file is unusable: %s" % e) + + if version != VERSION: + raise DBError("Unable to handle db version %s" % version) + + return db diff --git a/src/wormhole/db-schemas/v1.sql b/src/wormhole/db-schemas/v1.sql new file mode 100644 index 0000000..43785de --- /dev/null +++ b/src/wormhole/db-schemas/v1.sql @@ -0,0 +1,25 @@ + +-- note: anything which isn't an boolean, integer, or human-readable unicode +-- string, (i.e. binary strings) will be stored as hex + +CREATE TABLE `version` +( + `version` INTEGER -- contains one row, set to 1 +); + +CREATE TABLE `messages` +( + `channel_id` INTEGER, + `side` VARCHAR, + `msgnum` VARCHAR, -- not numeric, more of a PAKE-phase indicator string + `message` VARCHAR, + `when` INTEGER +); +CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `msgnum`); + +CREATE TABLE `allocations` +( + `channel_id` INTEGER, + `side` VARCHAR +); +CREATE INDEX `allocations_idx` ON `allocations` (`channel_id`); diff --git a/src/wormhole/servers/relay.py b/src/wormhole/servers/relay.py index 53a4ca3..68dbb3e 100644 --- a/src/wormhole/servers/relay.py +++ b/src/wormhole/servers/relay.py @@ -5,13 +5,16 @@ from twisted.internet import protocol from twisted.application import strports, service, internet from twisted.web import server, static, resource, http from .. import __version__ +from ..database import get_db SECONDS = 1.0 MINUTE = 60*SECONDS HOUR = 60*MINUTE +DAY = 24*HOUR MB = 1000*1000 -CHANNEL_EXPIRATION_TIME = 1*HOUR +CHANNEL_EXPIRATION_TIME = 3*DAY +EXPIRATION_CHECK_PERIOD = 2*HOUR class EventsProtocol: def __init__(self, request): @@ -58,17 +61,14 @@ class EventsProtocol: class Channel(resource.Resource): isLeaf = True # I handle /CHANNEL-ID/* - def __init__(self, channel_id, relay, welcome): + def __init__(self, channel_id, relay, db, welcome): resource.Resource.__init__(self) self.channel_id = channel_id self.relay = relay + self.db = db self.welcome = welcome - self.expire_at = time.time() + CHANNEL_EXPIRATION_TIME - self.sides = set() - self.messages = [] # (side, msgnum, str) self.event_channels = set() # (side, msgnum, ep) - def render_GET(self, request): # rest of URL is: SIDE/poll/MSGNUM their_side = request.postpath[0] @@ -85,14 +85,17 @@ class Channel(resource.Resource): handle = (their_side, their_msgnum, ep) self.event_channels.add(handle) request.notifyFinish().addErrback(self._shutdown, handle) - for (msg_side, msg_msgnum, msg_str) in self.messages: - self.message_added(msg_side, msg_msgnum, msg_str, channels=[handle]) + for row in self.db.execute("SELECT * FROM `messages`" + " WHERE `channel_id`=?" + " ORDER BY `when` ASC", + (self.channel_id,)).fetchall(): + self.message_added(row["side"], row["msgnum"], row["message"], + channels=[handle]) return server.NOT_DONE_YET def _shutdown(self, _, handle): self.event_channels.discard(handle) - def message_added(self, msg_side, msg_msgnum, msg_str, channels=None): if channels is None: channels = self.event_channels @@ -101,19 +104,16 @@ class Channel(resource.Resource): data = json.dumps({ "side": msg_side, "message": msg_str }) their_ep.sendEvent(data) - def render_POST(self, request): # rest of URL is: SIDE/(MSGNUM|deallocate)/(post|poll) side = request.postpath[0] - self.sides.add(side) verb = request.postpath[1] if verb == "deallocate": - self.sides.remove(side) - if self.sides: - return "waiting\n" - self.relay.free_child(self.channel_id) - return "deleted\n" + deleted = self.relay.maybe_free_child(self.channel_id, side) + if deleted: + return "deleted\n" + return "waiting\n" if verb not in ("post", "poll"): request.setResponseCode(http.BAD_REQUEST) @@ -122,96 +122,139 @@ class Channel(resource.Resource): msgnum = request.postpath[2] other_messages = [] - for (msg_side, msg_msgnum, msg_str) in self.messages: - if msg_side != side and msg_msgnum == msgnum: - other_messages.append(msg_str) + for row in self.db.execute("SELECT `message` FROM `messages`" + " WHERE `channel_id`=? AND `side`!=?" + " AND `msgnum`=?" + " ORDER BY `when` ASC", + (self.channel_id, side, msgnum)).fetchall(): + other_messages.append(row["message"]) if verb == "post": data = json.load(request.content) - self.messages.append( (side, msgnum, data["message"]) ) + self.db.execute("INSERT INTO `messages`" + " (`channel_id`, `side`, `msgnum`, `message`, `when`)" + " VALUES (?,?,?,?,?)", + (self.channel_id, side, msgnum, data["message"], + time.time())) + self.db.execute("INSERT INTO `allocations`" + " (`channel_id`, `side`)" + " VALUES (?,?)", + (self.channel_id, side)) + self.db.commit() self.message_added(side, msgnum, data["message"]) request.setHeader("content-type", "application/json; charset=utf-8") return json.dumps({"welcome": self.welcome, "messages": other_messages})+"\n" +def get_allocated(db): + c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`") + return set([row["channel_id"] for row in c.fetchall()]) + class Allocator(resource.Resource): isLeaf = True - def __init__(self, relay, welcome): + def __init__(self, db, welcome): resource.Resource.__init__(self) - self.relay = relay + self.db = db self.welcome = welcome - def render_POST(self, request): - #side = request.postpath[0] - channel_id = self.relay.allocate_channel_id() - self.relay.channels[channel_id] = Channel(channel_id, self.relay, - self.welcome) - log.msg("allocated #%d, now have %d channels" % - (channel_id, len(self.relay.channels))) - request.setHeader("content-type", "application/json; charset=utf-8") - return json.dumps({"welcome": self.welcome, - "channel-id": channel_id})+"\n" - -class ChannelList(resource.Resource): - def __init__(self, channel_ids, welcome): - resource.Resource.__init__(self) - self.channel_ids = channel_ids - self.welcome = welcome - def render_GET(self, request): - request.setHeader("content-type", "application/json; charset=utf-8") - return json.dumps({"welcome": self.welcome, - "channel-ids": self.channel_ids})+"\n" - -class Relay(resource.Resource): - def __init__(self, welcome): - resource.Resource.__init__(self) - self.welcome = welcome - self.channels = {} - - def prune_old_channels(self): - now = time.time() - for channel_id in list(self.channels): - c = self.channels[channel_id] - if c.expire_at < now: - log.msg("expiring %d" % channel_id) - self.free_child(channel_id) def allocate_channel_id(self): + allocated = get_allocated(self.db) for size in range(1,4): # stick to 1-999 for now available = set() for cid in range(10**(size-1), 10**size): - if cid not in self.channels: + if cid not in allocated: available.add(cid) if available: return random.choice(list(available)) # ouch, 999 currently allocated. Try random ones for a while. for tries in range(1000): cid = random.randrange(1000, 1000*1000) - if cid not in self.channels: + if cid not in allocated: return cid raise ValueError("unable to find a free channel-id") + def render_POST(self, request): + side = request.postpath[0] + channel_id = self.allocate_channel_id() + self.db.execute("INSERT INTO `allocations` VALUES (?,?)", + (channel_id, side)) + self.db.commit() + log.msg("allocated #%d, now have %d DB channels" % + (channel_id, len(get_allocated(self.db)))) + request.setHeader("content-type", "application/json; charset=utf-8") + return json.dumps({"welcome": self.welcome, + "channel-id": channel_id})+"\n" + +class ChannelList(resource.Resource): + def __init__(self, db, welcome): + resource.Resource.__init__(self) + self.db = db + self.welcome = welcome + def render_GET(self, request): + c = self.db.execute("SELECT DISTINCT `channel_id` FROM `allocations`") + allocated = sorted(set([row["channel_id"] for row in c.fetchall()])) + request.setHeader("content-type", "application/json; charset=utf-8") + return json.dumps({"welcome": self.welcome, + "channel-ids": allocated})+"\n" + +class Relay(resource.Resource): + def __init__(self, db, welcome): + resource.Resource.__init__(self) + self.db = db + self.welcome = welcome + self.channels = {} + def getChild(self, path, request): if path == "allocate": - return Allocator(self, self.welcome) + return Allocator(self.db, self.welcome) if path == "list": - channel_ids = sorted(self.channels.keys()) - return ChannelList(channel_ids, self.welcome) + return ChannelList(self.db, self.welcome) if not re.search(r'^\d+$', path): return resource.ErrorPage(http.BAD_REQUEST, "invalid channel id", "invalid channel id") channel_id = int(path) if not channel_id in self.channels: - log.msg("claimed #%d, now have %d channels" % - (channel_id, len(self.channels))) - self.channels[channel_id] = Channel(channel_id, self, self.welcome) + log.msg("spawning #%d" % channel_id) + self.channels[channel_id] = Channel(channel_id, self, self.db, + self.welcome) return self.channels[channel_id] + def maybe_free_child(self, channel_id, side): + self.db.execute("DELETE FROM `allocations`" + " WHERE `channel_id`=? AND `side`=?", + (channel_id, side)) + self.db.commit() + remaining = self.db.execute("SELECT COUNT(*) FROM `allocations`" + " WHERE `channel_id`=?", + (channel_id,)).fetchone()[0] + if remaining: + return False + self.free_child(channel_id) + return True + def free_child(self, channel_id): - self.channels.pop(channel_id) - log.msg("freed #%d, now have %d channels" % - (channel_id, len(self.channels))) + self.db.execute("DELETE FROM `allocations` WHERE `channel_id`=?", + (channel_id,)) + self.db.execute("DELETE FROM `messages` WHERE `channel_id`=?", + (channel_id,)) + self.db.commit() + if channel_id in self.channels: + self.channels.pop(channel_id) + log.msg("freed+killed #%d, now have %d DB channels, %d live" % + (channel_id, len(get_allocated(self.db)), len(self.channels))) + + def prune_old_channels(self): + old = time.time() - CHANNEL_EXPIRATION_TIME + for channel_id in get_allocated(self.db): + c = self.db.execute("SELECT `when` FROM `messages`" + " WHERE `channel_id`=?" + " ORDER BY `when` DESC LIMIT 1", (channel_id,)) + rows = c.fetchall() + if not rows or (rows[0]["when"] < old): + log.msg("expiring %d" % channel_id) + self.free_child(channel_id) class TransitConnection(protocol.Protocol): def __init__(self): @@ -342,7 +385,7 @@ class Root(resource.Resource): class RelayServer(service.MultiService): def __init__(self, relayport, transitport, advertise_version): service.MultiService.__init__(self) - + self.db = get_db("relay.sqlite") welcome = { "current_version": __version__, # adding .motd will cause all clients to display the message, @@ -358,9 +401,10 @@ class RelayServer(service.MultiService): site = server.Site(self.root) self.relayport_service = strports.service(relayport, site) self.relayport_service.setServiceParent(self) - self.relay = Relay(welcome) # accessible from tests + self.relay = Relay(self.db, welcome) # accessible from tests self.root.putChild("wormhole-relay", self.relay) - t = internet.TimerService(5*MINUTE, self.relay.prune_old_channels) + t = internet.TimerService(EXPIRATION_CHECK_PERIOD, + self.relay.prune_old_channels) t.setServiceParent(self) self.transit = Transit() self.transit.setServiceParent(self) # for the timer