factor out parse_tcp_v1_hint

This commit is contained in:
Brian Warner 2018-12-21 23:22:02 -05:00
parent 2f4e4d3031
commit 7720312c8f
4 changed files with 59 additions and 81 deletions

View File

@ -20,27 +20,8 @@ from .connection import DilatedConnectionProtocol, KCM
from .roles import LEADER from .roles import LEADER
from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint, from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint,
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj) parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj,
parse_tcp_v1_hint)
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
hint_type = hint.get("type", "")
if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]:
log.msg("unknown hint type: %r" % (hint,))
return None
if not("hostname" in hint and
isinstance(hint["hostname"], type(""))):
log.msg("invalid hostname in hint: %r" % (hint,))
return None
if not("port" in hint and
isinstance(hint["port"], six.integer_types)):
log.msg("invalid port in hint: %r" % (hint,))
return None
priority = hint.get("priority", 0.0)
if hint_type == "direct-tcp-v1":
return DirectTCPV1Hint(hint["hostname"], hint["port"], priority)
else:
return TorTCPV1Hint(hint["hostname"], hint["port"], priority)
def parse_hint(hint_struct): def parse_hint(hint_struct):

View File

@ -1,8 +1,10 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import sys import sys
import re import re
import six
from collections import namedtuple from collections import namedtuple
from twisted.internet.endpoints import HostnameEndpoint from twisted.internet.endpoints import HostnameEndpoint
from twisted.python import log
# These namedtuples are "hint objects". The JSON-serializable dictionaries # These namedtuples are "hint objects". The JSON-serializable dictionaries
# are "hint dicts". # are "hint dicts".
@ -82,3 +84,22 @@ def endpoint_from_hint_obj(hint, tor, reactor):
if isinstance(hint, DirectTCPV1Hint): if isinstance(hint, DirectTCPV1Hint):
return HostnameEndpoint(reactor, hint.hostname, hint.port) return HostnameEndpoint(reactor, hint.hostname, hint.port)
return None return None
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
hint_type = hint.get("type", "")
if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]:
log.msg("unknown hint type: %r" % (hint, ))
return None
if not ("hostname" in hint and
isinstance(hint["hostname"], type(""))):
log.msg("invalid hostname in hint: %r" % (hint, ))
return None
if not ("port" in hint and
isinstance(hint["port"], six.integer_types)):
log.msg("invalid port in hint: %r" % (hint, ))
return None
priority = hint.get("priority", 0.0)
if hint_type == "direct-tcp-v1":
return DirectTCPV1Hint(hint["hostname"], hint["port"], priority)
else:
return TorTCPV1Hint(hint["hostname"], hint["port"], priority)

View File

