hack in prelim websocket support

This commit is contained in:
meejah 2021-02-23 13:31:55 -07:00
parent c2147ee985
commit 34d039c38c
3 changed files with 119 additions and 2 deletions

View File

@ -612,11 +612,19 @@ class TransitServerState(object):
Terminal state
"""
# need a listening.upon(connection_lost) for special websocket
# case where handshake fails?
listening.upon(
connection_made,
enter=wait_relay,
outputs=[_remember_client],
)
listening.upon(
connection_lost,
enter=done,
outputs=[_mood_errory],
)
wait_relay.upon(
please_relay,

View File

@ -45,7 +45,8 @@ def makeService(config, reactor=reactor):
log_file=log_file,
usage_db=db,
)
factory = transit_server.Transit(usage, reactor.seconds)
##factory = transit_server.Transit(usage, reactor.seconds)
factory = transit_server.WebSocketTransit(usage, reactor.seconds)
parent = MultiService()
StreamServerEndpointService(ep, factory).setServiceParent(parent)
TimerService(5*60.0, factory.update_stats).setServiceParent(parent)

View File

@ -4,6 +4,12 @@ import time
from twisted.python import log
from twisted.internet import protocol
from twisted.protocols.basic import LineReceiver
from autobahn.twisted.websocket import (
WebSocketServerProtocol,
WebSocketServerFactory,
)
SECONDS = 1.0
MINUTE = 60*SECONDS
@ -129,8 +135,12 @@ class TransitConnection(LineReceiver):
# XXX describeToken -> self._state.get_token()
# XXX multiple-inheritance sucks...
# ("Transit" wants to be "the factory" but the base class is slightly
# different for websocket versus "normal" socket .. so maybe we need
# to make Transit *not* the factory directly?)
class Transit(protocol.ServerFactory):
class Transit(WebSocketServerFactory):#protocol.ServerFactory):
"""
I manage pairs of simultaneous connections to a secondary TCP port,
both forwarded to the other. Clients must begin each connection with
@ -169,6 +179,7 @@ class Transit(protocol.ServerFactory):
protocol = TransitConnection
def __init__(self, usage, get_timestamp):
super(Transit, self).__init__()
self.active_connections = ActiveConnections()
self.pending_requests = PendingRequests(self.active_connections)
self.usage = usage
@ -193,3 +204,100 @@ class Transit(protocol.ServerFactory):
for tc in self.active_connections._connections
),
)
@implementer(ITransitClient)
class WebSocketTransitConnection(WebSocketServerProtocol):
started_time = None
def send(self, data):
"""
ITransitClient API
"""
self.sendMessage(data, isBinary=True)
def disconnect(self):
"""
ITransitClient API
"""
self.sendClose(1000, None)
def connect_partner(self, other):
"""
ITransitClient API
"""
self._buddy = other
def disconnect_partner(self):
"""
ITransitClient API
"""
if self._buddy is not None:
log.msg("buddy_disconnected {}".format(self._buddy.get_token()))
self._buddy._client.transport.loseConnection()
self._buddy = None
def onConnect(self, request):
"""
IWebSocketChannel API
"""
print("onConnect: {}".format(request))
# ideally more like self._reactor.seconds() ... but Twisted
# doesn't have a good way to get the reactor for a protocol
# (besides "use the global one")
print("protocols: {}".format(request.protocols))
return None#"transit_relay"
def connectionMade(self):
"""
IProtocol API
"""
print("connectionMade")
self.started_time = time.time()
self._first_message = True
self._state = TransitServerState(
self.factory.pending_requests,
self.factory.usage,
)
return super(WebSocketTransitConnection, self).connectionMade()
def onOpen(self):
print("onOpen")
self._state.connection_made(self)
def onMessage(self, payload, isBinary):
"""
We may have a 'handshake' on our hands or we may just have some bytes to relay
"""
print("onMessage isBinary={}: {}".format(isBinary, payload))
if self._first_message:
self._first_message = False
token = None
old = re.search(br"^please relay (\w{64})$", payload)
if old:
token = old.group(1)
self._state.please_relay(token)
# new: "please relay {64} for side {16}\n"
new = re.search(br"^please relay (\w{64}) for side (\w{16})$", payload)
if new:
token = new.group(1)
side = new.group(2)
self._state.please_relay_for_side(token, side)
if token is None:
self._state.bad_token()
else:
self._state.got_bytes(payload)
def onClose(self, wasClean, code, reason):
"""
IWebSocketChannel API
"""
self._state.connection_lost()
# XXX "transit finished", etc
class WebSocketTransit(Transit, WebSocketServerFactory):
protocol = WebSocketTransitConnection
websocket_protocols = ["transit_relay"]