diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 8690e91..a98a338 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -78,7 +78,10 @@ class ServerBase: Create a new client protocol connected to the server. :returns: a IRelayTestClient implementation """ - server_protocol = self._transit_server.buildProtocol(('127.0.0.1', 0)) + server_factory = ServerFactory() + server_factory.protocol = TransitConnection + server_factory.transit = self._transit_server + server_protocol = server_factory.buildProtocol(('127.0.0.1', 0)) @implementer(IRelayTestClient) class TransitClientProtocolTcp(Protocol): diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 795b996..e75dac2 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -36,7 +36,7 @@ class _Transit: return sum([ len(potentials) for potentials - in self._transit.pending_requests._requests.values() + in self._transit_server.pending_requests._requests.values() ]) def test_blur_size(self): @@ -56,9 +56,8 @@ class _Transit: self.failUnlessEqual(blur_size(1100e6), 1100e6) self.failUnlessEqual(blur_size(1150e6), 1200e6) - @inlineCallbacks def test_register(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -72,12 +71,11 @@ class _Transit: self.assertEqual(self.count(), 0) # the token should be removed too - self.assertEqual(len(self._transit.pending_requests._requests), 0) + self.assertEqual(len(self._transit_server.pending_requests._requests), 0) - @inlineCallbacks def test_both_unsided(self): - p1 = yield self.new_protocol() - p2 = yield self.new_protocol() + p1 = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 p1.send(handshake(token1, side=None)) @@ -102,10 +100,9 @@ class _Transit: p2.disconnect() self.flush() - @inlineCallbacks def test_sided_unsided(self): - p1 = yield self.new_protocol() - p2 = yield self.new_protocol() + p1 = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -132,10 +129,9 @@ class _Transit: p2.disconnect() self.flush() - @inlineCallbacks def test_unsided_sided(self): - p1 = yield self.new_protocol() - p2 = yield self.new_protocol() + p1 = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -160,10 +156,9 @@ class _Transit: p1.disconnect() p2.disconnect() - @inlineCallbacks def test_both_sided(self): - p1 = yield self.new_protocol() - p2 = yield self.new_protocol() + p1 = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -190,11 +185,10 @@ class _Transit: p1.disconnect() p2.disconnect() - @inlineCallbacks def test_ignore_same_side(self): - p1 = yield self.new_protocol() - p2 = yield self.new_protocol() - p3 = yield self.new_protocol() + p1 = self.new_protocol() + p2 = self.new_protocol() + p3 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -213,8 +207,8 @@ class _Transit: p3.send(handshake(token1, side=side2)) self.flush() self.assertEqual(self.count(), 0) - self.assertEqual(len(self._transit.pending_requests._requests), 0) - self.assertEqual(len(self._transit.active_connections._connections), 2) + self.assertEqual(len(self._transit_server.pending_requests._requests), 0) + self.assertEqual(len(self._transit_server.active_connections._connections), 2) # That will trigger a disconnect on exactly one of (p1 or p2). # The other connection should still be connected self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1) @@ -223,9 +217,8 @@ class _Transit: p2.disconnect() p3.disconnect() - @inlineCallbacks def test_bad_handshake_old(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() token1 = b"\x00"*32 p1.send(b"please DELAY " + hexlify(token1) + b"\n") @@ -235,9 +228,8 @@ class _Transit: self.assertEqual(p1.get_received_data(), exp) p1.disconnect() - @inlineCallbacks def test_bad_handshake_old_slow(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() p1.send(b"please DELAY ") self.flush() @@ -257,9 +249,8 @@ class _Transit: p1.disconnect() - @inlineCallbacks def test_bad_handshake_new(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -274,9 +265,8 @@ class _Transit: p1.disconnect() - @inlineCallbacks def test_binary_handshake(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff" # the embedded \n makes the server trigger early, before the full @@ -293,9 +283,8 @@ class _Transit: p1.disconnect() - @inlineCallbacks def test_impatience_old(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() token1 = b"\x00"*32 # sending too many bytes is impatience. @@ -307,9 +296,8 @@ class _Transit: p1.disconnect() - @inlineCallbacks def test_impatience_new(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -323,9 +311,8 @@ class _Transit: p1.disconnect() - @inlineCallbacks def test_impatience_new_slow(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() # For full coverage, we need dataReceived to see a particular framing # of these two pieces of data, and ITCPTransport doesn't have flush() # (which probably wouldn't work anyways). For now, force a 100ms @@ -350,17 +337,15 @@ class _Transit: p1.disconnect() - @inlineCallbacks def test_short_handshake(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() # hang up before sending a complete handshake p1.send(b"short") self.flush() p1.disconnect() - @inlineCallbacks def test_empty_handshake(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() # hang up before sending anything p1.disconnect() @@ -379,11 +364,10 @@ class Usage(ServerBase, unittest.TestCase): def setUp(self): super(Usage, self).setUp() self._usage = MemoryUsageRecorder() - self._transit.usage.add_backend(self._usage) + self._transit_server.usage.add_backend(self._usage) - @inlineCallbacks def test_empty(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() # hang up before sending anything p1.disconnect() self.flush() @@ -392,9 +376,8 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "empty", self._usage) - @inlineCallbacks def test_short(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() # hang up before sending a complete handshake p1.send(b"short") p1.disconnect() @@ -404,9 +387,8 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual("empty", self._usage.events[0]["mood"]) - @inlineCallbacks def test_errory(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() p1.send(b"this is a very bad handshake\n") self.flush() @@ -415,9 +397,8 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "errory", self._usage) - @inlineCallbacks def test_lonely(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -431,10 +412,9 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(self._usage.events[0]["mood"], "lonely", self._usage) self.assertIdentical(self._usage.events[0]["waiting_time"], None) - @inlineCallbacks def test_one_happy_one_jilted(self): - p1 = yield self.new_protocol() - p2 = yield self.new_protocol() + p1 = self.new_protocol() + p2 = self.new_protocol() print(dir(p1.factory)) return @@ -462,12 +442,11 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(self._usage.events[0]["total_bytes"], 20) self.assertNotIdentical(self._usage.events[0]["waiting_time"], None) - @inlineCallbacks def test_redundant(self): - p1a = yield self.new_protocol() - p1b = yield self.new_protocol() - p1c = yield self.new_protocol() - p2 = yield self.new_protocol() + p1a = self.new_protocol() + p1b = self.new_protocol() + p1c = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -533,23 +512,6 @@ class UsageWebSockets(Usage): def tearDown(self): return self._pump.stop() - @inlineCallbacks - def new_protocol(self): - - class RelayFactory(WebSocketServerFactory): - protocol = WebSocketTransitConnection - websocket_protocols = ["transit_relay"] - transit = self._transit - - server_factory = RelayFactory("ws://localhost:4002") - - agent = create_memory_agent( - self._reactor, - self._pump, - lambda: server_factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 31337)), - ) - client_proto = yield agent.open("ws://127.0.0.1:4002/", dict()) - return client_proto class New(unittest.TestCase): @@ -577,9 +539,9 @@ class New(unittest.TestCase): log_file=log_file, usage_db=usage_db, ) - self._transit = Transit(usage, lambda: 123456789.0) - self._transit._debug_log = self.log_requests - self._transit.usage.add_backend(self._usage) + self._transit_server = Transit(usage, lambda: 123456789.0) + self._transit_server._debug_log = self.log_requests + self._transit_server.usage.add_backend(self._usage) def new_protocol(self): if False: @@ -590,7 +552,7 @@ class New(unittest.TestCase): def _new_protocol_tcp(self): server_factory = ServerFactory() server_factory.protocol = TransitConnection - server_factory.transit = self._transit + server_factory.transit = self._transit_server server_protocol = server_factory.buildProtocol(('127.0.0.1', 0)) class ClientProtocol(protocol.Protocol): @@ -618,7 +580,7 @@ class New(unittest.TestCase): def _new_protocol_ws(self): ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url ws_factory.protocol = WebSocketTransitConnection - ws_factory.transit = self._transit + ws_factory.transit = self._transit_server ws_factory.websocket_protocols = ["binary"] ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))