Compare commits

...

2 Commits

Author SHA1 Message Date
Brian Warner
790f29d4ba WIP db 2017-11-04 12:40:54 -07:00
Brian Warner
d7800f6337 rewrite connection handling, not sure it's a good idea 2017-11-04 12:23:32 -07:00
5 changed files with 377 additions and 152 deletions

8
docs/running.md Normal file
View File

@ -0,0 +1,8 @@
# How to Run the Transit Relay
```
pip install magic-wormhole-transit-relay
twist wormhole-transit-relay --port tcp:4001
```
The relay runs as a twist/twistd plugin. To

View File

@ -0,0 +1,126 @@
from __future__ import unicode_literals
import os
import sqlite3
import tempfile
from pkg_resources import resource_string
from twisted.python import log
class DBError(Exception):
pass
def get_schema(version):
schema_bytes = resource_string("wormhole.server",
"db-schemas/v%d.sql" % version)
return schema_bytes.decode("utf-8")
def get_upgrader(new_version):
schema_bytes = resource_string("wormhole.server",
"db-schemas/upgrade-to-v%d.sql" % new_version)
return schema_bytes.decode("utf-8")
TARGET_VERSION = 3
def dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
def _initialize_db_schema(db, target_version):
"""Creates the application schema in the given database.
"""
log.msg("populating new database with schema v%s" % target_version)
schema = get_schema(target_version)
db.executescript(schema)
db.execute("INSERT INTO version (version) VALUES (?)",
(target_version,))
db.commit()
def _initialize_db_connection(db):
"""Sets up the db connection object with a row factory and with necessary
foreign key settings.
"""
db.row_factory = dict_factory
db.execute("PRAGMA foreign_keys = ON")
problems = db.execute("PRAGMA foreign_key_check").fetchall()
if problems:
raise DBError("failed foreign key check: %s" % (problems,))
def _open_db_connection(dbfile):
"""Open a new connection to the SQLite3 database at the given path.
"""
try:
db = sqlite3.connect(dbfile)
except (EnvironmentError, sqlite3.OperationalError) as e:
raise DBError("Unable to create/open db file %s: %s" % (dbfile, e))
_initialize_db_connection(db)
return db
def _get_temporary_dbfile(dbfile):
"""Get a temporary filename near the given path.
"""
fd, name = tempfile.mkstemp(
prefix=os.path.basename(dbfile) + ".",
dir=os.path.dirname(dbfile)
)
os.close(fd)
return name
def _atomic_create_and_initialize_db(dbfile, target_version):
"""Create and return a new database, initialized with the application
schema.
If anything goes wrong, nothing is left at the ``dbfile`` path.
"""
temp_dbfile = _get_temporary_dbfile(dbfile)
db = _open_db_connection(temp_dbfile)
_initialize_db_schema(db, target_version)
db.close()
os.rename(temp_dbfile, dbfile)
return _open_db_connection(dbfile)
def get_db(dbfile, target_version=TARGET_VERSION):
"""Open or create the given db file. The parent directory must exist.
Returns the db connection object, or raises DBError.
"""
if dbfile == ":memory:":
db = _open_db_connection(dbfile)
_initialize_db_schema(db, target_version)
elif os.path.exists(dbfile):
db = _open_db_connection(dbfile)
else:
db = _atomic_create_and_initialize_db(dbfile, target_version)
try:
version = db.execute("SELECT version FROM version").fetchone()["version"]
except sqlite3.DatabaseError as e:
# this indicates that the file is not a compatible database format.
# Perhaps it was created with an old version, or it might be junk.
raise DBError("db file is unusable: %s" % e)
while version < target_version:
log.msg(" need to upgrade from %s to %s" % (version, target_version))
try:
upgrader = get_upgrader(version+1)
except ValueError: # ResourceError??
log.msg(" unable to upgrade %s to %s" % (version, version+1))
raise DBError("Unable to upgrade %s to version %s, left at %s"
% (dbfile, version+1, version))
log.msg(" executing upgrader v%s->v%s" % (version, version+1))
db.executescript(upgrader)
db.commit()
version = version+1
if version != target_version:
raise DBError("Unable to handle db version %s" % version)
return db
def dump_db(db):
# to let _iterdump work, we need to restore the original row factory
orig = db.row_factory
try:
db.row_factory = sqlite3.Row
return "".join(db.iterdump())
finally:
db.row_factory = orig

View File

