unify new_protocol_ws, make it a bare helper
This commit is contained in:
parent
0ce08b66cf
commit
807dfc1c18
|
@ -367,11 +367,72 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
|
|||
return self.new_protocol_tcp()
|
||||
|
||||
|
||||
def _new_protocol_ws(transit_server, log_requests):
|
||||
"""
|
||||
Internal helper for test-suites that need to provide WebSocket
|
||||
client/server pairs.
|
||||
|
||||
:returns: a 2-tuple: (iosim.IOPump, protocol)
|
||||
"""
|
||||
ws_factory = WebSocketServerFactory("ws://localhost:4002")
|
||||
ws_factory.protocol = WebSocketTransitConnection
|
||||
ws_factory.transit = transit_server
|
||||
ws_factory.log_requests = log_requests
|
||||
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
|
||||
|
||||
@implementer(IRelayTestClient)
|
||||
class TransitWebSocketClientProtocol(WebSocketClientProtocol):
|
||||
_received = b""
|
||||
connected = False
|
||||
|
||||
def connectionMade(self):
|
||||
self.connected = True
|
||||
return super(TransitWebSocketClientProtocol, self).connectionMade()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.connected = False
|
||||
return super(TransitWebSocketClientProtocol, self).connectionLost(reason)
|
||||
|
||||
def onMessage(self, data, isBinary):
|
||||
self._received = self._received + data
|
||||
|
||||
def send(self, data):
|
||||
self.sendMessage(data, True)
|
||||
|
||||
def get_received_data(self):
|
||||
return self._received
|
||||
|
||||
def reset_received_data(self):
|
||||
self._received = b""
|
||||
|
||||
def disconnect(self):
|
||||
self.sendClose(1000, True)
|
||||
|
||||
client_factory = WebSocketClientFactory()
|
||||
client_factory.protocol = TransitWebSocketClientProtocol
|
||||
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
|
||||
client_protocol.disconnect = client_protocol.dropConnection
|
||||
|
||||
pump = iosim.connect(
|
||||
ws_protocol,
|
||||
iosim.makeFakeServer(ws_protocol),
|
||||
client_protocol,
|
||||
iosim.makeFakeClient(client_protocol),
|
||||
)
|
||||
return pump, client_protocol
|
||||
|
||||
|
||||
|
||||
class TransitWebSockets(_Transit, ServerBase, unittest.TestCase):
|
||||
|
||||
def new_protocol(self):
|
||||
return self.new_protocol_ws()
|
||||
|
||||
def new_protocol_ws(self):
|
||||
pump, proto = _new_protocol_ws(self._transit_server, self.log_requests)
|
||||
self._pumps.append(pump)
|
||||
return proto
|
||||
|
||||
def test_websocket_to_tcp(self):
|
||||
"""
|
||||
One client is WebSocket and one is TCP
|
||||
|
@ -437,55 +498,6 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase):
|
|||
p1.send(b"more message")
|
||||
self.flush()
|
||||
|
||||
def new_protocol_ws(self):
|
||||
ws_factory = WebSocketServerFactory("ws://localhost:4002")
|
||||
ws_factory.protocol = WebSocketTransitConnection
|
||||
ws_factory.transit = self._transit_server
|
||||
ws_factory.log_requests = self.log_requests
|
||||
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
|
||||
|
||||
@implementer(IRelayTestClient)
|
||||
class TransitWebSocketClientProtocol(WebSocketClientProtocol):
|
||||
_received = b""
|
||||
connected = False
|
||||
|
||||
def connectionMade(self):
|
||||
self.connected = True
|
||||
return super(TransitWebSocketClientProtocol, self).connectionMade()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.connected = False
|
||||
return super(TransitWebSocketClientProtocol, self).connectionLost(reason)
|
||||
|
||||
def onMessage(self, data, isBinary):
|
||||
self._received = self._received + data
|
||||
|
||||
def send(self, data):
|
||||
self.sendMessage(data, True)
|
||||
|
||||
def get_received_data(self):
|
||||
return self._received
|
||||
|
||||
def reset_received_data(self):
|
||||
self._received = b""
|
||||
|
||||
def disconnect(self):
|
||||
self.sendClose(1000, True)
|
||||
|
||||
client_factory = WebSocketClientFactory()
|
||||
client_factory.protocol = TransitWebSocketClientProtocol
|
||||
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
|
||||
client_protocol.disconnect = client_protocol.dropConnection
|
||||
|
||||
pump = iosim.connect(
|
||||
ws_protocol,
|
||||
iosim.makeFakeServer(ws_protocol),
|
||||
client_protocol,
|
||||
iosim.makeFakeClient(client_protocol),
|
||||
)
|
||||
self._pumps.append(pump)
|
||||
return client_protocol
|
||||
|
||||
|
||||
class Usage(ServerBase, unittest.TestCase):
|
||||
log_requests = True
|
||||
|
@ -636,6 +648,11 @@ class UsageWebSockets(Usage):
|
|||
def new_protocol(self):
|
||||
return self.new_protocol_ws()
|
||||
|
||||
def new_protocol_ws(self):
|
||||
pump, proto = _new_protocol_ws(self._transit_server, self.log_requests)
|
||||
self._pumps.append(pump)
|
||||
return proto
|
||||
|
||||
def test_short(self):
|
||||
"""
|
||||
This test essentially just tests the framing of the line-oriented
|
||||
|
@ -654,31 +671,6 @@ class UsageWebSockets(Usage):
|
|||
with self.assertRaises(ValueError):
|
||||
ws_protocol.onMessage(u"foo", isBinary=False)
|
||||
|
||||
def new_protocol_ws(self):
|
||||
ws_factory = WebSocketServerFactory("ws://localhost:4002")
|
||||
ws_factory.protocol = WebSocketTransitConnection
|
||||
ws_factory.transit = self._transit_server
|
||||
ws_factory.log_requests = self.log_requests
|
||||
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
|
||||
|
||||
class TransitWebSocketClientProtocol(WebSocketClientProtocol):
|
||||
def send(self, data):
|
||||
self.sendMessage(data, True)
|
||||
|
||||
client_factory = WebSocketClientFactory()
|
||||
client_factory.protocol = TransitWebSocketClientProtocol
|
||||
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
|
||||
client_protocol.disconnect = client_protocol.dropConnection
|
||||
|
||||
pump = iosim.connect(
|
||||
ws_protocol,
|
||||
iosim.makeFakeServer(ws_protocol),
|
||||
client_protocol,
|
||||
iosim.makeFakeClient(client_protocol),
|
||||
)
|
||||
self._pumps.append(pump)
|
||||
return client_protocol
|
||||
|
||||
|
||||
class State(unittest.TestCase):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user