Use StringTransportWithDisconnection for transit server tests.

Replace the use of TCP in the test suite with Twisted's
StringTransport, specifically StringTransportWithDisconnection which
allows us to trigger a disconnect event on the server side during testing.

The `dataReceived` method on the server is now called directly, and any
effects will be realised immediately.
Responses are available to the test client using the `value()` method of
the transport objects, and the buffer can be cleared using `clear()`.

This allows all asynchronous behaviour to be removed from the transit
server test suite.
Furthermore, as we never have to wait for the server, tests no longer
hang if they fail: the errors are encountered immediately.
This commit is contained in:
Joe Harrison 2020-03-08 14:30:23 +00:00 committed by Brian Warner
parent ac7415a4d0
commit 45824ca5d6
3 changed files with 164 additions and 311 deletions

View File

@ -1,30 +1,28 @@
#from __future__ import unicode_literals from twisted.test import proto_helpers
from twisted.internet import reactor, endpoints
from twisted.internet.defer import inlineCallbacks
from ..transit_server import Transit from ..transit_server import Transit
class ServerBase: class ServerBase:
log_requests = False log_requests = False
@inlineCallbacks
def setUp(self): def setUp(self):
self._lp = None self._lp = None
if self.log_requests: if self.log_requests:
blur_usage = None blur_usage = None
else: else:
blur_usage = 60.0 blur_usage = 60.0
yield self._setup_relay(blur_usage=blur_usage) self._setup_relay(blur_usage=blur_usage)
self._transit_server._debug_log = self.log_requests self._transit_server._debug_log = self.log_requests
@inlineCallbacks
def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None):
ep = endpoints.TCP4ServerEndpoint(reactor, 0, interface="127.0.0.1")
self._transit_server = Transit(blur_usage=blur_usage, self._transit_server = Transit(blur_usage=blur_usage,
log_file=log_file, usage_db=usage_db) log_file=log_file, usage_db=usage_db)
self._lp = yield ep.listen(self._transit_server)
addr = self._lp.getHost() def new_protocol(self):
# ws://127.0.0.1:%d/wormhole-relay/ws protocol = self._transit_server.buildProtocol(('127.0.0.1', 0))
self.transit = u"tcp:127.0.0.1:%d" % addr.port transport = proto_helpers.StringTransportWithDisconnection()
protocol.makeConnection(transport)
transport.protocol = protocol
return protocol
def tearDown(self): def tearDown(self):
if self._lp: if self._lp:

View File

@ -1,42 +1,28 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
from binascii import hexlify from binascii import hexlify
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import protocol, reactor, defer from twisted.internet import reactor, defer
from twisted.internet.endpoints import clientFromString, connectProtocol
from .common import ServerBase from .common import ServerBase
from .. import transit_server from .. import transit_server
class Accumulator(protocol.Protocol):
def __init__(self):
self.data = b""
self.count = 0
self._wait = None
self._disconnect = defer.Deferred()
def waitForBytes(self, more):
assert self._wait is None
self.count = more
self._wait = defer.Deferred()
self._check_done()
return self._wait
def dataReceived(self, data):
self.data = self.data + data
self._check_done()
def _check_done(self):
if self._wait and len(self.data) >= self.count:
d = self._wait
self._wait = None
d.callback(self)
def connectionLost(self, why):
if self._wait:
self._wait.errback(RuntimeError("closed"))
self._disconnect.callback(None)
def wait(): def wait():
d = defer.Deferred() d = defer.Deferred()
reactor.callLater(0.001, d.callback, None) reactor.callLater(0.001, d.callback, None)
return d return d
def handshake(token, side=None):
hs = b"please relay " + hexlify(token)
if side is not None:
hs += b" for side " + hexlify(side)
hs += b"\n"
return hs
class _Transit: class _Transit:
def count(self):
return sum([len(potentials)
for potentials
in self._transit_server._pending_requests.values()])
def test_blur_size(self): def test_blur_size(self):
blur = transit_server.blur_size blur = transit_server.blur_size
self.failUnlessEqual(blur(0), 0) self.failUnlessEqual(blur(0), 0)
@ -55,268 +41,195 @@ class _Transit:
self.failUnlessEqual(blur(1100e6), 1100e6) self.failUnlessEqual(blur(1100e6), 1100e6)
self.failUnlessEqual(blur(1150e6), 1200e6) self.failUnlessEqual(blur(1150e6), 1200e6)
@defer.inlineCallbacks
def test_register(self): def test_register(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive p1.dataReceived(handshake(token1, side1))
while self.count() == 0:
yield wait()
self.assertEqual(self.count(), 1) self.assertEqual(self.count(), 1)
a1.transport.loseConnection() p1.transport.loseConnection()
# let that get removed
while self.count() > 0:
yield wait()
self.assertEqual(self.count(), 0) self.assertEqual(self.count(), 0)
# the token should be removed too # the token should be removed too
self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server._pending_requests), 0)
@defer.inlineCallbacks
def test_both_unsided(self): def test_both_unsided(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator()) p2 = self.new_protocol()
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32 token1 = b"\x00"*32
a1.transport.write(b"please relay " + hexlify(token1) + b"\n") p1.dataReceived(handshake(token1, side=None))
a2.transport.write(b"please relay " + hexlify(token1) + b"\n") p2.dataReceived(handshake(token1, side=None))
# a correct handshake yields an ack, after which we can send # a correct handshake yields an ack, after which we can send
exp = b"ok\n" exp = b"ok\n"
yield a1.waitForBytes(len(exp)) self.assertEqual(p1.transport.value(), exp)
self.assertEqual(a1.data, exp) self.assertEqual(p2.transport.value(), exp)
p1.transport.clear()
p2.transport.clear()
s1 = b"data1" s1 = b"data1"
a1.transport.write(s1) p1.dataReceived(s1)
self.assertEqual(p2.transport.value(), s1)
exp = b"ok\n" p1.transport.loseConnection()
yield a2.waitForBytes(len(exp)) p2.transport.loseConnection()
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_sided_unsided(self): def test_sided_unsided(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator()) p2 = self.new_protocol()
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) + p1.dataReceived(handshake(token1, side=side1))
b" for side " + hexlify(side1) + b"\n") p2.dataReceived(handshake(token1, side=None))
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
# a correct handshake yields an ack, after which we can send # a correct handshake yields an ack, after which we can send
exp = b"ok\n" exp = b"ok\n"
yield a1.waitForBytes(len(exp)) self.assertEqual(p1.transport.value(), exp)
self.assertEqual(a1.data, exp) self.assertEqual(p2.transport.value(), exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n" p1.transport.clear()
yield a2.waitForBytes(len(exp)) p2.transport.clear()
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us # all data they sent after the handshake should be given to us
exp = b"ok\n"+s1 s1 = b"data1"
yield a2.waitForBytes(len(exp)) p1.dataReceived(s1)
self.assertEqual(a2.data, exp) self.assertEqual(p2.transport.value(), s1)
a1.transport.loseConnection() p1.transport.loseConnection()
a2.transport.loseConnection() p2.transport.loseConnection()
@defer.inlineCallbacks
def test_unsided_sided(self): def test_unsided_sided(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator()) p2 = self.new_protocol()
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) + b"\n") p1.dataReceived(handshake(token1, side=None))
a2.transport.write(b"please relay " + hexlify(token1) + p2.dataReceived(handshake(token1, side=side1))
b" for side " + hexlify(side1) + b"\n")
# a correct handshake yields an ack, after which we can send # a correct handshake yields an ack, after which we can send
exp = b"ok\n" exp = b"ok\n"
yield a1.waitForBytes(len(exp)) self.assertEqual(p1.transport.value(), exp)
self.assertEqual(a1.data, exp) self.assertEqual(p2.transport.value(), exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n" p1.transport.clear()
yield a2.waitForBytes(len(exp)) p2.transport.clear()
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us # all data they sent after the handshake should be given to us
exp = b"ok\n"+s1 s1 = b"data1"
yield a2.waitForBytes(len(exp)) p1.dataReceived(s1)
self.assertEqual(a2.data, exp) self.assertEqual(p2.transport.value(), s1)
a1.transport.loseConnection() p1.transport.loseConnection()
a2.transport.loseConnection() p2.transport.loseConnection()
@defer.inlineCallbacks
def test_both_sided(self): def test_both_sided(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator()) p2 = self.new_protocol()
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
side2 = b"\x02"*8 side2 = b"\x02"*8
a1.transport.write(b"please relay " + hexlify(token1) + p1.dataReceived(handshake(token1, side=side1))
b" for side " + hexlify(side1) + b"\n") p2.dataReceived(handshake(token1, side=side2))
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side2) + b"\n")
# a correct handshake yields an ack, after which we can send # a correct handshake yields an ack, after which we can send
exp = b"ok\n" exp = b"ok\n"
yield a1.waitForBytes(len(exp)) self.assertEqual(p1.transport.value(), exp)
self.assertEqual(a1.data, exp) self.assertEqual(p2.transport.value(), exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n" p1.transport.clear()
yield a2.waitForBytes(len(exp)) p2.transport.clear()
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us # all data they sent after the handshake should be given to us
exp = b"ok\n"+s1 s1 = b"data1"
yield a2.waitForBytes(len(exp)) p1.dataReceived(s1)
self.assertEqual(a2.data, exp) self.assertEqual(p2.transport.value(), s1)
a1.transport.loseConnection() p1.transport.loseConnection()
a2.transport.loseConnection() p2.transport.loseConnection()
def count(self):
return sum([len(potentials)
for potentials
in self._transit_server._pending_requests.values()])
@defer.inlineCallbacks
def test_ignore_same_side(self): def test_ignore_same_side(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator()) p2 = self.new_protocol()
a2 = yield connectProtocol(ep, Accumulator()) p3 = self.new_protocol()
a3 = yield connectProtocol(ep, Accumulator())
disconnects = []
a1._disconnect.addCallback(disconnects.append)
a2._disconnect.addCallback(disconnects.append)
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n") p1.dataReceived(handshake(token1, side=side1))
# let that arrive self.assertEqual(self.count(), 1)
while self.count() == 0:
yield wait() p2.dataReceived(handshake(token1, side=side1))
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 1:
yield wait()
self.assertEqual(self.count(), 2) # same-side connections don't match self.assertEqual(self.count(), 2) # same-side connections don't match
# when the second side arrives, the spare first connection should be # when the second side arrives, the spare first connection should be
# closed # closed
side2 = b"\x02"*8 side2 = b"\x02"*8
a3.transport.write(b"please relay " + hexlify(token1) + p3.dataReceived(handshake(token1, side=side2))
b" for side " + hexlify(side2) + b"\n") self.assertEqual(self.count(), 0)
# let that arrive
while self.count() != 0:
yield wait()
self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server._pending_requests), 0)
self.assertEqual(len(self._transit_server._active_connections), 2) self.assertEqual(len(self._transit_server._active_connections), 2)
# That will trigger a disconnect on exactly one of (a1 or a2). Wait # That will trigger a disconnect on exactly one of (p1 or p2).
# until our client notices it. # The other connection should still be connected
while not disconnects: self.assertEqual(sum([int(t.transport.connected) for t in [p1, p2]]), 1)
yield wait()
# the other connection should still be connected
self.assertEqual(sum([int(t.transport.connected) for t in [a1, a2]]), 1)
a1.transport.loseConnection() p1.transport.loseConnection()
a2.transport.loseConnection() p2.transport.loseConnection()
a3.transport.loseConnection() p3.transport.loseConnection()
@defer.inlineCallbacks
def test_bad_handshake_old(self): def test_bad_handshake_old(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32 token1 = b"\x00"*32
# the server waits for the exact number of bytes in the expected p1.dataReceived(b"please DELAY " + hexlify(token1) + b"\n")
# handshake message. to trigger "bad handshake", we must match.
a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n")
exp = b"bad handshake\n" exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp)) self.assertEqual(p1.transport.value(), exp)
self.assertEqual(a1.data, exp) p1.transport.loseConnection()
a1.transport.loseConnection()
@defer.inlineCallbacks
def test_bad_handshake_old_slow(self): def test_bad_handshake_old_slow(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator())
a1.transport.write(b"please DELAY ") p1.dataReceived(b"please DELAY ")
# As in test_impatience_new_slow, the current state machine has code # As in test_impatience_new_slow, the current state machine has code
# that can only be reached if we insert a stall here, so dataReceived # that can only be reached if we insert a stall here, so dataReceived
# gets called twice. Hopefully we can delete this test once # gets called twice. Hopefully we can delete this test once
# dataReceived is refactored to remove that state. # dataReceived is refactored to remove that state.
d = defer.Deferred()
reactor.callLater(0.1, d.callback, None)
yield d
token1 = b"\x00"*32 token1 = b"\x00"*32
# the server waits for the exact number of bytes in the expected # the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match. # handshake message. to trigger "bad handshake", we must match.
a1.transport.write(hexlify(token1) + b"\n") p1.dataReceived(hexlify(token1) + b"\n")
exp = b"bad handshake\n" exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp)) self.assertEqual(p1.transport.value(), exp)
self.assertEqual(a1.data, exp)
a1.transport.loseConnection() p1.transport.loseConnection()
@defer.inlineCallbacks
def test_bad_handshake_new(self): def test_bad_handshake_new(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
# the server waits for the exact number of bytes in the expected # the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match. # handshake message. to trigger "bad handshake", we must match.
a1.transport.write(b"please DELAY " + hexlify(token1) + p1.dataReceived(b"please DELAY " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n") b" for side " + hexlify(side1) + b"\n")
exp = b"bad handshake\n" exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp)) self.assertEqual(p1.transport.value(), exp)
self.assertEqual(a1.data, exp)
a1.transport.loseConnection() p1.transport.loseConnection()
@defer.inlineCallbacks
def test_binary_handshake(self): def test_binary_handshake(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator())
binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff" binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff"
# the embedded \n makes the server trigger early, before the full # the embedded \n makes the server trigger early, before the full
@ -325,50 +238,41 @@ class _Transit:
# UnicodeDecodeError when it tried to coerce the incoming handshake # UnicodeDecodeError when it tried to coerce the incoming handshake
# to unicode, due to the ("\n" in buf) check. This was fixed to use # to unicode, due to the ("\n" in buf) check. This was fixed to use
# (b"\n" in buf). This exercises the old failure. # (b"\n" in buf). This exercises the old failure.
a1.transport.write(binary_bad_handshake) p1.dataReceived(binary_bad_handshake)
exp = b"bad handshake\n" exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp)) self.assertEqual(p1.transport.value(), exp)
self.assertEqual(a1.data, exp)
a1.transport.loseConnection() p1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience_old(self): def test_impatience_old(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32 token1 = b"\x00"*32
# sending too many bytes is impatience. # sending too many bytes is impatience.
a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") p1.dataReceived(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
exp = b"impatient\n" exp = b"impatient\n"
yield a1.waitForBytes(len(exp)) self.assertEqual(p1.transport.value(), exp)
self.assertEqual(a1.data, exp)
a1.transport.loseConnection() p1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience_new(self): def test_impatience_new(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
# sending too many bytes is impatience. # sending too many bytes is impatience.
a1.transport.write(b"please relay " + hexlify(token1) + p1.dataReceived(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\nNOWNOWNOW") b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
exp = b"impatient\n" exp = b"impatient\n"
yield a1.waitForBytes(len(exp)) self.assertEqual(p1.transport.value(), exp)
self.assertEqual(a1.data, exp)
a1.transport.loseConnection() p1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience_new_slow(self): def test_impatience_new_slow(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator())
# For full coverage, we need dataReceived to see a particular framing # For full coverage, we need dataReceived to see a particular framing
# of these two pieces of data, and ITCPTransport doesn't have flush() # of these two pieces of data, and ITCPTransport doesn't have flush()
# (which probably wouldn't work anyways). For now, force a 100ms # (which probably wouldn't work anyways). For now, force a 100ms
@ -381,20 +285,16 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
# sending too many bytes is impatience. # sending too many bytes is impatience.
a1.transport.write(b"please relay " + hexlify(token1) + p1.dataReceived(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n") b" for side " + hexlify(side1) + b"\n")
d = defer.Deferred()
reactor.callLater(0.1, d.callback, None)
yield d
a1.transport.write(b"NOWNOWNOW") p1.dataReceived(b"NOWNOWNOW")
exp = b"impatient\n" exp = b"impatient\n"
yield a1.waitForBytes(len(exp)) self.assertEqual(p1.transport.value(), exp)
self.assertEqual(a1.data, exp)
a1.transport.loseConnection() p1.transport.loseConnection()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_short_handshake(self): def test_short_handshake(self):
@ -420,9 +320,8 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
log_requests = False log_requests = False
class Usage(ServerBase, unittest.TestCase): class Usage(ServerBase, unittest.TestCase):
@defer.inlineCallbacks
def setUp(self): def setUp(self):
yield super(Usage, self).setUp() super(Usage, self).setUp()
self._usage = [] self._usage = []
def record(started, result, total_bytes, total_time, waiting_time): def record(started, result, total_bytes, total_time, waiting_time):
self._usage.append((started, result, total_bytes, self._usage.append((started, result, total_bytes,
@ -470,130 +369,83 @@ class Usage(ServerBase, unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_errory(self): def test_errory(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator())
a1.transport.write(b"this is a very bad handshake\n") p1.dataReceived(b"this is a very bad handshake\n")
# that will log the "errory" usage event, then drop the connection # that will log the "errory" usage event, then drop the connection
yield a1._disconnect p1.transport.loseConnection()
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "errory", self._usage) self.assertEqual(result, "errory", self._usage)
@defer.inlineCallbacks
def test_lonely(self): def test_lonely(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) + p1.dataReceived(handshake(token1, side=side1))
b" for side " + hexlify(side1) + b"\n")
while not self._transit_server._pending_requests:
yield wait() # wait for the server to see the connection
# now we disconnect before the peer connects # now we disconnect before the peer connects
a1.transport.loseConnection() p1.transport.loseConnection()
yield a1._disconnect
while self._transit_server._pending_requests:
yield wait() # wait for the server to see the disconnect too
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "lonely", self._usage) self.assertEqual(result, "lonely", self._usage)
self.assertIdentical(waiting_time, None) self.assertIdentical(waiting_time, None)
@defer.inlineCallbacks
def test_one_happy_one_jilted(self): def test_one_happy_one_jilted(self):
ep = clientFromString(reactor, self.transit) p1 = self.new_protocol()
a1 = yield connectProtocol(ep, Accumulator()) p2 = self.new_protocol()
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
side2 = b"\x02"*8 side2 = b"\x02"*8
a1.transport.write(b"please relay " + hexlify(token1) + p1.dataReceived(handshake(token1, side=side1))
b" for side " + hexlify(side1) + b"\n") p2.dataReceived(handshake(token1, side=side2))
while not self._transit_server._pending_requests:
yield wait() # make sure a1 connects first self.assertEqual(self._usage, []) # no events yet
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side2) + b"\n") p1.dataReceived(b"\x00" * 13)
while not self._transit_server._active_connections: p2.dataReceived(b"\xff" * 7)
yield wait() # wait for the server to see the connection
self.assertEqual(len(self._transit_server._pending_requests), 0) p1.transport.loseConnection()
self.assertEqual(self._usage, []) # no events yet
a1.transport.write(b"\x00" * 13)
yield a2.waitForBytes(13)
a2.transport.write(b"\xff" * 7)
yield a1.waitForBytes(7)
a1.transport.loseConnection()
yield a1._disconnect
while self._transit_server._active_connections:
yield wait()
yield a2._disconnect
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "happy", self._usage) self.assertEqual(result, "happy", self._usage)
self.assertEqual(total_bytes, 20) self.assertEqual(total_bytes, 20)
self.assertNotIdentical(waiting_time, None) self.assertNotIdentical(waiting_time, None)
@defer.inlineCallbacks
def test_redundant(self): def test_redundant(self):
ep = clientFromString(reactor, self.transit) p1a = self.new_protocol()
a1a = yield connectProtocol(ep, Accumulator()) p1b = self.new_protocol()
a1b = yield connectProtocol(ep, Accumulator()) p1c = self.new_protocol()
a1c = yield connectProtocol(ep, Accumulator()) p2 = self.new_protocol()
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
side2 = b"\x02"*8 side2 = b"\x02"*8
a1a.transport.write(b"please relay " + hexlify(token1) + p1a.dataReceived(handshake(token1, side=side1))
b" for side " + hexlify(side1) + b"\n") p1b.dataReceived(handshake(token1, side=side1))
def count_requests():
return sum([len(v)
for v in self._transit_server._pending_requests.values()])
while count_requests() < 1:
yield wait()
a1b.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
while count_requests() < 2:
yield wait()
# connect and disconnect a third client (for side1) to exercise the # connect and disconnect a third client (for side1) to exercise the
# code that removes a pending connection without removing the entire # code that removes a pending connection without removing the entire
# token # token
a1c.transport.write(b"please relay " + hexlify(token1) + p1c.dataReceived(handshake(token1, side=side1))
b" for side " + hexlify(side1) + b"\n") p1c.transport.loseConnection()
while count_requests() < 3:
yield wait()
a1c.transport.loseConnection()
yield a1c._disconnect
while count_requests() > 2:
yield wait()
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "lonely", self._usage) self.assertEqual(result, "lonely", self._usage)
a2.transport.write(b"please relay " + hexlify(token1) + p2.dataReceived(handshake(token1, side=side2))
b" for side " + hexlify(side2) + b"\n") self.assertEqual(len(self._transit_server._pending_requests), 0)
# this will claim one of (a1a, a1b), and close the other as redundant
while not self._transit_server._active_connections:
yield wait() # wait for the server to see the connection
self.assertEqual(count_requests(), 0)
self.assertEqual(len(self._usage), 2, self._usage) self.assertEqual(len(self._usage), 2, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[1] (started, result, total_bytes, total_time, waiting_time) = self._usage[1]
self.assertEqual(result, "redundant", self._usage) self.assertEqual(result, "redundant", self._usage)
# one of the these is unecessary, but probably harmless # one of the these is unecessary, but probably harmless
a1a.transport.loseConnection() p1a.transport.loseConnection()
a1b.transport.loseConnection() p1b.transport.loseConnection()
yield a1a._disconnect
yield a1b._disconnect
while self._transit_server._active_connections:
yield wait()
yield a2._disconnect
self.assertEqual(len(self._usage), 3, self._usage) self.assertEqual(len(self._usage), 3, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[2] (started, result, total_bytes, total_time, waiting_time) = self._usage[2]
self.assertEqual(result, "happy", self._usage) self.assertEqual(result, "happy", self._usage)

View File

@ -52,7 +52,10 @@ class TransitConnection(LineReceiver):
def connectionMade(self): def connectionMade(self):
self._started = time.time() self._started = time.time()
self._log_requests = self.factory._log_requests self._log_requests = self.factory._log_requests
self.transport.setTcpKeepAlive(True) try:
self.transport.setTcpKeepAlive(True)
except AttributeError:
pass
def lineReceived(self, line): def lineReceived(self, line):
# old: "please relay {64}\n" # old: "please relay {64}\n"
@ -257,14 +260,14 @@ class Transit(protocol.ServerFactory):
# drop and stop tracking the rest # drop and stop tracking the rest
potentials.remove(old) potentials.remove(old)
for (_, leftover_tc) in potentials: for (_, leftover_tc) in potentials.copy():
# Don't record this as errory. It's just a spare connection # Don't record this as errory. It's just a spare connection
# from the same side as a connection that got used. This # from the same side as a connection that got used. This
# can happen if the connection hint contains multiple # can happen if the connection hint contains multiple
# addresses (we don't currently support those, but it'd # addresses (we don't currently support those, but it'd
# probably be useful in the future). # probably be useful in the future).
leftover_tc.disconnect_redundant() leftover_tc.disconnect_redundant()
self._pending_requests.pop(token) self._pending_requests.pop(token, None)
# glue the two ends together # glue the two ends together
self._active_connections.add(new_tc) self._active_connections.add(new_tc)