more-explicit about which protocol clients use

This commit is contained in:
meejah 2021-04-14 15:10:58 -06:00
parent 5b7ec9ef4c
commit 5a405443b9
2 changed files with 21 additions and 2 deletions

View File

@ -74,6 +74,13 @@ class ServerBase:
self._transit_server = Transit(usage, lambda: 123456789.0) self._transit_server = Transit(usage, lambda: 123456789.0)
def new_protocol(self): def new_protocol(self):
"""
This should be overridden by derived test-case classes to decide
if they want a TCP or WebSockets protocol.
"""
raise NotImplementedError()
def new_protocol_tcp(self):
""" """
Create a new client protocol connected to the server. Create a new client protocol connected to the server.
:returns: a IRelayTestClient implementation :returns: a IRelayTestClient implementation

View File

@ -34,6 +34,9 @@ def handshake(token, side=None):
return hs return hs
class _Transit: class _Transit:
def new_protocol(self):
return self.new_protocol_tcp()
def count(self): def count(self):
return sum([ return sum([
len(potentials) len(potentials)
@ -366,6 +369,9 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase):
# XXX note to self, from pairing with Flo: # XXX note to self, from pairing with Flo:
# - write a WS <--> TCP version of at least one of these tests? # - write a WS <--> TCP version of at least one of these tests?
def new_protocol(self):
return self.new_protocol_ws()
def test_bad_handshake_old_slow(self): def test_bad_handshake_old_slow(self):
""" """
This test only makes sense for TCP This test only makes sense for TCP
@ -398,7 +404,7 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase):
p1.send(b"more message") p1.send(b"more message")
self.flush() self.flush()
def new_protocol(self): def new_protocol_ws(self):
ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory = WebSocketServerFactory("ws://localhost:4002")
ws_factory.protocol = WebSocketTransitConnection ws_factory.protocol = WebSocketTransitConnection
ws_factory.transit = self._transit_server ws_factory.transit = self._transit_server
@ -455,6 +461,9 @@ class Usage(ServerBase, unittest.TestCase):
self._usage = MemoryUsageRecorder() self._usage = MemoryUsageRecorder()
self._transit_server.usage.add_backend(self._usage) self._transit_server.usage.add_backend(self._usage)
def new_protocol(self):
return self.new_protocol_tcp()
def test_empty(self): def test_empty(self):
p1 = self.new_protocol() p1 = self.new_protocol()
# hang up before sending anything # hang up before sending anything
@ -587,6 +596,9 @@ class UsageWebSockets(Usage):
def tearDown(self): def tearDown(self):
return self._pump.stop() return self._pump.stop()
def new_protocol(self):
return self.new_protocol_ws()
def test_short(self): def test_short(self):
""" """
This test essentially just tests the framing of the line-oriented This test essentially just tests the framing of the line-oriented
@ -605,7 +617,7 @@ class UsageWebSockets(Usage):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
ws_protocol.onMessage(u"foo", isBinary=False) ws_protocol.onMessage(u"foo", isBinary=False)
def new_protocol(self): def new_protocol_ws(self):
ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory = WebSocketServerFactory("ws://localhost:4002")
ws_factory.protocol = WebSocketTransitConnection ws_factory.protocol = WebSocketTransitConnection
ws_factory.transit = self._transit_server ws_factory.transit = self._transit_server