From b9c2bbc524627c84b87585ed6002622f9443e720 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 2 Apr 2021 15:28:08 -0600 Subject: [PATCH 01/13] refactor to use IOPump: one test passes --- src/wormhole_transit_relay/test/common.py | 57 ++++++++++++++++--- .../test/test_transit_server.py | 6 +- 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 53958fb..b1ce269 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -1,28 +1,67 @@ from twisted.test import proto_helpers -from ..transit_server import Transit +from twisted.internet.protocol import ( + ServerFactory, + ClientFactory, + Protocol, +) +from twisted.test import iosim +from ..transit_server import ( + Transit, + TransitConnection, +) + class ServerBase: log_requests = False def setUp(self): + self._pumps = [] self._lp = None if self.log_requests: blur_usage = None else: blur_usage = 60.0 self._setup_relay(blur_usage=blur_usage) - self._transit_server._debug_log = self.log_requests + + def flush(self): + for pump in self._pumps: + pump.flush() def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): - self._transit_server = Transit(blur_usage=blur_usage, - log_file=log_file, usage_db=usage_db) + self._transit_server = Transit( + blur_usage=blur_usage, + log_file=log_file, + usage_db=usage_db, + ) + self._transit_server._debug_log = self.log_requests def new_protocol(self): - protocol = self._transit_server.buildProtocol(('127.0.0.1', 0)) - transport = proto_helpers.StringTransportWithDisconnection() - protocol.makeConnection(transport) - transport.protocol = protocol - return protocol + server_protocol = self._transit_server.buildProtocol(('127.0.0.1', 0)) + + # XXX interface? + class TransitClientProtocolTcp(Protocol): + """ + Speak the transit client protocol used by the tests over TCP + """ + def send(self, data): + self.transport.write(data) + + def disconnect(self): + self.transport.loseConnection() + + client_factory = ClientFactory() + client_factory.protocol = TransitClientProtocolTcp + client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) + + pump = iosim.connect( + server_protocol, + iosim.makeFakeServer(server_protocol), + client_protocol, + iosim.makeFakeClient(client_protocol), + ) + pump.flush() + self._pumps.append(pump) + return client_protocol def tearDown(self): if self._lp: diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index a4763d9..320ecdf 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -41,10 +41,12 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 - p1.dataReceived(handshake(token1, side1)) + p1.send(handshake(token1, side1)) + self.flush() self.assertEqual(self.count(), 1) - p1.transport.loseConnection() + p1.disconnect() + self.flush() self.assertEqual(self.count(), 0) # the token should be removed too From 5e21a3c35a7b5beabe363ef9799c3ac1920156e2 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 2 Apr 2021 15:50:37 -0600 Subject: [PATCH 02/13] all tests pass --- src/wormhole_transit_relay/test/common.py | 13 + .../test/test_transit_server.py | 236 +++++++++++------- 2 files changed, 153 insertions(+), 96 deletions(-) diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index b1ce269..86029b3 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -43,12 +43,25 @@ class ServerBase: """ Speak the transit client protocol used by the tests over TCP """ + received = b"" + connected = False + + def connectionMade(self): + self.connected = True + + def connectionLost(self, reason): + self.connected = False + def send(self, data): self.transport.write(data) def disconnect(self): self.transport.loseConnection() + def dataReceived(self, data): + self.received = self.received + data + + client_factory = ClientFactory() client_factory.protocol = TransitClientProtocolTcp client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 320ecdf..dfdf8de 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -57,23 +57,27 @@ class _Transit: p2 = self.new_protocol() token1 = b"\x00"*32 - p1.dataReceived(handshake(token1, side=None)) - p2.dataReceived(handshake(token1, side=None)) + p1.send(handshake(token1, side=None)) + self.flush() + p2.send(handshake(token1, side=None)) + self.flush() + self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.transport.value(), exp) - self.assertEqual(p2.transport.value(), exp) + self.assertEqual(p1.received, exp) + self.assertEqual(p2.received, exp) - p1.transport.clear() - p2.transport.clear() + p1.received = b"" + p2.received = b"" s1 = b"data1" - p1.dataReceived(s1) - self.assertEqual(p2.transport.value(), s1) + p1.send(s1) + self.flush() + self.assertEqual(p2.received, s1) - p1.transport.loseConnection() - p2.transport.loseConnection() + p1.disconnect() + p2.disconnect() def test_sided_unsided(self): p1 = self.new_protocol() @@ -81,24 +85,28 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 - p1.dataReceived(handshake(token1, side=side1)) - p2.dataReceived(handshake(token1, side=None)) + p1.send(handshake(token1, side=side1)) + self.flush() + p2.send(handshake(token1, side=None)) + self.flush() + self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.transport.value(), exp) - self.assertEqual(p2.transport.value(), exp) + self.assertEqual(p1.received, exp) + self.assertEqual(p2.received, exp) - p1.transport.clear() - p2.transport.clear() + p1.received = b"" + p2.received = b"" # all data they sent after the handshake should be given to us s1 = b"data1" - p1.dataReceived(s1) - self.assertEqual(p2.transport.value(), s1) + p1.send(s1) + self.flush() + self.assertEqual(p2.received, s1) - p1.transport.loseConnection() - p2.transport.loseConnection() + p1.disconnect() + p2.disconnect() def test_unsided_sided(self): p1 = self.new_protocol() @@ -106,24 +114,27 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 - p1.dataReceived(handshake(token1, side=None)) - p2.dataReceived(handshake(token1, side=side1)) + p1.send(handshake(token1, side=None)) + p2.send(handshake(token1, side=side1)) + self.flush() + self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.transport.value(), exp) - self.assertEqual(p2.transport.value(), exp) + self.assertEqual(p1.received, exp) + self.assertEqual(p2.received, exp) - p1.transport.clear() - p2.transport.clear() + p1.received = b"" + p2.received = b"" # all data they sent after the handshake should be given to us s1 = b"data1" - p1.dataReceived(s1) - self.assertEqual(p2.transport.value(), s1) + p1.send(s1) + self.flush() + self.assertEqual(p2.received, s1) - p1.transport.loseConnection() - p2.transport.loseConnection() + p1.disconnect() + p2.disconnect() def test_both_sided(self): p1 = self.new_protocol() @@ -132,24 +143,28 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 - p1.dataReceived(handshake(token1, side=side1)) - p2.dataReceived(handshake(token1, side=side2)) + p1.send(handshake(token1, side=side1)) + self.flush() + p2.send(handshake(token1, side=side2)) + self.flush() + self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.transport.value(), exp) - self.assertEqual(p2.transport.value(), exp) + self.assertEqual(p1.received, exp) + self.assertEqual(p2.received, exp) - p1.transport.clear() - p2.transport.clear() + p1.received = b"" + p2.received = b"" # all data they sent after the handshake should be given to us s1 = b"data1" - p1.dataReceived(s1) - self.assertEqual(p2.transport.value(), s1) + p1.send(s1) + self.flush() + self.assertEqual(p2.received, s1) - p1.transport.loseConnection() - p2.transport.loseConnection() + p1.disconnect() + p2.disconnect() def test_ignore_same_side(self): p1 = self.new_protocol() @@ -159,41 +174,47 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 - p1.dataReceived(handshake(token1, side=side1)) + p1.send(handshake(token1, side=side1)) + self.flush() self.assertEqual(self.count(), 1) - p2.dataReceived(handshake(token1, side=side1)) + p2.send(handshake(token1, side=side1)) + self.flush() self.assertEqual(self.count(), 2) # same-side connections don't match # when the second side arrives, the spare first connection should be # closed side2 = b"\x02"*8 - p3.dataReceived(handshake(token1, side=side2)) + p3.send(handshake(token1, side=side2)) + self.flush() + self.flush() self.assertEqual(self.count(), 0) self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server._active_connections), 2) # That will trigger a disconnect on exactly one of (p1 or p2). # The other connection should still be connected - self.assertEqual(sum([int(t.transport.connected) for t in [p1, p2]]), 1) + self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1) - p1.transport.loseConnection() - p2.transport.loseConnection() - p3.transport.loseConnection() + p1.disconnect() + p2.disconnect() + p3.disconnect() def test_bad_handshake_old(self): p1 = self.new_protocol() token1 = b"\x00"*32 - p1.dataReceived(b"please DELAY " + hexlify(token1) + b"\n") + p1.send(b"please DELAY " + hexlify(token1) + b"\n") + self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.transport.value(), exp) - p1.transport.loseConnection() + self.assertEqual(p1.received, exp) + p1.disconnect() def test_bad_handshake_old_slow(self): p1 = self.new_protocol() - p1.dataReceived(b"please DELAY ") + p1.send(b"please DELAY ") + self.flush() # As in test_impatience_new_slow, the current state machine has code # that can only be reached if we insert a stall here, so dataReceived # gets called twice. Hopefully we can delete this test once @@ -202,12 +223,13 @@ class _Transit: token1 = b"\x00"*32 # the server waits for the exact number of bytes in the expected # handshake message. to trigger "bad handshake", we must match. - p1.dataReceived(hexlify(token1) + b"\n") + p1.send(hexlify(token1) + b"\n") + self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.received, exp) - p1.transport.loseConnection() + p1.disconnect() def test_bad_handshake_new(self): p1 = self.new_protocol() @@ -216,13 +238,14 @@ class _Transit: side1 = b"\x01"*8 # the server waits for the exact number of bytes in the expected # handshake message. to trigger "bad handshake", we must match. - p1.dataReceived(b"please DELAY " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") + p1.send(b"please DELAY " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.received, exp) - p1.transport.loseConnection() + p1.disconnect() def test_binary_handshake(self): p1 = self.new_protocol() @@ -234,24 +257,26 @@ class _Transit: # UnicodeDecodeError when it tried to coerce the incoming handshake # to unicode, due to the ("\n" in buf) check. This was fixed to use # (b"\n" in buf). This exercises the old failure. - p1.dataReceived(binary_bad_handshake) + p1.send(binary_bad_handshake) + self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.received, exp) - p1.transport.loseConnection() + p1.disconnect() def test_impatience_old(self): p1 = self.new_protocol() token1 = b"\x00"*32 # sending too many bytes is impatience. - p1.dataReceived(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") + p1.send(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") + self.flush() exp = b"impatient\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.received, exp) - p1.transport.loseConnection() + p1.disconnect() def test_impatience_new(self): p1 = self.new_protocol() @@ -259,13 +284,14 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 # sending too many bytes is impatience. - p1.dataReceived(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\nNOWNOWNOW") + p1.send(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\nNOWNOWNOW") + self.flush() exp = b"impatient\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.received, exp) - p1.transport.loseConnection() + p1.disconnect() def test_impatience_new_slow(self): p1 = self.new_protocol() @@ -281,27 +307,29 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 # sending too many bytes is impatience. - p1.dataReceived(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") + p1.send(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + self.flush() - - p1.dataReceived(b"NOWNOWNOW") + p1.send(b"NOWNOWNOW") + self.flush() exp = b"impatient\n" - self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p1.received, exp) - p1.transport.loseConnection() + p1.disconnect() def test_short_handshake(self): p1 = self.new_protocol() # hang up before sending a complete handshake - p1.dataReceived(b"short") - p1.transport.loseConnection() + p1.send(b"short") + self.flush() + p1.disconnect() def test_empty_handshake(self): p1 = self.new_protocol() # hang up before sending anything - p1.transport.loseConnection() + p1.disconnect() class TransitWithLogs(_Transit, ServerBase, unittest.TestCase): log_requests = True @@ -321,7 +349,8 @@ class Usage(ServerBase, unittest.TestCase): def test_empty(self): p1 = self.new_protocol() # hang up before sending anything - p1.transport.loseConnection() + p1.disconnect() + self.flush() # that will log the "empty" usage event self.assertEqual(len(self._usage), 1, self._usage) @@ -331,8 +360,9 @@ class Usage(ServerBase, unittest.TestCase): def test_short(self): p1 = self.new_protocol() # hang up before sending a complete handshake - p1.transport.write(b"short") - p1.transport.loseConnection() + p1.send(b"short") + p1.disconnect() + self.flush() # that will log the "empty" usage event self.assertEqual(len(self._usage), 1, self._usage) @@ -342,9 +372,10 @@ class Usage(ServerBase, unittest.TestCase): def test_errory(self): p1 = self.new_protocol() - p1.dataReceived(b"this is a very bad handshake\n") + p1.send(b"this is a very bad handshake\n") + self.flush() # that will log the "errory" usage event, then drop the connection - p1.transport.loseConnection() + p1.disconnect() self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(result, "errory", self._usage) @@ -354,9 +385,11 @@ class Usage(ServerBase, unittest.TestCase): token1 = b"\x00"*32 side1 = b"\x01"*8 - p1.dataReceived(handshake(token1, side=side1)) + p1.send(handshake(token1, side=side1)) + self.flush() # now we disconnect before the peer connects - p1.transport.loseConnection() + p1.disconnect() + self.flush() self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] @@ -370,15 +403,20 @@ class Usage(ServerBase, unittest.TestCase): token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 - p1.dataReceived(handshake(token1, side=side1)) - p2.dataReceived(handshake(token1, side=side2)) + p1.send(handshake(token1, side=side1)) + self.flush() + p2.send(handshake(token1, side=side2)) + self.flush() self.assertEqual(self._usage, []) # no events yet - p1.dataReceived(b"\x00" * 13) - p2.dataReceived(b"\xff" * 7) + p1.send(b"\x00" * 13) + self.flush() + p2.send(b"\xff" * 7) + self.flush() - p1.transport.loseConnection() + p1.disconnect() + self.flush() self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] @@ -395,28 +433,34 @@ class Usage(ServerBase, unittest.TestCase): token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 - p1a.dataReceived(handshake(token1, side=side1)) - p1b.dataReceived(handshake(token1, side=side1)) + p1a.send(handshake(token1, side=side1)) + self.flush() + p1b.send(handshake(token1, side=side1)) + self.flush() # connect and disconnect a third client (for side1) to exercise the # code that removes a pending connection without removing the entire # token - p1c.dataReceived(handshake(token1, side=side1)) - p1c.transport.loseConnection() + p1c.send(handshake(token1, side=side1)) + p1c.disconnect() + self.flush() self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(result, "lonely", self._usage) - p2.dataReceived(handshake(token1, side=side2)) + p2.send(handshake(token1, side=side2)) + self.flush() + self.flush() self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._usage), 2, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[1] self.assertEqual(result, "redundant", self._usage) # one of the these is unecessary, but probably harmless - p1a.transport.loseConnection() - p1b.transport.loseConnection() + p1a.disconnect() + p1b.disconnect() + self.flush() self.assertEqual(len(self._usage), 3, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[2] self.assertEqual(result, "happy", self._usage) From 85f3f5b63cd281f539cffb44a07d281681b9f87e Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 2 Apr 2021 15:52:28 -0600 Subject: [PATCH 03/13] 'mock' location --- src/wormhole_transit_relay/test/test_rlimits.py | 2 +- src/wormhole_transit_relay/test/test_service.py | 2 +- src/wormhole_transit_relay/test/test_stats.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_rlimits.py b/src/wormhole_transit_relay/test/test_rlimits.py index 10497e4..3ee23a9 100644 --- a/src/wormhole_transit_relay/test/test_rlimits.py +++ b/src/wormhole_transit_relay/test/test_rlimits.py @@ -1,5 +1,5 @@ from __future__ import print_function, unicode_literals -import mock +from unittest import mock from twisted.trial import unittest from ..increase_rlimits import increase_rlimits diff --git a/src/wormhole_transit_relay/test/test_service.py b/src/wormhole_transit_relay/test/test_service.py index dac642c..f72765c 100644 --- a/src/wormhole_transit_relay/test/test_service.py +++ b/src/wormhole_transit_relay/test/test_service.py @@ -1,6 +1,6 @@ from __future__ import unicode_literals, print_function from twisted.trial import unittest -import mock +from unittest import mock from twisted.application.service import MultiService from .. import server_tap diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index f9433ef..43b912f 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -1,6 +1,6 @@ from __future__ import print_function, unicode_literals import os, io, json, sqlite3 -import mock +from unittest import mock from twisted.trial import unittest from ..transit_server import Transit from .. import database From 8447f88159d196aa7d4928edb03e62a6c93eddba Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 2 Apr 2021 16:11:18 -0600 Subject: [PATCH 04/13] pyflakes --- src/wormhole_transit_relay/test/common.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 86029b3..70cfed2 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -1,13 +1,10 @@ -from twisted.test import proto_helpers from twisted.internet.protocol import ( - ServerFactory, ClientFactory, Protocol, ) from twisted.test import iosim from ..transit_server import ( Transit, - TransitConnection, ) From 04342964157857532c06c3f4233398e285867bd5 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 2 Apr 2021 19:57:22 -0600 Subject: [PATCH 05/13] still support py27 --- src/wormhole_transit_relay/test/test_rlimits.py | 5 ++++- src/wormhole_transit_relay/test/test_service.py | 5 ++++- src/wormhole_transit_relay/test/test_stats.py | 5 ++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_rlimits.py b/src/wormhole_transit_relay/test/test_rlimits.py index 3ee23a9..1354c40 100644 --- a/src/wormhole_transit_relay/test/test_rlimits.py +++ b/src/wormhole_transit_relay/test/test_rlimits.py @@ -1,5 +1,8 @@ from __future__ import print_function, unicode_literals -from unittest import mock +try: + from unittest import mock +except ImportError: + import mock from twisted.trial import unittest from ..increase_rlimits import increase_rlimits diff --git a/src/wormhole_transit_relay/test/test_service.py b/src/wormhole_transit_relay/test/test_service.py index f72765c..003de32 100644 --- a/src/wormhole_transit_relay/test/test_service.py +++ b/src/wormhole_transit_relay/test/test_service.py @@ -1,6 +1,9 @@ from __future__ import unicode_literals, print_function from twisted.trial import unittest -from unittest import mock +try: + from unittest import mock +except ImportError: + import mock from twisted.application.service import MultiService from .. import server_tap diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index 43b912f..1f114b1 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -1,6 +1,9 @@ from __future__ import print_function, unicode_literals import os, io, json, sqlite3 -from unittest import mock +try: + from unittest import mock +except ImportError: + import mock from twisted.trial import unittest from ..transit_server import Transit from .. import database From 45c09fdd0526e9206cebacfa37620f0f3a47095f Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 2 Apr 2021 20:10:01 -0600 Subject: [PATCH 06/13] explicit interface, different naming --- src/wormhole_transit_relay/test/common.py | 49 +++++++++++++++-- .../test/test_transit_server.py | 54 +++++++++---------- 2 files changed, 73 insertions(+), 30 deletions(-) diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 70cfed2..961e680 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -3,11 +3,45 @@ from twisted.internet.protocol import ( Protocol, ) from twisted.test import iosim +from zope.interface import ( + Interface, + Attribute, + implementer, +) from ..transit_server import ( Transit, ) +class ITransitClient(Interface): + """ + The client interface used by tests. + """ + + connected = Attribute("True if we are currently connected else False") + + def send(data): + """ + Send some bytes. + :param bytes data: the data to send + """ + + def disconnect(): + """ + Terminate the connection. + """ + + def get_received_data(): + """ + :returns: all the bytes received from the server on this + connection. + """ + + def reset_data(): + """ + Erase any received data to this point. + """ + class ServerBase: log_requests = False @@ -33,14 +67,18 @@ class ServerBase: self._transit_server._debug_log = self.log_requests def new_protocol(self): + """ + Create a new client protocol connected to the server. + :returns: a ITransitClient implementation + """ server_protocol = self._transit_server.buildProtocol(('127.0.0.1', 0)) - # XXX interface? + @implementer(ITransitClient) class TransitClientProtocolTcp(Protocol): """ Speak the transit client protocol used by the tests over TCP """ - received = b"" + _received = b"" connected = False def connectionMade(self): @@ -56,8 +94,13 @@ class ServerBase: self.transport.loseConnection() def dataReceived(self, data): - self.received = self.received + data + self._received = self._received + data + def reset_received_data(self): + self._received = b"" + + def get_received_data(self): + return self._received client_factory = ClientFactory() client_factory.protocol = TransitClientProtocolTcp diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index dfdf8de..9a0f349 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -65,16 +65,16 @@ class _Transit: # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.received, exp) - self.assertEqual(p2.received, exp) + self.assertEqual(p1.get_received_data(), exp) + self.assertEqual(p2.get_received_data(), exp) - p1.received = b"" - p2.received = b"" + p1.reset_received_data() + p2.reset_received_data() s1 = b"data1" p1.send(s1) self.flush() - self.assertEqual(p2.received, s1) + self.assertEqual(p2.get_received_data(), s1) p1.disconnect() p2.disconnect() @@ -93,17 +93,17 @@ class _Transit: # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.received, exp) - self.assertEqual(p2.received, exp) + self.assertEqual(p1.get_received_data(), exp) + self.assertEqual(p2.get_received_data(), exp) - p1.received = b"" - p2.received = b"" + p1.reset_received_data() + p2.reset_received_data() # all data they sent after the handshake should be given to us s1 = b"data1" p1.send(s1) self.flush() - self.assertEqual(p2.received, s1) + self.assertEqual(p2.get_received_data(), s1) p1.disconnect() p2.disconnect() @@ -121,17 +121,17 @@ class _Transit: # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.received, exp) - self.assertEqual(p2.received, exp) + self.assertEqual(p1.get_received_data(), exp) + self.assertEqual(p2.get_received_data(), exp) - p1.received = b"" - p2.received = b"" + p1.reset_received_data() + p2.reset_received_data() # all data they sent after the handshake should be given to us s1 = b"data1" p1.send(s1) self.flush() - self.assertEqual(p2.received, s1) + self.assertEqual(p2.get_received_data(), s1) p1.disconnect() p2.disconnect() @@ -151,17 +151,17 @@ class _Transit: # a correct handshake yields an ack, after which we can send exp = b"ok\n" - self.assertEqual(p1.received, exp) - self.assertEqual(p2.received, exp) + self.assertEqual(p1.get_received_data(), exp) + self.assertEqual(p2.get_received_data(), exp) - p1.received = b"" - p2.received = b"" + p1.reset_received_data() + p2.reset_received_data() # all data they sent after the handshake should be given to us s1 = b"data1" p1.send(s1) self.flush() - self.assertEqual(p2.received, s1) + self.assertEqual(p2.get_received_data(), s1) p1.disconnect() p2.disconnect() @@ -207,7 +207,7 @@ class _Transit: self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.received, exp) + self.assertEqual(p1.get_received_data(), exp) p1.disconnect() def test_bad_handshake_old_slow(self): @@ -227,7 +227,7 @@ class _Transit: self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.received, exp) + self.assertEqual(p1.get_received_data(), exp) p1.disconnect() @@ -243,7 +243,7 @@ class _Transit: self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.received, exp) + self.assertEqual(p1.get_received_data(), exp) p1.disconnect() @@ -261,7 +261,7 @@ class _Transit: self.flush() exp = b"bad handshake\n" - self.assertEqual(p1.received, exp) + self.assertEqual(p1.get_received_data(), exp) p1.disconnect() @@ -274,7 +274,7 @@ class _Transit: self.flush() exp = b"impatient\n" - self.assertEqual(p1.received, exp) + self.assertEqual(p1.get_received_data(), exp) p1.disconnect() @@ -289,7 +289,7 @@ class _Transit: self.flush() exp = b"impatient\n" - self.assertEqual(p1.received, exp) + self.assertEqual(p1.get_received_data(), exp) p1.disconnect() @@ -315,7 +315,7 @@ class _Transit: self.flush() exp = b"impatient\n" - self.assertEqual(p1.received, exp) + self.assertEqual(p1.get_received_data(), exp) p1.disconnect() From 2d506de55af0d2625beb9f72020e31f54285c599 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 2 Apr 2021 20:33:57 -0600 Subject: [PATCH 07/13] upcall, 2.7-friendly --- src/wormhole_transit_relay/test/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 961e680..a0b0c28 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -83,9 +83,11 @@ class ServerBase: def connectionMade(self): self.connected = True + return Protocol.connectionMade(self) def connectionLost(self, reason): self.connected = False + return Protocol.connectionLost(self, reason) def send(self, data): self.transport.write(data) From 2903c7f2a0d532635d4ba78b3035c03f748bdf2d Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 2 Apr 2021 23:05:52 -0600 Subject: [PATCH 08/13] re-org + comments --- src/wormhole_transit_relay/test/common.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index a0b0c28..ed2e976 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -81,6 +81,8 @@ class ServerBase: _received = b"" connected = False + # override Protocol callbacks + def connectionMade(self): self.connected = True return Protocol.connectionMade(self) @@ -89,15 +91,17 @@ class ServerBase: self.connected = False return Protocol.connectionLost(self, reason) + def dataReceived(self, data): + self._received = self._received + data + + # ITransitClient API + def send(self, data): self.transport.write(data) def disconnect(self): self.transport.loseConnection() - def dataReceived(self, data): - self._received = self._received + data - def reset_received_data(self): self._received = b"" From f3c391e98bfd814c7f52e5a69a0d926331248d16 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 2 Apr 2021 23:15:48 -0600 Subject: [PATCH 09/13] more coverage --- src/wormhole_transit_relay/test/test_transit_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 9a0f349..e44b099 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -338,6 +338,8 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): log_requests = False class Usage(ServerBase, unittest.TestCase): + log_requests = True + def setUp(self): super(Usage, self).setUp() self._usage = [] From fc3507c1f646ae2f580b9b0cc40f77f9507986b3 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 2 Apr 2021 23:36:33 -0600 Subject: [PATCH 10/13] flip around 'if' logic to simplify --- src/wormhole_transit_relay/transit_server.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 91d84e0..426f50f 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -84,16 +84,16 @@ class TransitConnection(LineReceiver): # point the sender will only transmit data as fast as the # receiver can handle it. if self._sent_ok: - if not self._buddy: - # Our buddy disconnected (we're "jilted"), so we hung up too, - # but our incoming data hasn't stopped yet (it will in a - # moment, after our disconnect makes a roundtrip through the - # kernel). This probably means the file receiver hung up, and - # this connection is the file sender. In may-2020 this - # happened 11 times in 40 days. - return - self._total_sent += len(data) - self._buddy.transport.write(data) + # if self._buddy is None then our buddy disconnected + # (we're "jilted"), so we hung up too, but our incoming + # data hasn't stopped yet (it will in a moment, after our + # disconnect makes a roundtrip through the kernel). This + # probably means the file receiver hung up, and this + # connection is the file sender. In may-2020 this happened + # 11 times in 40 days. + if self._buddy: + self._total_sent += len(data) + self._buddy.transport.write(data) return # handshake is complete but not yet sent_ok From 591740ce5f3cfd2b5fd578121284d63db4fbdfd2 Mon Sep 17 00:00:00 2001 From: meejah Date: Sat, 10 Apr 2021 18:42:12 -0600 Subject: [PATCH 11/13] better name for interface --- src/wormhole_transit_relay/test/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index ed2e976..fb232b5 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -13,7 +13,7 @@ from ..transit_server import ( ) -class ITransitClient(Interface): +class IRelayTestClient(Interface): """ The client interface used by tests. """ @@ -69,11 +69,11 @@ class ServerBase: def new_protocol(self): """ Create a new client protocol connected to the server. - :returns: a ITransitClient implementation + :returns: a IRelayTestClient implementation """ server_protocol = self._transit_server.buildProtocol(('127.0.0.1', 0)) - @implementer(ITransitClient) + @implementer(IRelayTestClient) class TransitClientProtocolTcp(Protocol): """ Speak the transit client protocol used by the tests over TCP @@ -94,7 +94,7 @@ class ServerBase: def dataReceived(self, data): self._received = self._received + data - # ITransitClient API + # IRelayTestClient def send(self, data): self.transport.write(data) From 6efc274b811ede1526e119f0dd398a37eede1566 Mon Sep 17 00:00:00 2001 From: meejah Date: Mon, 12 Apr 2021 08:44:06 -0600 Subject: [PATCH 12/13] get rid of double-flush() pairing with florian --- src/wormhole_transit_relay/test/common.py | 5 ++++- src/wormhole_transit_relay/test/test_transit_server.py | 6 ------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index fb232b5..8073ee0 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -55,8 +55,11 @@ class ServerBase: self._setup_relay(blur_usage=blur_usage) def flush(self): + did_work = False for pump in self._pumps: - pump.flush() + did_work = pump.flush() or did_work + if did_work: + self.flush() def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): self._transit_server = Transit( diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index e44b099..8fbdef8 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -61,7 +61,6 @@ class _Transit: self.flush() p2.send(handshake(token1, side=None)) self.flush() - self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" @@ -89,7 +88,6 @@ class _Transit: self.flush() p2.send(handshake(token1, side=None)) self.flush() - self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" @@ -117,7 +115,6 @@ class _Transit: p1.send(handshake(token1, side=None)) p2.send(handshake(token1, side=side1)) self.flush() - self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" @@ -147,7 +144,6 @@ class _Transit: self.flush() p2.send(handshake(token1, side=side2)) self.flush() - self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" @@ -187,7 +183,6 @@ class _Transit: side2 = b"\x02"*8 p3.send(handshake(token1, side=side2)) self.flush() - self.flush() self.assertEqual(self.count(), 0) self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server._active_connections), 2) @@ -453,7 +448,6 @@ class Usage(ServerBase, unittest.TestCase): p2.send(handshake(token1, side=side2)) self.flush() - self.flush() self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._usage), 2, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[1] From 00086a798dfe5df830d40d61084590cca3043053 Mon Sep 17 00:00:00 2001 From: meejah Date: Mon, 12 Apr 2021 08:50:22 -0600 Subject: [PATCH 13/13] flush cleanup --- src/wormhole_transit_relay/test/test_transit_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 8fbdef8..bca740e 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -77,6 +77,7 @@ class _Transit: p1.disconnect() p2.disconnect() + self.flush() def test_sided_unsided(self): p1 = self.new_protocol() @@ -105,6 +106,7 @@ class _Transit: p1.disconnect() p2.disconnect() + self.flush() def test_unsided_sided(self): p1 = self.new_protocol()