relay: add database, not used yet

This commit is contained in:
Brian Warner 2015-04-10 11:15:27 -05:00
parent c3b048a4d3
commit 043392ee2a
4 changed files with 69 additions and 6 deletions

1
.gitignore vendored
View File

@ -55,3 +55,4 @@ docs/_build/
# PyBuilder # PyBuilder
target/ target/
/twistd.pid /twistd.pid
/relay.sqlite

40
src/wormhole/database.py Normal file
View File

@ -0,0 +1,40 @@
import os, sys
import sqlite3
from pkg_resources import resource_string
class DBError(Exception):
pass
def get_schema(version):
return resource_string("wormhole", "db-schemas/v%d.sql" % version)
def get_db(dbfile, stderr=sys.stderr):
"""Open or create the given db file. The parent directory must exist.
Returns the db connection object, or raises DBError.
"""
must_create = not os.path.exists(dbfile)
try:
db = sqlite3.connect(dbfile)
except (EnvironmentError, sqlite3.OperationalError), e:
raise DBError("Unable to create/open db file %s: %s" % (dbfile, e))
db.row_factory = sqlite3.Row
VERSION = 1
if must_create:
schema = get_schema(VERSION)
db.executescript(schema)
db.execute("INSERT INTO version (version) VALUES (?)", (VERSION,))
db.commit()
try:
version = db.execute("SELECT version FROM version").fetchone()[0]
except sqlite3.DatabaseError, e:
# this indicates that the file is not a compatible database format.
# Perhaps it was created with an old version, or it might be junk.
raise DBError("db file is unusable: %s" % e)
if version != VERSION:
raise DBError("Unable to handle db version %s" % version)
return db

View File

@ -0,0 +1,18 @@
-- note: anything which isn't an boolean, integer, or human-readable unicode
-- string, (i.e. binary strings) will be stored as hex
CREATE TABLE `version`
(
`version` INTEGER -- contains one row, set to 1
);
CREATE TABLE `messages`
(
`channel_id` INTEGER,
`side` VARCHAR,
`msgnum` VARCHAR, -- not numeric, more of a PAKE-phase indicator string
`message` VARCHAR,
`when` INTEGER
);
CREATE INDEX `lookup` ON `messages` (`channel_id`, `side`, `msgnum`);

View File

@ -5,6 +5,7 @@ from twisted.internet import protocol
from twisted.application import strports, service, internet from twisted.application import strports, service, internet
from twisted.web import server, static, resource, http from twisted.web import server, static, resource, http
from .. import __version__ from .. import __version__
from ..database import get_db
SECONDS = 1.0 SECONDS = 1.0
MINUTE = 60*SECONDS MINUTE = 60*SECONDS
@ -58,10 +59,11 @@ class EventsProtocol:
class Channel(resource.Resource): class Channel(resource.Resource):
isLeaf = True # I handle /CHANNEL-ID/* isLeaf = True # I handle /CHANNEL-ID/*
def __init__(self, channel_id, relay, welcome): def __init__(self, channel_id, relay, db, welcome):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.channel_id = channel_id self.channel_id = channel_id
self.relay = relay self.relay = relay
self.db = db
self.welcome = welcome self.welcome = welcome
self.expire_at = time.time() + CHANNEL_EXPIRATION_TIME self.expire_at = time.time() + CHANNEL_EXPIRATION_TIME
self.sides = set() self.sides = set()
@ -145,7 +147,7 @@ class Allocator(resource.Resource):
#side = request.postpath[0] #side = request.postpath[0]
channel_id = self.relay.allocate_channel_id() channel_id = self.relay.allocate_channel_id()
self.relay.channels[channel_id] = Channel(channel_id, self.relay, self.relay.channels[channel_id] = Channel(channel_id, self.relay,
self.welcome) self.relay.db, self.welcome)
log.msg("allocated #%d, now have %d channels" % log.msg("allocated #%d, now have %d channels" %
(channel_id, len(self.relay.channels))) (channel_id, len(self.relay.channels)))
request.setHeader("content-type", "application/json; charset=utf-8") request.setHeader("content-type", "application/json; charset=utf-8")
@ -163,8 +165,9 @@ class ChannelList(resource.Resource):
"channel-ids": self.channel_ids})+"\n" "channel-ids": self.channel_ids})+"\n"
class Relay(resource.Resource): class Relay(resource.Resource):
def __init__(self, welcome): def __init__(self, db, welcome):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.db = db
self.welcome = welcome self.welcome = welcome
self.channels = {} self.channels = {}
@ -205,7 +208,8 @@ class Relay(resource.Resource):
if not channel_id in self.channels: if not channel_id in self.channels:
log.msg("claimed #%d, now have %d channels" % log.msg("claimed #%d, now have %d channels" %
(channel_id, len(self.channels))) (channel_id, len(self.channels)))
self.channels[channel_id] = Channel(channel_id, self, self.welcome) self.channels[channel_id] = Channel(channel_id, self, self.db,
self.welcome)
return self.channels[channel_id] return self.channels[channel_id]
def free_child(self, channel_id): def free_child(self, channel_id):
@ -342,7 +346,7 @@ class Root(resource.Resource):
class RelayServer(service.MultiService): class RelayServer(service.MultiService):
def __init__(self, relayport, transitport, advertise_version): def __init__(self, relayport, transitport, advertise_version):
service.MultiService.__init__(self) service.MultiService.__init__(self)
self.db = get_db("relay.sqlite")
welcome = { welcome = {
"current_version": __version__, "current_version": __version__,
# adding .motd will cause all clients to display the message, # adding .motd will cause all clients to display the message,
@ -358,7 +362,7 @@ class RelayServer(service.MultiService):
site = server.Site(self.root) site = server.Site(self.root)
self.relayport_service = strports.service(relayport, site) self.relayport_service = strports.service(relayport, site)
self.relayport_service.setServiceParent(self) self.relayport_service.setServiceParent(self)
self.relay = Relay(welcome) # accessible from tests self.relay = Relay(self.db, welcome) # accessible from tests
self.root.putChild("wormhole-relay", self.relay) self.root.putChild("wormhole-relay", self.relay)
t = internet.TimerService(5*MINUTE, self.relay.prune_old_channels) t = internet.TimerService(5*MINUTE, self.relay.prune_old_channels)
t.setServiceParent(self) t.setServiceParent(self)