move describe() from Transit to RecordPipe

This commit is contained in:
Brian Warner 2016-02-28 01:37:52 -08:00
parent 1903c58248
commit 6654efb429
6 changed files with 93 additions and 101 deletions

View File

@ -166,13 +166,17 @@ class ReceiveBuffer:
return rc return rc
class RecordPipe: class RecordPipe:
def __init__(self, skt, send_key, receive_key): def __init__(self, skt, send_key, receive_key, description):
self.skt = skt self.skt = skt
self.send_box = SecretBox(send_key) self.send_box = SecretBox(send_key)
self.send_nonce = 0 self.send_nonce = 0
self.receive_buf = ReceiveBuffer(self.skt) self.receive_buf = ReceiveBuffer(self.skt)
self.receive_box = SecretBox(receive_key) self.receive_box = SecretBox(receive_key)
self.next_receive_nonce = 0 self.next_receive_nonce = 0
self._description = description
def describe(self):
return self._description
def send_record(self, record): def send_record(self, record):
if not isinstance(record, type(b"")): raise UsageError if not isinstance(record, type(b"")): raise UsageError
@ -338,11 +342,6 @@ class Common:
return self.winning_skt return self.winning_skt
raise TransitError("timeout") raise TransitError("timeout")
def describe(self):
if not self.winning_skt_description:
return "not yet established"
return self.winning_skt_description
def _connector_failed(self, hint): def _connector_failed(self, hint):
debug("- failed connector %s" % hint) debug("- failed connector %s" % hint)
# XXX this was .remove, and occasionally got KeyError # XXX this was .remove, and occasionally got KeyError
@ -375,7 +374,8 @@ class Common:
def connect(self): def connect(self):
skt = self.establish_socket() skt = self.establish_socket()
return RecordPipe(skt, self._sender_record_key(), return RecordPipe(skt, self._sender_record_key(),
self._receiver_record_key()) self._receiver_record_key(),
self.winning_skt_description)
class TransitSender(Common): class TransitSender(Common):
is_sender = True is_sender = True

View File

@ -122,7 +122,7 @@ def receive_blocking(args):
record_pipe = transit_receiver.connect() record_pipe = transit_receiver.connect()
print(u"Receiving %d bytes for '%s' (%s).." % print(u"Receiving %d bytes for '%s' (%s).." %
(xfersize, destname, transit_receiver.describe()), (xfersize, destname, record_pipe.describe()),
file=args.stdout) file=args.stdout)
if mode == "file": if mode == "file":
tmp_destname = abs_destname + ".tmp" tmp_destname = abs_destname + ".tmp"

View File

