run '_Transit' tests on websockets too
This commit is contained in:
parent
4112f718d4
commit
b73c76c8df
|
@ -196,6 +196,7 @@ class _Transit:
|
|||
|
||||
p2.send(handshake(token1, side=side1))
|
||||
self.flush()
|
||||
self.flush()
|
||||
self.assertEqual(self.count(), 2) # same-side connections don't match
|
||||
|
||||
# when the second side arrives, the spare first connection should be
|
||||
|
@ -285,7 +286,8 @@ class _Transit:
|
|||
|
||||
token1 = b"\x00"*32
|
||||
# sending too many bytes is impatience.
|
||||
p1.send(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
|
||||
p1.send(b"please relay " + hexlify(token1))
|
||||
p1.send(b"\nNOWNOWNOW")
|
||||
self.flush()
|
||||
|
||||
exp = b"impatient\n"
|
||||
|
@ -300,7 +302,8 @@ class _Transit:
|
|||
side1 = b"\x01"*8
|
||||
# sending too many bytes is impatience.
|
||||
p1.send(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
|
||||
b" for side " + hexlify(side1))
|
||||
p1.send(b"\nNOWNOWNOW")
|
||||
self.flush()
|
||||
|
||||
exp = b"impatient\n"
|
||||
|
@ -355,6 +358,58 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
|
|||
log_requests = False
|
||||
|
||||
|
||||
class TransitWebSockets(_Transit, ServerBase, unittest.TestCase):
|
||||
|
||||
def test_bad_handshake_old_slow(self):
|
||||
"""
|
||||
This test only makes sense for TCP
|
||||
"""
|
||||
|
||||
def new_protocol(self):
|
||||
ws_factory = WebSocketServerFactory("ws://localhost:4002")
|
||||
ws_factory.protocol = WebSocketTransitConnection
|
||||
ws_factory.transit = self._transit_server
|
||||
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
|
||||
|
||||
class TransitWebSocketClientProtocol(WebSocketClientProtocol):
|
||||
_received = b""
|
||||
connected = False
|
||||
|
||||
def connectionMade(self):
|
||||
self.connected = True
|
||||
return super(TransitWebSocketClientProtocol, self).connectionMade()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.connected = False
|
||||
return super(TransitWebSocketClientProtocol, self).connectionLost(reason)
|
||||
|
||||
def send(self, data):
|
||||
self.sendMessage(data, True)
|
||||
|
||||
def onMessage(self, data, isBinary):
|
||||
self._received = self._received + data
|
||||
|
||||
def get_received_data(self):
|
||||
return self._received
|
||||
|
||||
def reset_received_data(self):
|
||||
self._received = b""
|
||||
|
||||
client_factory = WebSocketClientFactory()
|
||||
client_factory.protocol = TransitWebSocketClientProtocol
|
||||
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
|
||||
client_protocol.disconnect = client_protocol.dropConnection
|
||||
|
||||
pump = iosim.connect(
|
||||
ws_protocol,
|
||||
iosim.makeFakeServer(ws_protocol),
|
||||
client_protocol,
|
||||
iosim.makeFakeClient(client_protocol),
|
||||
)
|
||||
self._pumps.append(pump)
|
||||
return client_protocol
|
||||
|
||||
|
||||
class Usage(ServerBase, unittest.TestCase):
|
||||
log_requests = True
|
||||
|
||||
|
@ -503,6 +558,16 @@ class UsageWebSockets(Usage):
|
|||
because it is semantically invalid or no handshake (yet).
|
||||
"""
|
||||
|
||||
def test_send_non_binary_message(self):
|
||||
"""
|
||||
A non-binary WebSocket message is an error
|
||||
"""
|
||||
ws_factory = WebSocketServerFactory("ws://localhost:4002")
|
||||
ws_factory.protocol = WebSocketTransitConnection
|
||||
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
|
||||
with self.assertRaises(ValueError):
|
||||
ws_protocol.onMessage(u"foo", isBinary=False)
|
||||
|
||||
def new_protocol(self):
|
||||
ws_factory = WebSocketServerFactory("ws://localhost:4002")
|
||||
ws_factory.protocol = WebSocketTransitConnection
|
||||
|
|
Loading…
Reference in New Issue
Block a user