diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 60e3101..613eee3 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -118,8 +118,10 @@ class DatabaseUsageRecorder: " VALUES (?,?,?,?,?)", (started, total_time, waiting_time, total_bytes, mood) ) - # XXX FIXME see comment in transit_server - #self._update_stats() + # original code did "self._update_stats()" here, thus causing + # "global" stats update on every connection update .. should + # we repeat this behavior, or really only record every + # 60-seconds with the timer? self._db.commit() @@ -227,6 +229,26 @@ class UsageTracker(object): "mood": result, }) + def update_stats(self, rebooted, updated, connected, waiting, + incomplete_bytes): + """ + Update general statistics. + """ + # in original code, this is only recorded in the database + # .. perhaps a better way to do this, but .. + for backend in self._backends: + if isinstance(backend, DatabaseUsageRecorder): + backend._db.execute("DELETE FROM `current`") + backend._db.execute( + "INSERT INTO `current`" + " (`rebooted`, `updated`, `connected`, `waiting`," + " `incomplete_bytes`)" + " VALUES (?, ?, ?, ?, ?)", + (int(rebooted), int(updated), connected, waiting, + incomplete_bytes) + ) + + def _notify_backends(self, data): """ Internal helper. Tell every backend we have about a new usage record. diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index 7f89409..704d404 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -44,8 +44,8 @@ def makeService(config, reactor=reactor): log_file=log_file, usage_db=db, ) - factory = transit_server.Transit(usage) + factory = transit_server.Transit(usage, reactor.seconds) parent = MultiService() StreamServerEndpointService(ep, factory).setServiceParent(parent) -### FIXME TODO TimerService(5*60.0, factory.timerUpdateStats).setServiceParent(parent) + TimerService(5*60.0, factory.update_stats).setServiceParent(parent) return parent diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index adbecf8..d78b844 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -70,7 +70,7 @@ class ServerBase: log_file=log_file, usage_db=usage_db, ) - self._transit_server = Transit(usage) + self._transit_server = Transit(usage, lambda: 123456789.0) self._transit_server._debug_log = self.log_requests def new_protocol(self): diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index 0ce46a5..390b524 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -12,19 +12,31 @@ from .. import database class DB(unittest.TestCase): def test_db(self): + T = 1519075308.0 + + class Timer: + t = T + def __call__(self): + return self.t + get_time = Timer() + d = self.mktemp() os.mkdir(d) usage_db = os.path.join(d, "usage.sqlite") db = database.get_db(usage_db) - with mock.patch("time.time", return_value=T+0): - t = Transit(create_usage_tracker(blur_usage=None, log_file=None, usage_db=db)) + t = Transit( + create_usage_tracker(blur_usage=None, log_file=None, usage_db=db), + get_time, + ) self.assertEqual(len(t.usage._backends), 1) usage = list(t.usage._backends)[0] - with mock.patch("time.time", return_value=T+1): - usage.record_usage(started=123, mood="happy", total_bytes=100, - total_time=10, waiting_time=2) + get_time.t = T + 1 + usage.record_usage(started=123, mood="happy", total_bytes=100, + total_time=10, waiting_time=2) + t.update_stats() + self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), [dict(result="happy", started=123, total_bytes=100, total_time=10, waiting_time=2), @@ -34,9 +46,10 @@ class DB(unittest.TestCase): incomplete_bytes=0, waiting=0, connected=0)) - with mock.patch("time.time", return_value=T+2): - usage.record_usage(started=150, mood="errory", total_bytes=200, - total_time=11, waiting_time=3) + get_time.t = T + 2 + usage.record_usage(started=150, mood="errory", total_bytes=200, + total_time=11, waiting_time=3) + t.update_stats() self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), [dict(result="happy", started=123, total_bytes=100, total_time=10, waiting_time=2), @@ -48,15 +61,18 @@ class DB(unittest.TestCase): incomplete_bytes=0, waiting=0, connected=0)) - with mock.patch("time.time", return_value=T+3): - t.timerUpdateStats() + get_time.t = T + 3 + t.update_stats() self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(), dict(rebooted=T+0, updated=T+3, incomplete_bytes=0, waiting=0, connected=0)) def test_no_db(self): - t = Transit(create_usage_tracker(blur_usage=None, log_file=None, usage_db=None)) + t = Transit( + create_usage_tracker(blur_usage=None, log_file=None, usage_db=None), + lambda: 0, + ) self.assertEqual(0, len(t.usage._backends)) @@ -64,7 +80,10 @@ class LogToStdout(unittest.TestCase): def test_log(self): # emit lines of JSON to log_file, if set log_file = io.StringIO() - t = Transit(create_usage_tracker(blur_usage=None, log_file=log_file, usage_db=None)) + t = Transit( + create_usage_tracker(blur_usage=None, log_file=log_file, usage_db=None), + lambda: 0, + ) with mock.patch("time.time", return_value=133): t.usage.record( started=123, @@ -82,7 +101,10 @@ class LogToStdout(unittest.TestCase): # if blurring is enabled, timestamps should be rounded to the # requested amount, and sizes should be rounded up too log_file = io.StringIO() - t = Transit(create_usage_tracker(blur_usage=60, log_file=log_file, usage_db=None)) + t = Transit( + create_usage_tracker(blur_usage=60, log_file=log_file, usage_db=None), + lambda: 0, + ) with mock.patch("time.time", return_value=123 + 10): t.usage.record( @@ -99,7 +121,10 @@ class LogToStdout(unittest.TestCase): "mood": "happy"}) def test_do_not_log(self): - t = Transit(create_usage_tracker(blur_usage=60, log_file=None, usage_db=None)) + t = Transit( + create_usage_tracker(blur_usage=60, log_file=None, usage_db=None), + lambda: 0, + ) t.usage.record( started=123, buddy_started=124, diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 7865c22..640e972 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -193,33 +193,28 @@ class Transit(protocol.ServerFactory): MAXTIME = 60*SECONDS protocol = TransitConnection - def __init__(self, usage): + def __init__(self, usage, get_timestamp): self.active_connections = ActiveConnections() self.pending_requests = PendingRequests(self.active_connections) self.usage = usage self._debug_log = False + self._timestamp = get_timestamp + self._rebooted = self._timestamp() - self._rebooted = time.time() - - # XXX TODO self._rebooted and the below could be in a separate - # object? or in the DatabaseUsageRecorder .. but not here - def _update_stats(self): - # current status: should be zero when idle - rebooted = self._rebooted - updated = time.time() - connected = len(self._active_connections) / 2 + def update_stats(self): # TODO: when a connection is half-closed, len(active) will be odd. a # moment later (hopefully) the other side will disconnect, but # _update_stats isn't updated until later. - waiting = len(self._pending_requests) + # "waiting" doesn't count multiple parallel connections from the same # side - incomplete_bytes = sum(tc._total_sent - for tc in self._active_connections) - self._db.execute("DELETE FROM `current`") - self._db.execute("INSERT INTO `current`" - " (`rebooted`, `updated`, `connected`, `waiting`," - " `incomplete_bytes`)" - " VALUES (?, ?, ?, ?, ?)", - (rebooted, updated, connected, waiting, - incomplete_bytes)) + self.usage.update_stats( + rebooted=self._rebooted, + updated=self._timestamp(), + connected=len(self.active_connections._connections), + waiting=len(self.pending_requests._requests), + incomplete_bytes=sum( + tc._total_sent + for tc in self.active_connections._connections + ), + )