@ -92,7 +92,7 @@ def _send_file_blocking(them_phase1, fd_to_send, transit_sender,
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"]) transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = transit_sender.connect() record_pipe = transit_sender.connect()
print(u"Sending (%s).." % transit_sender.describe(), file=stdout) print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
CHUNKSIZE = 64*1024 CHUNKSIZE = 64*1024
fd_to_send.seek(0,2) fd_to_send.seek(0,2)

View File

@ -138,7 +138,7 @@ def _send_file_twisted(tdata, transit_sender, fd_to_send,
record_pipe = yield transit_sender.connect() record_pipe = yield transit_sender.connect()
# record_pipe should implement IConsumer, chunks are just records # record_pipe should implement IConsumer, chunks are just records
print(u"Sending (%s).." % transit_sender.describe(), file=stdout) print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
yield pfs.beginFileTransfer(fd_to_send, record_pipe) yield pfs.beginFileTransfer(fd_to_send, record_pipe)
print(u"File sent.. waiting for confirmation", file=stdout) print(u"File sent.. waiting for confirmation", file=stdout)
ack = yield record_pipe.receive_record() ack = yield record_pipe.receive_record()

View File

@ -197,24 +197,14 @@ class Basic(unittest.TestCase):
def test_connection_ready(self): def test_connection_ready(self):
s = transit.TransitSender(u"") s = transit.TransitSender(u"")
self.assertEqual(s.describe(), "not yet established") self.assertEqual(s.connection_ready("p1"), "go")
self.assertEqual(s.connection_ready("p1", "desc1"), "go")
self.assertEqual(s.describe(), "desc1")
self.assertEqual(s._winner, "p1") self.assertEqual(s._winner, "p1")
self.assertEqual(s.connection_ready("p2"), "nevermind")
self.assertEqual(s.connection_ready("p2", "desc2"), "nevermind")
self.assertEqual(s.describe(), "desc1")
self.assertEqual(s._winner, "p1") self.assertEqual(s._winner, "p1")
r = transit.TransitReceiver(u"") r = transit.TransitReceiver(u"")
self.assertEqual(r.describe(), "not yet established") self.assertEqual(r.connection_ready("p1"), "wait-for-decision")
self.assertEqual(r.connection_ready("p2"), "wait-for-decision")
self.assertEqual(r.connection_ready("p1", "desc1"), "wait-for-decision")
self.assertEqual(r.describe(), "not yet established")
self.assertEqual(r.connection_ready("p2", "desc2"), "wait-for-decision")
self.assertEqual(r.describe(), "not yet established")
class Listener(unittest.TestCase): class Listener(unittest.TestCase):
@ -305,32 +295,32 @@ class RandomError(Exception):
pass pass
class MockConnection: class MockConnection:
def __init__(self, owner, relay_handshake, start): def __init__(self, owner, relay_handshake, start, description):
self.owner = owner self.owner = owner
self.relay_handshake = relay_handshake self.relay_handshake = relay_handshake
self.start = start self.start = start
self._description = description
def cancel(d): def cancel(d):
self._cancelled = True self._cancelled = True
self._d = defer.Deferred(cancel) self._d = defer.Deferred(cancel)
self._start_negotiation_called = False self._start_negotiation_called = False
self._cancelled = False self._cancelled = False
def startNegotiation(self, description): def startNegotiation(self):
self._start_negotiation_called = True self._start_negotiation_called = True
self._description = description
return self._d return self._d
class InboundConnectionFactory(unittest.TestCase): class InboundConnectionFactory(unittest.TestCase):
def test_describe(self): def test_describe(self):
f = transit.InboundConnectionFactory(None) f = transit.InboundConnectionFactory(None)
addrH = address.HostnameAddress("example.com", 1234) addrH = address.HostnameAddress("example.com", 1234)
self.assertEqual(f.describePeer(addrH), "<-example.com:1234") self.assertEqual(f._describePeer(addrH), "<-example.com:1234")
addr4 = address.IPv4Address("TCP", "1.2.3.4", 1234) addr4 = address.IPv4Address("TCP", "1.2.3.4", 1234)
self.assertEqual(f.describePeer(addr4), "<-1.2.3.4:1234") self.assertEqual(f._describePeer(addr4), "<-1.2.3.4:1234")
addr6 = address.IPv6Address("TCP", "::1", 1234) addr6 = address.IPv6Address("TCP", "::1", 1234)
self.assertEqual(f.describePeer(addr6), "<-::1:1234") self.assertEqual(f._describePeer(addr6), "<-::1:1234")
addrU = address.UNIXAddress("/dev/unlikely") addrU = address.UNIXAddress("/dev/unlikely")
self.assertEqual(f.describePeer(addrU), self.assertEqual(f._describePeer(addrU),
"<-UNIXAddress('/dev/unlikely')") "<-UNIXAddress('/dev/unlikely')")
def test_success(self): def test_success(self):
@ -350,7 +340,7 @@ class InboundConnectionFactory(unittest.TestCase):
# meh .start # meh .start
# this is normally called from Connection.connectionMade # this is normally called from Connection.connectionMade
f.connectionWasMade(p, addr) f.connectionWasMade(p)
self.assertEqual(p._start_negotiation_called, True) self.assertEqual(p._start_negotiation_called, True)
self.assertEqual(results, []) self.assertEqual(results, [])
self.assertEqual(p._description, "<-example.com:1234") self.assertEqual(p._description, "<-example.com:1234")
@ -366,12 +356,13 @@ class InboundConnectionFactory(unittest.TestCase):
d.addBoth(results.append) d.addBoth(results.append)
self.assertEqual(results, []) self.assertEqual(results, [])
addr = address.HostnameAddress("example.com", 1234) addr1 = address.HostnameAddress("example.com", 1234)
p1 = f.buildProtocol(addr) addr2 = address.HostnameAddress("example.com", 5678)
p2 = f.buildProtocol(addr) p1 = f.buildProtocol(addr1)
p2 = f.buildProtocol(addr2)
f.connectionWasMade(p1, "desc1") f.connectionWasMade(p1)
f.connectionWasMade(p2, "desc2") f.connectionWasMade(p2)
self.assertEqual(results, []) self.assertEqual(results, [])
p1._d.errback(transit.BadHandshake("nope")) p1._d.errback(transit.BadHandshake("nope"))
@ -387,12 +378,13 @@ class InboundConnectionFactory(unittest.TestCase):
d.addBoth(results.append) d.addBoth(results.append)
self.assertEqual(results, []) self.assertEqual(results, [])
addr = address.HostnameAddress("example.com", 1234) addr1 = address.HostnameAddress("example.com", 1234)
p1 = f.buildProtocol(addr) addr2 = address.HostnameAddress("example.com", 5678)
p2 = f.buildProtocol(addr) p1 = f.buildProtocol(addr1)
p2 = f.buildProtocol(addr2)
f.connectionWasMade(p1, "desc1") f.connectionWasMade(p1)
f.connectionWasMade(p2, "desc2") f.connectionWasMade(p2)
self.assertEqual(results, []) self.assertEqual(results, [])
p1._d.callback(p1) p1._d.callback(p1)
@ -414,7 +406,7 @@ class InboundConnectionFactory(unittest.TestCase):
# if the Connection protocol throws an unexpected error, that should # if the Connection protocol throws an unexpected error, that should
# get logged to the Twisted logs (as an Unhandled Error in Deferred) # get logged to the Twisted logs (as an Unhandled Error in Deferred)
# so we can diagnose the bug # so we can diagnose the bug
f.connectionWasMade(p1, "desc1") f.connectionWasMade(p1)
p1._d.errback(RandomError("boom")) p1._d.errback(RandomError("boom"))
self.assertEqual(len(results), 0) self.assertEqual(len(results), 0)
@ -433,12 +425,13 @@ class InboundConnectionFactory(unittest.TestCase):
d.addBoth(results.append) d.addBoth(results.append)
self.assertEqual(results, []) self.assertEqual(results, [])
addr = address.HostnameAddress("example.com", 1234) addr1 = address.HostnameAddress("example.com", 1234)
p1 = f.buildProtocol(addr) addr2 = address.HostnameAddress("example.com", 5678)
p2 = f.buildProtocol(addr) p1 = f.buildProtocol(addr1)
p2 = f.buildProtocol(addr2)
f.connectionWasMade(p1, "desc1") f.connectionWasMade(p1)
f.connectionWasMade(p2, "desc2") f.connectionWasMade(p2)
self.assertEqual(results, []) self.assertEqual(results, [])
d.cancel() d.cancel()
@ -454,7 +447,8 @@ class InboundConnectionFactory(unittest.TestCase):
class OutboundConnectionFactory(unittest.TestCase): class OutboundConnectionFactory(unittest.TestCase):
def test_success(self): def test_success(self):
f = transit.OutboundConnectionFactory("owner", "relay_handshake") f = transit.OutboundConnectionFactory("owner", "relay_handshake",
"description")
f.protocol = MockConnection f.protocol = MockConnection
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
@ -466,16 +460,15 @@ class OutboundConnectionFactory(unittest.TestCase):
# meh .start # meh .start
# this is normally called from Connection.connectionMade # this is normally called from Connection.connectionMade
f.connectionWasMade(p, "desc") # no-op for outbound f.connectionWasMade(p) # no-op for outbound
self.assertEqual(p._start_negotiation_called, False) self.assertEqual(p._start_negotiation_called, False)
class MockOwner: class MockOwner:
_connection_ready_called = False _connection_ready_called = False
def connection_ready(self, connection, description): def connection_ready(self, connection):
self._connection_ready_called = True self._connection_ready_called = True
self._connection = connection self._connection = connection
self._description = description
return self._state return self._state
def _send_this(self): def _send_this(self):
return b"send_this" return b"send_this"
@ -488,7 +481,7 @@ class MockOwner:
class MockFactory: class MockFactory:
_connectionWasMade_called = False _connectionWasMade_called = False
def connectionWasMade(self, p, description): def connectionWasMade(self, p):
self._connectionWasMade_called = True self._connectionWasMade_called = True
self._p = p self._p = p
@ -496,7 +489,7 @@ class Connection(unittest.TestCase):
# exercise the Connection protocol class # exercise the Connection protocol class
def test_check_and_remove(self): def test_check_and_remove(self):
c = transit.Connection(None, None, None) c = transit.Connection(None, None, None, "description")
c.buf = b"" c.buf = b""
EXP = b"expectation" EXP = b"expectation"
self.assertFalse(c._check_and_remove(EXP)) self.assertFalse(c._check_and_remove(EXP))
@ -525,7 +518,7 @@ class Connection(unittest.TestCase):
owner = MockOwner() owner = MockOwner()
factory = MockFactory() factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, relay_handshake, None) c = transit.Connection(owner, relay_handshake, None, "description")
self.assertEqual(c.state, "too-early") self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr) t = c.transport = FakeTransport(c, addr)
c.factory = factory c.factory = factory
@ -534,7 +527,7 @@ class Connection(unittest.TestCase):
self.assertEqual(factory._p, c) self.assertEqual(factory._p, c)
owner._state = "go" owner._state = "go"
d = c.startNegotiation("description") d = c.startNegotiation()
self.assertEqual(c.state, "handshake") self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(t.read_buf(), b"send_this")
results = [] results = []
@ -555,7 +548,7 @@ class Connection(unittest.TestCase):
owner = MockOwner() owner = MockOwner()
factory = MockFactory() factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, relay_handshake, None) c = transit.Connection(owner, relay_handshake, None, "description")
self.assertEqual(c.state, "too-early") self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr) t = c.transport = FakeTransport(c, addr)
c.factory = factory c.factory = factory
@ -564,7 +557,7 @@ class Connection(unittest.TestCase):
self.assertEqual(factory._p, c) self.assertEqual(factory._p, c)
owner._state = "nevermind" owner._state = "nevermind"
d = c.startNegotiation("description") d = c.startNegotiation()
self.assertEqual(c.state, "handshake") self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(t.read_buf(), b"send_this")
results = [] results = []
@ -585,7 +578,7 @@ class Connection(unittest.TestCase):
owner = MockOwner() owner = MockOwner()
factory = MockFactory() factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None) c = transit.Connection(owner, None, None, "description")
self.assertEqual(c.state, "too-early") self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr) t = c.transport = FakeTransport(c, addr)
c.factory = factory c.factory = factory
@ -593,7 +586,7 @@ class Connection(unittest.TestCase):
self.assertEqual(factory._connectionWasMade_called, True) self.assertEqual(factory._connectionWasMade_called, True)
self.assertEqual(factory._p, c) self.assertEqual(factory._p, c)
d = c.startNegotiation("description") d = c.startNegotiation()
self.assertEqual(c.state, "handshake") self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(t.read_buf(), b"send_this")
results = [] results = []
@ -613,7 +606,7 @@ class Connection(unittest.TestCase):
owner = MockOwner() owner = MockOwner()
factory = MockFactory() factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, relay_handshake, None) c = transit.Connection(owner, relay_handshake, None, "description")
self.assertEqual(c.state, "too-early") self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr) t = c.transport = FakeTransport(c, addr)
c.factory = factory c.factory = factory
@ -623,7 +616,7 @@ class Connection(unittest.TestCase):
self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation
owner._state = "go" owner._state = "go"
d = c.startNegotiation("description") d = c.startNegotiation()
self.assertEqual(t.read_buf(), relay_handshake) self.assertEqual(t.read_buf(), relay_handshake)
self.assertEqual(c.state, "relay") # waiting for OK from relay self.assertEqual(c.state, "relay") # waiting for OK from relay
@ -646,7 +639,7 @@ class Connection(unittest.TestCase):
owner = MockOwner() owner = MockOwner()
factory = MockFactory() factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, relay_handshake, None) c = transit.Connection(owner, relay_handshake, None, "description")
self.assertEqual(c.state, "too-early") self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr) t = c.transport = FakeTransport(c, addr)
c.factory = factory c.factory = factory
@ -656,7 +649,7 @@ class Connection(unittest.TestCase):
self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation
owner._state = "go" owner._state = "go"
d = c.startNegotiation("description") d = c.startNegotiation()
self.assertEqual(t.read_buf(), relay_handshake) self.assertEqual(t.read_buf(), relay_handshake)
self.assertEqual(c.state, "relay") # waiting for OK from relay self.assertEqual(c.state, "relay") # waiting for OK from relay
@ -678,7 +671,7 @@ class Connection(unittest.TestCase):
owner = MockOwner() owner = MockOwner()
factory = MockFactory() factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None) c = transit.Connection(owner, None, None, "description")
self.assertEqual(c.state, "too-early") self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr) t = c.transport = FakeTransport(c, addr)
c.factory = factory c.factory = factory
@ -687,7 +680,7 @@ class Connection(unittest.TestCase):
self.assertEqual(factory._p, c) self.assertEqual(factory._p, c)
owner._state = "wait-for-decision" owner._state = "wait-for-decision"
d = c.startNegotiation("description") d = c.startNegotiation()
self.assertEqual(c.state, "handshake") self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(t.read_buf(), b"send_this")
results = [] results = []
@ -707,7 +700,7 @@ class Connection(unittest.TestCase):
owner = MockOwner() owner = MockOwner()
factory = MockFactory() factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None) c = transit.Connection(owner, None, None, "description")
self.assertEqual(c.state, "too-early") self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr) t = c.transport = FakeTransport(c, addr)
c.factory = factory c.factory = factory
@ -716,7 +709,7 @@ class Connection(unittest.TestCase):
self.assertEqual(factory._p, c) self.assertEqual(factory._p, c)
owner._state = "wait-for-decision" owner._state = "wait-for-decision"
d = c.startNegotiation("description") d = c.startNegotiation()
self.assertEqual(c.state, "handshake") self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(t.read_buf(), b"send_this")
results = [] results = []
@ -742,7 +735,7 @@ class Connection(unittest.TestCase):
owner = MockOwner() owner = MockOwner()
factory = MockFactory() factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None) c = transit.Connection(owner, None, None, "description")
self.assertEqual(c.state, "too-early") self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr) t = c.transport = FakeTransport(c, addr)
c.factory = factory c.factory = factory
@ -751,7 +744,7 @@ class Connection(unittest.TestCase):
self.assertEqual(factory._p, c) self.assertEqual(factory._p, c)
owner._state = "wait-for-decision" owner._state = "wait-for-decision"
d = c.startNegotiation("description") d = c.startNegotiation()
self.assertEqual(c.state, "handshake") self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this") self.assertEqual(t.read_buf(), b"send_this")
results = [] results = []
@ -775,13 +768,13 @@ class Connection(unittest.TestCase):
owner = MockOwner() owner = MockOwner()
factory = MockFactory() factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None) c = transit.Connection(owner, None, None, "description")
self.assertEqual(c.state, "too-early") self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr) t = c.transport = FakeTransport(c, addr)
c.factory = factory c.factory = factory
c.connectionMade() c.connectionMade()
d = c.startNegotiation("description") d = c.startNegotiation()
results = [] results = []
d.addBoth(results.append) d.addBoth(results.append)
# while we're waiting for negotiation, we get cancelled # while we're waiting for negotiation, we get cancelled
@ -799,7 +792,7 @@ class Connection(unittest.TestCase):
owner = MockOwner() owner = MockOwner()
factory = MockFactory() factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None) c = transit.Connection(owner, None, None, "description")
def _callLater(period, func): def _callLater(period, func):
clock.callLater(period, func) clock.callLater(period, func)
c.callLater = _callLater c.callLater = _callLater
@ -808,7 +801,7 @@ class Connection(unittest.TestCase):
c.factory = factory c.factory = factory
c.connectionMade() c.connectionMade()
# the timer should now be running # the timer should now be running
d = c.startNegotiation("description") d = c.startNegotiation()
results = [] results = []
d.addBoth(results.append) d.addBoth(results.append)
# while we're waiting for negotiation, the timer expires # while we're waiting for negotiation, the timer expires
@ -825,13 +818,13 @@ class Connection(unittest.TestCase):
owner = MockOwner() owner = MockOwner()
factory = MockFactory() factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None) c = transit.Connection(owner, None, None, "description")
t = c.transport = FakeTransport(c, addr) t = c.transport = FakeTransport(c, addr)
c.factory = factory c.factory = factory
c.connectionMade() c.connectionMade()
owner._state = "go" owner._state = "go"
d = c.startNegotiation("description") d = c.startNegotiation()
results = [] results = []
d.addBoth(results.append) d.addBoth(results.append)
c.dataReceived(b"expect_this") c.dataReceived(b"expect_this")
@ -954,7 +947,7 @@ class Connection(unittest.TestCase):
# the key is None. # the key is None.
def test_receive_queue(self): def test_receive_queue(self):
c = transit.Connection(None, None, None) c = transit.Connection(None, None, None, "description")
c.transport = FakeTransport(c, None) c.transport = FakeTransport(c, None)
c.transport.signalConnectionLost = False c.transport.signalConnectionLost = False
results = [[] for i in range(5)] results = [[] for i in range(5)]
@ -995,7 +988,7 @@ class Connection(unittest.TestCase):
def test_producer(self): def test_producer(self):
# a Transit object (receiving data from the remote peer) produces # a Transit object (receiving data from the remote peer) produces
# data and writes it into a local Consumer # data and writes it into a local Consumer
c = transit.Connection(None, None, None) c = transit.Connection(None, None, None, "description")
c.transport = proto_helpers.StringTransport() c.transport = proto_helpers.StringTransport()
c.recordReceived(b"r1.") c.recordReceived(b"r1.")
c.recordReceived(b"r2.") c.recordReceived(b"r2.")
@ -1024,7 +1017,7 @@ class Connection(unittest.TestCase):
def test_consumer(self): def test_consumer(self):
# a local producer sends data to a consuming Transit object # a local producer sends data to a consuming Transit object
c = transit.Connection(None, None, None) c = transit.Connection(None, None, None, "description")
c.transport = proto_helpers.StringTransport() c.transport = proto_helpers.StringTransport()
records = [] records = []
c.send_record = records.append c.send_record = records.append

