diff --git a/src/wormhole/server/cli.py b/src/wormhole/server/cli.py index b86984f..c644588 100644 --- a/src/wormhole/server/cli.py +++ b/src/wormhole/server/cli.py @@ -22,6 +22,15 @@ def server(ctx): # this is the setuptools entrypoint for bin/wormhole-server ctx.obj = Config() +_relay_database_path = click.option( + "--relay-database-path", default="relay.sqlite", metavar="PATH", + help="location for the relay server state database", +) +_stats_json_path = click.option( + "--stats-json-path", default="stats.json", metavar="PATH", + help="location to write the relay stats file", +) + @server.command() @click.option( "--rendezvous", default="tcp:4000", metavar="tcp:PORT", @@ -52,9 +61,13 @@ def server(ctx): # this is the setuptools entrypoint for bin/wormhole-server "--disallow-list", is_flag=True, help="never send list of allocated nameplates", ) +@_relay_database_path +@_stats_json_path @click.pass_obj def start(cfg, signal_error, no_daemon, blur_usage, advertise_version, - transit, rendezvous, disallow_list): + transit, rendezvous, disallow_list, relay_database_path, + stats_json_path, +): """ Start a relay server """ @@ -66,6 +79,8 @@ def start(cfg, signal_error, no_daemon, blur_usage, advertise_version, cfg.rendezvous = str(rendezvous) cfg.signal_error = signal_error cfg.allow_list = not disallow_list + cfg.relay_database_path = relay_database_path + cfg.stats_json_path = stats_json_path start_server(cfg) @@ -102,9 +117,13 @@ def start(cfg, signal_error, no_daemon, blur_usage, advertise_version, "--disallow-list", is_flag=True, help="never send list of allocated nameplates", ) +@_relay_database_path +@_stats_json_path @click.pass_obj def restart(cfg, signal_error, no_daemon, blur_usage, advertise_version, - transit, rendezvous, disallow_list): + transit, rendezvous, disallow_list, relay_database_path, + stats_json_path, + ): """ Re-start a relay server """ @@ -116,6 +135,8 @@ def restart(cfg, signal_error, no_daemon, blur_usage, advertise_version, cfg.rendezvous = str(rendezvous) cfg.signal_error = signal_error cfg.allow_list = not disallow_list + cfg.relay_database_path = relay_database_path + cfg.stats_json_path = stats_json_path restart_server(cfg) diff --git a/src/wormhole/server/cmd_server.py b/src/wormhole/server/cmd_server.py index 7aff110..1cabea3 100644 --- a/src/wormhole/server/cmd_server.py +++ b/src/wormhole/server/cmd_server.py @@ -17,10 +17,10 @@ class MyPlugin(object): self.args.rendezvous, self.args.transit, self.args.advertise_version, - "relay.sqlite", + self.args.relay_database_path, self.args.blur_usage, signal_error=self.args.signal_error, - stats_file="stats.json", + stats_file=self.args.stats_json_path, allow_list=self.args.allow_list, ) diff --git a/src/wormhole/server/server.py b/src/wormhole/server/server.py index 39cab57..da94f21 100644 --- a/src/wormhole/server/server.py +++ b/src/wormhole/server/server.py @@ -38,6 +38,7 @@ class RelayServer(service.MultiService): service.MultiService.__init__(self) self._blur_usage = blur_usage self._allow_list = allow_list + self._db_url = db_url db = get_db(db_url) welcome = { diff --git a/src/wormhole/test/test_cli.py b/src/wormhole/test/test_cli.py index 90fdbe4..7466ce7 100644 --- a/src/wormhole/test/test_cli.py +++ b/src/wormhole/test/test_cli.py @@ -1172,6 +1172,18 @@ class Dispatch(unittest.TestCase): self.assertEqual(cfg.stderr.getvalue(), expected) +class FakeConfig(object): + no_daemon = True + blur_usage = True + advertise_version = u"fake.version.1" + transit = str('tcp:4321') + rendezvous = str('tcp:1234') + signal_error = True + allow_list = False + relay_database_path = "relay.sqlite" + stats_json_path = "stats.json" + + class Server(unittest.TestCase): def setUp(self): @@ -1183,15 +1195,6 @@ class Server(unittest.TestCase): self.assertEqual(0, result.exit_code) def test_server_plugin(self): - class FakeConfig(object): - no_daemon = True - blur_usage = True - advertise_version = u"fake.version.1" - transit = str('tcp:4321') - rendezvous = str('tcp:1234') - signal_error = True - allow_list = False - cfg = FakeConfig() plugin = MyPlugin(cfg) relay = plugin.makeService(None) @@ -1211,3 +1214,9 @@ class Server(unittest.TestCase): cfg = fake_start_reserver.mock_calls[0][1][0] MyPlugin(cfg).makeService(None) + def test_state_locations(self): + cfg = FakeConfig() + plugin = MyPlugin(cfg) + relay = plugin.makeService(None) + self.assertEqual('relay.sqlite', relay._db_url) + self.assertEqual('stats.json', relay._stats_file)