magic-wormhole/src/wormhole/test/test_transit.py

1621 lines
56 KiB
Python
Raw Normal View History

from __future__ import print_function, unicode_literals
2018-04-21 07:30:08 +00:00
2016-06-23 05:58:27 +00:00
import gc
2018-04-21 07:30:08 +00:00
import io
2016-02-15 01:57:09 +00:00
from binascii import hexlify, unhexlify
2018-04-21 07:30:08 +00:00
import six
from nacl.exceptions import CryptoError
from nacl.secret import SecretBox
from twisted.internet import address, defer, endpoints, error, protocol, task
2016-02-15 01:57:09 +00:00
from twisted.internet.defer import gatherResults, inlineCallbacks
from twisted.python import log
from twisted.test import proto_helpers
2018-04-21 07:30:08 +00:00
from twisted.trial import unittest
import mock
from wormhole_transit_relay import transit_server
2018-04-21 07:30:08 +00:00
from .. import transit
from .._hints import DirectTCPV1Hint
2018-04-21 07:30:08 +00:00
from ..errors import InternalError
2018-12-22 22:27:54 +00:00
from ..util import HKDF
2016-12-23 04:47:54 +00:00
from .common import ServerBase
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
class Highlander(unittest.TestCase):
def test_one_winner(self):
cancelled = set()
2018-04-21 07:30:08 +00:00
contenders = [
defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(5)
]
2016-02-15 01:57:09 +00:00
d = transit.there_can_be_only_one(contenders)
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
contenders[0].errback(ValueError())
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
contenders[1].errback(TypeError())
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
contenders[2].callback("yay")
self.assertEqual(self.successResultOf(d), "yay")
2018-04-21 07:30:08 +00:00
self.assertEqual(cancelled, set([3, 4]))
2016-02-15 01:57:09 +00:00
def test_there_might_also_be_none(self):
cancelled = set()
2018-04-21 07:30:08 +00:00
contenders = [
defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4)
]
2016-02-15 01:57:09 +00:00
d = transit.there_can_be_only_one(contenders)
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
contenders[0].errback(ValueError())
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
contenders[1].errback(TypeError())
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
contenders[2].errback(TypeError())
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
contenders[3].errback(NameError())
2018-04-21 07:30:08 +00:00
self.failureResultOf(d, ValueError) # first failure is recorded
2016-02-15 01:57:09 +00:00
self.assertEqual(cancelled, set())
def test_cancel_early(self):
cancelled = set()
2018-04-21 07:30:08 +00:00
contenders = [
defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4)
]
2016-02-15 01:57:09 +00:00
d = transit.there_can_be_only_one(contenders)
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
self.assertEqual(cancelled, set())
d.cancel()
self.failureResultOf(d, defer.CancelledError)
2016-02-15 01:57:09 +00:00
self.assertEqual(cancelled, set(range(4)))
def test_cancel_after_one_failure(self):
cancelled = set()
2018-04-21 07:30:08 +00:00
contenders = [
defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4)
]
2016-02-15 01:57:09 +00:00
d = transit.there_can_be_only_one(contenders)
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
self.assertEqual(cancelled, set())
contenders[0].errback(ValueError())
d.cancel()
self.failureResultOf(d, ValueError)
2018-04-21 07:30:08 +00:00
self.assertEqual(cancelled, set([1, 2, 3]))
2016-02-15 01:57:09 +00:00
class Forever(unittest.TestCase):
def _forever_setup(self):
clock = task.Clock()
2016-06-04 20:18:43 +00:00
c = transit.Common("", reactor=clock)
2016-02-15 01:57:09 +00:00
cancelled = []
d0 = defer.Deferred(cancelled.append)
d = c._not_forever(1.0, d0)
return c, clock, d0, d, cancelled
2016-02-15 01:57:09 +00:00
def test_not_forever_fires(self):
c, clock, d0, d, cancelled = self._forever_setup()
self.assertNoResult(d)
self.assertEqual(cancelled, [])
2016-02-15 01:57:09 +00:00
d.callback(1)
self.assertEqual(self.successResultOf(d), 1)
self.assertEqual(cancelled, [])
2016-02-15 01:57:09 +00:00
self.assertNot(clock.getDelayedCalls())
def test_not_forever_errs(self):
c, clock, d0, d, cancelled = self._forever_setup()
self.assertNoResult(d)
self.assertEqual(cancelled, [])
2016-02-15 01:57:09 +00:00
d.errback(ValueError())
self.assertEqual(cancelled, [])
self.failureResultOf(d, ValueError)
2016-02-15 01:57:09 +00:00
self.assertNot(clock.getDelayedCalls())
def test_not_forever_cancel_early(self):
c, clock, d0, d, cancelled = self._forever_setup()
self.assertNoResult(d)
self.assertEqual(cancelled, [])
2016-02-15 01:57:09 +00:00
d.cancel()
self.assertEqual(cancelled, [d0])
self.failureResultOf(d, defer.CancelledError)
2016-02-15 01:57:09 +00:00
self.assertNot(clock.getDelayedCalls())
def test_not_forever_timeout(self):
c, clock, d0, d, cancelled = self._forever_setup()
self.assertNoResult(d)
self.assertEqual(cancelled, [])
2016-02-15 01:57:09 +00:00
clock.advance(2.0)
self.assertEqual(cancelled, [d0])
self.failureResultOf(d, defer.CancelledError)
2016-02-15 01:57:09 +00:00
self.assertNot(clock.getDelayedCalls())
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
class Misc(unittest.TestCase):
def test_allocate_port(self):
portno = transit.allocate_tcp_port()
self.assertIsInstance(portno, int)
def test_allocate_port_no_reuseaddr(self):
mock_sys = mock.Mock()
mock_sys.platform = "cygwin"
with mock.patch("wormhole.transit.sys", mock_sys):
portno = transit.allocate_tcp_port()
self.assertIsInstance(portno, int)
2018-04-21 07:30:08 +00:00
# ipaddrs.py currently uses native strings: bytes on py2, unicode on
# py3
if six.PY2:
LOOPADDR = b"127.0.0.1"
OTHERADDR = b"1.2.3.4"
else:
2018-04-21 07:30:08 +00:00
LOOPADDR = "127.0.0.1" # unicode_literals
OTHERADDR = "1.2.3.4"
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
class Basic(unittest.TestCase):
@inlineCallbacks
2016-02-15 01:57:09 +00:00
def test_relay_hints(self):
2016-06-04 20:18:43 +00:00
URL = "tcp:host:1234"
c = transit.Common(URL, no_listen=True)
hints = yield c.get_connection_hints()
2018-04-21 07:30:08 +00:00
self.assertEqual(hints, [{
"type":
"relay-v1",
"hints": [{
"type": "direct-tcp-v1",
"hostname": "host",
"port": 1234,
"priority": 0.0
}],
}])
self.assertRaises(InternalError, transit.Common, 123)
2016-02-15 01:57:09 +00:00
@inlineCallbacks
2016-02-15 01:57:09 +00:00
def test_no_relay_hints(self):
c = transit.Common(None, no_listen=True)
hints = yield c.get_connection_hints()
self.assertEqual(hints, [])
2016-02-15 01:57:09 +00:00
def test_ignore_bad_hints(self):
2016-06-04 20:18:43 +00:00
c = transit.Common("")
c.add_connection_hints([{"type": "unknown"}])
2018-04-21 07:30:08 +00:00
c.add_connection_hints([{
"type": "relay-v1",
"hints": [{
"type": "unknown"
}]
}])
self.assertEqual(c._their_direct_hints, [])
self.assertEqual(c._our_relay_hints, set())
2016-02-15 01:57:09 +00:00
def test_ignore_localhost_hint_orig(self):
# this actually starts the listener
2016-06-04 20:18:43 +00:00
c = transit.TransitSender("")
hints = self.successResultOf(c.get_connection_hints())
c._stop_listening()
# If there are non-localhost hints, then localhost hints should be
# removed. But if the only hint is localhost, it should stay.
if len(hints) == 1:
2016-06-04 20:18:43 +00:00
if hints[0]["hostname"] == "127.0.0.1":
return
for hint in hints:
2016-06-04 20:18:43 +00:00
self.assertFalse(hint["hostname"] == "127.0.0.1")
def test_ignore_localhost_hint(self):
# this actually starts the listener
c = transit.TransitSender("")
2018-04-21 07:30:08 +00:00
with mock.patch(
"wormhole.ipaddrs.find_addresses",
return_value=[LOOPADDR, OTHERADDR]):
hints = self.successResultOf(c.get_connection_hints())
c._stop_listening()
# If there are non-localhost hints, then localhost hints should be
# removed.
self.assertEqual(len(hints), 1)
self.assertEqual(hints[0]["hostname"], "1.2.3.4")
def test_keep_only_localhost_hint(self):
# this actually starts the listener
c = transit.TransitSender("")
2018-04-21 07:30:08 +00:00
with mock.patch(
"wormhole.ipaddrs.find_addresses", return_value=[LOOPADDR]):
hints = self.successResultOf(c.get_connection_hints())
c._stop_listening()
# If the only hint is localhost, it should stay.
self.assertEqual(len(hints), 1)
self.assertEqual(hints[0]["hostname"], "127.0.0.1")
def test_abilities(self):
c = transit.Common(None, no_listen=True)
abilities = c.get_connection_abilities()
2018-04-21 07:30:08 +00:00
self.assertEqual(abilities, [
{
"type": "direct-tcp-v1"
},
{
"type": "relay-v1"
},
])
2016-02-15 01:57:09 +00:00
def test_transit_key_wait(self):
KEY = b"123"
2016-06-04 20:18:43 +00:00
c = transit.Common("")
2016-02-15 01:57:09 +00:00
d = c._get_transit_key()
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
c.set_transit_key(KEY)
self.assertEqual(self.successResultOf(d), KEY)
2016-02-15 01:57:09 +00:00
def test_transit_key_already_set(self):
KEY = b"123"
2016-06-04 20:18:43 +00:00
c = transit.Common("")
2016-02-15 01:57:09 +00:00
c.set_transit_key(KEY)
d = c._get_transit_key()
self.assertEqual(self.successResultOf(d), KEY)
2016-02-15 01:57:09 +00:00
def test_transit_keys(self):
KEY = b"123"
2016-06-04 20:18:43 +00:00
s = transit.TransitSender("")
2016-02-15 01:57:09 +00:00
s.set_transit_key(KEY)
2016-06-04 20:18:43 +00:00
r = transit.TransitReceiver("")
2016-02-15 01:57:09 +00:00
r.set_transit_key(KEY)
2018-04-21 07:30:08 +00:00
self.assertEqual(s._send_this(), (
b"transit sender "
b"559bdeae4b49fa6a23378d2b68f4c7e69378615d4af049c371c6a26e82391089"
b" ready\n\n"))
2016-02-15 01:57:09 +00:00
self.assertEqual(s._send_this(), r._expect_this())
2018-04-21 07:30:08 +00:00
self.assertEqual(r._send_this(), (
b"transit receiver "
b"ed447528194bac4c00d0c854b12a97ce51413d89aa74d6304475f516fdc23a1b"
b" ready\n\n"))
2016-02-15 01:57:09 +00:00
self.assertEqual(r._send_this(), s._expect_this())
2018-04-21 07:30:08 +00:00
self.assertEqual(
hexlify(s._sender_record_key()),
b"5a2fba3a9e524ab2e2823ff53b05f946896f6e4ce4e282ffd8e3ac0e5e9e0cda"
)
self.assertEqual(
hexlify(s._sender_record_key()), hexlify(r._receiver_record_key()))
2016-02-15 01:57:09 +00:00
2018-04-21 07:30:08 +00:00
self.assertEqual(
hexlify(r._sender_record_key()),
b"eedb143117249f45b39da324decf6bd9aae33b7ccd58487436de611a3c6b871d"
)
self.assertEqual(
hexlify(r._sender_record_key()), hexlify(s._receiver_record_key()))
2016-02-15 01:57:09 +00:00
def test_connection_ready(self):
2016-06-04 20:18:43 +00:00
s = transit.TransitSender("")
self.assertEqual(s.connection_ready("p1"), "go")
2016-02-15 01:57:09 +00:00
self.assertEqual(s._winner, "p1")
self.assertEqual(s.connection_ready("p2"), "nevermind")
2016-02-15 01:57:09 +00:00
self.assertEqual(s._winner, "p1")
2016-06-04 20:18:43 +00:00
r = transit.TransitReceiver("")
self.assertEqual(r.connection_ready("p1"), "wait-for-decision")
self.assertEqual(r.connection_ready("p2"), "wait-for-decision")
2016-02-15 01:57:09 +00:00
class Listener(unittest.TestCase):
def test_listener(self):
2016-06-04 20:18:43 +00:00
c = transit.Common("")
2016-02-15 01:57:09 +00:00
hints, ep = c._build_listener()
self.assertIsInstance(hints, (list, set))
if hints:
2018-12-22 04:22:02 +00:00
self.assertIsInstance(hints[0], DirectTCPV1Hint)
2016-02-15 01:57:09 +00:00
self.assertIsInstance(ep, endpoints.TCP4ServerEndpoint)
def test_get_direct_hints(self):
# this actually starts the listener
2016-06-04 20:18:43 +00:00
c = transit.TransitSender("")
2016-02-15 01:57:09 +00:00
d = c.get_connection_hints()
hints = self.successResultOf(d)
2016-02-15 01:57:09 +00:00
# the hints are supposed to be cached, so calling this twice won't
# start a second listener
self.assert_(c._listener)
d2 = c.get_connection_hints()
self.assertEqual(self.successResultOf(d2), hints)
2016-02-15 01:57:09 +00:00
c._stop_listening()
class DummyProtocol(protocol.Protocol):
def __init__(self):
self.buf = b""
self._count = None
self._d2 = None
def wait_for(self, count):
if len(self.buf) >= count:
data = self.buf[:count]
self.buf = self.buf[count:]
return defer.succeed(data)
self._d = defer.Deferred()
self._count = count
return self._d
def dataReceived(self, data):
self.buf += data
2018-04-21 07:30:08 +00:00
# print("oDR", self._count, len(self.buf))
2016-02-15 01:57:09 +00:00
if self._count is not None and len(self.buf) >= self._count:
got = self.buf[:self._count]
self.buf = self.buf[self._count:]
self._count = None
self._d.callback(got)
def wait_for_disconnect(self):
self._d2 = defer.Deferred()
return self._d2
def connectionLost(self, reason):
if self._d2:
self._d2.callback(None)
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
class FakeTransport:
signalConnectionLost = True
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
def __init__(self, p, peeraddr):
self.protocol = p
self._peeraddr = peeraddr
self._buf = b""
self._connected = True
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
def write(self, data):
self._buf += data
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
def loseConnection(self):
self._connected = False
if self.signalConnectionLost:
self.protocol.connectionLost()
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
def getPeer(self):
return self._peeraddr
def read_buf(self):
b = self._buf
self._buf = b""
return b
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
class RandomError(Exception):
pass
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
class MockConnection:
def __init__(self, owner, relay_handshake, start, description):
2016-02-15 01:57:09 +00:00
self.owner = owner
self.relay_handshake = relay_handshake
self.start = start
self._description = description
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
def cancel(d):
self._cancelled = True
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
self._d = defer.Deferred(cancel)
self._start_negotiation_called = False
self._cancelled = False
def startNegotiation(self):
2016-02-15 01:57:09 +00:00
self._start_negotiation_called = True
return self._d
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
class InboundConnectionFactory(unittest.TestCase):
def test_describe(self):
f = transit.InboundConnectionFactory(None)
addrH = address.HostnameAddress("example.com", 1234)
self.assertEqual(f._describePeer(addrH), "<-example.com:1234")
2016-02-15 01:57:09 +00:00
addr4 = address.IPv4Address("TCP", "1.2.3.4", 1234)
self.assertEqual(f._describePeer(addr4), "<-1.2.3.4:1234")
2016-02-15 01:57:09 +00:00
addr6 = address.IPv6Address("TCP", "::1", 1234)
self.assertEqual(f._describePeer(addr6), "<-::1:1234")
2016-02-15 01:57:09 +00:00
addrU = address.UNIXAddress("/dev/unlikely")
2018-04-21 07:30:08 +00:00
self.assertEqual(
f._describePeer(addrU), "<-UNIXAddress('/dev/unlikely')")
2016-02-15 01:57:09 +00:00
def test_success(self):
f = transit.InboundConnectionFactory("owner")
f.protocol = MockConnection
d = f.whenDone()
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
addr = address.HostnameAddress("example.com", 1234)
p = f.buildProtocol(addr)
self.assertIsInstance(p, MockConnection)
self.assertEqual(p.owner, "owner")
self.assertEqual(p.relay_handshake, None)
self.assertEqual(p._start_negotiation_called, False)
# meh .start
# this is normally called from Connection.connectionMade
f.connectionWasMade(p)
2016-02-15 01:57:09 +00:00
self.assertEqual(p._start_negotiation_called, True)
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
self.assertEqual(p._description, "<-example.com:1234")
p._d.callback(p)
self.assertEqual(self.successResultOf(d), p)
2016-02-15 01:57:09 +00:00
def test_one_fail_one_success(self):
f = transit.InboundConnectionFactory("owner")
f.protocol = MockConnection
d = f.whenDone()
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
addr1 = address.HostnameAddress("example.com", 1234)
addr2 = address.HostnameAddress("example.com", 5678)
p1 = f.buildProtocol(addr1)
p2 = f.buildProtocol(addr2)
2016-02-15 01:57:09 +00:00
f.connectionWasMade(p1)
f.connectionWasMade(p2)
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
p1._d.errback(transit.BadHandshake("nope"))
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
p2._d.callback(p2)
self.assertEqual(self.successResultOf(d), p2)
2016-02-15 01:57:09 +00:00
def test_first_success_wins(self):
f = transit.InboundConnectionFactory("owner")
f.protocol = MockConnection
d = f.whenDone()
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
addr1 = address.HostnameAddress("example.com", 1234)
addr2 = address.HostnameAddress("example.com", 5678)
p1 = f.buildProtocol(addr1)
p2 = f.buildProtocol(addr2)
2016-02-15 01:57:09 +00:00
f.connectionWasMade(p1)
f.connectionWasMade(p2)
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
p1._d.callback(p1)
self.assertEqual(self.successResultOf(d), p1)
2016-02-15 01:57:09 +00:00
self.assertEqual(p1._cancelled, False)
self.assertEqual(p2._cancelled, True)
def test_log_other_errors(self):
f = transit.InboundConnectionFactory("owner")
f.protocol = MockConnection
d = f.whenDone()
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
addr = address.HostnameAddress("example.com", 1234)
p1 = f.buildProtocol(addr)
# if the Connection protocol throws an unexpected error, that should
# get logged to the Twisted logs (as an Unhandled Error in Deferred)
# so we can diagnose the bug
f.connectionWasMade(p1)
2016-06-23 05:58:27 +00:00
our_error = RandomError("boom1")
p1._d.errback(our_error)
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
log.msg("=== note: the next RandomError is expected ===")
# Make sure the Deferred has gone out of scope, so the UnhandledError
# happens quickly. We must manually break the gc cycle.
del p1._d
2016-06-23 05:58:27 +00:00
gc.collect() # make PyPy happy
errors = self.flushLoggedErrors(RandomError)
self.assertEqual(1, len(errors))
self.assertEqual(our_error, errors[0].value)
2016-02-15 01:57:09 +00:00
log.msg("=== note: the preceding RandomError was expected ===")
def test_cancel(self):
f = transit.InboundConnectionFactory("owner")
f.protocol = MockConnection
d = f.whenDone()
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
addr1 = address.HostnameAddress("example.com", 1234)
addr2 = address.HostnameAddress("example.com", 5678)
p1 = f.buildProtocol(addr1)
p2 = f.buildProtocol(addr2)
2016-02-15 01:57:09 +00:00
f.connectionWasMade(p1)
f.connectionWasMade(p2)
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
d.cancel()
self.failureResultOf(d, defer.CancelledError)
2016-02-15 01:57:09 +00:00
self.assertEqual(p1._cancelled, True)
self.assertEqual(p2._cancelled, True)
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
# XXX check descriptions
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
class OutboundConnectionFactory(unittest.TestCase):
def test_success(self):
f = transit.OutboundConnectionFactory("owner", "relay_handshake",
"description")
2016-02-15 01:57:09 +00:00
f.protocol = MockConnection
addr = address.HostnameAddress("example.com", 1234)
p = f.buildProtocol(addr)
self.assertIsInstance(p, MockConnection)
self.assertEqual(p.owner, "owner")
self.assertEqual(p.relay_handshake, "relay_handshake")
self.assertEqual(p._start_negotiation_called, False)
# meh .start
# this is normally called from Connection.connectionMade
2018-04-21 07:30:08 +00:00
f.connectionWasMade(p) # no-op for outbound
2016-02-15 01:57:09 +00:00
self.assertEqual(p._start_negotiation_called, False)
class MockOwner:
_connection_ready_called = False
2018-04-21 07:30:08 +00:00
def connection_ready(self, connection):
2016-02-15 01:57:09 +00:00
self._connection_ready_called = True
self._connection = connection
return self._state
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
def _send_this(self):
return b"send_this"
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
def _expect_this(self):
return b"expect_this"
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
def _sender_record_key(self):
2018-04-21 07:30:08 +00:00
return b"s" * 32
2016-02-15 01:57:09 +00:00
def _receiver_record_key(self):
2018-04-21 07:30:08 +00:00
return b"r" * 32
2016-02-15 01:57:09 +00:00
class MockFactory:
_connectionWasMade_called = False
2018-04-21 07:30:08 +00:00
def connectionWasMade(self, p):
2016-02-15 01:57:09 +00:00
self._connectionWasMade_called = True
self._p = p
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
class Connection(unittest.TestCase):
# exercise the Connection protocol class
def test_check_and_remove(self):
c = transit.Connection(None, None, None, "description")
2016-02-15 01:57:09 +00:00
c.buf = b""
EXP = b"expectation"
self.assertFalse(c._check_and_remove(EXP))
self.assertEqual(c.buf, b"")
c.buf = b"unexpected"
e = self.assertRaises(transit.BadHandshake, c._check_and_remove, EXP)
2018-04-21 07:30:08 +00:00
self.assertEqual(
str(e), "got %r want %r" % (b'unexpected', b'expectation'))
2016-02-15 01:57:09 +00:00
self.assertEqual(c.buf, b"unexpected")
c.buf = b"expect"
self.assertFalse(c._check_and_remove(EXP))
self.assertEqual(c.buf, b"expect")
c.buf = b"expectation"
self.assertTrue(c._check_and_remove(EXP))
self.assertEqual(c.buf, b"")
c.buf = b"expectation exceeded"
self.assertTrue(c._check_and_remove(EXP))
self.assertEqual(c.buf, b" exceeded")
def test_describe(self):
c = transit.Connection(None, None, None, "description")
self.assertEqual(c.describe(), "description")
2016-02-15 01:57:09 +00:00
def test_sender_accepting(self):
relay_handshake = None
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, relay_handshake, None, "description")
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
self.assertEqual(factory._connectionWasMade_called, True)
self.assertEqual(factory._p, c)
owner._state = "go"
d = c.startNegotiation()
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this")
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
c.dataReceived(b"expect_this")
self.assertEqual(t.read_buf(), b"go\n")
self.assertEqual(t._connected, True)
self.assertEqual(c.state, "records")
self.assertEqual(self.successResultOf(d), c)
2016-02-15 01:57:09 +00:00
c.close()
self.assertEqual(t._connected, False)
def test_sender_rejecting(self):
relay_handshake = None
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, relay_handshake, None, "description")
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
self.assertEqual(factory._connectionWasMade_called, True)
self.assertEqual(factory._p, c)
owner._state = "nevermind"
d = c.startNegotiation()
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this")
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
c.dataReceived(b"expect_this")
self.assertEqual(t.read_buf(), b"nevermind\n")
self.assertEqual(t._connected, False)
self.assertEqual(c.state, "hung up")
f = self.failureResultOf(d, transit.BadHandshake)
2016-02-15 01:57:09 +00:00
self.assertEqual(str(f.value), "abandoned")
def test_handshake_other_error(self):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None, "description")
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
self.assertEqual(factory._connectionWasMade_called, True)
self.assertEqual(factory._p, c)
d = c.startNegotiation()
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this")
self.assertNoResult(d)
2016-06-23 05:58:27 +00:00
c.state = RandomError("boom2")
2016-02-15 01:57:09 +00:00
self.assertRaises(RandomError, c.dataReceived, b"surprise!")
self.assertEqual(t._connected, False)
self.assertEqual(c.state, "hung up")
self.failureResultOf(d, RandomError)
2016-02-15 01:57:09 +00:00
def test_handshake_bad_state(self):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None, "description")
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
self.assertEqual(factory._connectionWasMade_called, True)
self.assertEqual(factory._p, c)
d = c.startNegotiation()
self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this")
self.assertNoResult(d)
c.state = "unknown-bogus-state"
self.assertRaises(ValueError, c.dataReceived, b"surprise!")
self.assertEqual(t._connected, False)
self.assertEqual(c.state, "hung up")
self.failureResultOf(d, ValueError)
2016-02-15 01:57:09 +00:00
def test_relay_handshake(self):
relay_handshake = b"relay handshake"
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, relay_handshake, None, "description")
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
self.assertEqual(factory._connectionWasMade_called, True)
self.assertEqual(factory._p, c)
2018-04-21 07:30:08 +00:00
self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation
2016-02-15 01:57:09 +00:00
owner._state = "go"
d = c.startNegotiation()
2016-02-15 01:57:09 +00:00
self.assertEqual(t.read_buf(), relay_handshake)
2018-04-21 07:30:08 +00:00
self.assertEqual(c.state, "relay") # waiting for OK from relay
2016-02-15 01:57:09 +00:00
c.dataReceived(b"ok\n")
self.assertEqual(t.read_buf(), b"send_this")
self.assertEqual(c.state, "handshake")
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
c.dataReceived(b"expect_this")
self.assertEqual(c.state, "records")
self.assertEqual(self.successResultOf(d), c)
2016-02-15 01:57:09 +00:00
self.assertEqual(t.read_buf(), b"go\n")
def test_relay_handshake_bad(self):
relay_handshake = b"relay handshake"
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, relay_handshake, None, "description")
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
self.assertEqual(factory._connectionWasMade_called, True)
self.assertEqual(factory._p, c)
2018-04-21 07:30:08 +00:00
self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation
2016-02-15 01:57:09 +00:00
owner._state = "go"
d = c.startNegotiation()
2016-02-15 01:57:09 +00:00
self.assertEqual(t.read_buf(), relay_handshake)
2018-04-21 07:30:08 +00:00
self.assertEqual(c.state, "relay") # waiting for OK from relay
2016-02-15 01:57:09 +00:00
c.dataReceived(b"not ok\n")
self.assertEqual(t._connected, False)
self.assertEqual(c.state, "hung up")
f = self.failureResultOf(d, transit.BadHandshake)
2018-04-21 07:30:08 +00:00
self.assertEqual(
str(f.value), "got %r want %r" % (b"not ok\n", b"ok\n"))
2016-02-15 01:57:09 +00:00
def test_receiver_accepted(self):
# we're on the receiving side, so we wait for the sender to decide
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None, "description")
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
self.assertEqual(factory._connectionWasMade_called, True)
self.assertEqual(factory._p, c)
owner._state = "wait-for-decision"
d = c.startNegotiation()
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this")
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
c.dataReceived(b"expect_this")
self.assertEqual(c.state, "wait-for-decision")
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
c.dataReceived(b"go\n")
self.assertEqual(c.state, "records")
self.assertEqual(self.successResultOf(d), c)
2016-02-15 01:57:09 +00:00
def test_receiver_rejected_politely(self):
# we're on the receiving side, so we wait for the sender to decide
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None, "description")
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
self.assertEqual(factory._connectionWasMade_called, True)
self.assertEqual(factory._p, c)
owner._state = "wait-for-decision"
d = c.startNegotiation()
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this")
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
c.dataReceived(b"expect_this")
self.assertEqual(c.state, "wait-for-decision")
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
2018-04-21 07:30:08 +00:00
c.dataReceived(b"nevermind\n") # polite rejection
2016-02-15 01:57:09 +00:00
self.assertEqual(t._connected, False)
self.assertEqual(c.state, "hung up")
f = self.failureResultOf(d, transit.BadHandshake)
2018-04-21 07:30:08 +00:00
self.assertEqual(
str(f.value), "got %r want %r" % (b"nevermind\n", b"go\n"))
2016-02-15 01:57:09 +00:00
def test_receiver_rejected_rudely(self):
# we're on the receiving side, so we wait for the sender to decide
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None, "description")
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
self.assertEqual(factory._connectionWasMade_called, True)
self.assertEqual(factory._p, c)
owner._state = "wait-for-decision"
d = c.startNegotiation()
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this")
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
c.dataReceived(b"expect_this")
self.assertEqual(c.state, "wait-for-decision")
self.assertNoResult(d)
2016-02-15 01:57:09 +00:00
t.loseConnection()
self.assertEqual(t._connected, False)
f = self.failureResultOf(d, transit.BadHandshake)
2016-02-15 01:57:09 +00:00
self.assertEqual(str(f.value), "connection lost")
def test_cancel(self):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None, "description")
2016-02-15 01:57:09 +00:00
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
d = c.startNegotiation()
2016-02-15 01:57:09 +00:00
# while we're waiting for negotiation, we get cancelled
d.cancel()
self.assertEqual(t._connected, False)
self.assertEqual(c.state, "hung up")
self.failureResultOf(d, defer.CancelledError)
2016-02-15 01:57:09 +00:00
def test_timeout(self):
clock = task.Clock()
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None, "description")
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
def _callLater(period, func):
clock.callLater(period, func)
2018-04-21 07:30:08 +00:00
2016-02-15 01:57:09 +00:00
c.callLater = _callLater
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
# the timer should now be running
d = c.startNegotiation()
2016-02-15 01:57:09 +00:00
# while we're waiting for negotiation, the timer expires
clock.advance(transit.TIMEOUT + 1.0)
self.assertEqual(t._connected, False)
f = self.failureResultOf(d, transit.BadHandshake)
2016-02-15 01:57:09 +00:00
self.assertEqual(str(f.value), "timeout")
def make_connection(self):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None, "description")
2016-02-15 01:57:09 +00:00
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
owner._state = "go"
d = c.startNegotiation()
2016-02-15 01:57:09 +00:00
c.dataReceived(b"expect_this")
self.assertEqual(self.successResultOf(d), c)
2018-04-21 07:30:08 +00:00
t.read_buf() # flush input buffer, prepare for encrypted records
2016-02-15 01:57:09 +00:00
return t, c, owner
def test_records_not_binary(self):
t, c, owner = self.make_connection()
RECORD1 = u"not binary"
with self.assertRaises(InternalError):
c.send_record(RECORD1)
2016-02-15 01:57:09 +00:00
def test_records_good(self):
# now make sure that outbound records are encrypted properly
t, c, owner = self.make_connection()
RECORD1 = b"record"
c.send_record(RECORD1)
buf = t.read_buf()
2018-04-21 07:30:08 +00:00
expected = ("%08x" % (24 + len(RECORD1) + 16)).encode("ascii")
2016-02-15 01:57:09 +00:00
self.assertEqual(hexlify(buf[:4]), expected)
encrypted = buf[4:]
receive_box = SecretBox(owner._sender_record_key())
2018-04-21 07:30:08 +00:00
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
2016-02-15 01:57:09 +00:00
nonce = int(hexlify(nonce_buf), 16)
2018-04-21 07:30:08 +00:00
self.assertEqual(nonce, 0) # first message gets nonce 0
2016-02-15 01:57:09 +00:00
decrypted = receive_box.decrypt(encrypted)
self.assertEqual(decrypted, RECORD1)
# second message gets nonce 1
RECORD2 = b"record2"
c.send_record(RECORD2)
buf = t.read_buf()
2018-04-21 07:30:08 +00:00
expected = ("%08x" % (24 + len(RECORD2) + 16)).encode("ascii")
2016-02-15 01:57:09 +00:00
self.assertEqual(hexlify(buf[:4]), expected)
encrypted = buf[4:]
receive_box = SecretBox(owner._sender_record_key())
2018-04-21 07:30:08 +00:00
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
2016-02-15 01:57:09 +00:00
nonce = int(hexlify(nonce_buf), 16)
self.assertEqual(nonce, 1)
decrypted = receive_box.decrypt(encrypted)
self.assertEqual(decrypted, RECORD2)
# and that we can receive records properly
inbound_records = []
c.recordReceived = inbound_records.append
send_box = SecretBox(owner._receiver_record_key())
2016-02-15 01:57:09 +00:00
RECORD3 = b"record3"
2018-04-21 07:30:08 +00:00
nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0
2016-02-15 01:57:09 +00:00
encrypted = send_box.encrypt(RECORD3, nonce_buf)
2018-04-21 07:30:08 +00:00
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
2016-02-15 01:57:09 +00:00
c.dataReceived(length[:2])
c.dataReceived(length[2:])
c.dataReceived(encrypted[:-2])
self.assertEqual(inbound_records, [])
c.dataReceived(encrypted[-2:])
self.assertEqual(inbound_records, [RECORD3])
RECORD4 = b"record4"
2018-04-21 07:30:08 +00:00
nonce_buf = unhexlify("%048x" % 1) # nonces increment
2016-02-15 01:57:09 +00:00
encrypted = send_box.encrypt(RECORD4, nonce_buf)
2018-04-21 07:30:08 +00:00
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
2016-02-15 01:57:09 +00:00
c.dataReceived(length[:2])
c.dataReceived(length[2:])
c.dataReceived(encrypted[:-2])
self.assertEqual(inbound_records, [RECORD3])
c.dataReceived(encrypted[-2:])
self.assertEqual(inbound_records, [RECORD3, RECORD4])
# receiving two records at the same time: deliver both
inbound_records[:] = []
RECORD5 = b"record5"
2018-04-21 07:30:08 +00:00
nonce_buf = unhexlify("%048x" % 2) # nonces increment
encrypted = send_box.encrypt(RECORD5, nonce_buf)
2018-04-21 07:30:08 +00:00
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
r5 = length + encrypted
RECORD6 = b"record6"
2018-04-21 07:30:08 +00:00
nonce_buf = unhexlify("%048x" % 3) # nonces increment
encrypted = send_box.encrypt(RECORD6, nonce_buf)
2018-04-21 07:30:08 +00:00
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
r6 = length + encrypted
c.dataReceived(r5 + r6)
self.assertEqual(inbound_records, [RECORD5, RECORD6])
2016-02-15 01:57:09 +00:00
def corrupt(self, orig):
last_byte = orig[-1:]
num = int(hexlify(last_byte).decode("ascii"), 16)
corrupt_num = 256 - num
as_byte = unhexlify("%02x" % corrupt_num)
return orig[:-1] + as_byte
def test_records_corrupt(self):
# corrupt records should be rejected
t, c, owner = self.make_connection()
inbound_records = []
c.recordReceived = inbound_records.append
RECORD = b"record"
send_box = SecretBox(owner._receiver_record_key())
2018-04-21 07:30:08 +00:00
nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0
2016-02-15 01:57:09 +00:00
encrypted = self.corrupt(send_box.encrypt(RECORD, nonce_buf))
2018-04-21 07:30:08 +00:00
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
2016-02-15 01:57:09 +00:00
c.dataReceived(length)
c.dataReceived(encrypted[:-2])
self.assertEqual(inbound_records, [])
self.assertRaises(CryptoError, c.dataReceived, encrypted[-2:])
self.assertEqual(inbound_records, [])
# and the connection should have been dropped
self.assertEqual(t._connected, False)
def test_out_of_order_nonce(self):
# an inbound out-of-order nonce should be rejected
t, c, owner = self.make_connection()
inbound_records = []
c.recordReceived = inbound_records.append
RECORD = b"record"
send_box = SecretBox(owner._receiver_record_key())
2018-04-21 07:30:08 +00:00
nonce_buf = unhexlify("%048x" % 1) # first nonce must be 0
2016-02-15 01:57:09 +00:00
encrypted = send_box.encrypt(RECORD, nonce_buf)
2018-04-21 07:30:08 +00:00
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
2016-02-15 01:57:09 +00:00
c.dataReceived(length)
c.dataReceived(encrypted[:-2])
self.assertEqual(inbound_records, [])
self.assertRaises(transit.BadNonce, c.dataReceived, encrypted[-2:])
self.assertEqual(inbound_records, [])
# and the connection should have been dropped
self.assertEqual(t._connected, False)
# TODO: check that .connectionLost/loseConnection signatures are
# consistent: zero args, or one arg?
# XXX: if we don't set the transit key before connecting, what
# happens? We currently get a type-check assertion from HKDF because
# the key is None.
def test_receive_queue(self):
c = transit.Connection(None, None, None, "description")
c.transport = FakeTransport(c, None)
c.transport.signalConnectionLost = False
c.recordReceived(b"0")
c.recordReceived(b"1")
c.recordReceived(b"2")
d0 = c.receive_record()
self.assertEqual(self.successResultOf(d0), b"0")
d1 = c.receive_record()
d2 = c.receive_record()
# they must fire in order of receipt, not order of addCallback
self.assertEqual(self.successResultOf(d2), b"2")
self.assertEqual(self.successResultOf(d1), b"1")
d3 = c.receive_record()
d4 = c.receive_record()
self.assertNoResult(d3)
self.assertNoResult(d4)
c.recordReceived(b"3")
self.assertEqual(self.successResultOf(d3), b"3")
self.assertNoResult(d4)
c.recordReceived(b"4")
self.assertEqual(self.successResultOf(d4), b"4")
d5 = c.receive_record()
c.close()
self.failureResultOf(d5, error.ConnectionClosed)
def test_producer(self):
# a Transit object (receiving data from the remote peer) produces
# data and writes it into a local Consumer
c = transit.Connection(None, None, None, "description")
c.transport = proto_helpers.StringTransport()
c.recordReceived(b"r1.")
c.recordReceived(b"r2.")
consumer = proto_helpers.StringTransport()
rv = c.connectConsumer(consumer)
self.assertIs(rv, None)
self.assertIs(c._consumer, consumer)
self.assertEqual(consumer.value(), b"r1.r2.")
self.assertRaises(RuntimeError, c.connectConsumer, consumer)
c.recordReceived(b"r3.")
self.assertEqual(consumer.value(), b"r1.r2.r3.")
c.pauseProducing()
self.assertEqual(c.transport.producerState, "paused")
c.resumeProducing()
self.assertEqual(c.transport.producerState, "producing")
c.disconnectConsumer()
self.assertEqual(consumer.producer, None)
c.connectConsumer(consumer)
c.stopProducing()
self.assertEqual(c.transport.producerState, "stopped")
def test_connectConsumer(self):
# connectConsumer() takes an optional number of bytes to expect, and
# fires a Deferred when that many have been written
c = transit.Connection(None, None, None, "description")
2018-04-21 07:30:08 +00:00
c._negotiation_d.addErrback(lambda err: None) # eat it
c.transport = proto_helpers.StringTransport()
c.recordReceived(b"r1.")
consumer = proto_helpers.StringTransport()
d = c.connectConsumer(consumer, expected=10)
self.assertEqual(consumer.value(), b"r1.")
self.assertNoResult(d)
c.recordReceived(b"r2.")
self.assertEqual(consumer.value(), b"r1.r2.")
self.assertNoResult(d)
c.recordReceived(b"r3.")
self.assertEqual(consumer.value(), b"r1.r2.r3.")
self.assertNoResult(d)
c.recordReceived(b"!")
self.assertEqual(consumer.value(), b"r1.r2.r3.!")
self.assertEqual(self.successResultOf(d), 10)
# that should automatically disconnect the consumer, and subsequent
# records should get queued, not delivered
self.assertIs(c._consumer, None)
c.recordReceived(b"overflow")
self.assertEqual(consumer.value(), b"r1.r2.r3.!")
# now test that the Deferred errbacks when the connection is lost
d = c.connectConsumer(consumer, expected=10)
c.connectionLost()
self.failureResultOf(d, error.ConnectionClosed)
def test_connectConsumer_empty(self):
# if connectConsumer() expects 0 bytes (e.g. someone is "sending" a
# zero-length file), make sure it gets woken up right away, so it can
# disconnect itself, even though no bytes will actually arrive
c = transit.Connection(None, None, None, "description")
2018-04-21 07:30:08 +00:00
c._negotiation_d.addErrback(lambda err: None) # eat it
c.transport = proto_helpers.StringTransport()
consumer = proto_helpers.StringTransport()
d = c.connectConsumer(consumer, expected=0)
self.assertEqual(self.successResultOf(d), 0)
self.assertEqual(consumer.value(), b"")
self.assertIs(c._consumer, None)
def test_writeToFile(self):
c = transit.Connection(None, None, None, "description")
2018-04-21 07:30:08 +00:00
c._negotiation_d.addErrback(lambda err: None) # eat it
c.transport = proto_helpers.StringTransport()
c.recordReceived(b"r1.")
f = io.BytesIO()
progress = []
d = c.writeToFile(f, 10, progress.append)
self.assertEqual(f.getvalue(), b"r1.")
self.assertEqual(progress, [3])
self.assertNoResult(d)
c.recordReceived(b"r2.")
self.assertEqual(f.getvalue(), b"r1.r2.")
self.assertEqual(progress, [3, 3])
self.assertNoResult(d)
c.recordReceived(b"r3.")
self.assertEqual(f.getvalue(), b"r1.r2.r3.")
self.assertEqual(progress, [3, 3, 3])
self.assertNoResult(d)
c.recordReceived(b"!")
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
self.assertEqual(progress, [3, 3, 3, 1])
self.assertEqual(self.successResultOf(d), 10)
# that should automatically disconnect the consumer, and subsequent
# records should get queued, not delivered
self.assertIs(c._consumer, None)
c.recordReceived(b"overflow.")
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
self.assertEqual(progress, [3, 3, 3, 1])
# test what happens when enough data is queued ahead of time
2018-04-21 07:30:08 +00:00
c.recordReceived(b"second.") # now "overflow.second."
c.recordReceived(b"third.") # now "overflow.second.third."
f = io.BytesIO()
d = c.writeToFile(f, 10)
2018-04-21 07:30:08 +00:00
self.assertEqual(f.getvalue(), b"overflow.second.") # whole records
self.assertEqual(self.successResultOf(d), 16)
self.assertEqual(list(c._inbound_records), [b"third."])
# now test that the Deferred errbacks when the connection is lost
d = c.writeToFile(f, 10)
c.connectionLost()
self.failureResultOf(d, error.ConnectionClosed)
def test_consumer(self):
# a local producer sends data to a consuming Transit object
c = transit.Connection(None, None, None, "description")
c.transport = proto_helpers.StringTransport()
records = []
c.send_record = records.append
producer = proto_helpers.StringTransport()
c.registerProducer(producer, True)
self.assertIs(c.transport.producer, producer)
c.write(b"r1.")
self.assertEqual(records, [b"r1."])
c.unregisterProducer()
self.assertEqual(c.transport.producer, None)
2018-04-21 07:30:08 +00:00
class FileConsumer(unittest.TestCase):
def test_basic(self):
f = io.BytesIO()
progress = []
fc = transit.FileConsumer(f, progress.append)
self.assertEqual(progress, [])
self.assertEqual(f.getvalue(), b"")
2018-04-21 07:30:08 +00:00
fc.write(b"." * 99)
self.assertEqual(progress, [99])
2018-04-21 07:30:08 +00:00
self.assertEqual(f.getvalue(), b"." * 99)
fc.write(b"!")
self.assertEqual(progress, [99, 1])
2018-04-21 07:30:08 +00:00
self.assertEqual(f.getvalue(), b"." * 99 + b"!")
def test_hasher(self):
hashee = []
f = io.BytesIO()
progress = []
fc = transit.FileConsumer(f, progress.append, hasher=hashee.append)
self.assertEqual(progress, [])
self.assertEqual(f.getvalue(), b"")
self.assertEqual(hashee, [])
2018-04-21 07:30:08 +00:00
fc.write(b"." * 99)
self.assertEqual(progress, [99])
2018-04-21 07:30:08 +00:00
self.assertEqual(f.getvalue(), b"." * 99)
self.assertEqual(hashee, [b"." * 99])
fc.write(b"!")
self.assertEqual(progress, [99, 1])
2018-04-21 07:30:08 +00:00
self.assertEqual(f.getvalue(), b"." * 99 + b"!")
self.assertEqual(hashee, [b"." * 99, b"!"])
DIRECT_HINT_JSON = {
"type": "direct-tcp-v1",
"hostname": "direct",
"port": 1234
}
RELAY_HINT_JSON = {
"type": "relay-v1",
"hints": [{
"type": "direct-tcp-v1",
"hostname": "relay",
"port": 1234
}]
}
UNRECOGNIZED_DIRECT_HINT_JSON = {
"type": "direct-tcp-v1",
"hostname": ["cannot", "parse", "list"]
}
2016-12-31 03:33:34 +00:00
UNRECOGNIZED_HINT_JSON = {"type": "unknown"}
2018-04-21 07:30:08 +00:00
UNAVAILABLE_HINT_JSON = {
"type": "direct-tcp-v1", # e.g. Tor without txtorcon
"hostname": "unavailable",
"port": 1234
}
RELAY_HINT2_JSON = {
"type":
"relay-v1",
"hints": [{
"type": "direct-tcp-v1",
"hostname": "relay",
"port": 1234
}, UNRECOGNIZED_HINT_JSON]
}
UNAVAILABLE_RELAY_HINT_JSON = {
"type": "relay-v1",
"hints": [UNAVAILABLE_HINT_JSON]
}
2016-02-15 01:57:09 +00:00
class Transit(unittest.TestCase):
2016-12-31 03:33:34 +00:00
def setUp(self):
self._connectors = []
self._waiters = []
2017-04-17 21:04:59 +00:00
self._descriptions = []
2016-12-31 03:33:34 +00:00
def _start_connector(self, ep, description, is_relay=False):
d = defer.Deferred()
self._connectors.append(ep)
self._waiters.append(d)
2017-04-17 21:04:59 +00:00
self._descriptions.append(description)
2016-12-31 03:33:34 +00:00
return d
2016-02-15 01:57:09 +00:00
@inlineCallbacks
def test_success_direct(self):
clock = task.Clock()
2016-06-04 20:18:43 +00:00
s = transit.TransitSender("", reactor=clock)
2016-02-15 01:57:09 +00:00
s.set_transit_key(b"key")
2018-04-21 07:30:08 +00:00
hints = yield s.get_connection_hints() # start the listener
2016-02-15 01:57:09 +00:00
del hints
2018-04-21 07:30:08 +00:00
s.add_connection_hints([
DIRECT_HINT_JSON, UNRECOGNIZED_DIRECT_HINT_JSON,
UNRECOGNIZED_HINT_JSON
])
2016-12-31 03:33:34 +00:00
s._start_connector = self._start_connector
2016-02-15 01:57:09 +00:00
d = s.connect()
self.assertNoResult(d)
2016-12-31 03:33:34 +00:00
self.assertEqual(len(self._waiters), 1)
self.assertIsInstance(self._waiters[0], defer.Deferred)
2016-02-15 01:57:09 +00:00
2016-12-31 03:33:34 +00:00
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
2017-04-17 21:04:59 +00:00
self.assertEqual(self._descriptions, ["->tcp:direct:1234"])
@inlineCallbacks
def test_success_direct_tor(self):
clock = task.Clock()
s = transit.TransitSender("", tor=mock.Mock(), reactor=clock)
2017-04-17 21:04:59 +00:00
s.set_transit_key(b"key")
2018-04-21 07:30:08 +00:00
hints = yield s.get_connection_hints() # start the listener
2017-04-17 21:04:59 +00:00
del hints
s.add_connection_hints([DIRECT_HINT_JSON])
s._start_connector = self._start_connector
d = s.connect()
self.assertNoResult(d)
2017-04-17 21:04:59 +00:00
self.assertEqual(len(self._waiters), 1)
self.assertIsInstance(self._waiters[0], defer.Deferred)
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
2017-04-17 21:04:59 +00:00
self.assertEqual(self._descriptions, ["tor->tcp:direct:1234"])
@inlineCallbacks
def test_success_direct_tor_relay(self):
clock = task.Clock()
s = transit.TransitSender("", tor=mock.Mock(), reactor=clock)
2017-04-17 21:04:59 +00:00
s.set_transit_key(b"key")
2018-04-21 07:30:08 +00:00
hints = yield s.get_connection_hints() # start the listener
2017-04-17 21:04:59 +00:00
del hints
s.add_connection_hints([RELAY_HINT_JSON])
s._start_connector = self._start_connector
d = s.connect()
# move the clock forward any amount, since relay connections are
# triggered starting at T+0.0
clock.advance(1.0)
self.assertNoResult(d)
2017-04-17 21:04:59 +00:00
self.assertEqual(len(self._waiters), 1)
self.assertIsInstance(self._waiters[0], defer.Deferred)
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
2017-04-17 21:04:59 +00:00
self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"])
2016-02-15 01:57:09 +00:00
def _endpoint_from_hint_obj(self, hint, _tor, _reactor):
2018-12-22 04:22:02 +00:00
if isinstance(hint, DirectTCPV1Hint):
2016-12-31 03:33:34 +00:00
if hint.hostname == "unavailable":
return None
return hint.hostname
return None
2016-02-15 01:57:09 +00:00
@inlineCallbacks
def test_wait_for_relay(self):
clock = task.Clock()
2016-06-04 20:18:43 +00:00
s = transit.TransitSender("", reactor=clock, no_listen=True)
2016-02-15 01:57:09 +00:00
s.set_transit_key(b"key")
2016-12-31 03:33:34 +00:00
hints = yield s.get_connection_hints()
2016-02-15 01:57:09 +00:00
del hints
2018-04-21 07:30:08 +00:00
s.add_connection_hints(
[DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON, RELAY_HINT_JSON])
2016-12-31 03:33:34 +00:00
s._start_connector = self._start_connector
2016-02-15 01:57:09 +00:00
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self._endpoint_from_hint_obj):
d = s.connect()
self.assertNoResult(d)
# the direct connectors are tried right away, but the relay
# connectors are stalled for a few seconds
self.assertEqual(self._connectors, ["direct"])
2016-02-15 01:57:09 +00:00
clock.advance(s.RELAY_DELAY + 1.0)
self.assertEqual(self._connectors, ["direct", "relay"])
2016-02-15 01:57:09 +00:00
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
2016-02-15 01:57:09 +00:00
@inlineCallbacks
def test_priorities(self):
clock = task.Clock()
s = transit.TransitSender("", reactor=clock, no_listen=True)
s.set_transit_key(b"key")
hints = yield s.get_connection_hints()
del hints
s.add_connection_hints([
2018-04-21 07:30:08 +00:00
{
"type":
"relay-v1",
"hints": [{
"type": "direct-tcp-v1",
"hostname": "relay",
"port": 1234
}]
},
{
"type": "direct-tcp-v1",
"hostname": "direct",
"port": 1234
},
{
"type":
"relay-v1",
"hints": [{
"type": "direct-tcp-v1",
"priority": 2.0,
"hostname": "relay2",
"port": 1234
}, {
"type": "direct-tcp-v1",
"priority": 3.0,
"hostname": "relay3",
"port": 1234
}]
},
{
"type":
"relay-v1",
"hints": [{
"type": "direct-tcp-v1",
"priority": 2.0,
"hostname": "relay4",
"port": 1234
}]
},
])
s._start_connector = self._start_connector
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self._endpoint_from_hint_obj):
d = s.connect()
self.assertNoResult(d)
# direct connector should be used first, then the priority=3.0 relay,
# then the two 2.0 relays, then the (default) 0.0 relay
self.assertEqual(self._connectors, ["direct"])
clock.advance(s.RELAY_DELAY + 1.0)
self.assertEqual(self._connectors, ["direct", "relay3"])
clock.advance(s.RELAY_DELAY)
self.assertIn(self._connectors,
(["direct", "relay3", "relay2", "relay4"],
["direct", "relay3", "relay4", "relay2"]))
clock.advance(s.RELAY_DELAY)
self.assertIn(self._connectors,
(["direct", "relay3", "relay2", "relay4", "relay"],
["direct", "relay3", "relay4", "relay2", "relay"]))
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
2016-02-15 01:57:09 +00:00
@inlineCallbacks
def test_no_direct_hints(self):
clock = task.Clock()
2016-06-04 20:18:43 +00:00
s = transit.TransitSender("", reactor=clock, no_listen=True)
2016-02-15 01:57:09 +00:00
s.set_transit_key(b"key")
2018-04-21 07:30:08 +00:00
hints = yield s.get_connection_hints() # start the listener
2016-02-15 01:57:09 +00:00
del hints
2016-12-24 02:31:19 +00:00
# include hints that can't be turned into an endpoint at runtime
2018-04-21 07:30:08 +00:00
s.add_connection_hints([
UNRECOGNIZED_HINT_JSON, UNAVAILABLE_HINT_JSON, RELAY_HINT2_JSON,
UNAVAILABLE_RELAY_HINT_JSON
])
2016-12-31 03:33:34 +00:00
s._start_connector = self._start_connector
2016-02-15 01:57:09 +00:00
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self._endpoint_from_hint_obj):
d = s.connect()
self.assertNoResult(d)
# since there are no usable direct hints, the relay connector will
# only be stalled for 0 seconds
self.assertEqual(self._connectors, [])
2016-02-15 01:57:09 +00:00
clock.advance(0)
self.assertEqual(self._connectors, ["relay"])
2016-02-15 01:57:09 +00:00
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
2016-02-15 01:57:09 +00:00
2016-12-24 02:31:19 +00:00
@inlineCallbacks
def test_no_contenders(self):
clock = task.Clock()
s = transit.TransitSender("", reactor=clock, no_listen=True)
s.set_transit_key(b"key")
2018-04-21 07:30:08 +00:00
hints = yield s.get_connection_hints() # start the listener
2016-12-24 02:31:19 +00:00
del hints
2018-04-21 07:30:08 +00:00
s.add_connection_hints([]) # no hints at all
2016-12-31 03:33:34 +00:00
s._start_connector = self._start_connector
2016-12-24 02:31:19 +00:00
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self._endpoint_from_hint_obj):
d = s.connect()
f = self.failureResultOf(d, transit.TransitError)
self.assertEqual(str(f.value), "No contenders for connection")
2016-12-24 02:31:19 +00:00
2018-04-21 07:30:08 +00:00
2016-12-22 23:17:05 +00:00
class RelayHandshake(unittest.TestCase):
def old_build_relay_handshake(self, key):
2018-12-22 22:27:54 +00:00
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
2018-04-21 07:30:08 +00:00
return (token, b"please relay " + hexlify(token) + b"\n")
2016-12-22 23:17:05 +00:00
def test_old(self):
key = b"\x00"
token, old_handshake = self.old_build_relay_handshake(key)
tc = transit_server.TransitConnection()
tc.factory = mock.Mock()
tc.factory.connection_got_token = mock.Mock()
tc.dataReceived(old_handshake[:-1])
self.assertEqual(tc.factory.connection_got_token.mock_calls, [])
tc.dataReceived(old_handshake[-1:])
self.assertEqual(tc.factory.connection_got_token.mock_calls,
[mock.call(hexlify(token), None, tc)])
def test_new(self):
c = transit.Common(None)
c.set_transit_key(b"\x00")
new_handshake = c._build_relay_handshake()
token, old_handshake = self.old_build_relay_handshake(b"\x00")
tc = transit_server.TransitConnection()
tc.factory = mock.Mock()
tc.factory.connection_got_token = mock.Mock()
tc.dataReceived(new_handshake[:-1])
self.assertEqual(tc.factory.connection_got_token.mock_calls, [])
tc.dataReceived(new_handshake[-1:])
2018-04-21 07:30:08 +00:00
self.assertEqual(
tc.factory.connection_got_token.mock_calls,
[mock.call(hexlify(token), c._side.encode("ascii"), tc)])
2016-12-22 23:17:05 +00:00
2016-02-15 01:57:09 +00:00
2016-12-23 04:47:54 +00:00
class Full(ServerBase, unittest.TestCase):
2016-02-15 01:57:09 +00:00
def doBoth(self, d1, d2):
return gatherResults([d1, d2], True)
@inlineCallbacks
2016-12-23 04:47:54 +00:00
def test_direct(self):
2018-04-21 07:30:08 +00:00
KEY = b"k" * 32
2016-02-15 01:57:09 +00:00
s = transit.TransitSender(None)
r = transit.TransitReceiver(None)
s.set_transit_key(KEY)
r.set_transit_key(KEY)
shints = yield s.get_connection_hints()
rhints = yield r.get_connection_hints()
2016-02-15 01:57:09 +00:00
s.add_connection_hints(rhints)
r.add_connection_hints(shints)
2016-02-15 01:57:09 +00:00
2018-04-21 07:30:08 +00:00
(x, y) = yield self.doBoth(s.connect(), r.connect())
2016-02-15 01:57:09 +00:00
self.assertIsInstance(x, transit.Connection)
self.assertIsInstance(y, transit.Connection)
d = y.receive_record()
2016-02-15 01:57:09 +00:00
x.send_record(b"record1")
r = yield d
self.assertEqual(r, b"record1")
yield x.close()
yield y.close()
2016-12-23 04:47:54 +00:00
@inlineCallbacks
def test_relay(self):
2018-04-21 07:30:08 +00:00
KEY = b"k" * 32
2016-12-23 04:47:54 +00:00
s = transit.TransitSender(self.transit, no_listen=True)
r = transit.TransitReceiver(self.transit, no_listen=True)
s.set_transit_key(KEY)
r.set_transit_key(KEY)
shints = yield s.get_connection_hints()
rhints = yield r.get_connection_hints()
s.add_connection_hints(rhints)
r.add_connection_hints(shints)
2018-04-21 07:30:08 +00:00
(x, y) = yield self.doBoth(s.connect(), r.connect())
2016-12-23 04:47:54 +00:00
self.assertIsInstance(x, transit.Connection)
self.assertIsInstance(y, transit.Connection)
d = y.receive_record()
x.send_record(b"record1")
r = yield d
self.assertEqual(r, b"record1")
yield x.close()
yield y.close()