diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index c502c33..cb84de1 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -74,6 +74,13 @@ class ServerBase: self._transit_server = Transit(usage, lambda: 123456789.0) def new_protocol(self): + """ + This should be overridden by derived test-case classes to decide + if they want a TCP or WebSockets protocol. + """ + raise NotImplementedError() + + def new_protocol_tcp(self): """ Create a new client protocol connected to the server. :returns: a IRelayTestClient implementation diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 4dbdf31..5937c3b 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -34,6 +34,9 @@ def handshake(token, side=None): return hs class _Transit: + def new_protocol(self): + return self.new_protocol_tcp() + def count(self): return sum([ len(potentials) @@ -366,6 +369,9 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): # XXX note to self, from pairing with Flo: # - write a WS <--> TCP version of at least one of these tests? + def new_protocol(self): + return self.new_protocol_ws() + def test_bad_handshake_old_slow(self): """ This test only makes sense for TCP @@ -398,7 +404,7 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): p1.send(b"more message") self.flush() - def new_protocol(self): + def new_protocol_ws(self): ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory.protocol = WebSocketTransitConnection ws_factory.transit = self._transit_server @@ -455,6 +461,9 @@ class Usage(ServerBase, unittest.TestCase): self._usage = MemoryUsageRecorder() self._transit_server.usage.add_backend(self._usage) + def new_protocol(self): + return self.new_protocol_tcp() + def test_empty(self): p1 = self.new_protocol() # hang up before sending anything @@ -587,6 +596,9 @@ class UsageWebSockets(Usage): def tearDown(self): return self._pump.stop() + def new_protocol(self): + return self.new_protocol_ws() + def test_short(self): """ This test essentially just tests the framing of the line-oriented @@ -605,7 +617,7 @@ class UsageWebSockets(Usage): with self.assertRaises(ValueError): ws_protocol.onMessage(u"foo", isBinary=False) - def new_protocol(self): + def new_protocol_ws(self): ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory.protocol = WebSocketTransitConnection ws_factory.transit = self._transit_server