diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 5a3bbf6..8690e91 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -72,7 +72,6 @@ class ServerBase: usage_db=usage_db, ) self._transit_server = Transit(usage, lambda: 123456789.0) - self._transit_server._debug_log = self.log_requests def new_protocol(self): """ diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 1854a17..9fee87d 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -1,11 +1,19 @@ from __future__ import print_function, unicode_literals +import base64 from binascii import hexlify from twisted.trial import unittest +from twisted.test import proto_helpers from .common import ServerBase from ..server_state import ( MemoryUsageRecorder, blur_size, ) +from ..transit_server import ( + WebSocketTransitConnection, +) + +from autobahn.twisted.websocket import WebSocketServerFactory + def handshake(token, side=None): hs = b"please relay " + hexlify(token) @@ -458,3 +466,49 @@ class Usage(ServerBase, unittest.TestCase): self.flush() self.assertEqual(len(self._usage.events), 3, self._usage) self.assertEqual(self._usage.events[2]["mood"], "happy") + + +class UsageWebSockets(Usage): + """ + All the tests of 'Usage' except with a WebSocket (instead of TCP) + transport. + + This overrides ServerBase.new_protocol to achieve this. It might + be nicer to parametrize these tests in a way that doesn't use + inheritance .. but all the support etc classes are set up that way + already. + """ + + def new_protocol(self): + ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url + ws_factory.protocol = WebSocketTransitConnection + ws_factory.websocket_protocols = ["transit_relay"] + ws_factory.transit = self._transit + + protocol = ws_factory.buildProtocol(('127.0.0.1', 4002)) + transport = proto_helpers.StringTransportWithDisconnection() + protocol.makeConnection(transport) + transport.protocol = protocol + + class Producer: + pass + protocol.registerProducer(Producer(), False) +## protocol.transport.abortConnection = protocol.transport.loseConnection + + # unlike in the TCP case, we need to drive a WebSocket + # handshake through the server first. + options = {} + self._websocket_key = b"0" * 16 + request = ( + "GET /ws HTTP/1.1\x0d\x0a" + "Host: 127.0.0.1:4002\x0d\x0a" + "Upgrade: WebSocket\x0d\x0a" + "Connection: Upgrade\x0d\x0a" + "Sec-WebSocket-Key: {}\x0d\x0a" + "Sec-WebSocket-Protocol: transit-relay\x0d\x0a" + "Sec-WebSocket-Version: 13\x0d\x0a" + "\x0d\x0a" + ).format(base64.b64encode(self._websocket_key).decode()) + protocol.dataReceived(request.encode("utf8")) + + return protocol diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 5ca92f1..9e5a419 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -289,6 +289,11 @@ class WebSocketTransitConnection(WebSocketServerProtocol): else: self._state.got_bytes(payload) + def disconnect_redundant(self): + # this is called if a buddy connected and we were found unnecessary. + # Any token-tracking cleanup will have been done before we're called. + self.transport.loseConnection() + def onClose(self, wasClean, code, reason): """ IWebSocketChannel API