unify new_protocol_ws, make it a bare helper

This commit is contained in:
meejah 2021-04-18 21:33:01 -06:00
parent 0ce08b66cf
commit 807dfc1c18

View File

@ -367,11 +367,72 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
return self.new_protocol_tcp() return self.new_protocol_tcp()
def _new_protocol_ws(transit_server, log_requests):
"""
Internal helper for test-suites that need to provide WebSocket
client/server pairs.
:returns: a 2-tuple: (iosim.IOPump, protocol)
"""
ws_factory = WebSocketServerFactory("ws://localhost:4002")
ws_factory.protocol = WebSocketTransitConnection
ws_factory.transit = transit_server
ws_factory.log_requests = log_requests
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
@implementer(IRelayTestClient)
class TransitWebSocketClientProtocol(WebSocketClientProtocol):
_received = b""
connected = False
def connectionMade(self):
self.connected = True
return super(TransitWebSocketClientProtocol, self).connectionMade()
def connectionLost(self, reason):
self.connected = False
return super(TransitWebSocketClientProtocol, self).connectionLost(reason)
def onMessage(self, data, isBinary):
self._received = self._received + data
def send(self, data):
self.sendMessage(data, True)
def get_received_data(self):
return self._received
def reset_received_data(self):
self._received = b""
def disconnect(self):
self.sendClose(1000, True)
client_factory = WebSocketClientFactory()
client_factory.protocol = TransitWebSocketClientProtocol
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
client_protocol.disconnect = client_protocol.dropConnection
pump = iosim.connect(
ws_protocol,
iosim.makeFakeServer(ws_protocol),
client_protocol,
iosim.makeFakeClient(client_protocol),
)
return pump, client_protocol
class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): class TransitWebSockets(_Transit, ServerBase, unittest.TestCase):
def new_protocol(self): def new_protocol(self):
return self.new_protocol_ws() return self.new_protocol_ws()
def new_protocol_ws(self):
pump, proto = _new_protocol_ws(self._transit_server, self.log_requests)
self._pumps.append(pump)
return proto
def test_websocket_to_tcp(self): def test_websocket_to_tcp(self):
""" """
One client is WebSocket and one is TCP One client is WebSocket and one is TCP
@ -437,55 +498,6 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase):
p1.send(b"more message") p1.send(b"more message")
self.flush() self.flush()
def new_protocol_ws(self):
ws_factory = WebSocketServerFactory("ws://localhost:4002")
ws_factory.protocol = WebSocketTransitConnection
ws_factory.transit = self._transit_server
ws_factory.log_requests = self.log_requests
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
@implementer(IRelayTestClient)
class TransitWebSocketClientProtocol(WebSocketClientProtocol):
_received = b""
connected = False
def connectionMade(self):
self.connected = True
return super(TransitWebSocketClientProtocol, self).connectionMade()
def connectionLost(self, reason):
self.connected = False
return super(TransitWebSocketClientProtocol, self).connectionLost(reason)
def onMessage(self, data, isBinary):
self._received = self._received + data
def send(self, data):
self.sendMessage(data, True)
def get_received_data(self):
return self._received
def reset_received_data(self):
self._received = b""
def disconnect(self):
self.sendClose(1000, True)
client_factory = WebSocketClientFactory()
client_factory.protocol = TransitWebSocketClientProtocol
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
client_protocol.disconnect = client_protocol.dropConnection
pump = iosim.connect(
ws_protocol,
iosim.makeFakeServer(ws_protocol),
client_protocol,
iosim.makeFakeClient(client_protocol),
)
self._pumps.append(pump)
return client_protocol
class Usage(ServerBase, unittest.TestCase): class Usage(ServerBase, unittest.TestCase):
log_requests = True log_requests = True
@ -636,6 +648,11 @@ class UsageWebSockets(Usage):
def new_protocol(self): def new_protocol(self):
return self.new_protocol_ws() return self.new_protocol_ws()
def new_protocol_ws(self):
pump, proto = _new_protocol_ws(self._transit_server, self.log_requests)
self._pumps.append(pump)
return proto
def test_short(self): def test_short(self):
""" """
This test essentially just tests the framing of the line-oriented This test essentially just tests the framing of the line-oriented
@ -654,31 +671,6 @@ class UsageWebSockets(Usage):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
ws_protocol.onMessage(u"foo", isBinary=False) ws_protocol.onMessage(u"foo", isBinary=False)
def new_protocol_ws(self):
ws_factory = WebSocketServerFactory("ws://localhost:4002")
ws_factory.protocol = WebSocketTransitConnection
ws_factory.transit = self._transit_server
ws_factory.log_requests = self.log_requests
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
class TransitWebSocketClientProtocol(WebSocketClientProtocol):
def send(self, data):
self.sendMessage(data, True)
client_factory = WebSocketClientFactory()
client_factory.protocol = TransitWebSocketClientProtocol
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
client_protocol.disconnect = client_protocol.dropConnection
pump = iosim.connect(
ws_protocol,
iosim.makeFakeServer(ws_protocol),
client_protocol,
iosim.makeFakeClient(client_protocol),
)
self._pumps.append(pump)
return client_protocol
class State(unittest.TestCase): class State(unittest.TestCase):
""" """