7234e25897
I made the classic dataReceived() mistake, and exited the function after delivering the first record. Keep at it until there are no complete records left.
1206 lines
43 KiB
Python
1206 lines
43 KiB
Python
from __future__ import print_function
|
|
from binascii import hexlify, unhexlify
|
|
from twisted.trial import unittest
|
|
from twisted.internet import defer, task, endpoints, protocol, address, error
|
|
from twisted.internet.defer import gatherResults, inlineCallbacks
|
|
from twisted.python import log, failure
|
|
from twisted.test import proto_helpers
|
|
from ..twisted import transit
|
|
from ..errors import UsageError
|
|
from nacl.secret import SecretBox
|
|
from nacl.exceptions import CryptoError
|
|
|
|
class Highlander(unittest.TestCase):
|
|
def test_one_winner(self):
|
|
cancelled = set()
|
|
contenders = [defer.Deferred(lambda d: cancelled.add(i))
|
|
for i in range(4)]
|
|
result = []
|
|
d = transit.there_can_be_only_one(contenders)
|
|
d.addBoth(result.append)
|
|
self.assertEqual(result, [])
|
|
contenders[0].errback(ValueError())
|
|
self.assertEqual(result, [])
|
|
contenders[1].errback(TypeError())
|
|
self.assertEqual(result, [])
|
|
contenders[2].callback("yay")
|
|
self.assertEqual(result, ["yay"])
|
|
self.assertEqual(cancelled, set([3]))
|
|
|
|
def test_there_might_also_be_none(self):
|
|
cancelled = set()
|
|
contenders = [defer.Deferred(lambda d: cancelled.add(i))
|
|
for i in range(4)]
|
|
result = []
|
|
d = transit.there_can_be_only_one(contenders)
|
|
d.addBoth(result.append)
|
|
self.assertEqual(result, [])
|
|
contenders[0].errback(ValueError())
|
|
self.assertEqual(result, [])
|
|
contenders[1].errback(TypeError())
|
|
self.assertEqual(result, [])
|
|
contenders[2].errback(TypeError())
|
|
self.assertEqual(result, [])
|
|
contenders[3].errback(NameError())
|
|
self.assertEqual(len(result), 1)
|
|
f = result[0]
|
|
self.assertIsInstance(f.value, ValueError) # first failure is recorded
|
|
self.assertEqual(cancelled, set())
|
|
|
|
def test_cancel_early(self):
|
|
cancelled = set()
|
|
contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i))
|
|
for i in range(4)]
|
|
result = []
|
|
d = transit.there_can_be_only_one(contenders)
|
|
d.addBoth(result.append)
|
|
self.assertEqual(result, [])
|
|
self.assertEqual(cancelled, set())
|
|
d.cancel()
|
|
self.assertEqual(len(result), 1)
|
|
self.assertIsInstance(result[0].value, defer.CancelledError)
|
|
self.assertEqual(cancelled, set(range(4)))
|
|
|
|
def test_cancel_after_one_failure(self):
|
|
cancelled = set()
|
|
contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i))
|
|
for i in range(4)]
|
|
result = []
|
|
d = transit.there_can_be_only_one(contenders)
|
|
d.addBoth(result.append)
|
|
self.assertEqual(result, [])
|
|
self.assertEqual(cancelled, set())
|
|
contenders[0].errback(ValueError())
|
|
d.cancel()
|
|
self.assertEqual(len(result), 1)
|
|
self.assertIsInstance(result[0].value, ValueError)
|
|
self.assertEqual(cancelled, set([1,2,3]))
|
|
|
|
class Forever(unittest.TestCase):
|
|
def _forever_setup(self):
|
|
clock = task.Clock()
|
|
c = transit.Common(u"", reactor=clock)
|
|
cancelled = []
|
|
result = []
|
|
d0 = defer.Deferred(cancelled.append)
|
|
d = c._not_forever(1.0, d0)
|
|
d.addBoth(result.append)
|
|
return c, clock, d0, d, cancelled, result
|
|
|
|
def test_not_forever_fires(self):
|
|
c, clock, d0, d, cancelled, result = self._forever_setup()
|
|
self.assertEqual((result, cancelled), ([], []))
|
|
d.callback(1)
|
|
self.assertEqual((result, cancelled), ([1], []))
|
|
self.assertNot(clock.getDelayedCalls())
|
|
|
|
def test_not_forever_errs(self):
|
|
c, clock, d0, d, cancelled, result = self._forever_setup()
|
|
self.assertEqual((result, cancelled), ([], []))
|
|
d.errback(ValueError())
|
|
self.assertEqual(cancelled, [])
|
|
self.assertEqual(len(result), 1)
|
|
self.assertIsInstance(result[0].value, ValueError)
|
|
self.assertNot(clock.getDelayedCalls())
|
|
|
|
def test_not_forever_cancel_early(self):
|
|
c, clock, d0, d, cancelled, result = self._forever_setup()
|
|
self.assertEqual((result, cancelled), ([], []))
|
|
d.cancel()
|
|
self.assertEqual(cancelled, [d0])
|
|
self.assertEqual(len(result), 1)
|
|
self.assertIsInstance(result[0].value, defer.CancelledError)
|
|
self.assertNot(clock.getDelayedCalls())
|
|
|
|
def test_not_forever_timeout(self):
|
|
c, clock, d0, d, cancelled, result = self._forever_setup()
|
|
self.assertEqual((result, cancelled), ([], []))
|
|
clock.advance(2.0)
|
|
self.assertEqual(cancelled, [d0])
|
|
self.assertEqual(len(result), 1)
|
|
self.assertIsInstance(result[0].value, defer.CancelledError)
|
|
self.assertNot(clock.getDelayedCalls())
|
|
|
|
class Misc(unittest.TestCase):
|
|
def test_allocate_port(self):
|
|
portno = transit.allocate_tcp_port()
|
|
self.assertIsInstance(portno, int)
|
|
|
|
class Hints(unittest.TestCase):
|
|
def test_endpoint_from_hint(self):
|
|
c = transit.Common(u"")
|
|
ep = c._endpoint_from_hint("tcp:localhost:1234")
|
|
self.assertIsInstance(ep, endpoints.HostnameEndpoint)
|
|
ep = c._endpoint_from_hint("unknown:stuff:yowza:pivlor")
|
|
self.assertEqual(ep, None)
|
|
ep = c._endpoint_from_hint("tooshort")
|
|
self.assertEqual(ep, None)
|
|
|
|
|
|
class Basic(unittest.TestCase):
|
|
def test_relay_hints(self):
|
|
URL = u"RELAYURL"
|
|
c = transit.Common(URL)
|
|
self.assertEqual(c.get_relay_hints(), [URL])
|
|
self.assertRaises(UsageError, transit.Common, 123)
|
|
|
|
def test_no_relay_hints(self):
|
|
c = transit.Common(None)
|
|
self.assertEqual(c.get_relay_hints(), [])
|
|
|
|
def test_bad_hints(self):
|
|
c = transit.Common(u"")
|
|
self.assertRaises(TypeError, c.add_their_direct_hints, [123])
|
|
c.add_their_direct_hints([u"URL"])
|
|
self.assertRaises(TypeError, c.add_their_relay_hints, [123])
|
|
c.add_their_relay_hints([u"URL"])
|
|
|
|
def test_transit_key_wait(self):
|
|
KEY = b"123"
|
|
c = transit.Common(u"")
|
|
results = []
|
|
d = c._get_transit_key()
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
c.set_transit_key(KEY)
|
|
self.assertEqual(results, [KEY])
|
|
|
|
def test_transit_key_already_set(self):
|
|
KEY = b"123"
|
|
c = transit.Common(u"")
|
|
c.set_transit_key(KEY)
|
|
results = []
|
|
d = c._get_transit_key()
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [KEY])
|
|
|
|
def test_transit_keys(self):
|
|
KEY = b"123"
|
|
s = transit.TransitSender(u"")
|
|
s.set_transit_key(KEY)
|
|
r = transit.TransitReceiver(u"")
|
|
r.set_transit_key(KEY)
|
|
|
|
self.assertEqual(s._send_this(), b"transit sender 559bdeae4b49fa6a23378d2b68f4c7e69378615d4af049c371c6a26e82391089 ready\n\n")
|
|
self.assertEqual(s._send_this(), r._expect_this())
|
|
|
|
self.assertEqual(r._send_this(), b"transit receiver ed447528194bac4c00d0c854b12a97ce51413d89aa74d6304475f516fdc23a1b ready\n\n")
|
|
self.assertEqual(r._send_this(), s._expect_this())
|
|
|
|
self.assertEqual(hexlify(s._sender_record_key()), b"5a2fba3a9e524ab2e2823ff53b05f946896f6e4ce4e282ffd8e3ac0e5e9e0cda")
|
|
self.assertEqual(hexlify(s._sender_record_key()),
|
|
hexlify(r._receiver_record_key()))
|
|
|
|
self.assertEqual(hexlify(r._sender_record_key()), b"eedb143117249f45b39da324decf6bd9aae33b7ccd58487436de611a3c6b871d")
|
|
self.assertEqual(hexlify(r._sender_record_key()),
|
|
hexlify(s._receiver_record_key()))
|
|
|
|
def test_connection_ready(self):
|
|
s = transit.TransitSender(u"")
|
|
self.assertEqual(s.connection_ready("p1"), "go")
|
|
self.assertEqual(s._winner, "p1")
|
|
self.assertEqual(s.connection_ready("p2"), "nevermind")
|
|
self.assertEqual(s._winner, "p1")
|
|
|
|
r = transit.TransitReceiver(u"")
|
|
self.assertEqual(r.connection_ready("p1"), "wait-for-decision")
|
|
self.assertEqual(r.connection_ready("p2"), "wait-for-decision")
|
|
|
|
|
|
class Listener(unittest.TestCase):
|
|
def test_listener(self):
|
|
c = transit.Common(u"")
|
|
hints, ep = c._build_listener()
|
|
self.assertIsInstance(hints, (list, set))
|
|
if hints:
|
|
self.assertIsInstance(hints[0], type(u""))
|
|
self.assert_(hints[0].startswith(u"tcp:"))
|
|
self.assertIsInstance(ep, endpoints.TCP4ServerEndpoint)
|
|
|
|
def test_get_direct_hints(self):
|
|
# this actually starts the listener
|
|
c = transit.TransitSender(u"")
|
|
|
|
results = []
|
|
d = c.get_direct_hints()
|
|
d.addBoth(results.append)
|
|
self.assertEqual(len(results), 1)
|
|
hints = results[0]
|
|
|
|
# the hints are supposed to be cached, so calling this twice won't
|
|
# start a second listener
|
|
self.assert_(c._listener)
|
|
results = []
|
|
d = c.get_direct_hints()
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [hints])
|
|
|
|
c._stop_listening()
|
|
|
|
|
|
class DummyProtocol(protocol.Protocol):
|
|
def __init__(self):
|
|
self.buf = b""
|
|
self._count = None
|
|
self._d2 = None
|
|
|
|
def wait_for(self, count):
|
|
if len(self.buf) >= count:
|
|
data = self.buf[:count]
|
|
self.buf = self.buf[count:]
|
|
return defer.succeed(data)
|
|
self._d = defer.Deferred()
|
|
self._count = count
|
|
return self._d
|
|
|
|
def dataReceived(self, data):
|
|
self.buf += data
|
|
#print("oDR", self._count, len(self.buf))
|
|
if self._count is not None and len(self.buf) >= self._count:
|
|
got = self.buf[:self._count]
|
|
self.buf = self.buf[self._count:]
|
|
self._count = None
|
|
self._d.callback(got)
|
|
|
|
def wait_for_disconnect(self):
|
|
self._d2 = defer.Deferred()
|
|
return self._d2
|
|
|
|
def connectionLost(self, reason):
|
|
if self._d2:
|
|
self._d2.callback(None)
|
|
|
|
class FakeTransport:
|
|
signalConnectionLost = True
|
|
def __init__(self, p, peeraddr):
|
|
self.protocol = p
|
|
self._peeraddr = peeraddr
|
|
self._buf = b""
|
|
self._connected = True
|
|
def write(self, data):
|
|
self._buf += data
|
|
def loseConnection(self):
|
|
self._connected = False
|
|
if self.signalConnectionLost:
|
|
self.protocol.connectionLost()
|
|
def getPeer(self):
|
|
return self._peeraddr
|
|
|
|
def read_buf(self):
|
|
b = self._buf
|
|
self._buf = b""
|
|
return b
|
|
|
|
class RandomError(Exception):
|
|
pass
|
|
|
|
class MockConnection:
|
|
def __init__(self, owner, relay_handshake, start, description):
|
|
self.owner = owner
|
|
self.relay_handshake = relay_handshake
|
|
self.start = start
|
|
self._description = description
|
|
def cancel(d):
|
|
self._cancelled = True
|
|
self._d = defer.Deferred(cancel)
|
|
self._start_negotiation_called = False
|
|
self._cancelled = False
|
|
|
|
def startNegotiation(self):
|
|
self._start_negotiation_called = True
|
|
return self._d
|
|
|
|
class InboundConnectionFactory(unittest.TestCase):
|
|
def test_describe(self):
|
|
f = transit.InboundConnectionFactory(None)
|
|
addrH = address.HostnameAddress("example.com", 1234)
|
|
self.assertEqual(f._describePeer(addrH), "<-example.com:1234")
|
|
addr4 = address.IPv4Address("TCP", "1.2.3.4", 1234)
|
|
self.assertEqual(f._describePeer(addr4), "<-1.2.3.4:1234")
|
|
addr6 = address.IPv6Address("TCP", "::1", 1234)
|
|
self.assertEqual(f._describePeer(addr6), "<-::1:1234")
|
|
addrU = address.UNIXAddress("/dev/unlikely")
|
|
self.assertEqual(f._describePeer(addrU),
|
|
"<-UNIXAddress('/dev/unlikely')")
|
|
|
|
def test_success(self):
|
|
f = transit.InboundConnectionFactory("owner")
|
|
f.protocol = MockConnection
|
|
results = []
|
|
d = f.whenDone()
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
p = f.buildProtocol(addr)
|
|
self.assertIsInstance(p, MockConnection)
|
|
self.assertEqual(p.owner, "owner")
|
|
self.assertEqual(p.relay_handshake, None)
|
|
self.assertEqual(p._start_negotiation_called, False)
|
|
# meh .start
|
|
|
|
# this is normally called from Connection.connectionMade
|
|
f.connectionWasMade(p)
|
|
self.assertEqual(p._start_negotiation_called, True)
|
|
self.assertEqual(results, [])
|
|
self.assertEqual(p._description, "<-example.com:1234")
|
|
|
|
p._d.callback(p)
|
|
self.assertEqual(results, [p])
|
|
|
|
def test_one_fail_one_success(self):
|
|
f = transit.InboundConnectionFactory("owner")
|
|
f.protocol = MockConnection
|
|
results = []
|
|
d = f.whenDone()
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
|
|
addr1 = address.HostnameAddress("example.com", 1234)
|
|
addr2 = address.HostnameAddress("example.com", 5678)
|
|
p1 = f.buildProtocol(addr1)
|
|
p2 = f.buildProtocol(addr2)
|
|
|
|
f.connectionWasMade(p1)
|
|
f.connectionWasMade(p2)
|
|
self.assertEqual(results, [])
|
|
|
|
p1._d.errback(transit.BadHandshake("nope"))
|
|
self.assertEqual(results, [])
|
|
p2._d.callback(p2)
|
|
self.assertEqual(results, [p2])
|
|
|
|
def test_first_success_wins(self):
|
|
f = transit.InboundConnectionFactory("owner")
|
|
f.protocol = MockConnection
|
|
results = []
|
|
d = f.whenDone()
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
|
|
addr1 = address.HostnameAddress("example.com", 1234)
|
|
addr2 = address.HostnameAddress("example.com", 5678)
|
|
p1 = f.buildProtocol(addr1)
|
|
p2 = f.buildProtocol(addr2)
|
|
|
|
f.connectionWasMade(p1)
|
|
f.connectionWasMade(p2)
|
|
self.assertEqual(results, [])
|
|
|
|
p1._d.callback(p1)
|
|
self.assertEqual(results, [p1])
|
|
self.assertEqual(p1._cancelled, False)
|
|
self.assertEqual(p2._cancelled, True)
|
|
|
|
def test_log_other_errors(self):
|
|
f = transit.InboundConnectionFactory("owner")
|
|
f.protocol = MockConnection
|
|
results = []
|
|
d = f.whenDone()
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
p1 = f.buildProtocol(addr)
|
|
|
|
# if the Connection protocol throws an unexpected error, that should
|
|
# get logged to the Twisted logs (as an Unhandled Error in Deferred)
|
|
# so we can diagnose the bug
|
|
f.connectionWasMade(p1)
|
|
p1._d.errback(RandomError("boom"))
|
|
self.assertEqual(len(results), 0)
|
|
|
|
log.msg("=== note: the next RandomError is expected ===")
|
|
# Make sure the Deferred has gone out of scope, so the UnhandledError
|
|
# happens quickly. We must manually break the gc cycle.
|
|
del p1._d
|
|
self.flushLoggedErrors(RandomError)
|
|
log.msg("=== note: the preceding RandomError was expected ===")
|
|
|
|
def test_cancel(self):
|
|
f = transit.InboundConnectionFactory("owner")
|
|
f.protocol = MockConnection
|
|
results = []
|
|
d = f.whenDone()
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
|
|
addr1 = address.HostnameAddress("example.com", 1234)
|
|
addr2 = address.HostnameAddress("example.com", 5678)
|
|
p1 = f.buildProtocol(addr1)
|
|
p2 = f.buildProtocol(addr2)
|
|
|
|
f.connectionWasMade(p1)
|
|
f.connectionWasMade(p2)
|
|
self.assertEqual(results, [])
|
|
|
|
d.cancel()
|
|
|
|
self.assertEqual(len(results), 1)
|
|
f = results[0]
|
|
self.assertIsInstance(f, failure.Failure)
|
|
self.assertIsInstance(f.value, defer.CancelledError)
|
|
self.assertEqual(p1._cancelled, True)
|
|
self.assertEqual(p2._cancelled, True)
|
|
|
|
# XXX check descriptions
|
|
|
|
class OutboundConnectionFactory(unittest.TestCase):
|
|
def test_success(self):
|
|
f = transit.OutboundConnectionFactory("owner", "relay_handshake",
|
|
"description")
|
|
f.protocol = MockConnection
|
|
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
p = f.buildProtocol(addr)
|
|
self.assertIsInstance(p, MockConnection)
|
|
self.assertEqual(p.owner, "owner")
|
|
self.assertEqual(p.relay_handshake, "relay_handshake")
|
|
self.assertEqual(p._start_negotiation_called, False)
|
|
# meh .start
|
|
|
|
# this is normally called from Connection.connectionMade
|
|
f.connectionWasMade(p) # no-op for outbound
|
|
self.assertEqual(p._start_negotiation_called, False)
|
|
|
|
|
|
class MockOwner:
|
|
_connection_ready_called = False
|
|
def connection_ready(self, connection):
|
|
self._connection_ready_called = True
|
|
self._connection = connection
|
|
return self._state
|
|
def _send_this(self):
|
|
return b"send_this"
|
|
def _expect_this(self):
|
|
return b"expect_this"
|
|
def _sender_record_key(self):
|
|
return b"s"*32
|
|
def _receiver_record_key(self):
|
|
return b"r"*32
|
|
|
|
class MockFactory:
|
|
_connectionWasMade_called = False
|
|
def connectionWasMade(self, p):
|
|
self._connectionWasMade_called = True
|
|
self._p = p
|
|
|
|
class Connection(unittest.TestCase):
|
|
# exercise the Connection protocol class
|
|
|
|
def test_check_and_remove(self):
|
|
c = transit.Connection(None, None, None, "description")
|
|
c.buf = b""
|
|
EXP = b"expectation"
|
|
self.assertFalse(c._check_and_remove(EXP))
|
|
self.assertEqual(c.buf, b"")
|
|
|
|
c.buf = b"unexpected"
|
|
e = self.assertRaises(transit.BadHandshake, c._check_and_remove, EXP)
|
|
self.assertEqual(str(e),
|
|
"got %r want %r" % (b'unexpected', b'expectation'))
|
|
self.assertEqual(c.buf, b"unexpected")
|
|
|
|
c.buf = b"expect"
|
|
self.assertFalse(c._check_and_remove(EXP))
|
|
self.assertEqual(c.buf, b"expect")
|
|
|
|
c.buf = b"expectation"
|
|
self.assertTrue(c._check_and_remove(EXP))
|
|
self.assertEqual(c.buf, b"")
|
|
|
|
c.buf = b"expectation exceeded"
|
|
self.assertTrue(c._check_and_remove(EXP))
|
|
self.assertEqual(c.buf, b" exceeded")
|
|
|
|
def test_sender_accepting(self):
|
|
relay_handshake = None
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, relay_handshake, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
|
|
owner._state = "go"
|
|
d = c.startNegotiation()
|
|
self.assertEqual(c.state, "handshake")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
results = []
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
|
|
c.dataReceived(b"expect_this")
|
|
self.assertEqual(t.read_buf(), b"go\n")
|
|
self.assertEqual(t._connected, True)
|
|
self.assertEqual(c.state, "records")
|
|
self.assertEqual(results, [c])
|
|
|
|
c.close()
|
|
self.assertEqual(t._connected, False)
|
|
|
|
def test_sender_rejecting(self):
|
|
relay_handshake = None
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, relay_handshake, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
|
|
owner._state = "nevermind"
|
|
d = c.startNegotiation()
|
|
self.assertEqual(c.state, "handshake")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
results = []
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
|
|
c.dataReceived(b"expect_this")
|
|
self.assertEqual(t.read_buf(), b"nevermind\n")
|
|
self.assertEqual(t._connected, False)
|
|
self.assertEqual(c.state, "hung up")
|
|
self.assertEqual(len(results), 1)
|
|
f = results[0]
|
|
self.assertIsInstance(f, failure.Failure)
|
|
self.assertIsInstance(f.value, transit.BadHandshake)
|
|
self.assertEqual(str(f.value), "abandoned")
|
|
|
|
def test_handshake_other_error(self):
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
|
|
d = c.startNegotiation()
|
|
self.assertEqual(c.state, "handshake")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
results = []
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
c.state = RandomError("boom")
|
|
self.assertRaises(RandomError, c.dataReceived, b"surprise!")
|
|
self.assertEqual(t._connected, False)
|
|
self.assertEqual(c.state, "hung up")
|
|
self.assertEqual(len(results), 1)
|
|
f = results[0]
|
|
self.assertIsInstance(f, failure.Failure)
|
|
self.assertIsInstance(f.value, RandomError)
|
|
|
|
def test_relay_handshake(self):
|
|
relay_handshake = b"relay handshake"
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, relay_handshake, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation
|
|
|
|
owner._state = "go"
|
|
d = c.startNegotiation()
|
|
self.assertEqual(t.read_buf(), relay_handshake)
|
|
self.assertEqual(c.state, "relay") # waiting for OK from relay
|
|
|
|
c.dataReceived(b"ok\n")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
self.assertEqual(c.state, "handshake")
|
|
|
|
results = []
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
|
|
c.dataReceived(b"expect_this")
|
|
self.assertEqual(c.state, "records")
|
|
self.assertEqual(results, [c])
|
|
|
|
self.assertEqual(t.read_buf(), b"go\n")
|
|
|
|
def test_relay_handshake_bad(self):
|
|
relay_handshake = b"relay handshake"
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, relay_handshake, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation
|
|
|
|
owner._state = "go"
|
|
d = c.startNegotiation()
|
|
self.assertEqual(t.read_buf(), relay_handshake)
|
|
self.assertEqual(c.state, "relay") # waiting for OK from relay
|
|
|
|
c.dataReceived(b"not ok\n")
|
|
self.assertEqual(t._connected, False)
|
|
self.assertEqual(c.state, "hung up")
|
|
|
|
results = []
|
|
d.addBoth(results.append)
|
|
self.assertEqual(len(results), 1)
|
|
f = results[0]
|
|
self.assertIsInstance(f, failure.Failure)
|
|
self.assertIsInstance(f.value, transit.BadHandshake)
|
|
self.assertEqual(str(f.value),
|
|
"got %r want %r" % (b"not ok\n", b"ok\n"))
|
|
|
|
def test_receiver_accepted(self):
|
|
# we're on the receiving side, so we wait for the sender to decide
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
|
|
owner._state = "wait-for-decision"
|
|
d = c.startNegotiation()
|
|
self.assertEqual(c.state, "handshake")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
results = []
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
|
|
c.dataReceived(b"expect_this")
|
|
self.assertEqual(c.state, "wait-for-decision")
|
|
self.assertEqual(results, [])
|
|
|
|
c.dataReceived(b"go\n")
|
|
self.assertEqual(c.state, "records")
|
|
self.assertEqual(results, [c])
|
|
|
|
def test_receiver_rejected_politely(self):
|
|
# we're on the receiving side, so we wait for the sender to decide
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
|
|
owner._state = "wait-for-decision"
|
|
d = c.startNegotiation()
|
|
self.assertEqual(c.state, "handshake")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
results = []
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
|
|
c.dataReceived(b"expect_this")
|
|
self.assertEqual(c.state, "wait-for-decision")
|
|
self.assertEqual(results, [])
|
|
|
|
c.dataReceived(b"nevermind\n") # polite rejection
|
|
self.assertEqual(t._connected, False)
|
|
self.assertEqual(c.state, "hung up")
|
|
self.assertEqual(len(results), 1)
|
|
f = results[0]
|
|
self.assertIsInstance(f, failure.Failure)
|
|
self.assertIsInstance(f.value, transit.BadHandshake)
|
|
self.assertEqual(str(f.value),
|
|
"got %r want %r" % (b"nevermind\n", b"go\n"))
|
|
|
|
def test_receiver_rejected_rudely(self):
|
|
# we're on the receiving side, so we wait for the sender to decide
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
|
|
owner._state = "wait-for-decision"
|
|
d = c.startNegotiation()
|
|
self.assertEqual(c.state, "handshake")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
results = []
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
|
|
c.dataReceived(b"expect_this")
|
|
self.assertEqual(c.state, "wait-for-decision")
|
|
self.assertEqual(results, [])
|
|
|
|
t.loseConnection()
|
|
self.assertEqual(t._connected, False)
|
|
self.assertEqual(len(results), 1)
|
|
f = results[0]
|
|
self.assertIsInstance(f, failure.Failure)
|
|
self.assertIsInstance(f.value, transit.BadHandshake)
|
|
self.assertEqual(str(f.value), "connection lost")
|
|
|
|
|
|
def test_cancel(self):
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
|
|
d = c.startNegotiation()
|
|
results = []
|
|
d.addBoth(results.append)
|
|
# while we're waiting for negotiation, we get cancelled
|
|
d.cancel()
|
|
|
|
self.assertEqual(t._connected, False)
|
|
self.assertEqual(c.state, "hung up")
|
|
self.assertEqual(len(results), 1)
|
|
f = results[0]
|
|
self.assertIsInstance(f, failure.Failure)
|
|
self.assertIsInstance(f.value, defer.CancelledError)
|
|
|
|
def test_timeout(self):
|
|
clock = task.Clock()
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
def _callLater(period, func):
|
|
clock.callLater(period, func)
|
|
c.callLater = _callLater
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
# the timer should now be running
|
|
d = c.startNegotiation()
|
|
results = []
|
|
d.addBoth(results.append)
|
|
# while we're waiting for negotiation, the timer expires
|
|
clock.advance(transit.TIMEOUT + 1.0)
|
|
|
|
self.assertEqual(t._connected, False)
|
|
self.assertEqual(len(results), 1)
|
|
f = results[0]
|
|
self.assertIsInstance(f, failure.Failure)
|
|
self.assertIsInstance(f.value, transit.BadHandshake)
|
|
self.assertEqual(str(f.value), "timeout")
|
|
|
|
def make_connection(self):
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
|
|
owner._state = "go"
|
|
d = c.startNegotiation()
|
|
results = []
|
|
d.addBoth(results.append)
|
|
c.dataReceived(b"expect_this")
|
|
self.assertEqual(results, [c])
|
|
t.read_buf() # flush input buffer, prepare for encrypted records
|
|
|
|
return t, c, owner
|
|
|
|
def test_records_good(self):
|
|
# now make sure that outbound records are encrypted properly
|
|
t, c, owner = self.make_connection()
|
|
|
|
RECORD1 = b"record"
|
|
c.send_record(RECORD1)
|
|
buf = t.read_buf()
|
|
expected = ("%08x" % (24+len(RECORD1)+16)).encode("ascii")
|
|
self.assertEqual(hexlify(buf[:4]), expected)
|
|
encrypted = buf[4:]
|
|
receive_box = SecretBox(owner._sender_record_key())
|
|
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
|
|
nonce = int(hexlify(nonce_buf), 16)
|
|
self.assertEqual(nonce, 0) # first message gets nonce 0
|
|
decrypted = receive_box.decrypt(encrypted)
|
|
self.assertEqual(decrypted, RECORD1)
|
|
|
|
# second message gets nonce 1
|
|
RECORD2 = b"record2"
|
|
c.send_record(RECORD2)
|
|
buf = t.read_buf()
|
|
expected = ("%08x" % (24+len(RECORD2)+16)).encode("ascii")
|
|
self.assertEqual(hexlify(buf[:4]), expected)
|
|
encrypted = buf[4:]
|
|
receive_box = SecretBox(owner._sender_record_key())
|
|
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
|
|
nonce = int(hexlify(nonce_buf), 16)
|
|
self.assertEqual(nonce, 1)
|
|
decrypted = receive_box.decrypt(encrypted)
|
|
self.assertEqual(decrypted, RECORD2)
|
|
|
|
# and that we can receive records properly
|
|
inbound_records = []
|
|
c.recordReceived = inbound_records.append
|
|
send_box = SecretBox(owner._receiver_record_key())
|
|
|
|
RECORD3 = b"record3"
|
|
nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0
|
|
encrypted = send_box.encrypt(RECORD3, nonce_buf)
|
|
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
|
c.dataReceived(length[:2])
|
|
c.dataReceived(length[2:])
|
|
c.dataReceived(encrypted[:-2])
|
|
self.assertEqual(inbound_records, [])
|
|
c.dataReceived(encrypted[-2:])
|
|
self.assertEqual(inbound_records, [RECORD3])
|
|
|
|
RECORD4 = b"record4"
|
|
nonce_buf = unhexlify("%048x" % 1) # nonces increment
|
|
encrypted = send_box.encrypt(RECORD4, nonce_buf)
|
|
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
|
c.dataReceived(length[:2])
|
|
c.dataReceived(length[2:])
|
|
c.dataReceived(encrypted[:-2])
|
|
self.assertEqual(inbound_records, [RECORD3])
|
|
c.dataReceived(encrypted[-2:])
|
|
self.assertEqual(inbound_records, [RECORD3, RECORD4])
|
|
|
|
# receiving two records at the same time: deliver both
|
|
inbound_records[:] = []
|
|
RECORD5 = b"record5"
|
|
nonce_buf = unhexlify("%048x" % 2) # nonces increment
|
|
encrypted = send_box.encrypt(RECORD5, nonce_buf)
|
|
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
|
r5 = length+encrypted
|
|
RECORD6 = b"record6"
|
|
nonce_buf = unhexlify("%048x" % 3) # nonces increment
|
|
encrypted = send_box.encrypt(RECORD6, nonce_buf)
|
|
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
|
r6 = length+encrypted
|
|
c.dataReceived(r5+r6)
|
|
self.assertEqual(inbound_records, [RECORD5, RECORD6])
|
|
|
|
def corrupt(self, orig):
|
|
last_byte = orig[-1:]
|
|
num = int(hexlify(last_byte).decode("ascii"), 16)
|
|
corrupt_num = 256 - num
|
|
as_byte = unhexlify("%02x" % corrupt_num)
|
|
return orig[:-1] + as_byte
|
|
|
|
def test_records_corrupt(self):
|
|
# corrupt records should be rejected
|
|
t, c, owner = self.make_connection()
|
|
|
|
inbound_records = []
|
|
c.recordReceived = inbound_records.append
|
|
|
|
RECORD = b"record"
|
|
send_box = SecretBox(owner._receiver_record_key())
|
|
nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0
|
|
encrypted = self.corrupt(send_box.encrypt(RECORD, nonce_buf))
|
|
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
|
c.dataReceived(length)
|
|
c.dataReceived(encrypted[:-2])
|
|
self.assertEqual(inbound_records, [])
|
|
self.assertRaises(CryptoError, c.dataReceived, encrypted[-2:])
|
|
self.assertEqual(inbound_records, [])
|
|
# and the connection should have been dropped
|
|
self.assertEqual(t._connected, False)
|
|
|
|
def test_out_of_order_nonce(self):
|
|
# an inbound out-of-order nonce should be rejected
|
|
t, c, owner = self.make_connection()
|
|
|
|
inbound_records = []
|
|
c.recordReceived = inbound_records.append
|
|
|
|
RECORD = b"record"
|
|
send_box = SecretBox(owner._receiver_record_key())
|
|
nonce_buf = unhexlify("%048x" % 1) # first nonce must be 0
|
|
encrypted = send_box.encrypt(RECORD, nonce_buf)
|
|
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
|
c.dataReceived(length)
|
|
c.dataReceived(encrypted[:-2])
|
|
self.assertEqual(inbound_records, [])
|
|
self.assertRaises(transit.BadNonce, c.dataReceived, encrypted[-2:])
|
|
self.assertEqual(inbound_records, [])
|
|
# and the connection should have been dropped
|
|
self.assertEqual(t._connected, False)
|
|
|
|
# TODO: check that .connectionLost/loseConnection signatures are
|
|
# consistent: zero args, or one arg?
|
|
|
|
# XXX: if we don't set the transit key before connecting, what
|
|
# happens? We currently get a type-check assertion from HKDF because
|
|
# the key is None.
|
|
|
|
def test_receive_queue(self):
|
|
c = transit.Connection(None, None, None, "description")
|
|
c.transport = FakeTransport(c, None)
|
|
c.transport.signalConnectionLost = False
|
|
results = [[] for i in range(5)]
|
|
c.recordReceived(b"0")
|
|
c.recordReceived(b"1")
|
|
c.recordReceived(b"2")
|
|
c.receive_record().addBoth(results[0].append)
|
|
self.assertEqual(results[0], [b"0"])
|
|
d1 = c.receive_record()
|
|
d2 = c.receive_record()
|
|
# they must fire in order of receipt, not order of addCallback
|
|
d2.addBoth(results[2].append)
|
|
self.assertEqual(results[2], [b"2"])
|
|
d1.addBoth(results[1].append)
|
|
self.assertEqual(results[1], [b"1"])
|
|
|
|
c.receive_record().addBoth(results[3].append)
|
|
c.receive_record().addBoth(results[4].append)
|
|
self.assertEqual(results[3], [])
|
|
self.assertEqual(results[4], [])
|
|
|
|
c.recordReceived(b"3")
|
|
self.assertEqual(results[3], [b"3"])
|
|
self.assertEqual(results[4], [])
|
|
|
|
c.recordReceived(b"4")
|
|
self.assertEqual(results[3], [b"3"])
|
|
self.assertEqual(results[4], [b"4"])
|
|
|
|
closed = []
|
|
c.receive_record().addBoth(closed.append)
|
|
c.close()
|
|
self.assertEqual(len(closed), 1)
|
|
f = closed[0]
|
|
self.assertIsInstance(f, failure.Failure)
|
|
self.assertIsInstance(f.value, error.ConnectionClosed)
|
|
|
|
def test_producer(self):
|
|
# a Transit object (receiving data from the remote peer) produces
|
|
# data and writes it into a local Consumer
|
|
c = transit.Connection(None, None, None, "description")
|
|
c.transport = proto_helpers.StringTransport()
|
|
c.recordReceived(b"r1.")
|
|
c.recordReceived(b"r2.")
|
|
|
|
consumer = proto_helpers.StringTransport()
|
|
c.connectConsumer(consumer)
|
|
self.assertIs(c._consumer, consumer)
|
|
self.assertEqual(consumer.value(), b"r1.r2.")
|
|
|
|
self.assertRaises(RuntimeError, c.connectConsumer, consumer)
|
|
|
|
c.recordReceived(b"r3.")
|
|
self.assertEqual(consumer.value(), b"r1.r2.r3.")
|
|
|
|
c.pauseProducing()
|
|
self.assertEqual(c.transport.producerState, "paused")
|
|
c.resumeProducing()
|
|
self.assertEqual(c.transport.producerState, "producing")
|
|
|
|
c.disconnectConsumer()
|
|
self.assertEqual(consumer.producer, None)
|
|
c.connectConsumer(consumer)
|
|
|
|
c.stopProducing()
|
|
self.assertEqual(c.transport.producerState, "stopped")
|
|
|
|
def test_consumer(self):
|
|
# a local producer sends data to a consuming Transit object
|
|
c = transit.Connection(None, None, None, "description")
|
|
c.transport = proto_helpers.StringTransport()
|
|
records = []
|
|
c.send_record = records.append
|
|
|
|
producer = proto_helpers.StringTransport()
|
|
c.registerProducer(producer, True)
|
|
self.assertIs(c.transport.producer, producer)
|
|
|
|
c.write(b"r1.")
|
|
self.assertEqual(records, [b"r1."])
|
|
|
|
c.unregisterProducer()
|
|
self.assertEqual(c.transport.producer, None)
|
|
|
|
|
|
DIRECT_HINT = u"tcp:direct:1234"
|
|
RELAY_HINT = u"tcp:relay:1234"
|
|
UNUSABLE_HINT = u"unusable:foo:bar"
|
|
|
|
class Transit(unittest.TestCase):
|
|
@inlineCallbacks
|
|
def test_success_direct(self):
|
|
clock = task.Clock()
|
|
s = transit.TransitSender(u"", reactor=clock)
|
|
s.set_transit_key(b"key")
|
|
hints = yield s.get_direct_hints() # start the listener
|
|
del hints
|
|
s.add_their_direct_hints([DIRECT_HINT, UNUSABLE_HINT])
|
|
s.add_their_relay_hints([])
|
|
|
|
connectors = []
|
|
def _start_connector(ep, description, is_relay=False):
|
|
d = defer.Deferred()
|
|
connectors.append(d)
|
|
return d
|
|
s._start_connector = _start_connector
|
|
d = s.connect()
|
|
results = []
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
self.assertEqual(len(connectors), 1)
|
|
self.assertIsInstance(connectors[0], defer.Deferred)
|
|
|
|
connectors[0].callback("winner")
|
|
self.assertEqual(results, ["winner"])
|
|
|
|
def _endpoint_from_hint(self, hint):
|
|
if hint == DIRECT_HINT:
|
|
return "direct"
|
|
elif hint == RELAY_HINT:
|
|
return "relay"
|
|
elif hint == UNUSABLE_HINT:
|
|
return None
|
|
else:
|
|
return "ep"
|
|
|
|
@inlineCallbacks
|
|
def test_wait_for_relay(self):
|
|
clock = task.Clock()
|
|
s = transit.TransitSender(u"", reactor=clock)
|
|
s.set_transit_key(b"key")
|
|
hints = yield s.get_direct_hints() # start the listener
|
|
del hints
|
|
s.add_their_direct_hints([DIRECT_HINT, UNUSABLE_HINT])
|
|
s.add_their_relay_hints([RELAY_HINT])
|
|
|
|
direct_connectors = []
|
|
relay_connectors = []
|
|
s._endpoint_from_hint = self._endpoint_from_hint
|
|
def _start_connector(ep, description, is_relay=False):
|
|
d = defer.Deferred()
|
|
if ep == "direct":
|
|
direct_connectors.append(d)
|
|
elif ep == "relay":
|
|
relay_connectors.append(d)
|
|
else:
|
|
raise ValueError
|
|
return d
|
|
s._start_connector = _start_connector
|
|
|
|
d = s.connect()
|
|
results = []
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
# the direct connectors are tried right away, but the relay
|
|
# connectors are stalled for a few seconds
|
|
self.assertEqual(len(direct_connectors), 1)
|
|
self.assertEqual(len(relay_connectors), 0)
|
|
|
|
clock.advance(s.RELAY_DELAY + 1.0)
|
|
self.assertEqual(len(direct_connectors), 1)
|
|
self.assertEqual(len(relay_connectors), 1)
|
|
|
|
direct_connectors[0].callback("winner")
|
|
self.assertEqual(results, ["winner"])
|
|
|
|
@inlineCallbacks
|
|
def test_no_direct_hints(self):
|
|
clock = task.Clock()
|
|
s = transit.TransitSender(u"", reactor=clock)
|
|
s.set_transit_key(b"key")
|
|
hints = yield s.get_direct_hints() # start the listener
|
|
del hints
|
|
s.add_their_direct_hints([UNUSABLE_HINT])
|
|
s.add_their_relay_hints([RELAY_HINT, UNUSABLE_HINT])
|
|
|
|
direct_connectors = []
|
|
relay_connectors = []
|
|
s._endpoint_from_hint = self._endpoint_from_hint
|
|
def _start_connector(ep, description, is_relay=False):
|
|
d = defer.Deferred()
|
|
if ep == "direct":
|
|
direct_connectors.append(d)
|
|
elif ep == "relay":
|
|
relay_connectors.append(d)
|
|
else:
|
|
raise ValueError
|
|
return d
|
|
s._start_connector = _start_connector
|
|
|
|
d = s.connect()
|
|
results = []
|
|
d.addBoth(results.append)
|
|
self.assertEqual(results, [])
|
|
# since there are no usable direct hints, the relay connector will
|
|
# only be stalled for 0 seconds
|
|
self.assertEqual(len(direct_connectors), 0)
|
|
self.assertEqual(len(relay_connectors), 0)
|
|
|
|
clock.advance(0)
|
|
self.assertEqual(len(direct_connectors), 0)
|
|
self.assertEqual(len(relay_connectors), 1)
|
|
|
|
relay_connectors[0].callback("winner")
|
|
self.assertEqual(results, ["winner"])
|
|
|
|
|
|
class Full(unittest.TestCase):
|
|
def doBoth(self, d1, d2):
|
|
return gatherResults([d1, d2], True)
|
|
|
|
@inlineCallbacks
|
|
def test_full(self):
|
|
KEY = b"k"*32
|
|
s = transit.TransitSender(None)
|
|
r = transit.TransitReceiver(None)
|
|
|
|
s.set_transit_key(KEY)
|
|
r.set_transit_key(KEY)
|
|
|
|
shints = yield s.get_direct_hints()
|
|
rhints = yield r.get_direct_hints()
|
|
|
|
s.add_their_direct_hints(rhints)
|
|
r.add_their_direct_hints(shints)
|
|
|
|
s.add_their_relay_hints([])
|
|
r.add_their_relay_hints([])
|
|
|
|
(x,y) = yield self.doBoth(s.connect(), r.connect())
|
|
self.assertIsInstance(x, transit.Connection)
|
|
self.assertIsInstance(y, transit.Connection)
|
|
|
|
d = y.receive_record()
|
|
|
|
x.send_record(b"record1")
|
|
r = yield d
|
|
self.assertEqual(r, b"record1")
|
|
|
|
yield x.close()
|
|
yield y.close()
|