factor out describe_hint_obj and endpoint_from_hint_obj

This commit is contained in:
Brian Warner 2018-12-21 23:12:17 -05:00
parent bd1a199f3e
commit 2f4e4d3031
4 changed files with 104 additions and 120 deletions

View File

@ -8,7 +8,7 @@ from automat import MethodicalMachine
from zope.interface import implementer
from twisted.internet.task import deferLater
from twisted.internet.defer import DeferredList
from twisted.internet.endpoints import HostnameEndpoint, serverFromString
from twisted.internet.endpoints import serverFromString
from twisted.internet.protocol import ClientFactory, ServerFactory
from twisted.python import log
from hkdf import Hkdf
@ -19,18 +19,8 @@ from ..observer import EmptyableSet
from .connection import DilatedConnectionProtocol, KCM
from .roles import LEADER
from .._hints import parse_hint_argv, DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint
def describe_hint_obj(hint, relay, tor):
prefix = "tor->" if tor else "->"
if relay:
prefix = prefix + "relay:"
if isinstance(hint, DirectTCPV1Hint):
return prefix + "tcp:%s:%d" % (hint.hostname, hint.port)
elif isinstance(hint, TorTCPV1Hint):
return prefix + "tor:%s:%d" % (hint.hostname, hint.port)
else:
return prefix + str(hint)
from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint,
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj)
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
@ -340,7 +330,7 @@ class Connector(object):
for h in direct[p]:
if isinstance(h, TorTCPV1Hint) and not self._tor:
continue
ep = self._endpoint_from_hint_obj(h)
ep = endpoint_from_hint_obj(h, self._tor, self._reactor)
desc = describe_hint_obj(h, False, self._tor)
d = deferLater(self._reactor, delay,
self._connect, ep, desc, is_relay=False)
@ -376,7 +366,7 @@ class Connector(object):
for p in priorities:
for r in relays[p]:
for h in r.hints:
ep = self._endpoint_from_hint_obj(h)
ep = endpoint_from_hint_obj(h, self._tor, self._reactor)
desc = describe_hint_obj(h, True, self._tor)
d = deferLater(self._reactor, delay,
self._connect, ep, desc, is_relay=True)
@ -405,20 +395,6 @@ class Connector(object):
d.addCallback(_connected)
return d
def _endpoint_from_hint_obj(self, hint):
if self._tor:
if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
# this Tor object will throw ValueError for non-public IPv4
# addresses and any IPv6 address
try:
return self._tor.stream_via(hint.hostname, hint.port)
except ValueError:
return None
return None
if isinstance(hint, DirectTCPV1Hint):
return HostnameEndpoint(self._reactor, hint.hostname, hint.port)
return None
# Connection selection. All instances of DilatedConnectionProtocol which
# look viable get passed into our add_contender() method.

View File

@ -2,6 +2,7 @@ from __future__ import print_function, unicode_literals
import sys
import re
from collections import namedtuple
from twisted.internet.endpoints import HostnameEndpoint
# These namedtuples are "hint objects". The JSON-serializable dictionaries
# are "hint dicts".
@ -21,6 +22,17 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# rest of the V1 protocol. Only one hint per relay is useful.
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
def describe_hint_obj(hint, relay, tor):
prefix = "tor->" if tor else "->"
if relay:
prefix = prefix + "relay:"
if isinstance(hint, DirectTCPV1Hint):
return prefix + "tcp:%s:%d" % (hint.hostname, hint.port)
elif isinstance(hint, TorTCPV1Hint):
return prefix + "tor:%s:%d" % (hint.hostname, hint.port)
else:
return prefix + str(hint)
def parse_hint_argv(hint, stderr=sys.stderr):
assert isinstance(hint, type(u""))
# return tuple or None for an unparseable hint
@ -56,3 +68,17 @@ def parse_hint_argv(hint, stderr=sys.stderr):
file=stderr)
return None
return DirectTCPV1Hint(hint_host, hint_port, priority)
def endpoint_from_hint_obj(hint, tor, reactor):
if tor:
if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
# this Tor object will throw ValueError for non-public IPv4
# addresses and any IPv6 address
try:
return tor.stream_via(hint.hostname, hint.port)
except ValueError:
return None
return None
if isinstance(hint, DirectTCPV1Hint):
return HostnameEndpoint(reactor, hint.hostname, hint.port)
return None

View File

