This commit is contained in:
meejah 2021-04-12 09:35:55 -06:00
parent 8132ea8f91
commit 5f43e53db1

View File

@ -99,7 +99,6 @@ class _Transit:
self.assertEqual(p2.get_received_data(), s1) self.assertEqual(p2.get_received_data(), s1)
p1.disconnect() p1.disconnect()
p2.disconnect()
self.flush() self.flush()
def test_sided_unsided(self): def test_sided_unsided(self):
@ -128,7 +127,6 @@ class _Transit:
self.assertEqual(p2.get_received_data(), s1) self.assertEqual(p2.get_received_data(), s1)
p1.disconnect() p1.disconnect()
p2.disconnect()
self.flush() self.flush()
def test_unsided_sided(self): def test_unsided_sided(self):
@ -365,6 +363,9 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): class TransitWebSockets(_Transit, ServerBase, unittest.TestCase):
# XXX note to self, from pairing with Flo:
# - write a WS <--> TCP version of at least one of these tests?
def test_bad_handshake_old_slow(self): def test_bad_handshake_old_slow(self):
""" """
This test only makes sense for TCP This test only makes sense for TCP
@ -387,8 +388,6 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase):
# p2 loses connection, then p1 sends a message # p2 loses connection, then p1 sends a message
p2.transport.loseConnection() p2.transport.loseConnection()
self.flush() self.flush()
p1.send(b"more message")
self.flush()
# at this point, p1 learns that p2 is disconnected (because it # at this point, p1 learns that p2 is disconnected (because it
# tried to relay "a message" but failed) # tried to relay "a message" but failed)
@ -417,18 +416,21 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase):
self.connected = False self.connected = False
return super(TransitWebSocketClientProtocol, self).connectionLost(reason) return super(TransitWebSocketClientProtocol, self).connectionLost(reason)
def send(self, data):
self.sendMessage(data, True)
def onMessage(self, data, isBinary): def onMessage(self, data, isBinary):
self._received = self._received + data self._received = self._received + data
def send(self, data):
self.sendMessage(data, True)
def get_received_data(self): def get_received_data(self):
return self._received return self._received
def reset_received_data(self): def reset_received_data(self):
self._received = b"" self._received = b""
def disconnect(self):
self.sendClose(1000, True)
client_factory = WebSocketClientFactory() client_factory = WebSocketClientFactory()
client_factory.protocol = TransitWebSocketClientProtocol client_factory.protocol = TransitWebSocketClientProtocol
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))