websocket version of tests, with handshake

This commit is contained in:
meejah 2021-03-04 23:30:59 -07:00
parent 21af1f68a3
commit 5210566150
3 changed files with 59 additions and 1 deletions

View File

@ -72,7 +72,6 @@ class ServerBase:
usage_db=usage_db, usage_db=usage_db,
) )
self._transit_server = Transit(usage, lambda: 123456789.0) self._transit_server = Transit(usage, lambda: 123456789.0)
self._transit_server._debug_log = self.log_requests
def new_protocol(self): def new_protocol(self):
""" """

View File

@ -1,11 +1,19 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import base64
from binascii import hexlify from binascii import hexlify
from twisted.trial import unittest from twisted.trial import unittest
from twisted.test import proto_helpers
from .common import ServerBase from .common import ServerBase
from ..server_state import ( from ..server_state import (
MemoryUsageRecorder, MemoryUsageRecorder,
blur_size, blur_size,
) )
from ..transit_server import (
WebSocketTransitConnection,
)
from autobahn.twisted.websocket import WebSocketServerFactory
def handshake(token, side=None): def handshake(token, side=None):
hs = b"please relay " + hexlify(token) hs = b"please relay " + hexlify(token)
@ -458,3 +466,49 @@ class Usage(ServerBase, unittest.TestCase):
self.flush() self.flush()
self.assertEqual(len(self._usage.events), 3, self._usage) self.assertEqual(len(self._usage.events), 3, self._usage)
self.assertEqual(self._usage.events[2]["mood"], "happy") self.assertEqual(self._usage.events[2]["mood"], "happy")
class UsageWebSockets(Usage):
"""
All the tests of 'Usage' except with a WebSocket (instead of TCP)
transport.
This overrides ServerBase.new_protocol to achieve this. It might
be nicer to parametrize these tests in a way that doesn't use
inheritance .. but all the support etc classes are set up that way
already.
"""
def new_protocol(self):
ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url
ws_factory.protocol = WebSocketTransitConnection
ws_factory.websocket_protocols = ["transit_relay"]
ws_factory.transit = self._transit
protocol = ws_factory.buildProtocol(('127.0.0.1', 4002))
transport = proto_helpers.StringTransportWithDisconnection()
protocol.makeConnection(transport)
transport.protocol = protocol
class Producer:
pass
protocol.registerProducer(Producer(), False)
## protocol.transport.abortConnection = protocol.transport.loseConnection
# unlike in the TCP case, we need to drive a WebSocket
# handshake through the server first.
options = {}
self._websocket_key = b"0" * 16
request = (
"GET /ws HTTP/1.1\x0d\x0a"
"Host: 127.0.0.1:4002\x0d\x0a"
"Upgrade: WebSocket\x0d\x0a"
"Connection: Upgrade\x0d\x0a"
"Sec-WebSocket-Key: {}\x0d\x0a"
"Sec-WebSocket-Protocol: transit-relay\x0d\x0a"
"Sec-WebSocket-Version: 13\x0d\x0a"
"\x0d\x0a"
).format(base64.b64encode(self._websocket_key).decode())
protocol.dataReceived(request.encode("utf8"))
return protocol

View File

@ -289,6 +289,11 @@ class WebSocketTransitConnection(WebSocketServerProtocol):
else: else:
self._state.got_bytes(payload) self._state.got_bytes(payload)
def disconnect_redundant(self):
# this is called if a buddy connected and we were found unnecessary.
# Any token-tracking cleanup will have been done before we're called.
self.transport.loseConnection()
def onClose(self, wasClean, code, reason): def onClose(self, wasClean, code, reason):
""" """
IWebSocketChannel API IWebSocketChannel API