use sqlite to track channel allocation

This commit is contained in:
Brian Warner 2015-05-05 00:14:56 -07:00
commit ec90ef43da
4 changed files with 180 additions and 70 deletions

1
.gitignore vendored
View File

@ -55,3 +55,4 @@ docs/_build/
# PyBuilder # PyBuilder
target/ target/
/twistd.pid /twistd.pid
/relay.sqlite

40
src/wormhole/database.py Normal file
View File

@ -0,0 +1,40 @@
import os, sys
import sqlite3
from pkg_resources import resource_string
class DBError(Exception):
pass
def get_schema(version):
return resource_string("wormhole", "db-schemas/v%d.sql" % version)
def get_db(dbfile, stderr=sys.stderr):
"""Open or create the given db file. The parent directory must exist.
Returns the db connection object, or raises DBError.
"""
must_create = not os.path.exists(dbfile)
try:
db = sqlite3.connect(dbfile)
except (EnvironmentError, sqlite3.OperationalError), e:
raise DBError("Unable to create/open db file %s: %s" % (dbfile, e))
db.row_factory = sqlite3.Row
VERSION = 1
if must_create:
schema = get_schema(VERSION)
db.executescript(schema)
db.execute("INSERT INTO version (version) VALUES (?)", (VERSION,))
db.commit()
try:
version = db.execute("SELECT version FROM version").fetchone()[0]
except sqlite3.DatabaseError, e:
# this indicates that the file is not a compatible database format.
# Perhaps it was created with an old version, or it might be junk.
raise DBError("db file is unusable: %s" % e)
if version != VERSION:
raise DBError("Unable to handle db version %s" % version)
return db

View File

@ -0,0 +1,25 @@
-- note: anything which isn't an boolean, integer, or human-readable unicode
-- string, (i.e. binary strings) will be stored as hex
CREATE TABLE `version`
(
`version` INTEGER -- contains one row, set to 1
);
CREATE TABLE `messages`
(
`channel_id` INTEGER,
`side` VARCHAR,
`msgnum` VARCHAR, -- not numeric, more of a PAKE-phase indicator string
`message` VARCHAR,
`when` INTEGER
);
CREATE INDEX `messages_idx` ON `messages` (`channel_id`, `side`, `msgnum`);
CREATE TABLE `allocations`
(
`channel_id` INTEGER,
`side` VARCHAR
);
CREATE INDEX `allocations_idx` ON `allocations` (`channel_id`);

View File

