diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 2d67ab2..cf1a5f7 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1334,7 +1334,7 @@ class Transit(ServerBase, unittest.TestCase): a1.transport.loseConnection() @defer.inlineCallbacks - def test_impatience(self): + def test_impatience_old(self): ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) @@ -1347,3 +1347,20 @@ class Transit(ServerBase, unittest.TestCase): self.assertEqual(a1.data, exp) a1.transport.loseConnection() + + @defer.inlineCallbacks + def test_impatience_new(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + # sending too many bytes is impatience. + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\nNOWNOWNOW") + + exp = b"impatient\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection() diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index 9f76e63..7f63eac 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -1,6 +1,7 @@ from __future__ import print_function, unicode_literals import io import gc +import mock from binascii import hexlify, unhexlify from twisted.trial import unittest from twisted.internet import defer, task, endpoints, protocol, address, error @@ -9,6 +10,7 @@ from twisted.python import log, failure from twisted.test import proto_helpers from ..errors import InternalError from .. import transit +from ..server import transit_server from .common import ServerBase from nacl.secret import SecretBox from nacl.exceptions import CryptoError @@ -137,6 +139,17 @@ class Hints(unittest.TestCase): ep = c._endpoint_from_hint_obj("unknown:stuff:yowza:pivlor") self.assertEqual(ep, None) + def test_comparable(self): + h1 = transit.DirectTCPV1Hint("hostname", "port1") + h1b = transit.DirectTCPV1Hint("hostname", "port1") + h2 = transit.DirectTCPV1Hint("hostname", "port2") + r1 = transit.RelayV1Hint(tuple(sorted([h1, h2]))) + r2 = transit.RelayV1Hint(tuple(sorted([h2, h1]))) + r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2]))) + self.assertEqual(r1, r2) + self.assertEqual(r2, r3) + self.assertEqual(len(set([r1, r2, r3])), 1) + class Basic(unittest.TestCase): @inlineCallbacks @@ -163,7 +176,7 @@ class Basic(unittest.TestCase): c.add_connection_hints([{"type": "relay-v1", "hints": [{"type": "unknown"}]}]) self.assertEqual(c._their_direct_hints, []) - self.assertEqual(c._their_relay_hints, []) + self.assertEqual(c._our_relay_hints, set()) def test_ignore_localhost_hint(self): # this actually starts the listener @@ -1199,14 +1212,18 @@ DIRECT_HINT = {"type": "direct-tcp-v1", RELAY_HINT = {"type": "relay-v1", "hints": [{"type": "direct-tcp-v1", "hostname": "relay", "port": 1234}]} -UNUSABLE_HINT = {"type": "unknown"} +UNRECOGNIZED_HINT = {"type": "unknown"} +UNAVAILABLE_HINT = {"type": "direct-tcp-v1", # e.g. Tor without txtorcon + "hostname": "unavailable", "port": 1234} RELAY_HINT2 = {"type": "relay-v1", "hints": [{"type": "direct-tcp-v1", "hostname": "relay", "port": 1234}, - UNUSABLE_HINT]} + UNRECOGNIZED_HINT]} +UNAVAILABLE_RELAY_HINT = {"type": "relay-v1", + "hints": [UNAVAILABLE_HINT]} DIRECT_HINT_INTERNAL = transit.DirectTCPV1Hint("direct", 1234) RELAY_HINT_FIRST = transit.DirectTCPV1Hint("relay", 1234) -RELAY_HINT_INTERNAL = transit.RelayV1Hint([RELAY_HINT_FIRST]) +RELAY_HINT_INTERNAL = transit.RelayV1Hint((RELAY_HINT_FIRST,)) class Transit(unittest.TestCase): @inlineCallbacks @@ -1216,7 +1233,7 @@ class Transit(unittest.TestCase): s.set_transit_key(b"key") hints = yield s.get_connection_hints() # start the listener del hints - s.add_connection_hints([DIRECT_HINT, UNUSABLE_HINT]) + s.add_connection_hints([DIRECT_HINT, UNRECOGNIZED_HINT]) connectors = [] def _start_connector(ep, description, is_relay=False): @@ -1240,7 +1257,7 @@ class Transit(unittest.TestCase): elif hint == RELAY_HINT_FIRST: return "relay" else: - return None + return None # e.g. UNAVAILABLE_HINT @inlineCallbacks def test_wait_for_relay(self): @@ -1249,7 +1266,7 @@ class Transit(unittest.TestCase): s.set_transit_key(b"key") hints = yield s.get_connection_hints() # start the listener del hints - s.add_connection_hints([DIRECT_HINT, UNUSABLE_HINT, RELAY_HINT]) + s.add_connection_hints([DIRECT_HINT, UNRECOGNIZED_HINT, RELAY_HINT]) direct_connectors = [] relay_connectors = [] @@ -1288,7 +1305,9 @@ class Transit(unittest.TestCase): s.set_transit_key(b"key") hints = yield s.get_connection_hints() # start the listener del hints - s.add_connection_hints([UNUSABLE_HINT, RELAY_HINT2]) + # include hints that can't be turned into an endpoint at runtime + s.add_connection_hints([UNRECOGNIZED_HINT, UNAVAILABLE_HINT, + RELAY_HINT2, UNAVAILABLE_RELAY_HINT]) direct_connectors = [] relay_connectors = [] @@ -1320,6 +1339,65 @@ class Transit(unittest.TestCase): relay_connectors[0].callback("winner") self.assertEqual(results, ["winner"]) + @inlineCallbacks + def test_no_contenders(self): + clock = task.Clock() + s = transit.TransitSender("", reactor=clock, no_listen=True) + s.set_transit_key(b"key") + hints = yield s.get_connection_hints() # start the listener + del hints + s.add_connection_hints([]) # no hints at all + + direct_connectors = [] + relay_connectors = [] + s._endpoint_from_hint_obj = self._endpoint_from_hint_obj + def _start_connector(ep, description, is_relay=False): + d = defer.Deferred() + if ep == "direct": + direct_connectors.append(d) + elif ep == "relay": + relay_connectors.append(d) + else: + raise ValueError + return d + s._start_connector = _start_connector + + d = s.connect() + f = self.failureResultOf(d, transit.TransitError) + self.assertEqual(str(f.value), "No contenders for connection") + +class RelayHandshake(unittest.TestCase): + def old_build_relay_handshake(self, key): + token = transit.HKDF(key, 32, CTXinfo=b"transit_relay_token") + return (token, b"please relay "+hexlify(token)+b"\n") + + def test_old(self): + key = b"\x00" + token, old_handshake = self.old_build_relay_handshake(key) + tc = transit_server.TransitConnection() + tc.factory = mock.Mock() + tc.factory.connection_got_token = mock.Mock() + tc.dataReceived(old_handshake[:-1]) + self.assertEqual(tc.factory.connection_got_token.mock_calls, []) + tc.dataReceived(old_handshake[-1:]) + self.assertEqual(tc.factory.connection_got_token.mock_calls, + [mock.call(hexlify(token), None, tc)]) + + def test_new(self): + c = transit.Common(None) + c.set_transit_key(b"\x00") + new_handshake = c._build_relay_handshake() + token, old_handshake = self.old_build_relay_handshake(b"\x00") + + tc = transit_server.TransitConnection() + tc.factory = mock.Mock() + tc.factory.connection_got_token = mock.Mock() + tc.dataReceived(new_handshake[:-1]) + self.assertEqual(tc.factory.connection_got_token.mock_calls, []) + tc.dataReceived(new_handshake[-1:]) + self.assertEqual(tc.factory.connection_got_token.mock_calls, + [mock.call(hexlify(token), c._side.encode("ascii"), tc)]) + class Full(ServerBase, unittest.TestCase): def doBoth(self, d1, d2): diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index 308e92f..b9d887d 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -1,6 +1,6 @@ # no unicode_literals, revisit after twisted patch from __future__ import print_function, absolute_import -import re, sys, time, socket +import os, re, sys, time, socket from collections import namedtuple, deque from binascii import hexlify, unhexlify import six @@ -15,6 +15,7 @@ from nacl.secret import SecretBox from hkdf import Hkdf from .errors import InternalError from .timing import DebugTiming +from .util import bytes_to_hexstr from . import ipaddrs def HKDF(skm, outlen, salt=None, CTXinfo=b""): @@ -76,9 +77,11 @@ def build_sender_handshake(key): hexid = HKDF(key, 32, CTXinfo=b"transit_sender") return b"transit sender "+hexlify(hexid)+b" ready\n\n" -def build_relay_handshake(key): +def build_sided_relay_handshake(key, side): + assert isinstance(side, type(u"")) + assert len(side) == 8*2 token = HKDF(key, 32, CTXinfo=b"transit_relay_token") - return b"please relay "+hexlify(token)+b"\n" + return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n" # These namedtuples are "hint objects". The JSON-serializable dictionaries @@ -92,9 +95,10 @@ def build_relay_handshake(key): # * the rest of the connection contains transit data DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port"]) TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port"]) -# RelayV1Hint contains a list of DirectTCPV1Hint and TorTCPV1Hint hints. For -# each one, make the TCP connection, send the relay handshake, then complete -# the rest of the V1 protocol. Only one hint per relay is useful. +# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we +# use a tuple rather than a list so they'll be hashable into a set). For each +# one, make the TCP connection, send the relay handshake, then complete the +# rest of the V1 protocol. Only one hint per relay is useful. RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) def describe_hint_obj(hint): @@ -575,15 +579,18 @@ class Common: def __init__(self, transit_relay, no_listen=False, tor_manager=None, reactor=reactor, timing=None): + self._side = bytes_to_hexstr(os.urandom(8)) # unicode if transit_relay: if not isinstance(transit_relay, type(u"")): raise InternalError - relay = RelayV1Hint(hints=[parse_hint_argv(transit_relay)]) + # TODO: allow multiple hints for a single relay + relay_hint = parse_hint_argv(transit_relay) + relay = RelayV1Hint(hints=(relay_hint,)) self._transit_relays = [relay] else: self._transit_relays = [] self._their_direct_hints = [] # hintobjs - self._their_relay_hints = [] + self._our_relay_hints = set(self._transit_relays) self._tor_manager = tor_manager self._transit_key = None self._no_listen = no_listen @@ -703,10 +710,14 @@ class Common: # with a set of equally-valid ways to connect to it. Treat # them as separate relays, instead of merging them all # together like this. + relay_hints = [] for rhs in h.get(u"hints", []): - rh = self._parse_tcp_v1_hint(rhs) - if rh: - self._their_relay_hints.append(rh) + h = self._parse_tcp_v1_hint(rhs) + if h: + relay_hints.append(h) + if relay_hints: + rh = RelayV1Hint(hints=tuple(sorted(relay_hints))) + self._our_relay_hints.add(rh) else: log.msg("unknown hint type: %r" % (h,)) @@ -797,21 +808,22 @@ class Common: contenders.append(d) relay_delay = self.RELAY_DELAY - # Start trying the relay a few seconds after we start to try the + # Start trying the relays a few seconds after we start to try the # direct hints. The idea is to prefer direct connections, but not be - # afraid of using the relay when we have direct hints that don't + # afraid of using a relay when we have direct hints that don't # resolve quickly. Many direct hints will be to unused local-network # IP addresses, which won't answer, and would take the full TCP # timeout (30s or more) to fail. - for hint_obj in self._their_relay_hints: - ep = self._endpoint_from_hint_obj(hint_obj) - if not ep: - continue - description = "->relay:%s" % describe_hint_obj(hint_obj) - d = task.deferLater(self._reactor, relay_delay, - self._start_connector, ep, description, - is_relay=True) - contenders.append(d) + for rh in self._our_relay_hints: + for hint_obj in rh.hints: + ep = self._endpoint_from_hint_obj(hint_obj) + if not ep: + continue + description = "->relay:%s" % describe_hint_obj(hint_obj) + d = task.deferLater(self._reactor, relay_delay, + self._start_connector, ep, description, + is_relay=True) + contenders.append(d) if not contenders: raise TransitError("No contenders for connection") @@ -830,11 +842,14 @@ class Common: d.addBoth(_done) return d + def _build_relay_handshake(self): + return build_sided_relay_handshake(self._transit_key, self._side) + def _start_connector(self, ep, description, is_relay=False): relay_handshake = None if is_relay: assert self._transit_key - relay_handshake = build_relay_handshake(self._transit_key) + relay_handshake = self._build_relay_handshake() f = OutboundConnectionFactory(self, relay_handshake, description) d = ep.connect(f) # fires with protocol, or ConnectError