from __future__ import print_function from binascii import hexlify, unhexlify from twisted.trial import unittest from twisted.internet import defer, task, endpoints, protocol, address from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.python import log, failure from ..twisted import transit from ..errors import UsageError from nacl.secret import SecretBox from nacl.exceptions import CryptoError class Highlander(unittest.TestCase): def test_one_winner(self): cancelled = set() contenders = [defer.Deferred(lambda d: cancelled.add(i)) for i in range(4)] result = [] d = transit.there_can_be_only_one(contenders) d.addBoth(result.append) self.assertEqual(result, []) contenders[0].errback(ValueError()) self.assertEqual(result, []) contenders[1].errback(TypeError()) self.assertEqual(result, []) contenders[2].callback("yay") self.assertEqual(result, ["yay"]) self.assertEqual(cancelled, set([3])) def test_there_might_also_be_none(self): cancelled = set() contenders = [defer.Deferred(lambda d: cancelled.add(i)) for i in range(4)] result = [] d = transit.there_can_be_only_one(contenders) d.addBoth(result.append) self.assertEqual(result, []) contenders[0].errback(ValueError()) self.assertEqual(result, []) contenders[1].errback(TypeError()) self.assertEqual(result, []) contenders[2].errback(TypeError()) self.assertEqual(result, []) contenders[3].errback(NameError()) self.assertEqual(len(result), 1) f = result[0] self.assertIsInstance(f.value, ValueError) # first failure is recorded self.assertEqual(cancelled, set()) def test_cancel_early(self): cancelled = set() contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4)] result = [] d = transit.there_can_be_only_one(contenders) d.addBoth(result.append) self.assertEqual(result, []) self.assertEqual(cancelled, set()) d.cancel() self.assertEqual(len(result), 1) self.assertIsInstance(result[0].value, defer.CancelledError) self.assertEqual(cancelled, set(range(4))) def test_cancel_after_one_failure(self): cancelled = set() contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4)] result = [] d = transit.there_can_be_only_one(contenders) d.addBoth(result.append) self.assertEqual(result, []) self.assertEqual(cancelled, set()) contenders[0].errback(ValueError()) d.cancel() self.assertEqual(len(result), 1) self.assertIsInstance(result[0].value, ValueError) self.assertEqual(cancelled, set([1,2,3])) class Forever(unittest.TestCase): def _forever_setup(self): clock = task.Clock() c = transit.Common(u"", reactor=clock) cancelled = [] result = [] d0 = defer.Deferred(cancelled.append) d = c._not_forever(1.0, d0) d.addBoth(result.append) return c, clock, d0, d, cancelled, result def test_not_forever_fires(self): c, clock, d0, d, cancelled, result = self._forever_setup() self.assertEqual((result, cancelled), ([], [])) d.callback(1) self.assertEqual((result, cancelled), ([1], [])) self.assertNot(clock.getDelayedCalls()) def test_not_forever_errs(self): c, clock, d0, d, cancelled, result = self._forever_setup() self.assertEqual((result, cancelled), ([], [])) d.errback(ValueError()) self.assertEqual(cancelled, []) self.assertEqual(len(result), 1) self.assertIsInstance(result[0].value, ValueError) self.assertNot(clock.getDelayedCalls()) def test_not_forever_cancel_early(self): c, clock, d0, d, cancelled, result = self._forever_setup() self.assertEqual((result, cancelled), ([], [])) d.cancel() self.assertEqual(cancelled, [d0]) self.assertEqual(len(result), 1) self.assertIsInstance(result[0].value, defer.CancelledError) self.assertNot(clock.getDelayedCalls()) def test_not_forever_timeout(self): c, clock, d0, d, cancelled, result = self._forever_setup() self.assertEqual((result, cancelled), ([], [])) clock.advance(2.0) self.assertEqual(cancelled, [d0]) self.assertEqual(len(result), 1) self.assertIsInstance(result[0].value, defer.CancelledError) self.assertNot(clock.getDelayedCalls()) class Misc(unittest.TestCase): def test_allocate_port(self): portno = transit.allocate_tcp_port() self.assertIsInstance(portno, int) class Hints(unittest.TestCase): def test_endpoint_from_hint(self): c = transit.Common(u"") ep = c._endpoint_from_hint("tcp:localhost:1234") self.assertIsInstance(ep, endpoints.HostnameEndpoint) ep = c._endpoint_from_hint("unknown:stuff:yowza:pivlor") self.assertEqual(ep, None) ep = c._endpoint_from_hint("tooshort") self.assertEqual(ep, None) class Basic(unittest.TestCase): def test_relay_hints(self): URL = u"RELAYURL" c = transit.Common(URL) self.assertEqual(c.get_relay_hints(), [URL]) self.assertRaises(UsageError, transit.Common, 123) def test_no_relay_hints(self): c = transit.Common(None) self.assertEqual(c.get_relay_hints(), []) def test_bad_hints(self): c = transit.Common(u"") self.assertRaises(TypeError, c.add_their_direct_hints, [123]) c.add_their_direct_hints([u"URL"]) self.assertRaises(TypeError, c.add_their_relay_hints, [123]) c.add_their_relay_hints([u"URL"]) def test_transit_key_wait(self): KEY = b"123" c = transit.Common(u"") results = [] d = c._get_transit_key() d.addBoth(results.append) self.assertEqual(results, []) c.set_transit_key(KEY) self.assertEqual(results, [KEY]) def test_transit_key_already_set(self): KEY = b"123" c = transit.Common(u"") c.set_transit_key(KEY) results = [] d = c._get_transit_key() d.addBoth(results.append) self.assertEqual(results, [KEY]) def test_transit_keys(self): KEY = b"123" s = transit.TransitSender(u"") s.set_transit_key(KEY) r = transit.TransitReceiver(u"") r.set_transit_key(KEY) self.assertEqual(s._send_this(), b"transit sender 559bdeae4b49fa6a23378d2b68f4c7e69378615d4af049c371c6a26e82391089 ready\n\n") self.assertEqual(s._send_this(), r._expect_this()) self.assertEqual(r._send_this(), b"transit receiver ed447528194bac4c00d0c854b12a97ce51413d89aa74d6304475f516fdc23a1b ready\n\n") self.assertEqual(r._send_this(), s._expect_this()) self.assertEqual(hexlify(s._sender_record_key()), b"5a2fba3a9e524ab2e2823ff53b05f946896f6e4ce4e282ffd8e3ac0e5e9e0cda") self.assertEqual(hexlify(s._sender_record_key()), hexlify(r._receiver_record_key())) self.assertEqual(hexlify(r._sender_record_key()), b"eedb143117249f45b39da324decf6bd9aae33b7ccd58487436de611a3c6b871d") self.assertEqual(hexlify(r._sender_record_key()), hexlify(s._receiver_record_key())) def test_connection_ready(self): s = transit.TransitSender(u"") self.assertEqual(s.describe(), "not yet established") self.assertEqual(s.connection_ready("p1", "desc1"), "go") self.assertEqual(s.describe(), "desc1") self.assertEqual(s._winner, "p1") self.assertEqual(s.connection_ready("p2", "desc2"), "nevermind") self.assertEqual(s.describe(), "desc1") self.assertEqual(s._winner, "p1") r = transit.TransitReceiver(u"") self.assertEqual(r.describe(), "not yet established") self.assertEqual(r.connection_ready("p1", "desc1"), "wait-for-decision") self.assertEqual(r.describe(), "not yet established") self.assertEqual(r.connection_ready("p2", "desc2"), "wait-for-decision") self.assertEqual(r.describe(), "not yet established") class Listener(unittest.TestCase): def test_listener(self): c = transit.Common(u"") hints, ep = c._build_listener() self.assertIsInstance(hints, (list, set)) if hints: self.assertIsInstance(hints[0], type(u"")) self.assert_(hints[0].startswith(u"tcp:")) self.assertIsInstance(ep, endpoints.TCP4ServerEndpoint) def test_get_direct_hints(self): # this actually starts the listener c = transit.TransitSender(u"") results = [] d = c.get_direct_hints() d.addBoth(results.append) self.assertEqual(len(results), 1) hints = results[0] # the hints are supposed to be cached, so calling this twice won't # start a second listener self.assert_(c._listener) results = [] d = c.get_direct_hints() d.addBoth(results.append) self.assertEqual(results, [hints]) c._stop_listening() class DummyProtocol(protocol.Protocol): def __init__(self): self.buf = b"" self._count = None self._d2 = None def wait_for(self, count): if len(self.buf) >= count: data = self.buf[:count] self.buf = self.buf[count:] return defer.succeed(data) self._d = defer.Deferred() self._count = count return self._d def dataReceived(self, data): self.buf += data #print("oDR", self._count, len(self.buf)) if self._count is not None and len(self.buf) >= self._count: got = self.buf[:self._count] self.buf = self.buf[self._count:] self._count = None self._d.callback(got) def wait_for_disconnect(self): self._d2 = defer.Deferred() return self._d2 def connectionLost(self, reason): if self._d2: self._d2.callback(None) class FakeTransport: def __init__(self, p, peeraddr): self.protocol = p self._peeraddr = peeraddr self._buf = b"" self._connected = True def write(self, data): self._buf += data def loseConnection(self): self._connected = False self.protocol.connectionLost() def getPeer(self): return self._peeraddr def read_buf(self): b = self._buf self._buf = b"" return b class RandomError(Exception): pass class MockConnection: def __init__(self, owner, relay_handshake, start): self.owner = owner self.relay_handshake = relay_handshake self.start = start def cancel(d): self._cancelled = True self._d = defer.Deferred(cancel) self._start_negotiation_called = False self._cancelled = False def startNegotiation(self, description): self._start_negotiation_called = True self._description = description return self._d class InboundConnectionFactory(unittest.TestCase): def test_describe(self): f = transit.InboundConnectionFactory(None) addrH = address.HostnameAddress("example.com", 1234) self.assertEqual(f.describePeer(addrH), "<-example.com:1234") addr4 = address.IPv4Address("TCP", "1.2.3.4", 1234) self.assertEqual(f.describePeer(addr4), "<-1.2.3.4:1234") addr6 = address.IPv6Address("TCP", "::1", 1234) self.assertEqual(f.describePeer(addr6), "<-::1:1234") addrU = address.UNIXAddress("/dev/unlikely") self.assertEqual(f.describePeer(addrU), "<-UNIXAddress('/dev/unlikely')") def test_success(self): f = transit.InboundConnectionFactory("owner") f.protocol = MockConnection results = [] d = f.whenDone() d.addBoth(results.append) self.assertEqual(results, []) addr = address.HostnameAddress("example.com", 1234) p = f.buildProtocol(addr) self.assertIsInstance(p, MockConnection) self.assertEqual(p.owner, "owner") self.assertEqual(p.relay_handshake, None) self.assertEqual(p._start_negotiation_called, False) # meh .start # this is normally called from Connection.connectionMade f.connectionWasMade(p, addr) self.assertEqual(p._start_negotiation_called, True) self.assertEqual(results, []) self.assertEqual(p._description, "<-example.com:1234") p._d.callback(p) self.assertEqual(results, [p]) def test_one_fail_one_success(self): f = transit.InboundConnectionFactory("owner") f.protocol = MockConnection results = [] d = f.whenDone() d.addBoth(results.append) self.assertEqual(results, []) addr = address.HostnameAddress("example.com", 1234) p1 = f.buildProtocol(addr) p2 = f.buildProtocol(addr) f.connectionWasMade(p1, "desc1") f.connectionWasMade(p2, "desc2") self.assertEqual(results, []) p1._d.errback(transit.BadHandshake("nope")) self.assertEqual(results, []) p2._d.callback(p2) self.assertEqual(results, [p2]) def test_first_success_wins(self): f = transit.InboundConnectionFactory("owner") f.protocol = MockConnection results = [] d = f.whenDone() d.addBoth(results.append) self.assertEqual(results, []) addr = address.HostnameAddress("example.com", 1234) p1 = f.buildProtocol(addr) p2 = f.buildProtocol(addr) f.connectionWasMade(p1, "desc1") f.connectionWasMade(p2, "desc2") self.assertEqual(results, []) p1._d.callback(p1) self.assertEqual(results, [p1]) self.assertEqual(p1._cancelled, False) self.assertEqual(p2._cancelled, True) def test_log_other_errors(self): f = transit.InboundConnectionFactory("owner") f.protocol = MockConnection results = [] d = f.whenDone() d.addBoth(results.append) self.assertEqual(results, []) addr = address.HostnameAddress("example.com", 1234) p1 = f.buildProtocol(addr) # if the Connection protocol throws an unexpected error, that should # get logged to the Twisted logs (as an Unhandled Error in Deferred) # so we can diagnose the bug f.connectionWasMade(p1, "desc1") p1._d.errback(RandomError("boom")) self.assertEqual(len(results), 0) log.msg("=== note: the next RandomError is expected ===") # Make sure the Deferred has gone out of scope, so the UnhandledError # happens quickly. We must manually break the gc cycle. del p1._d self.flushLoggedErrors(RandomError) log.msg("=== note: the preceding RandomError was expected ===") def test_cancel(self): f = transit.InboundConnectionFactory("owner") f.protocol = MockConnection results = [] d = f.whenDone() d.addBoth(results.append) self.assertEqual(results, []) addr = address.HostnameAddress("example.com", 1234) p1 = f.buildProtocol(addr) p2 = f.buildProtocol(addr) f.connectionWasMade(p1, "desc1") f.connectionWasMade(p2, "desc2") self.assertEqual(results, []) d.cancel() self.assertEqual(len(results), 1) f = results[0] self.assertIsInstance(f, failure.Failure) self.assertIsInstance(f.value, defer.CancelledError) self.assertEqual(p1._cancelled, True) self.assertEqual(p2._cancelled, True) # XXX check descriptions class OutboundConnectionFactory(unittest.TestCase): def test_success(self): f = transit.OutboundConnectionFactory("owner", "relay_handshake") f.protocol = MockConnection addr = address.HostnameAddress("example.com", 1234) p = f.buildProtocol(addr) self.assertIsInstance(p, MockConnection) self.assertEqual(p.owner, "owner") self.assertEqual(p.relay_handshake, "relay_handshake") self.assertEqual(p._start_negotiation_called, False) # meh .start # this is normally called from Connection.connectionMade f.connectionWasMade(p, "desc") # no-op for outbound self.assertEqual(p._start_negotiation_called, False) class MockOwner: _connection_ready_called = False def connection_ready(self, connection, description): self._connection_ready_called = True self._connection = connection self._description = description return self._state def _send_this(self): return b"send_this" def _expect_this(self): return b"expect_this" def _sender_record_key(self): return b"s"*32 def _receiver_record_key(self): return b"r"*32 class MockFactory: _connectionWasMade_called = False def connectionWasMade(self, p, description): self._connectionWasMade_called = True self._p = p class Connection(unittest.TestCase): # exercise the Connection protocol class def test_check_and_remove(self): c = transit.Connection(None, None, None) c.buf = b"" EXP = b"expectation" self.assertFalse(c._check_and_remove(EXP)) self.assertEqual(c.buf, b"") c.buf = b"unexpected" e = self.assertRaises(transit.BadHandshake, c._check_and_remove, EXP) self.assertEqual(str(e), "got %r want %r" % (b'unexpected', b'expectation')) self.assertEqual(c.buf, b"unexpected") c.buf = b"expect" self.assertFalse(c._check_and_remove(EXP)) self.assertEqual(c.buf, b"expect") c.buf = b"expectation" self.assertTrue(c._check_and_remove(EXP)) self.assertEqual(c.buf, b"") c.buf = b"expectation exceeded" self.assertTrue(c._check_and_remove(EXP)) self.assertEqual(c.buf, b" exceeded") def test_sender_accepting(self): relay_handshake = None owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) c = transit.Connection(owner, relay_handshake, None) self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory c.connectionMade() self.assertEqual(factory._connectionWasMade_called, True) self.assertEqual(factory._p, c) owner._state = "go" d = c.startNegotiation("description") self.assertEqual(c.state, "handshake") self.assertEqual(t.read_buf(), b"send_this") results = [] d.addBoth(results.append) self.assertEqual(results, []) c.dataReceived(b"expect_this") self.assertEqual(t.read_buf(), b"go\n") self.assertEqual(t._connected, True) self.assertEqual(c.state, "records") self.assertEqual(results, [c]) c.close() self.assertEqual(t._connected, False) def test_sender_rejecting(self): relay_handshake = None owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) c = transit.Connection(owner, relay_handshake, None) self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory c.connectionMade() self.assertEqual(factory._connectionWasMade_called, True) self.assertEqual(factory._p, c) owner._state = "nevermind" d = c.startNegotiation("description") self.assertEqual(c.state, "handshake") self.assertEqual(t.read_buf(), b"send_this") results = [] d.addBoth(results.append) self.assertEqual(results, []) c.dataReceived(b"expect_this") self.assertEqual(t.read_buf(), b"nevermind\n") self.assertEqual(t._connected, False) self.assertEqual(c.state, "hung up") self.assertEqual(len(results), 1) f = results[0] self.assertIsInstance(f, failure.Failure) self.assertIsInstance(f.value, transit.BadHandshake) self.assertEqual(str(f.value), "abandoned") def test_handshake_other_error(self): owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) c = transit.Connection(owner, None, None) self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory c.connectionMade() self.assertEqual(factory._connectionWasMade_called, True) self.assertEqual(factory._p, c) d = c.startNegotiation("description") self.assertEqual(c.state, "handshake") self.assertEqual(t.read_buf(), b"send_this") results = [] d.addBoth(results.append) self.assertEqual(results, []) c.state = RandomError("boom") self.assertRaises(RandomError, c.dataReceived, b"surprise!") self.assertEqual(t._connected, False) self.assertEqual(c.state, "hung up") self.assertEqual(len(results), 1) f = results[0] self.assertIsInstance(f, failure.Failure) self.assertIsInstance(f.value, RandomError) def test_relay_handshake(self): relay_handshake = b"relay handshake" owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) c = transit.Connection(owner, relay_handshake, None) self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory c.connectionMade() self.assertEqual(factory._connectionWasMade_called, True) self.assertEqual(factory._p, c) self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation owner._state = "go" d = c.startNegotiation("description") self.assertEqual(t.read_buf(), relay_handshake) self.assertEqual(c.state, "relay") # waiting for OK from relay c.dataReceived(b"ok\n") self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(c.state, "handshake") results = [] d.addBoth(results.append) self.assertEqual(results, []) c.dataReceived(b"expect_this") self.assertEqual(c.state, "records") self.assertEqual(results, [c]) self.assertEqual(t.read_buf(), b"go\n") def test_relay_handshake_bad(self): relay_handshake = b"relay handshake" owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) c = transit.Connection(owner, relay_handshake, None) self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory c.connectionMade() self.assertEqual(factory._connectionWasMade_called, True) self.assertEqual(factory._p, c) self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation owner._state = "go" d = c.startNegotiation("description") self.assertEqual(t.read_buf(), relay_handshake) self.assertEqual(c.state, "relay") # waiting for OK from relay c.dataReceived(b"not ok\n") self.assertEqual(t._connected, False) self.assertEqual(c.state, "hung up") results = [] d.addBoth(results.append) self.assertEqual(len(results), 1) f = results[0] self.assertIsInstance(f, failure.Failure) self.assertIsInstance(f.value, transit.BadHandshake) self.assertEqual(str(f.value), "got %r want %r" % (b"not ok\n", b"ok\n")) def test_receiver_accepted(self): # we're on the receiving side, so we wait for the sender to decide owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) c = transit.Connection(owner, None, None) self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory c.connectionMade() self.assertEqual(factory._connectionWasMade_called, True) self.assertEqual(factory._p, c) owner._state = "wait-for-decision" d = c.startNegotiation("description") self.assertEqual(c.state, "handshake") self.assertEqual(t.read_buf(), b"send_this") results = [] d.addBoth(results.append) self.assertEqual(results, []) c.dataReceived(b"expect_this") self.assertEqual(c.state, "wait-for-decision") self.assertEqual(results, []) c.dataReceived(b"go\n") self.assertEqual(c.state, "records") self.assertEqual(results, [c]) def test_receiver_rejected_politely(self): # we're on the receiving side, so we wait for the sender to decide owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) c = transit.Connection(owner, None, None) self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory c.connectionMade() self.assertEqual(factory._connectionWasMade_called, True) self.assertEqual(factory._p, c) owner._state = "wait-for-decision" d = c.startNegotiation("description") self.assertEqual(c.state, "handshake") self.assertEqual(t.read_buf(), b"send_this") results = [] d.addBoth(results.append) self.assertEqual(results, []) c.dataReceived(b"expect_this") self.assertEqual(c.state, "wait-for-decision") self.assertEqual(results, []) c.dataReceived(b"nevermind\n") # polite rejection self.assertEqual(t._connected, False) self.assertEqual(c.state, "hung up") self.assertEqual(len(results), 1) f = results[0] self.assertIsInstance(f, failure.Failure) self.assertIsInstance(f.value, transit.BadHandshake) self.assertEqual(str(f.value), "got %r want %r" % (b"nevermind\n", b"go\n")) def test_receiver_rejected_rudely(self): # we're on the receiving side, so we wait for the sender to decide owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) c = transit.Connection(owner, None, None) self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory c.connectionMade() self.assertEqual(factory._connectionWasMade_called, True) self.assertEqual(factory._p, c) owner._state = "wait-for-decision" d = c.startNegotiation("description") self.assertEqual(c.state, "handshake") self.assertEqual(t.read_buf(), b"send_this") results = [] d.addBoth(results.append) self.assertEqual(results, []) c.dataReceived(b"expect_this") self.assertEqual(c.state, "wait-for-decision") self.assertEqual(results, []) t.loseConnection() self.assertEqual(t._connected, False) self.assertEqual(len(results), 1) f = results[0] self.assertIsInstance(f, failure.Failure) self.assertIsInstance(f.value, transit.BadHandshake) self.assertEqual(str(f.value), "connection lost") def test_cancel(self): owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) c = transit.Connection(owner, None, None) self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory c.connectionMade() d = c.startNegotiation("description") results = [] d.addBoth(results.append) # while we're waiting for negotiation, we get cancelled d.cancel() self.assertEqual(t._connected, False) self.assertEqual(c.state, "hung up") self.assertEqual(len(results), 1) f = results[0] self.assertIsInstance(f, failure.Failure) self.assertIsInstance(f.value, defer.CancelledError) def test_timeout(self): clock = task.Clock() owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) c = transit.Connection(owner, None, None) def _callLater(period, func): clock.callLater(period, func) c.callLater = _callLater self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory c.connectionMade() # the timer should now be running d = c.startNegotiation("description") results = [] d.addBoth(results.append) # while we're waiting for negotiation, the timer expires clock.advance(transit.TIMEOUT + 1.0) self.assertEqual(t._connected, False) self.assertEqual(len(results), 1) f = results[0] self.assertIsInstance(f, failure.Failure) self.assertIsInstance(f.value, transit.BadHandshake) self.assertEqual(str(f.value), "timeout") def make_connection(self): owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) c = transit.Connection(owner, None, None) t = c.transport = FakeTransport(c, addr) c.factory = factory c.connectionMade() owner._state = "go" d = c.startNegotiation("description") results = [] d.addBoth(results.append) c.dataReceived(b"expect_this") self.assertEqual(results, [c]) t.read_buf() # flush input buffer, prepare for encrypted records return t, c, owner def test_records_good(self): # now make sure that outbound records are encrypted properly t, c, owner = self.make_connection() RECORD1 = b"record" c.send_record(RECORD1) buf = t.read_buf() expected = ("%08x" % (24+len(RECORD1)+16)).encode("ascii") self.assertEqual(hexlify(buf[:4]), expected) encrypted = buf[4:] receive_box = SecretBox(owner._sender_record_key()) nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended nonce = int(hexlify(nonce_buf), 16) self.assertEqual(nonce, 0) # first message gets nonce 0 decrypted = receive_box.decrypt(encrypted) self.assertEqual(decrypted, RECORD1) # second message gets nonce 1 RECORD2 = b"record2" c.send_record(RECORD2) buf = t.read_buf() expected = ("%08x" % (24+len(RECORD2)+16)).encode("ascii") self.assertEqual(hexlify(buf[:4]), expected) encrypted = buf[4:] receive_box = SecretBox(owner._sender_record_key()) nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended nonce = int(hexlify(nonce_buf), 16) self.assertEqual(nonce, 1) decrypted = receive_box.decrypt(encrypted) self.assertEqual(decrypted, RECORD2) # and that we can receive records properly inbound_records = [] c.recordReceived = inbound_records.append RECORD3 = b"record3" send_box = SecretBox(owner._receiver_record_key()) nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0 encrypted = send_box.encrypt(RECORD3, nonce_buf) length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long c.dataReceived(length[:2]) c.dataReceived(length[2:]) c.dataReceived(encrypted[:-2]) self.assertEqual(inbound_records, []) c.dataReceived(encrypted[-2:]) self.assertEqual(inbound_records, [RECORD3]) RECORD4 = b"record4" send_box = SecretBox(owner._receiver_record_key()) nonce_buf = unhexlify("%048x" % 1) # nonces increment encrypted = send_box.encrypt(RECORD4, nonce_buf) length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long c.dataReceived(length[:2]) c.dataReceived(length[2:]) c.dataReceived(encrypted[:-2]) self.assertEqual(inbound_records, [RECORD3]) c.dataReceived(encrypted[-2:]) self.assertEqual(inbound_records, [RECORD3, RECORD4]) def corrupt(self, orig): last_byte = orig[-1:] num = int(hexlify(last_byte).decode("ascii"), 16) corrupt_num = 256 - num as_byte = unhexlify("%02x" % corrupt_num) return orig[:-1] + as_byte def test_records_corrupt(self): # corrupt records should be rejected t, c, owner = self.make_connection() inbound_records = [] c.recordReceived = inbound_records.append RECORD = b"record" send_box = SecretBox(owner._receiver_record_key()) nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0 encrypted = self.corrupt(send_box.encrypt(RECORD, nonce_buf)) length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long c.dataReceived(length) c.dataReceived(encrypted[:-2]) self.assertEqual(inbound_records, []) self.assertRaises(CryptoError, c.dataReceived, encrypted[-2:]) self.assertEqual(inbound_records, []) # and the connection should have been dropped self.assertEqual(t._connected, False) def test_out_of_order_nonce(self): # an inbound out-of-order nonce should be rejected t, c, owner = self.make_connection() inbound_records = [] c.recordReceived = inbound_records.append RECORD = b"record" send_box = SecretBox(owner._receiver_record_key()) nonce_buf = unhexlify("%048x" % 1) # first nonce must be 0 encrypted = send_box.encrypt(RECORD, nonce_buf) length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long c.dataReceived(length) c.dataReceived(encrypted[:-2]) self.assertEqual(inbound_records, []) self.assertRaises(transit.BadNonce, c.dataReceived, encrypted[-2:]) self.assertEqual(inbound_records, []) # and the connection should have been dropped self.assertEqual(t._connected, False) # TODO: check that .connectionLost/loseConnection signatures are # consistent: zero args, or one arg? # XXX: if we don't set the transit key before connecting, what # happens? We currently get a type-check assertion from HKDF because # the key is None. DIRECT_HINT = u"tcp:direct:1234" RELAY_HINT = u"tcp:relay:1234" UNUSABLE_HINT = u"unusable:foo:bar" class Transit(unittest.TestCase): @inlineCallbacks def test_success_direct(self): clock = task.Clock() s = transit.TransitSender(u"", reactor=clock) s.set_transit_key(b"key") hints = yield s.get_direct_hints() # start the listener del hints s.add_their_direct_hints([DIRECT_HINT, UNUSABLE_HINT]) s.add_their_relay_hints([]) connectors = [] def _start_connector(ep, description, is_relay=False): d = defer.Deferred() connectors.append(d) return d s._start_connector = _start_connector d = s.connect() results = [] d.addBoth(results.append) self.assertEqual(results, []) self.assertEqual(len(connectors), 1) self.assertIsInstance(connectors[0], defer.Deferred) connectors[0].callback("winner") self.assertEqual(results, ["winner"]) def _endpoint_from_hint(self, hint): if hint == DIRECT_HINT: return "direct" elif hint == RELAY_HINT: return "relay" elif hint == UNUSABLE_HINT: return None else: return "ep" @inlineCallbacks def test_wait_for_relay(self): clock = task.Clock() s = transit.TransitSender(u"", reactor=clock) s.set_transit_key(b"key") hints = yield s.get_direct_hints() # start the listener del hints s.add_their_direct_hints([DIRECT_HINT, UNUSABLE_HINT]) s.add_their_relay_hints([RELAY_HINT]) direct_connectors = [] relay_connectors = [] s._endpoint_from_hint = self._endpoint_from_hint def _start_connector(ep, description, is_relay=False): d = defer.Deferred() if ep == "direct": direct_connectors.append(d) elif ep == "relay": relay_connectors.append(d) else: raise ValueError return d s._start_connector = _start_connector d = s.connect() results = [] d.addBoth(results.append) self.assertEqual(results, []) # the direct connectors are tried right away, but the relay # connectors are stalled for a few seconds self.assertEqual(len(direct_connectors), 1) self.assertEqual(len(relay_connectors), 0) clock.advance(s.RELAY_DELAY + 1.0) self.assertEqual(len(direct_connectors), 1) self.assertEqual(len(relay_connectors), 1) direct_connectors[0].callback("winner") self.assertEqual(results, ["winner"]) @inlineCallbacks def test_no_direct_hints(self): clock = task.Clock() s = transit.TransitSender(u"", reactor=clock) s.set_transit_key(b"key") hints = yield s.get_direct_hints() # start the listener del hints s.add_their_direct_hints([UNUSABLE_HINT]) s.add_their_relay_hints([RELAY_HINT, UNUSABLE_HINT]) direct_connectors = [] relay_connectors = [] s._endpoint_from_hint = self._endpoint_from_hint def _start_connector(ep, description, is_relay=False): d = defer.Deferred() if ep == "direct": direct_connectors.append(d) elif ep == "relay": relay_connectors.append(d) else: raise ValueError return d s._start_connector = _start_connector d = s.connect() results = [] d.addBoth(results.append) self.assertEqual(results, []) # since there are no usable direct hints, the relay connector will # only be stalled for 0 seconds self.assertEqual(len(direct_connectors), 0) self.assertEqual(len(relay_connectors), 0) clock.advance(0) self.assertEqual(len(direct_connectors), 0) self.assertEqual(len(relay_connectors), 1) relay_connectors[0].callback("winner") self.assertEqual(results, ["winner"]) class Full(unittest.TestCase): def doBoth(self, d1, d2): return gatherResults([d1, d2], True) @inlineCallbacks def test_full(self): KEY = b"k"*32 s = transit.TransitSender(None) r = transit.TransitReceiver(None) s.set_transit_key(KEY) r.set_transit_key(KEY) shints = yield s.get_direct_hints() rhints = yield r.get_direct_hints() s.add_their_direct_hints(rhints) r.add_their_direct_hints(shints) s.add_their_relay_hints([]) r.add_their_relay_hints([]) (x,y) = yield self.doBoth(s.connect(), r.connect()) self.assertIsInstance(x, transit.Connection) self.assertIsInstance(y, transit.Connection) d = defer.Deferred() y.recordReceived = d.callback x.send_record(b"record1") r = yield d self.assertEqual(r, b"record1") yield x.close() yield y.close()