diff --git a/.gitignore b/.gitignore index 1f8a294..f198bc7 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,4 @@ docs/_build/ # PyBuilder target/ /twistd.pid +/relay.sqlite diff --git a/src/wormhole/database.py b/src/wormhole/database.py new file mode 100644 index 0000000..b764a24 --- /dev/null +++ b/src/wormhole/database.py @@ -0,0 +1,40 @@ +import os, sys +import sqlite3 +from pkg_resources import resource_string + +class DBError(Exception): + pass + +def get_schema(version): + return resource_string("wormhole", "db-schemas/v%d.sql" % version) + +def get_db(dbfile, stderr=sys.stderr): + """Open or create the given db file. The parent directory must exist. + Returns the db connection object, or raises DBError. + """ + + must_create = not os.path.exists(dbfile) + try: + db = sqlite3.connect(dbfile) + except (EnvironmentError, sqlite3.OperationalError), e: + raise DBError("Unable to create/open db file %s: %s" % (dbfile, e)) + db.row_factory = sqlite3.Row + + VERSION = 1 + if must_create: + schema = get_schema(VERSION) + db.executescript(schema) + db.execute("INSERT INTO version (version) VALUES (?)", (VERSION,)) + db.commit() + + try: + version = db.execute("SELECT version FROM version").fetchone()[0] + except sqlite3.DatabaseError, e: + # this indicates that the file is not a compatible database format. + # Perhaps it was created with an old version, or it might be junk. + raise DBError("db file is unusable: %s" % e) + + if version != VERSION: + raise DBError("Unable to handle db version %s" % version) + + return db diff --git a/src/wormhole/db-schemas/v1.sql b/src/wormhole/db-schemas/v1.sql new file mode 100644 index 0000000..27262ea --- /dev/null +++ b/src/wormhole/db-schemas/v1.sql @@ -0,0 +1,18 @@ + +-- note: anything which isn't an boolean, integer, or human-readable unicode +-- string, (i.e. binary strings) will be stored as hex + +CREATE TABLE `version` +( + `version` INTEGER -- contains one row, set to 1 +); + +CREATE TABLE `messages` +( + `channel_id` INTEGER, + `side` VARCHAR, + `msgnum` VARCHAR, -- not numeric, more of a PAKE-phase indicator string + `message` VARCHAR, + `when` INTEGER +); +CREATE INDEX `lookup` ON `messages` (`channel_id`, `side`, `msgnum`); diff --git a/src/wormhole/servers/relay.py b/src/wormhole/servers/relay.py index 53a4ca3..dafba87 100644 --- a/src/wormhole/servers/relay.py +++ b/src/wormhole/servers/relay.py @@ -5,6 +5,7 @@ from twisted.internet import protocol from twisted.application import strports, service, internet from twisted.web import server, static, resource, http from .. import __version__ +from ..database import get_db SECONDS = 1.0 MINUTE = 60*SECONDS @@ -58,10 +59,11 @@ class EventsProtocol: class Channel(resource.Resource): isLeaf = True # I handle /CHANNEL-ID/* - def __init__(self, channel_id, relay, welcome): + def __init__(self, channel_id, relay, db, welcome): resource.Resource.__init__(self) self.channel_id = channel_id self.relay = relay + self.db = db self.welcome = welcome self.expire_at = time.time() + CHANNEL_EXPIRATION_TIME self.sides = set() @@ -145,7 +147,7 @@ class Allocator(resource.Resource): #side = request.postpath[0] channel_id = self.relay.allocate_channel_id() self.relay.channels[channel_id] = Channel(channel_id, self.relay, - self.welcome) + self.relay.db, self.welcome) log.msg("allocated #%d, now have %d channels" % (channel_id, len(self.relay.channels))) request.setHeader("content-type", "application/json; charset=utf-8") @@ -163,8 +165,9 @@ class ChannelList(resource.Resource): "channel-ids": self.channel_ids})+"\n" class Relay(resource.Resource): - def __init__(self, welcome): + def __init__(self, db, welcome): resource.Resource.__init__(self) + self.db = db self.welcome = welcome self.channels = {} @@ -205,7 +208,8 @@ class Relay(resource.Resource): if not channel_id in self.channels: log.msg("claimed #%d, now have %d channels" % (channel_id, len(self.channels))) - self.channels[channel_id] = Channel(channel_id, self, self.welcome) + self.channels[channel_id] = Channel(channel_id, self, self.db, + self.welcome) return self.channels[channel_id] def free_child(self, channel_id): @@ -342,7 +346,7 @@ class Root(resource.Resource): class RelayServer(service.MultiService): def __init__(self, relayport, transitport, advertise_version): service.MultiService.__init__(self) - + self.db = get_db("relay.sqlite") welcome = { "current_version": __version__, # adding .motd will cause all clients to display the message, @@ -358,7 +362,7 @@ class RelayServer(service.MultiService): site = server.Site(self.root) self.relayport_service = strports.service(relayport, site) self.relayport_service.setServiceParent(self) - self.relay = Relay(welcome) # accessible from tests + self.relay = Relay(self.db, welcome) # accessible from tests self.root.putChild("wormhole-relay", self.relay) t = internet.TimerService(5*MINUTE, self.relay.prune_old_channels) t.setServiceParent(self)