move describe() from Transit to RecordPipe

This commit is contained in:
Brian Warner 2016-02-28 01:37:52 -08:00
parent 1903c58248
commit 6654efb429
6 changed files with 93 additions and 101 deletions

View File

@ -166,13 +166,17 @@ class ReceiveBuffer:
return rc
class RecordPipe:
def __init__(self, skt, send_key, receive_key):
def __init__(self, skt, send_key, receive_key, description):
self.skt = skt
self.send_box = SecretBox(send_key)
self.send_nonce = 0
self.receive_buf = ReceiveBuffer(self.skt)
self.receive_box = SecretBox(receive_key)
self.next_receive_nonce = 0
self._description = description
def describe(self):
return self._description
def send_record(self, record):
if not isinstance(record, type(b"")): raise UsageError
@ -338,11 +342,6 @@ class Common:
return self.winning_skt
raise TransitError("timeout")
def describe(self):
if not self.winning_skt_description:
return "not yet established"
return self.winning_skt_description
def _connector_failed(self, hint):
debug("- failed connector %s" % hint)
# XXX this was .remove, and occasionally got KeyError
@ -375,7 +374,8 @@ class Common:
def connect(self):
skt = self.establish_socket()
return RecordPipe(skt, self._sender_record_key(),
self._receiver_record_key())
self._receiver_record_key(),
self.winning_skt_description)
class TransitSender(Common):
is_sender = True

View File

@ -122,7 +122,7 @@ def receive_blocking(args):
record_pipe = transit_receiver.connect()
print(u"Receiving %d bytes for '%s' (%s).." %
(xfersize, destname, transit_receiver.describe()),
(xfersize, destname, record_pipe.describe()),
file=args.stdout)
if mode == "file":
tmp_destname = abs_destname + ".tmp"

View File