@ -0,0 +1,30 @@
CREATE TABLE `version` -- contains one row
(
`version` INTEGER -- set to 1
);
CREATE TABLE `current` -- contains one row
(
`reboot` INTEGER, -- seconds since epoch of most recent reboot
`last_update` INTEGER, -- when `current` was last updated
`connected` INTEGER, -- number of current paired connections
`waiting` INTEGER, -- number of not-yet-paired connections
`incomplete_bytes` INTEGER -- bytes sent through not-yet-complete connections
);
CREATE TABLE `usage`
(
`started` INTEGER, -- seconds since epoch, rounded to "blur time"
`total_time` INTEGER, -- seconds from open to last close
`waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None
`total_bytes` INTEGER, -- total bytes relayed (both directions)
`result` VARCHAR -- happy, scary, lonely, errory, pruney
-- transit moods:
-- "errory": one side gave the wrong handshake
-- "lonely": good handshake, but the other side never showed up
-- "happy": both sides gave correct handshake
);
CREATE INDEX `transit_usage_idx` ON `transit_usage` (`started`);
CREATE INDEX `transit_usage_result_idx` ON `transit_usage` (`result`);

View File

@ -9,10 +9,10 @@ This plugin sets up a 'Transit Relay' server for magic-wormhole. This service
listens for TCP connections, finds pairs which present the same handshake, and listens for TCP connections, finds pairs which present the same handshake, and
glues the two TCP sockets together. glues the two TCP sockets together.
If --usage-logfile= is provided, a line will be written to the given file after If --log-stdout is provided, a line will be written to stdout after each
each connection is done. This line will be a complete JSON object (starting connection is done. This line will be a complete JSON object (starting with
with "{", ending with "}\n", and containing no internal newlines). The keys "{", ending with "}\n", and containing no internal newlines). The keys will
will be: be:
* 'started': number, seconds since epoch * 'started': number, seconds since epoch
* 'total_time': number, seconds from open to last close * 'total_time': number, seconds from open to last close
@ -27,35 +27,62 @@ second matching side never appeared (and thus 'waiting_time' will be null).
If --blur-usage= is provided, then 'started' will be rounded to the given time If --blur-usage= is provided, then 'started' will be rounded to the given time
interval, and 'total_bytes' will be rounded as well. interval, and 'total_bytes' will be rounded as well.
If --stats-file is provided, the server will periodically write a simple JSON If --usage-db= is provided, the server will maintain a SQLite database in the
dictionary to that file (atomically), with cumulative usage data (since last given file. Current, recent, and historical usage data will be written to the
reboot, and all-time). This information is *not* blurred (the assumption is database, and external tools can query the DB for metrics: the munin plugins
that it will be overwritten on a regular basis, and is aggregated anyways). The in misc/ may be useful. Timestamps and sizes in this file will respect
keys are: --blur-usage. The four tables are:
* active.connected: number of paired connections "current" contains a single row, with these columns:
* active.waiting: number of not-yet-paired connections
* since_reboot.bytes: sum of 'total_bytes'
* since_reboot.total: number of completed connections
* since_reboot.moods: dict mapping mood string to number of connections
* all_time.bytes: same
* all_time.total
* all_time.moods
The server will write twistd.pid and twistd.log files as usual, if daemonized * connected: number of paired connections
by twistd. twistd.log will only contain startup, shutdown, and exception * waiting: number of not-yet-paired connections
messages. To record information about each connection, use --usage-logfile. * partal_bytes: bytes transmitted over not-yet-complete connections
"since_reboot" contains a single row, with these columns:
* bytes: sum of 'total_bytes'
* connections: number of completed connections
* mood_happy: count of connections that finished "happy": both sides gave correct handshake
* mood_lonely: one side gave good handshake, other side never showed up
* mood_errory: one side gave a bad handshake
"all_time" contains a single row, with these columns:
* bytes:
* connections:
* mood_happy:
* mood_lonely:
* mood_errory:
"usage" contains one row per closed connection, with these columns:
* started: seconds since epoch, rounded to "blur time"
* total_time: seconds from first open to last close
* waiting_time: seconds from first open to second open, or None
* bytes: total bytes relayed (in both directions)
* result: (string) the mood: happy, lonely, errory
All tables will be updated after each connection is finished. In addition,
the "current" table will be updated at least once every 5 minutes.
If daemonized by twistd, the server will write twistd.pid and twistd.log
files as usual. By default twistd.log will only contain startup, shutdown,
and exception messages. Adding --log-stdout will add per-connection JSON
lines to twistd.log.
""" """
class Options(usage.Options): class Options(usage.Options):
#synopsis = "[--port=] [--usage-logfile=] [--blur-usage=] [--stats-json=]" #synopsis = "[--port=] [--log-stdout] [--blur-usage=] [--usage-db=]"
longdesc = LONGDESC longdesc = LONGDESC
optFlags = {
("log-stdout", None, "write JSON usage logs to stdout"),
}
optParameters = [ optParameters = [
("port", "p", "tcp:4001", "endpoint to listen on"), ("port", "p", "tcp:4001", "endpoint to listen on"),
("blur-usage", None, None, "blur timestamps and data sizes in logs"), ("blur-usage", None, None, "blur timestamps and data sizes in logs"),
("usage-logfile", None, None, "record usage data (JSON lines)"), ("usage-db", None, None, "record usage data (SQLite)"),
("stats-file", None, None, "record usage in JSON format"),
] ]
def opt_blur_usage(self, arg): def opt_blur_usage(self, arg):
@ -65,6 +92,6 @@ class Options(usage.Options):
def makeService(config, reactor=reactor): def makeService(config, reactor=reactor):
ep = endpoints.serverFromString(reactor, config["port"]) # to listen ep = endpoints.serverFromString(reactor, config["port"]) # to listen
f = transit_server.Transit(blur_usage=config["blur-usage"], f = transit_server.Transit(blur_usage=config["blur-usage"],
usage_logfile=config["usage-logfile"], log_stdout=config["log-stdout"],
stats_file=config["stats-file"]) usage_db=config["usage-db"])
return StreamServerEndpointService(ep, f) return StreamServerEndpointService(ep, f)

View File

@ -2,6 +2,7 @@ from __future__ import print_function, unicode_literals
import os, re, time, json import os, re, time, json
from twisted.python import log from twisted.python import log
from twisted.internet import protocol from twisted.internet import protocol
from .database import get_db
SECONDS = 1.0 SECONDS = 1.0
MINUTE = 60*SECONDS MINUTE = 60*SECONDS
@ -28,7 +29,6 @@ class TransitConnection(protocol.Protocol):
self._token_buffer = b"" self._token_buffer = b""
self._sent_ok = False self._sent_ok = False
self._buddy = None self._buddy = None
self._had_buddy = False
self._total_sent = 0 self._total_sent = 0
def describeToken(self): def describeToken(self):
@ -131,7 +131,7 @@ class TransitConnection(protocol.Protocol):
def _got_handshake(self, token, side): def _got_handshake(self, token, side):
self._got_token = token self._got_token = token
self._got_side = side self._got_side = side
self.factory.connection_got_token(token, side, self) self.factory.transitGotToken(token, side, self)
def buddy_connected(self, them): def buddy_connected(self, them):
self._buddy = them self._buddy = them
@ -153,39 +153,57 @@ class TransitConnection(protocol.Protocol):
def connectionLost(self, reason): def connectionLost(self, reason):
if self._buddy: if self._buddy:
self._buddy.buddy_disconnected() self._buddy.buddy_disconnected() # hang up on the buddy
self.factory.transitFinished(self, self._got_token, self._got_side, self.factory.transitFinished(self, self._got_token, self._got_side,
self.describeToken()) self.describeToken())
# Record usage. There are four cases:
# * 1: we connected, never had a buddy
# * 2: we connected first, we disconnect before the buddy
# * 3: we connected first, buddy disconnects first
# * 4: buddy connected first, we disconnect before buddy
# * 5: buddy connected first, buddy disconnects first
# whoever disconnects first gets to write the usage record (1,2,4)
finished = time.time()
if not self._had_buddy: # 1
total_time = finished - self._started
self.factory.recordUsage(self._started, "lonely", 0,
total_time, None)
if self._had_buddy and self._buddy: # 2,4
total_bytes = self._total_sent + self._buddy._total_sent
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
self.factory.recordUsage(self._started, "happy", total_bytes,
total_time, waiting_time)
def disconnect(self): def disconnect(self):
# called when we hang up on a connection because they violated the
# protocol, or we abandon a losing connection because a different one
# from that side won
self.transport.loseConnection() self.transport.loseConnection()
self.factory.transitFailed(self) self.factory.transitFailed(self)
finished = time.time()
total_time = finished - self._started PENDING, OPEN, LINGERING, EMPTY = range(4)
self.factory.recordUsage(self._started, "errory", 0,
total_time, None) class Channel(object):
def __init__(self, factory):
self._factory = factory
self._state = PENDING
self._connections = set() # (side, tc)
def gotConnection(self, side, tc):
if self._state == PENDING:
for old in self._connections:
(old_side, old_tc) = old
if ((old_side is None)
or (side is None)
or (old_side != side)):
# we found a match
if self._debug_log:
log.msg("transit relay 2: %s" % new_tc.describeToken())
self._state = OPEN
self._factory.channelOpen(self)
# drop and stop tracking the rest
self._connections.remove(old)
for (_, leftover_tc) in self._connections:
# TODO: not "errory"? the ones we drop are the parallel
# connections from the first client ('side' was the
# same), so it's not really an error. More of a "you
# lost, one of your other connections won, sorry"
leftover_tc.disconnect()
self._pending_tokens.pop(token)
# glue the two ends together
self._active_connections.add(new_tc)
self._active_connections.add(old_tc)
new_tc.buddy_connected(old_tc)
old_tc.buddy_connected(new_tc)
return
if self._debug_log:
log.msg("transit relay 1: %s" % new_tc.describeToken())
potentials.add((new_side, new_tc))
# TODO: timer
class Transit(protocol.ServerFactory): class Transit(protocol.ServerFactory):
# I manage pairs of simultaneous connections to a secondary TCP port, # I manage pairs of simultaneous connections to a secondary TCP port,
@ -220,119 +238,135 @@ class Transit(protocol.ServerFactory):
MAXTIME = 60*SECONDS MAXTIME = 60*SECONDS
protocol = TransitConnection protocol = TransitConnection
def __init__(self, blur_usage, usage_logfile, stats_file): def __init__(self, blur_usage, log_stdout, usage_db):
self._blur_usage = blur_usage self._blur_usage = blur_usage
self._log_requests = blur_usage is None self._debug_log = False
self._usage_logfile = open(usage_logfile, "a") if usage_logfile else None self._log_stdout = log_stdout
self._stats_file = stats_file self._db = None
self._pending_requests = {} # token -> set((side, TransitConnection)) if usage_db:
self._active_connections = set() # TransitConnection self._db = get_db(usage_db)
self._counts = {"lonely": 0, "happy": 0, "errory": 0} # we don't track TransitConnections until they submit a token
self._count_bytes = 0
def connection_got_token(self, token, new_side, new_tc): # Channels are indexed by token, and are either pending, open, or
if token not in self._pending_requests: # lingering
self._pending_requests[token] = set() self._channels = {} # token -> Channel
potentials = self._pending_requests[token] self._pending_channels = set()
for old in potentials: self._open_channels = set()
(old_side, old_tc) = old self._lingering_channels = set()
if ((old_side is None)
or (new_side is None)
or (old_side != new_side)):
# we found a match
if self._log_requests:
log.msg("transit relay 2: %s" % new_tc.describeToken())
# drop and stop tracking the rest def transitGotToken(self, token, new_side, new_tc):
potentials.remove(old) if token not in self._channels:
for (_, leftover_tc) in potentials: self._channels[token] = Channel(self)
leftover_tc.disconnect() # TODO: not "errory"? self._channels[token].gotConnection(new_side, new_tc)
self._pending_requests.pop(token)
# glue the two ends together def channelOpen(self, c):
self._active_connections.add(new_tc) self._pending_channels.remove(c)
self._active_connections.add(old_tc) self._open_channels.add(c)
new_tc.buddy_connected(old_tc) def channelClosed(self, c):
old_tc.buddy_connected(new_tc) self._open_channels.remove(c)
return self._lingering_channels.add(c)
if self._log_requests: def channelEmpty(self, c):
log.msg("transit relay 1: %s" % new_tc.describeToken()) self._lingering_channels.remove(c)
potentials.add((new_side, new_tc))
# TODO: timer def transitFinished(self, tc, token, side, description):
# we're called each time a TransitConnection shuts down
if token in self._pending_tokens:
side_tc = (side, tc)
if side_tc in self._pending_tokens[token]:
self._pending_tokens[token].remove(side_tc)
if not self._pending_tokens[token]: # set is now empty
del self._pending_tokens[token]
if self._debug_log:
log.msg("transitFinished %s" % (description,))
self._active_connections.discard(tc)
# Record usage. There are five cases:
# * 1: we connected, never had a buddy
# * 2: we connected first, we disconnect before the buddy
# * 3: we connected first, buddy disconnects first
# * 4: buddy connected first, we disconnect before buddy
# * 5: buddy connected first, buddy disconnects first
# whoever disconnects first gets to write the usage record (1,2,4)
finished = time.time()
if self._had_buddy:
if self._buddy: # 2,4: we disconnected first
total_bytes = self._total_sent + self._buddy._total_sent
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
self.factory.buddyIsLingering(self._buddy)
self.factory.recordUsage(self._started, "happy", total_bytes,
total_time, waiting_time)
else: # 3, 5: we disconnected last
self.factory.doneLingering(self)
else: # 1: we were the only one
total_time = finished - self._started
self.factory.recordUsage(self._started, "lonely", 0,
total_time, None)
def transitFailed(self, p):
if self._debug_log:
log.msg("transitFailed %r" % p)
finished = time.time()
total_time = finished - self._started
self.factory.recordUsage(self._started, "errory", 0,
total_time, None)
pass
def buddyIsLingering(self, buddy_tc):
self._active_connections.remove(buddy_tc)
self._lingering_connections.add(buddy_tc)
def doneLingering(self, old_tc):
self._lingering_connections.remove(buddy_tc)
def recordUsage(self, started, result, total_bytes, def recordUsage(self, started, result, total_bytes,
total_time, waiting_time): total_time, waiting_time):
self._counts[result] += 1 if self._debug_log:
self._count_bytes += total_bytes
if self._log_requests:
log.msg(format="Transit.recordUsage {bytes}B", bytes=total_bytes) log.msg(format="Transit.recordUsage {bytes}B", bytes=total_bytes)
if self._blur_usage: if self._blur_usage:
started = self._blur_usage * (started // self._blur_usage) started = self._blur_usage * (started // self._blur_usage)
total_bytes = blur_size(total_bytes) total_bytes = blur_size(total_bytes)
if self._usage_logfile: if self._log_stdout:
data = {"started": started, data = {"started": started,
"total_time": total_time, "total_time": total_time,
"waiting_time": waiting_time, "waiting_time": waiting_time,
"total_bytes": total_bytes, "total_bytes": total_bytes,
"mood": result, "mood": result,
} }
self._usage_logfile.write(json.dumps(data)) sys.stdout.write(json.dumps(data))
self._usage_logfile.write("\n") sys.stdout.write("\n")
self._usage_logfile.flush() sys.stdout.flush()
if self._stats_file: if self._db:
self._update_stats(total_bytes, result) self._db.execute("INSERT INTO `usage`"
" (`started`, `total_time`, `waiting_time`,"
" `total_bytes`, `result`)"
" VALUES (?,?,?, ?,?)",
(started, total_time, waiting_time,
total_bytes, result))
self._update_stats()
self._db.commit()
def transitFinished(self, tc, token, side, description): def _update_stats(self):
if token in self._pending_requests: # current status: should be zero when idle
side_tc = (side, tc) reboot = self._reboot
if side_tc in self._pending_requests[token]: last_update = time.time()
self._pending_requests[token].remove(side_tc) connected = len(self._active_connections) / 2
if not self._pending_requests[token]: # set is now empty # TODO: when a connection is half-closed, len(active) will be odd. a
del self._pending_requests[token] # moment later (hopefully) the other side will disconnect, but
if self._log_requests: # _update_stats isn't updated until later.
log.msg("transitFinished %s" % (description,)) waiting = len(self._pending_tokens)
self._active_connections.discard(tc) # "waiting" doesn't count multiple parallel connections from the same
# side
def transitFailed(self, p): incomplete_bytes = sum(tc._total_sent
if self._log_requests: for tc in self._active_connections)
log.msg("transitFailed %r" % p) self._db.execute("DELETE FROM `current`")
pass self._db.execute("INSERT INTO `current`"
" (`reboot`, `last_update`, `connected`, `waiting`,"
def _update_stats(self, total_bytes, mood): " `incomplete_bytes`)"
try: " VALUES (?, ?, ?, ?, ?)",
with open(self._stats_file, "r") as f: (reboot, last_update, connected, waiting,
stats = json.load(f) incomplete_bytes))
except (EnvironmentError, ValueError):
stats = {}
# current status: expected to be zero most of the time
stats["active"] = {"connected": len(self._active_connections) / 2,
"waiting": len(self._pending_requests),
}
# usage since last reboot
rb = stats["since_reboot"] = {}
rb["bytes"] = self._count_bytes
rb["total"] = sum(self._counts.values(), 0)
rbm = rb["moods"] = {}
for result, count in self._counts.items():
rbm[result] = count
# historical usage (all-time)
if "all_time" not in stats:
stats["all_time"] = {}
u = stats["all_time"]
u["total"] = u.get("total", 0) + 1
u["bytes"] = u.get("bytes", 0) + total_bytes
if "moods" not in u:
u["moods"] = {}
um = u["moods"]
for m in "happy", "lonely", "errory":
if m not in um:
um[m] = 0
um[mood] += 1
tmpfile = self._stats_file + ".tmp"
with open(tmpfile, "w") as f:
f.write(json.dumps(stats))
f.write("\n")
os.rename(tmpfile, self._stats_file)