From 002773d79fe431bed99f99fe1777bd4629f9f04c Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 31 Mar 2021 19:47:29 -0600 Subject: [PATCH] WIP: first passing IOPump test --- .../test/test_transit_server.py | 66 +++++++++++-------- src/wormhole_transit_relay/transit_server.py | 1 + 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 9fee87d..6101501 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -3,6 +3,7 @@ import base64 from binascii import hexlify from twisted.trial import unittest from twisted.test import proto_helpers +from twisted.internet.defer import inlineCallbacks from .common import ServerBase from ..server_state import ( MemoryUsageRecorder, @@ -367,8 +368,9 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "empty", self._usage) + @inlineCallbacks def test_short(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() # hang up before sending a complete handshake p1.send(b"short") p1.disconnect() @@ -451,6 +453,7 @@ class Usage(ServerBase, unittest.TestCase): p1c.disconnect() self.flush() + print(self._usage.events) self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "lonely") @@ -468,6 +471,16 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(self._usage.events[2]["mood"], "happy") +from twisted.test import iosim +from twisted.internet.testing import MemoryReactorClock +from twisted.internet.address import IPv4Address +from autobahn.twisted.testing import ( + create_pumper, + create_memory_agent, + MemoryReactorClockResolver, +) + + class UsageWebSockets(Usage): """ All the tests of 'Usage' except with a WebSocket (instead of TCP) @@ -479,36 +492,31 @@ class UsageWebSockets(Usage): already. """ + def setUp(self): + super(UsageWebSockets, self).setUp() + self._pump = create_pumper() + self._reactor = MemoryReactorClockResolver() + return self._pump.start() + + def tearDown(self): + return self._pump.stop() + + @inlineCallbacks def new_protocol(self): - ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url - ws_factory.protocol = WebSocketTransitConnection - ws_factory.websocket_protocols = ["transit_relay"] - ws_factory.transit = self._transit - protocol = ws_factory.buildProtocol(('127.0.0.1', 4002)) - transport = proto_helpers.StringTransportWithDisconnection() - protocol.makeConnection(transport) - transport.protocol = protocol + class RelayFactory(WebSocketServerFactory): + protocol = WebSocketTransitConnection + websocket_protocols = ["transit_relay"] + transit = self._transit - class Producer: - pass - protocol.registerProducer(Producer(), False) -## protocol.transport.abortConnection = protocol.transport.loseConnection + server_factory = RelayFactory("ws://localhost:4002") - # unlike in the TCP case, we need to drive a WebSocket - # handshake through the server first. - options = {} - self._websocket_key = b"0" * 16 - request = ( - "GET /ws HTTP/1.1\x0d\x0a" - "Host: 127.0.0.1:4002\x0d\x0a" - "Upgrade: WebSocket\x0d\x0a" - "Connection: Upgrade\x0d\x0a" - "Sec-WebSocket-Key: {}\x0d\x0a" - "Sec-WebSocket-Protocol: transit-relay\x0d\x0a" - "Sec-WebSocket-Version: 13\x0d\x0a" - "\x0d\x0a" - ).format(base64.b64encode(self._websocket_key).decode()) - protocol.dataReceived(request.encode("utf8")) + agent = create_memory_agent( + self._reactor, + self._pump, + lambda: server_factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 31337)), + ) + client_proto = yield agent.open("ws://127.0.0.1:4002/", dict()) + print("PROTO", client_proto) + return client_proto - return protocol diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 9e5a419..133fbd5 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -298,5 +298,6 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ IWebSocketChannel API """ + print("onClose", wasClean, code, reason) self._state.connection_lost() # XXX "transit finished", etc