test_transit: improve coverage
and fix py2/py3-isms in ipaddr tests
This commit is contained in:
		
							parent
							
								
									478405cb6a
								
							
						
					
					
						commit
						afe9f7152d
					
				| 
						 | 
				
			
			@ -1,4 +1,5 @@
 | 
			
		|||
from __future__ import print_function, unicode_literals
 | 
			
		||||
import six
 | 
			
		||||
import io
 | 
			
		||||
import gc
 | 
			
		||||
import mock
 | 
			
		||||
| 
						 | 
				
			
			@ -132,6 +133,13 @@ class Misc(unittest.TestCase):
 | 
			
		|||
        portno = transit.allocate_tcp_port()
 | 
			
		||||
        self.assertIsInstance(portno, int)
 | 
			
		||||
 | 
			
		||||
    def test_allocate_port_no_reuseaddr(self):
 | 
			
		||||
        mock_sys = mock.Mock()
 | 
			
		||||
        mock_sys.platform = "cygwin"
 | 
			
		||||
        with mock.patch("wormhole.transit.sys", mock_sys):
 | 
			
		||||
            portno = transit.allocate_tcp_port()
 | 
			
		||||
        self.assertIsInstance(portno, int)
 | 
			
		||||
 | 
			
		||||
UnknownHint = namedtuple("UnknownHint", ["stuff"])
 | 
			
		||||
 | 
			
		||||
class Hints(unittest.TestCase):
 | 
			
		||||
| 
						 | 
				
			
			@ -205,6 +213,10 @@ class Hints(unittest.TestCase):
 | 
			
		|||
        self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6))
 | 
			
		||||
        self.assertEqual(stderr, "")
 | 
			
		||||
 | 
			
		||||
        h,stderr = p("tcp:host:1234:unknown=stuff")
 | 
			
		||||
        self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0))
 | 
			
		||||
        self.assertEqual(stderr, "")
 | 
			
		||||
 | 
			
		||||
        h,stderr = p("$!@#^")
 | 
			
		||||
        self.assertEqual(h, None)
 | 
			
		||||
        self.assertEqual(stderr, "unparseable hint '$!@#^'\n")
 | 
			
		||||
| 
						 | 
				
			
			@ -237,6 +249,15 @@ class Hints(unittest.TestCase):
 | 
			
		|||
                         "tor:host:1234")
 | 
			
		||||
        self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff")))
 | 
			
		||||
 | 
			
		||||
# ipaddrs.py currently uses native strings: bytes on py2, unicode on
 | 
			
		||||
# py3
 | 
			
		||||
if six.PY2:
 | 
			
		||||
    LOOPADDR = b"127.0.0.1"
 | 
			
		||||
    OTHERADDR = b"1.2.3.4"
 | 
			
		||||
else:
 | 
			
		||||
    LOOPADDR = "127.0.0.1" # unicode_literals
 | 
			
		||||
    OTHERADDR = "1.2.3.4"
 | 
			
		||||
 | 
			
		||||
