test_transit: improve coverage
and fix py2/py3-isms in ipaddr tests
This commit is contained in:
parent
478405cb6a
commit
afe9f7152d
|
@ -1,4 +1,5 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import six
|
||||
import io
|
||||
import gc
|
||||
import mock
|
||||
|
@ -132,6 +133,13 @@ class Misc(unittest.TestCase):
|
|||
portno = transit.allocate_tcp_port()
|
||||
self.assertIsInstance(portno, int)
|
||||
|
||||
def test_allocate_port_no_reuseaddr(self):
|
||||
mock_sys = mock.Mock()
|
||||
mock_sys.platform = "cygwin"
|
||||
with mock.patch("wormhole.transit.sys", mock_sys):
|
||||
portno = transit.allocate_tcp_port()
|
||||
self.assertIsInstance(portno, int)
|
||||
|
||||
UnknownHint = namedtuple("UnknownHint", ["stuff"])
|
||||
|
||||
class Hints(unittest.TestCase):
|
||||
|
@ -205,6 +213,10 @@ class Hints(unittest.TestCase):
|
|||
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6))
|
||||
self.assertEqual(stderr, "")
|
||||
|
||||
h,stderr = p("tcp:host:1234:unknown=stuff")
|
||||
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0))
|
||||
self.assertEqual(stderr, "")
|
||||
|
||||
h,stderr = p("$!@#^")
|
||||
self.assertEqual(h, None)
|
||||
self.assertEqual(stderr, "unparseable hint '$!@#^'\n")
|
||||
|
@ -237,6 +249,15 @@ class Hints(unittest.TestCase):
|
|||
"tor:host:1234")
|
||||
self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff")))
|
||||
|
||||
# ipaddrs.py currently uses native strings: bytes on py2, unicode on
|
||||
# py3
|
||||
if six.PY2:
|
||||
LOOPADDR = b"127.0.0.1"
|
||||
OTHERADDR = b"1.2.3.4"
|
||||
else:
|
||||
LOOPADDR = "127.0.0.1" # unicode_literals
|
||||
OTHERADDR = "1.2.3.4"
|
||||
|
||||
class Basic(unittest.TestCase):
|
||||
@inlineCallbacks
|
||||
def test_relay_hints(self):
|
||||
|
@ -265,7 +286,7 @@ class Basic(unittest.TestCase):
|
|||
self.assertEqual(c._their_direct_hints, [])
|
||||
self.assertEqual(c._our_relay_hints, set())
|
||||
|
||||
def test_ignore_localhost_hint(self):
|
||||
def test_ignore_localhost_hint_orig(self):
|
||||
# this actually starts the listener
|
||||
c = transit.TransitSender("")
|
||||
results = []
|
||||
|
@ -281,6 +302,36 @@ class Basic(unittest.TestCase):
|
|||
for hint in hints:
|
||||
self.assertFalse(hint["hostname"] == "127.0.0.1")
|
||||
|
||||
def test_ignore_localhost_hint(self):
|
||||
# this actually starts the listener
|
||||
c = transit.TransitSender("")
|
||||
with mock.patch("wormhole.ipaddrs.find_addresses",
|
||||
return_value=[LOOPADDR, OTHERADDR]):
|
||||
hints = self.successResultOf(c.get_connection_hints())
|
||||
c._stop_listening()
|
||||
# If there are non-localhost hints, then localhost hints should be
|
||||
# removed.
|
||||
self.assertEqual(len(hints), 1)
|
||||
self.assertEqual(hints[0]["hostname"], "1.2.3.4")
|
||||
|
||||
def test_keep_only_localhost_hint(self):
|
||||
# this actually starts the listener
|
||||
c = transit.TransitSender("")
|
||||
with mock.patch("wormhole.ipaddrs.find_addresses",
|
||||
return_value=[LOOPADDR]):
|
||||
hints = self.successResultOf(c.get_connection_hints())
|
||||
c._stop_listening()
|
||||
# If the only hint is localhost, it should stay.
|
||||
self.assertEqual(len(hints), 1)
|
||||
self.assertEqual(hints[0]["hostname"], "127.0.0.1")
|
||||
|
||||
def test_abilities(self):
|
||||
c = transit.Common(None, no_listen=True)
|
||||
abilities = c.get_connection_abilities()
|
||||
self.assertEqual(abilities, [{"type": "direct-tcp-v1"},
|
||||
{"type": "relay-v1"},
|
||||
])
|
||||
|
||||
def test_transit_key_wait(self):
|
||||
KEY = b"123"
|
||||
c = transit.Common("")
|
||||
|
@ -642,6 +693,10 @@ class Connection(unittest.TestCase):
|
|||
self.assertTrue(c._check_and_remove(EXP))
|
||||
self.assertEqual(c.buf, b" exceeded")
|
||||
|
||||
def test_describe(self):
|
||||
c = transit.Connection(None, None, None, "description")
|
||||
self.assertEqual(c.describe(), "description")
|
||||
|
||||
def test_sender_accepting(self):
|
||||
relay_handshake = None
|
||||
owner = MockOwner()
|
||||
|
@ -730,6 +785,33 @@ class Connection(unittest.TestCase):
|
|||
self.assertIsInstance(f, failure.Failure)
|
||||
self.assertIsInstance(f.value, RandomError)
|
||||
|
||||
def test_handshake_bad_state(self):
|
||||
owner = MockOwner()
|
||||
factory = MockFactory()
|
||||
addr = address.HostnameAddress("example.com", 1234)
|
||||
c = transit.Connection(owner, None, None, "description")
|
||||
self.assertEqual(c.state, "too-early")
|
||||
t = c.transport = FakeTransport(c, addr)
|
||||
c.factory = factory
|
||||
c.connectionMade()
|
||||
self.assertEqual(factory._connectionWasMade_called, True)
|
||||
self.assertEqual(factory._p, c)
|
||||
|
||||
d = c.startNegotiation()
|
||||
self.assertEqual(c.state, "handshake")
|
||||
self.assertEqual(t.read_buf(), b"send_this")
|
||||
results = []
|
||||
d.addBoth(results.append)
|
||||
self.assertEqual(results, [])
|
||||
c.state = "unknown-bogus-state"
|
||||
self.assertRaises(ValueError, c.dataReceived, b"surprise!")
|
||||
self.assertEqual(t._connected, False)
|
||||
self.assertEqual(c.state, "hung up")
|
||||
self.assertEqual(len(results), 1)
|
||||
f = results[0]
|
||||
self.assertIsInstance(f, failure.Failure)
|
||||
self.assertIsInstance(f.value, ValueError)
|
||||
|
||||
def test_relay_handshake(self):
|
||||
relay_handshake = b"relay handshake"
|
||||
owner = MockOwner()
|
||||
|
@ -962,6 +1044,13 @@ class Connection(unittest.TestCase):
|
|||
|
||||
return t, c, owner
|
||||
|
||||
def test_records_not_binary(self):
|
||||
t, c, owner = self.make_connection()
|
||||
|
||||
RECORD1 = u"not binary"
|
||||
with self.assertRaises(InternalError):
|
||||
c.send_record(RECORD1)
|
||||
|
||||
def test_records_good(self):
|
||||
# now make sure that outbound records are encrypted properly
|
||||
t, c, owner = self.make_connection()
|
||||
|
@ -1203,6 +1292,20 @@ class Connection(unittest.TestCase):
|
|||
self.assertIsInstance(f, failure.Failure)
|
||||
self.assertIsInstance(f.value, error.ConnectionClosed)
|
||||
|
||||
def test_connectConsumer_empty(self):
|
||||
# if connectConsumer() expects 0 bytes (e.g. someone is "sending" a
|
||||
# zero-length file), make sure it gets woken up right away, so it can
|
||||
# disconnect itself, even though no bytes will actually arrive
|
||||
c = transit.Connection(None, None, None, "description")
|
||||
c._negotiation_d.addErrback(lambda err: None) # eat it
|
||||
c.transport = proto_helpers.StringTransport()
|
||||
|
||||
consumer = proto_helpers.StringTransport()
|
||||
d = c.connectConsumer(consumer, expected=0)
|
||||
self.assertEqual(self.successResultOf(d), 0)
|
||||
self.assertEqual(consumer.value(), b"")
|
||||
self.assertIs(c._consumer, None)
|
||||
|
||||
def test_writeToFile(self):
|
||||
c = transit.Connection(None, None, None, "description")
|
||||
c._negotiation_d.addErrback(lambda err: None) # eat it
|
||||
|
@ -1293,12 +1396,31 @@ class FileConsumer(unittest.TestCase):
|
|||
self.assertEqual(progress, [99, 1])
|
||||
self.assertEqual(f.getvalue(), b"."*99+b"!")
|
||||
|
||||
def test_hasher(self):
|
||||
hashee = []
|
||||
f = io.BytesIO()
|
||||
progress = []
|
||||
fc = transit.FileConsumer(f, progress.append, hasher=hashee.append)
|
||||
self.assertEqual(progress, [])
|
||||
self.assertEqual(f.getvalue(), b"")
|
||||
self.assertEqual(hashee, [])
|
||||
fc.write(b"."* 99)
|
||||
self.assertEqual(progress, [99])
|
||||
self.assertEqual(f.getvalue(), b"."*99)
|
||||
self.assertEqual(hashee, [b"."*99])
|
||||
fc.write(b"!")
|
||||
self.assertEqual(progress, [99, 1])
|
||||
self.assertEqual(f.getvalue(), b"."*99+b"!")
|
||||
self.assertEqual(hashee, [b"."*99, b"!"])
|
||||
|
||||
|
||||
DIRECT_HINT_JSON = {"type": "direct-tcp-v1",
|
||||
"hostname": "direct", "port": 1234}
|
||||
RELAY_HINT_JSON = {"type": "relay-v1",
|
||||
"hints": [{"type": "direct-tcp-v1",
|
||||
"hostname": "relay", "port": 1234}]}
|
||||
UNRECOGNIZED_DIRECT_HINT_JSON = {"type": "direct-tcp-v1",
|
||||
"hostname": ["cannot", "parse", "list"]}
|
||||
UNRECOGNIZED_HINT_JSON = {"type": "unknown"}
|
||||
UNAVAILABLE_HINT_JSON = {"type": "direct-tcp-v1", # e.g. Tor without txtorcon
|
||||
"hostname": "unavailable", "port": 1234}
|
||||
|
@ -1327,7 +1449,9 @@ class Transit(unittest.TestCase):
|
|||
s.set_transit_key(b"key")
|
||||
hints = yield s.get_connection_hints() # start the listener
|
||||
del hints
|
||||
s.add_connection_hints([DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON])
|
||||
s.add_connection_hints([DIRECT_HINT_JSON,
|
||||
UNRECOGNIZED_DIRECT_HINT_JSON,
|
||||
UNRECOGNIZED_HINT_JSON])
|
||||
|
||||
s._start_connector = self._start_connector
|
||||
d = s.connect()
|
||||
|
|
|
@ -245,6 +245,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
|||
return self.dataReceivedRECORDS()
|
||||
if isinstance(self.state, Exception): # for tests
|
||||
raise self.state
|
||||
raise ValueError("internal error: unknown state %s" % (self.state,))
|
||||
|
||||
def _negotiationSuccessful(self):
|
||||
self.state = "records"
|
||||
|
|
Loading…
Reference in New Issue
Block a user