@ -5,13 +5,16 @@ from twisted.internet import protocol
from twisted.application import strports, service, internet from twisted.application import strports, service, internet
from twisted.web import server, static, resource, http from twisted.web import server, static, resource, http
from .. import __version__ from .. import __version__
from ..database import get_db
SECONDS = 1.0 SECONDS = 1.0
MINUTE = 60*SECONDS MINUTE = 60*SECONDS
HOUR = 60*MINUTE HOUR = 60*MINUTE
DAY = 24*HOUR
MB = 1000*1000 MB = 1000*1000
CHANNEL_EXPIRATION_TIME = 1*HOUR CHANNEL_EXPIRATION_TIME = 3*DAY
EXPIRATION_CHECK_PERIOD = 2*HOUR
class EventsProtocol: class EventsProtocol:
def __init__(self, request): def __init__(self, request):
@ -58,17 +61,14 @@ class EventsProtocol:
class Channel(resource.Resource): class Channel(resource.Resource):
isLeaf = True # I handle /CHANNEL-ID/* isLeaf = True # I handle /CHANNEL-ID/*
def __init__(self, channel_id, relay, welcome): def __init__(self, channel_id, relay, db, welcome):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.channel_id = channel_id self.channel_id = channel_id
self.relay = relay self.relay = relay
self.db = db
self.welcome = welcome self.welcome = welcome
self.expire_at = time.time() + CHANNEL_EXPIRATION_TIME
self.sides = set()
self.messages = [] # (side, msgnum, str)
self.event_channels = set() # (side, msgnum, ep) self.event_channels = set() # (side, msgnum, ep)
def render_GET(self, request): def render_GET(self, request):
# rest of URL is: SIDE/poll/MSGNUM # rest of URL is: SIDE/poll/MSGNUM
their_side = request.postpath[0] their_side = request.postpath[0]
@ -85,14 +85,17 @@ class Channel(resource.Resource):
handle = (their_side, their_msgnum, ep) handle = (their_side, their_msgnum, ep)
self.event_channels.add(handle) self.event_channels.add(handle)
request.notifyFinish().addErrback(self._shutdown, handle) request.notifyFinish().addErrback(self._shutdown, handle)
for (msg_side, msg_msgnum, msg_str) in self.messages: for row in self.db.execute("SELECT * FROM `messages`"
self.message_added(msg_side, msg_msgnum, msg_str, channels=[handle]) " WHERE `channel_id`=?"
" ORDER BY `when` ASC",
(self.channel_id,)).fetchall():
self.message_added(row["side"], row["msgnum"], row["message"],
channels=[handle])
return server.NOT_DONE_YET return server.NOT_DONE_YET
def _shutdown(self, _, handle): def _shutdown(self, _, handle):
self.event_channels.discard(handle) self.event_channels.discard(handle)
def message_added(self, msg_side, msg_msgnum, msg_str, channels=None): def message_added(self, msg_side, msg_msgnum, msg_str, channels=None):
if channels is None: if channels is None:
channels = self.event_channels channels = self.event_channels
@ -101,19 +104,16 @@ class Channel(resource.Resource):
data = json.dumps({ "side": msg_side, "message": msg_str }) data = json.dumps({ "side": msg_side, "message": msg_str })
their_ep.sendEvent(data) their_ep.sendEvent(data)
def render_POST(self, request): def render_POST(self, request):
# rest of URL is: SIDE/(MSGNUM|deallocate)/(post|poll) # rest of URL is: SIDE/(MSGNUM|deallocate)/(post|poll)
side = request.postpath[0] side = request.postpath[0]
self.sides.add(side)
verb = request.postpath[1] verb = request.postpath[1]
if verb == "deallocate": if verb == "deallocate":
self.sides.remove(side) deleted = self.relay.maybe_free_child(self.channel_id, side)
if self.sides: if deleted:
return "waiting\n" return "deleted\n"
self.relay.free_child(self.channel_id) return "waiting\n"
return "deleted\n"
if verb not in ("post", "poll"): if verb not in ("post", "poll"):
request.setResponseCode(http.BAD_REQUEST) request.setResponseCode(http.BAD_REQUEST)
@ -122,96 +122,139 @@ class Channel(resource.Resource):
msgnum = request.postpath[2] msgnum = request.postpath[2]
other_messages = [] other_messages = []
for (msg_side, msg_msgnum, msg_str) in self.messages: for row in self.db.execute("SELECT `message` FROM `messages`"
if msg_side != side and msg_msgnum == msgnum: " WHERE `channel_id`=? AND `side`!=?"
other_messages.append(msg_str) " AND `msgnum`=?"
" ORDER BY `when` ASC",
(self.channel_id, side, msgnum)).fetchall():
other_messages.append(row["message"])
if verb == "post": if verb == "post":
data = json.load(request.content) data = json.load(request.content)
self.messages.append( (side, msgnum, data["message"]) ) self.db.execute("INSERT INTO `messages`"
" (`channel_id`, `side`, `msgnum`, `message`, `when`)"
" VALUES (?,?,?,?,?)",
(self.channel_id, side, msgnum, data["message"],
time.time()))
self.db.execute("INSERT INTO `allocations`"
" (`channel_id`, `side`)"
" VALUES (?,?)",
(self.channel_id, side))
self.db.commit()
self.message_added(side, msgnum, data["message"]) self.message_added(side, msgnum, data["message"])
request.setHeader("content-type", "application/json; charset=utf-8") request.setHeader("content-type", "application/json; charset=utf-8")
return json.dumps({"welcome": self.welcome, return json.dumps({"welcome": self.welcome,
"messages": other_messages})+"\n" "messages": other_messages})+"\n"
def get_allocated(db):
c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`")
return set([row["channel_id"] for row in c.fetchall()])
class Allocator(resource.Resource): class Allocator(resource.Resource):
isLeaf = True isLeaf = True
def __init__(self, relay, welcome): def __init__(self, db, welcome):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.relay = relay self.db = db
self.welcome = welcome self.welcome = welcome
def render_POST(self, request):
#side = request.postpath[0]
channel_id = self.relay.allocate_channel_id()
self.relay.channels[channel_id] = Channel(channel_id, self.relay,
self.welcome)
log.msg("allocated #%d, now have %d channels" %
(channel_id, len(self.relay.channels)))
request.setHeader("content-type", "application/json; charset=utf-8")
return json.dumps({"welcome": self.welcome,
"channel-id": channel_id})+"\n"
class ChannelList(resource.Resource):
def __init__(self, channel_ids, welcome):
resource.Resource.__init__(self)
self.channel_ids = channel_ids
self.welcome = welcome
def render_GET(self, request):
request.setHeader("content-type", "application/json; charset=utf-8")
return json.dumps({"welcome": self.welcome,
"channel-ids": self.channel_ids})+"\n"
class Relay(resource.Resource):
def __init__(self, welcome):
resource.Resource.__init__(self)
self.welcome = welcome
self.channels = {}
def prune_old_channels(self):
now = time.time()
for channel_id in list(self.channels):
c = self.channels[channel_id]
if c.expire_at < now:
log.msg("expiring %d" % channel_id)
self.free_child(channel_id)
def allocate_channel_id(self): def allocate_channel_id(self):
allocated = get_allocated(self.db)
for size in range(1,4): # stick to 1-999 for now for size in range(1,4): # stick to 1-999 for now
available = set() available = set()
for cid in range(10**(size-1), 10**size): for cid in range(10**(size-1), 10**size):
if cid not in self.channels: if cid not in allocated:
available.add(cid) available.add(cid)
if available: if available:
return random.choice(list(available)) return random.choice(list(available))
# ouch, 999 currently allocated. Try random ones for a while. # ouch, 999 currently allocated. Try random ones for a while.
for tries in range(1000): for tries in range(1000):
cid = random.randrange(1000, 1000*1000) cid = random.randrange(1000, 1000*1000)
if cid not in self.channels: if cid not in allocated:
return cid return cid
raise ValueError("unable to find a free channel-id") raise ValueError("unable to find a free channel-id")
def render_POST(self, request):
side = request.postpath[0]
channel_id = self.allocate_channel_id()
self.db.execute("INSERT INTO `allocations` VALUES (?,?)",
(channel_id, side))
self.db.commit()
log.msg("allocated #%d, now have %d DB channels" %
(channel_id, len(get_allocated(self.db))))
request.setHeader("content-type", "application/json; charset=utf-8")
return json.dumps({"welcome": self.welcome,
"channel-id": channel_id})+"\n"
class ChannelList(resource.Resource):
def __init__(self, db, welcome):
resource.Resource.__init__(self)
self.db = db
self.welcome = welcome
def render_GET(self, request):
c = self.db.execute("SELECT DISTINCT `channel_id` FROM `allocations`")
allocated = sorted(set([row["channel_id"] for row in c.fetchall()]))
request.setHeader("content-type", "application/json; charset=utf-8")
return json.dumps({"welcome": self.welcome,
"channel-ids": allocated})+"\n"
class Relay(resource.Resource):
def __init__(self, db, welcome):
resource.Resource.__init__(self)
self.db = db
self.welcome = welcome
self.channels = {}
def getChild(self, path, request): def getChild(self, path, request):
if path == "allocate": if path == "allocate":
return Allocator(self, self.welcome) return Allocator(self.db, self.welcome)
if path == "list": if path == "list":
channel_ids = sorted(self.channels.keys()) return ChannelList(self.db, self.welcome)
return ChannelList(channel_ids, self.welcome)
if not re.search(r'^\d+$', path): if not re.search(r'^\d+$', path):
return resource.ErrorPage(http.BAD_REQUEST, return resource.ErrorPage(http.BAD_REQUEST,
"invalid channel id", "invalid channel id",
"invalid channel id") "invalid channel id")
channel_id = int(path) channel_id = int(path)
if not channel_id in self.channels: if not channel_id in self.channels:
log.msg("claimed #%d, now have %d channels" % log.msg("spawning #%d" % channel_id)
(channel_id, len(self.channels))) self.channels[channel_id] = Channel(channel_id, self, self.db,
self.channels[channel_id] = Channel(channel_id, self, self.welcome) self.welcome)
return self.channels[channel_id] return self.channels[channel_id]
def maybe_free_child(self, channel_id, side):
self.db.execute("DELETE FROM `allocations`"
" WHERE `channel_id`=? AND `side`=?",
(channel_id, side))
self.db.commit()
remaining = self.db.execute("SELECT COUNT(*) FROM `allocations`"
" WHERE `channel_id`=?",
(channel_id,)).fetchone()[0]
if remaining:
return False
self.free_child(channel_id)
return True
def free_child(self, channel_id): def free_child(self, channel_id):
self.channels.pop(channel_id) self.db.execute("DELETE FROM `allocations` WHERE `channel_id`=?",
log.msg("freed #%d, now have %d channels" % (channel_id,))
(channel_id, len(self.channels))) self.db.execute("DELETE FROM `messages` WHERE `channel_id`=?",
(channel_id,))
self.db.commit()
if channel_id in self.channels:
self.channels.pop(channel_id)
log.msg("freed+killed #%d, now have %d DB channels, %d live" %
(channel_id, len(get_allocated(self.db)), len(self.channels)))
def prune_old_channels(self):
old = time.time() - CHANNEL_EXPIRATION_TIME
for channel_id in get_allocated(self.db):
c = self.db.execute("SELECT `when` FROM `messages`"
" WHERE `channel_id`=?"
" ORDER BY `when` DESC LIMIT 1", (channel_id,))
rows = c.fetchall()
if not rows or (rows[0]["when"] < old):
log.msg("expiring %d" % channel_id)
self.free_child(channel_id)
class TransitConnection(protocol.Protocol): class TransitConnection(protocol.Protocol):
def __init__(self): def __init__(self):
@ -342,7 +385,7 @@ class Root(resource.Resource):
class RelayServer(service.MultiService): class RelayServer(service.MultiService):
def __init__(self, relayport, transitport, advertise_version): def __init__(self, relayport, transitport, advertise_version):
service.MultiService.__init__(self) service.MultiService.__init__(self)
self.db = get_db("relay.sqlite")
welcome = { welcome = {
"current_version": __version__, "current_version": __version__,
# adding .motd will cause all clients to display the message, # adding .motd will cause all clients to display the message,
@ -358,9 +401,10 @@ class RelayServer(service.MultiService):
site = server.Site(self.root) site = server.Site(self.root)
self.relayport_service = strports.service(relayport, site) self.relayport_service = strports.service(relayport, site)
self.relayport_service.setServiceParent(self) self.relayport_service.setServiceParent(self)
self.relay = Relay(welcome) # accessible from tests self.relay = Relay(self.db, welcome) # accessible from tests
self.root.putChild("wormhole-relay", self.relay) self.root.putChild("wormhole-relay", self.relay)
t = internet.TimerService(5*MINUTE, self.relay.prune_old_channels) t = internet.TimerService(EXPIRATION_CHECK_PERIOD,
self.relay.prune_old_channels)
t.setServiceParent(self) t.setServiceParent(self)
self.transit = Transit() self.transit = Transit()
self.transit.setServiceParent(self) # for the timer self.transit.setServiceParent(self) # for the timer