Merge branch '115-transit': fix transit-relay handling
* correctly handles two peers which offer different transit relays * handles two transit relays that actually point to the same server closes #115
This commit is contained in:
commit
d0830e709f
|
@ -1334,7 +1334,7 @@ class Transit(ServerBase, unittest.TestCase):
|
||||||
a1.transport.loseConnection()
|
a1.transport.loseConnection()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_impatience(self):
|
def test_impatience_old(self):
|
||||||
ep = clientFromString(reactor, self.transit)
|
ep = clientFromString(reactor, self.transit)
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
a1 = yield connectProtocol(ep, Accumulator())
|
||||||
|
|
||||||
|
@ -1347,3 +1347,20 @@ class Transit(ServerBase, unittest.TestCase):
|
||||||
self.assertEqual(a1.data, exp)
|
self.assertEqual(a1.data, exp)
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
a1.transport.loseConnection()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_impatience_new(self):
|
||||||
|
ep = clientFromString(reactor, self.transit)
|
||||||
|
a1 = yield connectProtocol(ep, Accumulator())
|
||||||
|
|
||||||
|
token1 = b"\x00"*32
|
||||||
|
side1 = b"\x01"*8
|
||||||
|
# sending too many bytes is impatience.
|
||||||
|
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||||
|
b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
|
||||||
|
|
||||||
|
exp = b"impatient\n"
|
||||||
|
yield a1.waitForBytes(len(exp))
|
||||||
|
self.assertEqual(a1.data, exp)
|
||||||
|
|
||||||
|
a1.transport.loseConnection()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import print_function, unicode_literals
|
from __future__ import print_function, unicode_literals
|
||||||
import io
|
import io
|
||||||
import gc
|
import gc
|
||||||
|
import mock
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.internet import defer, task, endpoints, protocol, address, error
|
from twisted.internet import defer, task, endpoints, protocol, address, error
|
||||||
|
@ -9,6 +10,7 @@ from twisted.python import log, failure
|
||||||
from twisted.test import proto_helpers
|
from twisted.test import proto_helpers
|
||||||
from ..errors import InternalError
|
from ..errors import InternalError
|
||||||
from .. import transit
|
from .. import transit
|
||||||
|
from ..server import transit_server
|
||||||
from .common import ServerBase
|
from .common import ServerBase
|
||||||
from nacl.secret import SecretBox
|
from nacl.secret import SecretBox
|
||||||
from nacl.exceptions import CryptoError
|
from nacl.exceptions import CryptoError
|
||||||
|
@ -137,6 +139,17 @@ class Hints(unittest.TestCase):
|
||||||
ep = c._endpoint_from_hint_obj("unknown:stuff:yowza:pivlor")
|
ep = c._endpoint_from_hint_obj("unknown:stuff:yowza:pivlor")
|
||||||
self.assertEqual(ep, None)
|
self.assertEqual(ep, None)
|
||||||
|
|
||||||
|
def test_comparable(self):
|
||||||
|
h1 = transit.DirectTCPV1Hint("hostname", "port1")
|
||||||
|
h1b = transit.DirectTCPV1Hint("hostname", "port1")
|
||||||
|
h2 = transit.DirectTCPV1Hint("hostname", "port2")
|
||||||
|
r1 = transit.RelayV1Hint(tuple(sorted([h1, h2])))
|
||||||
|
r2 = transit.RelayV1Hint(tuple(sorted([h2, h1])))
|
||||||
|
r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2])))
|
||||||
|
self.assertEqual(r1, r2)
|
||||||
|
self.assertEqual(r2, r3)
|
||||||
|
self.assertEqual(len(set([r1, r2, r3])), 1)
|
||||||
|
|
||||||
|
|
||||||
class Basic(unittest.TestCase):
|
class Basic(unittest.TestCase):
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
|
@ -163,7 +176,7 @@ class Basic(unittest.TestCase):
|
||||||
c.add_connection_hints([{"type": "relay-v1",
|
c.add_connection_hints([{"type": "relay-v1",
|
||||||
"hints": [{"type": "unknown"}]}])
|
"hints": [{"type": "unknown"}]}])
|
||||||
self.assertEqual(c._their_direct_hints, [])
|
self.assertEqual(c._their_direct_hints, [])
|
||||||
self.assertEqual(c._their_relay_hints, [])
|
self.assertEqual(c._our_relay_hints, set())
|
||||||
|
|
||||||
def test_ignore_localhost_hint(self):
|
def test_ignore_localhost_hint(self):
|
||||||
# this actually starts the listener
|
# this actually starts the listener
|
||||||
|
@ -1199,14 +1212,18 @@ DIRECT_HINT = {"type": "direct-tcp-v1",
|
||||||
RELAY_HINT = {"type": "relay-v1",
|
RELAY_HINT = {"type": "relay-v1",
|
||||||
"hints": [{"type": "direct-tcp-v1",
|
"hints": [{"type": "direct-tcp-v1",
|
||||||
"hostname": "relay", "port": 1234}]}
|
"hostname": "relay", "port": 1234}]}
|
||||||
UNUSABLE_HINT = {"type": "unknown"}
|
UNRECOGNIZED_HINT = {"type": "unknown"}
|
||||||
|
UNAVAILABLE_HINT = {"type": "direct-tcp-v1", # e.g. Tor without txtorcon
|
||||||
|
"hostname": "unavailable", "port": 1234}
|
||||||
RELAY_HINT2 = {"type": "relay-v1",
|
RELAY_HINT2 = {"type": "relay-v1",
|
||||||
"hints": [{"type": "direct-tcp-v1",
|
"hints": [{"type": "direct-tcp-v1",
|
||||||
"hostname": "relay", "port": 1234},
|
"hostname": "relay", "port": 1234},
|
||||||
UNUSABLE_HINT]}
|
UNRECOGNIZED_HINT]}
|
||||||
|
UNAVAILABLE_RELAY_HINT = {"type": "relay-v1",
|
||||||
|
"hints": [UNAVAILABLE_HINT]}
|
||||||
DIRECT_HINT_INTERNAL = transit.DirectTCPV1Hint("direct", 1234)
|
DIRECT_HINT_INTERNAL = transit.DirectTCPV1Hint("direct", 1234)
|
||||||
RELAY_HINT_FIRST = transit.DirectTCPV1Hint("relay", 1234)
|
RELAY_HINT_FIRST = transit.DirectTCPV1Hint("relay", 1234)
|
||||||
RELAY_HINT_INTERNAL = transit.RelayV1Hint([RELAY_HINT_FIRST])
|
RELAY_HINT_INTERNAL = transit.RelayV1Hint((RELAY_HINT_FIRST,))
|
||||||
|
|
||||||
class Transit(unittest.TestCase):
|
class Transit(unittest.TestCase):
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
|
@ -1216,7 +1233,7 @@ class Transit(unittest.TestCase):
|
||||||
s.set_transit_key(b"key")
|
s.set_transit_key(b"key")
|
||||||
hints = yield s.get_connection_hints() # start the listener
|
hints = yield s.get_connection_hints() # start the listener
|
||||||
del hints
|
del hints
|
||||||
s.add_connection_hints([DIRECT_HINT, UNUSABLE_HINT])
|
s.add_connection_hints([DIRECT_HINT, UNRECOGNIZED_HINT])
|
||||||
|
|
||||||
connectors = []
|
connectors = []
|
||||||
def _start_connector(ep, description, is_relay=False):
|
def _start_connector(ep, description, is_relay=False):
|
||||||
|
@ -1240,7 +1257,7 @@ class Transit(unittest.TestCase):
|
||||||
elif hint == RELAY_HINT_FIRST:
|
elif hint == RELAY_HINT_FIRST:
|
||||||
return "relay"
|
return "relay"
|
||||||
else:
|
else:
|
||||||
return None
|
return None # e.g. UNAVAILABLE_HINT
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def test_wait_for_relay(self):
|
def test_wait_for_relay(self):
|
||||||
|
@ -1249,7 +1266,7 @@ class Transit(unittest.TestCase):
|
||||||
s.set_transit_key(b"key")
|
s.set_transit_key(b"key")
|
||||||
hints = yield s.get_connection_hints() # start the listener
|
hints = yield s.get_connection_hints() # start the listener
|
||||||
del hints
|
del hints
|
||||||
s.add_connection_hints([DIRECT_HINT, UNUSABLE_HINT, RELAY_HINT])
|
s.add_connection_hints([DIRECT_HINT, UNRECOGNIZED_HINT, RELAY_HINT])
|
||||||
|
|
||||||
direct_connectors = []
|
direct_connectors = []
|
||||||
relay_connectors = []
|
relay_connectors = []
|
||||||
|
@ -1288,7 +1305,9 @@ class Transit(unittest.TestCase):
|
||||||
s.set_transit_key(b"key")
|
s.set_transit_key(b"key")
|
||||||
hints = yield s.get_connection_hints() # start the listener
|
hints = yield s.get_connection_hints() # start the listener
|
||||||
del hints
|
del hints
|
||||||
s.add_connection_hints([UNUSABLE_HINT, RELAY_HINT2])
|
# include hints that can't be turned into an endpoint at runtime
|
||||||
|
s.add_connection_hints([UNRECOGNIZED_HINT, UNAVAILABLE_HINT,
|
||||||
|
RELAY_HINT2, UNAVAILABLE_RELAY_HINT])
|
||||||
|
|
||||||
direct_connectors = []
|
direct_connectors = []
|
||||||
relay_connectors = []
|
relay_connectors = []
|
||||||
|
@ -1320,6 +1339,65 @@ class Transit(unittest.TestCase):
|
||||||
relay_connectors[0].callback("winner")
|
relay_connectors[0].callback("winner")
|
||||||
self.assertEqual(results, ["winner"])
|
self.assertEqual(results, ["winner"])
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_no_contenders(self):
|
||||||
|
clock = task.Clock()
|
||||||
|
s = transit.TransitSender("", reactor=clock, no_listen=True)
|
||||||
|
s.set_transit_key(b"key")
|
||||||
|
hints = yield s.get_connection_hints() # start the listener
|
||||||
|
del hints
|
||||||
|
s.add_connection_hints([]) # no hints at all
|
||||||
|
|
||||||
|
direct_connectors = []
|
||||||
|
relay_connectors = []
|
||||||
|
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
|
||||||
|
def _start_connector(ep, description, is_relay=False):
|
||||||
|
d = defer.Deferred()
|
||||||
|
if ep == "direct":
|
||||||
|
direct_connectors.append(d)
|
||||||
|
elif ep == "relay":
|
||||||
|
relay_connectors.append(d)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
return d
|
||||||
|
s._start_connector = _start_connector
|
||||||
|
|
||||||
|
d = s.connect()
|
||||||
|
f = self.failureResultOf(d, transit.TransitError)
|
||||||
|
self.assertEqual(str(f.value), "No contenders for connection")
|
||||||
|
|
||||||
|
class RelayHandshake(unittest.TestCase):
|
||||||
|
def old_build_relay_handshake(self, key):
|
||||||
|
token = transit.HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
||||||
|
return (token, b"please relay "+hexlify(token)+b"\n")
|
||||||
|
|
||||||
|
def test_old(self):
|
||||||
|
key = b"\x00"
|
||||||
|
token, old_handshake = self.old_build_relay_handshake(key)
|
||||||
|
tc = transit_server.TransitConnection()
|
||||||
|
tc.factory = mock.Mock()
|
||||||
|
tc.factory.connection_got_token = mock.Mock()
|
||||||
|
tc.dataReceived(old_handshake[:-1])
|
||||||
|
self.assertEqual(tc.factory.connection_got_token.mock_calls, [])
|
||||||
|
tc.dataReceived(old_handshake[-1:])
|
||||||
|
self.assertEqual(tc.factory.connection_got_token.mock_calls,
|
||||||
|
[mock.call(hexlify(token), None, tc)])
|
||||||
|
|
||||||
|
def test_new(self):
|
||||||
|
c = transit.Common(None)
|
||||||
|
c.set_transit_key(b"\x00")
|
||||||
|
new_handshake = c._build_relay_handshake()
|
||||||
|
token, old_handshake = self.old_build_relay_handshake(b"\x00")
|
||||||
|
|
||||||
|
tc = transit_server.TransitConnection()
|
||||||
|
tc.factory = mock.Mock()
|
||||||
|
tc.factory.connection_got_token = mock.Mock()
|
||||||
|
tc.dataReceived(new_handshake[:-1])
|
||||||
|
self.assertEqual(tc.factory.connection_got_token.mock_calls, [])
|
||||||
|
tc.dataReceived(new_handshake[-1:])
|
||||||
|
self.assertEqual(tc.factory.connection_got_token.mock_calls,
|
||||||
|
[mock.call(hexlify(token), c._side.encode("ascii"), tc)])
|
||||||
|
|
||||||
|
|
||||||
class Full(ServerBase, unittest.TestCase):
|
class Full(ServerBase, unittest.TestCase):
|
||||||
def doBoth(self, d1, d2):
|
def doBoth(self, d1, d2):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# no unicode_literals, revisit after twisted patch
|
# no unicode_literals, revisit after twisted patch
|
||||||
from __future__ import print_function, absolute_import
|
from __future__ import print_function, absolute_import
|
||||||
import re, sys, time, socket
|
import os, re, sys, time, socket
|
||||||
from collections import namedtuple, deque
|
from collections import namedtuple, deque
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
import six
|
import six
|
||||||
|
@ -15,6 +15,7 @@ from nacl.secret import SecretBox
|
||||||
from hkdf import Hkdf
|
from hkdf import Hkdf
|
||||||
from .errors import InternalError
|
from .errors import InternalError
|
||||||
from .timing import DebugTiming
|
from .timing import DebugTiming
|
||||||
|
from .util import bytes_to_hexstr
|
||||||
from . import ipaddrs
|
from . import ipaddrs
|
||||||
|
|
||||||
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
||||||
|
@ -76,9 +77,11 @@ def build_sender_handshake(key):
|
||||||
hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
|
hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
|
||||||
return b"transit sender "+hexlify(hexid)+b" ready\n\n"
|
return b"transit sender "+hexlify(hexid)+b" ready\n\n"
|
||||||
|
|
||||||
def build_relay_handshake(key):
|
def build_sided_relay_handshake(key, side):
|
||||||
|
assert isinstance(side, type(u""))
|
||||||
|
assert len(side) == 8*2
|
||||||
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
||||||
return b"please relay "+hexlify(token)+b"\n"
|
return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n"
|
||||||
|
|
||||||
|
|
||||||
# These namedtuples are "hint objects". The JSON-serializable dictionaries
|
# These namedtuples are "hint objects". The JSON-serializable dictionaries
|
||||||
|
@ -92,9 +95,10 @@ def build_relay_handshake(key):
|
||||||
# * the rest of the connection contains transit data
|
# * the rest of the connection contains transit data
|
||||||
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port"])
|
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port"])
|
||||||
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port"])
|
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port"])
|
||||||
# RelayV1Hint contains a list of DirectTCPV1Hint and TorTCPV1Hint hints. For
|
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
|
||||||
# each one, make the TCP connection, send the relay handshake, then complete
|
# use a tuple rather than a list so they'll be hashable into a set). For each
|
||||||
# the rest of the V1 protocol. Only one hint per relay is useful.
|
# one, make the TCP connection, send the relay handshake, then complete the
|
||||||
|
# rest of the V1 protocol. Only one hint per relay is useful.
|
||||||
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
|
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
|
||||||
|
|
||||||
def describe_hint_obj(hint):
|
def describe_hint_obj(hint):
|
||||||
|
@ -575,15 +579,18 @@ class Common:
|
||||||
|
|
||||||
def __init__(self, transit_relay, no_listen=False, tor_manager=None,
|
def __init__(self, transit_relay, no_listen=False, tor_manager=None,
|
||||||
reactor=reactor, timing=None):
|
reactor=reactor, timing=None):
|
||||||
|
self._side = bytes_to_hexstr(os.urandom(8)) # unicode
|
||||||
if transit_relay:
|
if transit_relay:
|
||||||
if not isinstance(transit_relay, type(u"")):
|
if not isinstance(transit_relay, type(u"")):
|
||||||
raise InternalError
|
raise InternalError
|
||||||
relay = RelayV1Hint(hints=[parse_hint_argv(transit_relay)])
|
# TODO: allow multiple hints for a single relay
|
||||||
|
relay_hint = parse_hint_argv(transit_relay)
|
||||||
|
relay = RelayV1Hint(hints=(relay_hint,))
|
||||||
self._transit_relays = [relay]
|
self._transit_relays = [relay]
|
||||||
else:
|
else:
|
||||||
self._transit_relays = []
|
self._transit_relays = []
|
||||||
self._their_direct_hints = [] # hintobjs
|
self._their_direct_hints = [] # hintobjs
|
||||||
self._their_relay_hints = []
|
self._our_relay_hints = set(self._transit_relays)
|
||||||
self._tor_manager = tor_manager
|
self._tor_manager = tor_manager
|
||||||
self._transit_key = None
|
self._transit_key = None
|
||||||
self._no_listen = no_listen
|
self._no_listen = no_listen
|
||||||
|
@ -703,10 +710,14 @@ class Common:
|
||||||
# with a set of equally-valid ways to connect to it. Treat
|
# with a set of equally-valid ways to connect to it. Treat
|
||||||
# them as separate relays, instead of merging them all
|
# them as separate relays, instead of merging them all
|
||||||
# together like this.
|
# together like this.
|
||||||
|
relay_hints = []
|
||||||
for rhs in h.get(u"hints", []):
|
for rhs in h.get(u"hints", []):
|
||||||
rh = self._parse_tcp_v1_hint(rhs)
|
h = self._parse_tcp_v1_hint(rhs)
|
||||||
if rh:
|
if h:
|
||||||
self._their_relay_hints.append(rh)
|
relay_hints.append(h)
|
||||||
|
if relay_hints:
|
||||||
|
rh = RelayV1Hint(hints=tuple(sorted(relay_hints)))
|
||||||
|
self._our_relay_hints.add(rh)
|
||||||
else:
|
else:
|
||||||
log.msg("unknown hint type: %r" % (h,))
|
log.msg("unknown hint type: %r" % (h,))
|
||||||
|
|
||||||
|
@ -797,13 +808,14 @@ class Common:
|
||||||
contenders.append(d)
|
contenders.append(d)
|
||||||
relay_delay = self.RELAY_DELAY
|
relay_delay = self.RELAY_DELAY
|
||||||
|
|
||||||
# Start trying the relay a few seconds after we start to try the
|
# Start trying the relays a few seconds after we start to try the
|
||||||
# direct hints. The idea is to prefer direct connections, but not be
|
# direct hints. The idea is to prefer direct connections, but not be
|
||||||
# afraid of using the relay when we have direct hints that don't
|
# afraid of using a relay when we have direct hints that don't
|
||||||
# resolve quickly. Many direct hints will be to unused local-network
|
# resolve quickly. Many direct hints will be to unused local-network
|
||||||
# IP addresses, which won't answer, and would take the full TCP
|
# IP addresses, which won't answer, and would take the full TCP
|
||||||
# timeout (30s or more) to fail.
|
# timeout (30s or more) to fail.
|
||||||
for hint_obj in self._their_relay_hints:
|
for rh in self._our_relay_hints:
|
||||||
|
for hint_obj in rh.hints:
|
||||||
ep = self._endpoint_from_hint_obj(hint_obj)
|
ep = self._endpoint_from_hint_obj(hint_obj)
|
||||||
if not ep:
|
if not ep:
|
||||||
continue
|
continue
|
||||||
|
@ -830,11 +842,14 @@ class Common:
|
||||||
d.addBoth(_done)
|
d.addBoth(_done)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
def _build_relay_handshake(self):
|
||||||
|
return build_sided_relay_handshake(self._transit_key, self._side)
|
||||||
|
|
||||||
def _start_connector(self, ep, description, is_relay=False):
|
def _start_connector(self, ep, description, is_relay=False):
|
||||||
relay_handshake = None
|
relay_handshake = None
|
||||||
if is_relay:
|
if is_relay:
|
||||||
assert self._transit_key
|
assert self._transit_key
|
||||||
relay_handshake = build_relay_handshake(self._transit_key)
|
relay_handshake = self._build_relay_handshake()
|
||||||
f = OutboundConnectionFactory(self, relay_handshake, description)
|
f = OutboundConnectionFactory(self, relay_handshake, description)
|
||||||
d = ep.connect(f)
|
d = ep.connect(f)
|
||||||
# fires with protocol, or ConnectError
|
# fires with protocol, or ConnectError
|
||||||
|
|
Loading…
Reference in New Issue
Block a user