diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 6101501..943013d 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -48,8 +48,9 @@ class _Transit: self.failUnlessEqual(blur_size(1100e6), 1100e6) self.failUnlessEqual(blur_size(1150e6), 1200e6) + @inlineCallbacks def test_register(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -65,9 +66,10 @@ class _Transit: # the token should be removed too self.assertEqual(len(self._transit.pending_requests._requests), 0) + @inlineCallbacks def test_both_unsided(self): - p1 = self.new_protocol() - p2 = self.new_protocol() + p1 = yield self.new_protocol() + p2 = yield self.new_protocol() token1 = b"\x00"*32 p1.send(handshake(token1, side=None)) @@ -92,9 +94,10 @@ class _Transit: p2.disconnect() self.flush() + @inlineCallbacks def test_sided_unsided(self): - p1 = self.new_protocol() - p2 = self.new_protocol() + p1 = yield self.new_protocol() + p2 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -121,9 +124,10 @@ class _Transit: p2.disconnect() self.flush() + @inlineCallbacks def test_unsided_sided(self): - p1 = self.new_protocol() - p2 = self.new_protocol() + p1 = yield self.new_protocol() + p2 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -148,9 +152,10 @@ class _Transit: p1.disconnect() p2.disconnect() + @inlineCallbacks def test_both_sided(self): - p1 = self.new_protocol() - p2 = self.new_protocol() + p1 = yield self.new_protocol() + p2 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -177,10 +182,11 @@ class _Transit: p1.disconnect() p2.disconnect() + @inlineCallbacks def test_ignore_same_side(self): - p1 = self.new_protocol() - p2 = self.new_protocol() - p3 = self.new_protocol() + p1 = yield self.new_protocol() + p2 = yield self.new_protocol() + p3 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -209,8 +215,9 @@ class _Transit: p2.disconnect() p3.disconnect() + @inlineCallbacks def test_bad_handshake_old(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() token1 = b"\x00"*32 p1.send(b"please DELAY " + hexlify(token1) + b"\n") @@ -220,8 +227,9 @@ class _Transit: self.assertEqual(p1.get_received_data(), exp) p1.disconnect() + @inlineCallbacks def test_bad_handshake_old_slow(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() p1.send(b"please DELAY ") self.flush() @@ -241,8 +249,9 @@ class _Transit: p1.disconnect() + @inlineCallbacks def test_bad_handshake_new(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -257,8 +266,9 @@ class _Transit: p1.disconnect() + @inlineCallbacks def test_binary_handshake(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff" # the embedded \n makes the server trigger early, before the full @@ -275,8 +285,9 @@ class _Transit: p1.disconnect() + @inlineCallbacks def test_impatience_old(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() token1 = b"\x00"*32 # sending too many bytes is impatience. @@ -288,8 +299,9 @@ class _Transit: p1.disconnect() + @inlineCallbacks def test_impatience_new(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -303,8 +315,9 @@ class _Transit: p1.disconnect() + @inlineCallbacks def test_impatience_new_slow(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() # For full coverage, we need dataReceived to see a particular framing # of these two pieces of data, and ITCPTransport doesn't have flush() # (which probably wouldn't work anyways). For now, force a 100ms @@ -329,15 +342,17 @@ class _Transit: p1.disconnect() + @inlineCallbacks def test_short_handshake(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() # hang up before sending a complete handshake p1.send(b"short") self.flush() p1.disconnect() + @inlineCallbacks def test_empty_handshake(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() # hang up before sending anything p1.disconnect() @@ -358,8 +373,9 @@ class Usage(ServerBase, unittest.TestCase): self._usage = MemoryUsageRecorder() self._transit.usage.add_backend(self._usage) + @inlineCallbacks def test_empty(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() # hang up before sending anything p1.disconnect() self.flush() @@ -380,8 +396,9 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual("empty", self._usage.events[0]["mood"]) + @inlineCallbacks def test_errory(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() p1.send(b"this is a very bad handshake\n") self.flush() @@ -390,8 +407,9 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "errory", self._usage) + @inlineCallbacks def test_lonely(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -405,9 +423,10 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(self._usage.events[0]["mood"], "lonely", self._usage) self.assertIdentical(self._usage.events[0]["waiting_time"], None) + @inlineCallbacks def test_one_happy_one_jilted(self): - p1 = self.new_protocol() - p2 = self.new_protocol() + p1 = yield self.new_protocol() + p2 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -432,11 +451,12 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(self._usage.events[0]["total_bytes"], 20) self.assertNotIdentical(self._usage.events[0]["waiting_time"], None) + @inlineCallbacks def test_redundant(self): - p1a = self.new_protocol() - p1b = self.new_protocol() - p1c = self.new_protocol() - p2 = self.new_protocol() + p1a = yield self.new_protocol() + p1b = yield self.new_protocol() + p1c = yield self.new_protocol() + p2 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -453,7 +473,8 @@ class Usage(ServerBase, unittest.TestCase): p1c.disconnect() self.flush() - print(self._usage.events) + for x in self._usage.events: + print(x) self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "lonely") @@ -517,6 +538,5 @@ class UsageWebSockets(Usage): lambda: server_factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 31337)), ) client_proto = yield agent.open("ws://127.0.0.1:4002/", dict()) - print("PROTO", client_proto) return client_proto diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 133fbd5..0c2510a 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -240,18 +240,18 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ IWebSocketChannel API """ - print("onConnect: {}".format(request)) + # print("onConnect: {}".format(request)) # ideally more like self._reactor.seconds() ... but Twisted # doesn't have a good way to get the reactor for a protocol # (besides "use the global one") - print("protocols: {}".format(request.protocols)) - return None#"transit_relay" + # print("protocols: {}".format(request.protocols)) + return None #"transit_relay" def connectionMade(self): """ IProtocol API """ - print("connectionMade") + # print("connectionMade") super(WebSocketTransitConnection, self).connectionMade() self.started_time = time.time() self._first_message = True @@ -261,14 +261,14 @@ class WebSocketTransitConnection(WebSocketServerProtocol): ) def onOpen(self): - print("onOpen") + # print("onOpen") self._state.connection_made(self) def onMessage(self, payload, isBinary): """ We may have a 'handshake' on our hands or we may just have some bytes to relay """ - print("onMessage isBinary={}: {}".format(isBinary, payload)) + # print("onMessage isBinary={}: {}".format(isBinary, payload)) if self._first_message: self._first_message = False token = None @@ -298,6 +298,6 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ IWebSocketChannel API """ - print("onClose", wasClean, code, reason) + # print("onClose", wasClean, code, reason) self._state.connection_lost() # XXX "transit finished", etc