From 5f43e53db17daeddc3139a67eedfc1bf35e8b941 Mon Sep 17 00:00:00 2001 From: meejah Date: Mon, 12 Apr 2021 09:35:55 -0600 Subject: [PATCH] cleanup --- .../test/test_transit_server.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 826093d..994121e 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -99,7 +99,6 @@ class _Transit: self.assertEqual(p2.get_received_data(), s1) p1.disconnect() - p2.disconnect() self.flush() def test_sided_unsided(self): @@ -128,7 +127,6 @@ class _Transit: self.assertEqual(p2.get_received_data(), s1) p1.disconnect() - p2.disconnect() self.flush() def test_unsided_sided(self): @@ -365,6 +363,9 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): + # XXX note to self, from pairing with Flo: + # - write a WS <--> TCP version of at least one of these tests? + def test_bad_handshake_old_slow(self): """ This test only makes sense for TCP @@ -387,8 +388,6 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): # p2 loses connection, then p1 sends a message p2.transport.loseConnection() self.flush() - p1.send(b"more message") - self.flush() # at this point, p1 learns that p2 is disconnected (because it # tried to relay "a message" but failed) @@ -417,18 +416,21 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): self.connected = False return super(TransitWebSocketClientProtocol, self).connectionLost(reason) - def send(self, data): - self.sendMessage(data, True) - def onMessage(self, data, isBinary): self._received = self._received + data + def send(self, data): + self.sendMessage(data, True) + def get_received_data(self): return self._received def reset_received_data(self): self._received = b"" + def disconnect(self): + self.sendClose(1000, True) + client_factory = WebSocketClientFactory() client_factory.protocol = TransitWebSocketClientProtocol client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))