From 2f4e4d30318261548ddf4b7009eba17ffde3b5dd Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 21 Dec 2018 23:12:17 -0500 Subject: [PATCH] factor out describe_hint_obj and endpoint_from_hint_obj --- src/wormhole/_dilation/connector.py | 34 ++------ src/wormhole/_hints.py | 26 ++++++ src/wormhole/test/test_transit.py | 122 +++++++++++++++------------- src/wormhole/transit.py | 42 ++-------- 4 files changed, 104 insertions(+), 120 deletions(-) diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py index 8d7f4a5..f472ea9 100644 --- a/src/wormhole/_dilation/connector.py +++ b/src/wormhole/_dilation/connector.py @@ -8,7 +8,7 @@ from automat import MethodicalMachine from zope.interface import implementer from twisted.internet.task import deferLater from twisted.internet.defer import DeferredList -from twisted.internet.endpoints import HostnameEndpoint, serverFromString +from twisted.internet.endpoints import serverFromString from twisted.internet.protocol import ClientFactory, ServerFactory from twisted.python import log from hkdf import Hkdf @@ -19,18 +19,8 @@ from ..observer import EmptyableSet from .connection import DilatedConnectionProtocol, KCM from .roles import LEADER -from .._hints import parse_hint_argv, DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint - -def describe_hint_obj(hint, relay, tor): - prefix = "tor->" if tor else "->" - if relay: - prefix = prefix + "relay:" - if isinstance(hint, DirectTCPV1Hint): - return prefix + "tcp:%s:%d" % (hint.hostname, hint.port) - elif isinstance(hint, TorTCPV1Hint): - return prefix + "tor:%s:%d" % (hint.hostname, hint.port) - else: - return prefix + str(hint) +from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint, + parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj) def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj @@ -340,7 +330,7 @@ class Connector(object): for h in direct[p]: if isinstance(h, TorTCPV1Hint) and not self._tor: continue - ep = self._endpoint_from_hint_obj(h) + ep = endpoint_from_hint_obj(h, self._tor, self._reactor) desc = describe_hint_obj(h, False, self._tor) d = deferLater(self._reactor, delay, self._connect, ep, desc, is_relay=False) @@ -376,7 +366,7 @@ class Connector(object): for p in priorities: for r in relays[p]: for h in r.hints: - ep = self._endpoint_from_hint_obj(h) + ep = endpoint_from_hint_obj(h, self._tor, self._reactor) desc = describe_hint_obj(h, True, self._tor) d = deferLater(self._reactor, delay, self._connect, ep, desc, is_relay=True) @@ -405,20 +395,6 @@ class Connector(object): d.addCallback(_connected) return d - def _endpoint_from_hint_obj(self, hint): - if self._tor: - if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)): - # this Tor object will throw ValueError for non-public IPv4 - # addresses and any IPv6 address - try: - return self._tor.stream_via(hint.hostname, hint.port) - except ValueError: - return None - return None - if isinstance(hint, DirectTCPV1Hint): - return HostnameEndpoint(self._reactor, hint.hostname, hint.port) - return None - # Connection selection. All instances of DilatedConnectionProtocol which # look viable get passed into our add_contender() method. diff --git a/src/wormhole/_hints.py b/src/wormhole/_hints.py index 8955e97..9b79542 100644 --- a/src/wormhole/_hints.py +++ b/src/wormhole/_hints.py @@ -2,6 +2,7 @@ from __future__ import print_function, unicode_literals import sys import re from collections import namedtuple +from twisted.internet.endpoints import HostnameEndpoint # These namedtuples are "hint objects". The JSON-serializable dictionaries # are "hint dicts". @@ -21,6 +22,17 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) # rest of the V1 protocol. Only one hint per relay is useful. RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) +def describe_hint_obj(hint, relay, tor): + prefix = "tor->" if tor else "->" + if relay: + prefix = prefix + "relay:" + if isinstance(hint, DirectTCPV1Hint): + return prefix + "tcp:%s:%d" % (hint.hostname, hint.port) + elif isinstance(hint, TorTCPV1Hint): + return prefix + "tor:%s:%d" % (hint.hostname, hint.port) + else: + return prefix + str(hint) + def parse_hint_argv(hint, stderr=sys.stderr): assert isinstance(hint, type(u"")) # return tuple or None for an unparseable hint @@ -56,3 +68,17 @@ def parse_hint_argv(hint, stderr=sys.stderr): file=stderr) return None return DirectTCPV1Hint(hint_host, hint_port, priority) + +def endpoint_from_hint_obj(hint, tor, reactor): + if tor: + if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)): + # this Tor object will throw ValueError for non-public IPv4 + # addresses and any IPv6 address + try: + return tor.stream_via(hint.hostname, hint.port) + except ValueError: + return None + return None + if isinstance(hint, DirectTCPV1Hint): + return HostnameEndpoint(reactor, hint.hostname, hint.port) + return None diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index 63216ae..9a53983 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -8,7 +8,7 @@ from collections import namedtuple import six from nacl.exceptions import CryptoError from nacl.secret import SecretBox -from twisted.internet import address, defer, endpoints, error, protocol, task +from twisted.internet import address, defer, endpoints, error, protocol, task, reactor from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.python import log from twisted.test import proto_helpers @@ -18,6 +18,7 @@ import mock from wormhole_transit_relay import transit_server from .. import transit +from .._hints import endpoint_from_hint_obj from ..errors import InternalError from .common import ServerBase @@ -145,30 +146,32 @@ UnknownHint = namedtuple("UnknownHint", ["stuff"]) class Hints(unittest.TestCase): def test_endpoint_from_hint_obj(self): - c = transit.Common("") - efho = c._endpoint_from_hint_obj + def efho(hint, tor=None): + return endpoint_from_hint_obj(hint, tor, reactor) self.assertIsInstance( efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), endpoints.HostnameEndpoint) self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) - # c._tor is currently None - self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None) - c._tor = mock.Mock() + # tor=None + self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None) + + tor = mock.Mock() def tor_ep(hostname, port): if hostname == "non-public": return None return ("tor_ep", hostname, port) + tor.stream_via = mock.Mock(side_effect=tor_ep) - c._tor.stream_via = mock.Mock(side_effect=tor_ep) self.assertEqual( - efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), + efho(transit.DirectTCPV1Hint("host", 1234, 0.0), tor), ("tor_ep", "host", 1234)) self.assertEqual( - efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)), + efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0), tor), ("tor_ep", "host2.onion", 1234)) self.assertEqual( - efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), None) + efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0), tor), None) + self.assertEqual(efho(UnknownHint("foo")), None) def test_comparable(self): @@ -270,10 +273,13 @@ class Hints(unittest.TestCase): def test_describe_hint_obj(self): d = transit.describe_hint_obj self.assertEqual( - d(transit.DirectTCPV1Hint("host", 1234, 0.0)), "tcp:host:1234") + d(transit.DirectTCPV1Hint("host", 1234, 0.0), False, False), + "->tcp:host:1234") self.assertEqual( - d(transit.TorTCPV1Hint("host", 1234, 0.0)), "tor:host:1234") - self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff"))) + d(transit.TorTCPV1Hint("host", 1234, 0.0), False, False), + "->tor:host:1234") + self.assertEqual(d(UnknownHint("stuff"), False, False), + "->%s" % str(UnknownHint("stuff"))) # ipaddrs.py currently uses native strings: bytes on py2, unicode on @@ -1507,7 +1513,7 @@ class Transit(unittest.TestCase): self.assertEqual(self.successResultOf(d), "winner") self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"]) - def _endpoint_from_hint_obj(self, hint): + def _endpoint_from_hint_obj(self, hint, _tor, _reactor): if isinstance(hint, transit.DirectTCPV1Hint): if hint.hostname == "unavailable": return None @@ -1523,20 +1529,21 @@ class Transit(unittest.TestCase): del hints s.add_connection_hints( [DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON, RELAY_HINT_JSON]) - s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector - d = s.connect() - self.assertNoResult(d) - # the direct connectors are tried right away, but the relay - # connectors are stalled for a few seconds - self.assertEqual(self._connectors, ["direct"]) + with mock.patch("wormhole.transit.endpoint_from_hint_obj", + self._endpoint_from_hint_obj): + d = s.connect() + self.assertNoResult(d) + # the direct connectors are tried right away, but the relay + # connectors are stalled for a few seconds + self.assertEqual(self._connectors, ["direct"]) - clock.advance(s.RELAY_DELAY + 1.0) - self.assertEqual(self._connectors, ["direct", "relay"]) + clock.advance(s.RELAY_DELAY + 1.0) + self.assertEqual(self._connectors, ["direct", "relay"]) - self._waiters[0].callback("winner") - self.assertEqual(self.successResultOf(d), "winner") + self._waiters[0].callback("winner") + self.assertEqual(self.successResultOf(d), "winner") @inlineCallbacks def test_priorities(self): @@ -1586,31 +1593,32 @@ class Transit(unittest.TestCase): }] }, ]) - s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector - d = s.connect() - self.assertNoResult(d) - # direct connector should be used first, then the priority=3.0 relay, - # then the two 2.0 relays, then the (default) 0.0 relay + with mock.patch("wormhole.transit.endpoint_from_hint_obj", + self._endpoint_from_hint_obj): + d = s.connect() + self.assertNoResult(d) + # direct connector should be used first, then the priority=3.0 relay, + # then the two 2.0 relays, then the (default) 0.0 relay - self.assertEqual(self._connectors, ["direct"]) + self.assertEqual(self._connectors, ["direct"]) - clock.advance(s.RELAY_DELAY + 1.0) - self.assertEqual(self._connectors, ["direct", "relay3"]) + clock.advance(s.RELAY_DELAY + 1.0) + self.assertEqual(self._connectors, ["direct", "relay3"]) - clock.advance(s.RELAY_DELAY) - self.assertIn(self._connectors, - (["direct", "relay3", "relay2", "relay4"], - ["direct", "relay3", "relay4", "relay2"])) + clock.advance(s.RELAY_DELAY) + self.assertIn(self._connectors, + (["direct", "relay3", "relay2", "relay4"], + ["direct", "relay3", "relay4", "relay2"])) - clock.advance(s.RELAY_DELAY) - self.assertIn(self._connectors, - (["direct", "relay3", "relay2", "relay4", "relay"], - ["direct", "relay3", "relay4", "relay2", "relay"])) + clock.advance(s.RELAY_DELAY) + self.assertIn(self._connectors, + (["direct", "relay3", "relay2", "relay4", "relay"], + ["direct", "relay3", "relay4", "relay2", "relay"])) - self._waiters[0].callback("winner") - self.assertEqual(self.successResultOf(d), "winner") + self._waiters[0].callback("winner") + self.assertEqual(self.successResultOf(d), "winner") @inlineCallbacks def test_no_direct_hints(self): @@ -1624,20 +1632,21 @@ class Transit(unittest.TestCase): UNRECOGNIZED_HINT_JSON, UNAVAILABLE_HINT_JSON, RELAY_HINT2_JSON, UNAVAILABLE_RELAY_HINT_JSON ]) - s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector - d = s.connect() - self.assertNoResult(d) - # since there are no usable direct hints, the relay connector will - # only be stalled for 0 seconds - self.assertEqual(self._connectors, []) + with mock.patch("wormhole.transit.endpoint_from_hint_obj", + self._endpoint_from_hint_obj): + d = s.connect() + self.assertNoResult(d) + # since there are no usable direct hints, the relay connector will + # only be stalled for 0 seconds + self.assertEqual(self._connectors, []) - clock.advance(0) - self.assertEqual(self._connectors, ["relay"]) + clock.advance(0) + self.assertEqual(self._connectors, ["relay"]) - self._waiters[0].callback("winner") - self.assertEqual(self.successResultOf(d), "winner") + self._waiters[0].callback("winner") + self.assertEqual(self.successResultOf(d), "winner") @inlineCallbacks def test_no_contenders(self): @@ -1647,12 +1656,13 @@ class Transit(unittest.TestCase): hints = yield s.get_connection_hints() # start the listener del hints s.add_connection_hints([]) # no hints at all - s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector - d = s.connect() - f = self.failureResultOf(d, transit.TransitError) - self.assertEqual(str(f.value), "No contenders for connection") + with mock.patch("wormhole.transit.endpoint_from_hint_obj", + self._endpoint_from_hint_obj): + d = s.connect() + f = self.failureResultOf(d, transit.TransitError) + self.assertEqual(str(f.value), "No contenders for connection") class RelayHandshake(unittest.TestCase): diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index fa708fb..b122055 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -23,6 +23,8 @@ from . import ipaddrs from .errors import InternalError from .timing import DebugTiming from .util import bytes_to_hexstr +from ._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint, + parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj) def HKDF(skm, outlen, salt=None, CTXinfo=b""): @@ -94,16 +96,6 @@ def build_sided_relay_handshake(key, side): "ascii") + b"\n" -from ._hints import parse_hint_argv, DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint - -def describe_hint_obj(hint): - if isinstance(hint, DirectTCPV1Hint): - return u"tcp:%s:%d" % (hint.hostname, hint.port) - elif isinstance(hint, TorTCPV1Hint): - return u"tor:%s:%d" % (hint.hostname, hint.port) - else: - return str(hint) - TIMEOUT = 60 # seconds @@ -818,13 +810,11 @@ class Common: # Check the hint type to see if we can support it (e.g. skip # onion hints on a non-Tor client). Do not increase relay_delay # unless we have at least one viable hint. - ep = self._endpoint_from_hint_obj(hint_obj) + ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor) if not ep: continue - description = "->%s" % describe_hint_obj(hint_obj) - if self._tor: - description = "tor" + description - d = self._start_connector(ep, description) + d = self._start_connector(ep, + describe_hint_obj(hint_obj, False, self._tor)) contenders.append(d) relay_delay = self.RELAY_DELAY @@ -845,18 +835,15 @@ class Common: for priority in sorted(prioritized_relays, reverse=True): for hint_obj in prioritized_relays[priority]: - ep = self._endpoint_from_hint_obj(hint_obj) + ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor) if not ep: continue - description = "->relay:%s" % describe_hint_obj(hint_obj) - if self._tor: - description = "tor" + description d = task.deferLater( self._reactor, relay_delay, self._start_connector, ep, - description, + describe_hint_obj(hint_obj, True, self._tor), is_relay=True) contenders.append(d) relay_delay += self.RELAY_DELAY @@ -894,21 +881,6 @@ class Common: d.addCallback(lambda p: p.startNegotiation()) return d - def _endpoint_from_hint_obj(self, hint): - if self._tor: - if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)): - # this Tor object will throw ValueError for non-public IPv4 - # addresses and any IPv6 address - try: - return self._tor.stream_via(hint.hostname, hint.port) - except ValueError: - return None - return None - if isinstance(hint, DirectTCPV1Hint): - return endpoints.HostnameEndpoint(self._reactor, hint.hostname, - hint.port) - return None - def connection_ready(self, p): # inbound/outbound Connection protocols call this when they finish # negotiation. The first one wins and gets a "go". Any subsequent