From afe9f7152d972830901e16cdad37e58c3e8441c6 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 16 Apr 2017 18:13:51 -0400 Subject: [PATCH] test_transit: improve coverage and fix py2/py3-isms in ipaddr tests --- src/wormhole/test/test_transit.py | 128 +++++++++++++++++++++++++++++- src/wormhole/transit.py | 1 + 2 files changed, 127 insertions(+), 2 deletions(-) diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index 62848b4..c170b32 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -1,4 +1,5 @@ from __future__ import print_function, unicode_literals +import six import io import gc import mock @@ -132,6 +133,13 @@ class Misc(unittest.TestCase): portno = transit.allocate_tcp_port() self.assertIsInstance(portno, int) + def test_allocate_port_no_reuseaddr(self): + mock_sys = mock.Mock() + mock_sys.platform = "cygwin" + with mock.patch("wormhole.transit.sys", mock_sys): + portno = transit.allocate_tcp_port() + self.assertIsInstance(portno, int) + UnknownHint = namedtuple("UnknownHint", ["stuff"]) class Hints(unittest.TestCase): @@ -205,6 +213,10 @@ class Hints(unittest.TestCase): self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6)) self.assertEqual(stderr, "") + h,stderr = p("tcp:host:1234:unknown=stuff") + self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) + self.assertEqual(stderr, "") + h,stderr = p("$!@#^") self.assertEqual(h, None) self.assertEqual(stderr, "unparseable hint '$!@#^'\n") @@ -237,6 +249,15 @@ class Hints(unittest.TestCase): "tor:host:1234") self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff"))) +# ipaddrs.py currently uses native strings: bytes on py2, unicode on +# py3 +if six.PY2: + LOOPADDR = b"127.0.0.1" + OTHERADDR = b"1.2.3.4" +else: + LOOPADDR = "127.0.0.1" # unicode_literals + OTHERADDR = "1.2.3.4" + class Basic(unittest.TestCase): @inlineCallbacks def test_relay_hints(self): @@ -265,7 +286,7 @@ class Basic(unittest.TestCase): self.assertEqual(c._their_direct_hints, []) self.assertEqual(c._our_relay_hints, set()) - def test_ignore_localhost_hint(self): + def test_ignore_localhost_hint_orig(self): # this actually starts the listener c = transit.TransitSender("") results = [] @@ -281,6 +302,36 @@ class Basic(unittest.TestCase): for hint in hints: self.assertFalse(hint["hostname"] == "127.0.0.1") + def test_ignore_localhost_hint(self): + # this actually starts the listener + c = transit.TransitSender("") + with mock.patch("wormhole.ipaddrs.find_addresses", + return_value=[LOOPADDR, OTHERADDR]): + hints = self.successResultOf(c.get_connection_hints()) + c._stop_listening() + # If there are non-localhost hints, then localhost hints should be + # removed. + self.assertEqual(len(hints), 1) + self.assertEqual(hints[0]["hostname"], "1.2.3.4") + + def test_keep_only_localhost_hint(self): + # this actually starts the listener + c = transit.TransitSender("") + with mock.patch("wormhole.ipaddrs.find_addresses", + return_value=[LOOPADDR]): + hints = self.successResultOf(c.get_connection_hints()) + c._stop_listening() + # If the only hint is localhost, it should stay. + self.assertEqual(len(hints), 1) + self.assertEqual(hints[0]["hostname"], "127.0.0.1") + + def test_abilities(self): + c = transit.Common(None, no_listen=True) + abilities = c.get_connection_abilities() + self.assertEqual(abilities, [{"type": "direct-tcp-v1"}, + {"type": "relay-v1"}, + ]) + def test_transit_key_wait(self): KEY = b"123" c = transit.Common("") @@ -642,6 +693,10 @@ class Connection(unittest.TestCase): self.assertTrue(c._check_and_remove(EXP)) self.assertEqual(c.buf, b" exceeded") + def test_describe(self): + c = transit.Connection(None, None, None, "description") + self.assertEqual(c.describe(), "description") + def test_sender_accepting(self): relay_handshake = None owner = MockOwner() @@ -730,6 +785,33 @@ class Connection(unittest.TestCase): self.assertIsInstance(f, failure.Failure) self.assertIsInstance(f.value, RandomError) + def test_handshake_bad_state(self): + owner = MockOwner() + factory = MockFactory() + addr = address.HostnameAddress("example.com", 1234) + c = transit.Connection(owner, None, None, "description") + self.assertEqual(c.state, "too-early") + t = c.transport = FakeTransport(c, addr) + c.factory = factory + c.connectionMade() + self.assertEqual(factory._connectionWasMade_called, True) + self.assertEqual(factory._p, c) + + d = c.startNegotiation() + self.assertEqual(c.state, "handshake") + self.assertEqual(t.read_buf(), b"send_this") + results = [] + d.addBoth(results.append) + self.assertEqual(results, []) + c.state = "unknown-bogus-state" + self.assertRaises(ValueError, c.dataReceived, b"surprise!") + self.assertEqual(t._connected, False) + self.assertEqual(c.state, "hung up") + self.assertEqual(len(results), 1) + f = results[0] + self.assertIsInstance(f, failure.Failure) + self.assertIsInstance(f.value, ValueError) + def test_relay_handshake(self): relay_handshake = b"relay handshake" owner = MockOwner() @@ -962,6 +1044,13 @@ class Connection(unittest.TestCase): return t, c, owner + def test_records_not_binary(self): + t, c, owner = self.make_connection() + + RECORD1 = u"not binary" + with self.assertRaises(InternalError): + c.send_record(RECORD1) + def test_records_good(self): # now make sure that outbound records are encrypted properly t, c, owner = self.make_connection() @@ -1203,6 +1292,20 @@ class Connection(unittest.TestCase): self.assertIsInstance(f, failure.Failure) self.assertIsInstance(f.value, error.ConnectionClosed) + def test_connectConsumer_empty(self): + # if connectConsumer() expects 0 bytes (e.g. someone is "sending" a + # zero-length file), make sure it gets woken up right away, so it can + # disconnect itself, even though no bytes will actually arrive + c = transit.Connection(None, None, None, "description") + c._negotiation_d.addErrback(lambda err: None) # eat it + c.transport = proto_helpers.StringTransport() + + consumer = proto_helpers.StringTransport() + d = c.connectConsumer(consumer, expected=0) + self.assertEqual(self.successResultOf(d), 0) + self.assertEqual(consumer.value(), b"") + self.assertIs(c._consumer, None) + def test_writeToFile(self): c = transit.Connection(None, None, None, "description") c._negotiation_d.addErrback(lambda err: None) # eat it @@ -1293,12 +1396,31 @@ class FileConsumer(unittest.TestCase): self.assertEqual(progress, [99, 1]) self.assertEqual(f.getvalue(), b"."*99+b"!") + def test_hasher(self): + hashee = [] + f = io.BytesIO() + progress = [] + fc = transit.FileConsumer(f, progress.append, hasher=hashee.append) + self.assertEqual(progress, []) + self.assertEqual(f.getvalue(), b"") + self.assertEqual(hashee, []) + fc.write(b"."* 99) + self.assertEqual(progress, [99]) + self.assertEqual(f.getvalue(), b"."*99) + self.assertEqual(hashee, [b"."*99]) + fc.write(b"!") + self.assertEqual(progress, [99, 1]) + self.assertEqual(f.getvalue(), b"."*99+b"!") + self.assertEqual(hashee, [b"."*99, b"!"]) + DIRECT_HINT_JSON = {"type": "direct-tcp-v1", "hostname": "direct", "port": 1234} RELAY_HINT_JSON = {"type": "relay-v1", "hints": [{"type": "direct-tcp-v1", "hostname": "relay", "port": 1234}]} +UNRECOGNIZED_DIRECT_HINT_JSON = {"type": "direct-tcp-v1", + "hostname": ["cannot", "parse", "list"]} UNRECOGNIZED_HINT_JSON = {"type": "unknown"} UNAVAILABLE_HINT_JSON = {"type": "direct-tcp-v1", # e.g. Tor without txtorcon "hostname": "unavailable", "port": 1234} @@ -1327,7 +1449,9 @@ class Transit(unittest.TestCase): s.set_transit_key(b"key") hints = yield s.get_connection_hints() # start the listener del hints - s.add_connection_hints([DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON]) + s.add_connection_hints([DIRECT_HINT_JSON, + UNRECOGNIZED_DIRECT_HINT_JSON, + UNRECOGNIZED_HINT_JSON]) s._start_connector = self._start_connector d = s.connect() diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index e0002a9..5088895 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -245,6 +245,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): return self.dataReceivedRECORDS() if isinstance(self.state, Exception): # for tests raise self.state + raise ValueError("internal error: unknown state %s" % (self.state,)) def _negotiationSuccessful(self): self.state = "records"