websocket version of tests, with handshake
This commit is contained in:
parent
21af1f68a3
commit
5210566150
|
@ -72,7 +72,6 @@ class ServerBase:
|
|||
usage_db=usage_db,
|
||||
)
|
||||
self._transit_server = Transit(usage, lambda: 123456789.0)
|
||||
self._transit_server._debug_log = self.log_requests
|
||||
|
||||
def new_protocol(self):
|
||||
"""
|
||||
|
|
|
@ -1,11 +1,19 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import base64
|
||||
from binascii import hexlify
|
||||
from twisted.trial import unittest
|
||||
from twisted.test import proto_helpers
|
||||
from .common import ServerBase
|
||||
from ..server_state import (
|
||||
MemoryUsageRecorder,
|
||||
blur_size,
|
||||
)
|
||||
from ..transit_server import (
|
||||
WebSocketTransitConnection,
|
||||
)
|
||||
|
||||
from autobahn.twisted.websocket import WebSocketServerFactory
|
||||
|
||||
|
||||
def handshake(token, side=None):
|
||||
hs = b"please relay " + hexlify(token)
|
||||
|
@ -458,3 +466,49 @@ class Usage(ServerBase, unittest.TestCase):
|
|||
self.flush()
|
||||
self.assertEqual(len(self._usage.events), 3, self._usage)
|
||||
self.assertEqual(self._usage.events[2]["mood"], "happy")
|
||||
|
||||
|
||||
class UsageWebSockets(Usage):
|
||||
"""
|
||||
All the tests of 'Usage' except with a WebSocket (instead of TCP)
|
||||
transport.
|
||||
|
||||
This overrides ServerBase.new_protocol to achieve this. It might
|
||||
be nicer to parametrize these tests in a way that doesn't use
|
||||
inheritance .. but all the support etc classes are set up that way
|
||||
already.
|
||||
"""
|
||||
|
||||
def new_protocol(self):
|
||||
ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url
|
||||
ws_factory.protocol = WebSocketTransitConnection
|
||||
ws_factory.websocket_protocols = ["transit_relay"]
|
||||
ws_factory.transit = self._transit
|
||||
|
||||
protocol = ws_factory.buildProtocol(('127.0.0.1', 4002))
|
||||
transport = proto_helpers.StringTransportWithDisconnection()
|
||||
protocol.makeConnection(transport)
|
||||
transport.protocol = protocol
|
||||
|
||||
class Producer:
|
||||
pass
|
||||
protocol.registerProducer(Producer(), False)
|
||||
## protocol.transport.abortConnection = protocol.transport.loseConnection
|
||||
|
||||
# unlike in the TCP case, we need to drive a WebSocket
|
||||
# handshake through the server first.
|
||||
options = {}
|
||||
self._websocket_key = b"0" * 16
|
||||
request = (
|
||||
"GET /ws HTTP/1.1\x0d\x0a"
|
||||
"Host: 127.0.0.1:4002\x0d\x0a"
|
||||
"Upgrade: WebSocket\x0d\x0a"
|
||||
"Connection: Upgrade\x0d\x0a"
|
||||
"Sec-WebSocket-Key: {}\x0d\x0a"
|
||||
"Sec-WebSocket-Protocol: transit-relay\x0d\x0a"
|
||||
"Sec-WebSocket-Version: 13\x0d\x0a"
|
||||
"\x0d\x0a"
|
||||
).format(base64.b64encode(self._websocket_key).decode())
|
||||
protocol.dataReceived(request.encode("utf8"))
|
||||
|
||||
return protocol
|
||||
|
|
|
@ -289,6 +289,11 @@ class WebSocketTransitConnection(WebSocketServerProtocol):
|
|||
else:
|
||||
self._state.got_bytes(payload)
|
||||
|
||||
def disconnect_redundant(self):
|
||||
# this is called if a buddy connected and we were found unnecessary.
|
||||
# Any token-tracking cleanup will have been done before we're called.
|
||||
self.transport.loseConnection()
|
||||
|
||||
def onClose(self, wasClean, code, reason):
|
||||
"""
|
||||
IWebSocketChannel API
|
||||
|
|
Loading…
Reference in New Issue
Block a user