diff --git a/src/wormhole/test/test_transit_twisted.py b/src/wormhole/test/test_transit_twisted.py new file mode 100644 index 0000000..3f1da6b --- /dev/null +++ b/src/wormhole/test/test_transit_twisted.py @@ -0,0 +1,1109 @@ +from __future__ import print_function +from binascii import hexlify, unhexlify +from twisted.trial import unittest +from twisted.internet import defer, task, endpoints, protocol, address +from twisted.internet.defer import gatherResults, inlineCallbacks +from twisted.python import log, failure +from ..twisted import transit +from ..errors import UsageError +from nacl.secret import SecretBox +from nacl.exceptions import CryptoError + +class Highlander(unittest.TestCase): + def test_one_winner(self): + cancelled = set() + contenders = [defer.Deferred(lambda d: cancelled.add(i)) + for i in range(4)] + result = [] + d = transit.there_can_be_only_one(contenders) + d.addBoth(result.append) + self.assertEqual(result, []) + contenders[0].errback(ValueError()) + self.assertEqual(result, []) + contenders[1].errback(TypeError()) + self.assertEqual(result, []) + contenders[2].callback("yay") + self.assertEqual(result, ["yay"]) + self.assertEqual(cancelled, set([3])) + + def test_there_might_also_be_none(self): + cancelled = set() + contenders = [defer.Deferred(lambda d: cancelled.add(i)) + for i in range(4)] + result = [] + d = transit.there_can_be_only_one(contenders) + d.addBoth(result.append) + self.assertEqual(result, []) + contenders[0].errback(ValueError()) + self.assertEqual(result, []) + contenders[1].errback(TypeError()) + self.assertEqual(result, []) + contenders[2].errback(TypeError()) + self.assertEqual(result, []) + contenders[3].errback(NameError()) + self.assertEqual(len(result), 1) + f = result[0] + self.assertIsInstance(f.value, ValueError) # first failure is recorded + self.assertEqual(cancelled, set()) + + def test_cancel_early(self): + cancelled = set() + contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) + for i in range(4)] + result = [] + d = transit.there_can_be_only_one(contenders) + d.addBoth(result.append) + self.assertEqual(result, []) + self.assertEqual(cancelled, set()) + d.cancel() + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0].value, defer.CancelledError) + self.assertEqual(cancelled, set(range(4))) + + def test_cancel_after_one_failure(self): + cancelled = set() + contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) + for i in range(4)] + result = [] + d = transit.there_can_be_only_one(contenders) + d.addBoth(result.append) + self.assertEqual(result, []) + self.assertEqual(cancelled, set()) + contenders[0].errback(ValueError()) + d.cancel() + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0].value, ValueError) + self.assertEqual(cancelled, set([1,2,3])) + +class Forever(unittest.TestCase): + def _forever_setup(self): + clock = task.Clock() + c = transit.Common(u"", reactor=clock) + cancelled = [] + result = [] + d0 = defer.Deferred(cancelled.append) + d = c._not_forever(1.0, d0) + d.addBoth(result.append) + return c, clock, d0, d, cancelled, result + + def test_not_forever_fires(self): + c, clock, d0, d, cancelled, result = self._forever_setup() + self.assertEqual((result, cancelled), ([], [])) + d.callback(1) + self.assertEqual((result, cancelled), ([1], [])) + self.assertNot(clock.getDelayedCalls()) + + def test_not_forever_errs(self): + c, clock, d0, d, cancelled, result = self._forever_setup() + self.assertEqual((result, cancelled), ([], [])) + d.errback(ValueError()) + self.assertEqual(cancelled, []) + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0].value, ValueError) + self.assertNot(clock.getDelayedCalls()) + + def test_not_forever_cancel_early(self): + c, clock, d0, d, cancelled, result = self._forever_setup() + self.assertEqual((result, cancelled), ([], [])) + d.cancel() + self.assertEqual(cancelled, [d0]) + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0].value, defer.CancelledError) + self.assertNot(clock.getDelayedCalls()) + + def test_not_forever_timeout(self): + c, clock, d0, d, cancelled, result = self._forever_setup() + self.assertEqual((result, cancelled), ([], [])) + clock.advance(2.0) + self.assertEqual(cancelled, [d0]) + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0].value, defer.CancelledError) + self.assertNot(clock.getDelayedCalls()) + +class Misc(unittest.TestCase): + def test_allocate_port(self): + portno = transit.allocate_tcp_port() + self.assertIsInstance(portno, int) + +class Hints(unittest.TestCase): + def test_endpoint_from_hint(self): + c = transit.Common(u"") + ep = c._endpoint_from_hint("tcp:localhost:1234") + self.assertIsInstance(ep, endpoints.HostnameEndpoint) + ep = c._endpoint_from_hint("unknown:stuff:yowza:pivlor") + self.assertEqual(ep, None) + ep = c._endpoint_from_hint("tooshort") + self.assertEqual(ep, None) + + +class Basic(unittest.TestCase): + def test_relay_hints(self): + URL = u"RELAYURL" + c = transit.Common(URL) + self.assertEqual(c.get_relay_hints(), [URL]) + self.assertRaises(UsageError, transit.Common, 123) + + def test_no_relay_hints(self): + c = transit.Common(None) + self.assertEqual(c.get_relay_hints(), []) + + def test_bad_hints(self): + c = transit.Common(u"") + self.assertRaises(TypeError, c.add_their_direct_hints, [123]) + c.add_their_direct_hints([u"URL"]) + self.assertRaises(TypeError, c.add_their_relay_hints, [123]) + c.add_their_relay_hints([u"URL"]) + + def test_transit_key_wait(self): + KEY = b"123" + c = transit.Common(u"") + results = [] + d = c._get_transit_key() + d.addBoth(results.append) + self.assertEqual(results, []) + c.set_transit_key(KEY) + self.assertEqual(results, [KEY]) + + def test_transit_key_already_set(self): + KEY = b"123" + c = transit.Common(u"") + c.set_transit_key(KEY) + results = [] + d = c._get_transit_key() + d.addBoth(results.append) + self.assertEqual(results, [KEY]) + + def test_transit_keys(self): + KEY = b"123" + s = transit.TransitSender(u"") + s.set_transit_key(KEY) + r = transit.TransitReceiver(u"") + r.set_transit_key(KEY) + + self.assertEqual(s._send_this(), b"transit sender 559bdeae4b49fa6a23378d2b68f4c7e69378615d4af049c371c6a26e82391089 ready\n\n") + self.assertEqual(s._send_this(), r._expect_this()) + + self.assertEqual(r._send_this(), b"transit receiver ed447528194bac4c00d0c854b12a97ce51413d89aa74d6304475f516fdc23a1b ready\n\n") + self.assertEqual(r._send_this(), s._expect_this()) + + self.assertEqual(hexlify(s._sender_record_key()), b"5a2fba3a9e524ab2e2823ff53b05f946896f6e4ce4e282ffd8e3ac0e5e9e0cda") + self.assertEqual(hexlify(s._sender_record_key()), + hexlify(r._receiver_record_key())) + + self.assertEqual(hexlify(r._sender_record_key()), b"eedb143117249f45b39da324decf6bd9aae33b7ccd58487436de611a3c6b871d") + self.assertEqual(hexlify(r._sender_record_key()), + hexlify(s._receiver_record_key())) + + def test_connection_ready(self): + s = transit.TransitSender(u"") + self.assertEqual(s.describe(), "not yet established") + + self.assertEqual(s.connection_ready("p1", "desc1"), "go") + self.assertEqual(s.describe(), "desc1") + self.assertEqual(s._winner, "p1") + + self.assertEqual(s.connection_ready("p2", "desc2"), "nevermind") + self.assertEqual(s.describe(), "desc1") + self.assertEqual(s._winner, "p1") + + r = transit.TransitReceiver(u"") + self.assertEqual(r.describe(), "not yet established") + + self.assertEqual(r.connection_ready("p1", "desc1"), "wait-for-decision") + self.assertEqual(r.describe(), "not yet established") + + self.assertEqual(r.connection_ready("p2", "desc2"), "wait-for-decision") + self.assertEqual(r.describe(), "not yet established") + + +class Listener(unittest.TestCase): + def test_listener(self): + c = transit.Common(u"") + hints, ep = c._build_listener() + self.assertIsInstance(hints, (list, set)) + if hints: + self.assertIsInstance(hints[0], type(u"")) + self.assert_(hints[0].startswith(u"tcp:")) + self.assertIsInstance(ep, endpoints.TCP4ServerEndpoint) + + def test_get_direct_hints(self): + # this actually starts the listener + c = transit.TransitSender(u"") + + results = [] + d = c.get_direct_hints() + d.addBoth(results.append) + self.assertEqual(len(results), 1) + hints = results[0] + + # the hints are supposed to be cached, so calling this twice won't + # start a second listener + self.assert_(c._listener) + results = [] + d = c.get_direct_hints() + d.addBoth(results.append) + self.assertEqual(results, [hints]) + + c._stop_listening() + + +class DummyProtocol(protocol.Protocol): + def __init__(self): + self.buf = b"" + self._count = None + self._d2 = None + + def wait_for(self, count): + if len(self.buf) >= count: + data = self.buf[:count] + self.buf = self.buf[count:] + return defer.succeed(data) + self._d = defer.Deferred() + self._count = count + return self._d + + def dataReceived(self, data): + self.buf += data + #print("oDR", self._count, len(self.buf)) + if self._count is not None and len(self.buf) >= self._count: + got = self.buf[:self._count] + self.buf = self.buf[self._count:] + self._count = None + self._d.callback(got) + + def wait_for_disconnect(self): + self._d2 = defer.Deferred() + return self._d2 + + def connectionLost(self, reason): + if self._d2: + self._d2.callback(None) + +class FakeTransport: + def __init__(self, p, peeraddr): + self.protocol = p + self._peeraddr = peeraddr + self._buf = b"" + self._connected = True + def write(self, data): + self._buf += data + def loseConnection(self): + self._connected = False + self.protocol.connectionLost() + def getPeer(self): + return self._peeraddr + + def read_buf(self): + b = self._buf + self._buf = b"" + return b + +class RandomError(Exception): + pass + +class MockConnection: + def __init__(self, owner, relay_handshake, start): + self.owner = owner + self.relay_handshake = relay_handshake + self.start = start + def cancel(d): + self._cancelled = True + self._d = defer.Deferred(cancel) + self._start_negotiation_called = False + self._cancelled = False + + def startNegotiation(self, description): + self._start_negotiation_called = True + self._description = description + return self._d + +class InboundConnectionFactory(unittest.TestCase): + def test_describe(self): + f = transit.InboundConnectionFactory(None) + addrH = address.HostnameAddress("example.com", 1234) + self.assertEqual(f.describePeer(addrH), "<-example.com:1234") + addr4 = address.IPv4Address("TCP", "1.2.3.4", 1234) + self.assertEqual(f.describePeer(addr4), "<-1.2.3.4:1234") + addr6 = address.IPv6Address("TCP", "::1", 1234) + self.assertEqual(f.describePeer(addr6), "<-::1:1234") + addrU = address.UNIXAddress("/dev/unlikely") + self.assertEqual(f.describePeer(addrU), + "<-UNIXAddress('/dev/unlikely')") + + def test_success(self): + f = transit.InboundConnectionFactory("owner") + f.protocol = MockConnection + results = [] + d = f.whenDone() + d.addBoth(results.append) + self.assertEqual(results, []) + + addr = address.HostnameAddress("example.com", 1234) + p = f.buildProtocol(addr) + self.assertIsInstance(p, MockConnection) + self.assertEqual(p.owner, "owner") + self.assertEqual(p.relay_handshake, None) + self.assertEqual(p._start_negotiation_called, False) + # meh .start + + # this is normally called from Connection.connectionMade + f.connectionWasMade(p, addr) + self.assertEqual(p._start_negotiation_called, True) + self.assertEqual(results, []) + self.assertEqual(p._description, "<-example.com:1234") + + p._d.callback(p) + self.assertEqual(results, [p]) + + def test_one_fail_one_success(self): + f = transit.InboundConnectionFactory("owner") + f.protocol = MockConnection + results = [] + d = f.whenDone() + d.addBoth(results.append) + self.assertEqual(results, []) + + addr = address.HostnameAddress("example.com", 1234) + p1 = f.buildProtocol(addr) + p2 = f.buildProtocol(addr) + + f.connectionWasMade(p1, "desc1") + f.connectionWasMade(p2, "desc2") + self.assertEqual(results, []) + + p1._d.errback(transit.BadHandshake("nope")) + self.assertEqual(results, []) + p2._d.callback(p2) + self.assertEqual(results, [p2]) + + def test_first_success_wins(self): + f = transit.InboundConnectionFactory("owner") + f.protocol = MockConnection + results = [] + d = f.whenDone() + d.addBoth(results.append) + self.assertEqual(results, []) + + addr = address.HostnameAddress("example.com", 1234) + p1 = f.buildProtocol(addr) + p2 = f.buildProtocol(addr) + + f.connectionWasMade(p1, "desc1") + f.connectionWasMade(p2, "desc2") + self.assertEqual(results, []) + + p1._d.callback(p1) + self.assertEqual(results, [p1]) + self.assertEqual(p1._cancelled, False) + self.assertEqual(p2._cancelled, True) + + def test_log_other_errors(self): + f = transit.InboundConnectionFactory("owner") + f.protocol = MockConnection + results = [] + d = f.whenDone() + d.addBoth(results.append) + self.assertEqual(results, []) + + addr = address.HostnameAddress("example.com", 1234) + p1 = f.buildProtocol(addr) + + # if the Connection protocol throws an unexpected error, that should + # get logged to the Twisted logs (as an Unhandled Error in Deferred) + # so we can diagnose the bug + f.connectionWasMade(p1, "desc1") + p1._d.errback(RandomError("boom")) + self.assertEqual(len(results), 0) + + log.msg("=== note: the next RandomError is expected ===") + # Make sure the Deferred has gone out of scope, so the UnhandledError + # happens quickly. We must manually break the gc cycle. + del p1._d + self.flushLoggedErrors(RandomError) + log.msg("=== note: the preceding RandomError was expected ===") + + def test_cancel(self): + f = transit.InboundConnectionFactory("owner") + f.protocol = MockConnection + results = [] + d = f.whenDone() + d.addBoth(results.append) + self.assertEqual(results, []) + + addr = address.HostnameAddress("example.com", 1234) + p1 = f.buildProtocol(addr) + p2 = f.buildProtocol(addr) + + f.connectionWasMade(p1, "desc1") + f.connectionWasMade(p2, "desc2") + self.assertEqual(results, []) + + d.cancel() + + self.assertEqual(len(results), 1) + f = results[0] + self.assertIsInstance(f, failure.Failure) + self.assertIsInstance(f.value, defer.CancelledError) + self.assertEqual(p1._cancelled, True) + self.assertEqual(p2._cancelled, True) + +# XXX check descriptions + +class OutboundConnectionFactory(unittest.TestCase): + def test_success(self): + f = transit.OutboundConnectionFactory("owner", "relay_handshake") + f.protocol = MockConnection + + addr = address.HostnameAddress("example.com", 1234) + p = f.buildProtocol(addr) + self.assertIsInstance(p, MockConnection) + self.assertEqual(p.owner, "owner") + self.assertEqual(p.relay_handshake, "relay_handshake") + self.assertEqual(p._start_negotiation_called, False) + # meh .start + + # this is normally called from Connection.connectionMade + f.connectionWasMade(p, "desc") # no-op for outbound + self.assertEqual(p._start_negotiation_called, False) + + +class MockOwner: + _connection_ready_called = False + def connection_ready(self, connection, description): + self._connection_ready_called = True + self._connection = connection + self._description = description + return self._state + def _send_this(self): + return b"send_this" + def _expect_this(self): + return b"expect_this" + def _sender_record_key(self): + return b"s"*32 + def _receiver_record_key(self): + return b"r"*32 + +class MockFactory: + _connectionWasMade_called = False + def connectionWasMade(self, p, description): + self._connectionWasMade_called = True + self._p = p + +class Connection(unittest.TestCase): + # exercise the Connection protocol class + + def test_check_and_remove(self): + c = transit.Connection(None, None, None) + c.buf = b"" + EXP = b"expectation" + self.assertFalse(c._check_and_remove(EXP)) + self.assertEqual(c.buf, b"") + + c.buf = b"unexpected" + e = self.assertRaises(transit.BadHandshake, c._check_and_remove, EXP) + self.assertEqual(str(e), + "got %r want %r" % (b'unexpected', b'expectation')) + self.assertEqual(c.buf, b"unexpected") + + c.buf = b"expect" + self.assertFalse(c._check_and_remove(EXP)) + self.assertEqual(c.buf, b"expect") + + c.buf = b"expectation" + self.assertTrue(c._check_and_remove(EXP)) + self.assertEqual(c.buf, b"") + + c.buf = b"expectation exceeded" + self.assertTrue(c._check_and_remove(EXP)) + self.assertEqual(c.buf, b" exceeded") + + def test_sender_accepting(self): + relay_handshake = None + owner = MockOwner() + factory = MockFactory() + addr = address.HostnameAddress("example.com", 1234) + c = transit.Connection(owner, relay_handshake, None) + self.assertEqual(c.state, "too-early") + t = c.transport = FakeTransport(c, addr) + c.factory = factory + c.connectionMade() + self.assertEqual(factory._connectionWasMade_called, True) + self.assertEqual(factory._p, c) + + owner._state = "go" + d = c.startNegotiation("description") + self.assertEqual(c.state, "handshake") + self.assertEqual(t.read_buf(), b"send_this") + results = [] + d.addBoth(results.append) + self.assertEqual(results, []) + + c.dataReceived(b"expect_this") + self.assertEqual(t.read_buf(), b"go\n") + self.assertEqual(t._connected, True) + self.assertEqual(c.state, "records") + self.assertEqual(results, [c]) + + c.close() + self.assertEqual(t._connected, False) + + def test_sender_rejecting(self): + relay_handshake = None + owner = MockOwner() + factory = MockFactory() + addr = address.HostnameAddress("example.com", 1234) + c = transit.Connection(owner, relay_handshake, None) + self.assertEqual(c.state, "too-early") + t = c.transport = FakeTransport(c, addr) + c.factory = factory + c.connectionMade() + self.assertEqual(factory._connectionWasMade_called, True) + self.assertEqual(factory._p, c) + + owner._state = "nevermind" + d = c.startNegotiation("description") + self.assertEqual(c.state, "handshake") + self.assertEqual(t.read_buf(), b"send_this") + results = [] + d.addBoth(results.append) + self.assertEqual(results, []) + + c.dataReceived(b"expect_this") + self.assertEqual(t.read_buf(), b"nevermind\n") + self.assertEqual(t._connected, False) + self.assertEqual(c.state, "hung up") + self.assertEqual(len(results), 1) + f = results[0] + self.assertIsInstance(f, failure.Failure) + self.assertIsInstance(f.value, transit.BadHandshake) + self.assertEqual(str(f.value), "abandoned") + + def test_handshake_other_error(self): + owner = MockOwner() + factory = MockFactory() + addr = address.HostnameAddress("example.com", 1234) + c = transit.Connection(owner, None, None) + self.assertEqual(c.state, "too-early") + t = c.transport = FakeTransport(c, addr) + c.factory = factory + c.connectionMade() + self.assertEqual(factory._connectionWasMade_called, True) + self.assertEqual(factory._p, c) + + d = c.startNegotiation("description") + self.assertEqual(c.state, "handshake") + self.assertEqual(t.read_buf(), b"send_this") + results = [] + d.addBoth(results.append) + self.assertEqual(results, []) + c.state = RandomError("boom") + self.assertRaises(RandomError, c.dataReceived, b"surprise!") + self.assertEqual(t._connected, False) + self.assertEqual(c.state, "hung up") + self.assertEqual(len(results), 1) + f = results[0] + self.assertIsInstance(f, failure.Failure) + self.assertIsInstance(f.value, RandomError) + + def test_relay_handshake(self): + relay_handshake = b"relay handshake" + owner = MockOwner() + factory = MockFactory() + addr = address.HostnameAddress("example.com", 1234) + c = transit.Connection(owner, relay_handshake, None) + self.assertEqual(c.state, "too-early") + t = c.transport = FakeTransport(c, addr) + c.factory = factory + c.connectionMade() + self.assertEqual(factory._connectionWasMade_called, True) + self.assertEqual(factory._p, c) + self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation + + owner._state = "go" + d = c.startNegotiation("description") + self.assertEqual(t.read_buf(), relay_handshake) + self.assertEqual(c.state, "relay") # waiting for OK from relay + + c.dataReceived(b"ok\n") + self.assertEqual(t.read_buf(), b"send_this") + self.assertEqual(c.state, "handshake") + + results = [] + d.addBoth(results.append) + self.assertEqual(results, []) + + c.dataReceived(b"expect_this") + self.assertEqual(c.state, "records") + self.assertEqual(results, [c]) + + self.assertEqual(t.read_buf(), b"go\n") + + def test_relay_handshake_bad(self): + relay_handshake = b"relay handshake" + owner = MockOwner() + factory = MockFactory() + addr = address.HostnameAddress("example.com", 1234) + c = transit.Connection(owner, relay_handshake, None) + self.assertEqual(c.state, "too-early") + t = c.transport = FakeTransport(c, addr) + c.factory = factory + c.connectionMade() + self.assertEqual(factory._connectionWasMade_called, True) + self.assertEqual(factory._p, c) + self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation + + owner._state = "go" + d = c.startNegotiation("description") + self.assertEqual(t.read_buf(), relay_handshake) + self.assertEqual(c.state, "relay") # waiting for OK from relay + + c.dataReceived(b"not ok\n") + self.assertEqual(t._connected, False) + self.assertEqual(c.state, "hung up") + + results = [] + d.addBoth(results.append) + self.assertEqual(len(results), 1) + f = results[0] + self.assertIsInstance(f, failure.Failure) + self.assertIsInstance(f.value, transit.BadHandshake) + self.assertEqual(str(f.value), + "got %r want %r" % (b"not ok\n", b"ok\n")) + + def test_receiver_accepted(self): + # we're on the receiving side, so we wait for the sender to decide + owner = MockOwner() + factory = MockFactory() + addr = address.HostnameAddress("example.com", 1234) + c = transit.Connection(owner, None, None) + self.assertEqual(c.state, "too-early") + t = c.transport = FakeTransport(c, addr) + c.factory = factory + c.connectionMade() + self.assertEqual(factory._connectionWasMade_called, True) + self.assertEqual(factory._p, c) + + owner._state = "wait-for-decision" + d = c.startNegotiation("description") + self.assertEqual(c.state, "handshake") + self.assertEqual(t.read_buf(), b"send_this") + results = [] + d.addBoth(results.append) + self.assertEqual(results, []) + + c.dataReceived(b"expect_this") + self.assertEqual(c.state, "wait-for-decision") + self.assertEqual(results, []) + + c.dataReceived(b"go\n") + self.assertEqual(c.state, "records") + self.assertEqual(results, [c]) + + def test_receiver_rejected_politely(self): + # we're on the receiving side, so we wait for the sender to decide + owner = MockOwner() + factory = MockFactory() + addr = address.HostnameAddress("example.com", 1234) + c = transit.Connection(owner, None, None) + self.assertEqual(c.state, "too-early") + t = c.transport = FakeTransport(c, addr) + c.factory = factory + c.connectionMade() + self.assertEqual(factory._connectionWasMade_called, True) + self.assertEqual(factory._p, c) + + owner._state = "wait-for-decision" + d = c.startNegotiation("description") + self.assertEqual(c.state, "handshake") + self.assertEqual(t.read_buf(), b"send_this") + results = [] + d.addBoth(results.append) + self.assertEqual(results, []) + + c.dataReceived(b"expect_this") + self.assertEqual(c.state, "wait-for-decision") + self.assertEqual(results, []) + + c.dataReceived(b"nevermind\n") # polite rejection + self.assertEqual(t._connected, False) + self.assertEqual(c.state, "hung up") + self.assertEqual(len(results), 1) + f = results[0] + self.assertIsInstance(f, failure.Failure) + self.assertIsInstance(f.value, transit.BadHandshake) + self.assertEqual(str(f.value), + "got %r want %r" % (b"nevermind\n", b"go\n")) + + def test_receiver_rejected_rudely(self): + # we're on the receiving side, so we wait for the sender to decide + owner = MockOwner() + factory = MockFactory() + addr = address.HostnameAddress("example.com", 1234) + c = transit.Connection(owner, None, None) + self.assertEqual(c.state, "too-early") + t = c.transport = FakeTransport(c, addr) + c.factory = factory + c.connectionMade() + self.assertEqual(factory._connectionWasMade_called, True) + self.assertEqual(factory._p, c) + + owner._state = "wait-for-decision" + d = c.startNegotiation("description") + self.assertEqual(c.state, "handshake") + self.assertEqual(t.read_buf(), b"send_this") + results = [] + d.addBoth(results.append) + self.assertEqual(results, []) + + c.dataReceived(b"expect_this") + self.assertEqual(c.state, "wait-for-decision") + self.assertEqual(results, []) + + t.loseConnection() + self.assertEqual(t._connected, False) + self.assertEqual(len(results), 1) + f = results[0] + self.assertIsInstance(f, failure.Failure) + self.assertIsInstance(f.value, transit.BadHandshake) + self.assertEqual(str(f.value), "connection lost") + + + def test_cancel(self): + owner = MockOwner() + factory = MockFactory() + addr = address.HostnameAddress("example.com", 1234) + c = transit.Connection(owner, None, None) + self.assertEqual(c.state, "too-early") + t = c.transport = FakeTransport(c, addr) + c.factory = factory + c.connectionMade() + + d = c.startNegotiation("description") + results = [] + d.addBoth(results.append) + # while we're waiting for negotiation, we get cancelled + d.cancel() + + self.assertEqual(t._connected, False) + self.assertEqual(c.state, "hung up") + self.assertEqual(len(results), 1) + f = results[0] + self.assertIsInstance(f, failure.Failure) + self.assertIsInstance(f.value, defer.CancelledError) + + def test_timeout(self): + clock = task.Clock() + owner = MockOwner() + factory = MockFactory() + addr = address.HostnameAddress("example.com", 1234) + c = transit.Connection(owner, None, None) + def _callLater(period, func): + clock.callLater(period, func) + c.callLater = _callLater + self.assertEqual(c.state, "too-early") + t = c.transport = FakeTransport(c, addr) + c.factory = factory + c.connectionMade() + # the timer should now be running + d = c.startNegotiation("description") + results = [] + d.addBoth(results.append) + # while we're waiting for negotiation, the timer expires + clock.advance(transit.TIMEOUT + 1.0) + + self.assertEqual(t._connected, False) + self.assertEqual(len(results), 1) + f = results[0] + self.assertIsInstance(f, failure.Failure) + self.assertIsInstance(f.value, transit.BadHandshake) + self.assertEqual(str(f.value), "timeout") + + def make_connection(self): + owner = MockOwner() + factory = MockFactory() + addr = address.HostnameAddress("example.com", 1234) + c = transit.Connection(owner, None, None) + t = c.transport = FakeTransport(c, addr) + c.factory = factory + c.connectionMade() + + owner._state = "go" + d = c.startNegotiation("description") + results = [] + d.addBoth(results.append) + c.dataReceived(b"expect_this") + self.assertEqual(results, [c]) + t.read_buf() # flush input buffer, prepare for encrypted records + + return t, c, owner + + def test_records_good(self): + # now make sure that outbound records are encrypted properly + t, c, owner = self.make_connection() + + RECORD1 = b"record" + c.send_record(RECORD1) + buf = t.read_buf() + expected = ("%08x" % (24+len(RECORD1)+16)).encode("ascii") + self.assertEqual(hexlify(buf[:4]), expected) + encrypted = buf[4:] + receive_box = SecretBox(owner._sender_record_key()) + nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended + nonce = int(hexlify(nonce_buf), 16) + self.assertEqual(nonce, 0) # first message gets nonce 0 + decrypted = receive_box.decrypt(encrypted) + self.assertEqual(decrypted, RECORD1) + + # second message gets nonce 1 + RECORD2 = b"record2" + c.send_record(RECORD2) + buf = t.read_buf() + expected = ("%08x" % (24+len(RECORD2)+16)).encode("ascii") + self.assertEqual(hexlify(buf[:4]), expected) + encrypted = buf[4:] + receive_box = SecretBox(owner._sender_record_key()) + nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended + nonce = int(hexlify(nonce_buf), 16) + self.assertEqual(nonce, 1) + decrypted = receive_box.decrypt(encrypted) + self.assertEqual(decrypted, RECORD2) + + # and that we can receive records properly + inbound_records = [] + c.recordReceived = inbound_records.append + + RECORD3 = b"record3" + send_box = SecretBox(owner._receiver_record_key()) + nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0 + encrypted = send_box.encrypt(RECORD3, nonce_buf) + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + c.dataReceived(length[:2]) + c.dataReceived(length[2:]) + c.dataReceived(encrypted[:-2]) + self.assertEqual(inbound_records, []) + c.dataReceived(encrypted[-2:]) + self.assertEqual(inbound_records, [RECORD3]) + + RECORD4 = b"record4" + send_box = SecretBox(owner._receiver_record_key()) + nonce_buf = unhexlify("%048x" % 1) # nonces increment + encrypted = send_box.encrypt(RECORD4, nonce_buf) + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + c.dataReceived(length[:2]) + c.dataReceived(length[2:]) + c.dataReceived(encrypted[:-2]) + self.assertEqual(inbound_records, [RECORD3]) + c.dataReceived(encrypted[-2:]) + self.assertEqual(inbound_records, [RECORD3, RECORD4]) + + def corrupt(self, orig): + last_byte = orig[-1:] + num = int(hexlify(last_byte).decode("ascii"), 16) + corrupt_num = 256 - num + as_byte = unhexlify("%02x" % corrupt_num) + return orig[:-1] + as_byte + + def test_records_corrupt(self): + # corrupt records should be rejected + t, c, owner = self.make_connection() + + inbound_records = [] + c.recordReceived = inbound_records.append + + RECORD = b"record" + send_box = SecretBox(owner._receiver_record_key()) + nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0 + encrypted = self.corrupt(send_box.encrypt(RECORD, nonce_buf)) + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + c.dataReceived(length) + c.dataReceived(encrypted[:-2]) + self.assertEqual(inbound_records, []) + self.assertRaises(CryptoError, c.dataReceived, encrypted[-2:]) + self.assertEqual(inbound_records, []) + # and the connection should have been dropped + self.assertEqual(t._connected, False) + + def test_out_of_order_nonce(self): + # an inbound out-of-order nonce should be rejected + t, c, owner = self.make_connection() + + inbound_records = [] + c.recordReceived = inbound_records.append + + RECORD = b"record" + send_box = SecretBox(owner._receiver_record_key()) + nonce_buf = unhexlify("%048x" % 1) # first nonce must be 0 + encrypted = send_box.encrypt(RECORD, nonce_buf) + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + c.dataReceived(length) + c.dataReceived(encrypted[:-2]) + self.assertEqual(inbound_records, []) + self.assertRaises(transit.BadNonce, c.dataReceived, encrypted[-2:]) + self.assertEqual(inbound_records, []) + # and the connection should have been dropped + self.assertEqual(t._connected, False) + + # TODO: check that .connectionLost/loseConnection signatures are + # consistent: zero args, or one arg? + + # XXX: if we don't set the transit key before connecting, what + # happens? We currently get a type-check assertion from HKDF because + # the key is None. + +DIRECT_HINT = u"tcp:direct:1234" +RELAY_HINT = u"tcp:relay:1234" +UNUSABLE_HINT = u"unusable:foo:bar" + +class Transit(unittest.TestCase): + @inlineCallbacks + def test_success_direct(self): + clock = task.Clock() + s = transit.TransitSender(u"", reactor=clock) + s.set_transit_key(b"key") + hints = yield s.get_direct_hints() # start the listener + del hints + s.add_their_direct_hints([DIRECT_HINT, UNUSABLE_HINT]) + s.add_their_relay_hints([]) + + connectors = [] + def _start_connector(ep, description, is_relay=False): + d = defer.Deferred() + connectors.append(d) + return d + s._start_connector = _start_connector + d = s.connect() + results = [] + d.addBoth(results.append) + self.assertEqual(results, []) + self.assertEqual(len(connectors), 1) + self.assertIsInstance(connectors[0], defer.Deferred) + + connectors[0].callback("winner") + self.assertEqual(results, ["winner"]) + + def _endpoint_from_hint(self, hint): + if hint == DIRECT_HINT: + return "direct" + elif hint == RELAY_HINT: + return "relay" + elif hint == UNUSABLE_HINT: + return None + else: + return "ep" + + @inlineCallbacks + def test_wait_for_relay(self): + clock = task.Clock() + s = transit.TransitSender(u"", reactor=clock) + s.set_transit_key(b"key") + hints = yield s.get_direct_hints() # start the listener + del hints + s.add_their_direct_hints([DIRECT_HINT, UNUSABLE_HINT]) + s.add_their_relay_hints([RELAY_HINT]) + + direct_connectors = [] + relay_connectors = [] + s._endpoint_from_hint = self._endpoint_from_hint + def _start_connector(ep, description, is_relay=False): + d = defer.Deferred() + if ep == "direct": + direct_connectors.append(d) + elif ep == "relay": + relay_connectors.append(d) + else: + raise ValueError + return d + s._start_connector = _start_connector + + d = s.connect() + results = [] + d.addBoth(results.append) + self.assertEqual(results, []) + # the direct connectors are tried right away, but the relay + # connectors are stalled for a few seconds + self.assertEqual(len(direct_connectors), 1) + self.assertEqual(len(relay_connectors), 0) + + clock.advance(s.RELAY_DELAY + 1.0) + self.assertEqual(len(direct_connectors), 1) + self.assertEqual(len(relay_connectors), 1) + + direct_connectors[0].callback("winner") + self.assertEqual(results, ["winner"]) + + @inlineCallbacks + def test_no_direct_hints(self): + clock = task.Clock() + s = transit.TransitSender(u"", reactor=clock) + s.set_transit_key(b"key") + hints = yield s.get_direct_hints() # start the listener + del hints + s.add_their_direct_hints([UNUSABLE_HINT]) + s.add_their_relay_hints([RELAY_HINT, UNUSABLE_HINT]) + + direct_connectors = [] + relay_connectors = [] + s._endpoint_from_hint = self._endpoint_from_hint + def _start_connector(ep, description, is_relay=False): + d = defer.Deferred() + if ep == "direct": + direct_connectors.append(d) + elif ep == "relay": + relay_connectors.append(d) + else: + raise ValueError + return d + s._start_connector = _start_connector + + d = s.connect() + results = [] + d.addBoth(results.append) + self.assertEqual(results, []) + # since there are no usable direct hints, the relay connector will + # only be stalled for 0 seconds + self.assertEqual(len(direct_connectors), 0) + self.assertEqual(len(relay_connectors), 0) + + clock.advance(0) + self.assertEqual(len(direct_connectors), 0) + self.assertEqual(len(relay_connectors), 1) + + relay_connectors[0].callback("winner") + self.assertEqual(results, ["winner"]) + + +class Full(unittest.TestCase): + def doBoth(self, d1, d2): + return gatherResults([d1, d2], True) + + @inlineCallbacks + def test_full(self): + KEY = b"k"*32 + s = transit.TransitSender(None) + r = transit.TransitReceiver(None) + + s.set_transit_key(KEY) + r.set_transit_key(KEY) + + shints = yield s.get_direct_hints() + rhints = yield r.get_direct_hints() + + s.add_their_direct_hints(rhints) + r.add_their_direct_hints(shints) + + s.add_their_relay_hints([]) + r.add_their_relay_hints([]) + + (x,y) = yield self.doBoth(s.connect(), r.connect()) + self.assertIsInstance(x, transit.Connection) + self.assertIsInstance(y, transit.Connection) + + d = defer.Deferred() + y.recordReceived = d.callback + + x.send_record(b"record1") + r = yield d + self.assertEqual(r, b"record1") + + yield x.close() + yield y.close() diff --git a/src/wormhole/twisted/transit.py b/src/wormhole/twisted/transit.py new file mode 100644 index 0000000..fbb5d88 --- /dev/null +++ b/src/wormhole/twisted/transit.py @@ -0,0 +1,603 @@ +from __future__ import print_function +import sys, time, socket +from binascii import hexlify, unhexlify +from zope.interface import implementer +from twisted.python.runtime import platformType +from twisted.internet import (reactor, interfaces, defer, protocol, + endpoints, task, address) +from twisted.protocols import policies +from nacl.secret import SecretBox +from ..util import ipaddrs +from ..util.hkdf import HKDF +from ..errors import UsageError +from ..transit_common import (BadHandshake, + BadNonce, + build_receiver_handshake, + build_sender_handshake, + build_relay_handshake) + +def debug(msg): + if False: + print(msg) +def since(start): + return time.time() - start + +TIMEOUT=15 + +@implementer(interfaces.IProducer, interfaces.IConsumer) +class Connection(protocol.Protocol, policies.TimeoutMixin): + def __init__(self, owner, relay_handshake, start): + self.state = "too-early" + self.buf = b"" + self.owner = owner + self.relay_handshake = relay_handshake + self.start = start + self._negotiation_d = defer.Deferred(self._cancel) + self._error = None + + def connectionMade(self): + debug("handle %r" % (self.transport,)) + self.setTimeout(TIMEOUT) # does timeoutConnection() when it expires + self.factory.connectionWasMade(self, self.transport.getPeer()) + + def startNegotiation(self, description): + self.description = description + if self.relay_handshake is not None: + self.transport.write(self.relay_handshake) + self.state = "relay" + else: + self.state = "start" + self.dataReceived(b"") # cycle the state machine + return self._negotiation_d + + def _cancel(self, d): + self.state = "hung up" # stop reacting to anything further + self._error = defer.CancelledError() + self.transport.loseConnection() + # if connectionLost isn't called synchronously, then our + # self._negotiation_d will have been errbacked by Deferred.cancel + # (which is our caller). So if it's still around, clobber it + if self._negotiation_d: + self._negotiation_d = None + + + def dataReceived(self, data): + try: + self._dataReceived(data) + except Exception as e: + self.setTimeout(None) + self._error = e + self.transport.loseConnection() + self.state = "hung up" + if not isinstance(e, BadHandshake): + raise + + def _check_and_remove(self, expected): + # any divergence is a handshake error + if not self.buf.startswith(expected[:len(self.buf)]): + raise BadHandshake("got %r want %r" % (self.buf, expected)) + if len(self.buf) < len(expected): + return False # keep waiting + self.buf = self.buf[len(expected):] + return True + + def _dataReceived(self, data): + # protocol is: + # (maybe: send relay handshake, wait for ok) + # send (send|receive)_handshake + # wait for (receive|send)_handshake + # sender: decide, send "go" or hang up + # receiver: wait for "go" + self.buf += data + + assert self.state != "too-early" + if self.state == "relay": + if not self._check_and_remove(b"ok\n"): + return + self.state = "start" + if self.state == "start": + self.transport.write(self.owner._send_this()) + self.state = "handshake" + if self.state == "handshake": + if not self._check_and_remove(self.owner._expect_this()): + return + self.state = self.owner.connection_ready(self, self.description) + # If we're the receiver, we'll be moved to state + # "wait-for-decision", which means we're waiting for the other + # side (the sender) to make a decision. If we're the sender, + # we'll either be moved to state "go" (send GO and move directly + # to state "records") or state "nevermind" (send NEVERMIND and + # hang up). + + if self.state == "wait-for-decision": + if not self._check_and_remove(b"go\n"): + return + self._negotiationSuccessful() + if self.state == "go": + GO = b"go\n" + self.transport.write(GO) + self._negotiationSuccessful() + if self.state == "nevermind": + self.transport.write(b"nevermind\n") + raise BadHandshake("abandoned") + if self.state == "records": + return self.dataReceivedRECORDS() + if isinstance(self.state, Exception): # for tests + raise self.state + + def _negotiationSuccessful(self): + self.state = "records" + self.setTimeout(None) + send_key = self.owner._sender_record_key() + self.send_box = SecretBox(send_key) + self.send_nonce = 0 + receive_key = self.owner._receiver_record_key() + self.receive_box = SecretBox(receive_key) + self.next_receive_nonce = 0 + d, self._negotiation_d = self._negotiation_d, None + d.callback(self) + + def dataReceivedRECORDS(self): + if len(self.buf) < 4: + return + length = int(hexlify(self.buf[:4]), 16) + if len(self.buf) < 4+length: + return + encrypted, self.buf = self.buf[4:4+length], self.buf[4+length:] + + record = self._decrypt_record(encrypted) + self.recordReceived(record) + + def _decrypt_record(self, encrypted): + nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended + nonce = int(hexlify(nonce_buf), 16) + if nonce != self.next_receive_nonce: + raise BadNonce("received out-of-order record: got %d, expected %d" + % (nonce, self.next_receive_nonce)) + self.next_receive_nonce += 1 + record = self.receive_box.decrypt(encrypted) + return record + + def send_record(self, record): + if not isinstance(record, type(b"")): raise UsageError + assert SecretBox.NONCE_SIZE == 24 + assert self.send_nonce < 2**(8*24) + assert len(record) < 2**(8*4) + nonce = unhexlify("%048x" % self.send_nonce) # big-endian + self.send_nonce += 1 + encrypted = self.send_box.encrypt(record, nonce) + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + self.transport.write(length) + self.transport.write(encrypted) + + def recordReceived(self, record): + pass + + def close(self): + self.transport.loseConnection() + + def timeoutConnection(self): + self._error = BadHandshake("timeout") + self.transport.loseConnection() + + def connectionLost(self, reason=None): + self.setTimeout(None) + d, self._negotiation_d = self._negotiation_d, None + # the Deferred is only relevant until negotiation finishes, so skip + # this if it's alredy been fired + if d: + # Each call to loseConnection() sets self._error first, so we can + # deliver useful information to the Factory that's waiting on + # this (although they'll generally ignore the specific error, + # except for logging unexpected ones). The possible cases are: + # + # cancel: defer.CancelledError + # far-end disconnect: BadHandshake("connection lost") + # handshake error (something we didn't like): BadHandshake(what) + # other error: some other Exception + # timeout: BadHandshake("timeout") + + d.errback(self._error or BadHandshake("connection lost")) + + + +class OutboundConnectionFactory(protocol.ClientFactory): + protocol = Connection + + def __init__(self, owner, relay_handshake): + self.owner = owner + self.relay_handshake = relay_handshake + self.start = time.time() + + def buildProtocol(self, addr): + p = self.protocol(self.owner, self.relay_handshake, self.start) + p.factory = self + return p + + def connectionWasMade(self, p, addr): + # outbound connections are handled via the endpoint + pass + + +class InboundConnectionFactory(protocol.ClientFactory): + protocol = Connection + + def __init__(self, owner): + self.owner = owner + self.start = time.time() + self._inbound_d = defer.Deferred(self._cancel) + self._pending_connections = set() + + def whenDone(self): + return self._inbound_d + + def _cancel(self, inbound_d): + self._shutdown() + # our _inbound_d will be errbacked by Deferred.cancel() + + def _shutdown(self): + for d in list(self._pending_connections): + d.cancel() # that fires _remove and _proto_failed + + def describePeer(self, addr): + if isinstance(addr, address.HostnameAddress): + return "<-%s:%d" % (addr.hostname, addr.port) + elif isinstance(addr, (address.IPv4Address, address.IPv6Address)): + return "<-%s:%d" % (addr.host, addr.port) + return "<-%r" % addr + + def buildProtocol(self, addr): + p = self.protocol(self.owner, None, self.start) + p.factory = self + return p + + def connectionWasMade(self, p, addr): + d = p.startNegotiation(self.describePeer(addr)) + self._pending_connections.add(d) + d.addBoth(self._remove, d) + d.addCallbacks(self._proto_succeeded, self._proto_failed) + + def _remove(self, res, d): + self._pending_connections.remove(d) + return res + + def _proto_succeeded(self, p): + self._shutdown() + self._inbound_d.callback(p) + + def _proto_failed(self, f): + # ignore these two, let Twisted log everything else + f.trap(BadHandshake, defer.CancelledError) + pass + +def allocate_tcp_port(): + """Return an (integer) available TCP port on localhost. This briefly + listens on the port in question, then closes it right away.""" + # We want to bind() the socket but not listen(). Twisted (in + # tcp.Port.createInternetSocket) would do several other things: + # non-blocking, close-on-exec, and SO_REUSEADDR. We don't need + # non-blocking because we never listen on it, and we don't need + # close-on-exec because we close it right away. So just add SO_REUSEADDR. + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if platformType == "posix" and sys.platform != "cygwin": + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + s.close() + return port + +class _ThereCanBeOnlyOne: + """Accept a list of contender Deferreds, and return a summary Deferred. + When the first contender fires successfully, cancel the rest and fire the + summary with the winning contender's result. If all error, errback the + summary. + + status_cb=? + """ + def __init__(self, contenders): + self._remaining = set(contenders) + self._winner_d = defer.Deferred(self._cancel) + self._first_success = None + self._first_failure = None + self._have_winner = False + self._fired = False + + def _cancel(self, _): + for d in list(self._remaining): + d.cancel() + # since that will errback everything in _remaining, we'll have hit + # _maybe_done() and fired self._winner_d by this point + + def run(self): + for d in list(self._remaining): + d.addBoth(self._remove, d) + d.addCallbacks(self._succeeded, self._failed) + d.addCallback(self._maybe_done) + return self._winner_d + + def _remove(self, res, d): + self._remaining.remove(d) + return res + + def _succeeded(self, res): + self._have_winner = True + self._first_success = res + for d in list(self._remaining): + d.cancel() + + def _failed(self, f): + if self._first_failure is None: + self._first_failure = f + + def _maybe_done(self, _): + if self._remaining: + return + if self._fired: + return + self._fired = True + if self._have_winner: + self._winner_d.callback(self._first_success) + else: + self._winner_d.errback(self._first_failure) + +def there_can_be_only_one(contenders): + return _ThereCanBeOnlyOne(contenders).run() + +class Common: + RELAY_DELAY = 2.0 + + def __init__(self, transit_relay, reactor=reactor): + if not isinstance(transit_relay, (type(None), type(u""))): + raise UsageError + self._transit_relay = transit_relay + self._transit_key = None + self._waiting_for_transit_key = [] + self._listener = None + self._winner = None + self._winner_description = None + self._reactor = reactor + + def _build_listener(self): + portnum = allocate_tcp_port() + direct_hints = [u"tcp:%s:%d" % (addr, portnum) + for addr in ipaddrs.find_addresses()] + ep = endpoints.serverFromString(reactor, "tcp:%d" % portnum) + return direct_hints, ep + + def get_direct_hints(self): + if self._listener: + return defer.succeed(self._my_direct_hints) + # there is a slight race here: if someone calls get_direct_hints() a + # second time, before the listener has actually started listening, + # then they'll get a Deferred that fires (with the hints) before the + # listener starts listening. But most applications won't call this + # multiple times, and the race is between 1: the parent Wormhole + # protocol getting the connection hints to the other end, and 2: the + # listener being ready for connections, and I'm confident that the + # listener will win. + self._my_direct_hints, self._listener = self._build_listener() + + # Start the server, so it will be running by the time anyone tries to + # connect to the direct hints we return. + f = InboundConnectionFactory(self) + self._listener_f = f # for tests # XX move to __init__ ? + self._listener_d = f.whenDone() + d = self._listener.listen(f) + def _listening(lp): + # lp is an IListeningPort + #self._listener_port = lp # for tests + def _stop_listening(res): + lp.stopListening() + return res + self._listener_d.addBoth(_stop_listening) + return self._my_direct_hints + d.addCallback(_listening) + return d + + def _stop_listening(self): + # this is for unit tests. The usual control flow (via connect()) + # wires the listener's Deferred into a there_can_be_only_one(), which + # eats the errback. If we don't ever call connect(), we must catch it + # ourselves. + self._listener_d.addErrback(lambda f: None) + self._listener_d.cancel() + + def get_relay_hints(self): + if self._transit_relay: + return [self._transit_relay] + return [] + + def add_their_direct_hints(self, hints): + for h in hints: + if not isinstance(h, type(u"")): + raise TypeError("hint '%r' should be unicode, not %s" + % (h, type(h))) + self._their_direct_hints = set(hints) + def add_their_relay_hints(self, hints): + for h in hints: + if not isinstance(h, type(u"")): + raise TypeError("hint '%r' should be unicode, not %s" + % (h, type(h))) + self._their_relay_hints = set(hints) + + def _send_this(self): + if self.is_sender: + return build_sender_handshake(self._transit_key) + else: + return build_receiver_handshake(self._transit_key) + + def _expect_this(self): + if self.is_sender: + return build_receiver_handshake(self._transit_key) + else: + return build_sender_handshake(self._transit_key)# + b"go\n" + + def _sender_record_key(self): + if self.is_sender: + return HKDF(self._transit_key, SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_sender_key") + else: + return HKDF(self._transit_key, SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_receiver_key") + + def _receiver_record_key(self): + if self.is_sender: + return HKDF(self._transit_key, SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_receiver_key") + else: + return HKDF(self._transit_key, SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_sender_key") + + def set_transit_key(self, key): + # We use pubsub to protect against the race where the sender knows + # the hints and the key, and connects to the receiver's transit + # socket before the receiver gets the relay message (and thus the + # key). + self._transit_key = key + waiters = self._waiting_for_transit_key + del self._waiting_for_transit_key + for d in waiters: + # We don't need eventual-send here. It's safer in general, but + # set_transit_key() is only called once, and _get_transit_key() + # won't touch the subscribers list once the key is set. + d.callback(key) + + def _get_transit_key(self): + if self._transit_key: + return defer.succeed(self._transit_key) + d = defer.Deferred() + self._waiting_for_transit_key.append(d) + return d + + def connect(self): + d = self._get_transit_key() + d.addCallback(self._connect) + # we want to have the transit key before starting any outbound + # connections, so those connections will know what to say when they + # connect + return d + + def _connect(self, _): + # It might be nice to wire this so that a failure in the direct hints + # causes the relay hints to be used right away (fast failover). But + # none of our current use cases would take advantage of that: if we + # have any viable direct hints, then they're either going to succeed + # quickly or hang for a long time. + contenders = [] + contenders.append(self._listener_d) + relay_delay = 0 + + for hint in self._their_direct_hints: + # Check the hint type to see if we can support it (e.g. skip + # onion hints on a non-Tor client). Do not increase relay_delay + # unless we have at least one viable hint. + ep = self._endpoint_from_hint(hint) + if not ep: + continue + description = "->%s" % (hint,) + d = self._start_connector(ep, description) + contenders.append(d) + relay_delay = self.RELAY_DELAY + + # Start trying the relay a few seconds after we start to try the + # direct hints. The idea is to prefer direct connections, but not be + # afraid of using the relay when we have direct hints that don't + # resolve quickly. Many direct hints will be to unused local-network + # IP addresses, which won't answer, and would take the full TCP + # timeout (30s or more) to fail. + for hint in self._their_relay_hints: + ep = self._endpoint_from_hint(hint) + if not ep: + continue + description = "->relay:%s" % (hint,) + d = task.deferLater(self._reactor, relay_delay, + self._start_connector, ep, description, + is_relay=True) + contenders.append(d) + + winner = there_can_be_only_one(contenders) + return self._not_forever(2*TIMEOUT, winner) + + def _not_forever(self, timeout, d): + """If the timer fires first, cancel the deferred. If the deferred fires + first, cancel the timer.""" + t = self._reactor.callLater(timeout, d.cancel) + def _done(res): + if t.active(): + t.cancel() + return res + d.addBoth(_done) + return d + + def _start_connector(self, ep, description, is_relay=False): + relay_handshake = None + if is_relay: + relay_handshake = build_relay_handshake(self._transit_key) + f = OutboundConnectionFactory(self, relay_handshake) + d = ep.connect(f) + # fires with protocol, or ConnectError + d.addCallback(lambda p: p.startNegotiation(description)) + return d + + def _endpoint_from_hint(self, hint): + # TODO: use transit_common.parse_hint_tcp + if ":" not in hint: + return None + hint_type = hint.split(":")[0] + if hint_type != "tcp": + return None + pieces = hint.split(":") + return endpoints.HostnameEndpoint(self._reactor, pieces[1], + int(pieces[2])) + + def connection_ready(self, p, description): + # inbound/outbound Connection protocols call this when they finish + # negotiation. The first one wins and gets a "go". Any subsequent + # ones lose and get a "nevermind" before being closed. + + if not self.is_sender: + return "wait-for-decision" + + if self._winner: + # we already have a winner, so this one loses + return "nevermind" + # this one wins! + self._winner = p + self._winner_description = description + return "go" + + def describe(self): + if not self._winner: + return "not yet established" + return self._winner_description + +class TransitSender(Common): + is_sender = True + +class TransitReceiver(Common): + is_sender = False + +# the TransitSender/Receiver.connect() yields a Connection, on which you can +# do send_record(), but what should the receive API be? set a callback for +# inbound records? get a Deferred for the next record? The producer/consumer +# API is enough for file transfer, but what would other applications want? + +# how should the Listener be managed? we want to shut it down when the +# connect() Deferred is cancelled, as well as terminating any negotiations in +# progress. +# +# the factory should return/manage a deferred, which fires iff an inbound +# connection completes negotiation successfully, can be cancelled (which +# stops the listener and drops all pending connections), but will never +# timeout, and only errbacks if cancelled. + +# write unit test for _ThereCanBeOnlyOne + +# check start/finish time-gathering instrumentation + +# add progress API + +# relay URLs are probably mishandled: both sides probably send their URL, +# then connect to the *other* side's URL, when they really should connect to +# both their own and the other side's. The current implementation probably +# only works if the two URLs are the same.