From 7720312c8f22e914bf9ffce1f1240f05ff86d86d Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 21 Dec 2018 23:22:02 -0500 Subject: [PATCH] factor out parse_tcp_v1_hint --- src/wormhole/_dilation/connector.py | 23 +--------- src/wormhole/_hints.py | 21 +++++++++ src/wormhole/test/test_transit.py | 68 +++++++++++++---------------- src/wormhole/transit.py | 28 +++--------- 4 files changed, 59 insertions(+), 81 deletions(-) diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py index f472ea9..3d4d773 100644 --- a/src/wormhole/_dilation/connector.py +++ b/src/wormhole/_dilation/connector.py @@ -20,27 +20,8 @@ from .connection import DilatedConnectionProtocol, KCM from .roles import LEADER from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint, - parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj) - - -def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj - hint_type = hint.get("type", "") - if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]: - log.msg("unknown hint type: %r" % (hint,)) - return None - if not("hostname" in hint and - isinstance(hint["hostname"], type(""))): - log.msg("invalid hostname in hint: %r" % (hint,)) - return None - if not("port" in hint and - isinstance(hint["port"], six.integer_types)): - log.msg("invalid port in hint: %r" % (hint,)) - return None - priority = hint.get("priority", 0.0) - if hint_type == "direct-tcp-v1": - return DirectTCPV1Hint(hint["hostname"], hint["port"], priority) - else: - return TorTCPV1Hint(hint["hostname"], hint["port"], priority) + parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj, + parse_tcp_v1_hint) def parse_hint(hint_struct): diff --git a/src/wormhole/_hints.py b/src/wormhole/_hints.py index 9b79542..0a44b3b 100644 --- a/src/wormhole/_hints.py +++ b/src/wormhole/_hints.py @@ -1,8 +1,10 @@ from __future__ import print_function, unicode_literals import sys import re +import six from collections import namedtuple from twisted.internet.endpoints import HostnameEndpoint +from twisted.python import log # These namedtuples are "hint objects". The JSON-serializable dictionaries # are "hint dicts". @@ -82,3 +84,22 @@ def endpoint_from_hint_obj(hint, tor, reactor): if isinstance(hint, DirectTCPV1Hint): return HostnameEndpoint(reactor, hint.hostname, hint.port) return None + +def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj + hint_type = hint.get("type", "") + if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]: + log.msg("unknown hint type: %r" % (hint, )) + return None + if not ("hostname" in hint and + isinstance(hint["hostname"], type(""))): + log.msg("invalid hostname in hint: %r" % (hint, )) + return None + if not ("port" in hint and + isinstance(hint["port"], six.integer_types)): + log.msg("invalid port in hint: %r" % (hint, )) + return None + priority = hint.get("priority", 0.0) + if hint_type == "direct-tcp-v1": + return DirectTCPV1Hint(hint["hostname"], hint["port"], priority) + else: + return TorTCPV1Hint(hint["hostname"], hint["port"], priority) diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index 9a53983..86f008b 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -18,7 +18,8 @@ import mock from wormhole_transit_relay import transit_server from .. import transit -from .._hints import endpoint_from_hint_obj +from .._hints import (endpoint_from_hint_obj, parse_hint_argv, parse_tcp_v1_hint, + DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint) from ..errors import InternalError from .common import ServerBase @@ -148,13 +149,12 @@ class Hints(unittest.TestCase): def test_endpoint_from_hint_obj(self): def efho(hint, tor=None): return endpoint_from_hint_obj(hint, tor, reactor) - self.assertIsInstance( - efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), - endpoints.HostnameEndpoint) + self.assertIsInstance(efho(DirectTCPV1Hint("host", 1234, 0.0)), + endpoints.HostnameEndpoint) self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) # tor=None - self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None) + self.assertEqual(efho(TorTCPV1Hint("host", "port", 0)), None) tor = mock.Mock() def tor_ep(hostname, port): @@ -163,50 +163,46 @@ class Hints(unittest.TestCase): return ("tor_ep", hostname, port) tor.stream_via = mock.Mock(side_effect=tor_ep) - self.assertEqual( - efho(transit.DirectTCPV1Hint("host", 1234, 0.0), tor), - ("tor_ep", "host", 1234)) - self.assertEqual( - efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0), tor), - ("tor_ep", "host2.onion", 1234)) - self.assertEqual( - efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0), tor), None) + self.assertEqual(efho(DirectTCPV1Hint("host", 1234, 0.0), tor), + ("tor_ep", "host", 1234)) + self.assertEqual(efho(TorTCPV1Hint("host2.onion", 1234, 0.0), tor), + ("tor_ep", "host2.onion", 1234)) + self.assertEqual( efho(DirectTCPV1Hint("non-public", 1234, 0.0), tor), None) self.assertEqual(efho(UnknownHint("foo")), None) def test_comparable(self): - h1 = transit.DirectTCPV1Hint("hostname", "port1", 0.0) - h1b = transit.DirectTCPV1Hint("hostname", "port1", 0.0) - h2 = transit.DirectTCPV1Hint("hostname", "port2", 0.0) - r1 = transit.RelayV1Hint(tuple(sorted([h1, h2]))) - r2 = transit.RelayV1Hint(tuple(sorted([h2, h1]))) - r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2]))) + h1 = DirectTCPV1Hint("hostname", "port1", 0.0) + h1b = DirectTCPV1Hint("hostname", "port1", 0.0) + h2 = DirectTCPV1Hint("hostname", "port2", 0.0) + r1 = RelayV1Hint(tuple(sorted([h1, h2]))) + r2 = RelayV1Hint(tuple(sorted([h2, h1]))) + r3 = RelayV1Hint(tuple(sorted([h1b, h2]))) self.assertEqual(r1, r2) self.assertEqual(r2, r3) self.assertEqual(len(set([r1, r2, r3])), 1) def test_parse_tcp_v1_hint(self): - c = transit.Common("") - p = c._parse_tcp_v1_hint + p = parse_tcp_v1_hint self.assertEqual(p({"type": "unknown"}), None) h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234}) - self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0)) + self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 0.0)) h = p({ "type": "direct-tcp-v1", "hostname": "foo", "port": 1234, "priority": 2.5 }) - self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5)) + self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 2.5)) h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234}) - self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0)) + self.assertEqual(h, TorTCPV1Hint("foo", 1234, 0.0)) h = p({ "type": "tor-tcp-v1", "hostname": "foo", "port": 1234, "priority": 2.5 }) - self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5)) + self.assertEqual(h, TorTCPV1Hint("foo", 1234, 2.5)) self.assertEqual(p({ "type": "direct-tcp-v1" }), None) # missing hostname @@ -229,19 +225,19 @@ class Hints(unittest.TestCase): def test_parse_hint_argv(self): def p(hint): stderr = io.StringIO() - value = transit.parse_hint_argv(hint, stderr=stderr) + value = parse_hint_argv(hint, stderr=stderr) return value, stderr.getvalue() h, stderr = p("tcp:host:1234") - self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0)) self.assertEqual(stderr, "") h, stderr = p("tcp:host:1234:priority=2.6") - self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6)) + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 2.6)) self.assertEqual(stderr, "") h, stderr = p("tcp:host:1234:unknown=stuff") - self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0)) self.assertEqual(stderr, "") h, stderr = p("$!@#^") @@ -272,12 +268,10 @@ class Hints(unittest.TestCase): def test_describe_hint_obj(self): d = transit.describe_hint_obj - self.assertEqual( - d(transit.DirectTCPV1Hint("host", 1234, 0.0), False, False), - "->tcp:host:1234") - self.assertEqual( - d(transit.TorTCPV1Hint("host", 1234, 0.0), False, False), - "->tor:host:1234") + self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, False), + "->tcp:host:1234") + self.assertEqual(d(TorTCPV1Hint("host", 1234, 0.0), False, False), + "->tor:host:1234") self.assertEqual(d(UnknownHint("stuff"), False, False), "->%s" % str(UnknownHint("stuff"))) @@ -443,7 +437,7 @@ class Listener(unittest.TestCase): hints, ep = c._build_listener() self.assertIsInstance(hints, (list, set)) if hints: - self.assertIsInstance(hints[0], transit.DirectTCPV1Hint) + self.assertIsInstance(hints[0], DirectTCPV1Hint) self.assertIsInstance(ep, endpoints.TCP4ServerEndpoint) def test_get_direct_hints(self): @@ -1514,7 +1508,7 @@ class Transit(unittest.TestCase): self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"]) def _endpoint_from_hint_obj(self, hint, _tor, _reactor): - if isinstance(hint, transit.DirectTCPV1Hint): + if isinstance(hint, DirectTCPV1Hint): if hint.hostname == "unavailable": return None return hint.hostname diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index b122055..63aafdb 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -23,8 +23,9 @@ from . import ipaddrs from .errors import InternalError from .timing import DebugTiming from .util import bytes_to_hexstr -from ._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint, - parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj) +from ._hints import (DirectTCPV1Hint, RelayV1Hint, + parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj, + parse_tcp_v1_hint) def HKDF(skm, outlen, salt=None, CTXinfo=b""): @@ -681,30 +682,11 @@ class Common: self._listener_d.addErrback(lambda f: None) self._listener_d.cancel() - def _parse_tcp_v1_hint(self, hint): # hint_struct -> hint_obj - hint_type = hint.get(u"type", u"") - if hint_type not in [u"direct-tcp-v1", u"tor-tcp-v1"]: - log.msg("unknown hint type: %r" % (hint, )) - return None - if not (u"hostname" in hint and - isinstance(hint[u"hostname"], type(u""))): - log.msg("invalid hostname in hint: %r" % (hint, )) - return None - if not (u"port" in hint and - isinstance(hint[u"port"], six.integer_types)): - log.msg("invalid port in hint: %r" % (hint, )) - return None - priority = hint.get(u"priority", 0.0) - if hint_type == u"direct-tcp-v1": - return DirectTCPV1Hint(hint[u"hostname"], hint[u"port"], priority) - else: - return TorTCPV1Hint(hint[u"hostname"], hint[u"port"], priority) - def add_connection_hints(self, hints): for h in hints: # hint structs hint_type = h.get(u"type", u"") if hint_type in [u"direct-tcp-v1", u"tor-tcp-v1"]: - dh = self._parse_tcp_v1_hint(h) + dh = parse_tcp_v1_hint(h) if dh: self._their_direct_hints.append(dh) # hint_obj elif hint_type == u"relay-v1": @@ -714,7 +696,7 @@ class Common: # together like this. relay_hints = [] for rhs in h.get(u"hints", []): - h = self._parse_tcp_v1_hint(rhs) + h = parse_tcp_v1_hint(rhs) if h: relay_hints.append(h) if relay_hints: