diff --git a/docs/logging.md b/docs/logging.md new file mode 100644 index 0000000..3cbc957 --- /dev/null +++ b/docs/logging.md @@ -0,0 +1,91 @@ +# Usage Logs + +The transit relay does not emit or record any logging by default. By adding +option flags to the twist/twistd command line, you can enable one of two +different kinds of logs. + +To avoid collecting information which could later be used to correlate +clients with external network traces, logged information can be "blurred". +This reduces the resolution of the data, retaining enough to answer questions +about how much the server is being used, but discarding fine-grained +timestamps or exact transfer sizes. The ``--blur-usage=`` option enables +this, and it takes an integer value (in seconds) to specify the desired time +window. + +## Logging JSON Upon Each Connection + +If --log-fd is provided, a line will be written to the given (numeric) file +descriptor after each connection is done. These events could be delivered to +a comprehensive logging system like XXX for offline analysis. + +Each line will be a complete JSON object (starting with ``{``, ending with +``}\n``, and containing no internal newlines). The keys will be: + +* ``started``: number, seconds since epoch +* ``total_time``: number, seconds from open to last close +* ``waiting_time``: number, seconds from start to 2nd side appearing, or null +* ``total_bytes``: number, total bytes relayed (sum of both directions) +* ``mood``: string, one of: happy, lonely, errory + +A mood of ``happy`` means both sides gave a correct handshake. ``lonely`` +means a second matching side never appeared (and thus ``waiting_time`` will +be null). ``errory`` means the first side gave an invalid handshake. + +If --blur-usage= is provided, then ``started`` will be rounded to the given +time interval, and ``total_bytes`` will be rounded to a fixed set of buckets: + +* file sizes less than 1MB: rounded to the next largest multiple of 10kB +* less than 1GB: multiple of 1MB +* 1GB or larger: multiple of 100MB + +## Usage Database + +If --usage-db= is provided, the server will maintain a SQLite database in the +given file. Current, recent, and historical usage data will be written to the +database, and external tools can query the DB for metrics: the munin plugins +in misc/ may be useful. Timestamps and sizes in this file will respect +--blur-usage. The four tables are: + +``current`` contains a single row, with these columns: + +* connected: number of paired connections +* waiting: number of not-yet-paired connections +* partal_bytes: bytes transmitted over not-yet-complete connections + +``since_reboot`` contains a single row, with these columns: + +* bytes: sum of ``total_bytes`` +* connections: number of completed connections +* mood_happy: count of connections that finished "happy": both sides gave correct handshake +* mood_lonely: one side gave good handshake, other side never showed up +* mood_errory: one side gave a bad handshake + +``all_time`` contains a single row, with these columns: + +* bytes: +* connections: +* mood_happy: +* mood_lonely: +* mood_errory: + +``usage`` contains one row per closed connection, with these columns: + +* started: seconds since epoch, rounded to "blur time" +* total_time: seconds from first open to last close +* waiting_time: seconds from first open to second open, or None +* bytes: total bytes relayed (in both directions) +* result: (string) the mood: happy, lonely, errory + +All tables will be updated after each connection is finished. In addition, +the ``current`` table will be updated at least once every 5 minutes. + +## Logfiles for twistd + +If daemonized by twistd, the server will write ``twistd.pid`` and +``twistd.log`` files as usual. By default ``twistd.log`` will only contain +startup, shutdown, and exception messages. + +Setting ``--log-fd=1`` (file descriptor 1 is always stdout) will cause the +per-connection JSON lines to be interleaved with any messages sent to +Twisted's logging system. It may be better to use a different file +descriptor. diff --git a/misc/migrate_usage_db.py b/misc/migrate_usage_db.py new file mode 100644 index 0000000..6530d40 --- /dev/null +++ b/misc/migrate_usage_db.py @@ -0,0 +1,47 @@ +"""Migrate the usage data from the old bundled Transit Relay database. + +The magic-wormhole package used to include both servers (Rendezvous and +Transit). "wormhole server" started both of these, and used the +"relay.sqlite" database to store both immediate server state and long-term +usage data. + +These were split out to their own packages: version 0.11 omitted the Transit +Relay in favor of the new "magic-wormhole-transit-relay" distribution. + +This script reads the long-term Transit usage data from the pre-0.11 +wormhole-server relay.sqlite, and copies it into a new "usage.sqlite" +database in the current directory. + +It will refuse to touch an existing "usage.sqlite" file. + +The resuting "usage.sqlite" should be passed into --usage-db=, e.g. "twist +transitrelay --usage=.../PATH/TO/usage.sqlite". +""" + +from __future__ import unicode_literals, print_function +import sys +from wormhole_transit_relay.database import open_existing_db, create_db + +source_fn = sys.argv[1] +source_db = open_existing_db(source_fn) +target_db = create_db("usage.sqlite") + +num_rows = 0 +for row in source_db.execute("SELECT * FROM `transit_usage`" + " ORDER BY `started`").fetchall(): + target_db.execute("INSERT INTO `usage`" + " (`started`, `total_time`, `waiting_time`," + " `total_bytes`, `result`)" + " VALUES(?,?,?,?,?)", + (row["started"], row["total_time"], row["waiting_time"], + row["total_bytes"], row["result"])) + num_rows += 1 +target_db.execute("INSERT INTO `current`" + " (`rebooted`, `updated`, `connected`, `waiting`," + " `incomplete_bytes`)" + " VALUES(?,?,?,?,?)", + (0, 0, 0, 0, 0)) +target_db.commit() + +print("usage database migrated (%d rows) into 'usage.sqlite'" % num_rows) +sys.exit(0) diff --git a/src/wormhole_transit_relay/database.py b/src/wormhole_transit_relay/database.py new file mode 100644 index 0000000..7bb09d8 --- /dev/null +++ b/src/wormhole_transit_relay/database.py @@ -0,0 +1,152 @@ +from __future__ import unicode_literals +import os +import sqlite3 +import tempfile +from pkg_resources import resource_string +from twisted.python import log + +class DBError(Exception): + pass + +def get_schema(version): + schema_bytes = resource_string("wormhole_transit_relay", + "db-schemas/v%d.sql" % version) + return schema_bytes.decode("utf-8") + +def get_upgrader(new_version): + schema_bytes = resource_string("wormhole_transit_relay", + "db-schemas/upgrade-to-v%d.sql" % new_version) + return schema_bytes.decode("utf-8") + +TARGET_VERSION = 1 + +def dict_factory(cursor, row): + d = {} + for idx, col in enumerate(cursor.description): + d[col[0]] = row[idx] + return d + +def _initialize_db_schema(db, target_version): + """Creates the application schema in the given database. + """ + log.msg("populating new database with schema v%s" % target_version) + schema = get_schema(target_version) + db.executescript(schema) + db.execute("INSERT INTO version (version) VALUES (?)", + (target_version,)) + db.commit() + +def _initialize_db_connection(db): + """Sets up the db connection object with a row factory and with necessary + foreign key settings. + """ + db.row_factory = dict_factory + db.execute("PRAGMA foreign_keys = ON") + problems = db.execute("PRAGMA foreign_key_check").fetchall() + if problems: + raise DBError("failed foreign key check: %s" % (problems,)) + +def _open_db_connection(dbfile): + """Open a new connection to the SQLite3 database at the given path. + """ + try: + db = sqlite3.connect(dbfile) + except (EnvironmentError, sqlite3.OperationalError) as e: + raise DBError("Unable to create/open db file %s: %s" % (dbfile, e)) + _initialize_db_connection(db) + return db + +def _get_temporary_dbfile(dbfile): + """Get a temporary filename near the given path. + """ + fd, name = tempfile.mkstemp( + prefix=os.path.basename(dbfile) + ".", + dir=os.path.dirname(dbfile) + ) + os.close(fd) + return name + +def _atomic_create_and_initialize_db(dbfile, target_version): + """Create and return a new database, initialized with the application + schema. + + If anything goes wrong, nothing is left at the ``dbfile`` path. + """ + temp_dbfile = _get_temporary_dbfile(dbfile) + db = _open_db_connection(temp_dbfile) + _initialize_db_schema(db, target_version) + db.close() + os.rename(temp_dbfile, dbfile) + return _open_db_connection(dbfile) + +def get_db(dbfile, target_version=TARGET_VERSION): + """Open or create the given db file. The parent directory must exist. + Returns the db connection object, or raises DBError. + """ + if dbfile == ":memory:": + db = _open_db_connection(dbfile) + _initialize_db_schema(db, target_version) + elif os.path.exists(dbfile): + db = _open_db_connection(dbfile) + else: + db = _atomic_create_and_initialize_db(dbfile, target_version) + + try: + version = db.execute("SELECT version FROM version").fetchone()["version"] + except sqlite3.DatabaseError as e: + # this indicates that the file is not a compatible database format. + # Perhaps it was created with an old version, or it might be junk. + raise DBError("db file is unusable: %s" % e) + + while version < target_version: + log.msg(" need to upgrade from %s to %s" % (version, target_version)) + try: + upgrader = get_upgrader(version+1) + except ValueError: # ResourceError?? + log.msg(" unable to upgrade %s to %s" % (version, version+1)) + raise DBError("Unable to upgrade %s to version %s, left at %s" + % (dbfile, version+1, version)) + log.msg(" executing upgrader v%s->v%s" % (version, version+1)) + db.executescript(upgrader) + db.commit() + version = version+1 + + if version != target_version: + raise DBError("Unable to handle db version %s" % version) + + return db + +class DBDoesntExist(Exception): + pass + +def open_existing_db(dbfile): + assert dbfile != ":memory:" + if not os.path.exists(dbfile): + raise DBDoesntExist() + return _open_db_connection(dbfile) + +class DBAlreadyExists(Exception): + pass + +def create_db(dbfile): + """Create the given db file. Refuse to touch a pre-existing file. + + This is meant for use by migration tools, to create the output target""" + + if dbfile == ":memory:": + db = _open_db_connection(dbfile) + _initialize_db_schema(db, TARGET_VERSION) + elif os.path.exists(dbfile): + raise DBAlreadyExists() + else: + db = _atomic_create_and_initialize_db(dbfile, TARGET_VERSION) + return db + +def dump_db(db): + # to let _iterdump work, we need to restore the original row factory + orig = db.row_factory + try: + db.row_factory = sqlite3.Row + return "".join(db.iterdump()) + finally: + db.row_factory = orig diff --git a/src/wormhole_transit_relay/db-schemas/v1.sql b/src/wormhole_transit_relay/db-schemas/v1.sql new file mode 100644 index 0000000..f68d742 --- /dev/null +++ b/src/wormhole_transit_relay/db-schemas/v1.sql @@ -0,0 +1,30 @@ + +CREATE TABLE `version` -- contains one row +( + `version` INTEGER -- set to 1 +); + + +CREATE TABLE `current` -- contains one row +( + `rebooted` INTEGER, -- seconds since epoch of most recent reboot + `updated` INTEGER, -- when `current` was last updated + `connected` INTEGER, -- number of current paired connections + `waiting` INTEGER, -- number of not-yet-paired connections + `incomplete_bytes` INTEGER -- bytes sent through not-yet-complete connections +); + +CREATE TABLE `usage` +( + `started` INTEGER, -- seconds since epoch, rounded to "blur time" + `total_time` INTEGER, -- seconds from open to last close + `waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None + `total_bytes` INTEGER, -- total bytes relayed (both directions) + `result` VARCHAR -- happy, scary, lonely, errory, pruney + -- transit moods: + -- "errory": one side gave the wrong handshake + -- "lonely": good handshake, but the other side never showed up + -- "happy": both sides gave correct handshake +); +CREATE INDEX `usage_started_index` ON `usage` (`started`); +CREATE INDEX `usage_result_index` ON `usage` (`result`); diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index 1444a85..0f7ca50 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -1,70 +1,42 @@ +import os from . import transit_server from twisted.internet import reactor from twisted.python import usage -from twisted.application.internet import StreamServerEndpointService +from twisted.application.service import MultiService +from twisted.application.internet import (TimerService, + StreamServerEndpointService) from twisted.internet import endpoints LONGDESC = """\ This plugin sets up a 'Transit Relay' server for magic-wormhole. This service listens for TCP connections, finds pairs which present the same handshake, and glues the two TCP sockets together. - -If --usage-logfile= is provided, a line will be written to the given file after -each connection is done. This line will be a complete JSON object (starting -with "{", ending with "}\n", and containing no internal newlines). The keys -will be: - -* 'started': number, seconds since epoch -* 'total_time': number, seconds from open to last close -* 'waiting_time': number, seconds from start to 2nd side appearing, or null -* 'total_bytes': number, total bytes relayed (sum of both directions) -* 'mood': string, one of: happy, lonely, errory - -A mood of "happy" means both sides gave a correct handshake. "lonely" means a -second matching side never appeared (and thus 'waiting_time' will be null). -"errory" means the first side gave an invalid handshake. - -If --blur-usage= is provided, then 'started' will be rounded to the given time -interval, and 'total_bytes' will be rounded as well. - -If --stats-file is provided, the server will periodically write a simple JSON -dictionary to that file (atomically), with cumulative usage data (since last -reboot, and all-time). This information is *not* blurred (the assumption is -that it will be overwritten on a regular basis, and is aggregated anyways). The -keys are: - -* active.connected: number of paired connections -* active.waiting: number of not-yet-paired connections -* since_reboot.bytes: sum of 'total_bytes' -* since_reboot.total: number of completed connections -* since_reboot.moods: dict mapping mood string to number of connections -* all_time.bytes: same -* all_time.total -* all_time.moods - -The server will write twistd.pid and twistd.log files as usual, if daemonized -by twistd. twistd.log will only contain startup, shutdown, and exception -messages. To record information about each connection, use --usage-logfile. """ class Options(usage.Options): - #synopsis = "[--port=] [--usage-logfile=] [--blur-usage=] [--stats-json=]" + synopsis = "[--port=] [--log-fd] [--blur-usage=] [--usage-db=]" longdesc = LONGDESC optParameters = [ ("port", "p", "tcp:4001", "endpoint to listen on"), ("blur-usage", None, None, "blur timestamps and data sizes in logs"), - ("usage-logfile", None, None, "record usage data (JSON lines)"), - ("stats-file", None, None, "record usage in JSON format"), + ("log-fd", None, None, "write JSON usage logs to this file descriptor"), + ("usage-db", None, None, "record usage data (SQLite)"), ] def opt_blur_usage(self, arg): - self["blur_usage"] = int(arg) + self["blur-usage"] = int(arg) def makeService(config, reactor=reactor): ep = endpoints.serverFromString(reactor, config["port"]) # to listen + log_file = (os.fdopen(int(config["log-fd"]), "w") + if config["log-fd"] is not None + else None) f = transit_server.Transit(blur_usage=config["blur-usage"], - usage_logfile=config["usage-logfile"], - stats_file=config["stats-file"]) - return StreamServerEndpointService(ep, f) + log_file=log_file, + usage_db=config["usage-db"]) + parent = MultiService() + StreamServerEndpointService(ep, f).setServiceParent(parent) + TimerService(5*60.0, f.timerUpdateStats).setServiceParent(parent) + return parent diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 440c028..5bdbcba 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -4,16 +4,16 @@ from twisted.internet.defer import inlineCallbacks from ..transit_server import Transit class ServerBase: + @inlineCallbacks def setUp(self): self._lp = None - self._setup_relay() + yield self._setup_relay() @inlineCallbacks - def _setup_relay(self, blur_usage=None, usage_logfile=None, stats_file=None): + def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): ep = endpoints.TCP4ServerEndpoint(reactor, 0, interface="127.0.0.1") self._transit_server = Transit(blur_usage=blur_usage, - usage_logfile=usage_logfile, - stats_file=stats_file) + log_file=log_file, usage_db=usage_db) self._lp = yield ep.listen(self._transit_server) addr = self._lp.getHost() # ws://127.0.0.1:%d/wormhole-relay/ws diff --git a/src/wormhole_transit_relay/test/test_config.py b/src/wormhole_transit_relay/test/test_config.py new file mode 100644 index 0000000..02e3fee --- /dev/null +++ b/src/wormhole_transit_relay/test/test_config.py @@ -0,0 +1,23 @@ +from __future__ import unicode_literals, print_function +from twisted.trial import unittest +from .. import server_tap + +class Config(unittest.TestCase): + def test_defaults(self): + o = server_tap.Options() + o.parseOptions([]) + self.assertEqual(o, {"blur-usage": None, "log-fd": None, + "usage-db": None, "port": "tcp:4001"}) + def test_blur(self): + o = server_tap.Options() + o.parseOptions(["--blur-usage=60"]) + self.assertEqual(o, {"blur-usage": 60, "log-fd": None, + "usage-db": None, "port": "tcp:4001"}) + + def test_string(self): + o = server_tap.Options() + s = str(o) + self.assertIn("This plugin sets up a 'Transit Relay'", s) + self.assertIn("--blur-usage=", s) + self.assertIn("blur timestamps and data sizes in logs", s) + diff --git a/src/wormhole_transit_relay/test/test_database.py b/src/wormhole_transit_relay/test/test_database.py new file mode 100644 index 0000000..0b5c918 --- /dev/null +++ b/src/wormhole_transit_relay/test/test_database.py @@ -0,0 +1,104 @@ +from __future__ import print_function, unicode_literals +import os +from twisted.python import filepath +from twisted.trial import unittest +from .. import database +from ..database import get_db, TARGET_VERSION, dump_db + +class Get(unittest.TestCase): + def test_create_default(self): + db_url = ":memory:" + db = get_db(db_url) + rows = db.execute("SELECT * FROM version").fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["version"], TARGET_VERSION) + + def test_failed_create_allows_subsequent_create(self): + patch = self.patch(database, "get_schema", lambda version: b"this is a broken schema") + dbfile = filepath.FilePath(self.mktemp()) + self.assertRaises(Exception, lambda: get_db(dbfile.path)) + patch.restore() + get_db(dbfile.path) + + def OFF_test_upgrade(self): # disabled until we add a v2 schema + basedir = self.mktemp() + os.mkdir(basedir) + fn = os.path.join(basedir, "upgrade.db") + self.assertNotEqual(TARGET_VERSION, 2) + + # create an old-version DB in a file + db = get_db(fn, 2) + rows = db.execute("SELECT * FROM version").fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["version"], 2) + del db + + # then upgrade the file to the latest version + dbA = get_db(fn, TARGET_VERSION) + rows = dbA.execute("SELECT * FROM version").fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["version"], TARGET_VERSION) + dbA_text = dump_db(dbA) + del dbA + + # make sure the upgrades got committed to disk + dbB = get_db(fn, TARGET_VERSION) + dbB_text = dump_db(dbB) + del dbB + self.assertEqual(dbA_text, dbB_text) + + # The upgraded schema should be equivalent to that of a new DB. + # However a text dump will differ because ALTER TABLE always appends + # the new column to the end of a table, whereas our schema puts it + # somewhere in the middle (wherever it fits naturally). Also ALTER + # TABLE doesn't include comments. + if False: + latest_db = get_db(":memory:", TARGET_VERSION) + latest_text = dump_db(latest_db) + with open("up.sql","w") as f: f.write(dbA_text) + with open("new.sql","w") as f: f.write(latest_text) + # check with "diff -u _trial_temp/up.sql _trial_temp/new.sql" + self.assertEqual(dbA_text, latest_text) + +class Create(unittest.TestCase): + def test_memory(self): + db = database.create_db(":memory:") + latest_text = dump_db(db) + self.assertIn("CREATE TABLE", latest_text) + + def test_preexisting(self): + basedir = self.mktemp() + os.mkdir(basedir) + fn = os.path.join(basedir, "preexisting.db") + with open(fn, "w"): + pass + with self.assertRaises(database.DBAlreadyExists): + database.create_db(fn) + + def test_create(self): + basedir = self.mktemp() + os.mkdir(basedir) + fn = os.path.join(basedir, "created.db") + db = database.create_db(fn) + latest_text = dump_db(db) + self.assertIn("CREATE TABLE", latest_text) + +class Open(unittest.TestCase): + def test_open(self): + basedir = self.mktemp() + os.mkdir(basedir) + fn = os.path.join(basedir, "created.db") + db1 = database.create_db(fn) + latest_text = dump_db(db1) + self.assertIn("CREATE TABLE", latest_text) + db2 = database.open_existing_db(fn) + self.assertIn("CREATE TABLE", dump_db(db2)) + + def test_doesnt_exist(self): + basedir = self.mktemp() + os.mkdir(basedir) + fn = os.path.join(basedir, "created.db") + with self.assertRaises(database.DBDoesntExist): + database.open_existing_db(fn) + + diff --git a/src/wormhole_transit_relay/test/test_service.py b/src/wormhole_transit_relay/test/test_service.py new file mode 100644 index 0000000..dac642c --- /dev/null +++ b/src/wormhole_transit_relay/test/test_service.py @@ -0,0 +1,39 @@ +from __future__ import unicode_literals, print_function +from twisted.trial import unittest +import mock +from twisted.application.service import MultiService +from .. import server_tap + +class Service(unittest.TestCase): + def test_defaults(self): + o = server_tap.Options() + o.parseOptions([]) + with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: + s = server_tap.makeService(o) + self.assertEqual(t.mock_calls, + [mock.call(blur_usage=None, + log_file=None, usage_db=None)]) + self.assertIsInstance(s, MultiService) + + def test_blur(self): + o = server_tap.Options() + o.parseOptions(["--blur-usage=60"]) + with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: + server_tap.makeService(o) + self.assertEqual(t.mock_calls, + [mock.call(blur_usage=60, + log_file=None, usage_db=None)]) + + def test_log_fd(self): + o = server_tap.Options() + o.parseOptions(["--log-fd=99"]) + fd = object() + with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: + with mock.patch("wormhole_transit_relay.server_tap.os.fdopen", + return_value=fd) as f: + server_tap.makeService(o) + self.assertEqual(f.mock_calls, [mock.call(99, "w")]) + self.assertEqual(t.mock_calls, + [mock.call(blur_usage=None, + log_file=fd, usage_db=None)]) + diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index f05f7d0..56e8c98 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -1,59 +1,89 @@ from __future__ import print_function, unicode_literals -import os, json +import os, io, json, sqlite3 import mock from twisted.trial import unittest from ..transit_server import Transit +from .. import database -class UsageLog(unittest.TestCase): - def test_log(self): +class DB(unittest.TestCase): + def open_db(self, dbfile): + db = sqlite3.connect(dbfile) + database._initialize_db_connection(db) + return db + + def test_db(self): d = self.mktemp() os.mkdir(d) - usage_logfile = os.path.join(d, "usage.log") - def read(): - with open(usage_logfile, "r") as f: - return [json.loads(line) for line in f.readlines()] - t = Transit(None, usage_logfile, None) - t.recordUsage(started=123, result="happy", total_bytes=100, - total_time=10, waiting_time=2) - self.assertEqual(read(), [dict(started=123, mood="happy", - total_time=10, waiting_time=2, - total_bytes=100)]) + usage_db = os.path.join(d, "usage.sqlite") + with mock.patch("time.time", return_value=456): + t = Transit(blur_usage=None, log_file=None, usage_db=usage_db) + db = self.open_db(usage_db) - t.recordUsage(started=150, result="errory", total_bytes=200, - total_time=11, waiting_time=3) - self.assertEqual(read(), [dict(started=123, mood="happy", - total_time=10, waiting_time=2, - total_bytes=100), - dict(started=150, mood="errory", - total_time=11, waiting_time=3, - total_bytes=200), - ]) - - if False: - # the current design opens the logfile exactly once, at process - # start, in the faint hopes of surviving an exhaustion of available - # file descriptors. This should be rethought. - os.unlink(usage_logfile) - - t.recordUsage(started=200, result="lonely", total_bytes=300, - total_time=12, waiting_time=4) - self.assertEqual(read(), [dict(started=200, mood="lonely", - total_time=12, waiting_time=4, - total_bytes=300)]) - -class StandardLogfile(unittest.TestCase): - def test_log(self): - # the default, when _blur_usage is None, will log to twistd.log - t = Transit(blur_usage=None, usage_logfile=None, stats_file=None) - with mock.patch("twisted.python.log.msg") as m: - t.recordUsage(started=123, result="happy", total_bytes=100, - total_time=10, waiting_time=2) - self.assertEqual(m.mock_calls, [mock.call(format="Transit.recordUsage {bytes}B", bytes=100)]) - - def test_do_not_log(self): - # the default, when _blur_usage is None, will log to twistd.log - t = Transit(blur_usage=60, usage_logfile=None, stats_file=None) - with mock.patch("twisted.python.log.msg") as m: + with mock.patch("time.time", return_value=457): t.recordUsage(started=123, result="happy", total_bytes=100, total_time=10, waiting_time=2) - self.assertEqual(m.mock_calls, []) + self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), + [dict(result="happy", started=123, + total_bytes=100, total_time=10, waiting_time=2), + ]) + self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(), + dict(rebooted=456, updated=457, + incomplete_bytes=0, + waiting=0, connected=0)) + + with mock.patch("time.time", return_value=458): + t.recordUsage(started=150, result="errory", total_bytes=200, + total_time=11, waiting_time=3) + self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), + [dict(result="happy", started=123, + total_bytes=100, total_time=10, waiting_time=2), + dict(result="errory", started=150, + total_bytes=200, total_time=11, waiting_time=3), + ]) + self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(), + dict(rebooted=456, updated=458, + incomplete_bytes=0, + waiting=0, connected=0)) + + with mock.patch("time.time", return_value=459): + t.timerUpdateStats() + self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(), + dict(rebooted=456, updated=459, + incomplete_bytes=0, + waiting=0, connected=0)) + + def test_no_db(self): + t = Transit(blur_usage=None, log_file=None, usage_db=None) + + t.recordUsage(started=123, result="happy", total_bytes=100, + total_time=10, waiting_time=2) + t.timerUpdateStats() + +class LogToStdout(unittest.TestCase): + def test_log(self): + # emit lines of JSON to log_file, if set + log_file = io.StringIO() + t = Transit(blur_usage=None, log_file=log_file, usage_db=None) + t.recordUsage(started=123, result="happy", total_bytes=100, + total_time=10, waiting_time=2) + self.assertEqual(json.loads(log_file.getvalue()), + {"started": 123, "total_time": 10, + "waiting_time": 2, "total_bytes": 100, + "mood": "happy"}) + + def test_log_blurred(self): + # if blurring is enabled, timestamps should be rounded to the + # requested amount, and sizes should be rounded up too + log_file = io.StringIO() + t = Transit(blur_usage=60, log_file=log_file, usage_db=None) + t.recordUsage(started=123, result="happy", total_bytes=11999, + total_time=10, waiting_time=2) + self.assertEqual(json.loads(log_file.getvalue()), + {"started": 120, "total_time": 10, + "waiting_time": 2, "total_bytes": 20000, + "mood": "happy"}) + + def test_do_not_log(self): + t = Transit(blur_usage=60, log_file=None, usage_db=None) + t.recordUsage(started=123, result="happy", total_bytes=11999, + total_time=10, waiting_time=2) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index eb2e997..beaec7a 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -232,7 +232,7 @@ class Transit(ServerBase, unittest.TestCase): a2.transport.loseConnection() @defer.inlineCallbacks - def test_bad_handshake(self): + def test_bad_handshake_old(self): ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) @@ -247,6 +247,49 @@ class Transit(ServerBase, unittest.TestCase): a1.transport.loseConnection() + @defer.inlineCallbacks + def test_bad_handshake_old_slow(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + a1.transport.write(b"please DELAY ") + # As in test_impatience_new_slow, the current state machine has code + # that can only be reached if we insert a stall here, so dataReceived + # gets called twice. Hopefully we can delete this test once + # dataReceived is refactored to remove that state. + d = defer.Deferred() + reactor.callLater(0.1, d.callback, None) + yield d + + token1 = b"\x00"*32 + # the server waits for the exact number of bytes in the expected + # handshake message. to trigger "bad handshake", we must match. + a1.transport.write(hexlify(token1) + b"\n") + + exp = b"bad handshake\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection() + + @defer.inlineCallbacks + def test_bad_handshake_new(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + # the server waits for the exact number of bytes in the expected + # handshake message. to trigger "bad handshake", we must match. + a1.transport.write(b"please DELAY " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + + exp = b"bad handshake\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection() + @defer.inlineCallbacks def test_binary_handshake(self): ep = clientFromString(reactor, self.transit) @@ -298,3 +341,34 @@ class Transit(ServerBase, unittest.TestCase): self.assertEqual(a1.data, exp) a1.transport.loseConnection() + + @defer.inlineCallbacks + def test_impatience_new_slow(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + # For full coverage, we need dataReceived to see a particular framing + # of these two pieces of data, and ITCPTransport doesn't have flush() + # (which probably wouldn't work anyways). For now, force a 100ms + # stall between the two writes. I tried setTcpNoDelay(True) but it + # didn't seem to help without the stall. The long-term fix is to + # rewrite dataReceived() to remove the multiple "impatient" + # codepaths, deleting the particular clause that this test exercises, + # then remove this test. + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + # sending too many bytes is impatience. + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + + d = defer.Deferred() + reactor.callLater(0.1, d.callback, None) + yield d + + a1.transport.write(b"NOWNOWNOW") + + exp = b"impatient\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection() diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 42f19ca..5fa6c90 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -1,7 +1,8 @@ from __future__ import print_function, unicode_literals -import os, re, time, json +import re, time, json from twisted.python import log from twisted.internet import protocol +from .database import get_db SECONDS = 1.0 MINUTE = 60*SECONDS @@ -220,15 +221,23 @@ class Transit(protocol.ServerFactory): MAXTIME = 60*SECONDS protocol = TransitConnection - def __init__(self, blur_usage, usage_logfile, stats_file): + def __init__(self, blur_usage, log_file, usage_db): self._blur_usage = blur_usage self._log_requests = blur_usage is None - self._usage_logfile = open(usage_logfile, "a") if usage_logfile else None - self._stats_file = stats_file + if self._blur_usage: + log.msg("blurring access times to %d seconds" % self._blur_usage) + log.msg("not logging Transit connections to Twisted log") + else: + log.msg("not blurring access times") + self._debug_log = False + self._log_file = log_file + self._db = None + if usage_db: + self._db = get_db(usage_db) + self._rebooted = time.time() + # we don't track TransitConnections until they submit a token self._pending_requests = {} # token -> set((side, TransitConnection)) self._active_connections = set() # TransitConnection - self._counts = {"lonely": 0, "happy": 0, "errory": 0} - self._count_bytes = 0 def connection_got_token(self, token, new_side, new_tc): if token not in self._pending_requests: @@ -240,7 +249,7 @@ class Transit(protocol.ServerFactory): or (new_side is None) or (old_side != new_side)): # we found a match - if self._log_requests: + if self._debug_log: log.msg("transit relay 2: %s" % new_tc.describeToken()) # drop and stop tracking the rest @@ -255,33 +264,11 @@ class Transit(protocol.ServerFactory): new_tc.buddy_connected(old_tc) old_tc.buddy_connected(new_tc) return - if self._log_requests: + if self._debug_log: log.msg("transit relay 1: %s" % new_tc.describeToken()) potentials.add((new_side, new_tc)) # TODO: timer - def recordUsage(self, started, result, total_bytes, - total_time, waiting_time): - self._counts[result] += 1 - self._count_bytes += total_bytes - if self._log_requests: - log.msg(format="Transit.recordUsage {bytes}B", bytes=total_bytes) - if self._blur_usage: - started = self._blur_usage * (started // self._blur_usage) - total_bytes = blur_size(total_bytes) - if self._usage_logfile: - data = {"started": started, - "total_time": total_time, - "waiting_time": waiting_time, - "total_bytes": total_bytes, - "mood": result, - } - self._usage_logfile.write(json.dumps(data)) - self._usage_logfile.write("\n") - self._usage_logfile.flush() - if self._stats_file: - self._update_stats(total_bytes, result) - def transitFinished(self, tc, token, side, description): if token in self._pending_requests: side_tc = (side, tc) @@ -289,50 +276,63 @@ class Transit(protocol.ServerFactory): self._pending_requests[token].remove(side_tc) if not self._pending_requests[token]: # set is now empty del self._pending_requests[token] - if self._log_requests: + if self._debug_log: log.msg("transitFinished %s" % (description,)) self._active_connections.discard(tc) def transitFailed(self, p): - if self._log_requests: + if self._debug_log: log.msg("transitFailed %r" % p) pass - def _update_stats(self, total_bytes, mood): - try: - with open(self._stats_file, "r") as f: - stats = json.load(f) - except (EnvironmentError, ValueError): - stats = {} + def recordUsage(self, started, result, total_bytes, + total_time, waiting_time): + if self._debug_log: + log.msg(format="Transit.recordUsage {bytes}B", bytes=total_bytes) + if self._blur_usage: + started = self._blur_usage * (started // self._blur_usage) + total_bytes = blur_size(total_bytes) + if self._log_file is not None: + data = {"started": started, + "total_time": total_time, + "waiting_time": waiting_time, + "total_bytes": total_bytes, + "mood": result, + } + self._log_file.write(json.dumps(data)+"\n") + self._log_file.flush() + if self._db: + self._db.execute("INSERT INTO `usage`" + " (`started`, `total_time`, `waiting_time`," + " `total_bytes`, `result`)" + " VALUES (?,?,?, ?,?)", + (started, total_time, waiting_time, + total_bytes, result)) + self._update_stats() + self._db.commit() - # current status: expected to be zero most of the time - stats["active"] = {"connected": len(self._active_connections) / 2, - "waiting": len(self._pending_requests), - } + def timerUpdateStats(self): + if self._db: + self._update_stats() + self._db.commit() - # usage since last reboot - rb = stats["since_reboot"] = {} - rb["bytes"] = self._count_bytes - rb["total"] = sum(self._counts.values(), 0) - rbm = rb["moods"] = {} - for result, count in self._counts.items(): - rbm[result] = count - - # historical usage (all-time) - if "all_time" not in stats: - stats["all_time"] = {} - u = stats["all_time"] - u["total"] = u.get("total", 0) + 1 - u["bytes"] = u.get("bytes", 0) + total_bytes - if "moods" not in u: - u["moods"] = {} - um = u["moods"] - for m in "happy", "lonely", "errory": - if m not in um: - um[m] = 0 - um[mood] += 1 - tmpfile = self._stats_file + ".tmp" - with open(tmpfile, "w") as f: - f.write(json.dumps(stats)) - f.write("\n") - os.rename(tmpfile, self._stats_file) + def _update_stats(self): + # current status: should be zero when idle + rebooted = self._rebooted + updated = time.time() + connected = len(self._active_connections) / 2 + # TODO: when a connection is half-closed, len(active) will be odd. a + # moment later (hopefully) the other side will disconnect, but + # _update_stats isn't updated until later. + waiting = len(self._pending_requests) + # "waiting" doesn't count multiple parallel connections from the same + # side + incomplete_bytes = sum(tc._total_sent + for tc in self._active_connections) + self._db.execute("DELETE FROM `current`") + self._db.execute("INSERT INTO `current`" + " (`rebooted`, `updated`, `connected`, `waiting`," + " `incomplete_bytes`)" + " VALUES (?, ?, ?, ?, ?)", + (rebooted, updated, connected, waiting, + incomplete_bytes))