diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index e5c7bc3..9f83736 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -149,15 +149,12 @@ def create_usage_tracker(blur_usage, log_file, usage_db): """ tracker = UsageTracker(blur_usage) if usage_db: - db = get_db(usage_db) - tracker.add_backend(DatabaseUsageRecorder(db)) + tracker.add_backend(DatabaseUsageRecorder(usage_db)) if log_file: tracker.add_backend(LogFileUsageRecorder(log_file)) return tracker - - class UsageTracker(object): """ Tracks usage statistics of connections diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index cbf3efa..7f89409 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -38,10 +38,11 @@ def makeService(config, reactor=reactor): if config["log-fd"] is not None else None ) + db = None if config["usage-db"] is None else get_db(config["usage-db"]) usage = create_usage_tracker( blur_usage=config["blur-usage"], log_file=log_file, - usage_db=config["usage-db"], + usage_db=db, ) factory = transit_server.Transit(usage) parent = MultiService() diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index bce450b..6cdfc7b 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -10,19 +10,16 @@ from ..server_state import create_usage_tracker from .. import database class DB(unittest.TestCase): - def open_db(self, dbfile): - db = sqlite3.connect(dbfile) - database._initialize_db_connection(db) - return db def test_db(self): T = 1519075308.0 d = self.mktemp() os.mkdir(d) usage_db = os.path.join(d, "usage.sqlite") + db = database.get_db(usage_db) with mock.patch("time.time", return_value=T+0): - t = Transit(create_usage_tracker(blur_usage=None, log_file=None, usage_db=usage_db)) - db = self.open_db(usage_db) + t = Transit(create_usage_tracker(blur_usage=None, log_file=None, usage_db=db)) + self.assertEqual(len(t.usage._backends), 1) usage = list(t.usage._backends)[0] with mock.patch("time.time", return_value=T+1):