test_transit: improve coverage

and fix py2/py3-isms in ipaddr tests
This commit is contained in:
Brian Warner 2017-04-16 18:13:51 -04:00
parent 478405cb6a
commit afe9f7152d
2 changed files with 127 additions and 2 deletions

View File

@ -1,4 +1,5 @@
from __future__ import print_function, unicode_literals
import six
import io
import gc
import mock
@ -132,6 +133,13 @@ class Misc(unittest.TestCase):
portno = transit.allocate_tcp_port()
self.assertIsInstance(portno, int)
def test_allocate_port_no_reuseaddr(self):
mock_sys = mock.Mock()
mock_sys.platform = "cygwin"
with mock.patch("wormhole.transit.sys", mock_sys):
portno = transit.allocate_tcp_port()
self.assertIsInstance(portno, int)
UnknownHint = namedtuple("UnknownHint", ["stuff"])
class Hints(unittest.TestCase):
@ -205,6 +213,10 @@ class Hints(unittest.TestCase):
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6))
self.assertEqual(stderr, "")
h,stderr = p("tcp:host:1234:unknown=stuff")
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0))
self.assertEqual(stderr, "")
h,stderr = p("$!@#^")
self.assertEqual(h, None)
self.assertEqual(stderr, "unparseable hint '$!@#^'\n")
@ -237,6 +249,15 @@ class Hints(unittest.TestCase):
"tor:host:1234")
self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff")))
# ipaddrs.py currently uses native strings: bytes on py2, unicode on
# py3
if six.PY2:
LOOPADDR = b"127.0.0.1"
OTHERADDR = b"1.2.3.4"
else:
LOOPADDR = "127.0.0.1" # unicode_literals
OTHERADDR = "1.2.3.4"
class Basic(unittest.TestCase):
@inlineCallbacks
def test_relay_hints(self):
@ -265,7 +286,7 @@ class Basic(unittest.TestCase):
self.assertEqual(c._their_direct_hints, [])
self.assertEqual(c._our_relay_hints, set())
def test_ignore_localhost_hint(self):
def test_ignore_localhost_hint_orig(self):
# this actually starts the listener
c = transit.TransitSender("")
results = []
@ -281,6 +302,36 @@ class Basic(unittest.TestCase):
for hint in hints:
self.assertFalse(hint["hostname"] == "127.0.0.1")
def test_ignore_localhost_hint(self):
# this actually starts the listener
c = transit.TransitSender("")
with mock.patch("wormhole.ipaddrs.find_addresses",
return_value=[LOOPADDR, OTHERADDR]):
hints = self.successResultOf(c.get_connection_hints())
c._stop_listening()
# If there are non-localhost hints, then localhost hints should be
# removed.
self.assertEqual(len(hints), 1)
self.assertEqual(hints[0]["hostname"], "1.2.3.4")
def test_keep_only_localhost_hint(self):
# this actually starts the listener
c = transit.TransitSender("")
with mock.patch("wormhole.ipaddrs.find_addresses",
return_value=[LOOPADDR]):
hints = self.successResultOf(c.get_connection_hints())
c._stop_listening()
# If the only hint is localhost, it should stay.
self.assertEqual(len(hints), 1)
self.assertEqual(hints[0]["hostname"], "127.0.0.1")
def test_abilities(self):
c = transit.Common(None, no_listen=True)
abilities = c.get_connection_abilities()
self.assertEqual(abilities, [{"type": "direct-tcp-v1"},
{"type": "relay-v1"},
])
def test_transit_key_wait(self):
KEY = b"123"
c = transit.Common("")
@ -642,6 +693,10 @@ class Connection(unittest.TestCase):
self.assertTrue(c._check_and_remove(EXP))
self.assertEqual(c.buf, b" exceeded")
def test_describe(self):
c = transit.Connection(None, None, None, "description")
self.assertEqual(c.describe(), "description")
def test_sender_accepting(self):
relay_handshake = None
owner = MockOwner()
@ -730,6 +785,33 @@ class Connection(unittest.TestCase):
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, RandomError)
def test_handshake_bad_state(self):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None, "description")
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
self.assertEqual(factory._connectionWasMade_called, True)
self.assertEqual(factory._p, c)
d = c.startNegotiation()
self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this")
results = []
d.addBoth(results.append)
self.assertEqual(results, [])
c.state = "unknown-bogus-state"
self.assertRaises(ValueError, c.dataReceived, b"surprise!")
self.assertEqual(t._connected, False)
self.assertEqual(c.state, "hung up")
self.assertEqual(len(results), 1)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, ValueError)
def test_relay_handshake(self):
relay_handshake = b"relay handshake"
owner = MockOwner()
@ -962,6 +1044,13 @@ class Connection(unittest.TestCase):
return t, c, owner
def test_records_not_binary(self):
t, c, owner = self.make_connection()
RECORD1 = u"not binary"
with self.assertRaises(InternalError):
c.send_record(RECORD1)
def test_records_good(self):
# now make sure that outbound records are encrypted properly
t, c, owner = self.make_connection()
@ -1203,6 +1292,20 @@ class Connection(unittest.TestCase):
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, error.ConnectionClosed)
def test_connectConsumer_empty(self):
# if connectConsumer() expects 0 bytes (e.g. someone is "sending" a
# zero-length file), make sure it gets woken up right away, so it can
# disconnect itself, even though no bytes will actually arrive
c = transit.Connection(None, None, None, "description")
c._negotiation_d.addErrback(lambda err: None) # eat it
c.transport = proto_helpers.StringTransport()
consumer = proto_helpers.StringTransport()
d = c.connectConsumer(consumer, expected=0)
self.assertEqual(self.successResultOf(d), 0)
self.assertEqual(consumer.value(), b"")
self.assertIs(c._consumer, None)
def test_writeToFile(self):
c = transit.Connection(None, None, None, "description")
c._negotiation_d.addErrback(lambda err: None) # eat it
@ -1293,12 +1396,31 @@ class FileConsumer(unittest.TestCase):
self.assertEqual(progress, [99, 1])
self.assertEqual(f.getvalue(), b"."*99+b"!")
def test_hasher(self):
hashee = []
f = io.BytesIO()
progress = []
fc = transit.FileConsumer(f, progress.append, hasher=hashee.append)
self.assertEqual(progress, [])
self.assertEqual(f.getvalue(), b"")
self.assertEqual(hashee, [])
fc.write(b"."* 99)
self.assertEqual(progress, [99])
self.assertEqual(f.getvalue(), b"."*99)
self.assertEqual(hashee, [b"."*99])
fc.write(b"!")
self.assertEqual(progress, [99, 1])
self.assertEqual(f.getvalue(), b"."*99+b"!")
self.assertEqual(hashee, [b"."*99, b"!"])
DIRECT_HINT_JSON = {"type": "direct-tcp-v1",
"hostname": "direct", "port": 1234}
RELAY_HINT_JSON = {"type": "relay-v1",
"hints": [{"type": "direct-tcp-v1",
"hostname": "relay", "port": 1234}]}
UNRECOGNIZED_DIRECT_HINT_JSON = {"type": "direct-tcp-v1",
"hostname": ["cannot", "parse", "list"]}
UNRECOGNIZED_HINT_JSON = {"type": "unknown"}
UNAVAILABLE_HINT_JSON = {"type": "direct-tcp-v1", # e.g. Tor without txtorcon
"hostname": "unavailable", "port": 1234}
@ -1327,7 +1449,9 @@ class Transit(unittest.TestCase):
s.set_transit_key(b"key")
hints = yield s.get_connection_hints() # start the listener
del hints
s.add_connection_hints([DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON])
s.add_connection_hints([DIRECT_HINT_JSON,
UNRECOGNIZED_DIRECT_HINT_JSON,
UNRECOGNIZED_HINT_JSON])
s._start_connector = self._start_connector
d = s.connect()

View File

@ -245,6 +245,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
return self.dataReceivedRECORDS()
if isinstance(self.state, Exception): # for tests
raise self.state
raise ValueError("internal error: unknown state %s" % (self.state,))
def _negotiationSuccessful(self):
self.state = "records"