@ -92,7 +92,7 @@ def _send_file_blocking(them_phase1, fd_to_send, transit_sender,
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = transit_sender.connect()
print(u"Sending (%s).." % transit_sender.describe(), file=stdout)
print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
CHUNKSIZE = 64*1024
fd_to_send.seek(0,2)

View File

@ -138,7 +138,7 @@ def _send_file_twisted(tdata, transit_sender, fd_to_send,
record_pipe = yield transit_sender.connect()
# record_pipe should implement IConsumer, chunks are just records
print(u"Sending (%s).." % transit_sender.describe(), file=stdout)
print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
yield pfs.beginFileTransfer(fd_to_send, record_pipe)
print(u"File sent.. waiting for confirmation", file=stdout)
ack = yield record_pipe.receive_record()

View File

@ -197,24 +197,14 @@ class Basic(unittest.TestCase):
def test_connection_ready(self):
s = transit.TransitSender(u"")
self.assertEqual(s.describe(), "not yet established")
self.assertEqual(s.connection_ready("p1", "desc1"), "go")
self.assertEqual(s.describe(), "desc1")
self.assertEqual(s.connection_ready("p1"), "go")
self.assertEqual(s._winner, "p1")
self.assertEqual(s.connection_ready("p2", "desc2"), "nevermind")
self.assertEqual(s.describe(), "desc1")
self.assertEqual(s.connection_ready("p2"), "nevermind")
self.assertEqual(s._winner, "p1")
r = transit.TransitReceiver(u"")
self.assertEqual(r.describe(), "not yet established")
self.assertEqual(r.connection_ready("p1", "desc1"), "wait-for-decision")
self.assertEqual(r.describe(), "not yet established")
self.assertEqual(r.connection_ready("p2", "desc2"), "wait-for-decision")
self.assertEqual(r.describe(), "not yet established")
self.assertEqual(r.connection_ready("p1"), "wait-for-decision")
self.assertEqual(r.connection_ready("p2"), "wait-for-decision")
class Listener(unittest.TestCase):
@ -305,32 +295,32 @@ class RandomError(Exception):
pass
class MockConnection:
def __init__(self, owner, relay_handshake, start):
def __init__(self, owner, relay_handshake, start, description):
self.owner = owner
self.relay_handshake = relay_handshake
self.start = start
self._description = description
def cancel(d):
self._cancelled = True
self._d = defer.Deferred(cancel)
self._start_negotiation_called = False
self._cancelled = False
def startNegotiation(self, description):
def startNegotiation(self):
self._start_negotiation_called = True
self._description = description
return self._d
class InboundConnectionFactory(unittest.TestCase):
def test_describe(self):
f = transit.InboundConnectionFactory(None)
addrH = address.HostnameAddress("example.com", 1234)
self.assertEqual(f.describePeer(addrH), "<-example.com:1234")
self.assertEqual(f._describePeer(addrH), "<-example.com:1234")
addr4 = address.IPv4Address("TCP", "1.2.3.4", 1234)
self.assertEqual(f.describePeer(addr4), "<-1.2.3.4:1234")
self.assertEqual(f._describePeer(addr4), "<-1.2.3.4:1234")
addr6 = address.IPv6Address("TCP", "::1", 1234)
self.assertEqual(f.describePeer(addr6), "<-::1:1234")
self.assertEqual(f._describePeer(addr6), "<-::1:1234")
addrU = address.UNIXAddress("/dev/unlikely")
self.assertEqual(f.describePeer(addrU),
self.assertEqual(f._describePeer(addrU),
"<-UNIXAddress('/dev/unlikely')")
def test_success(self):
@ -350,7 +340,7 @@ class InboundConnectionFactory(unittest.TestCase):
# meh .start
# this is normally called from Connection.connectionMade
f.connectionWasMade(p, addr)
f.connectionWasMade(p)
self.assertEqual(p._start_negotiation_called, True)
self.assertEqual(results, [])
self.assertEqual(p._description, "<-example.com:1234")
@ -366,12 +356,13 @@ class InboundConnectionFactory(unittest.TestCase):
d.addBoth(results.append)
self.assertEqual(results, [])
addr = address.HostnameAddress("example.com", 1234)
p1 = f.buildProtocol(addr)
p2 = f.buildProtocol(addr)
addr1 = address.HostnameAddress("example.com", 1234)
addr2 = address.HostnameAddress("example.com", 5678)
p1 = f.buildProtocol(addr1)
p2 = f.buildProtocol(addr2)
f.connectionWasMade(p1, "desc1")
f.connectionWasMade(p2, "desc2")
f.connectionWasMade(p1)
f.connectionWasMade(p2)
self.assertEqual(results, [])
p1._d.errback(transit.BadHandshake("nope"))
@ -387,12 +378,13 @@ class InboundConnectionFactory(unittest.TestCase):
d.addBoth(results.append)
self.assertEqual(results, [])
addr = address.HostnameAddress("example.com", 1234)
p1 = f.buildProtocol(addr)
p2 = f.buildProtocol(addr)
addr1 = address.HostnameAddress("example.com", 1234)
addr2 = address.HostnameAddress("example.com", 5678)
p1 = f.buildProtocol(addr1)
p2 = f.buildProtocol(addr2)
f.connectionWasMade(p1, "desc1")
f.connectionWasMade(p2, "desc2")
f.connectionWasMade(p1)
f.connectionWasMade(p2)
self.assertEqual(results, [])
p1._d.callback(p1)
@ -414,7 +406,7 @@ class InboundConnectionFactory(unittest.TestCase):
# if the Connection protocol throws an unexpected error, that should
# get logged to the Twisted logs (as an Unhandled Error in Deferred)
# so we can diagnose the bug
f.connectionWasMade(p1, "desc1")
f.connectionWasMade(p1)
p1._d.errback(RandomError("boom"))
self.assertEqual(len(results), 0)
@ -433,12 +425,13 @@ class InboundConnectionFactory(unittest.TestCase):
d.addBoth(results.append)
self.assertEqual(results, [])
addr = address.HostnameAddress("example.com", 1234)
p1 = f.buildProtocol(addr)
p2 = f.buildProtocol(addr)
addr1 = address.HostnameAddress("example.com", 1234)
addr2 = address.HostnameAddress("example.com", 5678)
p1 = f.buildProtocol(addr1)
p2 = f.buildProtocol(addr2)
f.connectionWasMade(p1, "desc1")
f.connectionWasMade(p2, "desc2")
f.connectionWasMade(p1)
f.connectionWasMade(p2)
self.assertEqual(results, [])
d.cancel()
@ -454,7 +447,8 @@ class InboundConnectionFactory(unittest.TestCase):
class OutboundConnectionFactory(unittest.TestCase):
def test_success(self):
f = transit.OutboundConnectionFactory("owner", "relay_handshake")
f = transit.OutboundConnectionFactory("owner", "relay_handshake",
"description")
f.protocol = MockConnection
addr = address.HostnameAddress("example.com", 1234)
@ -466,16 +460,15 @@ class OutboundConnectionFactory(unittest.TestCase):
# meh .start
# this is normally called from Connection.connectionMade
f.connectionWasMade(p, "desc") # no-op for outbound
f.connectionWasMade(p) # no-op for outbound
self.assertEqual(p._start_negotiation_called, False)
class MockOwner:
_connection_ready_called = False
def connection_ready(self, connection, description):
def connection_ready(self, connection):
self._connection_ready_called = True
self._connection = connection
self._description = description
return self._state
def _send_this(self):
return b"send_this"
@ -488,7 +481,7 @@ class MockOwner:
class MockFactory:
_connectionWasMade_called = False
def connectionWasMade(self, p, description):
def connectionWasMade(self, p):
self._connectionWasMade_called = True
self._p = p
@ -496,7 +489,7 @@ class Connection(unittest.TestCase):
# exercise the Connection protocol class
def test_check_and_remove(self):
c = transit.Connection(None, None, None)
c = transit.Connection(None, None, None, "description")
c.buf = b""
EXP = b"expectation"
self.assertFalse(c._check_and_remove(EXP))
@ -525,7 +518,7 @@ class Connection(unittest.TestCase):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, relay_handshake, None)
c = transit.Connection(owner, relay_handshake, None, "description")
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
@ -534,7 +527,7 @@ class Connection(unittest.TestCase):
self.assertEqual(factory._p, c)
owner._state = "go"
d = c.startNegotiation("description")
d = c.startNegotiation()
self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this")
results = []
@ -555,7 +548,7 @@ class Connection(unittest.TestCase):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, relay_handshake, None)
c = transit.Connection(owner, relay_handshake, None, "description")
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
@ -564,7 +557,7 @@ class Connection(unittest.TestCase):
self.assertEqual(factory._p, c)
owner._state = "nevermind"
d = c.startNegotiation("description")
d = c.startNegotiation()
self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this")
results = []
@ -585,7 +578,7 @@ class Connection(unittest.TestCase):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None)
c = transit.Connection(owner, None, None, "description")
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
@ -593,7 +586,7 @@ class Connection(unittest.TestCase):
self.assertEqual(factory._connectionWasMade_called, True)
self.assertEqual(factory._p, c)
d = c.startNegotiation("description")
d = c.startNegotiation()
self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this")
results = []
@ -613,7 +606,7 @@ class Connection(unittest.TestCase):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, relay_handshake, None)
c = transit.Connection(owner, relay_handshake, None, "description")
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
@ -623,7 +616,7 @@ class Connection(unittest.TestCase):
self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation
owner._state = "go"
d = c.startNegotiation("description")
d = c.startNegotiation()
self.assertEqual(t.read_buf(), relay_handshake)
self.assertEqual(c.state, "relay") # waiting for OK from relay
@ -646,7 +639,7 @@ class Connection(unittest.TestCase):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, relay_handshake, None)
c = transit.Connection(owner, relay_handshake, None, "description")
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
@ -656,7 +649,7 @@ class Connection(unittest.TestCase):
self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation
owner._state = "go"
d = c.startNegotiation("description")
d = c.startNegotiation()
self.assertEqual(t.read_buf(), relay_handshake)
self.assertEqual(c.state, "relay") # waiting for OK from relay
@ -678,7 +671,7 @@ class Connection(unittest.TestCase):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None)
c = transit.Connection(owner, None, None, "description")
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
@ -687,7 +680,7 @@ class Connection(unittest.TestCase):
self.assertEqual(factory._p, c)
owner._state = "wait-for-decision"
d = c.startNegotiation("description")
d = c.startNegotiation()
self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this")
results = []
@ -707,7 +700,7 @@ class Connection(unittest.TestCase):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None)
c = transit.Connection(owner, None, None, "description")
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
@ -716,7 +709,7 @@ class Connection(unittest.TestCase):
self.assertEqual(factory._p, c)
owner._state = "wait-for-decision"
d = c.startNegotiation("description")
d = c.startNegotiation()
self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this")
results = []
@ -742,7 +735,7 @@ class Connection(unittest.TestCase):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None)
c = transit.Connection(owner, None, None, "description")
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
@ -751,7 +744,7 @@ class Connection(unittest.TestCase):
self.assertEqual(factory._p, c)
owner._state = "wait-for-decision"
d = c.startNegotiation("description")
d = c.startNegotiation()
self.assertEqual(c.state, "handshake")
self.assertEqual(t.read_buf(), b"send_this")
results = []
@ -775,13 +768,13 @@ class Connection(unittest.TestCase):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None)
c = transit.Connection(owner, None, None, "description")
self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
d = c.startNegotiation("description")
d = c.startNegotiation()
results = []
d.addBoth(results.append)
# while we're waiting for negotiation, we get cancelled
@ -799,7 +792,7 @@ class Connection(unittest.TestCase):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None)
c = transit.Connection(owner, None, None, "description")
def _callLater(period, func):
clock.callLater(period, func)
c.callLater = _callLater
@ -808,7 +801,7 @@ class Connection(unittest.TestCase):
c.factory = factory
c.connectionMade()
# the timer should now be running
d = c.startNegotiation("description")
d = c.startNegotiation()
results = []
d.addBoth(results.append)
# while we're waiting for negotiation, the timer expires
@ -825,13 +818,13 @@ class Connection(unittest.TestCase):
owner = MockOwner()
factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None)
c = transit.Connection(owner, None, None, "description")
t = c.transport = FakeTransport(c, addr)
c.factory = factory
c.connectionMade()
owner._state = "go"
d = c.startNegotiation("description")
d = c.startNegotiation()
results = []
d.addBoth(results.append)
c.dataReceived(b"expect_this")
@ -954,7 +947,7 @@ class Connection(unittest.TestCase):
# the key is None.
def test_receive_queue(self):
c = transit.Connection(None, None, None)
c = transit.Connection(None, None, None, "description")
c.transport = FakeTransport(c, None)
c.transport.signalConnectionLost = False
results = [[] for i in range(5)]
@ -995,7 +988,7 @@ class Connection(unittest.TestCase):
def test_producer(self):
# a Transit object (receiving data from the remote peer) produces
# data and writes it into a local Consumer
c = transit.Connection(None, None, None)
c = transit.Connection(None, None, None, "description")
c.transport = proto_helpers.StringTransport()
c.recordReceived(b"r1.")
c.recordReceived(b"r2.")
@ -1024,7 +1017,7 @@ class Connection(unittest.TestCase):
def test_consumer(self):
# a local producer sends data to a consuming Transit object
c = transit.Connection(None, None, None)
c = transit.Connection(None, None, None, "description")
c.transport = proto_helpers.StringTransport()
records = []
c.send_record = records.append

