run '_Transit' tests on websockets too

This commit is contained in:
meejah 2021-04-07 17:03:23 -06:00
parent 4112f718d4
commit b73c76c8df

View File

@ -196,6 +196,7 @@ class _Transit:
p2.send(handshake(token1, side=side1)) p2.send(handshake(token1, side=side1))
self.flush() self.flush()
self.flush()
self.assertEqual(self.count(), 2) # same-side connections don't match self.assertEqual(self.count(), 2) # same-side connections don't match
# when the second side arrives, the spare first connection should be # when the second side arrives, the spare first connection should be
@ -285,7 +286,8 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
# sending too many bytes is impatience. # sending too many bytes is impatience.
p1.send(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") p1.send(b"please relay " + hexlify(token1))
p1.send(b"\nNOWNOWNOW")
self.flush() self.flush()
exp = b"impatient\n" exp = b"impatient\n"
@ -300,7 +302,8 @@ class _Transit:
side1 = b"\x01"*8 side1 = b"\x01"*8
# sending too many bytes is impatience. # sending too many bytes is impatience.
p1.send(b"please relay " + hexlify(token1) + p1.send(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\nNOWNOWNOW") b" for side " + hexlify(side1))
p1.send(b"\nNOWNOWNOW")
self.flush() self.flush()
exp = b"impatient\n" exp = b"impatient\n"
@ -355,6 +358,58 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
log_requests = False log_requests = False
class TransitWebSockets(_Transit, ServerBase, unittest.TestCase):
def test_bad_handshake_old_slow(self):
"""
This test only makes sense for TCP
"""
def new_protocol(self):
ws_factory = WebSocketServerFactory("ws://localhost:4002")
ws_factory.protocol = WebSocketTransitConnection
ws_factory.transit = self._transit_server
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
class TransitWebSocketClientProtocol(WebSocketClientProtocol):
_received = b""
connected = False
def connectionMade(self):
self.connected = True
return super(TransitWebSocketClientProtocol, self).connectionMade()
def connectionLost(self, reason):
self.connected = False
return super(TransitWebSocketClientProtocol, self).connectionLost(reason)
def send(self, data):
self.sendMessage(data, True)
def onMessage(self, data, isBinary):
self._received = self._received + data
def get_received_data(self):
return self._received
def reset_received_data(self):
self._received = b""
client_factory = WebSocketClientFactory()
client_factory.protocol = TransitWebSocketClientProtocol
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
client_protocol.disconnect = client_protocol.dropConnection
pump = iosim.connect(
ws_protocol,
iosim.makeFakeServer(ws_protocol),
client_protocol,
iosim.makeFakeClient(client_protocol),
)
self._pumps.append(pump)
return client_protocol
class Usage(ServerBase, unittest.TestCase): class Usage(ServerBase, unittest.TestCase):
log_requests = True log_requests = True
@ -503,6 +558,16 @@ class UsageWebSockets(Usage):
because it is semantically invalid or no handshake (yet). because it is semantically invalid or no handshake (yet).
""" """
def test_send_non_binary_message(self):
"""
A non-binary WebSocket message is an error
"""
ws_factory = WebSocketServerFactory("ws://localhost:4002")
ws_factory.protocol = WebSocketTransitConnection
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
with self.assertRaises(ValueError):
ws_protocol.onMessage(u"foo", isBinary=False)
def new_protocol(self): def new_protocol(self):
ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory = WebSocketServerFactory("ws://localhost:4002")
ws_factory.protocol = WebSocketTransitConnection ws_factory.protocol = WebSocketTransitConnection