@ -8,7 +8,7 @@ from collections import namedtuple
import six
from nacl.exceptions import CryptoError
from nacl.secret import SecretBox
from twisted.internet import address, defer, endpoints, error, protocol, task
from twisted.internet import address, defer, endpoints, error, protocol, task, reactor
from twisted.internet.defer import gatherResults, inlineCallbacks
from twisted.python import log
from twisted.test import proto_helpers
@ -18,6 +18,7 @@ import mock
from wormhole_transit_relay import transit_server
from .. import transit
from .._hints import endpoint_from_hint_obj
from ..errors import InternalError
from .common import ServerBase
@ -145,30 +146,32 @@ UnknownHint = namedtuple("UnknownHint", ["stuff"])
class Hints(unittest.TestCase):
def test_endpoint_from_hint_obj(self):
c = transit.Common("")
efho = c._endpoint_from_hint_obj
def efho(hint, tor=None):
return endpoint_from_hint_obj(hint, tor, reactor)
self.assertIsInstance(
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
endpoints.HostnameEndpoint)
self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None)
# c._tor is currently None
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None)
c._tor = mock.Mock()
# tor=None
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None)
tor = mock.Mock()
def tor_ep(hostname, port):
if hostname == "non-public":
return None
return ("tor_ep", hostname, port)
tor.stream_via = mock.Mock(side_effect=tor_ep)
c._tor.stream_via = mock.Mock(side_effect=tor_ep)
self.assertEqual(
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
efho(transit.DirectTCPV1Hint("host", 1234, 0.0), tor),
("tor_ep", "host", 1234))
self.assertEqual(
efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)),
efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0), tor),
("tor_ep", "host2.onion", 1234))
self.assertEqual(
efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), None)
efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0), tor), None)
self.assertEqual(efho(UnknownHint("foo")), None)
def test_comparable(self):
@ -270,10 +273,13 @@ class Hints(unittest.TestCase):
def test_describe_hint_obj(self):
d = transit.describe_hint_obj
self.assertEqual(
d(transit.DirectTCPV1Hint("host", 1234, 0.0)), "tcp:host:1234")
d(transit.DirectTCPV1Hint("host", 1234, 0.0), False, False),
"->tcp:host:1234")
self.assertEqual(
d(transit.TorTCPV1Hint("host", 1234, 0.0)), "tor:host:1234")
self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff")))
d(transit.TorTCPV1Hint("host", 1234, 0.0), False, False),
"->tor:host:1234")
self.assertEqual(d(UnknownHint("stuff"), False, False),
"->%s" % str(UnknownHint("stuff")))
# ipaddrs.py currently uses native strings: bytes on py2, unicode on
@ -1507,7 +1513,7 @@ class Transit(unittest.TestCase):
self.assertEqual(self.successResultOf(d), "winner")
self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"])
def _endpoint_from_hint_obj(self, hint):
def _endpoint_from_hint_obj(self, hint, _tor, _reactor):
if isinstance(hint, transit.DirectTCPV1Hint):
if hint.hostname == "unavailable":
return None
@ -1523,20 +1529,21 @@ class Transit(unittest.TestCase):
del hints
s.add_connection_hints(
[DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON, RELAY_HINT_JSON])
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector
d = s.connect()
self.assertNoResult(d)
# the direct connectors are tried right away, but the relay
# connectors are stalled for a few seconds
self.assertEqual(self._connectors, ["direct"])
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self._endpoint_from_hint_obj):
d = s.connect()
self.assertNoResult(d)
# the direct connectors are tried right away, but the relay
# connectors are stalled for a few seconds
self.assertEqual(self._connectors, ["direct"])
clock.advance(s.RELAY_DELAY + 1.0)
self.assertEqual(self._connectors, ["direct", "relay"])
clock.advance(s.RELAY_DELAY + 1.0)
self.assertEqual(self._connectors, ["direct", "relay"])
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
@inlineCallbacks
def test_priorities(self):
@ -1586,31 +1593,32 @@ class Transit(unittest.TestCase):
}]
},
])
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector
d = s.connect()
self.assertNoResult(d)
# direct connector should be used first, then the priority=3.0 relay,
# then the two 2.0 relays, then the (default) 0.0 relay
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self._endpoint_from_hint_obj):
d = s.connect()
self.assertNoResult(d)
# direct connector should be used first, then the priority=3.0 relay,
# then the two 2.0 relays, then the (default) 0.0 relay
self.assertEqual(self._connectors, ["direct"])
self.assertEqual(self._connectors, ["direct"])
clock.advance(s.RELAY_DELAY + 1.0)
self.assertEqual(self._connectors, ["direct", "relay3"])
clock.advance(s.RELAY_DELAY + 1.0)
self.assertEqual(self._connectors, ["direct", "relay3"])
clock.advance(s.RELAY_DELAY)
self.assertIn(self._connectors,
(["direct", "relay3", "relay2", "relay4"],
["direct", "relay3", "relay4", "relay2"]))
clock.advance(s.RELAY_DELAY)
self.assertIn(self._connectors,
(["direct", "relay3", "relay2", "relay4"],
["direct", "relay3", "relay4", "relay2"]))
clock.advance(s.RELAY_DELAY)
self.assertIn(self._connectors,
(["direct", "relay3", "relay2", "relay4", "relay"],
["direct", "relay3", "relay4", "relay2", "relay"]))
clock.advance(s.RELAY_DELAY)
self.assertIn(self._connectors,
(["direct", "relay3", "relay2", "relay4", "relay"],
["direct", "relay3", "relay4", "relay2", "relay"]))
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
@inlineCallbacks
def test_no_direct_hints(self):
@ -1624,20 +1632,21 @@ class Transit(unittest.TestCase):
UNRECOGNIZED_HINT_JSON, UNAVAILABLE_HINT_JSON, RELAY_HINT2_JSON,
UNAVAILABLE_RELAY_HINT_JSON
])
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector
d = s.connect()
self.assertNoResult(d)
# since there are no usable direct hints, the relay connector will
# only be stalled for 0 seconds
self.assertEqual(self._connectors, [])
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self._endpoint_from_hint_obj):
d = s.connect()
self.assertNoResult(d)
# since there are no usable direct hints, the relay connector will
# only be stalled for 0 seconds
self.assertEqual(self._connectors, [])
clock.advance(0)
self.assertEqual(self._connectors, ["relay"])
clock.advance(0)
self.assertEqual(self._connectors, ["relay"])
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
@inlineCallbacks
def test_no_contenders(self):
@ -1647,12 +1656,13 @@ class Transit(unittest.TestCase):
hints = yield s.get_connection_hints() # start the listener
del hints
s.add_connection_hints([]) # no hints at all
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector
d = s.connect()
f = self.failureResultOf(d, transit.TransitError)
self.assertEqual(str(f.value), "No contenders for connection")
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self._endpoint_from_hint_obj):
d = s.connect()
f = self.failureResultOf(d, transit.TransitError)
self.assertEqual(str(f.value), "No contenders for connection")
class RelayHandshake(unittest.TestCase):

