factor out parse_tcp_v1_hint
This commit is contained in:
parent
2f4e4d3031
commit
7720312c8f
|
@ -20,27 +20,8 @@ from .connection import DilatedConnectionProtocol, KCM
|
|||
from .roles import LEADER
|
||||
|
||||
from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint,
|
||||
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj)
|
||||
|
||||
|
||||
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
|
||||
hint_type = hint.get("type", "")
|
||||
if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]:
|
||||
log.msg("unknown hint type: %r" % (hint,))
|
||||
return None
|
||||
if not("hostname" in hint and
|
||||
isinstance(hint["hostname"], type(""))):
|
||||
log.msg("invalid hostname in hint: %r" % (hint,))
|
||||
return None
|
||||
if not("port" in hint and
|
||||
isinstance(hint["port"], six.integer_types)):
|
||||
log.msg("invalid port in hint: %r" % (hint,))
|
||||
return None
|
||||
priority = hint.get("priority", 0.0)
|
||||
if hint_type == "direct-tcp-v1":
|
||||
return DirectTCPV1Hint(hint["hostname"], hint["port"], priority)
|
||||
else:
|
||||
return TorTCPV1Hint(hint["hostname"], hint["port"], priority)
|
||||
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj,
|
||||
parse_tcp_v1_hint)
|
||||
|
||||
|
||||
def parse_hint(hint_struct):
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import sys
|
||||
import re
|
||||
import six
|
||||
from collections import namedtuple
|
||||
from twisted.internet.endpoints import HostnameEndpoint
|
||||
from twisted.python import log
|
||||
|
||||
# These namedtuples are "hint objects". The JSON-serializable dictionaries
|
||||
# are "hint dicts".
|
||||
|
@ -82,3 +84,22 @@ def endpoint_from_hint_obj(hint, tor, reactor):
|
|||
if isinstance(hint, DirectTCPV1Hint):
|
||||
return HostnameEndpoint(reactor, hint.hostname, hint.port)
|
||||
return None
|
||||
|
||||
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
|
||||
hint_type = hint.get("type", "")
|
||||
if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]:
|
||||
log.msg("unknown hint type: %r" % (hint, ))
|
||||
return None
|
||||
if not ("hostname" in hint and
|
||||
isinstance(hint["hostname"], type(""))):
|
||||
log.msg("invalid hostname in hint: %r" % (hint, ))
|
||||
return None
|
||||
if not ("port" in hint and
|
||||
isinstance(hint["port"], six.integer_types)):
|
||||
log.msg("invalid port in hint: %r" % (hint, ))
|
||||
return None
|
||||
priority = hint.get("priority", 0.0)
|
||||
if hint_type == "direct-tcp-v1":
|
||||
return DirectTCPV1Hint(hint["hostname"], hint["port"], priority)
|
||||
else:
|
||||
return TorTCPV1Hint(hint["hostname"], hint["port"], priority)
|
||||
|
|
|
@ -18,7 +18,8 @@ import mock
|
|||
from wormhole_transit_relay import transit_server
|
||||
|
||||
from .. import transit
|
||||
from .._hints import endpoint_from_hint_obj
|
||||
from .._hints import (endpoint_from_hint_obj, parse_hint_argv, parse_tcp_v1_hint,
|
||||
DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint)
|
||||
from ..errors import InternalError
|
||||
from .common import ServerBase
|
||||
|
||||
|
@ -148,13 +149,12 @@ class Hints(unittest.TestCase):
|
|||
def test_endpoint_from_hint_obj(self):
|
||||
def efho(hint, tor=None):
|
||||
return endpoint_from_hint_obj(hint, tor, reactor)
|
||||
self.assertIsInstance(
|
||||
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
|
||||
endpoints.HostnameEndpoint)
|
||||
self.assertIsInstance(efho(DirectTCPV1Hint("host", 1234, 0.0)),
|
||||
endpoints.HostnameEndpoint)
|
||||
self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None)
|
||||
|
||||
# tor=None
|
||||
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None)
|
||||
self.assertEqual(efho(TorTCPV1Hint("host", "port", 0)), None)
|
||||
|
||||
tor = mock.Mock()
|
||||
def tor_ep(hostname, port):
|
||||
|
@ -163,50 +163,46 @@ class Hints(unittest.TestCase):
|
|||
return ("tor_ep", hostname, port)
|
||||
tor.stream_via = mock.Mock(side_effect=tor_ep)
|
||||
|
||||
self.assertEqual(
|
||||
efho(transit.DirectTCPV1Hint("host", 1234, 0.0), tor),
|
||||
("tor_ep", "host", 1234))
|
||||
self.assertEqual(
|
||||
efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0), tor),
|
||||
("tor_ep", "host2.onion", 1234))
|
||||
self.assertEqual(
|
||||
efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0), tor), None)
|
||||
self.assertEqual(efho(DirectTCPV1Hint("host", 1234, 0.0), tor),
|
||||
("tor_ep", "host", 1234))
|
||||
self.assertEqual(efho(TorTCPV1Hint("host2.onion", 1234, 0.0), tor),
|
||||
("tor_ep", "host2.onion", 1234))
|
||||
self.assertEqual( efho(DirectTCPV1Hint("non-public", 1234, 0.0), tor), None)
|
||||
|
||||
self.assertEqual(efho(UnknownHint("foo")), None)
|
||||
|
||||
def test_comparable(self):
|
||||
h1 = transit.DirectTCPV1Hint("hostname", "port1", 0.0)
|
||||
h1b = transit.DirectTCPV1Hint("hostname", "port1", 0.0)
|
||||
h2 = transit.DirectTCPV1Hint("hostname", "port2", 0.0)
|
||||
r1 = transit.RelayV1Hint(tuple(sorted([h1, h2])))
|
||||
r2 = transit.RelayV1Hint(tuple(sorted([h2, h1])))
|
||||
r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2])))
|
||||
h1 = DirectTCPV1Hint("hostname", "port1", 0.0)
|
||||
h1b = DirectTCPV1Hint("hostname", "port1", 0.0)
|
||||
h2 = DirectTCPV1Hint("hostname", "port2", 0.0)
|
||||
r1 = RelayV1Hint(tuple(sorted([h1, h2])))
|
||||
r2 = RelayV1Hint(tuple(sorted([h2, h1])))
|
||||
r3 = RelayV1Hint(tuple(sorted([h1b, h2])))
|
||||
self.assertEqual(r1, r2)
|
||||
self.assertEqual(r2, r3)
|
||||
self.assertEqual(len(set([r1, r2, r3])), 1)
|
||||
|
||||
def test_parse_tcp_v1_hint(self):
|
||||
c = transit.Common("")
|
||||
p = c._parse_tcp_v1_hint
|
||||
p = parse_tcp_v1_hint
|
||||
self.assertEqual(p({"type": "unknown"}), None)
|
||||
h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234})
|
||||
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0))
|
||||
self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 0.0))
|
||||
h = p({
|
||||
"type": "direct-tcp-v1",
|
||||
"hostname": "foo",
|
||||
"port": 1234,
|
||||
"priority": 2.5
|
||||
})
|
||||
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5))
|
||||
self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 2.5))
|
||||
h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234})
|
||||
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0))
|
||||
self.assertEqual(h, TorTCPV1Hint("foo", 1234, 0.0))
|
||||
h = p({
|
||||
"type": "tor-tcp-v1",
|
||||
"hostname": "foo",
|
||||
"port": 1234,
|
||||
"priority": 2.5
|
||||
})
|
||||
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5))
|
||||
self.assertEqual(h, TorTCPV1Hint("foo", 1234, 2.5))
|
||||
self.assertEqual(p({
|
||||
"type": "direct-tcp-v1"
|
||||
}), None) # missing hostname
|
||||
|
@ -229,19 +225,19 @@ class Hints(unittest.TestCase):
|
|||
def test_parse_hint_argv(self):
|
||||
def p(hint):
|
||||
stderr = io.StringIO()
|
||||
value = transit.parse_hint_argv(hint, stderr=stderr)
|
||||
value = parse_hint_argv(hint, stderr=stderr)
|
||||
return value, stderr.getvalue()
|
||||
|
||||
h, stderr = p("tcp:host:1234")
|
||||
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0))
|
||||
self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0))
|
||||
self.assertEqual(stderr, "")
|
||||
|
||||
h, stderr = p("tcp:host:1234:priority=2.6")
|
||||
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6))
|
||||
self.assertEqual(h, DirectTCPV1Hint("host", 1234, 2.6))
|
||||
self.assertEqual(stderr, "")
|
||||
|
||||
h, stderr = p("tcp:host:1234:unknown=stuff")
|
||||
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0))
|
||||
self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0))
|
||||
self.assertEqual(stderr, "")
|
||||
|
||||
h, stderr = p("$!@#^")
|
||||
|
@ -272,12 +268,10 @@ class Hints(unittest.TestCase):
|
|||
|
||||
def test_describe_hint_obj(self):
|
||||
d = transit.describe_hint_obj
|
||||
self.assertEqual(
|
||||
d(transit.DirectTCPV1Hint("host", 1234, 0.0), False, False),
|
||||
"->tcp:host:1234")
|
||||
self.assertEqual(
|
||||
d(transit.TorTCPV1Hint("host", 1234, 0.0), False, False),
|
||||
"->tor:host:1234")
|
||||
self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, False),
|
||||
"->tcp:host:1234")
|
||||
self.assertEqual(d(TorTCPV1Hint("host", 1234, 0.0), False, False),
|
||||
"->tor:host:1234")
|
||||
self.assertEqual(d(UnknownHint("stuff"), False, False),
|
||||
"->%s" % str(UnknownHint("stuff")))
|
||||
|
||||
|
@ -443,7 +437,7 @@ class Listener(unittest.TestCase):
|
|||
hints, ep = c._build_listener()
|
||||
self.assertIsInstance(hints, (list, set))
|
||||
if hints:
|
||||
self.assertIsInstance(hints[0], transit.DirectTCPV1Hint)
|
||||
self.assertIsInstance(hints[0], DirectTCPV1Hint)
|
||||
self.assertIsInstance(ep, endpoints.TCP4ServerEndpoint)
|
||||
|
||||
def test_get_direct_hints(self):
|
||||
|
@ -1514,7 +1508,7 @@ class Transit(unittest.TestCase):
|
|||
self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"])
|
||||
|
||||
def _endpoint_from_hint_obj(self, hint, _tor, _reactor):
|
||||
if isinstance(hint, transit.DirectTCPV1Hint):
|
||||
if isinstance(hint, DirectTCPV1Hint):
|
||||
if hint.hostname == "unavailable":
|
||||
return None
|
||||
return hint.hostname
|
||||
|
|
|
@ -23,8 +23,9 @@ from . import ipaddrs
|
|||
from .errors import InternalError
|
||||
from .timing import DebugTiming
|
||||
from .util import bytes_to_hexstr
|
||||
from ._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint,
|
||||
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj)
|
||||
from ._hints import (DirectTCPV1Hint, RelayV1Hint,
|
||||
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj,
|
||||
parse_tcp_v1_hint)
|
||||
|
||||
|
||||
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
||||
|
@ -681,30 +682,11 @@ class Common:
|
|||
self._listener_d.addErrback(lambda f: None)
|
||||
self._listener_d.cancel()
|
||||
|
||||
def _parse_tcp_v1_hint(self, hint): # hint_struct -> hint_obj
|
||||
hint_type = hint.get(u"type", u"")
|
||||
if hint_type not in [u"direct-tcp-v1", u"tor-tcp-v1"]:
|
||||
log.msg("unknown hint type: %r" % (hint, ))
|
||||
return None
|
||||
if not (u"hostname" in hint and
|
||||
isinstance(hint[u"hostname"], type(u""))):
|
||||
log.msg("invalid hostname in hint: %r" % (hint, ))
|
||||
return None
|
||||
if not (u"port" in hint and
|
||||
isinstance(hint[u"port"], six.integer_types)):
|
||||
log.msg("invalid port in hint: %r" % (hint, ))
|
||||
return None
|
||||
priority = hint.get(u"priority", 0.0)
|
||||
if hint_type == u"direct-tcp-v1":
|
||||
return DirectTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
|
||||
else:
|
||||
return TorTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
|
||||
|
||||
def add_connection_hints(self, hints):
|
||||
for h in hints: # hint structs
|
||||
hint_type = h.get(u"type", u"")
|
||||
if hint_type in [u"direct-tcp-v1", u"tor-tcp-v1"]:
|
||||
dh = self._parse_tcp_v1_hint(h)
|
||||
dh = parse_tcp_v1_hint(h)
|
||||
if dh:
|
||||
self._their_direct_hints.append(dh) # hint_obj
|
||||
elif hint_type == u"relay-v1":
|
||||
|
@ -714,7 +696,7 @@ class Common:
|
|||
# together like this.
|
||||
relay_hints = []
|
||||
for rhs in h.get(u"hints", []):
|
||||
h = self._parse_tcp_v1_hint(rhs)
|
||||
h = parse_tcp_v1_hint(rhs)
|
||||
if h:
|
||||
relay_hints.append(h)
|
||||
if relay_hints:
|
||||
|
|
Loading…
Reference in New Issue
Block a user