@ -18,7 +18,8 @@ import mock
from wormhole_transit_relay import transit_server from wormhole_transit_relay import transit_server
from .. import transit from .. import transit
from .._hints import endpoint_from_hint_obj from .._hints import (endpoint_from_hint_obj, parse_hint_argv, parse_tcp_v1_hint,
DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint)
from ..errors import InternalError from ..errors import InternalError
from .common import ServerBase from .common import ServerBase
@ -148,13 +149,12 @@ class Hints(unittest.TestCase):
def test_endpoint_from_hint_obj(self): def test_endpoint_from_hint_obj(self):
def efho(hint, tor=None): def efho(hint, tor=None):
return endpoint_from_hint_obj(hint, tor, reactor) return endpoint_from_hint_obj(hint, tor, reactor)
self.assertIsInstance( self.assertIsInstance(efho(DirectTCPV1Hint("host", 1234, 0.0)),
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), endpoints.HostnameEndpoint)
endpoints.HostnameEndpoint)
self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None)
# tor=None # tor=None
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None) self.assertEqual(efho(TorTCPV1Hint("host", "port", 0)), None)
tor = mock.Mock() tor = mock.Mock()
def tor_ep(hostname, port): def tor_ep(hostname, port):
@ -163,50 +163,46 @@ class Hints(unittest.TestCase):
return ("tor_ep", hostname, port) return ("tor_ep", hostname, port)
tor.stream_via = mock.Mock(side_effect=tor_ep) tor.stream_via = mock.Mock(side_effect=tor_ep)
self.assertEqual( self.assertEqual(efho(DirectTCPV1Hint("host", 1234, 0.0), tor),
efho(transit.DirectTCPV1Hint("host", 1234, 0.0), tor), ("tor_ep", "host", 1234))
("tor_ep", "host", 1234)) self.assertEqual(efho(TorTCPV1Hint("host2.onion", 1234, 0.0), tor),
self.assertEqual( ("tor_ep", "host2.onion", 1234))
efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0), tor), self.assertEqual( efho(DirectTCPV1Hint("non-public", 1234, 0.0), tor), None)
("tor_ep", "host2.onion", 1234))
self.assertEqual(
efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0), tor), None)
self.assertEqual(efho(UnknownHint("foo")), None) self.assertEqual(efho(UnknownHint("foo")), None)
def test_comparable(self): def test_comparable(self):
h1 = transit.DirectTCPV1Hint("hostname", "port1", 0.0) h1 = DirectTCPV1Hint("hostname", "port1", 0.0)
h1b = transit.DirectTCPV1Hint("hostname", "port1", 0.0) h1b = DirectTCPV1Hint("hostname", "port1", 0.0)
h2 = transit.DirectTCPV1Hint("hostname", "port2", 0.0) h2 = DirectTCPV1Hint("hostname", "port2", 0.0)
r1 = transit.RelayV1Hint(tuple(sorted([h1, h2]))) r1 = RelayV1Hint(tuple(sorted([h1, h2])))
r2 = transit.RelayV1Hint(tuple(sorted([h2, h1]))) r2 = RelayV1Hint(tuple(sorted([h2, h1])))
r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2]))) r3 = RelayV1Hint(tuple(sorted([h1b, h2])))
self.assertEqual(r1, r2) self.assertEqual(r1, r2)
self.assertEqual(r2, r3) self.assertEqual(r2, r3)
self.assertEqual(len(set([r1, r2, r3])), 1) self.assertEqual(len(set([r1, r2, r3])), 1)
def test_parse_tcp_v1_hint(self): def test_parse_tcp_v1_hint(self):
c = transit.Common("") p = parse_tcp_v1_hint
p = c._parse_tcp_v1_hint
self.assertEqual(p({"type": "unknown"}), None) self.assertEqual(p({"type": "unknown"}), None)
h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234}) h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234})
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0)) self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 0.0))
h = p({ h = p({
"type": "direct-tcp-v1", "type": "direct-tcp-v1",
"hostname": "foo", "hostname": "foo",
"port": 1234, "port": 1234,
"priority": 2.5 "priority": 2.5
}) })
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5)) self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 2.5))
h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234}) h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234})
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0)) self.assertEqual(h, TorTCPV1Hint("foo", 1234, 0.0))
h = p({ h = p({
"type": "tor-tcp-v1", "type": "tor-tcp-v1",
"hostname": "foo", "hostname": "foo",
"port": 1234, "port": 1234,
"priority": 2.5 "priority": 2.5
}) })
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5)) self.assertEqual(h, TorTCPV1Hint("foo", 1234, 2.5))
self.assertEqual(p({ self.assertEqual(p({
"type": "direct-tcp-v1" "type": "direct-tcp-v1"
}), None) # missing hostname }), None) # missing hostname
@ -229,19 +225,19 @@ class Hints(unittest.TestCase):
def test_parse_hint_argv(self): def test_parse_hint_argv(self):
def p(hint): def p(hint):
stderr = io.StringIO() stderr = io.StringIO()
value = transit.parse_hint_argv(hint, stderr=stderr) value = parse_hint_argv(hint, stderr=stderr)
return value, stderr.getvalue() return value, stderr.getvalue()
h, stderr = p("tcp:host:1234") h, stderr = p("tcp:host:1234")
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0))
self.assertEqual(stderr, "") self.assertEqual(stderr, "")
h, stderr = p("tcp:host:1234:priority=2.6") h, stderr = p("tcp:host:1234:priority=2.6")
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6)) self.assertEqual(h, DirectTCPV1Hint("host", 1234, 2.6))
self.assertEqual(stderr, "") self.assertEqual(stderr, "")
h, stderr = p("tcp:host:1234:unknown=stuff") h, stderr = p("tcp:host:1234:unknown=stuff")
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0))
self.assertEqual(stderr, "") self.assertEqual(stderr, "")
h, stderr = p("$!@#^") h, stderr = p("$!@#^")
@ -272,12 +268,10 @@ class Hints(unittest.TestCase):
def test_describe_hint_obj(self): def test_describe_hint_obj(self):
d = transit.describe_hint_obj d = transit.describe_hint_obj
self.assertEqual( self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, False),
d(transit.DirectTCPV1Hint("host", 1234, 0.0), False, False), "->tcp:host:1234")
"->tcp:host:1234") self.assertEqual(d(TorTCPV1Hint("host", 1234, 0.0), False, False),
self.assertEqual( "->tor:host:1234")
d(transit.TorTCPV1Hint("host", 1234, 0.0), False, False),
"->tor:host:1234")
self.assertEqual(d(UnknownHint("stuff"), False, False), self.assertEqual(d(UnknownHint("stuff"), False, False),
"->%s" % str(UnknownHint("stuff"))) "->%s" % str(UnknownHint("stuff")))
@ -443,7 +437,7 @@ class Listener(unittest.TestCase):
hints, ep = c._build_listener() hints, ep = c._build_listener()
self.assertIsInstance(hints, (list, set)) self.assertIsInstance(hints, (list, set))
if hints: if hints:
self.assertIsInstance(hints[0], transit.DirectTCPV1Hint) self.assertIsInstance(hints[0], DirectTCPV1Hint)
self.assertIsInstance(ep, endpoints.TCP4ServerEndpoint) self.assertIsInstance(ep, endpoints.TCP4ServerEndpoint)
def test_get_direct_hints(self): def test_get_direct_hints(self):
@ -1514,7 +1508,7 @@ class Transit(unittest.TestCase):
self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"]) self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"])
def _endpoint_from_hint_obj(self, hint, _tor, _reactor): def _endpoint_from_hint_obj(self, hint, _tor, _reactor):
if isinstance(hint, transit.DirectTCPV1Hint): if isinstance(hint, DirectTCPV1Hint):
if hint.hostname == "unavailable": if hint.hostname == "unavailable":
return None return None
return hint.hostname return hint.hostname

