factor out describe_hint_obj and endpoint_from_hint_obj

This commit is contained in:
Brian Warner 2018-12-21 23:12:17 -05:00
parent bd1a199f3e
commit 2f4e4d3031
4 changed files with 104 additions and 120 deletions

View File

@ -8,7 +8,7 @@ from automat import MethodicalMachine
from zope.interface import implementer from zope.interface import implementer
from twisted.internet.task import deferLater from twisted.internet.task import deferLater
from twisted.internet.defer import DeferredList from twisted.internet.defer import DeferredList
from twisted.internet.endpoints import HostnameEndpoint, serverFromString from twisted.internet.endpoints import serverFromString
from twisted.internet.protocol import ClientFactory, ServerFactory from twisted.internet.protocol import ClientFactory, ServerFactory
from twisted.python import log from twisted.python import log
from hkdf import Hkdf from hkdf import Hkdf
@ -19,18 +19,8 @@ from ..observer import EmptyableSet
from .connection import DilatedConnectionProtocol, KCM from .connection import DilatedConnectionProtocol, KCM
from .roles import LEADER from .roles import LEADER
from .._hints import parse_hint_argv, DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint,
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj)
def describe_hint_obj(hint, relay, tor):
prefix = "tor->" if tor else "->"
if relay:
prefix = prefix + "relay:"
if isinstance(hint, DirectTCPV1Hint):
return prefix + "tcp:%s:%d" % (hint.hostname, hint.port)
elif isinstance(hint, TorTCPV1Hint):
return prefix + "tor:%s:%d" % (hint.hostname, hint.port)
else:
return prefix + str(hint)
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
@ -340,7 +330,7 @@ class Connector(object):
for h in direct[p]: for h in direct[p]:
if isinstance(h, TorTCPV1Hint) and not self._tor: if isinstance(h, TorTCPV1Hint) and not self._tor:
continue continue
ep = self._endpoint_from_hint_obj(h) ep = endpoint_from_hint_obj(h, self._tor, self._reactor)
desc = describe_hint_obj(h, False, self._tor) desc = describe_hint_obj(h, False, self._tor)
d = deferLater(self._reactor, delay, d = deferLater(self._reactor, delay,
self._connect, ep, desc, is_relay=False) self._connect, ep, desc, is_relay=False)
@ -376,7 +366,7 @@ class Connector(object):
for p in priorities: for p in priorities:
for r in relays[p]: for r in relays[p]:
for h in r.hints: for h in r.hints:
ep = self._endpoint_from_hint_obj(h) ep = endpoint_from_hint_obj(h, self._tor, self._reactor)
desc = describe_hint_obj(h, True, self._tor) desc = describe_hint_obj(h, True, self._tor)
d = deferLater(self._reactor, delay, d = deferLater(self._reactor, delay,
self._connect, ep, desc, is_relay=True) self._connect, ep, desc, is_relay=True)
@ -405,20 +395,6 @@ class Connector(object):
d.addCallback(_connected) d.addCallback(_connected)
return d return d
def _endpoint_from_hint_obj(self, hint):
if self._tor:
if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
# this Tor object will throw ValueError for non-public IPv4
# addresses and any IPv6 address
try:
return self._tor.stream_via(hint.hostname, hint.port)
except ValueError:
return None
return None
if isinstance(hint, DirectTCPV1Hint):
return HostnameEndpoint(self._reactor, hint.hostname, hint.port)
return None
# Connection selection. All instances of DilatedConnectionProtocol which # Connection selection. All instances of DilatedConnectionProtocol which
# look viable get passed into our add_contender() method. # look viable get passed into our add_contender() method.

View File

