Use StringTransportWithDisconnection for transit server tests.

Replace the use of TCP in the test suite with Twisted's
StringTransport, specifically StringTransportWithDisconnection which
allows us to trigger a disconnect event on the server side during testing.

The `dataReceived` method on the server is now called directly, and any
effects will be realised immediately.
Responses are available to the test client using the `value()` method of
the transport objects, and the buffer can be cleared using `clear()`.

This allows all asynchronous behaviour to be removed from the transit
server test suite.
Furthermore, as we never have to wait for the server, tests no longer
hang if they fail: the errors are encountered immediately.
This commit is contained in:
Joe Harrison 2020-03-08 14:30:23 +00:00 committed by Brian Warner
parent ac7415a4d0
commit 45824ca5d6
3 changed files with 164 additions and 311 deletions

View File

@ -1,30 +1,28 @@
#from __future__ import unicode_literals
from twisted.internet import reactor, endpoints
from twisted.internet.defer import inlineCallbacks
from twisted.test import proto_helpers
from ..transit_server import Transit
class ServerBase:
log_requests = False
@inlineCallbacks
def setUp(self):
self._lp = None
if self.log_requests:
blur_usage = None
else:
blur_usage = 60.0
yield self._setup_relay(blur_usage=blur_usage)
self._setup_relay(blur_usage=blur_usage)
self._transit_server._debug_log = self.log_requests
@inlineCallbacks
def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None):
ep = endpoints.TCP4ServerEndpoint(reactor, 0, interface="127.0.0.1")
self._transit_server = Transit(blur_usage=blur_usage,
log_file=log_file, usage_db=usage_db)
self._lp = yield ep.listen(self._transit_server)
addr = self._lp.getHost()
# ws://127.0.0.1:%d/wormhole-relay/ws
self.transit = u"tcp:127.0.0.1:%d" % addr.port
def new_protocol(self):
protocol = self._transit_server.buildProtocol(('127.0.0.1', 0))
transport = proto_helpers.StringTransportWithDisconnection()
protocol.makeConnection(transport)
transport.protocol = protocol
return protocol
def tearDown(self):
if self._lp:

View File

