test_transit: simplify by using successResultOf/failureResultOf

This commit is contained in:
Brian Warner 2017-11-29 15:03:03 -06:00
parent f03c8bc516
commit 8227d963a3

View File

@ -8,7 +8,7 @@ from collections import namedtuple
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer, task, endpoints, protocol, address, error from twisted.internet import defer, task, endpoints, protocol, address, error
from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.internet.defer import gatherResults, inlineCallbacks
from twisted.python import log, failure from twisted.python import log
from twisted.test import proto_helpers from twisted.test import proto_helpers
from wormhole_transit_relay import transit_server from wormhole_transit_relay import transit_server
from ..errors import InternalError from ..errors import InternalError
@ -22,65 +22,53 @@ class Highlander(unittest.TestCase):
cancelled = set() cancelled = set()
contenders = [defer.Deferred(lambda d: cancelled.add(i)) contenders = [defer.Deferred(lambda d: cancelled.add(i))
for i in range(4)] for i in range(4)]
result = []
d = transit.there_can_be_only_one(contenders) d = transit.there_can_be_only_one(contenders)
d.addBoth(result.append) self.assertNoResult(d)
self.assertEqual(result, [])
contenders[0].errback(ValueError()) contenders[0].errback(ValueError())
self.assertEqual(result, []) self.assertNoResult(d)
contenders[1].errback(TypeError()) contenders[1].errback(TypeError())
self.assertEqual(result, []) self.assertNoResult(d)
contenders[2].callback("yay") contenders[2].callback("yay")
self.assertEqual(result, ["yay"]) self.assertEqual(self.successResultOf(d), "yay")
self.assertEqual(cancelled, set([3])) self.assertEqual(cancelled, set([3]))
def test_there_might_also_be_none(self): def test_there_might_also_be_none(self):
cancelled = set() cancelled = set()
contenders = [defer.Deferred(lambda d: cancelled.add(i)) contenders = [defer.Deferred(lambda d: cancelled.add(i))
for i in range(4)] for i in range(4)]
result = []
d = transit.there_can_be_only_one(contenders) d = transit.there_can_be_only_one(contenders)
d.addBoth(result.append) self.assertNoResult(d)
self.assertEqual(result, [])
contenders[0].errback(ValueError()) contenders[0].errback(ValueError())
self.assertEqual(result, []) self.assertNoResult(d)
contenders[1].errback(TypeError()) contenders[1].errback(TypeError())
self.assertEqual(result, []) self.assertNoResult(d)
contenders[2].errback(TypeError()) contenders[2].errback(TypeError())
self.assertEqual(result, []) self.assertNoResult(d)
contenders[3].errback(NameError()) contenders[3].errback(NameError())
self.assertEqual(len(result), 1) self.failureResultOf(d, ValueError) # first failure is recorded
f = result[0]
self.assertIsInstance(f.value, ValueError) # first failure is recorded
self.assertEqual(cancelled, set()) self.assertEqual(cancelled, set())
def test_cancel_early(self): def test_cancel_early(self):
cancelled = set() cancelled = set()
contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i))
for i in range(4)] for i in range(4)]
result = []
d = transit.there_can_be_only_one(contenders) d = transit.there_can_be_only_one(contenders)
d.addBoth(result.append) self.assertNoResult(d)
self.assertEqual(result, [])
self.assertEqual(cancelled, set()) self.assertEqual(cancelled, set())
d.cancel() d.cancel()
self.assertEqual(len(result), 1) self.failureResultOf(d, defer.CancelledError)
self.assertIsInstance(result[0].value, defer.CancelledError)
self.assertEqual(cancelled, set(range(4))) self.assertEqual(cancelled, set(range(4)))
def test_cancel_after_one_failure(self): def test_cancel_after_one_failure(self):
cancelled = set() cancelled = set()
contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i))
for i in range(4)] for i in range(4)]
result = []
d = transit.there_can_be_only_one(contenders) d = transit.there_can_be_only_one(contenders)
d.addBoth(result.append) self.assertNoResult(d)
self.assertEqual(result, [])
self.assertEqual(cancelled, set()) self.assertEqual(cancelled, set())
contenders[0].errback(ValueError()) contenders[0].errback(ValueError())
d.cancel() d.cancel()
self.assertEqual(len(result), 1) self.failureResultOf(d, ValueError)
self.assertIsInstance(result[0].value, ValueError)
self.assertEqual(cancelled, set([1,2,3])) self.assertEqual(cancelled, set([1,2,3]))
class Forever(unittest.TestCase): class Forever(unittest.TestCase):
@ -88,44 +76,44 @@ class Forever(unittest.TestCase):
clock = task.Clock() clock = task.Clock()
c = transit.Common("", reactor=clock) c = transit.Common("", reactor=clock)
cancelled = [] cancelled = []
result = []
d0 = defer.Deferred(cancelled.append) d0 = defer.Deferred(cancelled.append)
d = c._not_forever(1.0, d0) d = c._not_forever(1.0, d0)
d.addBoth(result.append) return c, clock, d0, d, cancelled
return c, clock, d0, d, cancelled, result
def test_not_forever_fires(self): def test_not_forever_fires(self):
c, clock, d0, d, cancelled, result = self._forever_setup() c, clock, d0, d, cancelled = self._forever_setup()
self.assertEqual((result, cancelled), ([], [])) self.assertNoResult(d)
self.assertEqual(cancelled, [])
d.callback(1) d.callback(1)
self.assertEqual((result, cancelled), ([1], [])) self.assertEqual(self.successResultOf(d), 1)
self.assertEqual(cancelled, [])
self.assertNot(clock.getDelayedCalls()) self.assertNot(clock.getDelayedCalls())
def test_not_forever_errs(self): def test_not_forever_errs(self):
c, clock, d0, d, cancelled, result = self._forever_setup() c, clock, d0, d, cancelled = self._forever_setup()
self.assertEqual((result, cancelled), ([], [])) self.assertNoResult(d)
self.assertEqual(cancelled, [])
d.errback(ValueError()) d.errback(ValueError())
self.assertEqual(cancelled, []) self.assertEqual(cancelled, [])
self.assertEqual(len(result), 1) self.failureResultOf(d, ValueError)
self.assertIsInstance(result[0].value, ValueError)
self.assertNot(clock.getDelayedCalls()) self.assertNot(clock.getDelayedCalls())
def test_not_forever_cancel_early(self): def test_not_forever_cancel_early(self):
c, clock, d0, d, cancelled, result = self._forever_setup() c, clock, d0, d, cancelled = self._forever_setup()
self.assertEqual((result, cancelled), ([], [])) self.assertNoResult(d)
self.assertEqual(cancelled, [])
d.cancel() d.cancel()
self.assertEqual(cancelled, [d0]) self.assertEqual(cancelled, [d0])
self.assertEqual(len(result), 1) self.failureResultOf(d, defer.CancelledError)
self.assertIsInstance(result[0].value, defer.CancelledError)
self.assertNot(clock.getDelayedCalls()) self.assertNot(clock.getDelayedCalls())
def test_not_forever_timeout(self): def test_not_forever_timeout(self):
c, clock, d0, d, cancelled, result = self._forever_setup() c, clock, d0, d, cancelled = self._forever_setup()
self.assertEqual((result, cancelled), ([], [])) self.assertNoResult(d)
self.assertEqual(cancelled, [])
clock.advance(2.0) clock.advance(2.0)
self.assertEqual(cancelled, [d0]) self.assertEqual(cancelled, [d0])
self.assertEqual(len(result), 1) self.failureResultOf(d, defer.CancelledError)
self.assertIsInstance(result[0].value, defer.CancelledError)
self.assertNot(clock.getDelayedCalls()) self.assertNot(clock.getDelayedCalls())
class Misc(unittest.TestCase): class Misc(unittest.TestCase):
@ -289,10 +277,7 @@ class Basic(unittest.TestCase):
def test_ignore_localhost_hint_orig(self): def test_ignore_localhost_hint_orig(self):
# this actually starts the listener # this actually starts the listener
c = transit.TransitSender("") c = transit.TransitSender("")
results = [] hints = self.successResultOf(c.get_connection_hints())
d = c.get_connection_hints()
d.addBoth(results.append)
hints = results[0]
c._stop_listening() c._stop_listening()
# If there are non-localhost hints, then localhost hints should be # If there are non-localhost hints, then localhost hints should be
# removed. But if the only hint is localhost, it should stay. # removed. But if the only hint is localhost, it should stay.
@ -335,21 +320,17 @@ class Basic(unittest.TestCase):
def test_transit_key_wait(self): def test_transit_key_wait(self):
KEY = b"123" KEY = b"123"
c = transit.Common("") c = transit.Common("")
results = []
d = c._get_transit_key() d = c._get_transit_key()
d.addBoth(results.append) self.assertNoResult(d)
self.assertEqual(results, [])
c.set_transit_key(KEY) c.set_transit_key(KEY)
self.assertEqual(results, [KEY]) self.assertEqual(self.successResultOf(d), KEY)
def test_transit_key_already_set(self): def test_transit_key_already_set(self):
KEY = b"123" KEY = b"123"
c = transit.Common("") c = transit.Common("")
c.set_transit_key(KEY) c.set_transit_key(KEY)
results = []
d = c._get_transit_key() d = c._get_transit_key()
d.addBoth(results.append) self.assertEqual(self.successResultOf(d), KEY)
self.assertEqual(results, [KEY])
def test_transit_keys(self): def test_transit_keys(self):
KEY = b"123" KEY = b"123"
@ -397,19 +378,14 @@ class Listener(unittest.TestCase):
# this actually starts the listener # this actually starts the listener
c = transit.TransitSender("") c = transit.TransitSender("")
results = []
d = c.get_connection_hints() d = c.get_connection_hints()
d.addBoth(results.append) hints = self.successResultOf(d)
self.assertEqual(len(results), 1)
hints = results[0]
# the hints are supposed to be cached, so calling this twice won't # the hints are supposed to be cached, so calling this twice won't
# start a second listener # start a second listener
self.assert_(c._listener) self.assert_(c._listener)
results = [] d2 = c.get_connection_hints()
d = c.get_connection_hints() self.assertEqual(self.successResultOf(d2), hints)
d.addBoth(results.append)
self.assertEqual(results, [hints])
c._stop_listening() c._stop_listening()
@ -502,10 +478,8 @@ class InboundConnectionFactory(unittest.TestCase):
def test_success(self): def test_success(self):
f = transit.InboundConnectionFactory("owner") f = transit.InboundConnectionFactory("owner")
f.protocol = MockConnection f.protocol = MockConnection
results = []
d = f.whenDone() d = f.whenDone()
d.addBoth(results.append) self.assertNoResult(d)
self.assertEqual(results, [])
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
p = f.buildProtocol(addr) p = f.buildProtocol(addr)
@ -518,19 +492,17 @@ class InboundConnectionFactory(unittest.TestCase):
# this is normally called from Connection.connectionMade # this is normally called from Connection.connectionMade
f.connectionWasMade(p) f.connectionWasMade(p)
self.assertEqual(p._start_negotiation_called, True) self.assertEqual(p._start_negotiation_called, True)
self.assertEqual(results, []) self.assertNoResult(d)
self.assertEqual(p._description, "<-example.com:1234") self.assertEqual(p._description, "<-example.com:1234")
p._d.callback(p) p._d.callback(p)
self.assertEqual(results, [p]) self.assertEqual(self.successResultOf(d), p)
def test_one_fail_one_success(self): def test_one_fail_one_success(self):
f = transit.InboundConnectionFactory("owner") f = transit.InboundConnectionFactory("owner")
f.protocol = MockConnection f.protocol = MockConnection
results = []
d = f.whenDone() d = f.whenDone()
d.addBoth(results.append) self.assertNoResult(d)
self.assertEqual(results, [])
addr1 = address.HostnameAddress("example.com", 1234) addr1 = address.HostnameAddress("example.com", 1234)
addr2 = address.HostnameAddress("example.com", 5678) addr2 = address.HostnameAddress("example.com", 5678)
@ -539,20 +511,18 @@ class InboundConnectionFactory(unittest.TestCase):
f.connectionWasMade(p1) f.connectionWasMade(p1)
f.connectionWasMade(p2) f.connectionWasMade(p2)
self.assertEqual(results, []) self.assertNoResult(d)
p1._d.errback(transit.BadHandshake("nope")) p1._d.errback(transit.BadHandshake("nope"))
self.assertEqual(results, []) self.assertNoResult(d)
p2._d.callback(p2) p2._d.callback(p2)
self.assertEqual(results, [p2]) self.assertEqual(self.successResultOf(d), p2)
def test_first_success_wins(self): def test_first_success_wins(self):
f = transit.InboundConnectionFactory("owner") f = transit.InboundConnectionFactory("owner")
f.protocol = MockConnection f.protocol = MockConnection
results = []
d = f.whenDone() d = f.whenDone()
d.addBoth(results.append) self.assertNoResult(d)
self.assertEqual(results, [])
addr1 = address.HostnameAddress("example.com", 1234) addr1 = address.HostnameAddress("example.com", 1234)
addr2 = address.HostnameAddress("example.com", 5678) addr2 = address.HostnameAddress("example.com", 5678)
@ -561,20 +531,18 @@ class InboundConnectionFactory(unittest.TestCase):
f.connectionWasMade(p1) f.connectionWasMade(p1)
f.connectionWasMade(p2) f.connectionWasMade(p2)
self.assertEqual(results, []) self.assertNoResult(d)
p1._d.callback(p1) p1._d.callback(p1)
self.assertEqual(results, [p1]) self.assertEqual(self.successResultOf(d), p1)
self.assertEqual(p1._cancelled, False) self.assertEqual(p1._cancelled, False)
self.assertEqual(p2._cancelled, True) self.assertEqual(p2._cancelled, True)
def test_log_other_errors(self): def test_log_other_errors(self):
f = transit.InboundConnectionFactory("owner") f = transit.InboundConnectionFactory("owner")
f.protocol = MockConnection f.protocol = MockConnection
results = []
d = f.whenDone() d = f.whenDone()
d.addBoth(results.append) self.assertNoResult(d)
self.assertEqual(results, [])
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
p1 = f.buildProtocol(addr) p1 = f.buildProtocol(addr)
@ -585,7 +553,7 @@ class InboundConnectionFactory(unittest.TestCase):
f.connectionWasMade(p1) f.connectionWasMade(p1)
our_error = RandomError("boom1") our_error = RandomError("boom1")
p1._d.errback(our_error) p1._d.errback(our_error)
self.assertEqual(len(results), 0) self.assertNoResult(d)
log.msg("=== note: the next RandomError is expected ===") log.msg("=== note: the next RandomError is expected ===")
# Make sure the Deferred has gone out of scope, so the UnhandledError # Make sure the Deferred has gone out of scope, so the UnhandledError
@ -600,10 +568,8 @@ class InboundConnectionFactory(unittest.TestCase):
def test_cancel(self): def test_cancel(self):
f = transit.InboundConnectionFactory("owner") f = transit.InboundConnectionFactory("owner")
f.protocol = MockConnection f.protocol = MockConnection
results = []
d = f.whenDone() d = f.whenDone()
d.addBoth(results.append) self.assertNoResult(d)
self.assertEqual(results, [])
addr1 = address.HostnameAddress("example.com", 1234) addr1 = address.HostnameAddress("example.com", 1234)
addr2 = address.HostnameAddress("example.com", 5678) addr2 = address.HostnameAddress("example.com", 5678)
@ -612,14 +578,11 @@ class InboundConnectionFactory(unittest.TestCase):
f.connectionWasMade(p1) f.connectionWasMade(p1)
f.connectionWasMade(p2) f.connectionWasMade(p2)
self.assertEqual(results, []) self.assertNoResult(d)
d.cancel() d.cancel()
self.assertEqual(len(results), 1) self.failureResultOf(d, defer.CancelledError)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, defer.CancelledError)
self.assertEqual(p1._cancelled, True) self.assertEqual(p1._cancelled, True)
self.assertEqual(p2._cancelled, True) self.assertEqual(p2._cancelled, True)
@ -714,15 +677,13 @@ class Connection(unittest.TestCase):
d = c.startNegotiation() d = c.startNegotiation()
self.assertEqual(c.state, "handshake") self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(t.read_buf(), b"send_this")
results = [] self.assertNoResult(d)
d.addBoth(results.append)
self.assertEqual(results, [])
c.dataReceived(b"expect_this") c.dataReceived(b"expect_this")
self.assertEqual(t.read_buf(), b"go\n") self.assertEqual(t.read_buf(), b"go\n")
self.assertEqual(t._connected, True) self.assertEqual(t._connected, True)
self.assertEqual(c.state, "records") self.assertEqual(c.state, "records")
self.assertEqual(results, [c]) self.assertEqual(self.successResultOf(d), c)
c.close() c.close()
self.assertEqual(t._connected, False) self.assertEqual(t._connected, False)
@ -744,18 +705,13 @@ class Connection(unittest.TestCase):
d = c.startNegotiation() d = c.startNegotiation()
self.assertEqual(c.state, "handshake") self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(t.read_buf(), b"send_this")
results = [] self.assertNoResult(d)
d.addBoth(results.append)
self.assertEqual(results, [])
c.dataReceived(b"expect_this") c.dataReceived(b"expect_this")
self.assertEqual(t.read_buf(), b"nevermind\n") self.assertEqual(t.read_buf(), b"nevermind\n")
self.assertEqual(t._connected, False) self.assertEqual(t._connected, False)
self.assertEqual(c.state, "hung up") self.assertEqual(c.state, "hung up")
self.assertEqual(len(results), 1) f = self.failureResultOf(d, transit.BadHandshake)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, transit.BadHandshake)
self.assertEqual(str(f.value), "abandoned") self.assertEqual(str(f.value), "abandoned")
def test_handshake_other_error(self): def test_handshake_other_error(self):
@ -773,17 +729,12 @@ class Connection(unittest.TestCase):
d = c.startNegotiation() d = c.startNegotiation()
self.assertEqual(c.state, "handshake") self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(t.read_buf(), b"send_this")
results = [] self.assertNoResult(d)
d.addBoth(results.append)
self.assertEqual(results, [])
c.state = RandomError("boom2") c.state = RandomError("boom2")
self.assertRaises(RandomError, c.dataReceived, b"surprise!") self.assertRaises(RandomError, c.dataReceived, b"surprise!")
self.assertEqual(t._connected, False) self.assertEqual(t._connected, False)
self.assertEqual(c.state, "hung up") self.assertEqual(c.state, "hung up")
self.assertEqual(len(results), 1) self.failureResultOf(d, RandomError)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, RandomError)
def test_handshake_bad_state(self): def test_handshake_bad_state(self):
owner = MockOwner() owner = MockOwner()
@ -800,17 +751,12 @@ class Connection(unittest.TestCase):
d = c.startNegotiation() d = c.startNegotiation()
self.assertEqual(c.state, "handshake") self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(t.read_buf(), b"send_this")
results = [] self.assertNoResult(d)
d.addBoth(results.append)
self.assertEqual(results, [])
c.state = "unknown-bogus-state" c.state = "unknown-bogus-state"
self.assertRaises(ValueError, c.dataReceived, b"surprise!") self.assertRaises(ValueError, c.dataReceived, b"surprise!")
self.assertEqual(t._connected, False) self.assertEqual(t._connected, False)
self.assertEqual(c.state, "hung up") self.assertEqual(c.state, "hung up")
self.assertEqual(len(results), 1) self.failureResultOf(d, ValueError)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, ValueError)
def test_relay_handshake(self): def test_relay_handshake(self):
relay_handshake = b"relay handshake" relay_handshake = b"relay handshake"
@ -835,13 +781,11 @@ class Connection(unittest.TestCase):
self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(t.read_buf(), b"send_this")
self.assertEqual(c.state, "handshake") self.assertEqual(c.state, "handshake")
results = [] self.assertNoResult(d)
d.addBoth(results.append)
self.assertEqual(results, [])
c.dataReceived(b"expect_this") c.dataReceived(b"expect_this")
self.assertEqual(c.state, "records") self.assertEqual(c.state, "records")
self.assertEqual(results, [c]) self.assertEqual(self.successResultOf(d), c)
self.assertEqual(t.read_buf(), b"go\n") self.assertEqual(t.read_buf(), b"go\n")
@ -868,12 +812,7 @@ class Connection(unittest.TestCase):
self.assertEqual(t._connected, False) self.assertEqual(t._connected, False)
self.assertEqual(c.state, "hung up") self.assertEqual(c.state, "hung up")
results = [] f = self.failureResultOf(d, transit.BadHandshake)
d.addBoth(results.append)
self.assertEqual(len(results), 1)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, transit.BadHandshake)
self.assertEqual(str(f.value), self.assertEqual(str(f.value),
"got %r want %r" % (b"not ok\n", b"ok\n")) "got %r want %r" % (b"not ok\n", b"ok\n"))
@ -894,17 +833,15 @@ class Connection(unittest.TestCase):
d = c.startNegotiation() d = c.startNegotiation()
self.assertEqual(c.state, "handshake") self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(t.read_buf(), b"send_this")
results = [] self.assertNoResult(d)
d.addBoth(results.append)
self.assertEqual(results, [])
c.dataReceived(b"expect_this") c.dataReceived(b"expect_this")
self.assertEqual(c.state, "wait-for-decision") self.assertEqual(c.state, "wait-for-decision")
self.assertEqual(results, []) self.assertNoResult(d)
c.dataReceived(b"go\n") c.dataReceived(b"go\n")
self.assertEqual(c.state, "records") self.assertEqual(c.state, "records")
self.assertEqual(results, [c]) self.assertEqual(self.successResultOf(d), c)
def test_receiver_rejected_politely(self): def test_receiver_rejected_politely(self):
# we're on the receiving side, so we wait for the sender to decide # we're on the receiving side, so we wait for the sender to decide
@ -923,21 +860,16 @@ class Connection(unittest.TestCase):
d = c.startNegotiation() d = c.startNegotiation()
self.assertEqual(c.state, "handshake") self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(t.read_buf(), b"send_this")
results = [] self.assertNoResult(d)
d.addBoth(results.append)
self.assertEqual(results, [])
c.dataReceived(b"expect_this") c.dataReceived(b"expect_this")
self.assertEqual(c.state, "wait-for-decision") self.assertEqual(c.state, "wait-for-decision")
self.assertEqual(results, []) self.assertNoResult(d)
c.dataReceived(b"nevermind\n") # polite rejection c.dataReceived(b"nevermind\n") # polite rejection
self.assertEqual(t._connected, False) self.assertEqual(t._connected, False)
self.assertEqual(c.state, "hung up") self.assertEqual(c.state, "hung up")
self.assertEqual(len(results), 1) f = self.failureResultOf(d, transit.BadHandshake)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, transit.BadHandshake)
self.assertEqual(str(f.value), self.assertEqual(str(f.value),
"got %r want %r" % (b"nevermind\n", b"go\n")) "got %r want %r" % (b"nevermind\n", b"go\n"))
@ -958,20 +890,15 @@ class Connection(unittest.TestCase):
d = c.startNegotiation() d = c.startNegotiation()
self.assertEqual(c.state, "handshake") self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(t.read_buf(), b"send_this")
results = [] self.assertNoResult(d)
d.addBoth(results.append)
self.assertEqual(results, [])
c.dataReceived(b"expect_this") c.dataReceived(b"expect_this")
self.assertEqual(c.state, "wait-for-decision") self.assertEqual(c.state, "wait-for-decision")
self.assertEqual(results, []) self.assertNoResult(d)
t.loseConnection() t.loseConnection()
self.assertEqual(t._connected, False) self.assertEqual(t._connected, False)
self.assertEqual(len(results), 1) f = self.failureResultOf(d, transit.BadHandshake)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, transit.BadHandshake)
self.assertEqual(str(f.value), "connection lost") self.assertEqual(str(f.value), "connection lost")
@ -986,17 +913,12 @@ class Connection(unittest.TestCase):
c.connectionMade() c.connectionMade()
d = c.startNegotiation() d = c.startNegotiation()
results = []
d.addBoth(results.append)
# while we're waiting for negotiation, we get cancelled # while we're waiting for negotiation, we get cancelled
d.cancel() d.cancel()
self.assertEqual(t._connected, False) self.assertEqual(t._connected, False)
self.assertEqual(c.state, "hung up") self.assertEqual(c.state, "hung up")
self.assertEqual(len(results), 1) self.failureResultOf(d, defer.CancelledError)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, defer.CancelledError)
def test_timeout(self): def test_timeout(self):
clock = task.Clock() clock = task.Clock()
@ -1013,16 +935,11 @@ class Connection(unittest.TestCase):
c.connectionMade() c.connectionMade()
# the timer should now be running # the timer should now be running
d = c.startNegotiation() d = c.startNegotiation()
results = []
d.addBoth(results.append)
# while we're waiting for negotiation, the timer expires # while we're waiting for negotiation, the timer expires
clock.advance(transit.TIMEOUT + 1.0) clock.advance(transit.TIMEOUT + 1.0)
self.assertEqual(t._connected, False) self.assertEqual(t._connected, False)
self.assertEqual(len(results), 1) f = self.failureResultOf(d, transit.BadHandshake)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, transit.BadHandshake)
self.assertEqual(str(f.value), "timeout") self.assertEqual(str(f.value), "timeout")
def make_connection(self): def make_connection(self):
@ -1036,10 +953,8 @@ class Connection(unittest.TestCase):
owner._state = "go" owner._state = "go"
d = c.startNegotiation() d = c.startNegotiation()
results = []
d.addBoth(results.append)
c.dataReceived(b"expect_this") c.dataReceived(b"expect_this")
self.assertEqual(results, [c]) self.assertEqual(self.successResultOf(d), c)
t.read_buf() # flush input buffer, prepare for encrypted records t.read_buf() # flush input buffer, prepare for encrypted records
return t, c, owner return t, c, owner
@ -1182,40 +1097,32 @@ class Connection(unittest.TestCase):
c = transit.Connection(None, None, None, "description") c = transit.Connection(None, None, None, "description")
c.transport = FakeTransport(c, None) c.transport = FakeTransport(c, None)
c.transport.signalConnectionLost = False c.transport.signalConnectionLost = False
results = [[] for i in range(5)]
c.recordReceived(b"0") c.recordReceived(b"0")
c.recordReceived(b"1") c.recordReceived(b"1")
c.recordReceived(b"2") c.recordReceived(b"2")
c.receive_record().addBoth(results[0].append) d0 = c.receive_record()
self.assertEqual(results[0], [b"0"]) self.assertEqual(self.successResultOf(d0), b"0")
d1 = c.receive_record() d1 = c.receive_record()
d2 = c.receive_record() d2 = c.receive_record()
# they must fire in order of receipt, not order of addCallback # they must fire in order of receipt, not order of addCallback
d2.addBoth(results[2].append) self.assertEqual(self.successResultOf(d2), b"2")
self.assertEqual(results[2], [b"2"]) self.assertEqual(self.successResultOf(d1), b"1")
d1.addBoth(results[1].append)
self.assertEqual(results[1], [b"1"])
c.receive_record().addBoth(results[3].append) d3 = c.receive_record()
c.receive_record().addBoth(results[4].append) d4 = c.receive_record()
self.assertEqual(results[3], []) self.assertNoResult(d3)
self.assertEqual(results[4], []) self.assertNoResult(d4)
c.recordReceived(b"3") c.recordReceived(b"3")
self.assertEqual(results[3], [b"3"]) self.assertEqual(self.successResultOf(d3), b"3")
self.assertEqual(results[4], []) self.assertNoResult(d4)
c.recordReceived(b"4") c.recordReceived(b"4")
self.assertEqual(results[3], [b"3"]) self.assertEqual(self.successResultOf(d4), b"4")
self.assertEqual(results[4], [b"4"])
closed = [] d5 = c.receive_record()
c.receive_record().addBoth(closed.append)
c.close() c.close()
self.assertEqual(len(closed), 1) self.failureResultOf(d5, error.ConnectionClosed)
f = closed[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, error.ConnectionClosed)
def test_producer(self): def test_producer(self):
# a Transit object (receiving data from the remote peer) produces # a Transit object (receiving data from the remote peer) produces
@ -1257,23 +1164,21 @@ class Connection(unittest.TestCase):
c.recordReceived(b"r1.") c.recordReceived(b"r1.")
consumer = proto_helpers.StringTransport() consumer = proto_helpers.StringTransport()
results = []
d = c.connectConsumer(consumer, expected=10) d = c.connectConsumer(consumer, expected=10)
d.addBoth(results.append)
self.assertEqual(consumer.value(), b"r1.") self.assertEqual(consumer.value(), b"r1.")
self.assertEqual(results, []) self.assertNoResult(d)
c.recordReceived(b"r2.") c.recordReceived(b"r2.")
self.assertEqual(consumer.value(), b"r1.r2.") self.assertEqual(consumer.value(), b"r1.r2.")
self.assertEqual(results, []) self.assertNoResult(d)
c.recordReceived(b"r3.") c.recordReceived(b"r3.")
self.assertEqual(consumer.value(), b"r1.r2.r3.") self.assertEqual(consumer.value(), b"r1.r2.r3.")
self.assertEqual(results, []) self.assertNoResult(d)
c.recordReceived(b"!") c.recordReceived(b"!")
self.assertEqual(consumer.value(), b"r1.r2.r3.!") self.assertEqual(consumer.value(), b"r1.r2.r3.!")
self.assertEqual(results, [10]) self.assertEqual(self.successResultOf(d), 10)
# that should automatically disconnect the consumer, and subsequent # that should automatically disconnect the consumer, and subsequent
# records should get queued, not delivered # records should get queued, not delivered
@ -1282,15 +1187,10 @@ class Connection(unittest.TestCase):
self.assertEqual(consumer.value(), b"r1.r2.r3.!") self.assertEqual(consumer.value(), b"r1.r2.r3.!")
# now test that the Deferred errbacks when the connection is lost # now test that the Deferred errbacks when the connection is lost
results = []
d = c.connectConsumer(consumer, expected=10) d = c.connectConsumer(consumer, expected=10)
d.addBoth(results.append)
c.connectionLost() c.connectionLost()
self.assertEqual(len(results), 1) self.failureResultOf(d, error.ConnectionClosed)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, error.ConnectionClosed)
def test_connectConsumer_empty(self): def test_connectConsumer_empty(self):
# if connectConsumer() expects 0 bytes (e.g. someone is "sending" a # if connectConsumer() expects 0 bytes (e.g. someone is "sending" a
@ -1314,27 +1214,25 @@ class Connection(unittest.TestCase):
f = io.BytesIO() f = io.BytesIO()
progress = [] progress = []
results = []
d = c.writeToFile(f, 10, progress.append) d = c.writeToFile(f, 10, progress.append)
d.addBoth(results.append)
self.assertEqual(f.getvalue(), b"r1.") self.assertEqual(f.getvalue(), b"r1.")
self.assertEqual(progress, [3]) self.assertEqual(progress, [3])
self.assertEqual(results, []) self.assertNoResult(d)
c.recordReceived(b"r2.") c.recordReceived(b"r2.")
self.assertEqual(f.getvalue(), b"r1.r2.") self.assertEqual(f.getvalue(), b"r1.r2.")
self.assertEqual(progress, [3, 3]) self.assertEqual(progress, [3, 3])
self.assertEqual(results, []) self.assertNoResult(d)
c.recordReceived(b"r3.") c.recordReceived(b"r3.")
self.assertEqual(f.getvalue(), b"r1.r2.r3.") self.assertEqual(f.getvalue(), b"r1.r2.r3.")
self.assertEqual(progress, [3, 3, 3]) self.assertEqual(progress, [3, 3, 3])
self.assertEqual(results, []) self.assertNoResult(d)
c.recordReceived(b"!") c.recordReceived(b"!")
self.assertEqual(f.getvalue(), b"r1.r2.r3.!") self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
self.assertEqual(progress, [3, 3, 3, 1]) self.assertEqual(progress, [3, 3, 3, 1])
self.assertEqual(results, [10]) self.assertEqual(self.successResultOf(d), 10)
# that should automatically disconnect the consumer, and subsequent # that should automatically disconnect the consumer, and subsequent
# records should get queued, not delivered # records should get queued, not delivered
@ -1347,23 +1245,16 @@ class Connection(unittest.TestCase):
c.recordReceived(b"second.") # now "overflow.second." c.recordReceived(b"second.") # now "overflow.second."
c.recordReceived(b"third.") # now "overflow.second.third." c.recordReceived(b"third.") # now "overflow.second.third."
f = io.BytesIO() f = io.BytesIO()
results = []
d = c.writeToFile(f, 10) d = c.writeToFile(f, 10)
d.addBoth(results.append)
self.assertEqual(f.getvalue(), b"overflow.second.") # whole records self.assertEqual(f.getvalue(), b"overflow.second.") # whole records
self.assertEqual(results, [16]) self.assertEqual(self.successResultOf(d), 16)
self.assertEqual(list(c._inbound_records), [b"third."]) self.assertEqual(list(c._inbound_records), [b"third."])
# now test that the Deferred errbacks when the connection is lost # now test that the Deferred errbacks when the connection is lost
results = []
d = c.writeToFile(f, 10) d = c.writeToFile(f, 10)
d.addBoth(results.append)
c.connectionLost() c.connectionLost()
self.assertEqual(len(results), 1) self.failureResultOf(d, error.ConnectionClosed)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, error.ConnectionClosed)
def test_consumer(self): def test_consumer(self):
# a local producer sends data to a consuming Transit object # a local producer sends data to a consuming Transit object
@ -1457,14 +1348,12 @@ class Transit(unittest.TestCase):
s._start_connector = self._start_connector s._start_connector = self._start_connector
d = s.connect() d = s.connect()
results = [] self.assertNoResult(d)
d.addBoth(results.append)
self.assertEqual(results, [])
self.assertEqual(len(self._waiters), 1) self.assertEqual(len(self._waiters), 1)
self.assertIsInstance(self._waiters[0], defer.Deferred) self.assertIsInstance(self._waiters[0], defer.Deferred)
self._waiters[0].callback("winner") self._waiters[0].callback("winner")
self.assertEqual(results, ["winner"]) self.assertEqual(self.successResultOf(d), "winner")
self.assertEqual(self._descriptions, ["->tcp:direct:1234"]) self.assertEqual(self._descriptions, ["->tcp:direct:1234"])
@inlineCallbacks @inlineCallbacks
@ -1478,14 +1367,12 @@ class Transit(unittest.TestCase):
s._start_connector = self._start_connector s._start_connector = self._start_connector
d = s.connect() d = s.connect()
results = [] self.assertNoResult(d)
d.addBoth(results.append)
self.assertEqual(results, [])
self.assertEqual(len(self._waiters), 1) self.assertEqual(len(self._waiters), 1)
self.assertIsInstance(self._waiters[0], defer.Deferred) self.assertIsInstance(self._waiters[0], defer.Deferred)
self._waiters[0].callback("winner") self._waiters[0].callback("winner")
self.assertEqual(results, ["winner"]) self.assertEqual(self.successResultOf(d), "winner")
self.assertEqual(self._descriptions, ["tor->tcp:direct:1234"]) self.assertEqual(self._descriptions, ["tor->tcp:direct:1234"])
@inlineCallbacks @inlineCallbacks
@ -1499,17 +1386,15 @@ class Transit(unittest.TestCase):
s._start_connector = self._start_connector s._start_connector = self._start_connector
d = s.connect() d = s.connect()
results = []
d.addBoth(results.append)
# move the clock forward any amount, since relay connections are # move the clock forward any amount, since relay connections are
# triggered starting at T+0.0 # triggered starting at T+0.0
clock.advance(1.0) clock.advance(1.0)
self.assertEqual(results, []) self.assertNoResult(d)
self.assertEqual(len(self._waiters), 1) self.assertEqual(len(self._waiters), 1)
self.assertIsInstance(self._waiters[0], defer.Deferred) self.assertIsInstance(self._waiters[0], defer.Deferred)
self._waiters[0].callback("winner") self._waiters[0].callback("winner")
self.assertEqual(results, ["winner"]) self.assertEqual(self.successResultOf(d), "winner")
self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"]) self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"])
def _endpoint_from_hint_obj(self, hint): def _endpoint_from_hint_obj(self, hint):
@ -1533,9 +1418,7 @@ class Transit(unittest.TestCase):
s._start_connector = self._start_connector s._start_connector = self._start_connector
d = s.connect() d = s.connect()
results = [] self.assertNoResult(d)
d.addBoth(results.append)
self.assertEqual(results, [])
# the direct connectors are tried right away, but the relay # the direct connectors are tried right away, but the relay
# connectors are stalled for a few seconds # connectors are stalled for a few seconds
self.assertEqual(self._connectors, ["direct"]) self.assertEqual(self._connectors, ["direct"])
@ -1544,7 +1427,7 @@ class Transit(unittest.TestCase):
self.assertEqual(self._connectors, ["direct", "relay"]) self.assertEqual(self._connectors, ["direct", "relay"])
self._waiters[0].callback("winner") self._waiters[0].callback("winner")
self.assertEqual(results, ["winner"]) self.assertEqual(self.successResultOf(d), "winner")
@inlineCallbacks @inlineCallbacks
def test_priorities(self): def test_priorities(self):
@ -1572,9 +1455,7 @@ class Transit(unittest.TestCase):
s._start_connector = self._start_connector s._start_connector = self._start_connector
d = s.connect() d = s.connect()
results = [] self.assertNoResult(d)
d.addBoth(results.append)
self.assertEqual(results, [])
# direct connector should be used first, then the priority=3.0 relay, # direct connector should be used first, then the priority=3.0 relay,
# then the two 2.0 relays, then the (default) 0.0 relay # then the two 2.0 relays, then the (default) 0.0 relay
@ -1594,7 +1475,7 @@ class Transit(unittest.TestCase):
["direct", "relay3", "relay4", "relay2", "relay"])) ["direct", "relay3", "relay4", "relay2", "relay"]))
self._waiters[0].callback("winner") self._waiters[0].callback("winner")
self.assertEqual(results, ["winner"]) self.assertEqual(self.successResultOf(d), "winner")
@inlineCallbacks @inlineCallbacks
def test_no_direct_hints(self): def test_no_direct_hints(self):
@ -1612,9 +1493,7 @@ class Transit(unittest.TestCase):
s._start_connector = self._start_connector s._start_connector = self._start_connector
d = s.connect() d = s.connect()
results = [] self.assertNoResult(d)
d.addBoth(results.append)
self.assertEqual(results, [])
# since there are no usable direct hints, the relay connector will # since there are no usable direct hints, the relay connector will
# only be stalled for 0 seconds # only be stalled for 0 seconds
self.assertEqual(self._connectors, []) self.assertEqual(self._connectors, [])
@ -1623,7 +1502,7 @@ class Transit(unittest.TestCase):
self.assertEqual(self._connectors, ["relay"]) self.assertEqual(self._connectors, ["relay"])
self._waiters[0].callback("winner") self._waiters[0].callback("winner")
self.assertEqual(results, ["winner"]) self.assertEqual(self.successResultOf(d), "winner")
@inlineCallbacks @inlineCallbacks
def test_no_contenders(self): def test_no_contenders(self):