From e7466a359505962783fc05002cab0c504c898406 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 14 Apr 2021 16:46:46 -0600 Subject: [PATCH] add CLI options for WebSockets support --- src/wormhole_transit_relay/server_tap.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index 3973617..a54082b 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -26,6 +26,8 @@ class Options(usage.Options): optParameters = [ ("port", "p", "tcp:4001:interface=\:\:", "endpoint to listen on"), + ("websocket", "w", None, "endpoint to listen for WebSocket connections"), + ("websocket-url", "u", None, "WebSocket URL (derived from endpoint if not provided)"), ("blur-usage", None, None, "blur timestamps and data sizes in logs"), ("log-fd", None, None, "write JSON usage logs to this file descriptor"), ("usage-db", None, None, "record usage data (SQLite)"), @@ -38,8 +40,11 @@ class Options(usage.Options): def makeService(config, reactor=reactor): increase_rlimits() tcp_ep = endpoints.serverFromString(reactor, config["port"]) # to listen - # XXX FIXME proper websocket option - ws_ep = endpoints.serverFromString(reactor, "tcp:4002") # to listen + ws_ep = ( + endpoints.serverFromString(reactor, config["websocket"]) + if config["websocket"] is not None + else None + ) log_file = ( os.fdopen(int(config["log-fd"]), "w") if config["log-fd"] is not None @@ -55,13 +60,22 @@ def makeService(config, reactor=reactor): tcp_factory = protocol.ServerFactory() tcp_factory.protocol = transit_server.TransitConnection - ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url - ws_factory.protocol = transit_server.WebSocketTransitConnection + if ws_ep is not None: + ws_url = config["websocket-url"] + if ws_url is None: + # we're using a "private" attribute here but I don't see + # any useful alternative unless we also want to parse + # Twisted endpoint-strings. + ws_url = "ws://localhost:{}/".format(ws_ep._port) + print("Using WebSocket URL '{}'".format(ws_url)) + ws_factory = WebSocketServerFactory(ws_url) + ws_factory.protocol = transit_server.WebSocketTransitConnection tcp_factory.transit = transit ws_factory.transit = transit parent = MultiService() StreamServerEndpointService(tcp_ep, tcp_factory).setServiceParent(parent) - StreamServerEndpointService(ws_ep, ws_factory).setServiceParent(parent) + if ws_ep is not None: + StreamServerEndpointService(ws_ep, ws_factory).setServiceParent(parent) TimerService(5*60.0, transit.update_stats).setServiceParent(parent) return parent