@ -1,42 +1,28 @@
from __future__ import print_function, unicode_literals
from binascii import hexlify
from twisted.trial import unittest
from twisted.internet import protocol, reactor, defer
from twisted.internet.endpoints import clientFromString, connectProtocol
from twisted.internet import reactor, defer
from .common import ServerBase
from .. import transit_server
class Accumulator(protocol.Protocol):
def __init__(self):
self.data = b""
self.count = 0
self._wait = None
self._disconnect = defer.Deferred()
def waitForBytes(self, more):
assert self._wait is None
self.count = more
self._wait = defer.Deferred()
self._check_done()
return self._wait
def dataReceived(self, data):
self.data = self.data + data
self._check_done()
def _check_done(self):
if self._wait and len(self.data) >= self.count:
d = self._wait
self._wait = None
d.callback(self)
def connectionLost(self, why):
if self._wait:
self._wait.errback(RuntimeError("closed"))
self._disconnect.callback(None)
def wait():
d = defer.Deferred()
reactor.callLater(0.001, d.callback, None)
return d
def handshake(token, side=None):
hs = b"please relay " + hexlify(token)
if side is not None:
hs += b" for side " + hexlify(side)
hs += b"\n"
return hs
class _Transit:
def count(self):
return sum([len(potentials)
for potentials
in self._transit_server._pending_requests.values()])
def test_blur_size(self):
blur = transit_server.blur_size
self.failUnlessEqual(blur(0), 0)
@ -55,268 +41,195 @@ class _Transit:
self.failUnlessEqual(blur(1100e6), 1100e6)
self.failUnlessEqual(blur(1150e6), 1200e6)
@defer.inlineCallbacks
def test_register(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 0:
yield wait()
p1.dataReceived(handshake(token1, side1))
self.assertEqual(self.count(), 1)
a1.transport.loseConnection()
# let that get removed
while self.count() > 0:
yield wait()
p1.transport.loseConnection()
self.assertEqual(self.count(), 0)
# the token should be removed too
self.assertEqual(len(self._transit_server._pending_requests), 0)
@defer.inlineCallbacks
def test_both_unsided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
p2 = self.new_protocol()
token1 = b"\x00"*32
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
p1.dataReceived(handshake(token1, side=None))
p2.dataReceived(handshake(token1, side=None))
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
self.assertEqual(p1.transport.value(), exp)
self.assertEqual(p2.transport.value(), exp)
p1.transport.clear()
p2.transport.clear()
s1 = b"data1"
a1.transport.write(s1)
p1.dataReceived(s1)
self.assertEqual(p2.transport.value(), s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
p1.transport.loseConnection()
p2.transport.loseConnection()
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_sided_unsided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
p2 = self.new_protocol()
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
p1.dataReceived(handshake(token1, side=side1))
p2.dataReceived(handshake(token1, side=None))
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
self.assertEqual(p1.transport.value(), exp)
self.assertEqual(p2.transport.value(), exp)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
p1.transport.clear()
p2.transport.clear()
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
s1 = b"data1"
p1.dataReceived(s1)
self.assertEqual(p2.transport.value(), s1)
a1.transport.loseConnection()
a2.transport.loseConnection()
p1.transport.loseConnection()
p2.transport.loseConnection()
@defer.inlineCallbacks
def test_unsided_sided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
p2 = self.new_protocol()
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
p1.dataReceived(handshake(token1, side=None))
p2.dataReceived(handshake(token1, side=side1))
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
self.assertEqual(p1.transport.value(), exp)
self.assertEqual(p2.transport.value(), exp)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
p1.transport.clear()
p2.transport.clear()
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
s1 = b"data1"
p1.dataReceived(s1)
self.assertEqual(p2.transport.value(), s1)
a1.transport.loseConnection()
a2.transport.loseConnection()
p1.transport.loseConnection()
p2.transport.loseConnection()
@defer.inlineCallbacks
def test_both_sided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
p2 = self.new_protocol()
token1 = b"\x00"*32
side1 = b"\x01"*8
side2 = b"\x02"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side2) + b"\n")
p1.dataReceived(handshake(token1, side=side1))
p2.dataReceived(handshake(token1, side=side2))
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
self.assertEqual(p1.transport.value(), exp)
self.assertEqual(p2.transport.value(), exp)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
p1.transport.clear()
p2.transport.clear()
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
s1 = b"data1"
p1.dataReceived(s1)
self.assertEqual(p2.transport.value(), s1)
a1.transport.loseConnection()
a2.transport.loseConnection()
p1.transport.loseConnection()
p2.transport.loseConnection()
def count(self):
return sum([len(potentials)
for potentials
in self._transit_server._pending_requests.values()])
@defer.inlineCallbacks
def test_ignore_same_side(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
a3 = yield connectProtocol(ep, Accumulator())
disconnects = []
a1._disconnect.addCallback(disconnects.append)
a2._disconnect.addCallback(disconnects.append)
p1 = self.new_protocol()
p2 = self.new_protocol()
p3 = self.new_protocol()
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 0:
yield wait()
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 1:
yield wait()
p1.dataReceived(handshake(token1, side=side1))
self.assertEqual(self.count(), 1)
p2.dataReceived(handshake(token1, side=side1))
self.assertEqual(self.count(), 2) # same-side connections don't match
# when the second side arrives, the spare first connection should be
# closed
side2 = b"\x02"*8
a3.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side2) + b"\n")
# let that arrive
while self.count() != 0:
yield wait()
p3.dataReceived(handshake(token1, side=side2))
self.assertEqual(self.count(), 0)
self.assertEqual(len(self._transit_server._pending_requests), 0)
self.assertEqual(len(self._transit_server._active_connections), 2)
# That will trigger a disconnect on exactly one of (a1 or a2). Wait
# until our client notices it.
while not disconnects:
yield wait()
# the other connection should still be connected
self.assertEqual(sum([int(t.transport.connected) for t in [a1, a2]]), 1)
# That will trigger a disconnect on exactly one of (p1 or p2).
# The other connection should still be connected
self.assertEqual(sum([int(t.transport.connected) for t in [p1, p2]]), 1)
a1.transport.loseConnection()
a2.transport.loseConnection()
a3.transport.loseConnection()
p1.transport.loseConnection()
p2.transport.loseConnection()
p3.transport.loseConnection()
@defer.inlineCallbacks
def test_bad_handshake_old(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
token1 = b"\x00"*32
# the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match.
a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n")
p1.dataReceived(b"please DELAY " + hexlify(token1) + b"\n")
exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
self.assertEqual(p1.transport.value(), exp)
p1.transport.loseConnection()
a1.transport.loseConnection()
@defer.inlineCallbacks
def test_bad_handshake_old_slow(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
a1.transport.write(b"please DELAY ")
p1.dataReceived(b"please DELAY ")
# As in test_impatience_new_slow, the current state machine has code
# that can only be reached if we insert a stall here, so dataReceived
# gets called twice. Hopefully we can delete this test once
# dataReceived is refactored to remove that state.
d = defer.Deferred()
reactor.callLater(0.1, d.callback, None)
yield d
token1 = b"\x00"*32
# the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match.
a1.transport.write(hexlify(token1) + b"\n")
p1.dataReceived(hexlify(token1) + b"\n")
exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
self.assertEqual(p1.transport.value(), exp)
a1.transport.loseConnection()
p1.transport.loseConnection()
@defer.inlineCallbacks
def test_bad_handshake_new(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
token1 = b"\x00"*32
side1 = b"\x01"*8
# the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match.
a1.transport.write(b"please DELAY " + hexlify(token1) +
p1.dataReceived(b"please DELAY " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
self.assertEqual(p1.transport.value(), exp)
a1.transport.loseConnection()
p1.transport.loseConnection()
@defer.inlineCallbacks
def test_binary_handshake(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff"
# the embedded \n makes the server trigger early, before the full
@ -325,50 +238,41 @@ class _Transit:
# UnicodeDecodeError when it tried to coerce the incoming handshake
# to unicode, due to the ("\n" in buf) check. This was fixed to use
# (b"\n" in buf). This exercises the old failure.
a1.transport.write(binary_bad_handshake)
p1.dataReceived(binary_bad_handshake)
exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
self.assertEqual(p1.transport.value(), exp)
a1.transport.loseConnection()
p1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience_old(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
token1 = b"\x00"*32
# sending too many bytes is impatience.
a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
p1.dataReceived(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
self.assertEqual(p1.transport.value(), exp)
a1.transport.loseConnection()
p1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience_new(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
token1 = b"\x00"*32
side1 = b"\x01"*8
# sending too many bytes is impatience.
a1.transport.write(b"please relay " + hexlify(token1) +
p1.dataReceived(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
self.assertEqual(p1.transport.value(), exp)
a1.transport.loseConnection()
p1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience_new_slow(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
# For full coverage, we need dataReceived to see a particular framing
# of these two pieces of data, and ITCPTransport doesn't have flush()
# (which probably wouldn't work anyways). For now, force a 100ms
@ -381,20 +285,16 @@ class _Transit:
token1 = b"\x00"*32
side1 = b"\x01"*8
# sending too many bytes is impatience.
a1.transport.write(b"please relay " + hexlify(token1) +
p1.dataReceived(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
d = defer.Deferred()
reactor.callLater(0.1, d.callback, None)
yield d
a1.transport.write(b"NOWNOWNOW")
p1.dataReceived(b"NOWNOWNOW")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
self.assertEqual(p1.transport.value(), exp)
a1.transport.loseConnection()
p1.transport.loseConnection()
@defer.inlineCallbacks
def test_short_handshake(self):
@ -420,9 +320,8 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
log_requests = False
class Usage(ServerBase, unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
yield super(Usage, self).setUp()
super(Usage, self).setUp()
self._usage = []
def record(started, result, total_bytes, total_time, waiting_time):
self._usage.append((started, result, total_bytes,
@ -470,130 +369,83 @@ class Usage(ServerBase, unittest.TestCase):
@defer.inlineCallbacks
def test_errory(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
a1.transport.write(b"this is a very bad handshake\n")
p1.dataReceived(b"this is a very bad handshake\n")
# that will log the "errory" usage event, then drop the connection
yield a1._disconnect
p1.transport.loseConnection()
self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "errory", self._usage)
@defer.inlineCallbacks
def test_lonely(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
while not self._transit_server._pending_requests:
yield wait() # wait for the server to see the connection
p1.dataReceived(handshake(token1, side=side1))
# now we disconnect before the peer connects
a1.transport.loseConnection()
yield a1._disconnect
while self._transit_server._pending_requests:
yield wait() # wait for the server to see the disconnect too
p1.transport.loseConnection()
self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "lonely", self._usage)
self.assertIdentical(waiting_time, None)
@defer.inlineCallbacks
def test_one_happy_one_jilted(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
p1 = self.new_protocol()
p2 = self.new_protocol()
token1 = b"\x00"*32
side1 = b"\x01"*8
side2 = b"\x02"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
while not self._transit_server._pending_requests:
yield wait() # make sure a1 connects first
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side2) + b"\n")
while not self._transit_server._active_connections:
yield wait() # wait for the server to see the connection
self.assertEqual(len(self._transit_server._pending_requests), 0)
self.assertEqual(self._usage, []) # no events yet
a1.transport.write(b"\x00" * 13)
yield a2.waitForBytes(13)
a2.transport.write(b"\xff" * 7)
yield a1.waitForBytes(7)
p1.dataReceived(handshake(token1, side=side1))
p2.dataReceived(handshake(token1, side=side2))
self.assertEqual(self._usage, []) # no events yet
p1.dataReceived(b"\x00" * 13)
p2.dataReceived(b"\xff" * 7)
p1.transport.loseConnection()
a1.transport.loseConnection()
yield a1._disconnect
while self._transit_server._active_connections:
yield wait()
yield a2._disconnect
self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "happy", self._usage)
self.assertEqual(total_bytes, 20)
self.assertNotIdentical(waiting_time, None)
@defer.inlineCallbacks
def test_redundant(self):
ep = clientFromString(reactor, self.transit)
a1a = yield connectProtocol(ep, Accumulator())
a1b = yield connectProtocol(ep, Accumulator())
a1c = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
p1a = self.new_protocol()
p1b = self.new_protocol()
p1c = self.new_protocol()
p2 = self.new_protocol()
token1 = b"\x00"*32
side1 = b"\x01"*8
side2 = b"\x02"*8
a1a.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
def count_requests():
return sum([len(v)
for v in self._transit_server._pending_requests.values()])
while count_requests() < 1:
yield wait()
a1b.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
while count_requests() < 2:
yield wait()
p1a.dataReceived(handshake(token1, side=side1))
p1b.dataReceived(handshake(token1, side=side1))
# connect and disconnect a third client (for side1) to exercise the
# code that removes a pending connection without removing the entire
# token
a1c.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
while count_requests() < 3:
yield wait()
a1c.transport.loseConnection()
yield a1c._disconnect
while count_requests() > 2:
yield wait()
p1c.dataReceived(handshake(token1, side=side1))
p1c.transport.loseConnection()
self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "lonely", self._usage)
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side2) + b"\n")
# this will claim one of (a1a, a1b), and close the other as redundant
while not self._transit_server._active_connections:
yield wait() # wait for the server to see the connection
self.assertEqual(count_requests(), 0)
p2.dataReceived(handshake(token1, side=side2))
self.assertEqual(len(self._transit_server._pending_requests), 0)
self.assertEqual(len(self._usage), 2, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[1]
self.assertEqual(result, "redundant", self._usage)
# one of the these is unecessary, but probably harmless
a1a.transport.loseConnection()
a1b.transport.loseConnection()
yield a1a._disconnect
yield a1b._disconnect
while self._transit_server._active_connections:
yield wait()
yield a2._disconnect
p1a.transport.loseConnection()
p1b.transport.loseConnection()
self.assertEqual(len(self._usage), 3, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[2]
self.assertEqual(result, "happy", self._usage)

View File

@ -52,7 +52,10 @@ class TransitConnection(LineReceiver):
def connectionMade(self):
self._started = time.time()
self._log_requests = self.factory._log_requests
try:
self.transport.setTcpKeepAlive(True)
except AttributeError:
pass
def lineReceived(self, line):
# old: "please relay {64}\n"
@ -257,14 +260,14 @@ class Transit(protocol.ServerFactory):
# drop and stop tracking the rest
potentials.remove(old)
for (_, leftover_tc) in potentials:
for (_, leftover_tc) in potentials.copy():
# Don't record this as errory. It's just a spare connection
# from the same side as a connection that got used. This
# can happen if the connection hint contains multiple
# addresses (we don't currently support those, but it'd
# probably be useful in the future).
leftover_tc.disconnect_redundant()
self._pending_requests.pop(token)
self._pending_requests.pop(token, None)
# glue the two ends together
self._active_connections.add(new_tc)