diff --git a/src/wormhole/server/cli.py b/src/wormhole/server/cli.py index 59dbd9a..9596dbd 100644 --- a/src/wormhole/server/cli.py +++ b/src/wormhole/server/cli.py @@ -1,4 +1,5 @@ from __future__ import print_function +import json import click from ..cli.cli import Config, _compose @@ -19,6 +20,21 @@ def server(ctx): # this is the setuptools entrypoint for bin/wormhole-server # server commands don't use ctx.obj = Config() +def _validate_websocket_protocol_options(ctx, param, value): + return list(_validate_websocket_protocol_option(option) for option in value) + +def _validate_websocket_protocol_option(option): + try: + key, value = option.split("=", 1) + except ValueError: + raise click.BadParameter("format options as OPTION=VALUE") + + try: + value = json.loads(value) + except: + raise click.BadParameter("could not parse JSON value for {}".format(key)) + + return (key, value) LaunchArgs = _compose( click.option( @@ -58,6 +74,11 @@ LaunchArgs = _compose( "--stats-json-path", default="stats.json", metavar="PATH", help="location to write the relay stats file", ), + click.option( + "--websocket-protocol-option", multiple=True, metavar="OPTION=VALUE", + callback=_validate_websocket_protocol_options, + help="a websocket server protocol option to configure", + ), ) diff --git a/src/wormhole/server/server.py b/src/wormhole/server/server.py index da94f21..ee2ba48 100644 --- a/src/wormhole/server/server.py +++ b/src/wormhole/server/server.py @@ -34,7 +34,8 @@ class RelayServer(service.MultiService): def __init__(self, rendezvous_web_port, transit_port, advertise_version, db_url=":memory:", blur_usage=None, - signal_error=None, stats_file=None, allow_list=True): + signal_error=None, stats_file=None, allow_list=True, + websocket_protocol_options=()): service.MultiService.__init__(self) self._blur_usage = blur_usage self._allow_list = allow_list @@ -64,6 +65,7 @@ class RelayServer(service.MultiService): root = Root() wsrf = WebSocketRendezvousFactory(None, self._rendezvous) + _set_options(websocket_protocol_options, wsrf) root.putChild(b"v1", WebSocketResource(wsrf)) site = PrivacyEnhancedSite(root) @@ -137,3 +139,7 @@ class RelayServer(service.MultiService): f.write(json.dumps(data, indent=1).encode("utf-8")) f.write(b"\n") os.rename(tmpfn, self._stats_file) + + +def _set_options(options, factory): + factory.setProtocolOptions(**dict(options)) diff --git a/src/wormhole/test/test_cli.py b/src/wormhole/test/test_cli.py index ba532bb..8c3a0fa 100644 --- a/src/wormhole/test/test_cli.py +++ b/src/wormhole/test/test_cli.py @@ -1238,3 +1238,48 @@ class Server(unittest.TestCase): relay = plugin.makeService(None) self.assertEqual('relay.sqlite', relay._db_url) self.assertEqual('stats.json', relay._stats_file) + + @mock.patch("wormhole.server.cmd_server.start_server") + def test_websocket_protocol_options(self, fake_start_server): + result = self.runner.invoke( + server, [ + 'start', + '--websocket-protocol-option=a=3', + '--websocket-protocol-option=b=true', + '--websocket-protocol-option=c=3.5', + '--websocket-protocol-option=d=["foo","bar"]', + '--websocket-protocol-option', 'e=["foof","barf"]', + ]) + self.assertEqual(0, result.exit_code) + cfg = fake_start_server.mock_calls[0][1][0] + self.assertEqual( + cfg.websocket_protocol_option, + [("a", 3), ("b", True), ("c", 3.5), ("d", ['foo', 'bar']), + ("e", ['foof', 'barf']), + ], + ) + + def test_broken_websocket_protocol_options(self): + result = self.runner.invoke( + server, [ + 'start', + '--websocket-protocol-option=a', + ]) + self.assertNotEqual(0, result.exit_code) + self.assertIn( + 'Error: Invalid value for "--websocket-protocol-option": ' + 'format options as OPTION=VALUE', + result.output, + ) + + result = self.runner.invoke( + server, [ + 'start', + '--websocket-protocol-option=a=foo', + ]) + self.assertNotEqual(0, result.exit_code) + self.assertIn( + 'Error: Invalid value for "--websocket-protocol-option": ' + 'could not parse JSON value for a', + result.output, + ) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 514d573..343ffa4 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -11,6 +11,19 @@ from ..server import server, rendezvous from ..server.rendezvous import Usage, SidedMessage from ..server.database import get_db +def easy_relay( + rendezvous_web_port=str("tcp:0"), + transit_port=str("tcp:0"), + advertise_version=None, + **kwargs +): + return server.RelayServer( + rendezvous_web_port, + transit_port, + advertise_version, + **kwargs + ) + class _Util: def _nameplate(self, app, name): np_row = app._db.execute("SELECT * FROM `nameplates`" @@ -1313,7 +1326,7 @@ class Summary(unittest.TestCase): class DumpStats(unittest.TestCase): def test_nostats(self): - rs = server.RelayServer(str("tcp:0"), str("tcp:0"), None) + rs = easy_relay() # with no ._stats_file, this should do nothing rs.dump_stats(1, 1) @@ -1321,8 +1334,7 @@ class DumpStats(unittest.TestCase): basedir = self.mktemp() os.mkdir(basedir) fn = os.path.join(basedir, "stats.json") - rs = server.RelayServer(str("tcp:0"), str("tcp:0"), None, - stats_file=fn) + rs = easy_relay(stats_file=fn) now = 1234 validity = 500 rs.dump_stats(now, validity) @@ -1339,12 +1351,7 @@ class Startup(unittest.TestCase): @mock.patch('wormhole.server.server.log') def test_empty(self, fake_log): - rs = server.RelayServer( - str("tcp:0"), - str("tcp:0"), - None, - allow_list=False, - ) + rs = easy_relay(allow_list=False) rs.startService() try: logs = '\n'.join([call[1][0] for call in fake_log.mock_calls]) @@ -1353,3 +1360,17 @@ class Startup(unittest.TestCase): ) finally: rs.stopService() + + +class WebSocketProtocolOptions(unittest.TestCase): + @mock.patch('wormhole.server.server.WebSocketRendezvousFactory') + def test_set(self, fake_factory): + easy_relay( + websocket_protocol_options=[ + ("foo", "bar"), + ] + ) + self.assertEqual( + mock.call().setProtocolOptions(foo="bar"), + fake_factory.mock_calls[1], + )