factor out parse_tcp_v1_hint

This commit is contained in:
Brian Warner 2018-12-21 23:22:02 -05:00
parent 2f4e4d3031
commit 7720312c8f
4 changed files with 59 additions and 81 deletions

View File

@ -20,27 +20,8 @@ from .connection import DilatedConnectionProtocol, KCM
from .roles import LEADER
from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint,
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj)
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
hint_type = hint.get("type", "")
if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]:
log.msg("unknown hint type: %r" % (hint,))
return None
if not("hostname" in hint and
isinstance(hint["hostname"], type(""))):
log.msg("invalid hostname in hint: %r" % (hint,))
return None
if not("port" in hint and
isinstance(hint["port"], six.integer_types)):
log.msg("invalid port in hint: %r" % (hint,))
return None
priority = hint.get("priority", 0.0)
if hint_type == "direct-tcp-v1":
return DirectTCPV1Hint(hint["hostname"], hint["port"], priority)
else:
return TorTCPV1Hint(hint["hostname"], hint["port"], priority)
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj,
parse_tcp_v1_hint)
def parse_hint(hint_struct):

View File

@ -1,8 +1,10 @@
from __future__ import print_function, unicode_literals
import sys
import re
import six
from collections import namedtuple
from twisted.internet.endpoints import HostnameEndpoint
from twisted.python import log
# These namedtuples are "hint objects". The JSON-serializable dictionaries
# are "hint dicts".
@ -82,3 +84,22 @@ def endpoint_from_hint_obj(hint, tor, reactor):
if isinstance(hint, DirectTCPV1Hint):
return HostnameEndpoint(reactor, hint.hostname, hint.port)
return None
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
hint_type = hint.get("type", "")
if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]:
log.msg("unknown hint type: %r" % (hint, ))
return None
if not ("hostname" in hint and
isinstance(hint["hostname"], type(""))):
log.msg("invalid hostname in hint: %r" % (hint, ))
return None
if not ("port" in hint and
isinstance(hint["port"], six.integer_types)):
log.msg("invalid port in hint: %r" % (hint, ))
return None
priority = hint.get("priority", 0.0)
if hint_type == "direct-tcp-v1":
return DirectTCPV1Hint(hint["hostname"], hint["port"], priority)
else:
return TorTCPV1Hint(hint["hostname"], hint["port"], priority)

View File

