1622 lines
56 KiB
Python
1622 lines
56 KiB
Python
from __future__ import print_function, unicode_literals
|
|
|
|
import gc
|
|
import io
|
|
from binascii import hexlify, unhexlify
|
|
|
|
import six
|
|
from nacl.exceptions import CryptoError
|
|
from nacl.secret import SecretBox
|
|
from twisted.internet import address, defer, endpoints, error, protocol, task
|
|
from twisted.internet.defer import gatherResults, inlineCallbacks
|
|
from twisted.python import log
|
|
from twisted.test import proto_helpers
|
|
from twisted.trial import unittest
|
|
|
|
import mock
|
|
from wormhole_transit_relay import transit_server
|
|
|
|
from .. import transit
|
|
from .._hints import DirectTCPV1Hint
|
|
from ..errors import InternalError
|
|
from ..util import HKDF
|
|
from .common import ServerBase
|
|
|
|
|
|
class Highlander(unittest.TestCase):
|
|
def test_one_winner(self):
|
|
cancelled = set()
|
|
contenders = [
|
|
defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(5)
|
|
]
|
|
d = transit.there_can_be_only_one(contenders)
|
|
self.assertNoResult(d)
|
|
contenders[0].errback(ValueError())
|
|
self.assertNoResult(d)
|
|
contenders[1].errback(TypeError())
|
|
self.assertNoResult(d)
|
|
contenders[2].callback("yay")
|
|
self.assertEqual(self.successResultOf(d), "yay")
|
|
self.assertEqual(cancelled, set([3, 4]))
|
|
|
|
def test_there_might_also_be_none(self):
|
|
cancelled = set()
|
|
contenders = [
|
|
defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4)
|
|
]
|
|
d = transit.there_can_be_only_one(contenders)
|
|
self.assertNoResult(d)
|
|
contenders[0].errback(ValueError())
|
|
self.assertNoResult(d)
|
|
contenders[1].errback(TypeError())
|
|
self.assertNoResult(d)
|
|
contenders[2].errback(TypeError())
|
|
self.assertNoResult(d)
|
|
contenders[3].errback(NameError())
|
|
self.failureResultOf(d, ValueError) # first failure is recorded
|
|
self.assertEqual(cancelled, set())
|
|
|
|
def test_cancel_early(self):
|
|
cancelled = set()
|
|
contenders = [
|
|
defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4)
|
|
]
|
|
d = transit.there_can_be_only_one(contenders)
|
|
self.assertNoResult(d)
|
|
self.assertEqual(cancelled, set())
|
|
d.cancel()
|
|
self.failureResultOf(d, defer.CancelledError)
|
|
self.assertEqual(cancelled, set(range(4)))
|
|
|
|
def test_cancel_after_one_failure(self):
|
|
cancelled = set()
|
|
contenders = [
|
|
defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4)
|
|
]
|
|
d = transit.there_can_be_only_one(contenders)
|
|
self.assertNoResult(d)
|
|
self.assertEqual(cancelled, set())
|
|
contenders[0].errback(ValueError())
|
|
d.cancel()
|
|
self.failureResultOf(d, ValueError)
|
|
self.assertEqual(cancelled, set([1, 2, 3]))
|
|
|
|
|
|
class Forever(unittest.TestCase):
|
|
def _forever_setup(self):
|
|
clock = task.Clock()
|
|
c = transit.Common("", reactor=clock)
|
|
cancelled = []
|
|
d0 = defer.Deferred(cancelled.append)
|
|
d = c._not_forever(1.0, d0)
|
|
return c, clock, d0, d, cancelled
|
|
|
|
def test_not_forever_fires(self):
|
|
c, clock, d0, d, cancelled = self._forever_setup()
|
|
self.assertNoResult(d)
|
|
self.assertEqual(cancelled, [])
|
|
d.callback(1)
|
|
self.assertEqual(self.successResultOf(d), 1)
|
|
self.assertEqual(cancelled, [])
|
|
self.assertNot(clock.getDelayedCalls())
|
|
|
|
def test_not_forever_errs(self):
|
|
c, clock, d0, d, cancelled = self._forever_setup()
|
|
self.assertNoResult(d)
|
|
self.assertEqual(cancelled, [])
|
|
d.errback(ValueError())
|
|
self.assertEqual(cancelled, [])
|
|
self.failureResultOf(d, ValueError)
|
|
self.assertNot(clock.getDelayedCalls())
|
|
|
|
def test_not_forever_cancel_early(self):
|
|
c, clock, d0, d, cancelled = self._forever_setup()
|
|
self.assertNoResult(d)
|
|
self.assertEqual(cancelled, [])
|
|
d.cancel()
|
|
self.assertEqual(cancelled, [d0])
|
|
self.failureResultOf(d, defer.CancelledError)
|
|
self.assertNot(clock.getDelayedCalls())
|
|
|
|
def test_not_forever_timeout(self):
|
|
c, clock, d0, d, cancelled = self._forever_setup()
|
|
self.assertNoResult(d)
|
|
self.assertEqual(cancelled, [])
|
|
clock.advance(2.0)
|
|
self.assertEqual(cancelled, [d0])
|
|
self.failureResultOf(d, defer.CancelledError)
|
|
self.assertNot(clock.getDelayedCalls())
|
|
|
|
|
|
class Misc(unittest.TestCase):
|
|
def test_allocate_port(self):
|
|
portno = transit.allocate_tcp_port()
|
|
self.assertIsInstance(portno, int)
|
|
|
|
def test_allocate_port_no_reuseaddr(self):
|
|
mock_sys = mock.Mock()
|
|
mock_sys.platform = "cygwin"
|
|
with mock.patch("wormhole.transit.sys", mock_sys):
|
|
portno = transit.allocate_tcp_port()
|
|
self.assertIsInstance(portno, int)
|
|
|
|
|
|
|
|
# ipaddrs.py currently uses native strings: bytes on py2, unicode on
|
|
# py3
|
|
if six.PY2:
|
|
LOOPADDR = b"127.0.0.1"
|
|
OTHERADDR = b"1.2.3.4"
|
|
else:
|
|
LOOPADDR = "127.0.0.1" # unicode_literals
|
|
OTHERADDR = "1.2.3.4"
|
|
|
|
|
|
class Basic(unittest.TestCase):
|
|
@inlineCallbacks
|
|
def test_relay_hints(self):
|
|
URL = "tcp:host:1234"
|
|
c = transit.Common(URL, no_listen=True)
|
|
hints = yield c.get_connection_hints()
|
|
self.assertEqual(hints, [{
|
|
"type":
|
|
"relay-v1",
|
|
"hints": [{
|
|
"type": "direct-tcp-v1",
|
|
"hostname": "host",
|
|
"port": 1234,
|
|
"priority": 0.0
|
|
}],
|
|
}])
|
|
self.assertRaises(InternalError, transit.Common, 123)
|
|
|
|
@inlineCallbacks
|
|
def test_no_relay_hints(self):
|
|
c = transit.Common(None, no_listen=True)
|
|
hints = yield c.get_connection_hints()
|
|
self.assertEqual(hints, [])
|
|
|
|
def test_ignore_bad_hints(self):
|
|
c = transit.Common("")
|
|
c.add_connection_hints([{"type": "unknown"}])
|
|
c.add_connection_hints([{
|
|
"type": "relay-v1",
|
|
"hints": [{
|
|
"type": "unknown"
|
|
}]
|
|
}])
|
|
self.assertEqual(c._their_direct_hints, [])
|
|
self.assertEqual(c._our_relay_hints, set())
|
|
|
|
def test_ignore_localhost_hint_orig(self):
|
|
# this actually starts the listener
|
|
c = transit.TransitSender("")
|
|
hints = self.successResultOf(c.get_connection_hints())
|
|
c._stop_listening()
|
|
# If there are non-localhost hints, then localhost hints should be
|
|
# removed. But if the only hint is localhost, it should stay.
|
|
if len(hints) == 1:
|
|
if hints[0]["hostname"] == "127.0.0.1":
|
|
return
|
|
for hint in hints:
|
|
self.assertFalse(hint["hostname"] == "127.0.0.1")
|
|
|
|
def test_ignore_localhost_hint(self):
|
|
# this actually starts the listener
|
|
c = transit.TransitSender("")
|
|
with mock.patch(
|
|
"wormhole.ipaddrs.find_addresses",
|
|
return_value=[LOOPADDR, OTHERADDR]):
|
|
hints = self.successResultOf(c.get_connection_hints())
|
|
c._stop_listening()
|
|
# If there are non-localhost hints, then localhost hints should be
|
|
# removed.
|
|
self.assertEqual(len(hints), 1)
|
|
self.assertEqual(hints[0]["hostname"], "1.2.3.4")
|
|
|
|
def test_keep_only_localhost_hint(self):
|
|
# this actually starts the listener
|
|
c = transit.TransitSender("")
|
|
with mock.patch(
|
|
"wormhole.ipaddrs.find_addresses", return_value=[LOOPADDR]):
|
|
hints = self.successResultOf(c.get_connection_hints())
|
|
c._stop_listening()
|
|
# If the only hint is localhost, it should stay.
|
|
self.assertEqual(len(hints), 1)
|
|
self.assertEqual(hints[0]["hostname"], "127.0.0.1")
|
|
|
|
def test_abilities(self):
|
|
c = transit.Common(None, no_listen=True)
|
|
abilities = c.get_connection_abilities()
|
|
self.assertEqual(abilities, [
|
|
{
|
|
"type": "direct-tcp-v1"
|
|
},
|
|
{
|
|
"type": "relay-v1"
|
|
},
|
|
])
|
|
|
|
def test_transit_key_wait(self):
|
|
KEY = b"123"
|
|
c = transit.Common("")
|
|
d = c._get_transit_key()
|
|
self.assertNoResult(d)
|
|
c.set_transit_key(KEY)
|
|
self.assertEqual(self.successResultOf(d), KEY)
|
|
|
|
def test_transit_key_already_set(self):
|
|
KEY = b"123"
|
|
c = transit.Common("")
|
|
c.set_transit_key(KEY)
|
|
d = c._get_transit_key()
|
|
self.assertEqual(self.successResultOf(d), KEY)
|
|
|
|
def test_transit_keys(self):
|
|
KEY = b"123"
|
|
s = transit.TransitSender("")
|
|
s.set_transit_key(KEY)
|
|
r = transit.TransitReceiver("")
|
|
r.set_transit_key(KEY)
|
|
|
|
self.assertEqual(s._send_this(), (
|
|
b"transit sender "
|
|
b"559bdeae4b49fa6a23378d2b68f4c7e69378615d4af049c371c6a26e82391089"
|
|
b" ready\n\n"))
|
|
self.assertEqual(s._send_this(), r._expect_this())
|
|
|
|
self.assertEqual(r._send_this(), (
|
|
b"transit receiver "
|
|
b"ed447528194bac4c00d0c854b12a97ce51413d89aa74d6304475f516fdc23a1b"
|
|
b" ready\n\n"))
|
|
self.assertEqual(r._send_this(), s._expect_this())
|
|
|
|
self.assertEqual(
|
|
hexlify(s._sender_record_key()),
|
|
b"5a2fba3a9e524ab2e2823ff53b05f946896f6e4ce4e282ffd8e3ac0e5e9e0cda"
|
|
)
|
|
self.assertEqual(
|
|
hexlify(s._sender_record_key()), hexlify(r._receiver_record_key()))
|
|
|
|
self.assertEqual(
|
|
hexlify(r._sender_record_key()),
|
|
b"eedb143117249f45b39da324decf6bd9aae33b7ccd58487436de611a3c6b871d"
|
|
)
|
|
self.assertEqual(
|
|
hexlify(r._sender_record_key()), hexlify(s._receiver_record_key()))
|
|
|
|
def test_connection_ready(self):
|
|
s = transit.TransitSender("")
|
|
self.assertEqual(s.connection_ready("p1"), "go")
|
|
self.assertEqual(s._winner, "p1")
|
|
self.assertEqual(s.connection_ready("p2"), "nevermind")
|
|
self.assertEqual(s._winner, "p1")
|
|
|
|
r = transit.TransitReceiver("")
|
|
self.assertEqual(r.connection_ready("p1"), "wait-for-decision")
|
|
self.assertEqual(r.connection_ready("p2"), "wait-for-decision")
|
|
|
|
|
|
class Listener(unittest.TestCase):
|
|
def test_listener(self):
|
|
c = transit.Common("")
|
|
hints, ep = c._build_listener()
|
|
self.assertIsInstance(hints, (list, set))
|
|
if hints:
|
|
self.assertIsInstance(hints[0], DirectTCPV1Hint)
|
|
self.assertIsInstance(ep, endpoints.TCP4ServerEndpoint)
|
|
|
|
def test_get_direct_hints(self):
|
|
# this actually starts the listener
|
|
c = transit.TransitSender("")
|
|
|
|
d = c.get_connection_hints()
|
|
hints = self.successResultOf(d)
|
|
|
|
# the hints are supposed to be cached, so calling this twice won't
|
|
# start a second listener
|
|
self.assert_(c._listener)
|
|
d2 = c.get_connection_hints()
|
|
self.assertEqual(self.successResultOf(d2), hints)
|
|
|
|
c._stop_listening()
|
|
|
|
|
|
class DummyProtocol(protocol.Protocol):
|
|
def __init__(self):
|
|
self.buf = b""
|
|
self._count = None
|
|
self._d2 = None
|
|
|
|
def wait_for(self, count):
|
|
if len(self.buf) >= count:
|
|
data = self.buf[:count]
|
|
self.buf = self.buf[count:]
|
|
return defer.succeed(data)
|
|
self._d = defer.Deferred()
|
|
self._count = count
|
|
return self._d
|
|
|
|
def dataReceived(self, data):
|
|
self.buf += data
|
|
# print("oDR", self._count, len(self.buf))
|
|
if self._count is not None and len(self.buf) >= self._count:
|
|
got = self.buf[:self._count]
|
|
self.buf = self.buf[self._count:]
|
|
self._count = None
|
|
self._d.callback(got)
|
|
|
|
def wait_for_disconnect(self):
|
|
self._d2 = defer.Deferred()
|
|
return self._d2
|
|
|
|
def connectionLost(self, reason):
|
|
if self._d2:
|
|
self._d2.callback(None)
|
|
|
|
|
|
class FakeTransport:
|
|
signalConnectionLost = True
|
|
|
|
def __init__(self, p, peeraddr):
|
|
self.protocol = p
|
|
self._peeraddr = peeraddr
|
|
self._buf = b""
|
|
self._connected = True
|
|
|
|
def write(self, data):
|
|
self._buf += data
|
|
|
|
def loseConnection(self):
|
|
self._connected = False
|
|
if self.signalConnectionLost:
|
|
self.protocol.connectionLost()
|
|
|
|
def getPeer(self):
|
|
return self._peeraddr
|
|
|
|
def read_buf(self):
|
|
b = self._buf
|
|
self._buf = b""
|
|
return b
|
|
|
|
|
|
class RandomError(Exception):
|
|
pass
|
|
|
|
|
|
class MockConnection:
|
|
def __init__(self, owner, relay_handshake, start, description):
|
|
self.owner = owner
|
|
self.relay_handshake = relay_handshake
|
|
self.start = start
|
|
self._description = description
|
|
|
|
def cancel(d):
|
|
self._cancelled = True
|
|
|
|
self._d = defer.Deferred(cancel)
|
|
self._start_negotiation_called = False
|
|
self._cancelled = False
|
|
|
|
def startNegotiation(self):
|
|
self._start_negotiation_called = True
|
|
return self._d
|
|
|
|
|
|
class InboundConnectionFactory(unittest.TestCase):
|
|
def test_describe(self):
|
|
f = transit.InboundConnectionFactory(None)
|
|
addrH = address.HostnameAddress("example.com", 1234)
|
|
self.assertEqual(f._describePeer(addrH), "<-example.com:1234")
|
|
addr4 = address.IPv4Address("TCP", "1.2.3.4", 1234)
|
|
self.assertEqual(f._describePeer(addr4), "<-1.2.3.4:1234")
|
|
addr6 = address.IPv6Address("TCP", "::1", 1234)
|
|
self.assertEqual(f._describePeer(addr6), "<-::1:1234")
|
|
addrU = address.UNIXAddress("/dev/unlikely")
|
|
self.assertEqual(
|
|
f._describePeer(addrU), "<-UNIXAddress('/dev/unlikely')")
|
|
|
|
def test_success(self):
|
|
f = transit.InboundConnectionFactory("owner")
|
|
f.protocol = MockConnection
|
|
d = f.whenDone()
|
|
self.assertNoResult(d)
|
|
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
p = f.buildProtocol(addr)
|
|
self.assertIsInstance(p, MockConnection)
|
|
self.assertEqual(p.owner, "owner")
|
|
self.assertEqual(p.relay_handshake, None)
|
|
self.assertEqual(p._start_negotiation_called, False)
|
|
# meh .start
|
|
|
|
# this is normally called from Connection.connectionMade
|
|
f.connectionWasMade(p)
|
|
self.assertEqual(p._start_negotiation_called, True)
|
|
self.assertNoResult(d)
|
|
self.assertEqual(p._description, "<-example.com:1234")
|
|
|
|
p._d.callback(p)
|
|
self.assertEqual(self.successResultOf(d), p)
|
|
|
|
def test_one_fail_one_success(self):
|
|
f = transit.InboundConnectionFactory("owner")
|
|
f.protocol = MockConnection
|
|
d = f.whenDone()
|
|
self.assertNoResult(d)
|
|
|
|
addr1 = address.HostnameAddress("example.com", 1234)
|
|
addr2 = address.HostnameAddress("example.com", 5678)
|
|
p1 = f.buildProtocol(addr1)
|
|
p2 = f.buildProtocol(addr2)
|
|
|
|
f.connectionWasMade(p1)
|
|
f.connectionWasMade(p2)
|
|
self.assertNoResult(d)
|
|
|
|
p1._d.errback(transit.BadHandshake("nope"))
|
|
self.assertNoResult(d)
|
|
p2._d.callback(p2)
|
|
self.assertEqual(self.successResultOf(d), p2)
|
|
|
|
def test_first_success_wins(self):
|
|
f = transit.InboundConnectionFactory("owner")
|
|
f.protocol = MockConnection
|
|
d = f.whenDone()
|
|
self.assertNoResult(d)
|
|
|
|
addr1 = address.HostnameAddress("example.com", 1234)
|
|
addr2 = address.HostnameAddress("example.com", 5678)
|
|
p1 = f.buildProtocol(addr1)
|
|
p2 = f.buildProtocol(addr2)
|
|
|
|
f.connectionWasMade(p1)
|
|
f.connectionWasMade(p2)
|
|
self.assertNoResult(d)
|
|
|
|
p1._d.callback(p1)
|
|
self.assertEqual(self.successResultOf(d), p1)
|
|
self.assertEqual(p1._cancelled, False)
|
|
self.assertEqual(p2._cancelled, True)
|
|
|
|
def test_log_other_errors(self):
|
|
f = transit.InboundConnectionFactory("owner")
|
|
f.protocol = MockConnection
|
|
d = f.whenDone()
|
|
self.assertNoResult(d)
|
|
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
p1 = f.buildProtocol(addr)
|
|
|
|
# if the Connection protocol throws an unexpected error, that should
|
|
# get logged to the Twisted logs (as an Unhandled Error in Deferred)
|
|
# so we can diagnose the bug
|
|
f.connectionWasMade(p1)
|
|
our_error = RandomError("boom1")
|
|
p1._d.errback(our_error)
|
|
self.assertNoResult(d)
|
|
|
|
log.msg("=== note: the next RandomError is expected ===")
|
|
# Make sure the Deferred has gone out of scope, so the UnhandledError
|
|
# happens quickly. We must manually break the gc cycle.
|
|
del p1._d
|
|
gc.collect() # make PyPy happy
|
|
errors = self.flushLoggedErrors(RandomError)
|
|
self.assertEqual(1, len(errors))
|
|
self.assertEqual(our_error, errors[0].value)
|
|
log.msg("=== note: the preceding RandomError was expected ===")
|
|
|
|
def test_cancel(self):
|
|
f = transit.InboundConnectionFactory("owner")
|
|
f.protocol = MockConnection
|
|
d = f.whenDone()
|
|
self.assertNoResult(d)
|
|
|
|
addr1 = address.HostnameAddress("example.com", 1234)
|
|
addr2 = address.HostnameAddress("example.com", 5678)
|
|
p1 = f.buildProtocol(addr1)
|
|
p2 = f.buildProtocol(addr2)
|
|
|
|
f.connectionWasMade(p1)
|
|
f.connectionWasMade(p2)
|
|
self.assertNoResult(d)
|
|
|
|
d.cancel()
|
|
|
|
self.failureResultOf(d, defer.CancelledError)
|
|
self.assertEqual(p1._cancelled, True)
|
|
self.assertEqual(p2._cancelled, True)
|
|
|
|
|
|
# XXX check descriptions
|
|
|
|
|
|
class OutboundConnectionFactory(unittest.TestCase):
|
|
def test_success(self):
|
|
f = transit.OutboundConnectionFactory("owner", "relay_handshake",
|
|
"description")
|
|
f.protocol = MockConnection
|
|
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
p = f.buildProtocol(addr)
|
|
self.assertIsInstance(p, MockConnection)
|
|
self.assertEqual(p.owner, "owner")
|
|
self.assertEqual(p.relay_handshake, "relay_handshake")
|
|
self.assertEqual(p._start_negotiation_called, False)
|
|
# meh .start
|
|
|
|
# this is normally called from Connection.connectionMade
|
|
f.connectionWasMade(p) # no-op for outbound
|
|
self.assertEqual(p._start_negotiation_called, False)
|
|
|
|
|
|
class MockOwner:
|
|
_connection_ready_called = False
|
|
|
|
def connection_ready(self, connection):
|
|
self._connection_ready_called = True
|
|
self._connection = connection
|
|
return self._state
|
|
|
|
def _send_this(self):
|
|
return b"send_this"
|
|
|
|
def _expect_this(self):
|
|
return b"expect_this"
|
|
|
|
def _sender_record_key(self):
|
|
return b"s" * 32
|
|
|
|
def _receiver_record_key(self):
|
|
return b"r" * 32
|
|
|
|
|
|
class MockFactory:
|
|
_connectionWasMade_called = False
|
|
|
|
def connectionWasMade(self, p):
|
|
self._connectionWasMade_called = True
|
|
self._p = p
|
|
|
|
|
|
class Connection(unittest.TestCase):
|
|
# exercise the Connection protocol class
|
|
|
|
def test_check_and_remove(self):
|
|
c = transit.Connection(None, None, None, "description")
|
|
c.buf = b""
|
|
EXP = b"expectation"
|
|
self.assertFalse(c._check_and_remove(EXP))
|
|
self.assertEqual(c.buf, b"")
|
|
|
|
c.buf = b"unexpected"
|
|
e = self.assertRaises(transit.BadHandshake, c._check_and_remove, EXP)
|
|
self.assertEqual(
|
|
str(e), "got %r want %r" % (b'unexpected', b'expectation'))
|
|
self.assertEqual(c.buf, b"unexpected")
|
|
|
|
c.buf = b"expect"
|
|
self.assertFalse(c._check_and_remove(EXP))
|
|
self.assertEqual(c.buf, b"expect")
|
|
|
|
c.buf = b"expectation"
|
|
self.assertTrue(c._check_and_remove(EXP))
|
|
self.assertEqual(c.buf, b"")
|
|
|
|
c.buf = b"expectation exceeded"
|
|
self.assertTrue(c._check_and_remove(EXP))
|
|
self.assertEqual(c.buf, b" exceeded")
|
|
|
|
def test_describe(self):
|
|
c = transit.Connection(None, None, None, "description")
|
|
self.assertEqual(c.describe(), "description")
|
|
|
|
def test_sender_accepting(self):
|
|
relay_handshake = None
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, relay_handshake, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
|
|
owner._state = "go"
|
|
d = c.startNegotiation()
|
|
self.assertEqual(c.state, "handshake")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
self.assertNoResult(d)
|
|
|
|
c.dataReceived(b"expect_this")
|
|
self.assertEqual(t.read_buf(), b"go\n")
|
|
self.assertEqual(t._connected, True)
|
|
self.assertEqual(c.state, "records")
|
|
self.assertEqual(self.successResultOf(d), c)
|
|
|
|
c.close()
|
|
self.assertEqual(t._connected, False)
|
|
|
|
def test_sender_rejecting(self):
|
|
relay_handshake = None
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, relay_handshake, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
|
|
owner._state = "nevermind"
|
|
d = c.startNegotiation()
|
|
self.assertEqual(c.state, "handshake")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
self.assertNoResult(d)
|
|
|
|
c.dataReceived(b"expect_this")
|
|
self.assertEqual(t.read_buf(), b"nevermind\n")
|
|
self.assertEqual(t._connected, False)
|
|
self.assertEqual(c.state, "hung up")
|
|
f = self.failureResultOf(d, transit.BadHandshake)
|
|
self.assertEqual(str(f.value), "abandoned")
|
|
|
|
def test_handshake_other_error(self):
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
|
|
d = c.startNegotiation()
|
|
self.assertEqual(c.state, "handshake")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
self.assertNoResult(d)
|
|
c.state = RandomError("boom2")
|
|
self.assertRaises(RandomError, c.dataReceived, b"surprise!")
|
|
self.assertEqual(t._connected, False)
|
|
self.assertEqual(c.state, "hung up")
|
|
self.failureResultOf(d, RandomError)
|
|
|
|
def test_handshake_bad_state(self):
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
|
|
d = c.startNegotiation()
|
|
self.assertEqual(c.state, "handshake")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
self.assertNoResult(d)
|
|
c.state = "unknown-bogus-state"
|
|
self.assertRaises(ValueError, c.dataReceived, b"surprise!")
|
|
self.assertEqual(t._connected, False)
|
|
self.assertEqual(c.state, "hung up")
|
|
self.failureResultOf(d, ValueError)
|
|
|
|
def test_relay_handshake(self):
|
|
relay_handshake = b"relay handshake"
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, relay_handshake, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation
|
|
|
|
owner._state = "go"
|
|
d = c.startNegotiation()
|
|
self.assertEqual(t.read_buf(), relay_handshake)
|
|
self.assertEqual(c.state, "relay") # waiting for OK from relay
|
|
|
|
c.dataReceived(b"ok\n")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
self.assertEqual(c.state, "handshake")
|
|
|
|
self.assertNoResult(d)
|
|
|
|
c.dataReceived(b"expect_this")
|
|
self.assertEqual(c.state, "records")
|
|
self.assertEqual(self.successResultOf(d), c)
|
|
|
|
self.assertEqual(t.read_buf(), b"go\n")
|
|
|
|
def test_relay_handshake_bad(self):
|
|
relay_handshake = b"relay handshake"
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, relay_handshake, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation
|
|
|
|
owner._state = "go"
|
|
d = c.startNegotiation()
|
|
self.assertEqual(t.read_buf(), relay_handshake)
|
|
self.assertEqual(c.state, "relay") # waiting for OK from relay
|
|
|
|
c.dataReceived(b"not ok\n")
|
|
self.assertEqual(t._connected, False)
|
|
self.assertEqual(c.state, "hung up")
|
|
|
|
f = self.failureResultOf(d, transit.BadHandshake)
|
|
self.assertEqual(
|
|
str(f.value), "got %r want %r" % (b"not ok\n", b"ok\n"))
|
|
|
|
def test_receiver_accepted(self):
|
|
# we're on the receiving side, so we wait for the sender to decide
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
|
|
owner._state = "wait-for-decision"
|
|
d = c.startNegotiation()
|
|
self.assertEqual(c.state, "handshake")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
self.assertNoResult(d)
|
|
|
|
c.dataReceived(b"expect_this")
|
|
self.assertEqual(c.state, "wait-for-decision")
|
|
self.assertNoResult(d)
|
|
|
|
c.dataReceived(b"go\n")
|
|
self.assertEqual(c.state, "records")
|
|
self.assertEqual(self.successResultOf(d), c)
|
|
|
|
def test_receiver_rejected_politely(self):
|
|
# we're on the receiving side, so we wait for the sender to decide
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
|
|
owner._state = "wait-for-decision"
|
|
d = c.startNegotiation()
|
|
self.assertEqual(c.state, "handshake")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
self.assertNoResult(d)
|
|
|
|
c.dataReceived(b"expect_this")
|
|
self.assertEqual(c.state, "wait-for-decision")
|
|
self.assertNoResult(d)
|
|
|
|
c.dataReceived(b"nevermind\n") # polite rejection
|
|
self.assertEqual(t._connected, False)
|
|
self.assertEqual(c.state, "hung up")
|
|
f = self.failureResultOf(d, transit.BadHandshake)
|
|
self.assertEqual(
|
|
str(f.value), "got %r want %r" % (b"nevermind\n", b"go\n"))
|
|
|
|
def test_receiver_rejected_rudely(self):
|
|
# we're on the receiving side, so we wait for the sender to decide
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
self.assertEqual(factory._connectionWasMade_called, True)
|
|
self.assertEqual(factory._p, c)
|
|
|
|
owner._state = "wait-for-decision"
|
|
d = c.startNegotiation()
|
|
self.assertEqual(c.state, "handshake")
|
|
self.assertEqual(t.read_buf(), b"send_this")
|
|
self.assertNoResult(d)
|
|
|
|
c.dataReceived(b"expect_this")
|
|
self.assertEqual(c.state, "wait-for-decision")
|
|
self.assertNoResult(d)
|
|
|
|
t.loseConnection()
|
|
self.assertEqual(t._connected, False)
|
|
f = self.failureResultOf(d, transit.BadHandshake)
|
|
self.assertEqual(str(f.value), "connection lost")
|
|
|
|
def test_cancel(self):
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
|
|
d = c.startNegotiation()
|
|
# while we're waiting for negotiation, we get cancelled
|
|
d.cancel()
|
|
|
|
self.assertEqual(t._connected, False)
|
|
self.assertEqual(c.state, "hung up")
|
|
self.failureResultOf(d, defer.CancelledError)
|
|
|
|
def test_timeout(self):
|
|
clock = task.Clock()
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
|
|
def _callLater(period, func):
|
|
clock.callLater(period, func)
|
|
|
|
c.callLater = _callLater
|
|
self.assertEqual(c.state, "too-early")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
# the timer should now be running
|
|
d = c.startNegotiation()
|
|
# while we're waiting for negotiation, the timer expires
|
|
clock.advance(transit.TIMEOUT + 1.0)
|
|
|
|
self.assertEqual(t._connected, False)
|
|
f = self.failureResultOf(d, transit.BadHandshake)
|
|
self.assertEqual(str(f.value), "timeout")
|
|
|
|
def make_connection(self):
|
|
owner = MockOwner()
|
|
factory = MockFactory()
|
|
addr = address.HostnameAddress("example.com", 1234)
|
|
c = transit.Connection(owner, None, None, "description")
|
|
t = c.transport = FakeTransport(c, addr)
|
|
c.factory = factory
|
|
c.connectionMade()
|
|
|
|
owner._state = "go"
|
|
d = c.startNegotiation()
|
|
c.dataReceived(b"expect_this")
|
|
self.assertEqual(self.successResultOf(d), c)
|
|
t.read_buf() # flush input buffer, prepare for encrypted records
|
|
|
|
return t, c, owner
|
|
|
|
def test_records_not_binary(self):
|
|
t, c, owner = self.make_connection()
|
|
|
|
RECORD1 = u"not binary"
|
|
with self.assertRaises(InternalError):
|
|
c.send_record(RECORD1)
|
|
|
|
def test_records_good(self):
|
|
# now make sure that outbound records are encrypted properly
|
|
t, c, owner = self.make_connection()
|
|
|
|
RECORD1 = b"record"
|
|
c.send_record(RECORD1)
|
|
buf = t.read_buf()
|
|
expected = ("%08x" % (24 + len(RECORD1) + 16)).encode("ascii")
|
|
self.assertEqual(hexlify(buf[:4]), expected)
|
|
encrypted = buf[4:]
|
|
receive_box = SecretBox(owner._sender_record_key())
|
|
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
|
|
nonce = int(hexlify(nonce_buf), 16)
|
|
self.assertEqual(nonce, 0) # first message gets nonce 0
|
|
decrypted = receive_box.decrypt(encrypted)
|
|
self.assertEqual(decrypted, RECORD1)
|
|
|
|
# second message gets nonce 1
|
|
RECORD2 = b"record2"
|
|
c.send_record(RECORD2)
|
|
buf = t.read_buf()
|
|
expected = ("%08x" % (24 + len(RECORD2) + 16)).encode("ascii")
|
|
self.assertEqual(hexlify(buf[:4]), expected)
|
|
encrypted = buf[4:]
|
|
receive_box = SecretBox(owner._sender_record_key())
|
|
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
|
|
nonce = int(hexlify(nonce_buf), 16)
|
|
self.assertEqual(nonce, 1)
|
|
decrypted = receive_box.decrypt(encrypted)
|
|
self.assertEqual(decrypted, RECORD2)
|
|
|
|
# and that we can receive records properly
|
|
inbound_records = []
|
|
c.recordReceived = inbound_records.append
|
|
send_box = SecretBox(owner._receiver_record_key())
|
|
|
|
RECORD3 = b"record3"
|
|
nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0
|
|
encrypted = send_box.encrypt(RECORD3, nonce_buf)
|
|
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
|
c.dataReceived(length[:2])
|
|
c.dataReceived(length[2:])
|
|
c.dataReceived(encrypted[:-2])
|
|
self.assertEqual(inbound_records, [])
|
|
c.dataReceived(encrypted[-2:])
|
|
self.assertEqual(inbound_records, [RECORD3])
|
|
|
|
RECORD4 = b"record4"
|
|
nonce_buf = unhexlify("%048x" % 1) # nonces increment
|
|
encrypted = send_box.encrypt(RECORD4, nonce_buf)
|
|
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
|
c.dataReceived(length[:2])
|
|
c.dataReceived(length[2:])
|
|
c.dataReceived(encrypted[:-2])
|
|
self.assertEqual(inbound_records, [RECORD3])
|
|
c.dataReceived(encrypted[-2:])
|
|
self.assertEqual(inbound_records, [RECORD3, RECORD4])
|
|
|
|
# receiving two records at the same time: deliver both
|
|
inbound_records[:] = []
|
|
RECORD5 = b"record5"
|
|
nonce_buf = unhexlify("%048x" % 2) # nonces increment
|
|
encrypted = send_box.encrypt(RECORD5, nonce_buf)
|
|
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
|
r5 = length + encrypted
|
|
RECORD6 = b"record6"
|
|
nonce_buf = unhexlify("%048x" % 3) # nonces increment
|
|
encrypted = send_box.encrypt(RECORD6, nonce_buf)
|
|
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
|
r6 = length + encrypted
|
|
c.dataReceived(r5 + r6)
|
|
self.assertEqual(inbound_records, [RECORD5, RECORD6])
|
|
|
|
def corrupt(self, orig):
|
|
last_byte = orig[-1:]
|
|
num = int(hexlify(last_byte).decode("ascii"), 16)
|
|
corrupt_num = 256 - num
|
|
as_byte = unhexlify("%02x" % corrupt_num)
|
|
return orig[:-1] + as_byte
|
|
|
|
def test_records_corrupt(self):
|
|
# corrupt records should be rejected
|
|
t, c, owner = self.make_connection()
|
|
|
|
inbound_records = []
|
|
c.recordReceived = inbound_records.append
|
|
|
|
RECORD = b"record"
|
|
send_box = SecretBox(owner._receiver_record_key())
|
|
nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0
|
|
encrypted = self.corrupt(send_box.encrypt(RECORD, nonce_buf))
|
|
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
|
c.dataReceived(length)
|
|
c.dataReceived(encrypted[:-2])
|
|
self.assertEqual(inbound_records, [])
|
|
self.assertRaises(CryptoError, c.dataReceived, encrypted[-2:])
|
|
self.assertEqual(inbound_records, [])
|
|
# and the connection should have been dropped
|
|
self.assertEqual(t._connected, False)
|
|
|
|
def test_out_of_order_nonce(self):
|
|
# an inbound out-of-order nonce should be rejected
|
|
t, c, owner = self.make_connection()
|
|
|
|
inbound_records = []
|
|
c.recordReceived = inbound_records.append
|
|
|
|
RECORD = b"record"
|
|
send_box = SecretBox(owner._receiver_record_key())
|
|
nonce_buf = unhexlify("%048x" % 1) # first nonce must be 0
|
|
encrypted = send_box.encrypt(RECORD, nonce_buf)
|
|
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
|
c.dataReceived(length)
|
|
c.dataReceived(encrypted[:-2])
|
|
self.assertEqual(inbound_records, [])
|
|
self.assertRaises(transit.BadNonce, c.dataReceived, encrypted[-2:])
|
|
self.assertEqual(inbound_records, [])
|
|
# and the connection should have been dropped
|
|
self.assertEqual(t._connected, False)
|
|
|
|
# TODO: check that .connectionLost/loseConnection signatures are
|
|
# consistent: zero args, or one arg?
|
|
|
|
# XXX: if we don't set the transit key before connecting, what
|
|
# happens? We currently get a type-check assertion from HKDF because
|
|
# the key is None.
|
|
|
|
def test_receive_queue(self):
|
|
c = transit.Connection(None, None, None, "description")
|
|
c.transport = FakeTransport(c, None)
|
|
c.transport.signalConnectionLost = False
|
|
c.recordReceived(b"0")
|
|
c.recordReceived(b"1")
|
|
c.recordReceived(b"2")
|
|
d0 = c.receive_record()
|
|
self.assertEqual(self.successResultOf(d0), b"0")
|
|
d1 = c.receive_record()
|
|
d2 = c.receive_record()
|
|
# they must fire in order of receipt, not order of addCallback
|
|
self.assertEqual(self.successResultOf(d2), b"2")
|
|
self.assertEqual(self.successResultOf(d1), b"1")
|
|
|
|
d3 = c.receive_record()
|
|
d4 = c.receive_record()
|
|
self.assertNoResult(d3)
|
|
self.assertNoResult(d4)
|
|
|
|
c.recordReceived(b"3")
|
|
self.assertEqual(self.successResultOf(d3), b"3")
|
|
self.assertNoResult(d4)
|
|
|
|
c.recordReceived(b"4")
|
|
self.assertEqual(self.successResultOf(d4), b"4")
|
|
|
|
d5 = c.receive_record()
|
|
c.close()
|
|
self.failureResultOf(d5, error.ConnectionClosed)
|
|
|
|
def test_producer(self):
|
|
# a Transit object (receiving data from the remote peer) produces
|
|
# data and writes it into a local Consumer
|
|
c = transit.Connection(None, None, None, "description")
|
|
c.transport = proto_helpers.StringTransport()
|
|
c.recordReceived(b"r1.")
|
|
c.recordReceived(b"r2.")
|
|
|
|
consumer = proto_helpers.StringTransport()
|
|
rv = c.connectConsumer(consumer)
|
|
self.assertIs(rv, None)
|
|
self.assertIs(c._consumer, consumer)
|
|
self.assertEqual(consumer.value(), b"r1.r2.")
|
|
|
|
self.assertRaises(RuntimeError, c.connectConsumer, consumer)
|
|
|
|
c.recordReceived(b"r3.")
|
|
self.assertEqual(consumer.value(), b"r1.r2.r3.")
|
|
|
|
c.pauseProducing()
|
|
self.assertEqual(c.transport.producerState, "paused")
|
|
c.resumeProducing()
|
|
self.assertEqual(c.transport.producerState, "producing")
|
|
|
|
c.disconnectConsumer()
|
|
self.assertEqual(consumer.producer, None)
|
|
c.connectConsumer(consumer)
|
|
|
|
c.stopProducing()
|
|
self.assertEqual(c.transport.producerState, "stopped")
|
|
|
|
def test_connectConsumer(self):
|
|
# connectConsumer() takes an optional number of bytes to expect, and
|
|
# fires a Deferred when that many have been written
|
|
c = transit.Connection(None, None, None, "description")
|
|
c._negotiation_d.addErrback(lambda err: None) # eat it
|
|
c.transport = proto_helpers.StringTransport()
|
|
c.recordReceived(b"r1.")
|
|
|
|
consumer = proto_helpers.StringTransport()
|
|
d = c.connectConsumer(consumer, expected=10)
|
|
self.assertEqual(consumer.value(), b"r1.")
|
|
self.assertNoResult(d)
|
|
|
|
c.recordReceived(b"r2.")
|
|
self.assertEqual(consumer.value(), b"r1.r2.")
|
|
self.assertNoResult(d)
|
|
|
|
c.recordReceived(b"r3.")
|
|
self.assertEqual(consumer.value(), b"r1.r2.r3.")
|
|
self.assertNoResult(d)
|
|
|
|
c.recordReceived(b"!")
|
|
self.assertEqual(consumer.value(), b"r1.r2.r3.!")
|
|
self.assertEqual(self.successResultOf(d), 10)
|
|
|
|
# that should automatically disconnect the consumer, and subsequent
|
|
# records should get queued, not delivered
|
|
self.assertIs(c._consumer, None)
|
|
c.recordReceived(b"overflow")
|
|
self.assertEqual(consumer.value(), b"r1.r2.r3.!")
|
|
|
|
# now test that the Deferred errbacks when the connection is lost
|
|
d = c.connectConsumer(consumer, expected=10)
|
|
|
|
c.connectionLost()
|
|
self.failureResultOf(d, error.ConnectionClosed)
|
|
|
|
def test_connectConsumer_empty(self):
|
|
# if connectConsumer() expects 0 bytes (e.g. someone is "sending" a
|
|
# zero-length file), make sure it gets woken up right away, so it can
|
|
# disconnect itself, even though no bytes will actually arrive
|
|
c = transit.Connection(None, None, None, "description")
|
|
c._negotiation_d.addErrback(lambda err: None) # eat it
|
|
c.transport = proto_helpers.StringTransport()
|
|
|
|
consumer = proto_helpers.StringTransport()
|
|
d = c.connectConsumer(consumer, expected=0)
|
|
self.assertEqual(self.successResultOf(d), 0)
|
|
self.assertEqual(consumer.value(), b"")
|
|
self.assertIs(c._consumer, None)
|
|
|
|
def test_writeToFile(self):
|
|
c = transit.Connection(None, None, None, "description")
|
|
c._negotiation_d.addErrback(lambda err: None) # eat it
|
|
c.transport = proto_helpers.StringTransport()
|
|
c.recordReceived(b"r1.")
|
|
|
|
f = io.BytesIO()
|
|
progress = []
|
|
d = c.writeToFile(f, 10, progress.append)
|
|
self.assertEqual(f.getvalue(), b"r1.")
|
|
self.assertEqual(progress, [3])
|
|
self.assertNoResult(d)
|
|
|
|
c.recordReceived(b"r2.")
|
|
self.assertEqual(f.getvalue(), b"r1.r2.")
|
|
self.assertEqual(progress, [3, 3])
|
|
self.assertNoResult(d)
|
|
|
|
c.recordReceived(b"r3.")
|
|
self.assertEqual(f.getvalue(), b"r1.r2.r3.")
|
|
self.assertEqual(progress, [3, 3, 3])
|
|
self.assertNoResult(d)
|
|
|
|
c.recordReceived(b"!")
|
|
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
|
|
self.assertEqual(progress, [3, 3, 3, 1])
|
|
self.assertEqual(self.successResultOf(d), 10)
|
|
|
|
# that should automatically disconnect the consumer, and subsequent
|
|
# records should get queued, not delivered
|
|
self.assertIs(c._consumer, None)
|
|
c.recordReceived(b"overflow.")
|
|
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
|
|
self.assertEqual(progress, [3, 3, 3, 1])
|
|
|
|
# test what happens when enough data is queued ahead of time
|
|
c.recordReceived(b"second.") # now "overflow.second."
|
|
c.recordReceived(b"third.") # now "overflow.second.third."
|
|
f = io.BytesIO()
|
|
d = c.writeToFile(f, 10)
|
|
self.assertEqual(f.getvalue(), b"overflow.second.") # whole records
|
|
self.assertEqual(self.successResultOf(d), 16)
|
|
self.assertEqual(list(c._inbound_records), [b"third."])
|
|
|
|
# now test that the Deferred errbacks when the connection is lost
|
|
d = c.writeToFile(f, 10)
|
|
|
|
c.connectionLost()
|
|
self.failureResultOf(d, error.ConnectionClosed)
|
|
|
|
def test_consumer(self):
|
|
# a local producer sends data to a consuming Transit object
|
|
c = transit.Connection(None, None, None, "description")
|
|
c.transport = proto_helpers.StringTransport()
|
|
records = []
|
|
c.send_record = records.append
|
|
|
|
producer = proto_helpers.StringTransport()
|
|
c.registerProducer(producer, True)
|
|
self.assertIs(c.transport.producer, producer)
|
|
|
|
c.write(b"r1.")
|
|
self.assertEqual(records, [b"r1."])
|
|
|
|
c.unregisterProducer()
|
|
self.assertEqual(c.transport.producer, None)
|
|
|
|
|
|
class FileConsumer(unittest.TestCase):
|
|
def test_basic(self):
|
|
f = io.BytesIO()
|
|
progress = []
|
|
fc = transit.FileConsumer(f, progress.append)
|
|
self.assertEqual(progress, [])
|
|
self.assertEqual(f.getvalue(), b"")
|
|
fc.write(b"." * 99)
|
|
self.assertEqual(progress, [99])
|
|
self.assertEqual(f.getvalue(), b"." * 99)
|
|
fc.write(b"!")
|
|
self.assertEqual(progress, [99, 1])
|
|
self.assertEqual(f.getvalue(), b"." * 99 + b"!")
|
|
|
|
def test_hasher(self):
|
|
hashee = []
|
|
f = io.BytesIO()
|
|
progress = []
|
|
fc = transit.FileConsumer(f, progress.append, hasher=hashee.append)
|
|
self.assertEqual(progress, [])
|
|
self.assertEqual(f.getvalue(), b"")
|
|
self.assertEqual(hashee, [])
|
|
fc.write(b"." * 99)
|
|
self.assertEqual(progress, [99])
|
|
self.assertEqual(f.getvalue(), b"." * 99)
|
|
self.assertEqual(hashee, [b"." * 99])
|
|
fc.write(b"!")
|
|
self.assertEqual(progress, [99, 1])
|
|
self.assertEqual(f.getvalue(), b"." * 99 + b"!")
|
|
self.assertEqual(hashee, [b"." * 99, b"!"])
|
|
|
|
|
|
DIRECT_HINT_JSON = {
|
|
"type": "direct-tcp-v1",
|
|
"hostname": "direct",
|
|
"port": 1234
|
|
}
|
|
RELAY_HINT_JSON = {
|
|
"type": "relay-v1",
|
|
"hints": [{
|
|
"type": "direct-tcp-v1",
|
|
"hostname": "relay",
|
|
"port": 1234
|
|
}]
|
|
}
|
|
UNRECOGNIZED_DIRECT_HINT_JSON = {
|
|
"type": "direct-tcp-v1",
|
|
"hostname": ["cannot", "parse", "list"]
|
|
}
|
|
UNRECOGNIZED_HINT_JSON = {"type": "unknown"}
|
|
UNAVAILABLE_HINT_JSON = {
|
|
"type": "direct-tcp-v1", # e.g. Tor without txtorcon
|
|
"hostname": "unavailable",
|
|
"port": 1234
|
|
}
|
|
RELAY_HINT2_JSON = {
|
|
"type":
|
|
"relay-v1",
|
|
"hints": [{
|
|
"type": "direct-tcp-v1",
|
|
"hostname": "relay",
|
|
"port": 1234
|
|
}, UNRECOGNIZED_HINT_JSON]
|
|
}
|
|
UNAVAILABLE_RELAY_HINT_JSON = {
|
|
"type": "relay-v1",
|
|
"hints": [UNAVAILABLE_HINT_JSON]
|
|
}
|
|
|
|
|
|
class Transit(unittest.TestCase):
|
|
def setUp(self):
|
|
self._connectors = []
|
|
self._waiters = []
|
|
self._descriptions = []
|
|
|
|
def _start_connector(self, ep, description, is_relay=False):
|
|
d = defer.Deferred()
|
|
self._connectors.append(ep)
|
|
self._waiters.append(d)
|
|
self._descriptions.append(description)
|
|
return d
|
|
|
|
@inlineCallbacks
|
|
def test_success_direct(self):
|
|
clock = task.Clock()
|
|
s = transit.TransitSender("", reactor=clock)
|
|
s.set_transit_key(b"key")
|
|
hints = yield s.get_connection_hints() # start the listener
|
|
del hints
|
|
s.add_connection_hints([
|
|
DIRECT_HINT_JSON, UNRECOGNIZED_DIRECT_HINT_JSON,
|
|
UNRECOGNIZED_HINT_JSON
|
|
])
|
|
|
|
s._start_connector = self._start_connector
|
|
d = s.connect()
|
|
self.assertNoResult(d)
|
|
self.assertEqual(len(self._waiters), 1)
|
|
self.assertIsInstance(self._waiters[0], defer.Deferred)
|
|
|
|
self._waiters[0].callback("winner")
|
|
self.assertEqual(self.successResultOf(d), "winner")
|
|
self.assertEqual(self._descriptions, ["->tcp:direct:1234"])
|
|
|
|
@inlineCallbacks
|
|
def test_success_direct_tor(self):
|
|
clock = task.Clock()
|
|
s = transit.TransitSender("", tor=mock.Mock(), reactor=clock)
|
|
s.set_transit_key(b"key")
|
|
hints = yield s.get_connection_hints() # start the listener
|
|
del hints
|
|
s.add_connection_hints([DIRECT_HINT_JSON])
|
|
|
|
s._start_connector = self._start_connector
|
|
d = s.connect()
|
|
self.assertNoResult(d)
|
|
self.assertEqual(len(self._waiters), 1)
|
|
self.assertIsInstance(self._waiters[0], defer.Deferred)
|
|
|
|
self._waiters[0].callback("winner")
|
|
self.assertEqual(self.successResultOf(d), "winner")
|
|
self.assertEqual(self._descriptions, ["tor->tcp:direct:1234"])
|
|
|
|
@inlineCallbacks
|
|
def test_success_direct_tor_relay(self):
|
|
clock = task.Clock()
|
|
s = transit.TransitSender("", tor=mock.Mock(), reactor=clock)
|
|
s.set_transit_key(b"key")
|
|
hints = yield s.get_connection_hints() # start the listener
|
|
del hints
|
|
s.add_connection_hints([RELAY_HINT_JSON])
|
|
|
|
s._start_connector = self._start_connector
|
|
d = s.connect()
|
|
# move the clock forward any amount, since relay connections are
|
|
# triggered starting at T+0.0
|
|
clock.advance(1.0)
|
|
self.assertNoResult(d)
|
|
self.assertEqual(len(self._waiters), 1)
|
|
self.assertIsInstance(self._waiters[0], defer.Deferred)
|
|
|
|
self._waiters[0].callback("winner")
|
|
self.assertEqual(self.successResultOf(d), "winner")
|
|
self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"])
|
|
|
|
def _endpoint_from_hint_obj(self, hint, _tor, _reactor):
|
|
if isinstance(hint, DirectTCPV1Hint):
|
|
if hint.hostname == "unavailable":
|
|
return None
|
|
return hint.hostname
|
|
return None
|
|
|
|
@inlineCallbacks
|
|
def test_wait_for_relay(self):
|
|
clock = task.Clock()
|
|
s = transit.TransitSender("", reactor=clock, no_listen=True)
|
|
s.set_transit_key(b"key")
|
|
hints = yield s.get_connection_hints()
|
|
del hints
|
|
s.add_connection_hints(
|
|
[DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON, RELAY_HINT_JSON])
|
|
s._start_connector = self._start_connector
|
|
|
|
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
|
|
self._endpoint_from_hint_obj):
|
|
d = s.connect()
|
|
self.assertNoResult(d)
|
|
# the direct connectors are tried right away, but the relay
|
|
# connectors are stalled for a few seconds
|
|
self.assertEqual(self._connectors, ["direct"])
|
|
|
|
clock.advance(s.RELAY_DELAY + 1.0)
|
|
self.assertEqual(self._connectors, ["direct", "relay"])
|
|
|
|
self._waiters[0].callback("winner")
|
|
self.assertEqual(self.successResultOf(d), "winner")
|
|
|
|
@inlineCallbacks
|
|
def test_priorities(self):
|
|
clock = task.Clock()
|
|
s = transit.TransitSender("", reactor=clock, no_listen=True)
|
|
s.set_transit_key(b"key")
|
|
hints = yield s.get_connection_hints()
|
|
del hints
|
|
s.add_connection_hints([
|
|
{
|
|
"type":
|
|
"relay-v1",
|
|
"hints": [{
|
|
"type": "direct-tcp-v1",
|
|
"hostname": "relay",
|
|
"port": 1234
|
|
}]
|
|
},
|
|
{
|
|
"type": "direct-tcp-v1",
|
|
"hostname": "direct",
|
|
"port": 1234
|
|
},
|
|
{
|
|
"type":
|
|
"relay-v1",
|
|
"hints": [{
|
|
"type": "direct-tcp-v1",
|
|
"priority": 2.0,
|
|
"hostname": "relay2",
|
|
"port": 1234
|
|
}, {
|
|
"type": "direct-tcp-v1",
|
|
"priority": 3.0,
|
|
"hostname": "relay3",
|
|
"port": 1234
|
|
}]
|
|
},
|
|
{
|
|
"type":
|
|
"relay-v1",
|
|
"hints": [{
|
|
"type": "direct-tcp-v1",
|
|
"priority": 2.0,
|
|
"hostname": "relay4",
|
|
"port": 1234
|
|
}]
|
|
},
|
|
])
|
|
s._start_connector = self._start_connector
|
|
|
|
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
|
|
self._endpoint_from_hint_obj):
|
|
d = s.connect()
|
|
self.assertNoResult(d)
|
|
# direct connector should be used first, then the priority=3.0 relay,
|
|
# then the two 2.0 relays, then the (default) 0.0 relay
|
|
|
|
self.assertEqual(self._connectors, ["direct"])
|
|
|
|
clock.advance(s.RELAY_DELAY + 1.0)
|
|
self.assertEqual(self._connectors, ["direct", "relay3"])
|
|
|
|
clock.advance(s.RELAY_DELAY)
|
|
self.assertIn(self._connectors,
|
|
(["direct", "relay3", "relay2", "relay4"],
|
|
["direct", "relay3", "relay4", "relay2"]))
|
|
|
|
clock.advance(s.RELAY_DELAY)
|
|
self.assertIn(self._connectors,
|
|
(["direct", "relay3", "relay2", "relay4", "relay"],
|
|
["direct", "relay3", "relay4", "relay2", "relay"]))
|
|
|
|
self._waiters[0].callback("winner")
|
|
self.assertEqual(self.successResultOf(d), "winner")
|
|
|
|
@inlineCallbacks
|
|
def test_no_direct_hints(self):
|
|
clock = task.Clock()
|
|
s = transit.TransitSender("", reactor=clock, no_listen=True)
|
|
s.set_transit_key(b"key")
|
|
hints = yield s.get_connection_hints() # start the listener
|
|
del hints
|
|
# include hints that can't be turned into an endpoint at runtime
|
|
s.add_connection_hints([
|
|
UNRECOGNIZED_HINT_JSON, UNAVAILABLE_HINT_JSON, RELAY_HINT2_JSON,
|
|
UNAVAILABLE_RELAY_HINT_JSON
|
|
])
|
|
s._start_connector = self._start_connector
|
|
|
|
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
|
|
self._endpoint_from_hint_obj):
|
|
d = s.connect()
|
|
self.assertNoResult(d)
|
|
# since there are no usable direct hints, the relay connector will
|
|
# only be stalled for 0 seconds
|
|
self.assertEqual(self._connectors, [])
|
|
|
|
clock.advance(0)
|
|
self.assertEqual(self._connectors, ["relay"])
|
|
|
|
self._waiters[0].callback("winner")
|
|
self.assertEqual(self.successResultOf(d), "winner")
|
|
|
|
@inlineCallbacks
|
|
def test_no_contenders(self):
|
|
clock = task.Clock()
|
|
s = transit.TransitSender("", reactor=clock, no_listen=True)
|
|
s.set_transit_key(b"key")
|
|
hints = yield s.get_connection_hints() # start the listener
|
|
del hints
|
|
s.add_connection_hints([]) # no hints at all
|
|
s._start_connector = self._start_connector
|
|
|
|
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
|
|
self._endpoint_from_hint_obj):
|
|
d = s.connect()
|
|
f = self.failureResultOf(d, transit.TransitError)
|
|
self.assertEqual(str(f.value), "No contenders for connection")
|
|
|
|
|
|
class RelayHandshake(unittest.TestCase):
|
|
def old_build_relay_handshake(self, key):
|
|
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
|
return (token, b"please relay " + hexlify(token) + b"\n")
|
|
|
|
def test_old(self):
|
|
key = b"\x00"
|
|
token, old_handshake = self.old_build_relay_handshake(key)
|
|
tc = transit_server.TransitConnection()
|
|
tc.factory = mock.Mock()
|
|
tc.factory.connection_got_token = mock.Mock()
|
|
tc.dataReceived(old_handshake[:-1])
|
|
self.assertEqual(tc.factory.connection_got_token.mock_calls, [])
|
|
tc.dataReceived(old_handshake[-1:])
|
|
self.assertEqual(tc.factory.connection_got_token.mock_calls,
|
|
[mock.call(hexlify(token), None, tc)])
|
|
|
|
def test_new(self):
|
|
c = transit.Common(None)
|
|
c.set_transit_key(b"\x00")
|
|
new_handshake = c._build_relay_handshake()
|
|
token, old_handshake = self.old_build_relay_handshake(b"\x00")
|
|
|
|
tc = transit_server.TransitConnection()
|
|
tc.factory = mock.Mock()
|
|
tc.factory.connection_got_token = mock.Mock()
|
|
tc.dataReceived(new_handshake[:-1])
|
|
self.assertEqual(tc.factory.connection_got_token.mock_calls, [])
|
|
tc.dataReceived(new_handshake[-1:])
|
|
self.assertEqual(
|
|
tc.factory.connection_got_token.mock_calls,
|
|
[mock.call(hexlify(token), c._side.encode("ascii"), tc)])
|
|
|
|
|
|
class Full(ServerBase, unittest.TestCase):
|
|
def doBoth(self, d1, d2):
|
|
return gatherResults([d1, d2], True)
|
|
|
|
@inlineCallbacks
|
|
def test_direct(self):
|
|
KEY = b"k" * 32
|
|
s = transit.TransitSender(None)
|
|
r = transit.TransitReceiver(None)
|
|
|
|
s.set_transit_key(KEY)
|
|
r.set_transit_key(KEY)
|
|
|
|
# TODO: this sometimes fails with EADDRINUSE
|
|
shints = yield s.get_connection_hints()
|
|
rhints = yield r.get_connection_hints()
|
|
|
|
s.add_connection_hints(rhints)
|
|
r.add_connection_hints(shints)
|
|
|
|
(x, y) = yield self.doBoth(s.connect(), r.connect())
|
|
self.assertIsInstance(x, transit.Connection)
|
|
self.assertIsInstance(y, transit.Connection)
|
|
|
|
d = y.receive_record()
|
|
|
|
x.send_record(b"record1")
|
|
r = yield d
|
|
self.assertEqual(r, b"record1")
|
|
|
|
yield x.close()
|
|
yield y.close()
|
|
|
|
@inlineCallbacks
|
|
def test_relay(self):
|
|
KEY = b"k" * 32
|
|
s = transit.TransitSender(self.transit, no_listen=True)
|
|
r = transit.TransitReceiver(self.transit, no_listen=True)
|
|
|
|
s.set_transit_key(KEY)
|
|
r.set_transit_key(KEY)
|
|
|
|
shints = yield s.get_connection_hints()
|
|
rhints = yield r.get_connection_hints()
|
|
|
|
s.add_connection_hints(rhints)
|
|
r.add_connection_hints(shints)
|
|
|
|
(x, y) = yield self.doBoth(s.connect(), r.connect())
|
|
self.assertIsInstance(x, transit.Connection)
|
|
self.assertIsInstance(y, transit.Connection)
|
|
|
|
d = y.receive_record()
|
|
|
|
x.send_record(b"record1")
|
|
r = yield d
|
|
self.assertEqual(r, b"record1")
|
|
|
|
yield x.close()
|
|
yield y.close()
|