diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index 4caa5c9..42e986a 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -5,6 +5,10 @@ from twisted.application.service import MultiService from twisted.application.internet import (TimerService, StreamServerEndpointService) from twisted.internet import endpoints +from twisted.internet import protocol + +from autobahn.twisted.websocket import WebSocketServerFactory + from . import transit_server from .server_state import create_usage_tracker from .increase_rlimits import increase_rlimits @@ -33,7 +37,9 @@ class Options(usage.Options): def makeService(config, reactor=reactor): increase_rlimits() - ep = endpoints.serverFromString(reactor, config["port"]) # to listen + tcp_ep = endpoints.serverFromString(reactor, config["port"]) # to listen + # XXX FIXME proper websocket option + ws_ep = endpoints.serverFromString(reactor, "tcp:4002:interface=localhost") # to listen log_file = ( os.fdopen(int(config["log-fd"]), "w") if config["log-fd"] is not None @@ -45,9 +51,18 @@ def makeService(config, reactor=reactor): log_file=log_file, usage_db=db, ) - ##factory = transit_server.Transit(usage, reactor.seconds) - factory = transit_server.WebSocketTransit(usage, reactor.seconds) + transit = transit_server.Transit(usage, reactor.seconds) + tcp_factory = protocol.ServerFactory() + tcp_factory.protocol = transit_server.TransitConnection + + ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url + ws_factory.protocol = transit_server.WebSocketTransitConnection + ws_factory.websocket_protocols = ["transit_relay"] + + tcp_factory.transit = transit + ws_factory.transit = transit parent = MultiService() - StreamServerEndpointService(ep, factory).setServiceParent(parent) - TimerService(5*60.0, factory.update_stats).setServiceParent(parent) + StreamServerEndpointService(tcp_ep, tcp_factory).setServiceParent(parent) + StreamServerEndpointService(ws_ep, ws_factory).setServiceParent(parent) + TimerService(5*60.0, transit.update_stats).setServiceParent(parent) return parent diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index af6760c..5ca92f1 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -4,10 +4,8 @@ import time from twisted.python import log from twisted.internet import protocol from twisted.protocols.basic import LineReceiver -from autobahn.twisted.websocket import ( - WebSocketServerProtocol, - WebSocketServerFactory, -) +from autobahn.twisted.websocket import WebSocketServerProtocol + @@ -69,8 +67,8 @@ class TransitConnection(LineReceiver): # (besides "use the global one") self.started_time = time.time() self._state = TransitServerState( - self.factory.pending_requests, - self.factory.usage, + self.factory.transit.pending_requests, + self.factory.transit.usage, ) self._state.connection_made(self) ## self._log_requests = self.factory._log_requests @@ -119,7 +117,7 @@ class TransitConnection(LineReceiver): # us self.transport.loseConnection() # XXX probably should be logged by state? - if self.factory._debug_log: + if self.factory.transit._debug_log: log.msg("transitFailed %r" % self) def disconnect_redundant(self): @@ -140,7 +138,9 @@ class TransitConnection(LineReceiver): # different for websocket versus "normal" socket .. so maybe we need # to make Transit *not* the factory directly?) -class Transit(WebSocketServerFactory):#protocol.ServerFactory): +##WebSocketServerFactory):#protocol.ServerFactory): + +class Transit(object): """ I manage pairs of simultaneous connections to a secondary TCP port, both forwarded to the other. Clients must begin each connection with @@ -176,10 +176,9 @@ class Transit(WebSocketServerFactory):#protocol.ServerFactory): MAXLENGTH = 10*MB # TODO: unused MAXTIME = 60*SECONDS - protocol = TransitConnection +## protocol = TransitConnection def __init__(self, usage, get_timestamp): - super(Transit, self).__init__() self.active_connections = ActiveConnections() self.pending_requests = PendingRequests(self.active_connections) self.usage = usage @@ -253,13 +252,13 @@ class WebSocketTransitConnection(WebSocketServerProtocol): IProtocol API """ print("connectionMade") + super(WebSocketTransitConnection, self).connectionMade() self.started_time = time.time() self._first_message = True self._state = TransitServerState( - self.factory.pending_requests, - self.factory.usage, + self.factory.transit.pending_requests, + self.factory.transit.usage, ) - return super(WebSocketTransitConnection, self).connectionMade() def onOpen(self): print("onOpen") @@ -296,8 +295,3 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ self._state.connection_lost() # XXX "transit finished", etc - - -class WebSocketTransit(Transit, WebSocketServerFactory): - protocol = WebSocketTransitConnection - websocket_protocols = ["transit_relay"] diff --git a/ws_client.py b/ws_client.py index a8dee88..03158a8 100644 --- a/ws_client.py +++ b/ws_client.py @@ -51,10 +51,8 @@ class RelayEchoClient(WebSocketClientProtocol): @react @inlineCallbacks def main(reactor): - #ep = endpoints.clientFromString(reactor, "ws://localhost:4001/") - from twisted.plugins.autobahn_endpoints import AutobahnClientEndpoint - ep = endpoints.clientFromString(reactor, "tcp:localhost:4001") - f = WebSocketClientFactory("ws://127.0.0.1:4001/") + ep = endpoints.clientFromString(reactor, "tcp:localhost:4002") + f = WebSocketClientFactory("ws://127.0.0.1:4002/") f.protocol = RelayEchoClient # NB: write our own factory, probably.. f.token = "a" * 64 @@ -72,3 +70,5 @@ def main(reactor): proto.sendMessage(b"it's a message", True) yield proto.sendClose() yield f.done + print("relayed {} bytes:".format(len(proto.data))) + print(proto.data.decode("utf8"))