diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 2b55ef7..c5370de 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -367,11 +367,72 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): return self.new_protocol_tcp() +def _new_protocol_ws(transit_server, log_requests): + """ + Internal helper for test-suites that need to provide WebSocket + client/server pairs. + + :returns: a 2-tuple: (iosim.IOPump, protocol) + """ + ws_factory = WebSocketServerFactory("ws://localhost:4002") + ws_factory.protocol = WebSocketTransitConnection + ws_factory.transit = transit_server + ws_factory.log_requests = log_requests + ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) + + @implementer(IRelayTestClient) + class TransitWebSocketClientProtocol(WebSocketClientProtocol): + _received = b"" + connected = False + + def connectionMade(self): + self.connected = True + return super(TransitWebSocketClientProtocol, self).connectionMade() + + def connectionLost(self, reason): + self.connected = False + return super(TransitWebSocketClientProtocol, self).connectionLost(reason) + + def onMessage(self, data, isBinary): + self._received = self._received + data + + def send(self, data): + self.sendMessage(data, True) + + def get_received_data(self): + return self._received + + def reset_received_data(self): + self._received = b"" + + def disconnect(self): + self.sendClose(1000, True) + + client_factory = WebSocketClientFactory() + client_factory.protocol = TransitWebSocketClientProtocol + client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) + client_protocol.disconnect = client_protocol.dropConnection + + pump = iosim.connect( + ws_protocol, + iosim.makeFakeServer(ws_protocol), + client_protocol, + iosim.makeFakeClient(client_protocol), + ) + return pump, client_protocol + + + class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): def new_protocol(self): return self.new_protocol_ws() + def new_protocol_ws(self): + pump, proto = _new_protocol_ws(self._transit_server, self.log_requests) + self._pumps.append(pump) + return proto + def test_websocket_to_tcp(self): """ One client is WebSocket and one is TCP @@ -437,55 +498,6 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): p1.send(b"more message") self.flush() - def new_protocol_ws(self): - ws_factory = WebSocketServerFactory("ws://localhost:4002") - ws_factory.protocol = WebSocketTransitConnection - ws_factory.transit = self._transit_server - ws_factory.log_requests = self.log_requests - ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) - - @implementer(IRelayTestClient) - class TransitWebSocketClientProtocol(WebSocketClientProtocol): - _received = b"" - connected = False - - def connectionMade(self): - self.connected = True - return super(TransitWebSocketClientProtocol, self).connectionMade() - - def connectionLost(self, reason): - self.connected = False - return super(TransitWebSocketClientProtocol, self).connectionLost(reason) - - def onMessage(self, data, isBinary): - self._received = self._received + data - - def send(self, data): - self.sendMessage(data, True) - - def get_received_data(self): - return self._received - - def reset_received_data(self): - self._received = b"" - - def disconnect(self): - self.sendClose(1000, True) - - client_factory = WebSocketClientFactory() - client_factory.protocol = TransitWebSocketClientProtocol - client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) - client_protocol.disconnect = client_protocol.dropConnection - - pump = iosim.connect( - ws_protocol, - iosim.makeFakeServer(ws_protocol), - client_protocol, - iosim.makeFakeClient(client_protocol), - ) - self._pumps.append(pump) - return client_protocol - class Usage(ServerBase, unittest.TestCase): log_requests = True @@ -636,6 +648,11 @@ class UsageWebSockets(Usage): def new_protocol(self): return self.new_protocol_ws() + def new_protocol_ws(self): + pump, proto = _new_protocol_ws(self._transit_server, self.log_requests) + self._pumps.append(pump) + return proto + def test_short(self): """ This test essentially just tests the framing of the line-oriented @@ -654,31 +671,6 @@ class UsageWebSockets(Usage): with self.assertRaises(ValueError): ws_protocol.onMessage(u"foo", isBinary=False) - def new_protocol_ws(self): - ws_factory = WebSocketServerFactory("ws://localhost:4002") - ws_factory.protocol = WebSocketTransitConnection - ws_factory.transit = self._transit_server - ws_factory.log_requests = self.log_requests - ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) - - class TransitWebSocketClientProtocol(WebSocketClientProtocol): - def send(self, data): - self.sendMessage(data, True) - - client_factory = WebSocketClientFactory() - client_factory.protocol = TransitWebSocketClientProtocol - client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) - client_protocol.disconnect = client_protocol.dropConnection - - pump = iosim.connect( - ws_protocol, - iosim.makeFakeServer(ws_protocol), - client_protocol, - iosim.makeFakeClient(client_protocol), - ) - self._pumps.append(pump) - return client_protocol - class State(unittest.TestCase): """