@ -2,6 +2,7 @@ from __future__ import print_function, unicode_literals
import sys import sys
import re import re
from collections import namedtuple from collections import namedtuple
from twisted.internet.endpoints import HostnameEndpoint
# These namedtuples are "hint objects". The JSON-serializable dictionaries # These namedtuples are "hint objects". The JSON-serializable dictionaries
# are "hint dicts". # are "hint dicts".
@ -21,6 +22,17 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# rest of the V1 protocol. Only one hint per relay is useful. # rest of the V1 protocol. Only one hint per relay is useful.
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
def describe_hint_obj(hint, relay, tor):
prefix = "tor->" if tor else "->"
if relay:
prefix = prefix + "relay:"
if isinstance(hint, DirectTCPV1Hint):
return prefix + "tcp:%s:%d" % (hint.hostname, hint.port)
elif isinstance(hint, TorTCPV1Hint):
return prefix + "tor:%s:%d" % (hint.hostname, hint.port)
else:
return prefix + str(hint)
def parse_hint_argv(hint, stderr=sys.stderr): def parse_hint_argv(hint, stderr=sys.stderr):
assert isinstance(hint, type(u"")) assert isinstance(hint, type(u""))
# return tuple or None for an unparseable hint # return tuple or None for an unparseable hint
@ -56,3 +68,17 @@ def parse_hint_argv(hint, stderr=sys.stderr):
file=stderr) file=stderr)
return None return None
return DirectTCPV1Hint(hint_host, hint_port, priority) return DirectTCPV1Hint(hint_host, hint_port, priority)
def endpoint_from_hint_obj(hint, tor, reactor):
if tor:
if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
# this Tor object will throw ValueError for non-public IPv4
# addresses and any IPv6 address
try:
return tor.stream_via(hint.hostname, hint.port)
except ValueError:
return None
return None
if isinstance(hint, DirectTCPV1Hint):
return HostnameEndpoint(reactor, hint.hostname, hint.port)
return None

View File

