Merge PR 14: improve tests
many thanks to @sigwinch28 for the improvements closes #14
This commit is contained in:
commit
46ec26f2bb
|
@ -1,30 +1,28 @@
|
||||||
#from __future__ import unicode_literals
|
from twisted.test import proto_helpers
|
||||||
from twisted.internet import reactor, endpoints
|
|
||||||
from twisted.internet.defer import inlineCallbacks
|
|
||||||
from ..transit_server import Transit
|
from ..transit_server import Transit
|
||||||
|
|
||||||
class ServerBase:
|
class ServerBase:
|
||||||
log_requests = False
|
log_requests = False
|
||||||
|
|
||||||
@inlineCallbacks
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self._lp = None
|
self._lp = None
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
blur_usage = None
|
blur_usage = None
|
||||||
else:
|
else:
|
||||||
blur_usage = 60.0
|
blur_usage = 60.0
|
||||||
yield self._setup_relay(blur_usage=blur_usage)
|
self._setup_relay(blur_usage=blur_usage)
|
||||||
self._transit_server._debug_log = self.log_requests
|
self._transit_server._debug_log = self.log_requests
|
||||||
|
|
||||||
@inlineCallbacks
|
|
||||||
def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None):
|
def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None):
|
||||||
ep = endpoints.TCP4ServerEndpoint(reactor, 0, interface="127.0.0.1")
|
|
||||||
self._transit_server = Transit(blur_usage=blur_usage,
|
self._transit_server = Transit(blur_usage=blur_usage,
|
||||||
log_file=log_file, usage_db=usage_db)
|
log_file=log_file, usage_db=usage_db)
|
||||||
self._lp = yield ep.listen(self._transit_server)
|
|
||||||
addr = self._lp.getHost()
|
def new_protocol(self):
|
||||||
# ws://127.0.0.1:%d/wormhole-relay/ws
|
protocol = self._transit_server.buildProtocol(('127.0.0.1', 0))
|
||||||
self.transit = u"tcp:127.0.0.1:%d" % addr.port
|
transport = proto_helpers.StringTransportWithDisconnection()
|
||||||
|
protocol.makeConnection(transport)
|
||||||
|
transport.protocol = protocol
|
||||||
|
return protocol
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
if self._lp:
|
if self._lp:
|
||||||
|
|
|
@ -1,42 +1,22 @@
|
||||||
from __future__ import print_function, unicode_literals
|
from __future__ import print_function, unicode_literals
|
||||||
from binascii import hexlify
|
from binascii import hexlify
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.internet import protocol, reactor, defer
|
|
||||||
from twisted.internet.endpoints import clientFromString, connectProtocol
|
|
||||||
from .common import ServerBase
|
from .common import ServerBase
|
||||||
from .. import transit_server
|
from .. import transit_server
|
||||||
|
|
||||||
class Accumulator(protocol.Protocol):
|
def handshake(token, side=None):
|
||||||
def __init__(self):
|
hs = b"please relay " + hexlify(token)
|
||||||
self.data = b""
|
if side is not None:
|
||||||
self.count = 0
|
hs += b" for side " + hexlify(side)
|
||||||
self._wait = None
|
hs += b"\n"
|
||||||
self._disconnect = defer.Deferred()
|
return hs
|
||||||
def waitForBytes(self, more):
|
|
||||||
assert self._wait is None
|
|
||||||
self.count = more
|
|
||||||
self._wait = defer.Deferred()
|
|
||||||
self._check_done()
|
|
||||||
return self._wait
|
|
||||||
def dataReceived(self, data):
|
|
||||||
self.data = self.data + data
|
|
||||||
self._check_done()
|
|
||||||
def _check_done(self):
|
|
||||||
if self._wait and len(self.data) >= self.count:
|
|
||||||
d = self._wait
|
|
||||||
self._wait = None
|
|
||||||
d.callback(self)
|
|
||||||
def connectionLost(self, why):
|
|
||||||
if self._wait:
|
|
||||||
self._wait.errback(RuntimeError("closed"))
|
|
||||||
self._disconnect.callback(None)
|
|
||||||
|
|
||||||
def wait():
|
|
||||||
d = defer.Deferred()
|
|
||||||
reactor.callLater(0.001, d.callback, None)
|
|
||||||
return d
|
|
||||||
|
|
||||||
class _Transit:
|
class _Transit:
|
||||||
|
def count(self):
|
||||||
|
return sum([len(potentials)
|
||||||
|
for potentials
|
||||||
|
in self._transit_server._pending_requests.values()])
|
||||||
|
|
||||||
def test_blur_size(self):
|
def test_blur_size(self):
|
||||||
blur = transit_server.blur_size
|
blur = transit_server.blur_size
|
||||||
self.failUnlessEqual(blur(0), 0)
|
self.failUnlessEqual(blur(0), 0)
|
||||||
|
@ -55,268 +35,195 @@ class _Transit:
|
||||||
self.failUnlessEqual(blur(1100e6), 1100e6)
|
self.failUnlessEqual(blur(1100e6), 1100e6)
|
||||||
self.failUnlessEqual(blur(1150e6), 1200e6)
|
self.failUnlessEqual(blur(1150e6), 1200e6)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_register(self):
|
def test_register(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
side1 = b"\x01"*8
|
side1 = b"\x01"*8
|
||||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
|
||||||
b" for side " + hexlify(side1) + b"\n")
|
|
||||||
|
|
||||||
# let that arrive
|
p1.dataReceived(handshake(token1, side1))
|
||||||
while self.count() == 0:
|
|
||||||
yield wait()
|
|
||||||
self.assertEqual(self.count(), 1)
|
self.assertEqual(self.count(), 1)
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
|
|
||||||
# let that get removed
|
|
||||||
while self.count() > 0:
|
|
||||||
yield wait()
|
|
||||||
self.assertEqual(self.count(), 0)
|
self.assertEqual(self.count(), 0)
|
||||||
|
|
||||||
# the token should be removed too
|
# the token should be removed too
|
||||||
self.assertEqual(len(self._transit_server._pending_requests), 0)
|
self.assertEqual(len(self._transit_server._pending_requests), 0)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_both_unsided(self):
|
def test_both_unsided(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
p2 = self.new_protocol()
|
||||||
a2 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
p1.dataReceived(handshake(token1, side=None))
|
||||||
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
p2.dataReceived(handshake(token1, side=None))
|
||||||
|
|
||||||
# a correct handshake yields an ack, after which we can send
|
# a correct handshake yields an ack, after which we can send
|
||||||
exp = b"ok\n"
|
exp = b"ok\n"
|
||||||
yield a1.waitForBytes(len(exp))
|
self.assertEqual(p1.transport.value(), exp)
|
||||||
self.assertEqual(a1.data, exp)
|
self.assertEqual(p2.transport.value(), exp)
|
||||||
|
|
||||||
|
p1.transport.clear()
|
||||||
|
p2.transport.clear()
|
||||||
|
|
||||||
s1 = b"data1"
|
s1 = b"data1"
|
||||||
a1.transport.write(s1)
|
p1.dataReceived(s1)
|
||||||
|
self.assertEqual(p2.transport.value(), s1)
|
||||||
|
|
||||||
exp = b"ok\n"
|
p1.transport.loseConnection()
|
||||||
yield a2.waitForBytes(len(exp))
|
p2.transport.loseConnection()
|
||||||
self.assertEqual(a2.data, exp)
|
|
||||||
|
|
||||||
# all data they sent after the handshake should be given to us
|
|
||||||
exp = b"ok\n"+s1
|
|
||||||
yield a2.waitForBytes(len(exp))
|
|
||||||
self.assertEqual(a2.data, exp)
|
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
|
||||||
a2.transport.loseConnection()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_sided_unsided(self):
|
def test_sided_unsided(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
p2 = self.new_protocol()
|
||||||
a2 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
side1 = b"\x01"*8
|
side1 = b"\x01"*8
|
||||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
p1.dataReceived(handshake(token1, side=side1))
|
||||||
b" for side " + hexlify(side1) + b"\n")
|
p2.dataReceived(handshake(token1, side=None))
|
||||||
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
|
||||||
|
|
||||||
# a correct handshake yields an ack, after which we can send
|
# a correct handshake yields an ack, after which we can send
|
||||||
exp = b"ok\n"
|
exp = b"ok\n"
|
||||||
yield a1.waitForBytes(len(exp))
|
self.assertEqual(p1.transport.value(), exp)
|
||||||
self.assertEqual(a1.data, exp)
|
self.assertEqual(p2.transport.value(), exp)
|
||||||
s1 = b"data1"
|
|
||||||
a1.transport.write(s1)
|
|
||||||
|
|
||||||
exp = b"ok\n"
|
p1.transport.clear()
|
||||||
yield a2.waitForBytes(len(exp))
|
p2.transport.clear()
|
||||||
self.assertEqual(a2.data, exp)
|
|
||||||
|
|
||||||
# all data they sent after the handshake should be given to us
|
# all data they sent after the handshake should be given to us
|
||||||
exp = b"ok\n"+s1
|
s1 = b"data1"
|
||||||
yield a2.waitForBytes(len(exp))
|
p1.dataReceived(s1)
|
||||||
self.assertEqual(a2.data, exp)
|
self.assertEqual(p2.transport.value(), s1)
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
a2.transport.loseConnection()
|
p2.transport.loseConnection()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_unsided_sided(self):
|
def test_unsided_sided(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
p2 = self.new_protocol()
|
||||||
a2 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
side1 = b"\x01"*8
|
side1 = b"\x01"*8
|
||||||
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
p1.dataReceived(handshake(token1, side=None))
|
||||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
p2.dataReceived(handshake(token1, side=side1))
|
||||||
b" for side " + hexlify(side1) + b"\n")
|
|
||||||
|
|
||||||
# a correct handshake yields an ack, after which we can send
|
# a correct handshake yields an ack, after which we can send
|
||||||
exp = b"ok\n"
|
exp = b"ok\n"
|
||||||
yield a1.waitForBytes(len(exp))
|
self.assertEqual(p1.transport.value(), exp)
|
||||||
self.assertEqual(a1.data, exp)
|
self.assertEqual(p2.transport.value(), exp)
|
||||||
s1 = b"data1"
|
|
||||||
a1.transport.write(s1)
|
|
||||||
|
|
||||||
exp = b"ok\n"
|
p1.transport.clear()
|
||||||
yield a2.waitForBytes(len(exp))
|
p2.transport.clear()
|
||||||
self.assertEqual(a2.data, exp)
|
|
||||||
|
|
||||||
# all data they sent after the handshake should be given to us
|
# all data they sent after the handshake should be given to us
|
||||||
exp = b"ok\n"+s1
|
s1 = b"data1"
|
||||||
yield a2.waitForBytes(len(exp))
|
p1.dataReceived(s1)
|
||||||
self.assertEqual(a2.data, exp)
|
self.assertEqual(p2.transport.value(), s1)
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
a2.transport.loseConnection()
|
p2.transport.loseConnection()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_both_sided(self):
|
def test_both_sided(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
p2 = self.new_protocol()
|
||||||
a2 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
side1 = b"\x01"*8
|
side1 = b"\x01"*8
|
||||||
side2 = b"\x02"*8
|
side2 = b"\x02"*8
|
||||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
p1.dataReceived(handshake(token1, side=side1))
|
||||||
b" for side " + hexlify(side1) + b"\n")
|
p2.dataReceived(handshake(token1, side=side2))
|
||||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
|
||||||
b" for side " + hexlify(side2) + b"\n")
|
|
||||||
|
|
||||||
# a correct handshake yields an ack, after which we can send
|
# a correct handshake yields an ack, after which we can send
|
||||||
exp = b"ok\n"
|
exp = b"ok\n"
|
||||||
yield a1.waitForBytes(len(exp))
|
self.assertEqual(p1.transport.value(), exp)
|
||||||
self.assertEqual(a1.data, exp)
|
self.assertEqual(p2.transport.value(), exp)
|
||||||
s1 = b"data1"
|
|
||||||
a1.transport.write(s1)
|
|
||||||
|
|
||||||
exp = b"ok\n"
|
p1.transport.clear()
|
||||||
yield a2.waitForBytes(len(exp))
|
p2.transport.clear()
|
||||||
self.assertEqual(a2.data, exp)
|
|
||||||
|
|
||||||
# all data they sent after the handshake should be given to us
|
# all data they sent after the handshake should be given to us
|
||||||
exp = b"ok\n"+s1
|
s1 = b"data1"
|
||||||
yield a2.waitForBytes(len(exp))
|
p1.dataReceived(s1)
|
||||||
self.assertEqual(a2.data, exp)
|
self.assertEqual(p2.transport.value(), s1)
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
a2.transport.loseConnection()
|
p2.transport.loseConnection()
|
||||||
|
|
||||||
def count(self):
|
|
||||||
return sum([len(potentials)
|
|
||||||
for potentials
|
|
||||||
in self._transit_server._pending_requests.values()])
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_ignore_same_side(self):
|
def test_ignore_same_side(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
p2 = self.new_protocol()
|
||||||
a2 = yield connectProtocol(ep, Accumulator())
|
p3 = self.new_protocol()
|
||||||
a3 = yield connectProtocol(ep, Accumulator())
|
|
||||||
disconnects = []
|
|
||||||
a1._disconnect.addCallback(disconnects.append)
|
|
||||||
a2._disconnect.addCallback(disconnects.append)
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
side1 = b"\x01"*8
|
side1 = b"\x01"*8
|
||||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
|
||||||
b" for side " + hexlify(side1) + b"\n")
|
p1.dataReceived(handshake(token1, side=side1))
|
||||||
# let that arrive
|
self.assertEqual(self.count(), 1)
|
||||||
while self.count() == 0:
|
|
||||||
yield wait()
|
p2.dataReceived(handshake(token1, side=side1))
|
||||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
|
||||||
b" for side " + hexlify(side1) + b"\n")
|
|
||||||
# let that arrive
|
|
||||||
while self.count() == 1:
|
|
||||||
yield wait()
|
|
||||||
self.assertEqual(self.count(), 2) # same-side connections don't match
|
self.assertEqual(self.count(), 2) # same-side connections don't match
|
||||||
|
|
||||||
# when the second side arrives, the spare first connection should be
|
# when the second side arrives, the spare first connection should be
|
||||||
# closed
|
# closed
|
||||||
side2 = b"\x02"*8
|
side2 = b"\x02"*8
|
||||||
a3.transport.write(b"please relay " + hexlify(token1) +
|
p3.dataReceived(handshake(token1, side=side2))
|
||||||
b" for side " + hexlify(side2) + b"\n")
|
self.assertEqual(self.count(), 0)
|
||||||
# let that arrive
|
|
||||||
while self.count() != 0:
|
|
||||||
yield wait()
|
|
||||||
self.assertEqual(len(self._transit_server._pending_requests), 0)
|
self.assertEqual(len(self._transit_server._pending_requests), 0)
|
||||||
self.assertEqual(len(self._transit_server._active_connections), 2)
|
self.assertEqual(len(self._transit_server._active_connections), 2)
|
||||||
# That will trigger a disconnect on exactly one of (a1 or a2). Wait
|
# That will trigger a disconnect on exactly one of (p1 or p2).
|
||||||
# until our client notices it.
|
# The other connection should still be connected
|
||||||
while not disconnects:
|
self.assertEqual(sum([int(t.transport.connected) for t in [p1, p2]]), 1)
|
||||||
yield wait()
|
|
||||||
# the other connection should still be connected
|
|
||||||
self.assertEqual(sum([int(t.transport.connected) for t in [a1, a2]]), 1)
|
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
a2.transport.loseConnection()
|
p2.transport.loseConnection()
|
||||||
a3.transport.loseConnection()
|
p3.transport.loseConnection()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_bad_handshake_old(self):
|
def test_bad_handshake_old(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
# the server waits for the exact number of bytes in the expected
|
p1.dataReceived(b"please DELAY " + hexlify(token1) + b"\n")
|
||||||
# handshake message. to trigger "bad handshake", we must match.
|
|
||||||
a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n")
|
|
||||||
|
|
||||||
exp = b"bad handshake\n"
|
exp = b"bad handshake\n"
|
||||||
yield a1.waitForBytes(len(exp))
|
self.assertEqual(p1.transport.value(), exp)
|
||||||
self.assertEqual(a1.data, exp)
|
p1.transport.loseConnection()
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_bad_handshake_old_slow(self):
|
def test_bad_handshake_old_slow(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
a1.transport.write(b"please DELAY ")
|
p1.dataReceived(b"please DELAY ")
|
||||||
# As in test_impatience_new_slow, the current state machine has code
|
# As in test_impatience_new_slow, the current state machine has code
|
||||||
# that can only be reached if we insert a stall here, so dataReceived
|
# that can only be reached if we insert a stall here, so dataReceived
|
||||||
# gets called twice. Hopefully we can delete this test once
|
# gets called twice. Hopefully we can delete this test once
|
||||||
# dataReceived is refactored to remove that state.
|
# dataReceived is refactored to remove that state.
|
||||||
d = defer.Deferred()
|
|
||||||
reactor.callLater(0.1, d.callback, None)
|
|
||||||
yield d
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
# the server waits for the exact number of bytes in the expected
|
# the server waits for the exact number of bytes in the expected
|
||||||
# handshake message. to trigger "bad handshake", we must match.
|
# handshake message. to trigger "bad handshake", we must match.
|
||||||
a1.transport.write(hexlify(token1) + b"\n")
|
p1.dataReceived(hexlify(token1) + b"\n")
|
||||||
|
|
||||||
exp = b"bad handshake\n"
|
exp = b"bad handshake\n"
|
||||||
yield a1.waitForBytes(len(exp))
|
self.assertEqual(p1.transport.value(), exp)
|
||||||
self.assertEqual(a1.data, exp)
|
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_bad_handshake_new(self):
|
def test_bad_handshake_new(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
side1 = b"\x01"*8
|
side1 = b"\x01"*8
|
||||||
# the server waits for the exact number of bytes in the expected
|
# the server waits for the exact number of bytes in the expected
|
||||||
# handshake message. to trigger "bad handshake", we must match.
|
# handshake message. to trigger "bad handshake", we must match.
|
||||||
a1.transport.write(b"please DELAY " + hexlify(token1) +
|
p1.dataReceived(b"please DELAY " + hexlify(token1) +
|
||||||
b" for side " + hexlify(side1) + b"\n")
|
b" for side " + hexlify(side1) + b"\n")
|
||||||
|
|
||||||
exp = b"bad handshake\n"
|
exp = b"bad handshake\n"
|
||||||
yield a1.waitForBytes(len(exp))
|
self.assertEqual(p1.transport.value(), exp)
|
||||||
self.assertEqual(a1.data, exp)
|
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_binary_handshake(self):
|
def test_binary_handshake(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff"
|
binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff"
|
||||||
# the embedded \n makes the server trigger early, before the full
|
# the embedded \n makes the server trigger early, before the full
|
||||||
|
@ -325,50 +232,41 @@ class _Transit:
|
||||||
# UnicodeDecodeError when it tried to coerce the incoming handshake
|
# UnicodeDecodeError when it tried to coerce the incoming handshake
|
||||||
# to unicode, due to the ("\n" in buf) check. This was fixed to use
|
# to unicode, due to the ("\n" in buf) check. This was fixed to use
|
||||||
# (b"\n" in buf). This exercises the old failure.
|
# (b"\n" in buf). This exercises the old failure.
|
||||||
a1.transport.write(binary_bad_handshake)
|
p1.dataReceived(binary_bad_handshake)
|
||||||
|
|
||||||
exp = b"bad handshake\n"
|
exp = b"bad handshake\n"
|
||||||
yield a1.waitForBytes(len(exp))
|
self.assertEqual(p1.transport.value(), exp)
|
||||||
self.assertEqual(a1.data, exp)
|
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_impatience_old(self):
|
def test_impatience_old(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
# sending too many bytes is impatience.
|
# sending too many bytes is impatience.
|
||||||
a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
|
p1.dataReceived(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
|
||||||
|
|
||||||
exp = b"impatient\n"
|
exp = b"impatient\n"
|
||||||
yield a1.waitForBytes(len(exp))
|
self.assertEqual(p1.transport.value(), exp)
|
||||||
self.assertEqual(a1.data, exp)
|
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_impatience_new(self):
|
def test_impatience_new(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
side1 = b"\x01"*8
|
side1 = b"\x01"*8
|
||||||
# sending too many bytes is impatience.
|
# sending too many bytes is impatience.
|
||||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
p1.dataReceived(b"please relay " + hexlify(token1) +
|
||||||
b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
|
b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
|
||||||
|
|
||||||
exp = b"impatient\n"
|
exp = b"impatient\n"
|
||||||
yield a1.waitForBytes(len(exp))
|
self.assertEqual(p1.transport.value(), exp)
|
||||||
self.assertEqual(a1.data, exp)
|
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_impatience_new_slow(self):
|
def test_impatience_new_slow(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
# For full coverage, we need dataReceived to see a particular framing
|
# For full coverage, we need dataReceived to see a particular framing
|
||||||
# of these two pieces of data, and ITCPTransport doesn't have flush()
|
# of these two pieces of data, and ITCPTransport doesn't have flush()
|
||||||
# (which probably wouldn't work anyways). For now, force a 100ms
|
# (which probably wouldn't work anyways). For now, force a 100ms
|
||||||
|
@ -381,37 +279,27 @@ class _Transit:
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
side1 = b"\x01"*8
|
side1 = b"\x01"*8
|
||||||
# sending too many bytes is impatience.
|
# sending too many bytes is impatience.
|
||||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
p1.dataReceived(b"please relay " + hexlify(token1) +
|
||||||
b" for side " + hexlify(side1) + b"\n")
|
b" for side " + hexlify(side1) + b"\n")
|
||||||
|
|
||||||
d = defer.Deferred()
|
|
||||||
reactor.callLater(0.1, d.callback, None)
|
|
||||||
yield d
|
|
||||||
|
|
||||||
a1.transport.write(b"NOWNOWNOW")
|
p1.dataReceived(b"NOWNOWNOW")
|
||||||
|
|
||||||
exp = b"impatient\n"
|
exp = b"impatient\n"
|
||||||
yield a1.waitForBytes(len(exp))
|
self.assertEqual(p1.transport.value(), exp)
|
||||||
self.assertEqual(a1.data, exp)
|
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_short_handshake(self):
|
def test_short_handshake(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
# hang up before sending a complete handshake
|
# hang up before sending a complete handshake
|
||||||
a1.transport.write(b"short")
|
p1.dataReceived(b"short")
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_empty_handshake(self):
|
def test_empty_handshake(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
# hang up before sending anything
|
# hang up before sending anything
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
|
|
||||||
class TransitWithLogs(_Transit, ServerBase, unittest.TestCase):
|
class TransitWithLogs(_Transit, ServerBase, unittest.TestCase):
|
||||||
log_requests = True
|
log_requests = True
|
||||||
|
@ -420,180 +308,113 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
|
||||||
log_requests = False
|
log_requests = False
|
||||||
|
|
||||||
class Usage(ServerBase, unittest.TestCase):
|
class Usage(ServerBase, unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
yield super(Usage, self).setUp()
|
super(Usage, self).setUp()
|
||||||
self._usage = []
|
self._usage = []
|
||||||
def record(started, result, total_bytes, total_time, waiting_time):
|
def record(started, result, total_bytes, total_time, waiting_time):
|
||||||
self._usage.append((started, result, total_bytes,
|
self._usage.append((started, result, total_bytes,
|
||||||
total_time, waiting_time))
|
total_time, waiting_time))
|
||||||
self._transit_server.recordUsage = record
|
self._transit_server.recordUsage = record
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
# hang up before sending anything
|
# hang up before sending anything
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
yield a1._disconnect
|
|
||||||
# give the server a chance to react. in most of the other tests, the
|
|
||||||
# server hangs up on us, so this test needs extra synchronization
|
|
||||||
while len(self._usage) == 0:
|
|
||||||
d = defer.Deferred()
|
|
||||||
reactor.callLater(0.01, d.callback, None)
|
|
||||||
yield d
|
|
||||||
# that will log the "empty" usage event
|
# that will log the "empty" usage event
|
||||||
self.assertEqual(len(self._usage), 1, self._usage)
|
self.assertEqual(len(self._usage), 1, self._usage)
|
||||||
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
||||||
self.assertEqual(result, "empty", self._usage)
|
self.assertEqual(result, "empty", self._usage)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_short(self):
|
def test_short(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
# hang up before sending a complete handshake
|
# hang up before sending a complete handshake
|
||||||
a1.transport.write(b"short")
|
p1.transport.write(b"short")
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
yield a1._disconnect
|
|
||||||
# give the server a chance to react. in most of the other tests, the
|
|
||||||
# server hangs up on us, so this test needs extra synchronization
|
|
||||||
while len(self._usage) == 0:
|
|
||||||
d = defer.Deferred()
|
|
||||||
reactor.callLater(0.01, d.callback, None)
|
|
||||||
yield d
|
|
||||||
# that will log the "empty" usage event
|
# that will log the "empty" usage event
|
||||||
self.assertEqual(len(self._usage), 1, self._usage)
|
self.assertEqual(len(self._usage), 1, self._usage)
|
||||||
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
||||||
self.assertEqual(result, "empty", self._usage)
|
self.assertEqual(result, "empty", self._usage)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_errory(self):
|
def test_errory(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
a1.transport.write(b"this is a very bad handshake\n")
|
p1.dataReceived(b"this is a very bad handshake\n")
|
||||||
# that will log the "errory" usage event, then drop the connection
|
# that will log the "errory" usage event, then drop the connection
|
||||||
yield a1._disconnect
|
p1.transport.loseConnection()
|
||||||
self.assertEqual(len(self._usage), 1, self._usage)
|
self.assertEqual(len(self._usage), 1, self._usage)
|
||||||
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
||||||
self.assertEqual(result, "errory", self._usage)
|
self.assertEqual(result, "errory", self._usage)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_lonely(self):
|
def test_lonely(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
side1 = b"\x01"*8
|
side1 = b"\x01"*8
|
||||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
p1.dataReceived(handshake(token1, side=side1))
|
||||||
b" for side " + hexlify(side1) + b"\n")
|
|
||||||
while not self._transit_server._pending_requests:
|
|
||||||
yield wait() # wait for the server to see the connection
|
|
||||||
# now we disconnect before the peer connects
|
# now we disconnect before the peer connects
|
||||||
a1.transport.loseConnection()
|
p1.transport.loseConnection()
|
||||||
yield a1._disconnect
|
|
||||||
while self._transit_server._pending_requests:
|
|
||||||
yield wait() # wait for the server to see the disconnect too
|
|
||||||
|
|
||||||
self.assertEqual(len(self._usage), 1, self._usage)
|
self.assertEqual(len(self._usage), 1, self._usage)
|
||||||
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
||||||
self.assertEqual(result, "lonely", self._usage)
|
self.assertEqual(result, "lonely", self._usage)
|
||||||
self.assertIdentical(waiting_time, None)
|
self.assertIdentical(waiting_time, None)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_one_happy_one_jilted(self):
|
def test_one_happy_one_jilted(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1 = self.new_protocol()
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
p2 = self.new_protocol()
|
||||||
a2 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
side1 = b"\x01"*8
|
side1 = b"\x01"*8
|
||||||
side2 = b"\x02"*8
|
side2 = b"\x02"*8
|
||||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
p1.dataReceived(handshake(token1, side=side1))
|
||||||
b" for side " + hexlify(side1) + b"\n")
|
p2.dataReceived(handshake(token1, side=side2))
|
||||||
while not self._transit_server._pending_requests:
|
|
||||||
yield wait() # make sure a1 connects first
|
self.assertEqual(self._usage, []) # no events yet
|
||||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
|
||||||
b" for side " + hexlify(side2) + b"\n")
|
p1.dataReceived(b"\x00" * 13)
|
||||||
while not self._transit_server._active_connections:
|
p2.dataReceived(b"\xff" * 7)
|
||||||
yield wait() # wait for the server to see the connection
|
|
||||||
self.assertEqual(len(self._transit_server._pending_requests), 0)
|
p1.transport.loseConnection()
|
||||||
self.assertEqual(self._usage, []) # no events yet
|
|
||||||
a1.transport.write(b"\x00" * 13)
|
|
||||||
yield a2.waitForBytes(13)
|
|
||||||
a2.transport.write(b"\xff" * 7)
|
|
||||||
yield a1.waitForBytes(7)
|
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
|
||||||
yield a1._disconnect
|
|
||||||
while self._transit_server._active_connections:
|
|
||||||
yield wait()
|
|
||||||
yield a2._disconnect
|
|
||||||
self.assertEqual(len(self._usage), 1, self._usage)
|
self.assertEqual(len(self._usage), 1, self._usage)
|
||||||
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
||||||
self.assertEqual(result, "happy", self._usage)
|
self.assertEqual(result, "happy", self._usage)
|
||||||
self.assertEqual(total_bytes, 20)
|
self.assertEqual(total_bytes, 20)
|
||||||
self.assertNotIdentical(waiting_time, None)
|
self.assertNotIdentical(waiting_time, None)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_redundant(self):
|
def test_redundant(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
p1a = self.new_protocol()
|
||||||
a1a = yield connectProtocol(ep, Accumulator())
|
p1b = self.new_protocol()
|
||||||
a1b = yield connectProtocol(ep, Accumulator())
|
p1c = self.new_protocol()
|
||||||
a1c = yield connectProtocol(ep, Accumulator())
|
p2 = self.new_protocol()
|
||||||
a2 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
token1 = b"\x00"*32
|
||||||
side1 = b"\x01"*8
|
side1 = b"\x01"*8
|
||||||
side2 = b"\x02"*8
|
side2 = b"\x02"*8
|
||||||
a1a.transport.write(b"please relay " + hexlify(token1) +
|
p1a.dataReceived(handshake(token1, side=side1))
|
||||||
b" for side " + hexlify(side1) + b"\n")
|
p1b.dataReceived(handshake(token1, side=side1))
|
||||||
def count_requests():
|
|
||||||
return sum([len(v)
|
|
||||||
for v in self._transit_server._pending_requests.values()])
|
|
||||||
while count_requests() < 1:
|
|
||||||
yield wait()
|
|
||||||
a1b.transport.write(b"please relay " + hexlify(token1) +
|
|
||||||
b" for side " + hexlify(side1) + b"\n")
|
|
||||||
while count_requests() < 2:
|
|
||||||
yield wait()
|
|
||||||
|
|
||||||
# connect and disconnect a third client (for side1) to exercise the
|
# connect and disconnect a third client (for side1) to exercise the
|
||||||
# code that removes a pending connection without removing the entire
|
# code that removes a pending connection without removing the entire
|
||||||
# token
|
# token
|
||||||
a1c.transport.write(b"please relay " + hexlify(token1) +
|
p1c.dataReceived(handshake(token1, side=side1))
|
||||||
b" for side " + hexlify(side1) + b"\n")
|
p1c.transport.loseConnection()
|
||||||
while count_requests() < 3:
|
|
||||||
yield wait()
|
|
||||||
a1c.transport.loseConnection()
|
|
||||||
yield a1c._disconnect
|
|
||||||
while count_requests() > 2:
|
|
||||||
yield wait()
|
|
||||||
self.assertEqual(len(self._usage), 1, self._usage)
|
self.assertEqual(len(self._usage), 1, self._usage)
|
||||||
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
|
||||||
self.assertEqual(result, "lonely", self._usage)
|
self.assertEqual(result, "lonely", self._usage)
|
||||||
|
|
||||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
p2.dataReceived(handshake(token1, side=side2))
|
||||||
b" for side " + hexlify(side2) + b"\n")
|
self.assertEqual(len(self._transit_server._pending_requests), 0)
|
||||||
# this will claim one of (a1a, a1b), and close the other as redundant
|
|
||||||
while not self._transit_server._active_connections:
|
|
||||||
yield wait() # wait for the server to see the connection
|
|
||||||
self.assertEqual(count_requests(), 0)
|
|
||||||
self.assertEqual(len(self._usage), 2, self._usage)
|
self.assertEqual(len(self._usage), 2, self._usage)
|
||||||
(started, result, total_bytes, total_time, waiting_time) = self._usage[1]
|
(started, result, total_bytes, total_time, waiting_time) = self._usage[1]
|
||||||
self.assertEqual(result, "redundant", self._usage)
|
self.assertEqual(result, "redundant", self._usage)
|
||||||
|
|
||||||
# one of the these is unecessary, but probably harmless
|
# one of the these is unecessary, but probably harmless
|
||||||
a1a.transport.loseConnection()
|
p1a.transport.loseConnection()
|
||||||
a1b.transport.loseConnection()
|
p1b.transport.loseConnection()
|
||||||
yield a1a._disconnect
|
|
||||||
yield a1b._disconnect
|
|
||||||
while self._transit_server._active_connections:
|
|
||||||
yield wait()
|
|
||||||
yield a2._disconnect
|
|
||||||
self.assertEqual(len(self._usage), 3, self._usage)
|
self.assertEqual(len(self._usage), 3, self._usage)
|
||||||
(started, result, total_bytes, total_time, waiting_time) = self._usage[2]
|
(started, result, total_bytes, total_time, waiting_time) = self._usage[2]
|
||||||
self.assertEqual(result, "happy", self._usage)
|
self.assertEqual(result, "happy", self._usage)
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,10 @@ class TransitConnection(LineReceiver):
|
||||||
def connectionMade(self):
|
def connectionMade(self):
|
||||||
self._started = time.time()
|
self._started = time.time()
|
||||||
self._log_requests = self.factory._log_requests
|
self._log_requests = self.factory._log_requests
|
||||||
|
try:
|
||||||
self.transport.setTcpKeepAlive(True)
|
self.transport.setTcpKeepAlive(True)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
def lineReceived(self, line):
|
def lineReceived(self, line):
|
||||||
# old: "please relay {64}\n"
|
# old: "please relay {64}\n"
|
||||||
|
@ -257,14 +260,14 @@ class Transit(protocol.ServerFactory):
|
||||||
|
|
||||||
# drop and stop tracking the rest
|
# drop and stop tracking the rest
|
||||||
potentials.remove(old)
|
potentials.remove(old)
|
||||||
for (_, leftover_tc) in potentials:
|
for (_, leftover_tc) in potentials.copy():
|
||||||
# Don't record this as errory. It's just a spare connection
|
# Don't record this as errory. It's just a spare connection
|
||||||
# from the same side as a connection that got used. This
|
# from the same side as a connection that got used. This
|
||||||
# can happen if the connection hint contains multiple
|
# can happen if the connection hint contains multiple
|
||||||
# addresses (we don't currently support those, but it'd
|
# addresses (we don't currently support those, but it'd
|
||||||
# probably be useful in the future).
|
# probably be useful in the future).
|
||||||
leftover_tc.disconnect_redundant()
|
leftover_tc.disconnect_redundant()
|
||||||
self._pending_requests.pop(token)
|
self._pending_requests.pop(token, None)
|
||||||
|
|
||||||
# glue the two ends together
|
# glue the two ends together
|
||||||
self._active_connections.add(new_tc)
|
self._active_connections.add(new_tc)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user