From 34d039c38cd28d905a418559ae5ca0459b0de4f6 Mon Sep 17 00:00:00 2001 From: meejah Date: Tue, 23 Feb 2021 13:31:55 -0700 Subject: [PATCH] hack in prelim websocket support --- src/wormhole_transit_relay/server_state.py | 8 ++ src/wormhole_transit_relay/server_tap.py | 3 +- src/wormhole_transit_relay/transit_server.py | 110 ++++++++++++++++++- 3 files changed, 119 insertions(+), 2 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index a773f87..86cf9b6 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -612,11 +612,19 @@ class TransitServerState(object): Terminal state """ + # need a listening.upon(connection_lost) for special websocket + # case where handshake fails? + listening.upon( connection_made, enter=wait_relay, outputs=[_remember_client], ) + listening.upon( + connection_lost, + enter=done, + outputs=[_mood_errory], + ) wait_relay.upon( please_relay, diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index b4028d2..4caa5c9 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -45,7 +45,8 @@ def makeService(config, reactor=reactor): log_file=log_file, usage_db=db, ) - factory = transit_server.Transit(usage, reactor.seconds) + ##factory = transit_server.Transit(usage, reactor.seconds) + factory = transit_server.WebSocketTransit(usage, reactor.seconds) parent = MultiService() StreamServerEndpointService(ep, factory).setServiceParent(parent) TimerService(5*60.0, factory.update_stats).setServiceParent(parent) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index d9dabfd..af6760c 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -4,6 +4,12 @@ import time from twisted.python import log from twisted.internet import protocol from twisted.protocols.basic import LineReceiver +from autobahn.twisted.websocket import ( + WebSocketServerProtocol, + WebSocketServerFactory, +) + + SECONDS = 1.0 MINUTE = 60*SECONDS @@ -129,8 +135,12 @@ class TransitConnection(LineReceiver): # XXX describeToken -> self._state.get_token() +# XXX multiple-inheritance sucks... +# ("Transit" wants to be "the factory" but the base class is slightly +# different for websocket versus "normal" socket .. so maybe we need +# to make Transit *not* the factory directly?) -class Transit(protocol.ServerFactory): +class Transit(WebSocketServerFactory):#protocol.ServerFactory): """ I manage pairs of simultaneous connections to a secondary TCP port, both forwarded to the other. Clients must begin each connection with @@ -169,6 +179,7 @@ class Transit(protocol.ServerFactory): protocol = TransitConnection def __init__(self, usage, get_timestamp): + super(Transit, self).__init__() self.active_connections = ActiveConnections() self.pending_requests = PendingRequests(self.active_connections) self.usage = usage @@ -193,3 +204,100 @@ class Transit(protocol.ServerFactory): for tc in self.active_connections._connections ), ) + + +@implementer(ITransitClient) +class WebSocketTransitConnection(WebSocketServerProtocol): + started_time = None + + def send(self, data): + """ + ITransitClient API + """ + self.sendMessage(data, isBinary=True) + + def disconnect(self): + """ + ITransitClient API + """ + self.sendClose(1000, None) + + def connect_partner(self, other): + """ + ITransitClient API + """ + self._buddy = other + + def disconnect_partner(self): + """ + ITransitClient API + """ + if self._buddy is not None: + log.msg("buddy_disconnected {}".format(self._buddy.get_token())) + self._buddy._client.transport.loseConnection() + self._buddy = None + + def onConnect(self, request): + """ + IWebSocketChannel API + """ + print("onConnect: {}".format(request)) + # ideally more like self._reactor.seconds() ... but Twisted + # doesn't have a good way to get the reactor for a protocol + # (besides "use the global one") + print("protocols: {}".format(request.protocols)) + return None#"transit_relay" + + def connectionMade(self): + """ + IProtocol API + """ + print("connectionMade") + self.started_time = time.time() + self._first_message = True + self._state = TransitServerState( + self.factory.pending_requests, + self.factory.usage, + ) + return super(WebSocketTransitConnection, self).connectionMade() + + def onOpen(self): + print("onOpen") + self._state.connection_made(self) + + def onMessage(self, payload, isBinary): + """ + We may have a 'handshake' on our hands or we may just have some bytes to relay + """ + print("onMessage isBinary={}: {}".format(isBinary, payload)) + if self._first_message: + self._first_message = False + token = None + old = re.search(br"^please relay (\w{64})$", payload) + if old: + token = old.group(1) + self._state.please_relay(token) + + # new: "please relay {64} for side {16}\n" + new = re.search(br"^please relay (\w{64}) for side (\w{16})$", payload) + if new: + token = new.group(1) + side = new.group(2) + self._state.please_relay_for_side(token, side) + + if token is None: + self._state.bad_token() + else: + self._state.got_bytes(payload) + + def onClose(self, wasClean, code, reason): + """ + IWebSocketChannel API + """ + self._state.connection_lost() + # XXX "transit finished", etc + + +class WebSocketTransit(Transit, WebSocketServerFactory): + protocol = WebSocketTransitConnection + websocket_protocols = ["transit_relay"]