View File

@ -26,12 +26,13 @@ TIMEOUT=15
@implementer(interfaces.IProducer, interfaces.IConsumer)
class Connection(protocol.Protocol, policies.TimeoutMixin):
def __init__(self, owner, relay_handshake, start):
def __init__(self, owner, relay_handshake, start, description):
self.state = "too-early"
self.buf = b""
self.owner = owner
self.relay_handshake = relay_handshake
self.start = start
self._description = description
self._negotiation_d = defer.Deferred(self._cancel)
self._error = None
self._consumer = None
@ -41,10 +42,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
def connectionMade(self):
debug("handle %r" % (self.transport,))
self.setTimeout(TIMEOUT) # does timeoutConnection() when it expires
self.factory.connectionWasMade(self, self.transport.getPeer())
self.factory.connectionWasMade(self)
def startNegotiation(self, description):
self.description = description
def startNegotiation(self):
if self.relay_handshake is not None:
self.transport.write(self.relay_handshake)
self.state = "relay"
@ -104,7 +104,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
if self.state == "handshake":
if not self._check_and_remove(self.owner._expect_this()):
return
self.state = self.owner.connection_ready(self, self.description)
self.state = self.owner.connection_ready(self)
# If we're the receiver, we'll be moved to state
# "wait-for-decision", which means we're waiting for the other
# side (the sender) to make a decision. If we're the sender,
@ -161,6 +161,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
record = self.receive_box.decrypt(encrypted)
return record
def describe(self):
return self._description
def send_record(self, record):
if not isinstance(record, type(b"")): raise UsageError
assert SecretBox.NONCE_SIZE == 24
@ -260,17 +263,19 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
class OutboundConnectionFactory(protocol.ClientFactory):
protocol = Connection
def __init__(self, owner, relay_handshake):
def __init__(self, owner, relay_handshake, description):
self.owner = owner
self.relay_handshake = relay_handshake
self._description = description
self.start = time.time()
def buildProtocol(self, addr):
p = self.protocol(self.owner, self.relay_handshake, self.start)
p = self.protocol(self.owner, self.relay_handshake, self.start,
self._description)
p.factory = self
return p
def connectionWasMade(self, p, addr):
def connectionWasMade(self, p):
# outbound connections are handled via the endpoint
pass
@ -295,7 +300,7 @@ class InboundConnectionFactory(protocol.ClientFactory):
for d in list(self._pending_connections):
d.cancel() # that fires _remove and _proto_failed
def describePeer(self, addr):
def _describePeer(self, addr):
if isinstance(addr, address.HostnameAddress):
return "<-%s:%d" % (addr.hostname, addr.port)
elif isinstance(addr, (address.IPv4Address, address.IPv6Address)):
@ -303,12 +308,13 @@ class InboundConnectionFactory(protocol.ClientFactory):
return "<-%r" % addr
def buildProtocol(self, addr):
p = self.protocol(self.owner, None, self.start)
p = self.protocol(self.owner, None, self.start,
self._describePeer(addr))
p.factory = self
return p
def connectionWasMade(self, p, addr):
d = p.startNegotiation(self.describePeer(addr))
def connectionWasMade(self, p):
d = p.startNegotiation()
self._pending_connections.add(d)
d.addBoth(self._remove, d)
d.addCallbacks(self._proto_succeeded, self._proto_failed)
@ -413,7 +419,6 @@ class Common:
self._waiting_for_transit_key = []
self._listener = None
self._winner = None
self._winner_description = None
self._reactor = reactor
def _build_listener(self):
@ -596,10 +601,10 @@ class Common:
if is_relay:
assert self._transit_key
relay_handshake = build_relay_handshake(self._transit_key)
f = OutboundConnectionFactory(self, relay_handshake)
f = OutboundConnectionFactory(self, relay_handshake, description)
d = ep.connect(f)
# fires with protocol, or ConnectError
d.addCallback(lambda p: p.startNegotiation(description))
d.addCallback(lambda p: p.startNegotiation())
return d
def _endpoint_from_hint(self, hint):
@ -613,7 +618,7 @@ class Common:
return endpoints.HostnameEndpoint(self._reactor, pieces[1],
int(pieces[2]))
def connection_ready(self, p, description):
def connection_ready(self, p):
# inbound/outbound Connection protocols call this when they finish
# negotiation. The first one wins and gets a "go". Any subsequent
# ones lose and get a "nevermind" before being closed.
@ -626,14 +631,8 @@ class Common:
return "nevermind"
# this one wins!
self._winner = p
self._winner_description = description
return "go"
def describe(self):
if not self._winner:
return "not yet established"
return self._winner_description
class TransitSender(Common):
is_sender = True