@ -8,7 +8,7 @@ from collections import namedtuple
import six import six
from nacl.exceptions import CryptoError from nacl.exceptions import CryptoError
from nacl.secret import SecretBox from nacl.secret import SecretBox
from twisted.internet import address, defer, endpoints, error, protocol, task from twisted.internet import address, defer, endpoints, error, protocol, task, reactor
from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.internet.defer import gatherResults, inlineCallbacks
from twisted.python import log from twisted.python import log
from twisted.test import proto_helpers from twisted.test import proto_helpers
@ -18,6 +18,7 @@ import mock
from wormhole_transit_relay import transit_server from wormhole_transit_relay import transit_server
from .. import transit from .. import transit
from .._hints import endpoint_from_hint_obj
from ..errors import InternalError from ..errors import InternalError
from .common import ServerBase from .common import ServerBase
@ -145,30 +146,32 @@ UnknownHint = namedtuple("UnknownHint", ["stuff"])
class Hints(unittest.TestCase): class Hints(unittest.TestCase):
def test_endpoint_from_hint_obj(self): def test_endpoint_from_hint_obj(self):
c = transit.Common("") def efho(hint, tor=None):
efho = c._endpoint_from_hint_obj return endpoint_from_hint_obj(hint, tor, reactor)
self.assertIsInstance( self.assertIsInstance(
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
endpoints.HostnameEndpoint) endpoints.HostnameEndpoint)
self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None)
# c._tor is currently None
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None)
c._tor = mock.Mock()
# tor=None
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None)
tor = mock.Mock()
def tor_ep(hostname, port): def tor_ep(hostname, port):
if hostname == "non-public": if hostname == "non-public":
return None return None
return ("tor_ep", hostname, port) return ("tor_ep", hostname, port)
tor.stream_via = mock.Mock(side_effect=tor_ep)
c._tor.stream_via = mock.Mock(side_effect=tor_ep)
self.assertEqual( self.assertEqual(
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), efho(transit.DirectTCPV1Hint("host", 1234, 0.0), tor),
("tor_ep", "host", 1234)) ("tor_ep", "host", 1234))
self.assertEqual( self.assertEqual(
efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)), efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0), tor),
("tor_ep", "host2.onion", 1234)) ("tor_ep", "host2.onion", 1234))
self.assertEqual( self.assertEqual(
efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), None) efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0), tor), None)
self.assertEqual(efho(UnknownHint("foo")), None) self.assertEqual(efho(UnknownHint("foo")), None)
def test_comparable(self): def test_comparable(self):
@ -270,10 +273,13 @@ class Hints(unittest.TestCase):
def test_describe_hint_obj(self): def test_describe_hint_obj(self):
d = transit.describe_hint_obj d = transit.describe_hint_obj
self.assertEqual( self.assertEqual(
d(transit.DirectTCPV1Hint("host", 1234, 0.0)), "tcp:host:1234") d(transit.DirectTCPV1Hint("host", 1234, 0.0), False, False),
"->tcp:host:1234")
self.assertEqual( self.assertEqual(
d(transit.TorTCPV1Hint("host", 1234, 0.0)), "tor:host:1234") d(transit.TorTCPV1Hint("host", 1234, 0.0), False, False),
self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff"))) "->tor:host:1234")
self.assertEqual(d(UnknownHint("stuff"), False, False),
"->%s" % str(UnknownHint("stuff")))
# ipaddrs.py currently uses native strings: bytes on py2, unicode on # ipaddrs.py currently uses native strings: bytes on py2, unicode on
@ -1507,7 +1513,7 @@ class Transit(unittest.TestCase):
self.assertEqual(self.successResultOf(d), "winner") self.assertEqual(self.successResultOf(d), "winner")
self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"]) self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"])
def _endpoint_from_hint_obj(self, hint): def _endpoint_from_hint_obj(self, hint, _tor, _reactor):
if isinstance(hint, transit.DirectTCPV1Hint): if isinstance(hint, transit.DirectTCPV1Hint):
if hint.hostname == "unavailable": if hint.hostname == "unavailable":
return None return None
@ -1523,20 +1529,21 @@ class Transit(unittest.TestCase):
del hints del hints
s.add_connection_hints( s.add_connection_hints(
[DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON, RELAY_HINT_JSON]) [DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON, RELAY_HINT_JSON])
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector s._start_connector = self._start_connector
d = s.connect() with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self.assertNoResult(d) self._endpoint_from_hint_obj):
# the direct connectors are tried right away, but the relay d = s.connect()
# connectors are stalled for a few seconds self.assertNoResult(d)
self.assertEqual(self._connectors, ["direct"]) # the direct connectors are tried right away, but the relay
# connectors are stalled for a few seconds
self.assertEqual(self._connectors, ["direct"])
clock.advance(s.RELAY_DELAY + 1.0) clock.advance(s.RELAY_DELAY + 1.0)
self.assertEqual(self._connectors, ["direct", "relay"]) self.assertEqual(self._connectors, ["direct", "relay"])
self._waiters[0].callback("winner") self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner") self.assertEqual(self.successResultOf(d), "winner")
@inlineCallbacks @inlineCallbacks
def test_priorities(self): def test_priorities(self):
@ -1586,31 +1593,32 @@ class Transit(unittest.TestCase):
}] }]
}, },
]) ])
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector s._start_connector = self._start_connector
d = s.connect() with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self.assertNoResult(d) self._endpoint_from_hint_obj):
# direct connector should be used first, then the priority=3.0 relay, d = s.connect()
# then the two 2.0 relays, then the (default) 0.0 relay self.assertNoResult(d)
# direct connector should be used first, then the priority=3.0 relay,
# then the two 2.0 relays, then the (default) 0.0 relay
self.assertEqual(self._connectors, ["direct"]) self.assertEqual(self._connectors, ["direct"])
clock.advance(s.RELAY_DELAY + 1.0) clock.advance(s.RELAY_DELAY + 1.0)
self.assertEqual(self._connectors, ["direct", "relay3"]) self.assertEqual(self._connectors, ["direct", "relay3"])
clock.advance(s.RELAY_DELAY) clock.advance(s.RELAY_DELAY)
self.assertIn(self._connectors, self.assertIn(self._connectors,
(["direct", "relay3", "relay2", "relay4"], (["direct", "relay3", "relay2", "relay4"],
["direct", "relay3", "relay4", "relay2"])) ["direct", "relay3", "relay4", "relay2"]))
clock.advance(s.RELAY_DELAY) clock.advance(s.RELAY_DELAY)
self.assertIn(self._connectors, self.assertIn(self._connectors,
(["direct", "relay3", "relay2", "relay4", "relay"], (["direct", "relay3", "relay2", "relay4", "relay"],
["direct", "relay3", "relay4", "relay2", "relay"])) ["direct", "relay3", "relay4", "relay2", "relay"]))
self._waiters[0].callback("winner") self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner") self.assertEqual(self.successResultOf(d), "winner")
@inlineCallbacks @inlineCallbacks
def test_no_direct_hints(self): def test_no_direct_hints(self):
@ -1624,20 +1632,21 @@ class Transit(unittest.TestCase):
UNRECOGNIZED_HINT_JSON, UNAVAILABLE_HINT_JSON, RELAY_HINT2_JSON, UNRECOGNIZED_HINT_JSON, UNAVAILABLE_HINT_JSON, RELAY_HINT2_JSON,
UNAVAILABLE_RELAY_HINT_JSON UNAVAILABLE_RELAY_HINT_JSON
]) ])
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector s._start_connector = self._start_connector
d = s.connect() with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self.assertNoResult(d) self._endpoint_from_hint_obj):
# since there are no usable direct hints, the relay connector will d = s.connect()
# only be stalled for 0 seconds self.assertNoResult(d)
self.assertEqual(self._connectors, []) # since there are no usable direct hints, the relay connector will
# only be stalled for 0 seconds
self.assertEqual(self._connectors, [])
clock.advance(0) clock.advance(0)
self.assertEqual(self._connectors, ["relay"]) self.assertEqual(self._connectors, ["relay"])
self._waiters[0].callback("winner") self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner") self.assertEqual(self.successResultOf(d), "winner")
@inlineCallbacks @inlineCallbacks
def test_no_contenders(self): def test_no_contenders(self):
@ -1647,12 +1656,13 @@ class Transit(unittest.TestCase):
hints = yield s.get_connection_hints() # start the listener hints = yield s.get_connection_hints() # start the listener
del hints del hints
s.add_connection_hints([]) # no hints at all s.add_connection_hints([]) # no hints at all
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector s._start_connector = self._start_connector
d = s.connect() with mock.patch("wormhole.transit.endpoint_from_hint_obj",
f = self.failureResultOf(d, transit.TransitError) self._endpoint_from_hint_obj):
self.assertEqual(str(f.value), "No contenders for connection") d = s.connect()
f = self.failureResultOf(d, transit.TransitError)
self.assertEqual(str(f.value), "No contenders for connection")
class RelayHandshake(unittest.TestCase): class RelayHandshake(unittest.TestCase):

