This commit is contained in:
meejah 2021-02-23 13:47:22 -07:00
parent 2b78fbec8f
commit 1a461aa461
3 changed files with 36 additions and 27 deletions

View File

@ -5,6 +5,10 @@ from twisted.application.service import MultiService
from twisted.application.internet import (TimerService, from twisted.application.internet import (TimerService,
StreamServerEndpointService) StreamServerEndpointService)
from twisted.internet import endpoints from twisted.internet import endpoints
from twisted.internet import protocol
from autobahn.twisted.websocket import WebSocketServerFactory
from . import transit_server from . import transit_server
from .server_state import create_usage_tracker from .server_state import create_usage_tracker
from .increase_rlimits import increase_rlimits from .increase_rlimits import increase_rlimits
@ -33,7 +37,9 @@ class Options(usage.Options):
def makeService(config, reactor=reactor): def makeService(config, reactor=reactor):
increase_rlimits() increase_rlimits()
ep = endpoints.serverFromString(reactor, config["port"]) # to listen tcp_ep = endpoints.serverFromString(reactor, config["port"]) # to listen
# XXX FIXME proper websocket option
ws_ep = endpoints.serverFromString(reactor, "tcp:4002:interface=localhost") # to listen
log_file = ( log_file = (
os.fdopen(int(config["log-fd"]), "w") os.fdopen(int(config["log-fd"]), "w")
if config["log-fd"] is not None if config["log-fd"] is not None
@ -45,9 +51,18 @@ def makeService(config, reactor=reactor):
log_file=log_file, log_file=log_file,
usage_db=db, usage_db=db,
) )
##factory = transit_server.Transit(usage, reactor.seconds) transit = transit_server.Transit(usage, reactor.seconds)
factory = transit_server.WebSocketTransit(usage, reactor.seconds) tcp_factory = protocol.ServerFactory()
tcp_factory.protocol = transit_server.TransitConnection
ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url
ws_factory.protocol = transit_server.WebSocketTransitConnection
ws_factory.websocket_protocols = ["transit_relay"]
tcp_factory.transit = transit
ws_factory.transit = transit
parent = MultiService() parent = MultiService()
StreamServerEndpointService(ep, factory).setServiceParent(parent) StreamServerEndpointService(tcp_ep, tcp_factory).setServiceParent(parent)
TimerService(5*60.0, factory.update_stats).setServiceParent(parent) StreamServerEndpointService(ws_ep, ws_factory).setServiceParent(parent)
TimerService(5*60.0, transit.update_stats).setServiceParent(parent)
return parent return parent

View File

@ -4,10 +4,8 @@ import time
from twisted.python import log from twisted.python import log
from twisted.internet import protocol from twisted.internet import protocol
from twisted.protocols.basic import LineReceiver from twisted.protocols.basic import LineReceiver
from autobahn.twisted.websocket import ( from autobahn.twisted.websocket import WebSocketServerProtocol
WebSocketServerProtocol,
WebSocketServerFactory,
)
@ -69,8 +67,8 @@ class TransitConnection(LineReceiver):
# (besides "use the global one") # (besides "use the global one")
self.started_time = time.time() self.started_time = time.time()
self._state = TransitServerState( self._state = TransitServerState(
self.factory.pending_requests, self.factory.transit.pending_requests,
self.factory.usage, self.factory.transit.usage,
) )
self._state.connection_made(self) self._state.connection_made(self)
## self._log_requests = self.factory._log_requests ## self._log_requests = self.factory._log_requests
@ -119,7 +117,7 @@ class TransitConnection(LineReceiver):
# us # us
self.transport.loseConnection() self.transport.loseConnection()
# XXX probably should be logged by state? # XXX probably should be logged by state?
if self.factory._debug_log: if self.factory.transit._debug_log:
log.msg("transitFailed %r" % self) log.msg("transitFailed %r" % self)
def disconnect_redundant(self): def disconnect_redundant(self):
@ -140,7 +138,9 @@ class TransitConnection(LineReceiver):
# different for websocket versus "normal" socket .. so maybe we need # different for websocket versus "normal" socket .. so maybe we need
# to make Transit *not* the factory directly?) # to make Transit *not* the factory directly?)
class Transit(WebSocketServerFactory):#protocol.ServerFactory): ##WebSocketServerFactory):#protocol.ServerFactory):
class Transit(object):
""" """
I manage pairs of simultaneous connections to a secondary TCP port, I manage pairs of simultaneous connections to a secondary TCP port,
both forwarded to the other. Clients must begin each connection with both forwarded to the other. Clients must begin each connection with
@ -176,10 +176,9 @@ class Transit(WebSocketServerFactory):#protocol.ServerFactory):
MAXLENGTH = 10*MB MAXLENGTH = 10*MB
# TODO: unused # TODO: unused
MAXTIME = 60*SECONDS MAXTIME = 60*SECONDS
protocol = TransitConnection ## protocol = TransitConnection
def __init__(self, usage, get_timestamp): def __init__(self, usage, get_timestamp):
super(Transit, self).__init__()
self.active_connections = ActiveConnections() self.active_connections = ActiveConnections()
self.pending_requests = PendingRequests(self.active_connections) self.pending_requests = PendingRequests(self.active_connections)
self.usage = usage self.usage = usage
@ -253,13 +252,13 @@ class WebSocketTransitConnection(WebSocketServerProtocol):
IProtocol API IProtocol API
""" """
print("connectionMade") print("connectionMade")
super(WebSocketTransitConnection, self).connectionMade()
self.started_time = time.time() self.started_time = time.time()
self._first_message = True self._first_message = True
self._state = TransitServerState( self._state = TransitServerState(
self.factory.pending_requests, self.factory.transit.pending_requests,
self.factory.usage, self.factory.transit.usage,
) )
return super(WebSocketTransitConnection, self).connectionMade()
def onOpen(self): def onOpen(self):
print("onOpen") print("onOpen")
@ -296,8 +295,3 @@ class WebSocketTransitConnection(WebSocketServerProtocol):
""" """
self._state.connection_lost() self._state.connection_lost()
# XXX "transit finished", etc # XXX "transit finished", etc
class WebSocketTransit(Transit, WebSocketServerFactory):
protocol = WebSocketTransitConnection
websocket_protocols = ["transit_relay"]

View File

@ -51,10 +51,8 @@ class RelayEchoClient(WebSocketClientProtocol):
@react @react
@inlineCallbacks @inlineCallbacks
def main(reactor): def main(reactor):
#ep = endpoints.clientFromString(reactor, "ws://localhost:4001/") ep = endpoints.clientFromString(reactor, "tcp:localhost:4002")
from twisted.plugins.autobahn_endpoints import AutobahnClientEndpoint f = WebSocketClientFactory("ws://127.0.0.1:4002/")
ep = endpoints.clientFromString(reactor, "tcp:localhost:4001")
f = WebSocketClientFactory("ws://127.0.0.1:4001/")
f.protocol = RelayEchoClient f.protocol = RelayEchoClient
# NB: write our own factory, probably.. # NB: write our own factory, probably..
f.token = "a" * 64 f.token = "a" * 64
@ -72,3 +70,5 @@ def main(reactor):
proto.sendMessage(b"it's a message", True) proto.sendMessage(b"it's a message", True)
yield proto.sendClose() yield proto.sendClose()
yield f.done yield f.done
print("relayed {} bytes:".format(len(proto.data)))
print(proto.data.decode("utf8"))