From f18edc89f95047e7a94028bc3b5b7cd0bf675f8d Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 2 Apr 2021 14:58:31 -0600 Subject: [PATCH] refine --- .../test/test_transit_server.py | 22 +++++++++++-------- src/wormhole_transit_relay/transit_server.py | 4 ++++ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 946ca66..795b996 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -594,7 +594,7 @@ class New(unittest.TestCase): server_protocol = server_factory.buildProtocol(('127.0.0.1', 0)) class ClientProtocol(protocol.Protocol): - def sendMessage(self, data): + def send(self, data): self.transport.write(data) def disconnect(self): @@ -623,8 +623,13 @@ class New(unittest.TestCase): ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) from autobahn.twisted.websocket import WebSocketClientFactory, WebSocketClientProtocol + + class TransitWebSocketClientProtocol(WebSocketClientProtocol): + def send(self, data): + self.sendMessage(data, True) + client_factory = WebSocketClientFactory() - client_factory.protocol = WebSocketClientProtocol + client_factory.protocol = TransitWebSocketClientProtocol client_factory.protocols = ["binary"] client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) client_protocol.disconnect = client_protocol.dropConnection @@ -639,9 +644,10 @@ class New(unittest.TestCase): return client_protocol def test_short(self): + # XXX this test only makes sense for TCP p1 = self.new_protocol() # hang up before sending a complete handshake -# p1.sendMessage(b"short") # <-- only makes sense for TCP + p1.send(b"short") p1.disconnect() self.flush() @@ -659,21 +665,19 @@ class New(unittest.TestCase): from twisted.internet import reactor print("p1 data") - p1.sendMessage(handshake(token1, side=side1), True) + p1.send(handshake(token1, side=side1)) print("p2 data") - p2.sendMessage(handshake(token1, side=side2), True) + p2.send(handshake(token1, side=side2)) self.flush() print("shouldn't be events yet") self.assertEqual(self._usage.events, []) # no events yet print("p1 moar") - for x in range(13): - p1.sendMessage(b"\x00", True) - ##p1.sendMessage(b"\x00" * 13) + p1.send(b"\x00" * 13) self.flush() print("p2 moar") - p2.sendMessage(b"\xff" * 7, True) + p2.send(b"\xff" * 7) self.flush() print("p1 lose") diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index b81b5d6..f096bfc 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -277,6 +277,10 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ We may have a 'handshake' on our hands or we may just have some bytes to relay """ + if not isBinary: + raise ValueError( + "All messages must be binary" + ) # print("onMessage isBinary={}: {}".format(isBinary, payload)) if self._first_message: self._first_message = False