diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index b1ce269..86029b3 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -43,12 +43,25 @@ class ServerBase: """ Speak the transit client protocol used by the tests over TCP """ + received = b"" + connected = False + + def connectionMade(self): + self.connected = True + + def connectionLost(self, reason): + self.connected = False + def send(self, data): self.transport.write(data) def disconnect(self): self.transport.loseConnection() + def dataReceived(self, data): + self.received = self.received + data + + client_factory = ClientFactory() client_factory.protocol = TransitClientProtocolTcp client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 320ecdf..dfdf8de 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -57,23 +57,27 @@ class _Transit: p2 = self.new_protocol() token1 = b"\x00"*32 - p1.dataReceived(handshake(token1, side=None)) - p2.dataReceived(handshake(token1, side=None)) + p1.send(handshake(token1, side=None)) + self.flush() + p2.send(handshake(token1, side=None)) + self.flush() + self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.transport.value(), exp) - self.assertEqual(p2.transport.value(), exp) + self.assertEqual(p1.received, exp) + self.assertEqual(p2.received, exp) - p1.transport.clear() - p2.transport.clear() + p1.received = b"" + p2.received = b"" s1 = b"data1" - p1.dataReceived(s1) - self.assertEqual(p2.transport.value(), s1) + p1.send(s1) + self.flush() + self.assertEqual(p2.received, s1) - p1.transport.loseConnection() - p2.transport.loseConnection() + p1.disconnect() + p2.disconnect() def test_sided_unsided(self): p1 = self.new_protocol() @@ -81,24 +85,28 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 - p1.dataReceived(handshake(token1, side=side1)) - p2.dataReceived(handshake(token1, side=None)) + p1.send(handshake(token1, side=side1)) + self.flush() + p2.send(handshake(token1, side=None)) + self.flush() + self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.transport.value(), exp) - self.assertEqual(p2.transport.value(), exp) + self.assertEqual(p1.received, exp) + self.assertEqual(p2.received, exp) - p1.transport.clear() - p2.transport.clear() + p1.received = b"" + p2.received = b"" # all data they sent after the handshake should be given to us s1 = b"data1" - p1.dataReceived(s1) - self.assertEqual(p2.transport.value(), s1) + p1.send(s1) + self.flush() + self.assertEqual(p2.received, s1) - p1.transport.loseConnection() - p2.transport.loseConnection() + p1.disconnect() + p2.disconnect() def test_unsided_sided(self): p1 = self.new_protocol() @@ -106,24 +114,27 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 - p1.dataReceived(handshake(token1, side=None)) - p2.dataReceived(handshake(token1, side=side1)) + p1.send(handshake(token1, side=None)) + p2.send(handshake(token1, side=side1)) + self.flush() + self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.transport.value(), exp) - self.assertEqual(p2.transport.value(), exp) + self.assertEqual(p1.received, exp) + self.assertEqual(p2.received, exp) - p1.transport.clear() - p2.transport.clear() + p1.received = b"" + p2.received = b"" # all data they sent after the handshake should be given to us s1 = b"data1" - p1.dataReceived(s1) - self.assertEqual(p2.transport.value(), s1) + p1.send(s1) + self.flush() + self.assertEqual(p2.received, s1) - p1.transport.loseConnection() - p2.transport.loseConnection() + p1.disconnect() + p2.disconnect() def test_both_sided(self): p1 = self.new_protocol() @@ -132,24 +143,28 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 - p1.dataReceived(handshake(token1, side=side1)) - p2.dataReceived(handshake(token1, side=side2)) + p1.send(handshake(token1, side=side1)) + self.flush() + p2.send(handshake(token1, side=side2)) + self.flush() + self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.transport.value(), exp) - self.assertEqual(p2.transport.value(), exp) + self.assertEqual(p1.received, exp) + self.assertEqual(p2.received, exp) - p1.transport.clear() - p2.transport.clear() + p1.received = b"" + p2.received = b"" # all data they sent after the handshake should be given to us s1 = b"data1" - p1.dataReceived(s1) - self.assertEqual(p2.transport.value(), s1) + p1.send(s1) + self.flush() + self.assertEqual(p2.received, s1) - p1.transport.loseConnection() - p2.transport.loseConnection() + p1.disconnect() + p2.disconnect() def test_ignore_same_side(self): p1 = self.new_protocol() @@ -159,41 +174,47 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 - p1.dataReceived(handshake(token1, side=side1)) + p1.send(handshake(token1, side=side1)) + self.flush() self.assertEqual(self.count(), 1) - p2.dataReceived(handshake(token1, side=side1)) + p2.send(handshake(token1, side=side1)) + self.flush() self.assertEqual(self.count(), 2) # same-side connections don't match # when the second side arrives, the spare first connection should be # closed side2 = b"\x02"*8 - p3.dataReceived(handshake(token1, side=side2)) + p3.send(handshake(token1, side=side2)) + self.flush() + self.flush() self.assertEqual(self.count(), 0) self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server._active_connections), 2) # That will trigger a disconnect on exactly one of (p1 or p2). # The other connection should still be connected - self.assertEqual(sum([int(t.transport.connected) for t in [p1, p2]]), 1) + self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1) - p1.transport.loseConnection() - p2.transport.loseConnection() - p3.transport.loseConnection() + p1.disconnect() + p2.disconnect() + p3.disconnect() def test_bad_handshake_old(self): p1 = self.new_protocol() token1 = b"\x00"*32 - p1.dataReceived(b"please DELAY " + hexlify(token1) + b"\n") + p1.send(b"please DELAY " + hexlify(token1) + b"\n") + self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.transport.value(), exp) - p1.transport.loseConnection() + self.assertEqual(p1.received, exp) + p1.disconnect() def test_bad_handshake_old_slow(self): p1 = self.new_protocol() - p1.dataReceived(b"please DELAY ") + p1.send(b"please DELAY ") + self.flush() # As in test_impatience_new_slow, the current state machine has code # that can only be reached if we insert a stall here, so dataReceived # gets called twice. Hopefully we can delete this test once @@ -202,12 +223,13 @@ class _Transit: token1 = b"\x00"*32 # the server waits for the exact number of bytes in the expected # handshake message. to trigger "bad handshake", we must match. - p1.dataReceived(hexlify(token1) + b"\n") + p1.send(hexlify(token1) + b"\n") + self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.received, exp) - p1.transport.loseConnection() + p1.disconnect() def test_bad_handshake_new(self): p1 = self.new_protocol() @@ -216,13 +238,14 @@ class _Transit: side1 = b"\x01"*8 # the server waits for the exact number of bytes in the expected # handshake message. to trigger "bad handshake", we must match. - p1.dataReceived(b"please DELAY " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") + p1.send(b"please DELAY " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.received, exp) - p1.transport.loseConnection() + p1.disconnect() def test_binary_handshake(self): p1 = self.new_protocol() @@ -234,24 +257,26 @@ class _Transit: # UnicodeDecodeError when it tried to coerce the incoming handshake # to unicode, due to the ("\n" in buf) check. This was fixed to use # (b"\n" in buf). This exercises the old failure. - p1.dataReceived(binary_bad_handshake) + p1.send(binary_bad_handshake) + self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.received, exp) - p1.transport.loseConnection() + p1.disconnect() def test_impatience_old(self): p1 = self.new_protocol() token1 = b"\x00"*32 # sending too many bytes is impatience. - p1.dataReceived(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") + p1.send(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") + self.flush() exp = b"impatient\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.received, exp) - p1.transport.loseConnection() + p1.disconnect() def test_impatience_new(self): p1 = self.new_protocol() @@ -259,13 +284,14 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 # sending too many bytes is impatience. - p1.dataReceived(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\nNOWNOWNOW") + p1.send(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\nNOWNOWNOW") + self.flush() exp = b"impatient\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.received, exp) - p1.transport.loseConnection() + p1.disconnect() def test_impatience_new_slow(self): p1 = self.new_protocol() @@ -281,27 +307,29 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 # sending too many bytes is impatience. - p1.dataReceived(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") + p1.send(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + self.flush() - - p1.dataReceived(b"NOWNOWNOW") + p1.send(b"NOWNOWNOW") + self.flush() exp = b"impatient\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.received, exp) - p1.transport.loseConnection() + p1.disconnect() def test_short_handshake(self): p1 = self.new_protocol() # hang up before sending a complete handshake - p1.dataReceived(b"short") - p1.transport.loseConnection() + p1.send(b"short") + self.flush() + p1.disconnect() def test_empty_handshake(self): p1 = self.new_protocol() # hang up before sending anything - p1.transport.loseConnection() + p1.disconnect() class TransitWithLogs(_Transit, ServerBase, unittest.TestCase): log_requests = True @@ -321,7 +349,8 @@ class Usage(ServerBase, unittest.TestCase): def test_empty(self): p1 = self.new_protocol() # hang up before sending anything - p1.transport.loseConnection() + p1.disconnect() + self.flush() # that will log the "empty" usage event self.assertEqual(len(self._usage), 1, self._usage) @@ -331,8 +360,9 @@ class Usage(ServerBase, unittest.TestCase): def test_short(self): p1 = self.new_protocol() # hang up before sending a complete handshake - p1.transport.write(b"short") - p1.transport.loseConnection() + p1.send(b"short") + p1.disconnect() + self.flush() # that will log the "empty" usage event self.assertEqual(len(self._usage), 1, self._usage) @@ -342,9 +372,10 @@ class Usage(ServerBase, unittest.TestCase): def test_errory(self): p1 = self.new_protocol() - p1.dataReceived(b"this is a very bad handshake\n") + p1.send(b"this is a very bad handshake\n") + self.flush() # that will log the "errory" usage event, then drop the connection - p1.transport.loseConnection() + p1.disconnect() self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(result, "errory", self._usage) @@ -354,9 +385,11 @@ class Usage(ServerBase, unittest.TestCase): token1 = b"\x00"*32 side1 = b"\x01"*8 - p1.dataReceived(handshake(token1, side=side1)) + p1.send(handshake(token1, side=side1)) + self.flush() # now we disconnect before the peer connects - p1.transport.loseConnection() + p1.disconnect() + self.flush() self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] @@ -370,15 +403,20 @@ class Usage(ServerBase, unittest.TestCase): token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 - p1.dataReceived(handshake(token1, side=side1)) - p2.dataReceived(handshake(token1, side=side2)) + p1.send(handshake(token1, side=side1)) + self.flush() + p2.send(handshake(token1, side=side2)) + self.flush() self.assertEqual(self._usage, []) # no events yet - p1.dataReceived(b"\x00" * 13) - p2.dataReceived(b"\xff" * 7) + p1.send(b"\x00" * 13) + self.flush() + p2.send(b"\xff" * 7) + self.flush() - p1.transport.loseConnection() + p1.disconnect() + self.flush() self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] @@ -395,28 +433,34 @@ class Usage(ServerBase, unittest.TestCase): token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 - p1a.dataReceived(handshake(token1, side=side1)) - p1b.dataReceived(handshake(token1, side=side1)) + p1a.send(handshake(token1, side=side1)) + self.flush() + p1b.send(handshake(token1, side=side1)) + self.flush() # connect and disconnect a third client (for side1) to exercise the # code that removes a pending connection without removing the entire # token - p1c.dataReceived(handshake(token1, side=side1)) - p1c.transport.loseConnection() + p1c.send(handshake(token1, side=side1)) + p1c.disconnect() + self.flush() self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(result, "lonely", self._usage) - p2.dataReceived(handshake(token1, side=side2)) + p2.send(handshake(token1, side=side2)) + self.flush() + self.flush() self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._usage), 2, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[1] self.assertEqual(result, "redundant", self._usage) # one of the these is unecessary, but probably harmless - p1a.transport.loseConnection() - p1b.transport.loseConnection() + p1a.disconnect() + p1b.disconnect() + self.flush() self.assertEqual(len(self._usage), 3, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[2] self.assertEqual(result, "happy", self._usage)