View File

@ -23,6 +23,8 @@ from . import ipaddrs
from .errors import InternalError
from .timing import DebugTiming
from .util import bytes_to_hexstr
from ._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint,
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj)
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
@ -94,16 +96,6 @@ def build_sided_relay_handshake(key, side):
"ascii") + b"\n"
from ._hints import parse_hint_argv, DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint
def describe_hint_obj(hint):
if isinstance(hint, DirectTCPV1Hint):
return u"tcp:%s:%d" % (hint.hostname, hint.port)
elif isinstance(hint, TorTCPV1Hint):
return u"tor:%s:%d" % (hint.hostname, hint.port)
else:
return str(hint)
TIMEOUT = 60 # seconds
@ -818,13 +810,11 @@ class Common:
# Check the hint type to see if we can support it (e.g. skip
# onion hints on a non-Tor client). Do not increase relay_delay
# unless we have at least one viable hint.
ep = self._endpoint_from_hint_obj(hint_obj)
ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor)
if not ep:
continue
description = "->%s" % describe_hint_obj(hint_obj)
if self._tor:
description = "tor" + description
d = self._start_connector(ep, description)
d = self._start_connector(ep,
describe_hint_obj(hint_obj, False, self._tor))
contenders.append(d)
relay_delay = self.RELAY_DELAY
@ -845,18 +835,15 @@ class Common:
for priority in sorted(prioritized_relays, reverse=True):
for hint_obj in prioritized_relays[priority]:
ep = self._endpoint_from_hint_obj(hint_obj)
ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor)
if not ep:
continue
description = "->relay:%s" % describe_hint_obj(hint_obj)
if self._tor:
description = "tor" + description
d = task.deferLater(
self._reactor,
relay_delay,
self._start_connector,
ep,
description,
describe_hint_obj(hint_obj, True, self._tor),
is_relay=True)
contenders.append(d)
relay_delay += self.RELAY_DELAY
@ -894,21 +881,6 @@ class Common:
d.addCallback(lambda p: p.startNegotiation())
return d
def _endpoint_from_hint_obj(self, hint):
if self._tor:
if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
# this Tor object will throw ValueError for non-public IPv4
# addresses and any IPv6 address
try:
return self._tor.stream_via(hint.hostname, hint.port)
except ValueError:
return None
return None
if isinstance(hint, DirectTCPV1Hint):
return endpoints.HostnameEndpoint(self._reactor, hint.hostname,
hint.port)
return None
def connection_ready(self, p):
# inbound/outbound Connection protocols call this when they finish
# negotiation. The first one wins and gets a "go". Any subsequent