diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 4095f45..73f75fd 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -196,6 +196,7 @@ class _Transit: p2.send(handshake(token1, side=side1)) self.flush() + self.flush() self.assertEqual(self.count(), 2) # same-side connections don't match # when the second side arrives, the spare first connection should be @@ -285,7 +286,8 @@ class _Transit: token1 = b"\x00"*32 # sending too many bytes is impatience. - p1.send(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") + p1.send(b"please relay " + hexlify(token1)) + p1.send(b"\nNOWNOWNOW") self.flush() exp = b"impatient\n" @@ -300,7 +302,8 @@ class _Transit: side1 = b"\x01"*8 # sending too many bytes is impatience. p1.send(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\nNOWNOWNOW") + b" for side " + hexlify(side1)) + p1.send(b"\nNOWNOWNOW") self.flush() exp = b"impatient\n" @@ -355,6 +358,58 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): log_requests = False +class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): + + def test_bad_handshake_old_slow(self): + """ + This test only makes sense for TCP + """ + + def new_protocol(self): + ws_factory = WebSocketServerFactory("ws://localhost:4002") + ws_factory.protocol = WebSocketTransitConnection + ws_factory.transit = self._transit_server + ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) + + class TransitWebSocketClientProtocol(WebSocketClientProtocol): + _received = b"" + connected = False + + def connectionMade(self): + self.connected = True + return super(TransitWebSocketClientProtocol, self).connectionMade() + + def connectionLost(self, reason): + self.connected = False + return super(TransitWebSocketClientProtocol, self).connectionLost(reason) + + def send(self, data): + self.sendMessage(data, True) + + def onMessage(self, data, isBinary): + self._received = self._received + data + + def get_received_data(self): + return self._received + + def reset_received_data(self): + self._received = b"" + + client_factory = WebSocketClientFactory() + client_factory.protocol = TransitWebSocketClientProtocol + client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) + client_protocol.disconnect = client_protocol.dropConnection + + pump = iosim.connect( + ws_protocol, + iosim.makeFakeServer(ws_protocol), + client_protocol, + iosim.makeFakeClient(client_protocol), + ) + self._pumps.append(pump) + return client_protocol + + class Usage(ServerBase, unittest.TestCase): log_requests = True @@ -503,6 +558,16 @@ class UsageWebSockets(Usage): because it is semantically invalid or no handshake (yet). """ + def test_send_non_binary_message(self): + """ + A non-binary WebSocket message is an error + """ + ws_factory = WebSocketServerFactory("ws://localhost:4002") + ws_factory.protocol = WebSocketTransitConnection + ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) + with self.assertRaises(ValueError): + ws_protocol.onMessage(u"foo", isBinary=False) + def new_protocol(self): ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory.protocol = WebSocketTransitConnection