diff --git a/src/wormhole/server/server.py b/src/wormhole/server/server.py index da94f21..ee2ba48 100644 --- a/src/wormhole/server/server.py +++ b/src/wormhole/server/server.py @@ -34,7 +34,8 @@ class RelayServer(service.MultiService): def __init__(self, rendezvous_web_port, transit_port, advertise_version, db_url=":memory:", blur_usage=None, - signal_error=None, stats_file=None, allow_list=True): + signal_error=None, stats_file=None, allow_list=True, + websocket_protocol_options=()): service.MultiService.__init__(self) self._blur_usage = blur_usage self._allow_list = allow_list @@ -64,6 +65,7 @@ class RelayServer(service.MultiService): root = Root() wsrf = WebSocketRendezvousFactory(None, self._rendezvous) + _set_options(websocket_protocol_options, wsrf) root.putChild(b"v1", WebSocketResource(wsrf)) site = PrivacyEnhancedSite(root) @@ -137,3 +139,7 @@ class RelayServer(service.MultiService): f.write(json.dumps(data, indent=1).encode("utf-8")) f.write(b"\n") os.rename(tmpfn, self._stats_file) + + +def _set_options(options, factory): + factory.setProtocolOptions(**dict(options)) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index b940cc6..343ffa4 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1360,3 +1360,17 @@ class Startup(unittest.TestCase): ) finally: rs.stopService() + + +class WebSocketProtocolOptions(unittest.TestCase): + @mock.patch('wormhole.server.server.WebSocketRendezvousFactory') + def test_set(self, fake_factory): + easy_relay( + websocket_protocol_options=[ + ("foo", "bar"), + ] + ) + self.assertEqual( + mock.call().setProtocolOptions(foo="bar"), + fake_factory.mock_calls[1], + )