View File

@ -23,6 +23,8 @@ from . import ipaddrs
from .errors import InternalError from .errors import InternalError
from .timing import DebugTiming from .timing import DebugTiming
from .util import bytes_to_hexstr from .util import bytes_to_hexstr
from ._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint,
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj)
def HKDF(skm, outlen, salt=None, CTXinfo=b""): def HKDF(skm, outlen, salt=None, CTXinfo=b""):
@ -94,16 +96,6 @@ def build_sided_relay_handshake(key, side):
"ascii") + b"\n" "ascii") + b"\n"
from ._hints import parse_hint_argv, DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint
def describe_hint_obj(hint):
if isinstance(hint, DirectTCPV1Hint):
return u"tcp:%s:%d" % (hint.hostname, hint.port)
elif isinstance(hint, TorTCPV1Hint):
return u"tor:%s:%d" % (hint.hostname, hint.port)
else:
return str(hint)
TIMEOUT = 60 # seconds TIMEOUT = 60 # seconds
@ -818,13 +810,11 @@ class Common:
# Check the hint type to see if we can support it (e.g. skip # Check the hint type to see if we can support it (e.g. skip
# onion hints on a non-Tor client). Do not increase relay_delay # onion hints on a non-Tor client). Do not increase relay_delay
# unless we have at least one viable hint. # unless we have at least one viable hint.
ep = self._endpoint_from_hint_obj(hint_obj) ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor)
if not ep: if not ep:
continue continue
description = "->%s" % describe_hint_obj(hint_obj) d = self._start_connector(ep,
if self._tor: describe_hint_obj(hint_obj, False, self._tor))
description = "tor" + description
d = self._start_connector(ep, description)
contenders.append(d) contenders.append(d)
relay_delay = self.RELAY_DELAY relay_delay = self.RELAY_DELAY
@ -845,18 +835,15 @@ class Common:
for priority in sorted(prioritized_relays, reverse=True): for priority in sorted(prioritized_relays, reverse=True):
for hint_obj in prioritized_relays[priority]: for hint_obj in prioritized_relays[priority]:
ep = self._endpoint_from_hint_obj(hint_obj) ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor)
if not ep: if not ep:
continue continue
description = "->relay:%s" % describe_hint_obj(hint_obj)
if self._tor:
description = "tor" + description
d = task.deferLater( d = task.deferLater(
self._reactor, self._reactor,
relay_delay, relay_delay,
self._start_connector, self._start_connector,
ep, ep,
description, describe_hint_obj(hint_obj, True, self._tor),
is_relay=True) is_relay=True)
contenders.append(d) contenders.append(d)
relay_delay += self.RELAY_DELAY relay_delay += self.RELAY_DELAY
@ -894,21 +881,6 @@ class Common:
d.addCallback(lambda p: p.startNegotiation()) d.addCallback(lambda p: p.startNegotiation())
return d return d
def _endpoint_from_hint_obj(self, hint):
if self._tor:
if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
# this Tor object will throw ValueError for non-public IPv4
# addresses and any IPv6 address
try:
return self._tor.stream_via(hint.hostname, hint.port)
except ValueError:
return None
return None
if isinstance(hint, DirectTCPV1Hint):
return endpoints.HostnameEndpoint(self._reactor, hint.hostname,
hint.port)
return None
def connection_ready(self, p): def connection_ready(self, p):
# inbound/outbound Connection protocols call this when they finish # inbound/outbound Connection protocols call this when they finish
# negotiation. The first one wins and gets a "go". Any subsequent # negotiation. The first one wins and gets a "go". Any subsequent