View File

@ -26,12 +26,13 @@ TIMEOUT=15
@implementer(interfaces.IProducer, interfaces.IConsumer) @implementer(interfaces.IProducer, interfaces.IConsumer)
class Connection(protocol.Protocol, policies.TimeoutMixin): class Connection(protocol.Protocol, policies.TimeoutMixin):
def __init__(self, owner, relay_handshake, start): def __init__(self, owner, relay_handshake, start, description):
self.state = "too-early" self.state = "too-early"
self.buf = b"" self.buf = b""
self.owner = owner self.owner = owner
self.relay_handshake = relay_handshake self.relay_handshake = relay_handshake
self.start = start self.start = start
self._description = description
self._negotiation_d = defer.Deferred(self._cancel) self._negotiation_d = defer.Deferred(self._cancel)
self._error = None self._error = None
self._consumer = None self._consumer = None
@ -41,10 +42,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
def connectionMade(self): def connectionMade(self):
debug("handle %r" % (self.transport,)) debug("handle %r" % (self.transport,))
self.setTimeout(TIMEOUT) # does timeoutConnection() when it expires self.setTimeout(TIMEOUT) # does timeoutConnection() when it expires
self.factory.connectionWasMade(self, self.transport.getPeer()) self.factory.connectionWasMade(self)
def startNegotiation(self, description): def startNegotiation(self):
self.description = description
if self.relay_handshake is not None: if self.relay_handshake is not None:
self.transport.write(self.relay_handshake) self.transport.write(self.relay_handshake)
self.state = "relay" self.state = "relay"
@ -104,7 +104,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
if self.state == "handshake": if self.state == "handshake":
if not self._check_and_remove(self.owner._expect_this()): if not self._check_and_remove(self.owner._expect_this()):
return return
self.state = self.owner.connection_ready(self, self.description) self.state = self.owner.connection_ready(self)
# If we're the receiver, we'll be moved to state # If we're the receiver, we'll be moved to state
# "wait-for-decision", which means we're waiting for the other # "wait-for-decision", which means we're waiting for the other
# side (the sender) to make a decision. If we're the sender, # side (the sender) to make a decision. If we're the sender,
@ -161,6 +161,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
record = self.receive_box.decrypt(encrypted) record = self.receive_box.decrypt(encrypted)
return record return record
def describe(self):
return self._description
def send_record(self, record): def send_record(self, record):
if not isinstance(record, type(b"")): raise UsageError if not isinstance(record, type(b"")): raise UsageError
assert SecretBox.NONCE_SIZE == 24 assert SecretBox.NONCE_SIZE == 24
@ -260,17 +263,19 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
class OutboundConnectionFactory(protocol.ClientFactory): class OutboundConnectionFactory(protocol.ClientFactory):
protocol = Connection protocol = Connection
def __init__(self, owner, relay_handshake): def __init__(self, owner, relay_handshake, description):
self.owner = owner self.owner = owner
self.relay_handshake = relay_handshake self.relay_handshake = relay_handshake
self._description = description
self.start = time.time() self.start = time.time()
def buildProtocol(self, addr): def buildProtocol(self, addr):
p = self.protocol(self.owner, self.relay_handshake, self.start) p = self.protocol(self.owner, self.relay_handshake, self.start,
self._description)
p.factory = self p.factory = self
return p return p
def connectionWasMade(self, p, addr): def connectionWasMade(self, p):
# outbound connections are handled via the endpoint # outbound connections are handled via the endpoint
pass pass
@ -295,7 +300,7 @@ class InboundConnectionFactory(protocol.ClientFactory):
for d in list(self._pending_connections): for d in list(self._pending_connections):
d.cancel() # that fires _remove and _proto_failed d.cancel() # that fires _remove and _proto_failed
def describePeer(self, addr): def _describePeer(self, addr):
if isinstance(addr, address.HostnameAddress): if isinstance(addr, address.HostnameAddress):
return "<-%s:%d" % (addr.hostname, addr.port) return "<-%s:%d" % (addr.hostname, addr.port)
elif isinstance(addr, (address.IPv4Address, address.IPv6Address)): elif isinstance(addr, (address.IPv4Address, address.IPv6Address)):
@ -303,12 +308,13 @@ class InboundConnectionFactory(protocol.ClientFactory):
return "<-%r" % addr return "<-%r" % addr
def buildProtocol(self, addr): def buildProtocol(self, addr):
p = self.protocol(self.owner, None, self.start) p = self.protocol(self.owner, None, self.start,
self._describePeer(addr))
p.factory = self p.factory = self
return p return p
def connectionWasMade(self, p, addr): def connectionWasMade(self, p):
d = p.startNegotiation(self.describePeer(addr)) d = p.startNegotiation()
self._pending_connections.add(d) self._pending_connections.add(d)
d.addBoth(self._remove, d) d.addBoth(self._remove, d)
d.addCallbacks(self._proto_succeeded, self._proto_failed) d.addCallbacks(self._proto_succeeded, self._proto_failed)
@ -413,7 +419,6 @@ class Common:
self._waiting_for_transit_key = [] self._waiting_for_transit_key = []
self._listener = None self._listener = None
self._winner = None self._winner = None
self._winner_description = None
self._reactor = reactor self._reactor = reactor
def _build_listener(self): def _build_listener(self):
@ -596,10 +601,10 @@ class Common:
if is_relay: if is_relay:
assert self._transit_key assert self._transit_key
relay_handshake = build_relay_handshake(self._transit_key) relay_handshake = build_relay_handshake(self._transit_key)
f = OutboundConnectionFactory(self, relay_handshake) f = OutboundConnectionFactory(self, relay_handshake, description)
d = ep.connect(f) d = ep.connect(f)
# fires with protocol, or ConnectError # fires with protocol, or ConnectError
d.addCallback(lambda p: p.startNegotiation(description)) d.addCallback(lambda p: p.startNegotiation())
return d return d
def _endpoint_from_hint(self, hint): def _endpoint_from_hint(self, hint):
@ -613,7 +618,7 @@ class Common:
return endpoints.HostnameEndpoint(self._reactor, pieces[1], return endpoints.HostnameEndpoint(self._reactor, pieces[1],
int(pieces[2])) int(pieces[2]))
def connection_ready(self, p, description): def connection_ready(self, p):
# inbound/outbound Connection protocols call this when they finish # inbound/outbound Connection protocols call this when they finish
# negotiation. The first one wins and gets a "go". Any subsequent # negotiation. The first one wins and gets a "go". Any subsequent
# ones lose and get a "nevermind" before being closed. # ones lose and get a "nevermind" before being closed.
@ -626,14 +631,8 @@ class Common:
return "nevermind" return "nevermind"
# this one wins! # this one wins!
self._winner = p self._winner = p
self._winner_description = description
return "go" return "go"
def describe(self):
if not self._winner:
return "not yet established"
return self._winner_description
class TransitSender(Common): class TransitSender(Common):
is_sender = True is_sender = True