@ -18,7 +18,8 @@ import mock
from wormhole_transit_relay import transit_server
from .. import transit
from .._hints import endpoint_from_hint_obj
from .._hints import (endpoint_from_hint_obj, parse_hint_argv, parse_tcp_v1_hint,
DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint)
from ..errors import InternalError
from .common import ServerBase
@ -148,13 +149,12 @@ class Hints(unittest.TestCase):
def test_endpoint_from_hint_obj(self):
def efho(hint, tor=None):
return endpoint_from_hint_obj(hint, tor, reactor)
self.assertIsInstance(
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
endpoints.HostnameEndpoint)
self.assertIsInstance(efho(DirectTCPV1Hint("host", 1234, 0.0)),
endpoints.HostnameEndpoint)
self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None)
# tor=None
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None)
self.assertEqual(efho(TorTCPV1Hint("host", "port", 0)), None)
tor = mock.Mock()
def tor_ep(hostname, port):
@ -163,50 +163,46 @@ class Hints(unittest.TestCase):
return ("tor_ep", hostname, port)
tor.stream_via = mock.Mock(side_effect=tor_ep)
self.assertEqual(
efho(transit.DirectTCPV1Hint("host", 1234, 0.0), tor),
("tor_ep", "host", 1234))
self.assertEqual(
efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0), tor),
("tor_ep", "host2.onion", 1234))
self.assertEqual(
efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0), tor), None)
self.assertEqual(efho(DirectTCPV1Hint("host", 1234, 0.0), tor),
("tor_ep", "host", 1234))
self.assertEqual(efho(TorTCPV1Hint("host2.onion", 1234, 0.0), tor),
("tor_ep", "host2.onion", 1234))
self.assertEqual( efho(DirectTCPV1Hint("non-public", 1234, 0.0), tor), None)
self.assertEqual(efho(UnknownHint("foo")), None)
def test_comparable(self):
h1 = transit.DirectTCPV1Hint("hostname", "port1", 0.0)
h1b = transit.DirectTCPV1Hint("hostname", "port1", 0.0)
h2 = transit.DirectTCPV1Hint("hostname", "port2", 0.0)
r1 = transit.RelayV1Hint(tuple(sorted([h1, h2])))
r2 = transit.RelayV1Hint(tuple(sorted([h2, h1])))
r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2])))
h1 = DirectTCPV1Hint("hostname", "port1", 0.0)
h1b = DirectTCPV1Hint("hostname", "port1", 0.0)
h2 = DirectTCPV1Hint("hostname", "port2", 0.0)
r1 = RelayV1Hint(tuple(sorted([h1, h2])))
r2 = RelayV1Hint(tuple(sorted([h2, h1])))
r3 = RelayV1Hint(tuple(sorted([h1b, h2])))
self.assertEqual(r1, r2)
self.assertEqual(r2, r3)
self.assertEqual(len(set([r1, r2, r3])), 1)
def test_parse_tcp_v1_hint(self):
c = transit.Common("")
p = c._parse_tcp_v1_hint
p = parse_tcp_v1_hint
self.assertEqual(p({"type": "unknown"}), None)
h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234})
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0))
self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 0.0))
h = p({
"type": "direct-tcp-v1",
"hostname": "foo",
"port": 1234,
"priority": 2.5
})
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5))
self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 2.5))
h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234})
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0))
self.assertEqual(h, TorTCPV1Hint("foo", 1234, 0.0))
h = p({
"type": "tor-tcp-v1",
"hostname": "foo",
"port": 1234,
"priority": 2.5
})
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5))
self.assertEqual(h, TorTCPV1Hint("foo", 1234, 2.5))
self.assertEqual(p({
"type": "direct-tcp-v1"
}), None) # missing hostname
@ -229,19 +225,19 @@ class Hints(unittest.TestCase):
def test_parse_hint_argv(self):
def p(hint):
stderr = io.StringIO()
value = transit.parse_hint_argv(hint, stderr=stderr)
value = parse_hint_argv(hint, stderr=stderr)
return value, stderr.getvalue()
h, stderr = p("tcp:host:1234")
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0))
self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0))
self.assertEqual(stderr, "")
h, stderr = p("tcp:host:1234:priority=2.6")
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6))
self.assertEqual(h, DirectTCPV1Hint("host", 1234, 2.6))
self.assertEqual(stderr, "")
h, stderr = p("tcp:host:1234:unknown=stuff")
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0))
self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0))
self.assertEqual(stderr, "")
h, stderr = p("$!@#^")
@ -272,12 +268,10 @@ class Hints(unittest.TestCase):
def test_describe_hint_obj(self):
d = transit.describe_hint_obj
self.assertEqual(
d(transit.DirectTCPV1Hint("host", 1234, 0.0), False, False),
"->tcp:host:1234")
self.assertEqual(
d(transit.TorTCPV1Hint("host", 1234, 0.0), False, False),
"->tor:host:1234")
self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, False),
"->tcp:host:1234")
self.assertEqual(d(TorTCPV1Hint("host", 1234, 0.0), False, False),
"->tor:host:1234")
self.assertEqual(d(UnknownHint("stuff"), False, False),
"->%s" % str(UnknownHint("stuff")))
@ -443,7 +437,7 @@ class Listener(unittest.TestCase):
hints, ep = c._build_listener()
self.assertIsInstance(hints, (list, set))
if hints:
self.assertIsInstance(hints[0], transit.DirectTCPV1Hint)
self.assertIsInstance(hints[0], DirectTCPV1Hint)
self.assertIsInstance(ep, endpoints.TCP4ServerEndpoint)
def test_get_direct_hints(self):
@ -1514,7 +1508,7 @@ class Transit(unittest.TestCase):
self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"])
def _endpoint_from_hint_obj(self, hint, _tor, _reactor):
if isinstance(hint, transit.DirectTCPV1Hint):
if isinstance(hint, DirectTCPV1Hint):
if hint.hostname == "unavailable":
return None
return hint.hostname

View File

@ -23,8 +23,9 @@ from . import ipaddrs
from .errors import InternalError
from .timing import DebugTiming
from .util import bytes_to_hexstr
from ._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint,
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj)
from ._hints import (DirectTCPV1Hint, RelayV1Hint,
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj,
parse_tcp_v1_hint)
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
@ -681,30 +682,11 @@ class Common:
self._listener_d.addErrback(lambda f: None)
self._listener_d.cancel()
def _parse_tcp_v1_hint(self, hint): # hint_struct -> hint_obj
hint_type = hint.get(u"type", u"")
if hint_type not in [u"direct-tcp-v1", u"tor-tcp-v1"]:
log.msg("unknown hint type: %r" % (hint, ))
return None
if not (u"hostname" in hint and
isinstance(hint[u"hostname"], type(u""))):
log.msg("invalid hostname in hint: %r" % (hint, ))
return None
if not (u"port" in hint and
isinstance(hint[u"port"], six.integer_types)):
log.msg("invalid port in hint: %r" % (hint, ))
return None
priority = hint.get(u"priority", 0.0)
if hint_type == u"direct-tcp-v1":
return DirectTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
else:
return TorTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
def add_connection_hints(self, hints):
for h in hints: # hint structs
hint_type = h.get(u"type", u"")
if hint_type in [u"direct-tcp-v1", u"tor-tcp-v1"]:
dh = self._parse_tcp_v1_hint(h)
dh = parse_tcp_v1_hint(h)
if dh:
self._their_direct_hints.append(dh) # hint_obj
elif hint_type == u"relay-v1":
@ -714,7 +696,7 @@ class Common:
# together like this.
relay_hints = []
for rhs in h.get(u"hints", []):
h = self._parse_tcp_v1_hint(rhs)
h = parse_tcp_v1_hint(rhs)
if h:
relay_hints.append(h)
if relay_hints: