diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 53958fb..8073ee0 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -1,28 +1,129 @@ -from twisted.test import proto_helpers -from ..transit_server import Transit +from twisted.internet.protocol import ( + ClientFactory, + Protocol, +) +from twisted.test import iosim +from zope.interface import ( + Interface, + Attribute, + implementer, +) +from ..transit_server import ( + Transit, +) + + +class IRelayTestClient(Interface): + """ + The client interface used by tests. + """ + + connected = Attribute("True if we are currently connected else False") + + def send(data): + """ + Send some bytes. + :param bytes data: the data to send + """ + + def disconnect(): + """ + Terminate the connection. + """ + + def get_received_data(): + """ + :returns: all the bytes received from the server on this + connection. + """ + + def reset_data(): + """ + Erase any received data to this point. + """ class ServerBase: log_requests = False def setUp(self): + self._pumps = [] self._lp = None if self.log_requests: blur_usage = None else: blur_usage = 60.0 self._setup_relay(blur_usage=blur_usage) - self._transit_server._debug_log = self.log_requests + + def flush(self): + did_work = False + for pump in self._pumps: + did_work = pump.flush() or did_work + if did_work: + self.flush() def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): - self._transit_server = Transit(blur_usage=blur_usage, - log_file=log_file, usage_db=usage_db) + self._transit_server = Transit( + blur_usage=blur_usage, + log_file=log_file, + usage_db=usage_db, + ) + self._transit_server._debug_log = self.log_requests def new_protocol(self): - protocol = self._transit_server.buildProtocol(('127.0.0.1', 0)) - transport = proto_helpers.StringTransportWithDisconnection() - protocol.makeConnection(transport) - transport.protocol = protocol - return protocol + """ + Create a new client protocol connected to the server. + :returns: a IRelayTestClient implementation + """ + server_protocol = self._transit_server.buildProtocol(('127.0.0.1', 0)) + + @implementer(IRelayTestClient) + class TransitClientProtocolTcp(Protocol): + """ + Speak the transit client protocol used by the tests over TCP + """ + _received = b"" + connected = False + + # override Protocol callbacks + + def connectionMade(self): + self.connected = True + return Protocol.connectionMade(self) + + def connectionLost(self, reason): + self.connected = False + return Protocol.connectionLost(self, reason) + + def dataReceived(self, data): + self._received = self._received + data + + # IRelayTestClient + + def send(self, data): + self.transport.write(data) + + def disconnect(self): + self.transport.loseConnection() + + def reset_received_data(self): + self._received = b"" + + def get_received_data(self): + return self._received + + client_factory = ClientFactory() + client_factory.protocol = TransitClientProtocolTcp + client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) + + pump = iosim.connect( + server_protocol, + iosim.makeFakeServer(server_protocol), + client_protocol, + iosim.makeFakeClient(client_protocol), + ) + pump.flush() + self._pumps.append(pump) + return client_protocol def tearDown(self): if self._lp: diff --git a/src/wormhole_transit_relay/test/test_rlimits.py b/src/wormhole_transit_relay/test/test_rlimits.py index 10497e4..1354c40 100644 --- a/src/wormhole_transit_relay/test/test_rlimits.py +++ b/src/wormhole_transit_relay/test/test_rlimits.py @@ -1,5 +1,8 @@ from __future__ import print_function, unicode_literals -import mock +try: + from unittest import mock +except ImportError: + import mock from twisted.trial import unittest from ..increase_rlimits import increase_rlimits diff --git a/src/wormhole_transit_relay/test/test_service.py b/src/wormhole_transit_relay/test/test_service.py index dac642c..003de32 100644 --- a/src/wormhole_transit_relay/test/test_service.py +++ b/src/wormhole_transit_relay/test/test_service.py @@ -1,6 +1,9 @@ from __future__ import unicode_literals, print_function from twisted.trial import unittest -import mock +try: + from unittest import mock +except ImportError: + import mock from twisted.application.service import MultiService from .. import server_tap diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index f9433ef..1f114b1 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -1,6 +1,9 @@ from __future__ import print_function, unicode_literals import os, io, json, sqlite3 -import mock +try: + from unittest import mock +except ImportError: + import mock from twisted.trial import unittest from ..transit_server import Transit from .. import database diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index a4763d9..bca740e 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -41,10 +41,12 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 - p1.dataReceived(handshake(token1, side1)) + p1.send(handshake(token1, side1)) + self.flush() self.assertEqual(self.count(), 1) - p1.transport.loseConnection() + p1.disconnect() + self.flush() self.assertEqual(self.count(), 0) # the token should be removed too @@ -55,23 +57,27 @@ class _Transit: p2 = self.new_protocol() token1 = b"\x00"*32 - p1.dataReceived(handshake(token1, side=None)) - p2.dataReceived(handshake(token1, side=None)) + p1.send(handshake(token1, side=None)) + self.flush() + p2.send(handshake(token1, side=None)) + self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.transport.value(), exp) - self.assertEqual(p2.transport.value(), exp) + self.assertEqual(p1.get_received_data(), exp) + self.assertEqual(p2.get_received_data(), exp) - p1.transport.clear() - p2.transport.clear() + p1.reset_received_data() + p2.reset_received_data() s1 = b"data1" - p1.dataReceived(s1) - self.assertEqual(p2.transport.value(), s1) + p1.send(s1) + self.flush() + self.assertEqual(p2.get_received_data(), s1) - p1.transport.loseConnection() - p2.transport.loseConnection() + p1.disconnect() + p2.disconnect() + self.flush() def test_sided_unsided(self): p1 = self.new_protocol() @@ -79,24 +85,28 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 - p1.dataReceived(handshake(token1, side=side1)) - p2.dataReceived(handshake(token1, side=None)) + p1.send(handshake(token1, side=side1)) + self.flush() + p2.send(handshake(token1, side=None)) + self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.transport.value(), exp) - self.assertEqual(p2.transport.value(), exp) + self.assertEqual(p1.get_received_data(), exp) + self.assertEqual(p2.get_received_data(), exp) - p1.transport.clear() - p2.transport.clear() + p1.reset_received_data() + p2.reset_received_data() # all data they sent after the handshake should be given to us s1 = b"data1" - p1.dataReceived(s1) - self.assertEqual(p2.transport.value(), s1) + p1.send(s1) + self.flush() + self.assertEqual(p2.get_received_data(), s1) - p1.transport.loseConnection() - p2.transport.loseConnection() + p1.disconnect() + p2.disconnect() + self.flush() def test_unsided_sided(self): p1 = self.new_protocol() @@ -104,24 +114,26 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 - p1.dataReceived(handshake(token1, side=None)) - p2.dataReceived(handshake(token1, side=side1)) + p1.send(handshake(token1, side=None)) + p2.send(handshake(token1, side=side1)) + self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.transport.value(), exp) - self.assertEqual(p2.transport.value(), exp) + self.assertEqual(p1.get_received_data(), exp) + self.assertEqual(p2.get_received_data(), exp) - p1.transport.clear() - p2.transport.clear() + p1.reset_received_data() + p2.reset_received_data() # all data they sent after the handshake should be given to us s1 = b"data1" - p1.dataReceived(s1) - self.assertEqual(p2.transport.value(), s1) + p1.send(s1) + self.flush() + self.assertEqual(p2.get_received_data(), s1) - p1.transport.loseConnection() - p2.transport.loseConnection() + p1.disconnect() + p2.disconnect() def test_both_sided(self): p1 = self.new_protocol() @@ -130,24 +142,27 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 - p1.dataReceived(handshake(token1, side=side1)) - p2.dataReceived(handshake(token1, side=side2)) + p1.send(handshake(token1, side=side1)) + self.flush() + p2.send(handshake(token1, side=side2)) + self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.transport.value(), exp) - self.assertEqual(p2.transport.value(), exp) + self.assertEqual(p1.get_received_data(), exp) + self.assertEqual(p2.get_received_data(), exp) - p1.transport.clear() - p2.transport.clear() + p1.reset_received_data() + p2.reset_received_data() # all data they sent after the handshake should be given to us s1 = b"data1" - p1.dataReceived(s1) - self.assertEqual(p2.transport.value(), s1) + p1.send(s1) + self.flush() + self.assertEqual(p2.get_received_data(), s1) - p1.transport.loseConnection() - p2.transport.loseConnection() + p1.disconnect() + p2.disconnect() def test_ignore_same_side(self): p1 = self.new_protocol() @@ -157,41 +172,46 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 - p1.dataReceived(handshake(token1, side=side1)) + p1.send(handshake(token1, side=side1)) + self.flush() self.assertEqual(self.count(), 1) - p2.dataReceived(handshake(token1, side=side1)) + p2.send(handshake(token1, side=side1)) + self.flush() self.assertEqual(self.count(), 2) # same-side connections don't match # when the second side arrives, the spare first connection should be # closed side2 = b"\x02"*8 - p3.dataReceived(handshake(token1, side=side2)) + p3.send(handshake(token1, side=side2)) + self.flush() self.assertEqual(self.count(), 0) self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server._active_connections), 2) # That will trigger a disconnect on exactly one of (p1 or p2). # The other connection should still be connected - self.assertEqual(sum([int(t.transport.connected) for t in [p1, p2]]), 1) + self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1) - p1.transport.loseConnection() - p2.transport.loseConnection() - p3.transport.loseConnection() + p1.disconnect() + p2.disconnect() + p3.disconnect() def test_bad_handshake_old(self): p1 = self.new_protocol() token1 = b"\x00"*32 - p1.dataReceived(b"please DELAY " + hexlify(token1) + b"\n") + p1.send(b"please DELAY " + hexlify(token1) + b"\n") + self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.transport.value(), exp) - p1.transport.loseConnection() + self.assertEqual(p1.get_received_data(), exp) + p1.disconnect() def test_bad_handshake_old_slow(self): p1 = self.new_protocol() - p1.dataReceived(b"please DELAY ") + p1.send(b"please DELAY ") + self.flush() # As in test_impatience_new_slow, the current state machine has code # that can only be reached if we insert a stall here, so dataReceived # gets called twice. Hopefully we can delete this test once @@ -200,12 +220,13 @@ class _Transit: token1 = b"\x00"*32 # the server waits for the exact number of bytes in the expected # handshake message. to trigger "bad handshake", we must match. - p1.dataReceived(hexlify(token1) + b"\n") + p1.send(hexlify(token1) + b"\n") + self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.get_received_data(), exp) - p1.transport.loseConnection() + p1.disconnect() def test_bad_handshake_new(self): p1 = self.new_protocol() @@ -214,13 +235,14 @@ class _Transit: side1 = b"\x01"*8 # the server waits for the exact number of bytes in the expected # handshake message. to trigger "bad handshake", we must match. - p1.dataReceived(b"please DELAY " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") + p1.send(b"please DELAY " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.get_received_data(), exp) - p1.transport.loseConnection() + p1.disconnect() def test_binary_handshake(self): p1 = self.new_protocol() @@ -232,24 +254,26 @@ class _Transit: # UnicodeDecodeError when it tried to coerce the incoming handshake # to unicode, due to the ("\n" in buf) check. This was fixed to use # (b"\n" in buf). This exercises the old failure. - p1.dataReceived(binary_bad_handshake) + p1.send(binary_bad_handshake) + self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.get_received_data(), exp) - p1.transport.loseConnection() + p1.disconnect() def test_impatience_old(self): p1 = self.new_protocol() token1 = b"\x00"*32 # sending too many bytes is impatience. - p1.dataReceived(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") + p1.send(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") + self.flush() exp = b"impatient\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.get_received_data(), exp) - p1.transport.loseConnection() + p1.disconnect() def test_impatience_new(self): p1 = self.new_protocol() @@ -257,13 +281,14 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 # sending too many bytes is impatience. - p1.dataReceived(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\nNOWNOWNOW") + p1.send(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\nNOWNOWNOW") + self.flush() exp = b"impatient\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.get_received_data(), exp) - p1.transport.loseConnection() + p1.disconnect() def test_impatience_new_slow(self): p1 = self.new_protocol() @@ -279,27 +304,29 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 # sending too many bytes is impatience. - p1.dataReceived(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") + p1.send(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + self.flush() - - p1.dataReceived(b"NOWNOWNOW") + p1.send(b"NOWNOWNOW") + self.flush() exp = b"impatient\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.get_received_data(), exp) - p1.transport.loseConnection() + p1.disconnect() def test_short_handshake(self): p1 = self.new_protocol() # hang up before sending a complete handshake - p1.dataReceived(b"short") - p1.transport.loseConnection() + p1.send(b"short") + self.flush() + p1.disconnect() def test_empty_handshake(self): p1 = self.new_protocol() # hang up before sending anything - p1.transport.loseConnection() + p1.disconnect() class TransitWithLogs(_Transit, ServerBase, unittest.TestCase): log_requests = True @@ -308,6 +335,8 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): log_requests = False class Usage(ServerBase, unittest.TestCase): + log_requests = True + def setUp(self): super(Usage, self).setUp() self._usage = [] @@ -319,7 +348,8 @@ class Usage(ServerBase, unittest.TestCase): def test_empty(self): p1 = self.new_protocol() # hang up before sending anything - p1.transport.loseConnection() + p1.disconnect() + self.flush() # that will log the "empty" usage event self.assertEqual(len(self._usage), 1, self._usage) @@ -329,8 +359,9 @@ class Usage(ServerBase, unittest.TestCase): def test_short(self): p1 = self.new_protocol() # hang up before sending a complete handshake - p1.transport.write(b"short") - p1.transport.loseConnection() + p1.send(b"short") + p1.disconnect() + self.flush() # that will log the "empty" usage event self.assertEqual(len(self._usage), 1, self._usage) @@ -340,9 +371,10 @@ class Usage(ServerBase, unittest.TestCase): def test_errory(self): p1 = self.new_protocol() - p1.dataReceived(b"this is a very bad handshake\n") + p1.send(b"this is a very bad handshake\n") + self.flush() # that will log the "errory" usage event, then drop the connection - p1.transport.loseConnection() + p1.disconnect() self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(result, "errory", self._usage) @@ -352,9 +384,11 @@ class Usage(ServerBase, unittest.TestCase): token1 = b"\x00"*32 side1 = b"\x01"*8 - p1.dataReceived(handshake(token1, side=side1)) + p1.send(handshake(token1, side=side1)) + self.flush() # now we disconnect before the peer connects - p1.transport.loseConnection() + p1.disconnect() + self.flush() self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] @@ -368,15 +402,20 @@ class Usage(ServerBase, unittest.TestCase): token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 - p1.dataReceived(handshake(token1, side=side1)) - p2.dataReceived(handshake(token1, side=side2)) + p1.send(handshake(token1, side=side1)) + self.flush() + p2.send(handshake(token1, side=side2)) + self.flush() self.assertEqual(self._usage, []) # no events yet - p1.dataReceived(b"\x00" * 13) - p2.dataReceived(b"\xff" * 7) + p1.send(b"\x00" * 13) + self.flush() + p2.send(b"\xff" * 7) + self.flush() - p1.transport.loseConnection() + p1.disconnect() + self.flush() self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] @@ -393,28 +432,33 @@ class Usage(ServerBase, unittest.TestCase): token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 - p1a.dataReceived(handshake(token1, side=side1)) - p1b.dataReceived(handshake(token1, side=side1)) + p1a.send(handshake(token1, side=side1)) + self.flush() + p1b.send(handshake(token1, side=side1)) + self.flush() # connect and disconnect a third client (for side1) to exercise the # code that removes a pending connection without removing the entire # token - p1c.dataReceived(handshake(token1, side=side1)) - p1c.transport.loseConnection() + p1c.send(handshake(token1, side=side1)) + p1c.disconnect() + self.flush() self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(result, "lonely", self._usage) - p2.dataReceived(handshake(token1, side=side2)) + p2.send(handshake(token1, side=side2)) + self.flush() self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._usage), 2, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[1] self.assertEqual(result, "redundant", self._usage) # one of the these is unecessary, but probably harmless - p1a.transport.loseConnection() - p1b.transport.loseConnection() + p1a.disconnect() + p1b.disconnect() + self.flush() self.assertEqual(len(self._usage), 3, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[2] self.assertEqual(result, "happy", self._usage) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 91d84e0..426f50f 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -84,16 +84,16 @@ class TransitConnection(LineReceiver): # point the sender will only transmit data as fast as the # receiver can handle it. if self._sent_ok: - if not self._buddy: - # Our buddy disconnected (we're "jilted"), so we hung up too, - # but our incoming data hasn't stopped yet (it will in a - # moment, after our disconnect makes a roundtrip through the - # kernel). This probably means the file receiver hung up, and - # this connection is the file sender. In may-2020 this - # happened 11 times in 40 days. - return - self._total_sent += len(data) - self._buddy.transport.write(data) + # if self._buddy is None then our buddy disconnected + # (we're "jilted"), so we hung up too, but our incoming + # data hasn't stopped yet (it will in a moment, after our + # disconnect makes a roundtrip through the kernel). This + # probably means the file receiver hung up, and this + # connection is the file sender. In may-2020 this happened + # 11 times in 40 days. + if self._buddy: + self._total_sent += len(data) + self._buddy.transport.write(data) return # handshake is complete but not yet sent_ok