From 1bb5634d0ef5b3d653194571020b9e312d6149f6 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 21 Dec 2018 23:28:45 -0500 Subject: [PATCH] factor Hints tests out of test_transit into a new file --- src/wormhole/test/test_hints.py | 142 ++++++++++++++++++++++++++++++ src/wormhole/test/test_transit.py | 139 +---------------------------- 2 files changed, 144 insertions(+), 137 deletions(-) create mode 100644 src/wormhole/test/test_hints.py diff --git a/src/wormhole/test/test_hints.py b/src/wormhole/test/test_hints.py new file mode 100644 index 0000000..8cc71e3 --- /dev/null +++ b/src/wormhole/test/test_hints.py @@ -0,0 +1,142 @@ +from __future__ import print_function, unicode_literals +import io +from collections import namedtuple +import mock +from twisted.internet import endpoints, reactor +from twisted.trial import unittest +from .._hints import (endpoint_from_hint_obj, parse_hint_argv, parse_tcp_v1_hint, + describe_hint_obj, + DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint) + +UnknownHint = namedtuple("UnknownHint", ["stuff"]) + + +class Hints(unittest.TestCase): + def test_endpoint_from_hint_obj(self): + def efho(hint, tor=None): + return endpoint_from_hint_obj(hint, tor, reactor) + self.assertIsInstance(efho(DirectTCPV1Hint("host", 1234, 0.0)), + endpoints.HostnameEndpoint) + self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) + + # tor=None + self.assertEqual(efho(TorTCPV1Hint("host", "port", 0)), None) + + tor = mock.Mock() + def tor_ep(hostname, port): + if hostname == "non-public": + return None + return ("tor_ep", hostname, port) + tor.stream_via = mock.Mock(side_effect=tor_ep) + + self.assertEqual(efho(DirectTCPV1Hint("host", 1234, 0.0), tor), + ("tor_ep", "host", 1234)) + self.assertEqual(efho(TorTCPV1Hint("host2.onion", 1234, 0.0), tor), + ("tor_ep", "host2.onion", 1234)) + self.assertEqual( efho(DirectTCPV1Hint("non-public", 1234, 0.0), tor), None) + + self.assertEqual(efho(UnknownHint("foo")), None) + + def test_comparable(self): + h1 = DirectTCPV1Hint("hostname", "port1", 0.0) + h1b = DirectTCPV1Hint("hostname", "port1", 0.0) + h2 = DirectTCPV1Hint("hostname", "port2", 0.0) + r1 = RelayV1Hint(tuple(sorted([h1, h2]))) + r2 = RelayV1Hint(tuple(sorted([h2, h1]))) + r3 = RelayV1Hint(tuple(sorted([h1b, h2]))) + self.assertEqual(r1, r2) + self.assertEqual(r2, r3) + self.assertEqual(len(set([r1, r2, r3])), 1) + + def test_parse_tcp_v1_hint(self): + p = parse_tcp_v1_hint + self.assertEqual(p({"type": "unknown"}), None) + h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234}) + self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 0.0)) + h = p({ + "type": "direct-tcp-v1", + "hostname": "foo", + "port": 1234, + "priority": 2.5 + }) + self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 2.5)) + h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234}) + self.assertEqual(h, TorTCPV1Hint("foo", 1234, 0.0)) + h = p({ + "type": "tor-tcp-v1", + "hostname": "foo", + "port": 1234, + "priority": 2.5 + }) + self.assertEqual(h, TorTCPV1Hint("foo", 1234, 2.5)) + self.assertEqual(p({ + "type": "direct-tcp-v1" + }), None) # missing hostname + self.assertEqual(p({ + "type": "direct-tcp-v1", + "hostname": 12 + }), None) # invalid hostname + self.assertEqual( + p({ + "type": "direct-tcp-v1", + "hostname": "foo" + }), None) # missing port + self.assertEqual( + p({ + "type": "direct-tcp-v1", + "hostname": "foo", + "port": "not a number" + }), None) # invalid port + + def test_parse_hint_argv(self): + def p(hint): + stderr = io.StringIO() + value = parse_hint_argv(hint, stderr=stderr) + return value, stderr.getvalue() + + h, stderr = p("tcp:host:1234") + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0)) + self.assertEqual(stderr, "") + + h, stderr = p("tcp:host:1234:priority=2.6") + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 2.6)) + self.assertEqual(stderr, "") + + h, stderr = p("tcp:host:1234:unknown=stuff") + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0)) + self.assertEqual(stderr, "") + + h, stderr = p("$!@#^") + self.assertEqual(h, None) + self.assertEqual(stderr, "unparseable hint '$!@#^'\n") + + h, stderr = p("unknown:stuff") + self.assertEqual(h, None) + self.assertEqual(stderr, + "unknown hint type 'unknown' in 'unknown:stuff'\n") + + h, stderr = p("tcp:just-a-hostname") + self.assertEqual(h, None) + self.assertEqual( + stderr, + "unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n") + + h, stderr = p("tcp:host:number") + self.assertEqual(h, None) + self.assertEqual(stderr, + "non-numeric port in TCP hint 'tcp:host:number'\n") + + h, stderr = p("tcp:host:1234:priority=bad") + self.assertEqual(h, None) + self.assertEqual( + stderr, + "non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n") + + def test_describe_hint_obj(self): + d = describe_hint_obj + self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, False), + "->tcp:host:1234") + self.assertEqual(d(TorTCPV1Hint("host", 1234, 0.0), False, False), + "->tor:host:1234") + self.assertEqual(d(UnknownHint("stuff"), False, False), + "->%s" % str(UnknownHint("stuff"))) diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index 86f008b..b3d7590 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -3,12 +3,11 @@ from __future__ import print_function, unicode_literals import gc import io from binascii import hexlify, unhexlify -from collections import namedtuple import six from nacl.exceptions import CryptoError from nacl.secret import SecretBox -from twisted.internet import address, defer, endpoints, error, protocol, task, reactor +from twisted.internet import address, defer, endpoints, error, protocol, task from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.python import log from twisted.test import proto_helpers @@ -18,8 +17,7 @@ import mock from wormhole_transit_relay import transit_server from .. import transit -from .._hints import (endpoint_from_hint_obj, parse_hint_argv, parse_tcp_v1_hint, - DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint) +from .._hints import DirectTCPV1Hint from ..errors import InternalError from .common import ServerBase @@ -142,139 +140,6 @@ class Misc(unittest.TestCase): self.assertIsInstance(portno, int) -UnknownHint = namedtuple("UnknownHint", ["stuff"]) - - -class Hints(unittest.TestCase): - def test_endpoint_from_hint_obj(self): - def efho(hint, tor=None): - return endpoint_from_hint_obj(hint, tor, reactor) - self.assertIsInstance(efho(DirectTCPV1Hint("host", 1234, 0.0)), - endpoints.HostnameEndpoint) - self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) - - # tor=None - self.assertEqual(efho(TorTCPV1Hint("host", "port", 0)), None) - - tor = mock.Mock() - def tor_ep(hostname, port): - if hostname == "non-public": - return None - return ("tor_ep", hostname, port) - tor.stream_via = mock.Mock(side_effect=tor_ep) - - self.assertEqual(efho(DirectTCPV1Hint("host", 1234, 0.0), tor), - ("tor_ep", "host", 1234)) - self.assertEqual(efho(TorTCPV1Hint("host2.onion", 1234, 0.0), tor), - ("tor_ep", "host2.onion", 1234)) - self.assertEqual( efho(DirectTCPV1Hint("non-public", 1234, 0.0), tor), None) - - self.assertEqual(efho(UnknownHint("foo")), None) - - def test_comparable(self): - h1 = DirectTCPV1Hint("hostname", "port1", 0.0) - h1b = DirectTCPV1Hint("hostname", "port1", 0.0) - h2 = DirectTCPV1Hint("hostname", "port2", 0.0) - r1 = RelayV1Hint(tuple(sorted([h1, h2]))) - r2 = RelayV1Hint(tuple(sorted([h2, h1]))) - r3 = RelayV1Hint(tuple(sorted([h1b, h2]))) - self.assertEqual(r1, r2) - self.assertEqual(r2, r3) - self.assertEqual(len(set([r1, r2, r3])), 1) - - def test_parse_tcp_v1_hint(self): - p = parse_tcp_v1_hint - self.assertEqual(p({"type": "unknown"}), None) - h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234}) - self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 0.0)) - h = p({ - "type": "direct-tcp-v1", - "hostname": "foo", - "port": 1234, - "priority": 2.5 - }) - self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 2.5)) - h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234}) - self.assertEqual(h, TorTCPV1Hint("foo", 1234, 0.0)) - h = p({ - "type": "tor-tcp-v1", - "hostname": "foo", - "port": 1234, - "priority": 2.5 - }) - self.assertEqual(h, TorTCPV1Hint("foo", 1234, 2.5)) - self.assertEqual(p({ - "type": "direct-tcp-v1" - }), None) # missing hostname - self.assertEqual(p({ - "type": "direct-tcp-v1", - "hostname": 12 - }), None) # invalid hostname - self.assertEqual( - p({ - "type": "direct-tcp-v1", - "hostname": "foo" - }), None) # missing port - self.assertEqual( - p({ - "type": "direct-tcp-v1", - "hostname": "foo", - "port": "not a number" - }), None) # invalid port - - def test_parse_hint_argv(self): - def p(hint): - stderr = io.StringIO() - value = parse_hint_argv(hint, stderr=stderr) - return value, stderr.getvalue() - - h, stderr = p("tcp:host:1234") - self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0)) - self.assertEqual(stderr, "") - - h, stderr = p("tcp:host:1234:priority=2.6") - self.assertEqual(h, DirectTCPV1Hint("host", 1234, 2.6)) - self.assertEqual(stderr, "") - - h, stderr = p("tcp:host:1234:unknown=stuff") - self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0)) - self.assertEqual(stderr, "") - - h, stderr = p("$!@#^") - self.assertEqual(h, None) - self.assertEqual(stderr, "unparseable hint '$!@#^'\n") - - h, stderr = p("unknown:stuff") - self.assertEqual(h, None) - self.assertEqual(stderr, - "unknown hint type 'unknown' in 'unknown:stuff'\n") - - h, stderr = p("tcp:just-a-hostname") - self.assertEqual(h, None) - self.assertEqual( - stderr, - "unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n") - - h, stderr = p("tcp:host:number") - self.assertEqual(h, None) - self.assertEqual(stderr, - "non-numeric port in TCP hint 'tcp:host:number'\n") - - h, stderr = p("tcp:host:1234:priority=bad") - self.assertEqual(h, None) - self.assertEqual( - stderr, - "non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n") - - def test_describe_hint_obj(self): - d = transit.describe_hint_obj - self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, False), - "->tcp:host:1234") - self.assertEqual(d(TorTCPV1Hint("host", 1234, 0.0), False, False), - "->tor:host:1234") - self.assertEqual(d(UnknownHint("stuff"), False, False), - "->%s" % str(UnknownHint("stuff"))) - # ipaddrs.py currently uses native strings: bytes on py2, unicode on # py3