Use StringTransportWithDisconnection for transit server tests.
Replace the use of TCP in the test suite with Twisted's StringTransport, specifically StringTransportWithDisconnection which allows us to trigger a disconnect event on the server side during testing. The `dataReceived` method on the server is now called directly, and any effects will be realised immediately. Responses are available to the test client using the `value()` method of the transport objects, and the buffer can be cleared using `clear()`. This allows all asynchronous behaviour to be removed from the transit server test suite. Furthermore, as we never have to wait for the server, tests no longer hang if they fail: the errors are encountered immediately.
This commit is contained in:
parent
ac7415a4d0
commit
45824ca5d6
|
@ -1,30 +1,28 @@
|
|||
#from __future__ import unicode_literals
|
||||
from twisted.internet import reactor, endpoints
|
||||
from twisted.internet.defer import inlineCallbacks
|
||||
from twisted.test import proto_helpers
|
||||
from ..transit_server import Transit
|
||||
|
||||
class ServerBase:
|
||||
log_requests = False
|
||||
|
||||
@inlineCallbacks
|
||||
def setUp(self):
|
||||
self._lp = None
|
||||
if self.log_requests:
|
||||
blur_usage = None
|
||||
else:
|
||||
blur_usage = 60.0
|
||||
yield self._setup_relay(blur_usage=blur_usage)
|
||||
self._setup_relay(blur_usage=blur_usage)
|
||||
self._transit_server._debug_log = self.log_requests
|
||||
|
||||
@inlineCallbacks
|
||||
def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None):
|
||||
ep = endpoints.TCP4ServerEndpoint(reactor, 0, interface="127.0.0.1")
|
||||
self._transit_server = Transit(blur_usage=blur_usage,
|
||||
log_file=log_file, usage_db=usage_db)
|
||||
self._lp = yield ep.listen(self._transit_server)
|
||||
addr = self._lp.getHost()
|
||||
# ws://127.0.0.1:%d/wormhole-relay/ws
|
||||
self.transit = u"tcp:127.0.0.1:%d" % addr.port
|
||||
|
||||
def new_protocol(self):
|
||||
protocol = self._transit_server.buildProtocol(('127.0.0.1', 0))
|
||||
transport = proto_helpers.StringTransportWithDisconnection()
|
||||
protocol.makeConnection(transport)
|
||||
transport.protocol = protocol
|
||||
return protocol
|
||||
|
||||
def tearDown(self):
|
||||
if self._lp:
|
||||
|
|
|
@ -1,42 +1,28 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from binascii import hexlify
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import protocol, reactor, defer
|
||||
from twisted.internet.endpoints import clientFromString, connectProtocol
|
||||
from twisted.internet import reactor, defer
|
||||
from .common import ServerBase
|
||||
from .. import transit_server
|
||||
|
||||
class Accumulator(protocol.Protocol):
|
||||
def __init__(self):
|
||||
self.data = b""
|
||||
self.count = 0
|
||||
self._wait = None
|
||||
self._disconnect = defer.Deferred()
|
||||
def waitForBytes(self, more):
|
||||
assert self._wait is None
|
||||
self.count = more
|
||||
self._wait = defer.Deferred()
|
||||
self._check_done()
|
||||
return self._wait
|
||||
def dataReceived(self, data):
|
||||
self.data = self.data + data
|
||||
self._check_done()
|
||||
def _check_done(self):
|
||||
if self._wait and len(self.data) >= self.count:
|
||||
d = self._wait
|
||||
self._wait = None
|
||||
d.callback(self)
|
||||
def connectionLost(self, why):
|
||||
if self._wait:
|
||||
self._wait.errback(RuntimeError("closed"))
|
||||
self._disconnect.callback(None)
|
||||
|
||||
def wait():
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(0.001, d.callback, None)
|
||||
return d
|
||||
|
||||
def handshake(token, side=None):
|
||||
hs = b"please relay " + hexlify(token)
|
||||
if side is not None:
|
||||
hs += b" for side " + hexlify(side)
|
||||
hs += b"\n"
|
||||
return hs
|
||||
|
||||
class _Transit:
|
||||
def count(self):
|
||||
return sum([len(potentials)
|
||||
for potentials
|
||||
in self._transit_server._pending_requests.values()])
|
||||
|
||||
def test_blur_size(self):
|
||||
blur = transit_server.blur_size
|
||||
self.failUnlessEqual(blur(0), 0)
|
||||
|
@ -55,268 +41,195 @@ class _Transit:
|
|||
self.failUnlessEqual(blur(1100e6), 1100e6)
|
||||
self.failUnlessEqual(blur(1150e6), 1200e6)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_register(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
|
||||
# let that arrive
|
||||
while self.count() == 0:
|
||||
yield wait()
|
||||
p1.dataReceived(handshake(token1, side1))
|
||||
self.assertEqual(self.count(), 1)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
|
||||
# let that get removed
|
||||
while self.count() > 0:
|
||||
yield wait()
|
||||
p1.transport.loseConnection()
|
||||
self.assertEqual(self.count(), 0)
|
||||
|
||||
# the token should be removed too
|
||||
self.assertEqual(len(self._transit_server._pending_requests), 0)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_both_unsided(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
p2 = self.new_protocol()
|
||||
|
||||
token1 = b"\x00"*32
|
||||
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
||||
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
||||
p1.dataReceived(handshake(token1, side=None))
|
||||
p2.dataReceived(handshake(token1, side=None))
|
||||
|
||||
# a correct handshake yields an ack, after which we can send
|
||||
exp = b"ok\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
self.assertEqual(p1.transport.value(), exp)
|
||||
self.assertEqual(p2.transport.value(), exp)
|
||||
|
||||
p1.transport.clear()
|
||||
p2.transport.clear()
|
||||
|
||||
s1 = b"data1"
|
||||
a1.transport.write(s1)
|
||||
p1.dataReceived(s1)
|
||||
self.assertEqual(p2.transport.value(), s1)
|
||||
|
||||
exp = b"ok\n"
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
p1.transport.loseConnection()
|
||||
p2.transport.loseConnection()
|
||||
|
||||
# all data they sent after the handshake should be given to us
|
||||
exp = b"ok\n"+s1
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_sided_unsided(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
p2 = self.new_protocol()
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
||||
p1.dataReceived(handshake(token1, side=side1))
|
||||
p2.dataReceived(handshake(token1, side=None))
|
||||
|
||||
# a correct handshake yields an ack, after which we can send
|
||||
exp = b"ok\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
s1 = b"data1"
|
||||
a1.transport.write(s1)
|
||||
self.assertEqual(p1.transport.value(), exp)
|
||||
self.assertEqual(p2.transport.value(), exp)
|
||||
|
||||
exp = b"ok\n"
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
p1.transport.clear()
|
||||
p2.transport.clear()
|
||||
|
||||
# all data they sent after the handshake should be given to us
|
||||
exp = b"ok\n"+s1
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
s1 = b"data1"
|
||||
p1.dataReceived(s1)
|
||||
self.assertEqual(p2.transport.value(), s1)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
p1.transport.loseConnection()
|
||||
p2.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_unsided_sided(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
p2 = self.new_protocol()
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
p1.dataReceived(handshake(token1, side=None))
|
||||
p2.dataReceived(handshake(token1, side=side1))
|
||||
|
||||
# a correct handshake yields an ack, after which we can send
|
||||
exp = b"ok\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
s1 = b"data1"
|
||||
a1.transport.write(s1)
|
||||
self.assertEqual(p1.transport.value(), exp)
|
||||
self.assertEqual(p2.transport.value(), exp)
|
||||
|
||||
exp = b"ok\n"
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
p1.transport.clear()
|
||||
p2.transport.clear()
|
||||
|
||||
# all data they sent after the handshake should be given to us
|
||||
exp = b"ok\n"+s1
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
s1 = b"data1"
|
||||
p1.dataReceived(s1)
|
||||
self.assertEqual(p2.transport.value(), s1)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
p1.transport.loseConnection()
|
||||
p2.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_both_sided(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
p2 = self.new_protocol()
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
side2 = b"\x02"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side2) + b"\n")
|
||||
p1.dataReceived(handshake(token1, side=side1))
|
||||
p2.dataReceived(handshake(token1, side=side2))
|
||||
|
||||
# a correct handshake yields an ack, after which we can send
|
||||
exp = b"ok\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
s1 = b"data1"
|
||||
a1.transport.write(s1)
|
||||
self.assertEqual(p1.transport.value(), exp)
|
||||
self.assertEqual(p2.transport.value(), exp)
|
||||
|
||||
exp = b"ok\n"
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
p1.transport.clear()
|
||||
p2.transport.clear()
|
||||
|
||||
# all data they sent after the handshake should be given to us
|
||||
exp = b"ok\n"+s1
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
s1 = b"data1"
|
||||
p1.dataReceived(s1)
|
||||
self.assertEqual(p2.transport.value(), s1)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
p1.transport.loseConnection()
|
||||
p2.transport.loseConnection()
|
||||
|
||||
def count(self):
|
||||
return sum([len(potentials)
|
||||
for potentials
|
||||
in self._transit_server._pending_requests.values()])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_ignore_same_side(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
a3 = yield connectProtocol(ep, Accumulator())
|
||||
disconnects = []
|
||||
a1._disconnect.addCallback(disconnects.append)
|
||||
a2._disconnect.addCallback(disconnects.append)
|
||||
p1 = self.new_protocol()
|
||||
p2 = self.new_protocol()
|
||||
p3 = self.new_protocol()
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
# let that arrive
|
||||
while self.count() == 0:
|
||||
yield wait()
|
||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
# let that arrive
|
||||
while self.count() == 1:
|
||||
yield wait()
|
||||
|
||||
p1.dataReceived(handshake(token1, side=side1))
|
||||
self.assertEqual(self.count(), 1)
|
||||
|
||||
p2.dataReceived(handshake(token1, side=side1))
|
||||
self.assertEqual(self.count(), 2) # same-side connections don't match
|
||||
|
||||
# when the second side arrives, the spare first connection should be
|
||||
# closed
|
||||
side2 = b"\x02"*8
|
||||
a3.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side2) + b"\n")
|
||||
# let that arrive
|
||||
while self.count() != 0:
|
||||
yield wait()
|
||||
p3.dataReceived(handshake(token1, side=side2))
|
||||
self.assertEqual(self.count(), 0)
|
||||
self.assertEqual(len(self._transit_server._pending_requests), 0)
|
||||
self.assertEqual(len(self._transit_server._active_connections), 2)
|
||||
# That will trigger a disconnect on exactly one of (a1 or a2). Wait
|
||||
# until our client notices it.
|
||||
while not disconnects:
|
||||
yield wait()
|
||||
# the other connection should still be connected
|
||||
self.assertEqual(sum([int(t.transport.connected) for t in [a1, a2]]), 1)
|
||||
# That will trigger a disconnect on exactly one of (p1 or p2).
|
||||
# The other connection should still be connected
|
||||
self.assertEqual(sum([int(t.transport.connected) for t in [p1, p2]]), 1)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
a3.transport.loseConnection()
|
||||
p1.transport.loseConnection()
|
||||
p2.transport.loseConnection()
|
||||
p3.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_bad_handshake_old(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
|
||||
token1 = b"\x00"*32
|
||||
# the server waits for the exact number of bytes in the expected
|
||||
# handshake message. to trigger "bad handshake", we must match.
|
||||
a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n")
|
||||
p1.dataReceived(b"please DELAY " + hexlify(token1) + b"\n")
|
||||
|
||||
exp = b"bad handshake\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
self.assertEqual(p1.transport.value(), exp)
|
||||
p1.transport.loseConnection()
|
||||
|
||||
a1.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_bad_handshake_old_slow(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
|
||||
a1.transport.write(b"please DELAY ")
|
||||
p1.dataReceived(b"please DELAY ")
|
||||
# As in test_impatience_new_slow, the current state machine has code
|
||||
# that can only be reached if we insert a stall here, so dataReceived
|
||||
# gets called twice. Hopefully we can delete this test once
|
||||
# dataReceived is refactored to remove that state.
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(0.1, d.callback, None)
|
||||
yield d
|
||||
|
||||
token1 = b"\x00"*32
|
||||
# the server waits for the exact number of bytes in the expected
|
||||
# handshake message. to trigger "bad handshake", we must match.
|
||||
a1.transport.write(hexlify(token1) + b"\n")
|
||||
p1.dataReceived(hexlify(token1) + b"\n")
|
||||
|
||||
exp = b"bad handshake\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
self.assertEqual(p1.transport.value(), exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
p1.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_bad_handshake_new(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
# the server waits for the exact number of bytes in the expected
|
||||
# handshake message. to trigger "bad handshake", we must match.
|
||||
a1.transport.write(b"please DELAY " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
p1.dataReceived(b"please DELAY " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
|
||||
exp = b"bad handshake\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
self.assertEqual(p1.transport.value(), exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
p1.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_binary_handshake(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
|
||||
binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff"
|
||||
# the embedded \n makes the server trigger early, before the full
|
||||
|
@ -325,50 +238,41 @@ class _Transit:
|
|||
# UnicodeDecodeError when it tried to coerce the incoming handshake
|
||||
# to unicode, due to the ("\n" in buf) check. This was fixed to use
|
||||
# (b"\n" in buf). This exercises the old failure.
|
||||
a1.transport.write(binary_bad_handshake)
|
||||
p1.dataReceived(binary_bad_handshake)
|
||||
|
||||
exp = b"bad handshake\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
self.assertEqual(p1.transport.value(), exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
p1.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_impatience_old(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
|
||||
token1 = b"\x00"*32
|
||||
# sending too many bytes is impatience.
|
||||
a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
|
||||
p1.dataReceived(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
|
||||
|
||||
exp = b"impatient\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
self.assertEqual(p1.transport.value(), exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
p1.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_impatience_new(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
# sending too many bytes is impatience.
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
|
||||
p1.dataReceived(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
|
||||
|
||||
exp = b"impatient\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
self.assertEqual(p1.transport.value(), exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
p1.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_impatience_new_slow(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
# For full coverage, we need dataReceived to see a particular framing
|
||||
# of these two pieces of data, and ITCPTransport doesn't have flush()
|
||||
# (which probably wouldn't work anyways). For now, force a 100ms
|
||||
|
@ -381,20 +285,16 @@ class _Transit:
|
|||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
# sending too many bytes is impatience.
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
p1.dataReceived(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(0.1, d.callback, None)
|
||||
yield d
|
||||
|
||||
a1.transport.write(b"NOWNOWNOW")
|
||||
p1.dataReceived(b"NOWNOWNOW")
|
||||
|
||||
exp = b"impatient\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
self.assertEqual(p1.transport.value(), exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
p1.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_short_handshake(self):
|
||||
|
@ -420,9 +320,8 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
|
|||
log_requests = False
|
||||
|
||||
class Usage(ServerBase, unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
yield super(Usage, self).setUp()
|
||||
super(Usage, self).setUp()
|
||||
self._usage = []
|
||||
def record(started, result, total_bytes, total_time, waiting_time):
|
||||
self._usage.append((started, result, total_bytes,
|
||||
|
@ -470,130 +369,83 @@ class Usage(ServerBase, unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_errory(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
|
||||
a1.transport.write(b"this is a very bad handshake\n")
|
||||
p1.dataReceived(b"this is a very bad handshake\n")
|
||||
# that will log the "errory" usage event, then drop the connection
|
||||
yield a1._disconnect
|
||||
p1.transport.loseConnection()
|
||||
self.assertEqual(len(self._usage), 1, self._usage)
|
||||
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
||||
self.assertEqual(result, "errory", self._usage)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_lonely(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
while not self._transit_server._pending_requests:
|
||||
yield wait() # wait for the server to see the connection
|
||||
p1.dataReceived(handshake(token1, side=side1))
|
||||
# now we disconnect before the peer connects
|
||||
a1.transport.loseConnection()
|
||||
yield a1._disconnect
|
||||
while self._transit_server._pending_requests:
|
||||
yield wait() # wait for the server to see the disconnect too
|
||||
p1.transport.loseConnection()
|
||||
|
||||
self.assertEqual(len(self._usage), 1, self._usage)
|
||||
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
||||
self.assertEqual(result, "lonely", self._usage)
|
||||
self.assertIdentical(waiting_time, None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_one_happy_one_jilted(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
p1 = self.new_protocol()
|
||||
p2 = self.new_protocol()
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
side2 = b"\x02"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
while not self._transit_server._pending_requests:
|
||||
yield wait() # make sure a1 connects first
|
||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side2) + b"\n")
|
||||
while not self._transit_server._active_connections:
|
||||
yield wait() # wait for the server to see the connection
|
||||
self.assertEqual(len(self._transit_server._pending_requests), 0)
|
||||
self.assertEqual(self._usage, []) # no events yet
|
||||
a1.transport.write(b"\x00" * 13)
|
||||
yield a2.waitForBytes(13)
|
||||
a2.transport.write(b"\xff" * 7)
|
||||
yield a1.waitForBytes(7)
|
||||
p1.dataReceived(handshake(token1, side=side1))
|
||||
p2.dataReceived(handshake(token1, side=side2))
|
||||
|
||||
self.assertEqual(self._usage, []) # no events yet
|
||||
|
||||
p1.dataReceived(b"\x00" * 13)
|
||||
p2.dataReceived(b"\xff" * 7)
|
||||
|
||||
p1.transport.loseConnection()
|
||||
|
||||
a1.transport.loseConnection()
|
||||
yield a1._disconnect
|
||||
while self._transit_server._active_connections:
|
||||
yield wait()
|
||||
yield a2._disconnect
|
||||
self.assertEqual(len(self._usage), 1, self._usage)
|
||||
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
||||
self.assertEqual(result, "happy", self._usage)
|
||||
self.assertEqual(total_bytes, 20)
|
||||
self.assertNotIdentical(waiting_time, None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_redundant(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1a = yield connectProtocol(ep, Accumulator())
|
||||
a1b = yield connectProtocol(ep, Accumulator())
|
||||
a1c = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
p1a = self.new_protocol()
|
||||
p1b = self.new_protocol()
|
||||
p1c = self.new_protocol()
|
||||
p2 = self.new_protocol()
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
side2 = b"\x02"*8
|
||||
a1a.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
def count_requests():
|
||||
return sum([len(v)
|
||||
for v in self._transit_server._pending_requests.values()])
|
||||
while count_requests() < 1:
|
||||
yield wait()
|
||||
a1b.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
while count_requests() < 2:
|
||||
yield wait()
|
||||
p1a.dataReceived(handshake(token1, side=side1))
|
||||
p1b.dataReceived(handshake(token1, side=side1))
|
||||
|
||||
# connect and disconnect a third client (for side1) to exercise the
|
||||
# code that removes a pending connection without removing the entire
|
||||
# token
|
||||
a1c.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
while count_requests() < 3:
|
||||
yield wait()
|
||||
a1c.transport.loseConnection()
|
||||
yield a1c._disconnect
|
||||
while count_requests() > 2:
|
||||
yield wait()
|
||||
p1c.dataReceived(handshake(token1, side=side1))
|
||||
p1c.transport.loseConnection()
|
||||
|
||||
self.assertEqual(len(self._usage), 1, self._usage)
|
||||
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
||||
self.assertEqual(result, "lonely", self._usage)
|
||||
|
||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side2) + b"\n")
|
||||
# this will claim one of (a1a, a1b), and close the other as redundant
|
||||
while not self._transit_server._active_connections:
|
||||
yield wait() # wait for the server to see the connection
|
||||
self.assertEqual(count_requests(), 0)
|
||||
p2.dataReceived(handshake(token1, side=side2))
|
||||
self.assertEqual(len(self._transit_server._pending_requests), 0)
|
||||
self.assertEqual(len(self._usage), 2, self._usage)
|
||||
(started, result, total_bytes, total_time, waiting_time) = self._usage[1]
|
||||
self.assertEqual(result, "redundant", self._usage)
|
||||
|
||||
# one of the these is unecessary, but probably harmless
|
||||
a1a.transport.loseConnection()
|
||||
a1b.transport.loseConnection()
|
||||
yield a1a._disconnect
|
||||
yield a1b._disconnect
|
||||
while self._transit_server._active_connections:
|
||||
yield wait()
|
||||
yield a2._disconnect
|
||||
p1a.transport.loseConnection()
|
||||
p1b.transport.loseConnection()
|
||||
self.assertEqual(len(self._usage), 3, self._usage)
|
||||
(started, result, total_bytes, total_time, waiting_time) = self._usage[2]
|
||||
self.assertEqual(result, "happy", self._usage)
|
||||
|
||||
|
|
|
@ -52,7 +52,10 @@ class TransitConnection(LineReceiver):
|
|||
def connectionMade(self):
|
||||
self._started = time.time()
|
||||
self._log_requests = self.factory._log_requests
|
||||
self.transport.setTcpKeepAlive(True)
|
||||
try:
|
||||
self.transport.setTcpKeepAlive(True)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def lineReceived(self, line):
|
||||
# old: "please relay {64}\n"
|
||||
|
@ -257,14 +260,14 @@ class Transit(protocol.ServerFactory):
|
|||
|
||||
# drop and stop tracking the rest
|
||||
potentials.remove(old)
|
||||
for (_, leftover_tc) in potentials:
|
||||
for (_, leftover_tc) in potentials.copy():
|
||||
# Don't record this as errory. It's just a spare connection
|
||||
# from the same side as a connection that got used. This
|
||||
# can happen if the connection hint contains multiple
|
||||
# addresses (we don't currently support those, but it'd
|
||||
# probably be useful in the future).
|
||||
leftover_tc.disconnect_redundant()
|
||||
self._pending_requests.pop(token)
|
||||
self._pending_requests.pop(token, None)
|
||||
|
||||
# glue the two ends together
|
||||
self._active_connections.add(new_tc)
|
||||
|
|
Loading…
Reference in New Issue
Block a user