From 6654efb429bc79de6bf9b6c3efa22377d7684730 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 28 Feb 2016 01:37:52 -0800 Subject: [PATCH] move describe() from Transit to RecordPipe --- src/wormhole/blocking/transit.py | 14 +- src/wormhole/scripts/cmd_receive_blocking.py | 2 +- src/wormhole/scripts/cmd_send_blocking.py | 2 +- src/wormhole/scripts/cmd_send_twisted.py | 2 +- src/wormhole/test/test_transit_twisted.py | 131 +++++++++---------- src/wormhole/twisted/transit.py | 43 +++--- 6 files changed, 93 insertions(+), 101 deletions(-) diff --git a/src/wormhole/blocking/transit.py b/src/wormhole/blocking/transit.py index 2e74ae7..c4a8218 100644 --- a/src/wormhole/blocking/transit.py +++ b/src/wormhole/blocking/transit.py @@ -166,13 +166,17 @@ class ReceiveBuffer: return rc class RecordPipe: - def __init__(self, skt, send_key, receive_key): + def __init__(self, skt, send_key, receive_key, description): self.skt = skt self.send_box = SecretBox(send_key) self.send_nonce = 0 self.receive_buf = ReceiveBuffer(self.skt) self.receive_box = SecretBox(receive_key) self.next_receive_nonce = 0 + self._description = description + + def describe(self): + return self._description def send_record(self, record): if not isinstance(record, type(b"")): raise UsageError @@ -338,11 +342,6 @@ class Common: return self.winning_skt raise TransitError("timeout") - def describe(self): - if not self.winning_skt_description: - return "not yet established" - return self.winning_skt_description - def _connector_failed(self, hint): debug("- failed connector %s" % hint) # XXX this was .remove, and occasionally got KeyError @@ -375,7 +374,8 @@ class Common: def connect(self): skt = self.establish_socket() return RecordPipe(skt, self._sender_record_key(), - self._receiver_record_key()) + self._receiver_record_key(), + self.winning_skt_description) class TransitSender(Common): is_sender = True diff --git a/src/wormhole/scripts/cmd_receive_blocking.py b/src/wormhole/scripts/cmd_receive_blocking.py index 9c883c2..562a35d 100644 --- a/src/wormhole/scripts/cmd_receive_blocking.py +++ b/src/wormhole/scripts/cmd_receive_blocking.py @@ -122,7 +122,7 @@ def receive_blocking(args): record_pipe = transit_receiver.connect() print(u"Receiving %d bytes for '%s' (%s).." % - (xfersize, destname, transit_receiver.describe()), + (xfersize, destname, record_pipe.describe()), file=args.stdout) if mode == "file": tmp_destname = abs_destname + ".tmp" diff --git a/src/wormhole/scripts/cmd_send_blocking.py b/src/wormhole/scripts/cmd_send_blocking.py index 53bb874..436013e 100644 --- a/src/wormhole/scripts/cmd_send_blocking.py +++ b/src/wormhole/scripts/cmd_send_blocking.py @@ -92,7 +92,7 @@ def _send_file_blocking(them_phase1, fd_to_send, transit_sender, transit_sender.add_their_relay_hints(tdata["relay_connection_hints"]) record_pipe = transit_sender.connect() - print(u"Sending (%s).." % transit_sender.describe(), file=stdout) + print(u"Sending (%s).." % record_pipe.describe(), file=stdout) CHUNKSIZE = 64*1024 fd_to_send.seek(0,2) diff --git a/src/wormhole/scripts/cmd_send_twisted.py b/src/wormhole/scripts/cmd_send_twisted.py index cf21258..de70f82 100644 --- a/src/wormhole/scripts/cmd_send_twisted.py +++ b/src/wormhole/scripts/cmd_send_twisted.py @@ -138,7 +138,7 @@ def _send_file_twisted(tdata, transit_sender, fd_to_send, record_pipe = yield transit_sender.connect() # record_pipe should implement IConsumer, chunks are just records - print(u"Sending (%s).." % transit_sender.describe(), file=stdout) + print(u"Sending (%s).." % record_pipe.describe(), file=stdout) yield pfs.beginFileTransfer(fd_to_send, record_pipe) print(u"File sent.. waiting for confirmation", file=stdout) ack = yield record_pipe.receive_record() diff --git a/src/wormhole/test/test_transit_twisted.py b/src/wormhole/test/test_transit_twisted.py index 56878cc..0bb53a3 100644 --- a/src/wormhole/test/test_transit_twisted.py +++ b/src/wormhole/test/test_transit_twisted.py @@ -197,24 +197,14 @@ class Basic(unittest.TestCase): def test_connection_ready(self): s = transit.TransitSender(u"") - self.assertEqual(s.describe(), "not yet established") - - self.assertEqual(s.connection_ready("p1", "desc1"), "go") - self.assertEqual(s.describe(), "desc1") + self.assertEqual(s.connection_ready("p1"), "go") self.assertEqual(s._winner, "p1") - - self.assertEqual(s.connection_ready("p2", "desc2"), "nevermind") - self.assertEqual(s.describe(), "desc1") + self.assertEqual(s.connection_ready("p2"), "nevermind") self.assertEqual(s._winner, "p1") r = transit.TransitReceiver(u"") - self.assertEqual(r.describe(), "not yet established") - - self.assertEqual(r.connection_ready("p1", "desc1"), "wait-for-decision") - self.assertEqual(r.describe(), "not yet established") - - self.assertEqual(r.connection_ready("p2", "desc2"), "wait-for-decision") - self.assertEqual(r.describe(), "not yet established") + self.assertEqual(r.connection_ready("p1"), "wait-for-decision") + self.assertEqual(r.connection_ready("p2"), "wait-for-decision") class Listener(unittest.TestCase): @@ -305,32 +295,32 @@ class RandomError(Exception): pass class MockConnection: - def __init__(self, owner, relay_handshake, start): + def __init__(self, owner, relay_handshake, start, description): self.owner = owner self.relay_handshake = relay_handshake self.start = start + self._description = description def cancel(d): self._cancelled = True self._d = defer.Deferred(cancel) self._start_negotiation_called = False self._cancelled = False - def startNegotiation(self, description): + def startNegotiation(self): self._start_negotiation_called = True - self._description = description return self._d class InboundConnectionFactory(unittest.TestCase): def test_describe(self): f = transit.InboundConnectionFactory(None) addrH = address.HostnameAddress("example.com", 1234) - self.assertEqual(f.describePeer(addrH), "<-example.com:1234") + self.assertEqual(f._describePeer(addrH), "<-example.com:1234") addr4 = address.IPv4Address("TCP", "1.2.3.4", 1234) - self.assertEqual(f.describePeer(addr4), "<-1.2.3.4:1234") + self.assertEqual(f._describePeer(addr4), "<-1.2.3.4:1234") addr6 = address.IPv6Address("TCP", "::1", 1234) - self.assertEqual(f.describePeer(addr6), "<-::1:1234") + self.assertEqual(f._describePeer(addr6), "<-::1:1234") addrU = address.UNIXAddress("/dev/unlikely") - self.assertEqual(f.describePeer(addrU), + self.assertEqual(f._describePeer(addrU), "<-UNIXAddress('/dev/unlikely')") def test_success(self): @@ -350,7 +340,7 @@ class InboundConnectionFactory(unittest.TestCase): # meh .start # this is normally called from Connection.connectionMade - f.connectionWasMade(p, addr) + f.connectionWasMade(p) self.assertEqual(p._start_negotiation_called, True) self.assertEqual(results, []) self.assertEqual(p._description, "<-example.com:1234") @@ -366,12 +356,13 @@ class InboundConnectionFactory(unittest.TestCase): d.addBoth(results.append) self.assertEqual(results, []) - addr = address.HostnameAddress("example.com", 1234) - p1 = f.buildProtocol(addr) - p2 = f.buildProtocol(addr) + addr1 = address.HostnameAddress("example.com", 1234) + addr2 = address.HostnameAddress("example.com", 5678) + p1 = f.buildProtocol(addr1) + p2 = f.buildProtocol(addr2) - f.connectionWasMade(p1, "desc1") - f.connectionWasMade(p2, "desc2") + f.connectionWasMade(p1) + f.connectionWasMade(p2) self.assertEqual(results, []) p1._d.errback(transit.BadHandshake("nope")) @@ -387,12 +378,13 @@ class InboundConnectionFactory(unittest.TestCase): d.addBoth(results.append) self.assertEqual(results, []) - addr = address.HostnameAddress("example.com", 1234) - p1 = f.buildProtocol(addr) - p2 = f.buildProtocol(addr) + addr1 = address.HostnameAddress("example.com", 1234) + addr2 = address.HostnameAddress("example.com", 5678) + p1 = f.buildProtocol(addr1) + p2 = f.buildProtocol(addr2) - f.connectionWasMade(p1, "desc1") - f.connectionWasMade(p2, "desc2") + f.connectionWasMade(p1) + f.connectionWasMade(p2) self.assertEqual(results, []) p1._d.callback(p1) @@ -414,7 +406,7 @@ class InboundConnectionFactory(unittest.TestCase): # if the Connection protocol throws an unexpected error, that should # get logged to the Twisted logs (as an Unhandled Error in Deferred) # so we can diagnose the bug - f.connectionWasMade(p1, "desc1") + f.connectionWasMade(p1) p1._d.errback(RandomError("boom")) self.assertEqual(len(results), 0) @@ -433,12 +425,13 @@ class InboundConnectionFactory(unittest.TestCase): d.addBoth(results.append) self.assertEqual(results, []) - addr = address.HostnameAddress("example.com", 1234) - p1 = f.buildProtocol(addr) - p2 = f.buildProtocol(addr) + addr1 = address.HostnameAddress("example.com", 1234) + addr2 = address.HostnameAddress("example.com", 5678) + p1 = f.buildProtocol(addr1) + p2 = f.buildProtocol(addr2) - f.connectionWasMade(p1, "desc1") - f.connectionWasMade(p2, "desc2") + f.connectionWasMade(p1) + f.connectionWasMade(p2) self.assertEqual(results, []) d.cancel() @@ -454,7 +447,8 @@ class InboundConnectionFactory(unittest.TestCase): class OutboundConnectionFactory(unittest.TestCase): def test_success(self): - f = transit.OutboundConnectionFactory("owner", "relay_handshake") + f = transit.OutboundConnectionFactory("owner", "relay_handshake", + "description") f.protocol = MockConnection addr = address.HostnameAddress("example.com", 1234) @@ -466,16 +460,15 @@ class OutboundConnectionFactory(unittest.TestCase): # meh .start # this is normally called from Connection.connectionMade - f.connectionWasMade(p, "desc") # no-op for outbound + f.connectionWasMade(p) # no-op for outbound self.assertEqual(p._start_negotiation_called, False) class MockOwner: _connection_ready_called = False - def connection_ready(self, connection, description): + def connection_ready(self, connection): self._connection_ready_called = True self._connection = connection - self._description = description return self._state def _send_this(self): return b"send_this" @@ -488,7 +481,7 @@ class MockOwner: class MockFactory: _connectionWasMade_called = False - def connectionWasMade(self, p, description): + def connectionWasMade(self, p): self._connectionWasMade_called = True self._p = p @@ -496,7 +489,7 @@ class Connection(unittest.TestCase): # exercise the Connection protocol class def test_check_and_remove(self): - c = transit.Connection(None, None, None) + c = transit.Connection(None, None, None, "description") c.buf = b"" EXP = b"expectation" self.assertFalse(c._check_and_remove(EXP)) @@ -525,7 +518,7 @@ class Connection(unittest.TestCase): owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) - c = transit.Connection(owner, relay_handshake, None) + c = transit.Connection(owner, relay_handshake, None, "description") self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory @@ -534,7 +527,7 @@ class Connection(unittest.TestCase): self.assertEqual(factory._p, c) owner._state = "go" - d = c.startNegotiation("description") + d = c.startNegotiation() self.assertEqual(c.state, "handshake") self.assertEqual(t.read_buf(), b"send_this") results = [] @@ -555,7 +548,7 @@ class Connection(unittest.TestCase): owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) - c = transit.Connection(owner, relay_handshake, None) + c = transit.Connection(owner, relay_handshake, None, "description") self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory @@ -564,7 +557,7 @@ class Connection(unittest.TestCase): self.assertEqual(factory._p, c) owner._state = "nevermind" - d = c.startNegotiation("description") + d = c.startNegotiation() self.assertEqual(c.state, "handshake") self.assertEqual(t.read_buf(), b"send_this") results = [] @@ -585,7 +578,7 @@ class Connection(unittest.TestCase): owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) - c = transit.Connection(owner, None, None) + c = transit.Connection(owner, None, None, "description") self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory @@ -593,7 +586,7 @@ class Connection(unittest.TestCase): self.assertEqual(factory._connectionWasMade_called, True) self.assertEqual(factory._p, c) - d = c.startNegotiation("description") + d = c.startNegotiation() self.assertEqual(c.state, "handshake") self.assertEqual(t.read_buf(), b"send_this") results = [] @@ -613,7 +606,7 @@ class Connection(unittest.TestCase): owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) - c = transit.Connection(owner, relay_handshake, None) + c = transit.Connection(owner, relay_handshake, None, "description") self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory @@ -623,7 +616,7 @@ class Connection(unittest.TestCase): self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation owner._state = "go" - d = c.startNegotiation("description") + d = c.startNegotiation() self.assertEqual(t.read_buf(), relay_handshake) self.assertEqual(c.state, "relay") # waiting for OK from relay @@ -646,7 +639,7 @@ class Connection(unittest.TestCase): owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) - c = transit.Connection(owner, relay_handshake, None) + c = transit.Connection(owner, relay_handshake, None, "description") self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory @@ -656,7 +649,7 @@ class Connection(unittest.TestCase): self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation owner._state = "go" - d = c.startNegotiation("description") + d = c.startNegotiation() self.assertEqual(t.read_buf(), relay_handshake) self.assertEqual(c.state, "relay") # waiting for OK from relay @@ -678,7 +671,7 @@ class Connection(unittest.TestCase): owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) - c = transit.Connection(owner, None, None) + c = transit.Connection(owner, None, None, "description") self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory @@ -687,7 +680,7 @@ class Connection(unittest.TestCase): self.assertEqual(factory._p, c) owner._state = "wait-for-decision" - d = c.startNegotiation("description") + d = c.startNegotiation() self.assertEqual(c.state, "handshake") self.assertEqual(t.read_buf(), b"send_this") results = [] @@ -707,7 +700,7 @@ class Connection(unittest.TestCase): owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) - c = transit.Connection(owner, None, None) + c = transit.Connection(owner, None, None, "description") self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory @@ -716,7 +709,7 @@ class Connection(unittest.TestCase): self.assertEqual(factory._p, c) owner._state = "wait-for-decision" - d = c.startNegotiation("description") + d = c.startNegotiation() self.assertEqual(c.state, "handshake") self.assertEqual(t.read_buf(), b"send_this") results = [] @@ -742,7 +735,7 @@ class Connection(unittest.TestCase): owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) - c = transit.Connection(owner, None, None) + c = transit.Connection(owner, None, None, "description") self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory @@ -751,7 +744,7 @@ class Connection(unittest.TestCase): self.assertEqual(factory._p, c) owner._state = "wait-for-decision" - d = c.startNegotiation("description") + d = c.startNegotiation() self.assertEqual(c.state, "handshake") self.assertEqual(t.read_buf(), b"send_this") results = [] @@ -775,13 +768,13 @@ class Connection(unittest.TestCase): owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) - c = transit.Connection(owner, None, None) + c = transit.Connection(owner, None, None, "description") self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) c.factory = factory c.connectionMade() - d = c.startNegotiation("description") + d = c.startNegotiation() results = [] d.addBoth(results.append) # while we're waiting for negotiation, we get cancelled @@ -799,7 +792,7 @@ class Connection(unittest.TestCase): owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) - c = transit.Connection(owner, None, None) + c = transit.Connection(owner, None, None, "description") def _callLater(period, func): clock.callLater(period, func) c.callLater = _callLater @@ -808,7 +801,7 @@ class Connection(unittest.TestCase): c.factory = factory c.connectionMade() # the timer should now be running - d = c.startNegotiation("description") + d = c.startNegotiation() results = [] d.addBoth(results.append) # while we're waiting for negotiation, the timer expires @@ -825,13 +818,13 @@ class Connection(unittest.TestCase): owner = MockOwner() factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) - c = transit.Connection(owner, None, None) + c = transit.Connection(owner, None, None, "description") t = c.transport = FakeTransport(c, addr) c.factory = factory c.connectionMade() owner._state = "go" - d = c.startNegotiation("description") + d = c.startNegotiation() results = [] d.addBoth(results.append) c.dataReceived(b"expect_this") @@ -954,7 +947,7 @@ class Connection(unittest.TestCase): # the key is None. def test_receive_queue(self): - c = transit.Connection(None, None, None) + c = transit.Connection(None, None, None, "description") c.transport = FakeTransport(c, None) c.transport.signalConnectionLost = False results = [[] for i in range(5)] @@ -995,7 +988,7 @@ class Connection(unittest.TestCase): def test_producer(self): # a Transit object (receiving data from the remote peer) produces # data and writes it into a local Consumer - c = transit.Connection(None, None, None) + c = transit.Connection(None, None, None, "description") c.transport = proto_helpers.StringTransport() c.recordReceived(b"r1.") c.recordReceived(b"r2.") @@ -1024,7 +1017,7 @@ class Connection(unittest.TestCase): def test_consumer(self): # a local producer sends data to a consuming Transit object - c = transit.Connection(None, None, None) + c = transit.Connection(None, None, None, "description") c.transport = proto_helpers.StringTransport() records = [] c.send_record = records.append diff --git a/src/wormhole/twisted/transit.py b/src/wormhole/twisted/transit.py index e5c684f..942fcd1 100644 --- a/src/wormhole/twisted/transit.py +++ b/src/wormhole/twisted/transit.py @@ -26,12 +26,13 @@ TIMEOUT=15 @implementer(interfaces.IProducer, interfaces.IConsumer) class Connection(protocol.Protocol, policies.TimeoutMixin): - def __init__(self, owner, relay_handshake, start): + def __init__(self, owner, relay_handshake, start, description): self.state = "too-early" self.buf = b"" self.owner = owner self.relay_handshake = relay_handshake self.start = start + self._description = description self._negotiation_d = defer.Deferred(self._cancel) self._error = None self._consumer = None @@ -41,10 +42,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): def connectionMade(self): debug("handle %r" % (self.transport,)) self.setTimeout(TIMEOUT) # does timeoutConnection() when it expires - self.factory.connectionWasMade(self, self.transport.getPeer()) + self.factory.connectionWasMade(self) - def startNegotiation(self, description): - self.description = description + def startNegotiation(self): if self.relay_handshake is not None: self.transport.write(self.relay_handshake) self.state = "relay" @@ -104,7 +104,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): if self.state == "handshake": if not self._check_and_remove(self.owner._expect_this()): return - self.state = self.owner.connection_ready(self, self.description) + self.state = self.owner.connection_ready(self) # If we're the receiver, we'll be moved to state # "wait-for-decision", which means we're waiting for the other # side (the sender) to make a decision. If we're the sender, @@ -161,6 +161,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): record = self.receive_box.decrypt(encrypted) return record + def describe(self): + return self._description + def send_record(self, record): if not isinstance(record, type(b"")): raise UsageError assert SecretBox.NONCE_SIZE == 24 @@ -260,17 +263,19 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): class OutboundConnectionFactory(protocol.ClientFactory): protocol = Connection - def __init__(self, owner, relay_handshake): + def __init__(self, owner, relay_handshake, description): self.owner = owner self.relay_handshake = relay_handshake + self._description = description self.start = time.time() def buildProtocol(self, addr): - p = self.protocol(self.owner, self.relay_handshake, self.start) + p = self.protocol(self.owner, self.relay_handshake, self.start, + self._description) p.factory = self return p - def connectionWasMade(self, p, addr): + def connectionWasMade(self, p): # outbound connections are handled via the endpoint pass @@ -295,7 +300,7 @@ class InboundConnectionFactory(protocol.ClientFactory): for d in list(self._pending_connections): d.cancel() # that fires _remove and _proto_failed - def describePeer(self, addr): + def _describePeer(self, addr): if isinstance(addr, address.HostnameAddress): return "<-%s:%d" % (addr.hostname, addr.port) elif isinstance(addr, (address.IPv4Address, address.IPv6Address)): @@ -303,12 +308,13 @@ class InboundConnectionFactory(protocol.ClientFactory): return "<-%r" % addr def buildProtocol(self, addr): - p = self.protocol(self.owner, None, self.start) + p = self.protocol(self.owner, None, self.start, + self._describePeer(addr)) p.factory = self return p - def connectionWasMade(self, p, addr): - d = p.startNegotiation(self.describePeer(addr)) + def connectionWasMade(self, p): + d = p.startNegotiation() self._pending_connections.add(d) d.addBoth(self._remove, d) d.addCallbacks(self._proto_succeeded, self._proto_failed) @@ -413,7 +419,6 @@ class Common: self._waiting_for_transit_key = [] self._listener = None self._winner = None - self._winner_description = None self._reactor = reactor def _build_listener(self): @@ -596,10 +601,10 @@ class Common: if is_relay: assert self._transit_key relay_handshake = build_relay_handshake(self._transit_key) - f = OutboundConnectionFactory(self, relay_handshake) + f = OutboundConnectionFactory(self, relay_handshake, description) d = ep.connect(f) # fires with protocol, or ConnectError - d.addCallback(lambda p: p.startNegotiation(description)) + d.addCallback(lambda p: p.startNegotiation()) return d def _endpoint_from_hint(self, hint): @@ -613,7 +618,7 @@ class Common: return endpoints.HostnameEndpoint(self._reactor, pieces[1], int(pieces[2])) - def connection_ready(self, p, description): + def connection_ready(self, p): # inbound/outbound Connection protocols call this when they finish # negotiation. The first one wins and gets a "go". Any subsequent # ones lose and get a "nevermind" before being closed. @@ -626,14 +631,8 @@ class Common: return "nevermind" # this one wins! self._winner = p - self._winner_description = description return "go" - def describe(self): - if not self._winner: - return "not yet established" - return self._winner_description - class TransitSender(Common): is_sender = True