diff --git a/src/wormhole_transit_relay/database.py b/src/wormhole_transit_relay/database.py index 65a2e6c..de8bc08 100644 --- a/src/wormhole_transit_relay/database.py +++ b/src/wormhole_transit_relay/database.py @@ -9,12 +9,12 @@ class DBError(Exception): pass def get_schema(version): - schema_bytes = resource_string("wormhole.server", + schema_bytes = resource_string("wormhole_transit_relay", "db-schemas/v%d.sql" % version) return schema_bytes.decode("utf-8") def get_upgrader(new_version): - schema_bytes = resource_string("wormhole.server", + schema_bytes = resource_string("wormhole_transit_relay", "db-schemas/upgrade-to-v%d.sql" % new_version) return schema_bytes.decode("utf-8") diff --git a/src/wormhole_transit_relay/db-schemas/v1.sql b/src/wormhole_transit_relay/db-schemas/v1.sql index 9abb449..f68d742 100644 --- a/src/wormhole_transit_relay/db-schemas/v1.sql +++ b/src/wormhole_transit_relay/db-schemas/v1.sql @@ -7,8 +7,8 @@ CREATE TABLE `version` -- contains one row CREATE TABLE `current` -- contains one row ( - `reboot` INTEGER, -- seconds since epoch of most recent reboot - `last_update` INTEGER, -- when `current` was last updated + `rebooted` INTEGER, -- seconds since epoch of most recent reboot + `updated` INTEGER, -- when `current` was last updated `connected` INTEGER, -- number of current paired connections `waiting` INTEGER, -- number of not-yet-paired connections `incomplete_bytes` INTEGER -- bytes sent through not-yet-complete connections @@ -26,5 +26,5 @@ CREATE TABLE `usage` -- "lonely": good handshake, but the other side never showed up -- "happy": both sides gave correct handshake ); -CREATE INDEX `transit_usage_idx` ON `transit_usage` (`started`); -CREATE INDEX `transit_usage_result_idx` ON `transit_usage` (`result`); +CREATE INDEX `usage_started_index` ON `usage` (`started`); +CREATE INDEX `usage_result_index` ON `usage` (`result`); diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index babc441..8aacb63 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -1,6 +1,8 @@ +import sys from . import transit_server from twisted.internet import reactor from twisted.python import usage +from twisted.application.service import MultiService from twisted.application.internet import (TimerService, StreamServerEndpointService) from twisted.internet import endpoints @@ -92,10 +94,11 @@ class Options(usage.Options): def makeService(config, reactor=reactor): ep = endpoints.serverFromString(reactor, config["port"]) # to listen + log_file = sys.stdout if config["log-stdout"] else None f = transit_server.Transit(blur_usage=config["blur-usage"], - log_stdout=config["log-stdout"], + log_file=log_file, usage_db=config["usage-db"]) - parent = service.MultiService() + parent = MultiService() StreamServerEndpointService(ep, f).setServiceParent(parent) - TimerService(5.0, f.timerUpdateStats).setServiceParent(parent) + TimerService(5*60.0, f.timerUpdateStats).setServiceParent(parent) return parent diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 440c028..5bdbcba 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -4,16 +4,16 @@ from twisted.internet.defer import inlineCallbacks from ..transit_server import Transit class ServerBase: + @inlineCallbacks def setUp(self): self._lp = None - self._setup_relay() + yield self._setup_relay() @inlineCallbacks - def _setup_relay(self, blur_usage=None, usage_logfile=None, stats_file=None): + def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): ep = endpoints.TCP4ServerEndpoint(reactor, 0, interface="127.0.0.1") self._transit_server = Transit(blur_usage=blur_usage, - usage_logfile=usage_logfile, - stats_file=stats_file) + log_file=log_file, usage_db=usage_db) self._lp = yield ep.listen(self._transit_server) addr = self._lp.getHost() # ws://127.0.0.1:%d/wormhole-relay/ws diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index f05f7d0..56e8c98 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -1,59 +1,89 @@ from __future__ import print_function, unicode_literals -import os, json +import os, io, json, sqlite3 import mock from twisted.trial import unittest from ..transit_server import Transit +from .. import database -class UsageLog(unittest.TestCase): - def test_log(self): +class DB(unittest.TestCase): + def open_db(self, dbfile): + db = sqlite3.connect(dbfile) + database._initialize_db_connection(db) + return db + + def test_db(self): d = self.mktemp() os.mkdir(d) - usage_logfile = os.path.join(d, "usage.log") - def read(): - with open(usage_logfile, "r") as f: - return [json.loads(line) for line in f.readlines()] - t = Transit(None, usage_logfile, None) - t.recordUsage(started=123, result="happy", total_bytes=100, - total_time=10, waiting_time=2) - self.assertEqual(read(), [dict(started=123, mood="happy", - total_time=10, waiting_time=2, - total_bytes=100)]) + usage_db = os.path.join(d, "usage.sqlite") + with mock.patch("time.time", return_value=456): + t = Transit(blur_usage=None, log_file=None, usage_db=usage_db) + db = self.open_db(usage_db) - t.recordUsage(started=150, result="errory", total_bytes=200, - total_time=11, waiting_time=3) - self.assertEqual(read(), [dict(started=123, mood="happy", - total_time=10, waiting_time=2, - total_bytes=100), - dict(started=150, mood="errory", - total_time=11, waiting_time=3, - total_bytes=200), - ]) - - if False: - # the current design opens the logfile exactly once, at process - # start, in the faint hopes of surviving an exhaustion of available - # file descriptors. This should be rethought. - os.unlink(usage_logfile) - - t.recordUsage(started=200, result="lonely", total_bytes=300, - total_time=12, waiting_time=4) - self.assertEqual(read(), [dict(started=200, mood="lonely", - total_time=12, waiting_time=4, - total_bytes=300)]) - -class StandardLogfile(unittest.TestCase): - def test_log(self): - # the default, when _blur_usage is None, will log to twistd.log - t = Transit(blur_usage=None, usage_logfile=None, stats_file=None) - with mock.patch("twisted.python.log.msg") as m: - t.recordUsage(started=123, result="happy", total_bytes=100, - total_time=10, waiting_time=2) - self.assertEqual(m.mock_calls, [mock.call(format="Transit.recordUsage {bytes}B", bytes=100)]) - - def test_do_not_log(self): - # the default, when _blur_usage is None, will log to twistd.log - t = Transit(blur_usage=60, usage_logfile=None, stats_file=None) - with mock.patch("twisted.python.log.msg") as m: + with mock.patch("time.time", return_value=457): t.recordUsage(started=123, result="happy", total_bytes=100, total_time=10, waiting_time=2) - self.assertEqual(m.mock_calls, []) + self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), + [dict(result="happy", started=123, + total_bytes=100, total_time=10, waiting_time=2), + ]) + self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(), + dict(rebooted=456, updated=457, + incomplete_bytes=0, + waiting=0, connected=0)) + + with mock.patch("time.time", return_value=458): + t.recordUsage(started=150, result="errory", total_bytes=200, + total_time=11, waiting_time=3) + self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), + [dict(result="happy", started=123, + total_bytes=100, total_time=10, waiting_time=2), + dict(result="errory", started=150, + total_bytes=200, total_time=11, waiting_time=3), + ]) + self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(), + dict(rebooted=456, updated=458, + incomplete_bytes=0, + waiting=0, connected=0)) + + with mock.patch("time.time", return_value=459): + t.timerUpdateStats() + self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(), + dict(rebooted=456, updated=459, + incomplete_bytes=0, + waiting=0, connected=0)) + + def test_no_db(self): + t = Transit(blur_usage=None, log_file=None, usage_db=None) + + t.recordUsage(started=123, result="happy", total_bytes=100, + total_time=10, waiting_time=2) + t.timerUpdateStats() + +class LogToStdout(unittest.TestCase): + def test_log(self): + # emit lines of JSON to log_file, if set + log_file = io.StringIO() + t = Transit(blur_usage=None, log_file=log_file, usage_db=None) + t.recordUsage(started=123, result="happy", total_bytes=100, + total_time=10, waiting_time=2) + self.assertEqual(json.loads(log_file.getvalue()), + {"started": 123, "total_time": 10, + "waiting_time": 2, "total_bytes": 100, + "mood": "happy"}) + + def test_log_blurred(self): + # if blurring is enabled, timestamps should be rounded to the + # requested amount, and sizes should be rounded up too + log_file = io.StringIO() + t = Transit(blur_usage=60, log_file=log_file, usage_db=None) + t.recordUsage(started=123, result="happy", total_bytes=11999, + total_time=10, waiting_time=2) + self.assertEqual(json.loads(log_file.getvalue()), + {"started": 120, "total_time": 10, + "waiting_time": 2, "total_bytes": 20000, + "mood": "happy"}) + + def test_do_not_log(self): + t = Transit(blur_usage=60, log_file=None, usage_db=None) + t.recordUsage(started=123, result="happy", total_bytes=11999, + total_time=10, waiting_time=2) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index eb2e997..101b12a 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -298,3 +298,34 @@ class Transit(ServerBase, unittest.TestCase): self.assertEqual(a1.data, exp) a1.transport.loseConnection() + + @defer.inlineCallbacks + def test_impatience_new2(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + # For full coverage, we need dataReceived to see a particular framing + # of these two pieces of data, and ITCPTransport doesn't have flush() + # (which probably wouldn't work anyways). For now, force a 100ms + # stall between the two writes. I tried setTcpNoDelay(True) but it + # didn't seem to help without the stall. The long-term fix is to + # rewrite dataReceived() to remove the multiple "impatient" + # codepaths, deleting the particular clause that this test exercises, + # then remove this test. + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + # sending too many bytes is impatience. + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + + d = defer.Deferred() + reactor.callLater(0.1, d.callback, None) + yield d + + a1.transport.write(b"NOWNOWNOW") + + exp = b"impatient\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection() diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index e915944..c8f1011 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -1,5 +1,5 @@ from __future__ import print_function, unicode_literals -import os, re, time, json +import re, time, json from twisted.python import log from twisted.internet import protocol from .database import get_db @@ -221,13 +221,15 @@ class Transit(protocol.ServerFactory): MAXTIME = 60*SECONDS protocol = TransitConnection - def __init__(self, blur_usage, log_stdout, usage_db): + def __init__(self, blur_usage, log_file, usage_db): self._blur_usage = blur_usage + self._log_requests = blur_usage is None self._debug_log = False - self._log_stdout = log_stdout + self._log_file = log_file self._db = None if usage_db: self._db = get_db(usage_db) + self._rebooted = time.time() # we don't track TransitConnections until they submit a token self._pending_requests = {} # token -> set((side, TransitConnection)) self._active_connections = set() # TransitConnection @@ -285,16 +287,15 @@ class Transit(protocol.ServerFactory): if self._blur_usage: started = self._blur_usage * (started // self._blur_usage) total_bytes = blur_size(total_bytes) - if self._log_stdout: + if self._log_file is not None: data = {"started": started, "total_time": total_time, "waiting_time": waiting_time, "total_bytes": total_bytes, "mood": result, } - sys.stdout.write(json.dumps(data)) - sys.stdout.write("\n") - sys.stdout.flush() + self._log_file.write(json.dumps(data)+"\n") + self._log_file.flush() if self._db: self._db.execute("INSERT INTO `usage`" " (`started`, `total_time`, `waiting_time`," @@ -306,26 +307,27 @@ class Transit(protocol.ServerFactory): self._db.commit() def timerUpdateStats(self): - self._update_stats() - self._db.commit() + if self._db: + self._update_stats() + self._db.commit() def _update_stats(self): # current status: should be zero when idle - reboot = self._reboot - last_update = time.time() + rebooted = self._rebooted + updated = time.time() connected = len(self._active_connections) / 2 # TODO: when a connection is half-closed, len(active) will be odd. a # moment later (hopefully) the other side will disconnect, but # _update_stats isn't updated until later. - waiting = len(self._pending_tokens) + waiting = len(self._pending_requests) # "waiting" doesn't count multiple parallel connections from the same # side incomplete_bytes = sum(tc._total_sent for tc in self._active_connections) self._db.execute("DELETE FROM `current`") self._db.execute("INSERT INTO `current`" - " (`reboot`, `last_update`, `connected`, `waiting`," + " (`rebooted`, `updated`, `connected`, `waiting`," " `incomplete_bytes`)" " VALUES (?, ?, ?, ?, ?)", - (reboot, last_update, connected, waiting, + (rebooted, updated, connected, waiting, incomplete_bytes))