From 511a73c491bc3d1fc6f9bcef3e06c680b20c8eb1 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 30 Dec 2016 23:34:20 -0500 Subject: [PATCH] improve coverage: Transit._endpoint_from_hint_obj --- src/wormhole/test/test_transit.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index 7f63eac..657f9e6 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -3,6 +3,7 @@ import io import gc import mock from binascii import hexlify, unhexlify +from collections import namedtuple from twisted.trial import unittest from twisted.internet import defer, task, endpoints, protocol, address, error from twisted.internet.defer import gatherResults, inlineCallbacks @@ -134,10 +135,26 @@ class Misc(unittest.TestCase): class Hints(unittest.TestCase): def test_endpoint_from_hint_obj(self): c = transit.Common("") - ep = c._endpoint_from_hint_obj(transit.DirectTCPV1Hint("localhost", 1234)) - self.assertIsInstance(ep, endpoints.HostnameEndpoint) - ep = c._endpoint_from_hint_obj("unknown:stuff:yowza:pivlor") - self.assertEqual(ep, None) + efho = c._endpoint_from_hint_obj + self.assertIsInstance(efho(transit.DirectTCPV1Hint("host", 1234)), + endpoints.HostnameEndpoint) + self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) + # c._tor_manager is currently None + self.assertEqual(efho(transit.TorTCPV1Hint("host", "port")), None) + c._tor_manager = mock.Mock() + def tor_ep(hostname, port): + if hostname == "non-public": + return None + return ("tor_ep", hostname, port) + c._tor_manager.get_endpoint_for = mock.Mock(side_effect=tor_ep) + self.assertEqual(efho(transit.DirectTCPV1Hint("host", 1234)), + ("tor_ep", "host", 1234)) + self.assertEqual(efho(transit.TorTCPV1Hint("host2.onion", 1234)), + ("tor_ep", "host2.onion", 1234)) + self.assertEqual(efho(transit.DirectTCPV1Hint("non-public", 1234)), + None) + UnknownHint = namedtuple("UnknownHint", ["stuff"]) + self.assertEqual(efho(UnknownHint("foo")), None) def test_comparable(self): h1 = transit.DirectTCPV1Hint("hostname", "port1")