class Basic(unittest.TestCase):
 | 
			
		||||
    @inlineCallbacks
 | 
			
		||||
    def test_relay_hints(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -265,7 +286,7 @@ class Basic(unittest.TestCase):
 | 
			
		|||
        self.assertEqual(c._their_direct_hints, [])
 | 
			
		||||
        self.assertEqual(c._our_relay_hints, set())
 | 
			
		||||
 | 
			
		||||
    def test_ignore_localhost_hint(self):
 | 
			
		||||
    def test_ignore_localhost_hint_orig(self):
 | 
			
		||||
        # this actually starts the listener
 | 
			
		||||
        c = transit.TransitSender("")
 | 
			
		||||
        results = []
 | 
			
		||||
| 
						 | 
				
			
			@ -281,6 +302,36 @@ class Basic(unittest.TestCase):
 | 
			
		|||
        for hint in hints:
 | 
			
		||||
            self.assertFalse(hint["hostname"] == "127.0.0.1")
 | 
			
		||||
 | 
			
		||||
    def test_ignore_localhost_hint(self):
 | 
			
		||||
        # this actually starts the listener
 | 
			
		||||
        c = transit.TransitSender("")
 | 
			
		||||
        with mock.patch("wormhole.ipaddrs.find_addresses",
 | 
			
		||||
                        return_value=[LOOPADDR, OTHERADDR]):
 | 
			
		||||
            hints = self.successResultOf(c.get_connection_hints())
 | 
			
		||||
        c._stop_listening()
 | 
			
		||||
        # If there are non-localhost hints, then localhost hints should be
 | 
			
		||||
        # removed.
 | 
			
		||||
        self.assertEqual(len(hints), 1)
 | 
			
		||||
        self.assertEqual(hints[0]["hostname"], "1.2.3.4")
 | 
			
		||||
 | 
			
		||||
    def test_keep_only_localhost_hint(self):
 | 
			
		||||
        # this actually starts the listener
 | 
			
		||||
        c = transit.TransitSender("")
 | 
			
		||||
        with mock.patch("wormhole.ipaddrs.find_addresses",
 | 
			
		||||
                        return_value=[LOOPADDR]):
 | 
			
		||||
            hints = self.successResultOf(c.get_connection_hints())
 | 
			
		||||
        c._stop_listening()
 | 
			
		||||
        # If the only hint is localhost, it should stay.
 | 
			
		||||
        self.assertEqual(len(hints), 1)
 | 
			
		||||
        self.assertEqual(hints[0]["hostname"], "127.0.0.1")
 | 
			
		||||
 | 
			
		||||
    def test_abilities(self):
 | 
			
		||||
        c = transit.Common(None, no_listen=True)
 | 
			
		||||
        abilities = c.get_connection_abilities()
 | 
			
		||||
        self.assertEqual(abilities, [{"type": "direct-tcp-v1"},
 | 
			
		||||
                                     {"type": "relay-v1"},
 | 
			
		||||
                                     ])
 | 
			
		||||
 | 
			
		||||
    def test_transit_key_wait(self):
 | 
			
		||||
        KEY = b"123"
 | 
			
		||||
        c = transit.Common("")
 | 
			
		||||
| 
						 | 
				
			
			@ -642,6 +693,10 @@ class Connection(unittest.TestCase):
 | 
			
		|||
        self.assertTrue(c._check_and_remove(EXP))
 | 
			
		||||
        self.assertEqual(c.buf, b" exceeded")
 | 
			
		||||
 | 
			
		||||
    def test_describe(self):
 | 
			
		||||
        c = transit.Connection(None, None, None, "description")
 | 
			
		||||
        self.assertEqual(c.describe(), "description")
 | 
			
		||||
 | 
			
		||||
    def test_sender_accepting(self):
 | 
			
		||||
        relay_handshake = None
 | 
			
		||||
        owner = MockOwner()
 | 
			
		||||
| 
						 | 
				
			
			@ -730,6 +785,33 @@ class Connection(unittest.TestCase):
 | 
			
		|||
        self.assertIsInstance(f, failure.Failure)
 | 
			
		||||
        self.assertIsInstance(f.value, RandomError)
 | 
			
		||||
 | 
			
		||||
    def test_handshake_bad_state(self):
 | 
			
		||||
        owner = MockOwner()
 | 
			
		||||
        factory = MockFactory()
 | 
			
		||||
        addr = address.HostnameAddress("example.com", 1234)
 | 
			
		||||
        c = transit.Connection(owner, None, None, "description")
 | 
			
		||||
        self.assertEqual(c.state, "too-early")
 | 
			
		||||
        t = c.transport = FakeTransport(c, addr)
 | 
			
		||||
        c.factory = factory
 | 
			
		||||
        c.connectionMade()
 | 
			
		||||
        self.assertEqual(factory._connectionWasMade_called, True)
 | 
			
		||||
        self.assertEqual(factory._p, c)
 | 
			
		||||
 | 
			
		||||
        d = c.startNegotiation()
 | 
			
		||||
        self.assertEqual(c.state, "handshake")
 | 
			
		||||
        self.assertEqual(t.read_buf(), b"send_this")
 | 
			
		||||
        results = []
 | 
			
		||||
        d.addBoth(results.append)
 | 
			
		||||
        self.assertEqual(results, [])
 | 
			
		||||
        c.state = "unknown-bogus-state"
 | 
			
		||||
        self.assertRaises(ValueError, c.dataReceived, b"surprise!")
 | 
			
		||||
        self.assertEqual(t._connected, False)
 | 
			
		||||
        self.assertEqual(c.state, "hung up")
 | 
			
		||||
        self.assertEqual(len(results), 1)
 | 
			
		||||
        f = results[0]
 | 
			
		||||
        self.assertIsInstance(f, failure.Failure)
 | 
			
		||||
        self.assertIsInstance(f.value, ValueError)
 | 
			
		||||
 | 
			
		||||
    def test_relay_handshake(self):
 | 
			
		||||
        relay_handshake = b"relay handshake"
 | 
			
		||||
        owner = MockOwner()
 | 
			
		||||
| 
						 | 
				
			
			@ -962,6 +1044,13 @@ class Connection(unittest.TestCase):
 | 
			
		|||
 | 
			
		||||
        return t, c, owner
 | 
			
		||||
 | 
			
		||||
    def test_records_not_binary(self):
 | 
			
		||||
        t, c, owner = self.make_connection()
 | 
			
		||||
 | 
			
		||||
        RECORD1 = u"not binary"
 | 
			
		||||
        with self.assertRaises(InternalError):
 | 
			
		||||
            c.send_record(RECORD1)
 | 
			
		||||
 | 
			
		||||
    def test_records_good(self):
 | 
			
		||||
        # now make sure that outbound records are encrypted properly
 | 
			
		||||
        t, c, owner = self.make_connection()
 | 
			
		||||
| 
						 | 
				
			
			@ -1203,6 +1292,20 @@ class Connection(unittest.TestCase):
 | 
			
		|||
        self.assertIsInstance(f, failure.Failure)
 | 
			
		||||
        self.assertIsInstance(f.value, error.ConnectionClosed)
 | 
			
		||||
 | 
			
		||||
    def test_connectConsumer_empty(self):
 | 
			
		||||
        # if connectConsumer() expects 0 bytes (e.g. someone is "sending" a
 | 
			
		||||
        # zero-length file), make sure it gets woken up right away, so it can
 | 
			
		||||
        # disconnect itself, even though no bytes will actually arrive
 | 
			
		||||
        c = transit.Connection(None, None, None, "description")
 | 
			
		||||
        c._negotiation_d.addErrback(lambda err: None) # eat it
 | 
			
		||||
        c.transport = proto_helpers.StringTransport()
 | 
			
		||||
 | 
			
		||||
        consumer = proto_helpers.StringTransport()
 | 
			
		||||
        d = c.connectConsumer(consumer, expected=0)
 | 
			
		||||
        self.assertEqual(self.successResultOf(d), 0)
 | 
			
		||||
        self.assertEqual(consumer.value(), b"")
 | 
			
		||||
        self.assertIs(c._consumer, None)
 | 
			
		||||
 | 
			
		||||
    def test_writeToFile(self):
 | 
			
		||||
        c = transit.Connection(None, None, None, "description")
 | 
			
		||||
        c._negotiation_d.addErrback(lambda err: None) # eat it
 | 
			
		||||
| 
						 | 
				
			
			@ -1293,12 +1396,31 @@ class FileConsumer(unittest.TestCase):
 | 
			
		|||
        self.assertEqual(progress, [99, 1])
 | 
			
		||||
        self.assertEqual(f.getvalue(), b"."*99+b"!")
 | 
			
		||||
 | 
			
		||||
    def test_hasher(self):
 | 
			
		||||
        hashee = []
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        progress = []
 | 
			
		||||
        fc = transit.FileConsumer(f, progress.append, hasher=hashee.append)
 | 
			
		||||
        self.assertEqual(progress, [])
 | 
			
		||||
        self.assertEqual(f.getvalue(), b"")
 | 
			
		||||
        self.assertEqual(hashee, [])
 | 
			
		||||
        fc.write(b"."* 99)
 | 
			
		||||
        self.assertEqual(progress, [99])
 | 
			
		||||
        self.assertEqual(f.getvalue(), b"."*99)
 | 
			
		||||
        self.assertEqual(hashee, [b"."*99])
 | 
			
		||||
        fc.write(b"!")
 | 
			
		||||
        self.assertEqual(progress, [99, 1])
 | 
			
		||||
        self.assertEqual(f.getvalue(), b"."*99+b"!")
 | 
			
		||||
        self.assertEqual(hashee, [b"."*99, b"!"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
DIRECT_HINT_JSON = {"type": "direct-tcp-v1",
 | 
			
		||||
                    "hostname": "direct", "port": 1234}
 | 
			
		||||
RELAY_HINT_JSON = {"type": "relay-v1",
 | 
			
		||||
                   "hints": [{"type": "direct-tcp-v1",
 | 
			
		||||
                              "hostname": "relay", "port": 1234}]}
 | 
			
		||||
UNRECOGNIZED_DIRECT_HINT_JSON = {"type": "direct-tcp-v1",
 | 
			
		||||
                                 "hostname": ["cannot", "parse", "list"]}
 | 
			
		||||
UNRECOGNIZED_HINT_JSON = {"type": "unknown"}
 | 
			
		||||
UNAVAILABLE_HINT_JSON = {"type": "direct-tcp-v1", # e.g. Tor without txtorcon
 | 
			
		||||
                         "hostname": "unavailable", "port": 1234}
 | 
			
		||||
| 
						 | 
				
			
			@ -1327,7 +1449,9 @@ class Transit(unittest.TestCase):
 | 
			
		|||
        s.set_transit_key(b"key")
 | 
			
		||||
        hints = yield s.get_connection_hints() # start the listener
 | 
			
		||||
        del hints
 | 
			
		||||
        s.add_connection_hints([DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON])
 | 
			
		||||
        s.add_connection_hints([DIRECT_HINT_JSON,
 | 
			
		||||
                                UNRECOGNIZED_DIRECT_HINT_JSON,
 | 
			
		||||
                                UNRECOGNIZED_HINT_JSON])
 | 
			
		||||
 | 
			
		||||
        s._start_connector = self._start_connector
 | 
			
		||||
        d = s.connect()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -245,6 +245,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
 | 
			
		|||
            return self.dataReceivedRECORDS()
 | 
			
		||||
        if isinstance(self.state, Exception): # for tests
 | 
			
		||||
            raise self.state
 | 
			
		||||
        raise ValueError("internal error: unknown state %s" % (self.state,))
 | 
			
		||||
 | 
			
		||||
    def _negotiationSuccessful(self):
 | 
			
		||||
        self.state = "records"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user