diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 70cfed2..961e680 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -3,11 +3,45 @@ from twisted.internet.protocol import ( Protocol, ) from twisted.test import iosim +from zope.interface import ( + Interface, + Attribute, + implementer, +) from ..transit_server import ( Transit, ) +class ITransitClient(Interface): + """ + The client interface used by tests. + """ + + connected = Attribute("True if we are currently connected else False") + + def send(data): + """ + Send some bytes. + :param bytes data: the data to send + """ + + def disconnect(): + """ + Terminate the connection. + """ + + def get_received_data(): + """ + :returns: all the bytes received from the server on this + connection. + """ + + def reset_data(): + """ + Erase any received data to this point. + """ + class ServerBase: log_requests = False @@ -33,14 +67,18 @@ class ServerBase: self._transit_server._debug_log = self.log_requests def new_protocol(self): + """ + Create a new client protocol connected to the server. + :returns: a ITransitClient implementation + """ server_protocol = self._transit_server.buildProtocol(('127.0.0.1', 0)) - # XXX interface? + @implementer(ITransitClient) class TransitClientProtocolTcp(Protocol): """ Speak the transit client protocol used by the tests over TCP """ - received = b"" + _received = b"" connected = False def connectionMade(self): @@ -56,8 +94,13 @@ class ServerBase: self.transport.loseConnection() def dataReceived(self, data): - self.received = self.received + data + self._received = self._received + data + def reset_received_data(self): + self._received = b"" + + def get_received_data(self): + return self._received client_factory = ClientFactory() client_factory.protocol = TransitClientProtocolTcp diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index dfdf8de..9a0f349 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -65,16 +65,16 @@ class _Transit: # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.received, exp) - self.assertEqual(p2.received, exp) + self.assertEqual(p1.get_received_data(), exp) + self.assertEqual(p2.get_received_data(), exp) - p1.received = b"" - p2.received = b"" + p1.reset_received_data() + p2.reset_received_data() s1 = b"data1" p1.send(s1) self.flush() - self.assertEqual(p2.received, s1) + self.assertEqual(p2.get_received_data(), s1) p1.disconnect() p2.disconnect() @@ -93,17 +93,17 @@ class _Transit: # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.received, exp) - self.assertEqual(p2.received, exp) + self.assertEqual(p1.get_received_data(), exp) + self.assertEqual(p2.get_received_data(), exp) - p1.received = b"" - p2.received = b"" + p1.reset_received_data() + p2.reset_received_data() # all data they sent after the handshake should be given to us s1 = b"data1" p1.send(s1) self.flush() - self.assertEqual(p2.received, s1) + self.assertEqual(p2.get_received_data(), s1) p1.disconnect() p2.disconnect() @@ -121,17 +121,17 @@ class _Transit: # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.received, exp) - self.assertEqual(p2.received, exp) + self.assertEqual(p1.get_received_data(), exp) + self.assertEqual(p2.get_received_data(), exp) - p1.received = b"" - p2.received = b"" + p1.reset_received_data() + p2.reset_received_data() # all data they sent after the handshake should be given to us s1 = b"data1" p1.send(s1) self.flush() - self.assertEqual(p2.received, s1) + self.assertEqual(p2.get_received_data(), s1) p1.disconnect() p2.disconnect() @@ -151,17 +151,17 @@ class _Transit: # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.received, exp) - self.assertEqual(p2.received, exp) + self.assertEqual(p1.get_received_data(), exp) + self.assertEqual(p2.get_received_data(), exp) - p1.received = b"" - p2.received = b"" + p1.reset_received_data() + p2.reset_received_data() # all data they sent after the handshake should be given to us s1 = b"data1" p1.send(s1) self.flush() - self.assertEqual(p2.received, s1) + self.assertEqual(p2.get_received_data(), s1) p1.disconnect() p2.disconnect() @@ -207,7 +207,7 @@ class _Transit: self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.received, exp) + self.assertEqual(p1.get_received_data(), exp) p1.disconnect() def test_bad_handshake_old_slow(self): @@ -227,7 +227,7 @@ class _Transit: self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.received, exp) + self.assertEqual(p1.get_received_data(), exp) p1.disconnect() @@ -243,7 +243,7 @@ class _Transit: self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.received, exp) + self.assertEqual(p1.get_received_data(), exp) p1.disconnect() @@ -261,7 +261,7 @@ class _Transit: self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.received, exp) + self.assertEqual(p1.get_received_data(), exp) p1.disconnect() @@ -274,7 +274,7 @@ class _Transit: self.flush() exp = b"impatient\n" - self.assertEqual(p1.received, exp) + self.assertEqual(p1.get_received_data(), exp) p1.disconnect() @@ -289,7 +289,7 @@ class _Transit: self.flush() exp = b"impatient\n" - self.assertEqual(p1.received, exp) + self.assertEqual(p1.get_received_data(), exp) p1.disconnect() @@ -315,7 +315,7 @@ class _Transit: self.flush() exp = b"impatient\n" - self.assertEqual(p1.received, exp) + self.assertEqual(p1.get_received_data(), exp) p1.disconnect()