run '_Transit' tests on websockets too
This commit is contained in:
parent
4112f718d4
commit
b73c76c8df
|
@ -196,6 +196,7 @@ class _Transit:
|
||||||
|
|
||||||
p2.send(handshake(token1, side=side1))
|
p2.send(handshake(token1, side=side1))
|
||||||
self.flush()
|
self.flush()
|
||||||
|
self.flush()
|
||||||
self.assertEqual(self.count(), 2) # same-side connections don't match
|
self.assertEqual(self.count(), 2) # same-side connections don't match
|
||||||
|
|
||||||
# when the second side arrives, the spare first connection should be
|
# when the second side arrives, the spare first connection should be
|
||||||
|
@ -285,7 +286,8 @@ class _Transit:
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
# sending too many bytes is impatience.
|
# sending too many bytes is impatience.
|
||||||
p1.send(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
|
p1.send(b"please relay " + hexlify(token1))
|
||||||
|
p1.send(b"\nNOWNOWNOW")
|
||||||
self.flush()
|
self.flush()
|
||||||
|
|
||||||
exp = b"impatient\n"
|
exp = b"impatient\n"
|
||||||
|
@ -300,7 +302,8 @@ class _Transit:
|
||||||
side1 = b"\x01"*8
|
side1 = b"\x01"*8
|
||||||
# sending too many bytes is impatience.
|
# sending too many bytes is impatience.
|
||||||
p1.send(b"please relay " + hexlify(token1) +
|
p1.send(b"please relay " + hexlify(token1) +
|
||||||
b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
|
b" for side " + hexlify(side1))
|
||||||
|
p1.send(b"\nNOWNOWNOW")
|
||||||
self.flush()
|
self.flush()
|
||||||
|
|
||||||
exp = b"impatient\n"
|
exp = b"impatient\n"
|
||||||
|
@ -355,6 +358,58 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
|
||||||
log_requests = False
|
log_requests = False
|
||||||
|
|
||||||
|
|
||||||
|
class TransitWebSockets(_Transit, ServerBase, unittest.TestCase):
|
||||||
|
|
||||||
|
def test_bad_handshake_old_slow(self):
|
||||||
|
"""
|
||||||
|
This test only makes sense for TCP
|
||||||
|
"""
|
||||||
|
|
||||||
|
def new_protocol(self):
|
||||||
|
ws_factory = WebSocketServerFactory("ws://localhost:4002")
|
||||||
|
ws_factory.protocol = WebSocketTransitConnection
|
||||||
|
ws_factory.transit = self._transit_server
|
||||||
|
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
|
||||||
|
|
||||||
|
class TransitWebSocketClientProtocol(WebSocketClientProtocol):
|
||||||
|
_received = b""
|
||||||
|
connected = False
|
||||||
|
|
||||||
|
def connectionMade(self):
|
||||||
|
self.connected = True
|
||||||
|
return super(TransitWebSocketClientProtocol, self).connectionMade()
|
||||||
|
|
||||||
|
def connectionLost(self, reason):
|
||||||
|
self.connected = False
|
||||||
|
return super(TransitWebSocketClientProtocol, self).connectionLost(reason)
|
||||||
|
|
||||||
|
def send(self, data):
|
||||||
|
self.sendMessage(data, True)
|
||||||
|
|
||||||
|
def onMessage(self, data, isBinary):
|
||||||
|
self._received = self._received + data
|
||||||
|
|
||||||
|
def get_received_data(self):
|
||||||
|
return self._received
|
||||||
|
|
||||||
|
def reset_received_data(self):
|
||||||
|
self._received = b""
|
||||||
|
|
||||||
|
client_factory = WebSocketClientFactory()
|
||||||
|
client_factory.protocol = TransitWebSocketClientProtocol
|
||||||
|
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
|
||||||
|
client_protocol.disconnect = client_protocol.dropConnection
|
||||||
|
|
||||||
|
pump = iosim.connect(
|
||||||
|
ws_protocol,
|
||||||
|
iosim.makeFakeServer(ws_protocol),
|
||||||
|
client_protocol,
|
||||||
|
iosim.makeFakeClient(client_protocol),
|
||||||
|
)
|
||||||
|
self._pumps.append(pump)
|
||||||
|
return client_protocol
|
||||||
|
|
||||||
|
|
||||||
class Usage(ServerBase, unittest.TestCase):
|
class Usage(ServerBase, unittest.TestCase):
|
||||||
log_requests = True
|
log_requests = True
|
||||||
|
|
||||||
|
@ -503,6 +558,16 @@ class UsageWebSockets(Usage):
|
||||||
because it is semantically invalid or no handshake (yet).
|
because it is semantically invalid or no handshake (yet).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def test_send_non_binary_message(self):
|
||||||
|
"""
|
||||||
|
A non-binary WebSocket message is an error
|
||||||
|
"""
|
||||||
|
ws_factory = WebSocketServerFactory("ws://localhost:4002")
|
||||||
|
ws_factory.protocol = WebSocketTransitConnection
|
||||||
|
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
ws_protocol.onMessage(u"foo", isBinary=False)
|
||||||
|
|
||||||
def new_protocol(self):
|
def new_protocol(self):
|
||||||
ws_factory = WebSocketServerFactory("ws://localhost:4002")
|
ws_factory = WebSocketServerFactory("ws://localhost:4002")
|
||||||
ws_factory.protocol = WebSocketTransitConnection
|
ws_factory.protocol = WebSocketTransitConnection
|
||||||
|
|
Loading…
Reference in New Issue
Block a user