Merge branch '115-transit': fix transit-relay handling

* correctly handles two peers which offer different transit relays
* handles two transit relays that actually point to the same server

closes #115
This commit is contained in:
Brian Warner 2016-12-23 22:32:17 -05:00
commit d0830e709f
3 changed files with 142 additions and 32 deletions

View File

@ -1334,7 +1334,7 @@ class Transit(ServerBase, unittest.TestCase):
a1.transport.loseConnection() a1.transport.loseConnection()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_impatience(self): def test_impatience_old(self):
ep = clientFromString(reactor, self.transit) ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator()) a1 = yield connectProtocol(ep, Accumulator())
@ -1347,3 +1347,20 @@ class Transit(ServerBase, unittest.TestCase):
self.assertEqual(a1.data, exp) self.assertEqual(a1.data, exp)
a1.transport.loseConnection() a1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience_new(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
# sending too many bytes is impatience.
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()

View File

@ -1,6 +1,7 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import io import io
import gc import gc
import mock
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer, task, endpoints, protocol, address, error from twisted.internet import defer, task, endpoints, protocol, address, error
@ -9,6 +10,7 @@ from twisted.python import log, failure
from twisted.test import proto_helpers from twisted.test import proto_helpers
from ..errors import InternalError from ..errors import InternalError
from .. import transit from .. import transit
from ..server import transit_server
from .common import ServerBase from .common import ServerBase
from nacl.secret import SecretBox from nacl.secret import SecretBox
from nacl.exceptions import CryptoError from nacl.exceptions import CryptoError
@ -137,6 +139,17 @@ class Hints(unittest.TestCase):
ep = c._endpoint_from_hint_obj("unknown:stuff:yowza:pivlor") ep = c._endpoint_from_hint_obj("unknown:stuff:yowza:pivlor")
self.assertEqual(ep, None) self.assertEqual(ep, None)
def test_comparable(self):
h1 = transit.DirectTCPV1Hint("hostname", "port1")
h1b = transit.DirectTCPV1Hint("hostname", "port1")
h2 = transit.DirectTCPV1Hint("hostname", "port2")
r1 = transit.RelayV1Hint(tuple(sorted([h1, h2])))
r2 = transit.RelayV1Hint(tuple(sorted([h2, h1])))
r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2])))
self.assertEqual(r1, r2)
self.assertEqual(r2, r3)
self.assertEqual(len(set([r1, r2, r3])), 1)
class Basic(unittest.TestCase): class Basic(unittest.TestCase):
@inlineCallbacks @inlineCallbacks
@ -163,7 +176,7 @@ class Basic(unittest.TestCase):
c.add_connection_hints([{"type": "relay-v1", c.add_connection_hints([{"type": "relay-v1",
"hints": [{"type": "unknown"}]}]) "hints": [{"type": "unknown"}]}])
self.assertEqual(c._their_direct_hints, []) self.assertEqual(c._their_direct_hints, [])
self.assertEqual(c._their_relay_hints, []) self.assertEqual(c._our_relay_hints, set())
def test_ignore_localhost_hint(self): def test_ignore_localhost_hint(self):
# this actually starts the listener # this actually starts the listener
@ -1199,14 +1212,18 @@ DIRECT_HINT = {"type": "direct-tcp-v1",
RELAY_HINT = {"type": "relay-v1", RELAY_HINT = {"type": "relay-v1",
"hints": [{"type": "direct-tcp-v1", "hints": [{"type": "direct-tcp-v1",
"hostname": "relay", "port": 1234}]} "hostname": "relay", "port": 1234}]}
UNUSABLE_HINT = {"type": "unknown"} UNRECOGNIZED_HINT = {"type": "unknown"}
UNAVAILABLE_HINT = {"type": "direct-tcp-v1", # e.g. Tor without txtorcon
"hostname": "unavailable", "port": 1234}
RELAY_HINT2 = {"type": "relay-v1", RELAY_HINT2 = {"type": "relay-v1",
"hints": [{"type": "direct-tcp-v1", "hints": [{"type": "direct-tcp-v1",
"hostname": "relay", "port": 1234}, "hostname": "relay", "port": 1234},
UNUSABLE_HINT]} UNRECOGNIZED_HINT]}
UNAVAILABLE_RELAY_HINT = {"type": "relay-v1",
"hints": [UNAVAILABLE_HINT]}
DIRECT_HINT_INTERNAL = transit.DirectTCPV1Hint("direct", 1234) DIRECT_HINT_INTERNAL = transit.DirectTCPV1Hint("direct", 1234)
RELAY_HINT_FIRST = transit.DirectTCPV1Hint("relay", 1234) RELAY_HINT_FIRST = transit.DirectTCPV1Hint("relay", 1234)
RELAY_HINT_INTERNAL = transit.RelayV1Hint([RELAY_HINT_FIRST]) RELAY_HINT_INTERNAL = transit.RelayV1Hint((RELAY_HINT_FIRST,))
class Transit(unittest.TestCase): class Transit(unittest.TestCase):
@inlineCallbacks @inlineCallbacks
@ -1216,7 +1233,7 @@ class Transit(unittest.TestCase):
s.set_transit_key(b"key") s.set_transit_key(b"key")
hints = yield s.get_connection_hints() # start the listener hints = yield s.get_connection_hints() # start the listener
del hints del hints
s.add_connection_hints([DIRECT_HINT, UNUSABLE_HINT]) s.add_connection_hints([DIRECT_HINT, UNRECOGNIZED_HINT])
connectors = [] connectors = []
def _start_connector(ep, description, is_relay=False): def _start_connector(ep, description, is_relay=False):
@ -1240,7 +1257,7 @@ class Transit(unittest.TestCase):
elif hint == RELAY_HINT_FIRST: elif hint == RELAY_HINT_FIRST:
return "relay" return "relay"
else: else:
return None return None # e.g. UNAVAILABLE_HINT
@inlineCallbacks @inlineCallbacks
def test_wait_for_relay(self): def test_wait_for_relay(self):
@ -1249,7 +1266,7 @@ class Transit(unittest.TestCase):
s.set_transit_key(b"key") s.set_transit_key(b"key")
hints = yield s.get_connection_hints() # start the listener hints = yield s.get_connection_hints() # start the listener
del hints del hints
s.add_connection_hints([DIRECT_HINT, UNUSABLE_HINT, RELAY_HINT]) s.add_connection_hints([DIRECT_HINT, UNRECOGNIZED_HINT, RELAY_HINT])
direct_connectors = [] direct_connectors = []
relay_connectors = [] relay_connectors = []
@ -1288,7 +1305,9 @@ class Transit(unittest.TestCase):
s.set_transit_key(b"key") s.set_transit_key(b"key")
hints = yield s.get_connection_hints() # start the listener hints = yield s.get_connection_hints() # start the listener
del hints del hints
s.add_connection_hints([UNUSABLE_HINT, RELAY_HINT2]) # include hints that can't be turned into an endpoint at runtime
s.add_connection_hints([UNRECOGNIZED_HINT, UNAVAILABLE_HINT,
RELAY_HINT2, UNAVAILABLE_RELAY_HINT])
direct_connectors = [] direct_connectors = []
relay_connectors = [] relay_connectors = []
@ -1320,6 +1339,65 @@ class Transit(unittest.TestCase):
relay_connectors[0].callback("winner") relay_connectors[0].callback("winner")
self.assertEqual(results, ["winner"]) self.assertEqual(results, ["winner"])
@inlineCallbacks
def test_no_contenders(self):
clock = task.Clock()
s = transit.TransitSender("", reactor=clock, no_listen=True)
s.set_transit_key(b"key")
hints = yield s.get_connection_hints() # start the listener
del hints
s.add_connection_hints([]) # no hints at all
direct_connectors = []
relay_connectors = []
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
def _start_connector(ep, description, is_relay=False):
d = defer.Deferred()
if ep == "direct":
direct_connectors.append(d)
elif ep == "relay":
relay_connectors.append(d)
else:
raise ValueError
return d
s._start_connector = _start_connector
d = s.connect()
f = self.failureResultOf(d, transit.TransitError)
self.assertEqual(str(f.value), "No contenders for connection")
class RelayHandshake(unittest.TestCase):
def old_build_relay_handshake(self, key):
token = transit.HKDF(key, 32, CTXinfo=b"transit_relay_token")
return (token, b"please relay "+hexlify(token)+b"\n")
def test_old(self):
key = b"\x00"
token, old_handshake = self.old_build_relay_handshake(key)
tc = transit_server.TransitConnection()
tc.factory = mock.Mock()
tc.factory.connection_got_token = mock.Mock()
tc.dataReceived(old_handshake[:-1])
self.assertEqual(tc.factory.connection_got_token.mock_calls, [])
tc.dataReceived(old_handshake[-1:])
self.assertEqual(tc.factory.connection_got_token.mock_calls,
[mock.call(hexlify(token), None, tc)])
def test_new(self):
c = transit.Common(None)
c.set_transit_key(b"\x00")
new_handshake = c._build_relay_handshake()
token, old_handshake = self.old_build_relay_handshake(b"\x00")
tc = transit_server.TransitConnection()
tc.factory = mock.Mock()
tc.factory.connection_got_token = mock.Mock()
tc.dataReceived(new_handshake[:-1])
self.assertEqual(tc.factory.connection_got_token.mock_calls, [])
tc.dataReceived(new_handshake[-1:])
self.assertEqual(tc.factory.connection_got_token.mock_calls,
[mock.call(hexlify(token), c._side.encode("ascii"), tc)])
class Full(ServerBase, unittest.TestCase): class Full(ServerBase, unittest.TestCase):
def doBoth(self, d1, d2): def doBoth(self, d1, d2):

View File

@ -1,6 +1,6 @@
# no unicode_literals, revisit after twisted patch # no unicode_literals, revisit after twisted patch
from __future__ import print_function, absolute_import from __future__ import print_function, absolute_import
import re, sys, time, socket import os, re, sys, time, socket
from collections import namedtuple, deque from collections import namedtuple, deque
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
import six import six
@ -15,6 +15,7 @@ from nacl.secret import SecretBox
from hkdf import Hkdf from hkdf import Hkdf
from .errors import InternalError from .errors import InternalError
from .timing import DebugTiming from .timing import DebugTiming
from .util import bytes_to_hexstr
from . import ipaddrs from . import ipaddrs
def HKDF(skm, outlen, salt=None, CTXinfo=b""): def HKDF(skm, outlen, salt=None, CTXinfo=b""):
@ -76,9 +77,11 @@ def build_sender_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_sender") hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
return b"transit sender "+hexlify(hexid)+b" ready\n\n" return b"transit sender "+hexlify(hexid)+b" ready\n\n"
def build_relay_handshake(key): def build_sided_relay_handshake(key, side):
assert isinstance(side, type(u""))
assert len(side) == 8*2
token = HKDF(key, 32, CTXinfo=b"transit_relay_token") token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return b"please relay "+hexlify(token)+b"\n" return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n"
# These namedtuples are "hint objects". The JSON-serializable dictionaries # These namedtuples are "hint objects". The JSON-serializable dictionaries
@ -92,9 +95,10 @@ def build_relay_handshake(key):
# * the rest of the connection contains transit data # * the rest of the connection contains transit data
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port"]) DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port"])
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port"]) TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port"])
# RelayV1Hint contains a list of DirectTCPV1Hint and TorTCPV1Hint hints. For # RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
# each one, make the TCP connection, send the relay handshake, then complete # use a tuple rather than a list so they'll be hashable into a set). For each
# the rest of the V1 protocol. Only one hint per relay is useful. # one, make the TCP connection, send the relay handshake, then complete the
# rest of the V1 protocol. Only one hint per relay is useful.
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
def describe_hint_obj(hint): def describe_hint_obj(hint):
@ -575,15 +579,18 @@ class Common:
def __init__(self, transit_relay, no_listen=False, tor_manager=None, def __init__(self, transit_relay, no_listen=False, tor_manager=None,
reactor=reactor, timing=None): reactor=reactor, timing=None):
self._side = bytes_to_hexstr(os.urandom(8)) # unicode
if transit_relay: if transit_relay:
if not isinstance(transit_relay, type(u"")): if not isinstance(transit_relay, type(u"")):
raise InternalError raise InternalError
relay = RelayV1Hint(hints=[parse_hint_argv(transit_relay)]) # TODO: allow multiple hints for a single relay
relay_hint = parse_hint_argv(transit_relay)
relay = RelayV1Hint(hints=(relay_hint,))
self._transit_relays = [relay] self._transit_relays = [relay]
else: else:
self._transit_relays = [] self._transit_relays = []
self._their_direct_hints = [] # hintobjs self._their_direct_hints = [] # hintobjs
self._their_relay_hints = [] self._our_relay_hints = set(self._transit_relays)
self._tor_manager = tor_manager self._tor_manager = tor_manager
self._transit_key = None self._transit_key = None
self._no_listen = no_listen self._no_listen = no_listen
@ -703,10 +710,14 @@ class Common:
# with a set of equally-valid ways to connect to it. Treat # with a set of equally-valid ways to connect to it. Treat
# them as separate relays, instead of merging them all # them as separate relays, instead of merging them all
# together like this. # together like this.
relay_hints = []
for rhs in h.get(u"hints", []): for rhs in h.get(u"hints", []):
rh = self._parse_tcp_v1_hint(rhs) h = self._parse_tcp_v1_hint(rhs)
if rh: if h:
self._their_relay_hints.append(rh) relay_hints.append(h)
if relay_hints:
rh = RelayV1Hint(hints=tuple(sorted(relay_hints)))
self._our_relay_hints.add(rh)
else: else:
log.msg("unknown hint type: %r" % (h,)) log.msg("unknown hint type: %r" % (h,))
@ -797,13 +808,14 @@ class Common:
contenders.append(d) contenders.append(d)
relay_delay = self.RELAY_DELAY relay_delay = self.RELAY_DELAY
# Start trying the relay a few seconds after we start to try the # Start trying the relays a few seconds after we start to try the
# direct hints. The idea is to prefer direct connections, but not be # direct hints. The idea is to prefer direct connections, but not be
# afraid of using the relay when we have direct hints that don't # afraid of using a relay when we have direct hints that don't
# resolve quickly. Many direct hints will be to unused local-network # resolve quickly. Many direct hints will be to unused local-network
# IP addresses, which won't answer, and would take the full TCP # IP addresses, which won't answer, and would take the full TCP
# timeout (30s or more) to fail. # timeout (30s or more) to fail.
for hint_obj in self._their_relay_hints: for rh in self._our_relay_hints:
for hint_obj in rh.hints:
ep = self._endpoint_from_hint_obj(hint_obj) ep = self._endpoint_from_hint_obj(hint_obj)
if not ep: if not ep:
continue continue
@ -830,11 +842,14 @@ class Common:
d.addBoth(_done) d.addBoth(_done)
return d return d
def _build_relay_handshake(self):
return build_sided_relay_handshake(self._transit_key, self._side)
def _start_connector(self, ep, description, is_relay=False): def _start_connector(self, ep, description, is_relay=False):
relay_handshake = None relay_handshake = None
if is_relay: if is_relay:
assert self._transit_key assert self._transit_key
relay_handshake = build_relay_handshake(self._transit_key) relay_handshake = self._build_relay_handshake()
f = OutboundConnectionFactory(self, relay_handshake, description) f = OutboundConnectionFactory(self, relay_handshake, description)
d = ep.connect(f) d = ep.connect(f)
# fires with protocol, or ConnectError # fires with protocol, or ConnectError