From 45824ca5d60b5f29172710534a6e568f5cd0cbbd Mon Sep 17 00:00:00 2001 From: Joe Harrison Date: Sun, 8 Mar 2020 14:30:23 +0000 Subject: [PATCH] Use StringTransportWithDisconnection for transit server tests. Replace the use of TCP in the test suite with Twisted's StringTransport, specifically StringTransportWithDisconnection which allows us to trigger a disconnect event on the server side during testing. The `dataReceived` method on the server is now called directly, and any effects will be realised immediately. Responses are available to the test client using the `value()` method of the transport objects, and the buffer can be cleared using `clear()`. This allows all asynchronous behaviour to be removed from the transit server test suite. Furthermore, as we never have to wait for the server, tests no longer hang if they fail: the errors are encountered immediately. --- src/wormhole_transit_relay/test/common.py | 20 +- .../test/test_transit_server.py | 446 ++++++------------ src/wormhole_transit_relay/transit_server.py | 9 +- 3 files changed, 164 insertions(+), 311 deletions(-) diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 9f2e827..53958fb 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -1,30 +1,28 @@ -#from __future__ import unicode_literals -from twisted.internet import reactor, endpoints -from twisted.internet.defer import inlineCallbacks +from twisted.test import proto_helpers from ..transit_server import Transit class ServerBase: log_requests = False - @inlineCallbacks def setUp(self): self._lp = None if self.log_requests: blur_usage = None else: blur_usage = 60.0 - yield self._setup_relay(blur_usage=blur_usage) + self._setup_relay(blur_usage=blur_usage) self._transit_server._debug_log = self.log_requests - @inlineCallbacks def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): - ep = endpoints.TCP4ServerEndpoint(reactor, 0, interface="127.0.0.1") self._transit_server = Transit(blur_usage=blur_usage, log_file=log_file, usage_db=usage_db) - self._lp = yield ep.listen(self._transit_server) - addr = self._lp.getHost() - # ws://127.0.0.1:%d/wormhole-relay/ws - self.transit = u"tcp:127.0.0.1:%d" % addr.port + + def new_protocol(self): + protocol = self._transit_server.buildProtocol(('127.0.0.1', 0)) + transport = proto_helpers.StringTransportWithDisconnection() + protocol.makeConnection(transport) + transport.protocol = protocol + return protocol def tearDown(self): if self._lp: diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 325ee04..6faf6a4 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -1,42 +1,28 @@ from __future__ import print_function, unicode_literals from binascii import hexlify from twisted.trial import unittest -from twisted.internet import protocol, reactor, defer -from twisted.internet.endpoints import clientFromString, connectProtocol +from twisted.internet import reactor, defer from .common import ServerBase from .. import transit_server -class Accumulator(protocol.Protocol): - def __init__(self): - self.data = b"" - self.count = 0 - self._wait = None - self._disconnect = defer.Deferred() - def waitForBytes(self, more): - assert self._wait is None - self.count = more - self._wait = defer.Deferred() - self._check_done() - return self._wait - def dataReceived(self, data): - self.data = self.data + data - self._check_done() - def _check_done(self): - if self._wait and len(self.data) >= self.count: - d = self._wait - self._wait = None - d.callback(self) - def connectionLost(self, why): - if self._wait: - self._wait.errback(RuntimeError("closed")) - self._disconnect.callback(None) - def wait(): d = defer.Deferred() reactor.callLater(0.001, d.callback, None) return d +def handshake(token, side=None): + hs = b"please relay " + hexlify(token) + if side is not None: + hs += b" for side " + hexlify(side) + hs += b"\n" + return hs + class _Transit: + def count(self): + return sum([len(potentials) + for potentials + in self._transit_server._pending_requests.values()]) + def test_blur_size(self): blur = transit_server.blur_size self.failUnlessEqual(blur(0), 0) @@ -55,268 +41,195 @@ class _Transit: self.failUnlessEqual(blur(1100e6), 1100e6) self.failUnlessEqual(blur(1150e6), 1200e6) - @defer.inlineCallbacks def test_register(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - # let that arrive - while self.count() == 0: - yield wait() + p1.dataReceived(handshake(token1, side1)) self.assertEqual(self.count(), 1) - a1.transport.loseConnection() - - # let that get removed - while self.count() > 0: - yield wait() + p1.transport.loseConnection() self.assertEqual(self.count(), 0) # the token should be removed too self.assertEqual(len(self._transit_server._pending_requests), 0) - @defer.inlineCallbacks def test_both_unsided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 - a1.transport.write(b"please relay " + hexlify(token1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + b"\n") + p1.dataReceived(handshake(token1, side=None)) + p2.dataReceived(handshake(token1, side=None)) # a correct handshake yields an ack, after which we can send exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) + self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p2.transport.value(), exp) + + p1.transport.clear() + p2.transport.clear() + s1 = b"data1" - a1.transport.write(s1) + p1.dataReceived(s1) + self.assertEqual(p2.transport.value(), s1) - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) + p1.transport.loseConnection() + p2.transport.loseConnection() - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks def test_sided_unsided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + b"\n") + p1.dataReceived(handshake(token1, side=side1)) + p2.dataReceived(handshake(token1, side=None)) # a correct handshake yields an ack, after which we can send exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) + self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p2.transport.value(), exp) - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) + p1.transport.clear() + p2.transport.clear() # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) + s1 = b"data1" + p1.dataReceived(s1) + self.assertEqual(p2.transport.value(), s1) - a1.transport.loseConnection() - a2.transport.loseConnection() + p1.transport.loseConnection() + p2.transport.loseConnection() - @defer.inlineCallbacks def test_unsided_sided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") + p1.dataReceived(handshake(token1, side=None)) + p2.dataReceived(handshake(token1, side=side1)) # a correct handshake yields an ack, after which we can send exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) + self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p2.transport.value(), exp) - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) + p1.transport.clear() + p2.transport.clear() # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) + s1 = b"data1" + p1.dataReceived(s1) + self.assertEqual(p2.transport.value(), s1) - a1.transport.loseConnection() - a2.transport.loseConnection() + p1.transport.loseConnection() + p2.transport.loseConnection() - @defer.inlineCallbacks def test_both_sided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side2) + b"\n") + p1.dataReceived(handshake(token1, side=side1)) + p2.dataReceived(handshake(token1, side=side2)) # a correct handshake yields an ack, after which we can send exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) + self.assertEqual(p1.transport.value(), exp) + self.assertEqual(p2.transport.value(), exp) - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) + p1.transport.clear() + p2.transport.clear() # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) + s1 = b"data1" + p1.dataReceived(s1) + self.assertEqual(p2.transport.value(), s1) - a1.transport.loseConnection() - a2.transport.loseConnection() + p1.transport.loseConnection() + p2.transport.loseConnection() - def count(self): - return sum([len(potentials) - for potentials - in self._transit_server._pending_requests.values()]) - - @defer.inlineCallbacks def test_ignore_same_side(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - a3 = yield connectProtocol(ep, Accumulator()) - disconnects = [] - a1._disconnect.addCallback(disconnects.append) - a2._disconnect.addCallback(disconnects.append) + p1 = self.new_protocol() + p2 = self.new_protocol() + p3 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - # let that arrive - while self.count() == 0: - yield wait() - a2.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - # let that arrive - while self.count() == 1: - yield wait() + + p1.dataReceived(handshake(token1, side=side1)) + self.assertEqual(self.count(), 1) + + p2.dataReceived(handshake(token1, side=side1)) self.assertEqual(self.count(), 2) # same-side connections don't match # when the second side arrives, the spare first connection should be # closed side2 = b"\x02"*8 - a3.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side2) + b"\n") - # let that arrive - while self.count() != 0: - yield wait() + p3.dataReceived(handshake(token1, side=side2)) + self.assertEqual(self.count(), 0) self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server._active_connections), 2) - # That will trigger a disconnect on exactly one of (a1 or a2). Wait - # until our client notices it. - while not disconnects: - yield wait() - # the other connection should still be connected - self.assertEqual(sum([int(t.transport.connected) for t in [a1, a2]]), 1) + # That will trigger a disconnect on exactly one of (p1 or p2). + # The other connection should still be connected + self.assertEqual(sum([int(t.transport.connected) for t in [p1, p2]]), 1) - a1.transport.loseConnection() - a2.transport.loseConnection() - a3.transport.loseConnection() + p1.transport.loseConnection() + p2.transport.loseConnection() + p3.transport.loseConnection() - @defer.inlineCallbacks def test_bad_handshake_old(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() token1 = b"\x00"*32 - # the server waits for the exact number of bytes in the expected - # handshake message. to trigger "bad handshake", we must match. - a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n") + p1.dataReceived(b"please DELAY " + hexlify(token1) + b"\n") exp = b"bad handshake\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) + self.assertEqual(p1.transport.value(), exp) + p1.transport.loseConnection() - a1.transport.loseConnection() - - @defer.inlineCallbacks def test_bad_handshake_old_slow(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() - a1.transport.write(b"please DELAY ") + p1.dataReceived(b"please DELAY ") # As in test_impatience_new_slow, the current state machine has code # that can only be reached if we insert a stall here, so dataReceived # gets called twice. Hopefully we can delete this test once # dataReceived is refactored to remove that state. - d = defer.Deferred() - reactor.callLater(0.1, d.callback, None) - yield d token1 = b"\x00"*32 # the server waits for the exact number of bytes in the expected # handshake message. to trigger "bad handshake", we must match. - a1.transport.write(hexlify(token1) + b"\n") + p1.dataReceived(hexlify(token1) + b"\n") exp = b"bad handshake\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) + self.assertEqual(p1.transport.value(), exp) - a1.transport.loseConnection() + p1.transport.loseConnection() - @defer.inlineCallbacks def test_bad_handshake_new(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 # the server waits for the exact number of bytes in the expected # handshake message. to trigger "bad handshake", we must match. - a1.transport.write(b"please DELAY " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") + p1.dataReceived(b"please DELAY " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") exp = b"bad handshake\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) + self.assertEqual(p1.transport.value(), exp) - a1.transport.loseConnection() + p1.transport.loseConnection() - @defer.inlineCallbacks def test_binary_handshake(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff" # the embedded \n makes the server trigger early, before the full @@ -325,50 +238,41 @@ class _Transit: # UnicodeDecodeError when it tried to coerce the incoming handshake # to unicode, due to the ("\n" in buf) check. This was fixed to use # (b"\n" in buf). This exercises the old failure. - a1.transport.write(binary_bad_handshake) + p1.dataReceived(binary_bad_handshake) exp = b"bad handshake\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) + self.assertEqual(p1.transport.value(), exp) - a1.transport.loseConnection() + p1.transport.loseConnection() - @defer.inlineCallbacks def test_impatience_old(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() token1 = b"\x00"*32 # sending too many bytes is impatience. - a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") + p1.dataReceived(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") exp = b"impatient\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) + self.assertEqual(p1.transport.value(), exp) - a1.transport.loseConnection() + p1.transport.loseConnection() - @defer.inlineCallbacks def test_impatience_new(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 # sending too many bytes is impatience. - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\nNOWNOWNOW") + p1.dataReceived(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\nNOWNOWNOW") exp = b"impatient\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) + self.assertEqual(p1.transport.value(), exp) - a1.transport.loseConnection() + p1.transport.loseConnection() - @defer.inlineCallbacks def test_impatience_new_slow(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() # For full coverage, we need dataReceived to see a particular framing # of these two pieces of data, and ITCPTransport doesn't have flush() # (which probably wouldn't work anyways). For now, force a 100ms @@ -381,20 +285,16 @@ class _Transit: token1 = b"\x00"*32 side1 = b"\x01"*8 # sending too many bytes is impatience. - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") + p1.dataReceived(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") - d = defer.Deferred() - reactor.callLater(0.1, d.callback, None) - yield d - a1.transport.write(b"NOWNOWNOW") + p1.dataReceived(b"NOWNOWNOW") exp = b"impatient\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) + self.assertEqual(p1.transport.value(), exp) - a1.transport.loseConnection() + p1.transport.loseConnection() @defer.inlineCallbacks def test_short_handshake(self): @@ -420,9 +320,8 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): log_requests = False class Usage(ServerBase, unittest.TestCase): - @defer.inlineCallbacks def setUp(self): - yield super(Usage, self).setUp() + super(Usage, self).setUp() self._usage = [] def record(started, result, total_bytes, total_time, waiting_time): self._usage.append((started, result, total_bytes, @@ -470,130 +369,83 @@ class Usage(ServerBase, unittest.TestCase): @defer.inlineCallbacks def test_errory(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() - a1.transport.write(b"this is a very bad handshake\n") + p1.dataReceived(b"this is a very bad handshake\n") # that will log the "errory" usage event, then drop the connection - yield a1._disconnect + p1.transport.loseConnection() self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(result, "errory", self._usage) - @defer.inlineCallbacks def test_lonely(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - while not self._transit_server._pending_requests: - yield wait() # wait for the server to see the connection + p1.dataReceived(handshake(token1, side=side1)) # now we disconnect before the peer connects - a1.transport.loseConnection() - yield a1._disconnect - while self._transit_server._pending_requests: - yield wait() # wait for the server to see the disconnect too + p1.transport.loseConnection() self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(result, "lonely", self._usage) self.assertIdentical(waiting_time, None) - @defer.inlineCallbacks def test_one_happy_one_jilted(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) + p1 = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - while not self._transit_server._pending_requests: - yield wait() # make sure a1 connects first - a2.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side2) + b"\n") - while not self._transit_server._active_connections: - yield wait() # wait for the server to see the connection - self.assertEqual(len(self._transit_server._pending_requests), 0) - self.assertEqual(self._usage, []) # no events yet - a1.transport.write(b"\x00" * 13) - yield a2.waitForBytes(13) - a2.transport.write(b"\xff" * 7) - yield a1.waitForBytes(7) + p1.dataReceived(handshake(token1, side=side1)) + p2.dataReceived(handshake(token1, side=side2)) + + self.assertEqual(self._usage, []) # no events yet + + p1.dataReceived(b"\x00" * 13) + p2.dataReceived(b"\xff" * 7) + + p1.transport.loseConnection() - a1.transport.loseConnection() - yield a1._disconnect - while self._transit_server._active_connections: - yield wait() - yield a2._disconnect self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(result, "happy", self._usage) self.assertEqual(total_bytes, 20) self.assertNotIdentical(waiting_time, None) - @defer.inlineCallbacks def test_redundant(self): - ep = clientFromString(reactor, self.transit) - a1a = yield connectProtocol(ep, Accumulator()) - a1b = yield connectProtocol(ep, Accumulator()) - a1c = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) + p1a = self.new_protocol() + p1b = self.new_protocol() + p1c = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 - a1a.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - def count_requests(): - return sum([len(v) - for v in self._transit_server._pending_requests.values()]) - while count_requests() < 1: - yield wait() - a1b.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - while count_requests() < 2: - yield wait() + p1a.dataReceived(handshake(token1, side=side1)) + p1b.dataReceived(handshake(token1, side=side1)) # connect and disconnect a third client (for side1) to exercise the # code that removes a pending connection without removing the entire # token - a1c.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - while count_requests() < 3: - yield wait() - a1c.transport.loseConnection() - yield a1c._disconnect - while count_requests() > 2: - yield wait() + p1c.dataReceived(handshake(token1, side=side1)) + p1c.transport.loseConnection() + self.assertEqual(len(self._usage), 1, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(result, "lonely", self._usage) - a2.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side2) + b"\n") - # this will claim one of (a1a, a1b), and close the other as redundant - while not self._transit_server._active_connections: - yield wait() # wait for the server to see the connection - self.assertEqual(count_requests(), 0) + p2.dataReceived(handshake(token1, side=side2)) + self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._usage), 2, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[1] self.assertEqual(result, "redundant", self._usage) # one of the these is unecessary, but probably harmless - a1a.transport.loseConnection() - a1b.transport.loseConnection() - yield a1a._disconnect - yield a1b._disconnect - while self._transit_server._active_connections: - yield wait() - yield a2._disconnect + p1a.transport.loseConnection() + p1b.transport.loseConnection() self.assertEqual(len(self._usage), 3, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[2] self.assertEqual(result, "happy", self._usage) - diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index c3ca635..91d84e0 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -52,7 +52,10 @@ class TransitConnection(LineReceiver): def connectionMade(self): self._started = time.time() self._log_requests = self.factory._log_requests - self.transport.setTcpKeepAlive(True) + try: + self.transport.setTcpKeepAlive(True) + except AttributeError: + pass def lineReceived(self, line): # old: "please relay {64}\n" @@ -257,14 +260,14 @@ class Transit(protocol.ServerFactory): # drop and stop tracking the rest potentials.remove(old) - for (_, leftover_tc) in potentials: + for (_, leftover_tc) in potentials.copy(): # Don't record this as errory. It's just a spare connection # from the same side as a connection that got used. This # can happen if the connection hint contains multiple # addresses (we don't currently support those, but it'd # probably be useful in the future). leftover_tc.disconnect_redundant() - self._pending_requests.pop(token) + self._pending_requests.pop(token, None) # glue the two ends together self._active_connections.add(new_tc)