factor out describe_hint_obj and endpoint_from_hint_obj
This commit is contained in:
parent
bd1a199f3e
commit
2f4e4d3031
|
@ -8,7 +8,7 @@ from automat import MethodicalMachine
|
|||
from zope.interface import implementer
|
||||
from twisted.internet.task import deferLater
|
||||
from twisted.internet.defer import DeferredList
|
||||
from twisted.internet.endpoints import HostnameEndpoint, serverFromString
|
||||
from twisted.internet.endpoints import serverFromString
|
||||
from twisted.internet.protocol import ClientFactory, ServerFactory
|
||||
from twisted.python import log
|
||||
from hkdf import Hkdf
|
||||
|
@ -19,18 +19,8 @@ from ..observer import EmptyableSet
|
|||
from .connection import DilatedConnectionProtocol, KCM
|
||||
from .roles import LEADER
|
||||
|
||||
from .._hints import parse_hint_argv, DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint
|
||||
|
||||
def describe_hint_obj(hint, relay, tor):
|
||||
prefix = "tor->" if tor else "->"
|
||||
if relay:
|
||||
prefix = prefix + "relay:"
|
||||
if isinstance(hint, DirectTCPV1Hint):
|
||||
return prefix + "tcp:%s:%d" % (hint.hostname, hint.port)
|
||||
elif isinstance(hint, TorTCPV1Hint):
|
||||
return prefix + "tor:%s:%d" % (hint.hostname, hint.port)
|
||||
else:
|
||||
return prefix + str(hint)
|
||||
from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint,
|
||||
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj)
|
||||
|
||||
|
||||
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
|
||||
|
@ -340,7 +330,7 @@ class Connector(object):
|
|||
for h in direct[p]:
|
||||
if isinstance(h, TorTCPV1Hint) and not self._tor:
|
||||
continue
|
||||
ep = self._endpoint_from_hint_obj(h)
|
||||
ep = endpoint_from_hint_obj(h, self._tor, self._reactor)
|
||||
desc = describe_hint_obj(h, False, self._tor)
|
||||
d = deferLater(self._reactor, delay,
|
||||
self._connect, ep, desc, is_relay=False)
|
||||
|
@ -376,7 +366,7 @@ class Connector(object):
|
|||
for p in priorities:
|
||||
for r in relays[p]:
|
||||
for h in r.hints:
|
||||
ep = self._endpoint_from_hint_obj(h)
|
||||
ep = endpoint_from_hint_obj(h, self._tor, self._reactor)
|
||||
desc = describe_hint_obj(h, True, self._tor)
|
||||
d = deferLater(self._reactor, delay,
|
||||
self._connect, ep, desc, is_relay=True)
|
||||
|
@ -405,20 +395,6 @@ class Connector(object):
|
|||
d.addCallback(_connected)
|
||||
return d
|
||||
|
||||
def _endpoint_from_hint_obj(self, hint):
|
||||
if self._tor:
|
||||
if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
|
||||
# this Tor object will throw ValueError for non-public IPv4
|
||||
# addresses and any IPv6 address
|
||||
try:
|
||||
return self._tor.stream_via(hint.hostname, hint.port)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
if isinstance(hint, DirectTCPV1Hint):
|
||||
return HostnameEndpoint(self._reactor, hint.hostname, hint.port)
|
||||
return None
|
||||
|
||||
# Connection selection. All instances of DilatedConnectionProtocol which
|
||||
# look viable get passed into our add_contender() method.
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import print_function, unicode_literals
|
|||
import sys
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from twisted.internet.endpoints import HostnameEndpoint
|
||||
|
||||
# These namedtuples are "hint objects". The JSON-serializable dictionaries
|
||||
# are "hint dicts".
|
||||
|
@ -21,6 +22,17 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
|
|||
# rest of the V1 protocol. Only one hint per relay is useful.
|
||||
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
|
||||
|
||||
def describe_hint_obj(hint, relay, tor):
|
||||
prefix = "tor->" if tor else "->"
|
||||
if relay:
|
||||
prefix = prefix + "relay:"
|
||||
if isinstance(hint, DirectTCPV1Hint):
|
||||
return prefix + "tcp:%s:%d" % (hint.hostname, hint.port)
|
||||
elif isinstance(hint, TorTCPV1Hint):
|
||||
return prefix + "tor:%s:%d" % (hint.hostname, hint.port)
|
||||
else:
|
||||
return prefix + str(hint)
|
||||
|
||||
def parse_hint_argv(hint, stderr=sys.stderr):
|
||||
assert isinstance(hint, type(u""))
|
||||
# return tuple or None for an unparseable hint
|
||||
|
@ -56,3 +68,17 @@ def parse_hint_argv(hint, stderr=sys.stderr):
|
|||
file=stderr)
|
||||
return None
|
||||
return DirectTCPV1Hint(hint_host, hint_port, priority)
|
||||
|
||||
def endpoint_from_hint_obj(hint, tor, reactor):
|
||||
if tor:
|
||||
if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
|
||||
# this Tor object will throw ValueError for non-public IPv4
|
||||
# addresses and any IPv6 address
|
||||
try:
|
||||
return tor.stream_via(hint.hostname, hint.port)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
if isinstance(hint, DirectTCPV1Hint):
|
||||
return HostnameEndpoint(reactor, hint.hostname, hint.port)
|
||||
return None
|
||||
|
|
|
@ -8,7 +8,7 @@ from collections import namedtuple
|
|||
import six
|
||||
from nacl.exceptions import CryptoError
|
||||
from nacl.secret import SecretBox
|
||||
from twisted.internet import address, defer, endpoints, error, protocol, task
|
||||
from twisted.internet import address, defer, endpoints, error, protocol, task, reactor
|
||||
from twisted.internet.defer import gatherResults, inlineCallbacks
|
||||
from twisted.python import log
|
||||
from twisted.test import proto_helpers
|
||||
|
@ -18,6 +18,7 @@ import mock
|
|||
from wormhole_transit_relay import transit_server
|
||||
|
||||
from .. import transit
|
||||
from .._hints import endpoint_from_hint_obj
|
||||
from ..errors import InternalError
|
||||
from .common import ServerBase
|
||||
|
||||
|
@ -145,30 +146,32 @@ UnknownHint = namedtuple("UnknownHint", ["stuff"])
|
|||
|
||||
class Hints(unittest.TestCase):
|
||||
def test_endpoint_from_hint_obj(self):
|
||||
c = transit.Common("")
|
||||
efho = c._endpoint_from_hint_obj
|
||||
def efho(hint, tor=None):
|
||||
return endpoint_from_hint_obj(hint, tor, reactor)
|
||||
self.assertIsInstance(
|
||||
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
|
||||
endpoints.HostnameEndpoint)
|
||||
self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None)
|
||||
# c._tor is currently None
|
||||
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None)
|
||||
c._tor = mock.Mock()
|
||||
|
||||
# tor=None
|
||||
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None)
|
||||
|
||||
tor = mock.Mock()
|
||||
def tor_ep(hostname, port):
|
||||
if hostname == "non-public":
|
||||
return None
|
||||
return ("tor_ep", hostname, port)
|
||||
tor.stream_via = mock.Mock(side_effect=tor_ep)
|
||||
|
||||
c._tor.stream_via = mock.Mock(side_effect=tor_ep)
|
||||
self.assertEqual(
|
||||
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
|
||||
efho(transit.DirectTCPV1Hint("host", 1234, 0.0), tor),
|
||||
("tor_ep", "host", 1234))
|
||||
self.assertEqual(
|
||||
efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)),
|
||||
efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0), tor),
|
||||
("tor_ep", "host2.onion", 1234))
|
||||
self.assertEqual(
|
||||
efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), None)
|
||||
efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0), tor), None)
|
||||
|
||||
self.assertEqual(efho(UnknownHint("foo")), None)
|
||||
|
||||
def test_comparable(self):
|
||||
|
@ -270,10 +273,13 @@ class Hints(unittest.TestCase):
|
|||
def test_describe_hint_obj(self):
|
||||
d = transit.describe_hint_obj
|
||||
self.assertEqual(
|
||||
d(transit.DirectTCPV1Hint("host", 1234, 0.0)), "tcp:host:1234")
|
||||
d(transit.DirectTCPV1Hint("host", 1234, 0.0), False, False),
|
||||
"->tcp:host:1234")
|
||||
self.assertEqual(
|
||||
d(transit.TorTCPV1Hint("host", 1234, 0.0)), "tor:host:1234")
|
||||
self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff")))
|
||||
d(transit.TorTCPV1Hint("host", 1234, 0.0), False, False),
|
||||
"->tor:host:1234")
|
||||
self.assertEqual(d(UnknownHint("stuff"), False, False),
|
||||
"->%s" % str(UnknownHint("stuff")))
|
||||
|
||||
|
||||
# ipaddrs.py currently uses native strings: bytes on py2, unicode on
|
||||
|
@ -1507,7 +1513,7 @@ class Transit(unittest.TestCase):
|
|||
self.assertEqual(self.successResultOf(d), "winner")
|
||||
self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"])
|
||||
|
||||
def _endpoint_from_hint_obj(self, hint):
|
||||
def _endpoint_from_hint_obj(self, hint, _tor, _reactor):
|
||||
if isinstance(hint, transit.DirectTCPV1Hint):
|
||||
if hint.hostname == "unavailable":
|
||||
return None
|
||||
|
@ -1523,20 +1529,21 @@ class Transit(unittest.TestCase):
|
|||
del hints
|
||||
s.add_connection_hints(
|
||||
[DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON, RELAY_HINT_JSON])
|
||||
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
|
||||
s._start_connector = self._start_connector
|
||||
|
||||
d = s.connect()
|
||||
self.assertNoResult(d)
|
||||
# the direct connectors are tried right away, but the relay
|
||||
# connectors are stalled for a few seconds
|
||||
self.assertEqual(self._connectors, ["direct"])
|
||||
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
|
||||
self._endpoint_from_hint_obj):
|
||||
d = s.connect()
|
||||
self.assertNoResult(d)
|
||||
# the direct connectors are tried right away, but the relay
|
||||
# connectors are stalled for a few seconds
|
||||
self.assertEqual(self._connectors, ["direct"])
|
||||
|
||||
clock.advance(s.RELAY_DELAY + 1.0)
|
||||
self.assertEqual(self._connectors, ["direct", "relay"])
|
||||
clock.advance(s.RELAY_DELAY + 1.0)
|
||||
self.assertEqual(self._connectors, ["direct", "relay"])
|
||||
|
||||
self._waiters[0].callback("winner")
|
||||
self.assertEqual(self.successResultOf(d), "winner")
|
||||
self._waiters[0].callback("winner")
|
||||
self.assertEqual(self.successResultOf(d), "winner")
|
||||
|
||||
@inlineCallbacks
|
||||
def test_priorities(self):
|
||||
|
@ -1586,31 +1593,32 @@ class Transit(unittest.TestCase):
|
|||
}]
|
||||
},
|
||||
])
|
||||
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
|
||||
s._start_connector = self._start_connector
|
||||
|
||||
d = s.connect()
|
||||
self.assertNoResult(d)
|
||||
# direct connector should be used first, then the priority=3.0 relay,
|
||||
# then the two 2.0 relays, then the (default) 0.0 relay
|
||||
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
|
||||
self._endpoint_from_hint_obj):
|
||||
d = s.connect()
|
||||
self.assertNoResult(d)
|
||||
# direct connector should be used first, then the priority=3.0 relay,
|
||||
# then the two 2.0 relays, then the (default) 0.0 relay
|
||||
|
||||
self.assertEqual(self._connectors, ["direct"])
|
||||
self.assertEqual(self._connectors, ["direct"])
|
||||
|
||||
clock.advance(s.RELAY_DELAY + 1.0)
|
||||
self.assertEqual(self._connectors, ["direct", "relay3"])
|
||||
clock.advance(s.RELAY_DELAY + 1.0)
|
||||
self.assertEqual(self._connectors, ["direct", "relay3"])
|
||||
|
||||
clock.advance(s.RELAY_DELAY)
|
||||
self.assertIn(self._connectors,
|
||||
(["direct", "relay3", "relay2", "relay4"],
|
||||
["direct", "relay3", "relay4", "relay2"]))
|
||||
clock.advance(s.RELAY_DELAY)
|
||||
self.assertIn(self._connectors,
|
||||
(["direct", "relay3", "relay2", "relay4"],
|
||||
["direct", "relay3", "relay4", "relay2"]))
|
||||
|
||||
clock.advance(s.RELAY_DELAY)
|
||||
self.assertIn(self._connectors,
|
||||
(["direct", "relay3", "relay2", "relay4", "relay"],
|
||||
["direct", "relay3", "relay4", "relay2", "relay"]))
|
||||
clock.advance(s.RELAY_DELAY)
|
||||
self.assertIn(self._connectors,
|
||||
(["direct", "relay3", "relay2", "relay4", "relay"],
|
||||
["direct", "relay3", "relay4", "relay2", "relay"]))
|
||||
|
||||
self._waiters[0].callback("winner")
|
||||
self.assertEqual(self.successResultOf(d), "winner")
|
||||
self._waiters[0].callback("winner")
|
||||
self.assertEqual(self.successResultOf(d), "winner")
|
||||
|
||||
@inlineCallbacks
|
||||
def test_no_direct_hints(self):
|
||||
|
@ -1624,20 +1632,21 @@ class Transit(unittest.TestCase):
|
|||
UNRECOGNIZED_HINT_JSON, UNAVAILABLE_HINT_JSON, RELAY_HINT2_JSON,
|
||||
UNAVAILABLE_RELAY_HINT_JSON
|
||||
])
|
||||
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
|
||||
s._start_connector = self._start_connector
|
||||
|
||||
d = s.connect()
|
||||
self.assertNoResult(d)
|
||||
# since there are no usable direct hints, the relay connector will
|
||||
# only be stalled for 0 seconds
|
||||
self.assertEqual(self._connectors, [])
|
||||
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
|
||||
self._endpoint_from_hint_obj):
|
||||
d = s.connect()
|
||||
self.assertNoResult(d)
|
||||
# since there are no usable direct hints, the relay connector will
|
||||
# only be stalled for 0 seconds
|
||||
self.assertEqual(self._connectors, [])
|
||||
|
||||
clock.advance(0)
|
||||
self.assertEqual(self._connectors, ["relay"])
|
||||
clock.advance(0)
|
||||
self.assertEqual(self._connectors, ["relay"])
|
||||
|
||||
self._waiters[0].callback("winner")
|
||||
self.assertEqual(self.successResultOf(d), "winner")
|
||||
self._waiters[0].callback("winner")
|
||||
self.assertEqual(self.successResultOf(d), "winner")
|
||||
|
||||
@inlineCallbacks
|
||||
def test_no_contenders(self):
|
||||
|
@ -1647,12 +1656,13 @@ class Transit(unittest.TestCase):
|
|||
hints = yield s.get_connection_hints() # start the listener
|
||||
del hints
|
||||
s.add_connection_hints([]) # no hints at all
|
||||
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
|
||||
s._start_connector = self._start_connector
|
||||
|
||||
d = s.connect()
|
||||
f = self.failureResultOf(d, transit.TransitError)
|
||||
self.assertEqual(str(f.value), "No contenders for connection")
|
||||
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
|
||||
self._endpoint_from_hint_obj):
|
||||
d = s.connect()
|
||||
f = self.failureResultOf(d, transit.TransitError)
|
||||
self.assertEqual(str(f.value), "No contenders for connection")
|
||||
|
||||
|
||||
class RelayHandshake(unittest.TestCase):
|
||||
|
|
|
@ -23,6 +23,8 @@ from . import ipaddrs
|
|||
from .errors import InternalError
|
||||
from .timing import DebugTiming
|
||||
from .util import bytes_to_hexstr
|
||||
from ._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint,
|
||||
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj)
|
||||
|
||||
|
||||
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
||||
|
@ -94,16 +96,6 @@ def build_sided_relay_handshake(key, side):
|
|||
"ascii") + b"\n"
|
||||
|
||||
|
||||
from ._hints import parse_hint_argv, DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint
|
||||
|
||||
def describe_hint_obj(hint):
|
||||
if isinstance(hint, DirectTCPV1Hint):
|
||||
return u"tcp:%s:%d" % (hint.hostname, hint.port)
|
||||
elif isinstance(hint, TorTCPV1Hint):
|
||||
return u"tor:%s:%d" % (hint.hostname, hint.port)
|
||||
else:
|
||||
return str(hint)
|
||||
|
||||
|
||||
TIMEOUT = 60 # seconds
|
||||
|
||||
|
@ -818,13 +810,11 @@ class Common:
|
|||
# Check the hint type to see if we can support it (e.g. skip
|
||||
# onion hints on a non-Tor client). Do not increase relay_delay
|
||||
# unless we have at least one viable hint.
|
||||
ep = self._endpoint_from_hint_obj(hint_obj)
|
||||
ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor)
|
||||
if not ep:
|
||||
continue
|
||||
description = "->%s" % describe_hint_obj(hint_obj)
|
||||
if self._tor:
|
||||
description = "tor" + description
|
||||
d = self._start_connector(ep, description)
|
||||
d = self._start_connector(ep,
|
||||
describe_hint_obj(hint_obj, False, self._tor))
|
||||
contenders.append(d)
|
||||
relay_delay = self.RELAY_DELAY
|
||||
|
||||
|
@ -845,18 +835,15 @@ class Common:
|
|||
|
||||
for priority in sorted(prioritized_relays, reverse=True):
|
||||
for hint_obj in prioritized_relays[priority]:
|
||||
ep = self._endpoint_from_hint_obj(hint_obj)
|
||||
ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor)
|
||||
if not ep:
|
||||
continue
|
||||
description = "->relay:%s" % describe_hint_obj(hint_obj)
|
||||
if self._tor:
|
||||
description = "tor" + description
|
||||
d = task.deferLater(
|
||||
self._reactor,
|
||||
relay_delay,
|
||||
self._start_connector,
|
||||
ep,
|
||||
description,
|
||||
describe_hint_obj(hint_obj, True, self._tor),
|
||||
is_relay=True)
|
||||
contenders.append(d)
|
||||
relay_delay += self.RELAY_DELAY
|
||||
|
@ -894,21 +881,6 @@ class Common:
|
|||
d.addCallback(lambda p: p.startNegotiation())
|
||||
return d
|
||||
|
||||
def _endpoint_from_hint_obj(self, hint):
|
||||
if self._tor:
|
||||
if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
|
||||
# this Tor object will throw ValueError for non-public IPv4
|
||||
# addresses and any IPv6 address
|
||||
try:
|
||||
return self._tor.stream_via(hint.hostname, hint.port)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
if isinstance(hint, DirectTCPV1Hint):
|
||||
return endpoints.HostnameEndpoint(self._reactor, hint.hostname,
|
||||
hint.port)
|
||||
return None
|
||||
|
||||
def connection_ready(self, p):
|
||||
# inbound/outbound Connection protocols call this when they finish
|
||||
# negotiation. The first one wins and gets a "go". Any subsequent
|
||||
|
|
Loading…
Reference in New Issue
Block a user