View File

@ -23,8 +23,9 @@ from . import ipaddrs
from .errors import InternalError from .errors import InternalError
from .timing import DebugTiming from .timing import DebugTiming
from .util import bytes_to_hexstr from .util import bytes_to_hexstr
from ._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint, from ._hints import (DirectTCPV1Hint, RelayV1Hint,
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj) parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj,
parse_tcp_v1_hint)
def HKDF(skm, outlen, salt=None, CTXinfo=b""): def HKDF(skm, outlen, salt=None, CTXinfo=b""):
@ -681,30 +682,11 @@ class Common:
self._listener_d.addErrback(lambda f: None) self._listener_d.addErrback(lambda f: None)
self._listener_d.cancel() self._listener_d.cancel()
def _parse_tcp_v1_hint(self, hint): # hint_struct -> hint_obj
hint_type = hint.get(u"type", u"")
if hint_type not in [u"direct-tcp-v1", u"tor-tcp-v1"]:
log.msg("unknown hint type: %r" % (hint, ))
return None
if not (u"hostname" in hint and
isinstance(hint[u"hostname"], type(u""))):
log.msg("invalid hostname in hint: %r" % (hint, ))
return None
if not (u"port" in hint and
isinstance(hint[u"port"], six.integer_types)):
log.msg("invalid port in hint: %r" % (hint, ))
return None
priority = hint.get(u"priority", 0.0)
if hint_type == u"direct-tcp-v1":
return DirectTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
else:
return TorTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
def add_connection_hints(self, hints): def add_connection_hints(self, hints):
for h in hints: # hint structs for h in hints: # hint structs
hint_type = h.get(u"type", u"") hint_type = h.get(u"type", u"")
if hint_type in [u"direct-tcp-v1", u"tor-tcp-v1"]: if hint_type in [u"direct-tcp-v1", u"tor-tcp-v1"]:
dh = self._parse_tcp_v1_hint(h) dh = parse_tcp_v1_hint(h)
if dh: if dh:
self._their_direct_hints.append(dh) # hint_obj self._their_direct_hints.append(dh) # hint_obj
elif hint_type == u"relay-v1": elif hint_type == u"relay-v1":
@ -714,7 +696,7 @@ class Common:
# together like this. # together like this.
relay_hints = [] relay_hints = []
for rhs in h.get(u"hints", []): for rhs in h.get(u"hints", []):
h = self._parse_tcp_v1_hint(rhs) h = parse_tcp_v1_hint(rhs)
if h: if h:
relay_hints.append(h) relay_hints.append(h)
if relay_hints: if relay_hints: