move describe() from Transit to RecordPipe
This commit is contained in:
parent
1903c58248
commit
6654efb429
|
@ -166,13 +166,17 @@ class ReceiveBuffer:
|
||||||
return rc
|
return rc
|
||||||
|
|
||||||
class RecordPipe:
|
class RecordPipe:
|
||||||
def __init__(self, skt, send_key, receive_key):
|
def __init__(self, skt, send_key, receive_key, description):
|
||||||
self.skt = skt
|
self.skt = skt
|
||||||
self.send_box = SecretBox(send_key)
|
self.send_box = SecretBox(send_key)
|
||||||
self.send_nonce = 0
|
self.send_nonce = 0
|
||||||
self.receive_buf = ReceiveBuffer(self.skt)
|
self.receive_buf = ReceiveBuffer(self.skt)
|
||||||
self.receive_box = SecretBox(receive_key)
|
self.receive_box = SecretBox(receive_key)
|
||||||
self.next_receive_nonce = 0
|
self.next_receive_nonce = 0
|
||||||
|
self._description = description
|
||||||
|
|
||||||
|
def describe(self):
|
||||||
|
return self._description
|
||||||
|
|
||||||
def send_record(self, record):
|
def send_record(self, record):
|
||||||
if not isinstance(record, type(b"")): raise UsageError
|
if not isinstance(record, type(b"")): raise UsageError
|
||||||
|
@ -338,11 +342,6 @@ class Common:
|
||||||
return self.winning_skt
|
return self.winning_skt
|
||||||
raise TransitError("timeout")
|
raise TransitError("timeout")
|
||||||
|
|
||||||
def describe(self):
|
|
||||||
if not self.winning_skt_description:
|
|
||||||
return "not yet established"
|
|
||||||
return self.winning_skt_description
|
|
||||||
|
|
||||||
def _connector_failed(self, hint):
|
def _connector_failed(self, hint):
|
||||||
debug("- failed connector %s" % hint)
|
debug("- failed connector %s" % hint)
|
||||||
# XXX this was .remove, and occasionally got KeyError
|
# XXX this was .remove, and occasionally got KeyError
|
||||||
|
@ -375,7 +374,8 @@ class Common:
|
||||||
def connect(self):
|
def connect(self):
|
||||||
skt = self.establish_socket()
|
skt = self.establish_socket()
|
||||||
return RecordPipe(skt, self._sender_record_key(),
|
return RecordPipe(skt, self._sender_record_key(),
|
||||||
self._receiver_record_key())
|
self._receiver_record_key(),
|
||||||
|
self.winning_skt_description)
|
||||||
|
|
||||||
class TransitSender(Common):
|
class TransitSender(Common):
|
||||||
is_sender = True
|
is_sender = True
|
||||||
|
|
|
@ -122,7 +122,7 @@ def receive_blocking(args):
|
||||||
record_pipe = transit_receiver.connect()
|
record_pipe = transit_receiver.connect()
|
||||||
|
|
||||||
print(u"Receiving %d bytes for '%s' (%s).." %
|
print(u"Receiving %d bytes for '%s' (%s).." %
|
||||||
(xfersize, destname, transit_receiver.describe()),
|
(xfersize, destname, record_pipe.describe()),
|
||||||
file=args.stdout)
|
file=args.stdout)
|
||||||
if mode == "file":
|
if mode == "file":
|
||||||
tmp_destname = abs_destname + ".tmp"
|
tmp_destname = abs_destname + ".tmp"
|
||||||
|
|
|
@ -92,7 +92,7 @@ def _send_file_blocking(them_phase1, fd_to_send, transit_sender,
|
||||||
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
|
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
|
||||||
record_pipe = transit_sender.connect()
|
record_pipe = transit_sender.connect()
|
||||||
|
|
||||||
print(u"Sending (%s).." % transit_sender.describe(), file=stdout)
|
print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
|
||||||
|
|
||||||
CHUNKSIZE = 64*1024
|
CHUNKSIZE = 64*1024
|
||||||
fd_to_send.seek(0,2)
|
fd_to_send.seek(0,2)
|
||||||
|
|
|
@ -138,7 +138,7 @@ def _send_file_twisted(tdata, transit_sender, fd_to_send,
|
||||||
|
|
||||||
record_pipe = yield transit_sender.connect()
|
record_pipe = yield transit_sender.connect()
|
||||||
# record_pipe should implement IConsumer, chunks are just records
|
# record_pipe should implement IConsumer, chunks are just records
|
||||||
print(u"Sending (%s).." % transit_sender.describe(), file=stdout)
|
print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
|
||||||
yield pfs.beginFileTransfer(fd_to_send, record_pipe)
|
yield pfs.beginFileTransfer(fd_to_send, record_pipe)
|
||||||
print(u"File sent.. waiting for confirmation", file=stdout)
|
print(u"File sent.. waiting for confirmation", file=stdout)
|
||||||
ack = yield record_pipe.receive_record()
|
ack = yield record_pipe.receive_record()
|
||||||
|
|
|
@ -197,24 +197,14 @@ class Basic(unittest.TestCase):
|
||||||
|
|
||||||
def test_connection_ready(self):
|
def test_connection_ready(self):
|
||||||
s = transit.TransitSender(u"")
|
s = transit.TransitSender(u"")
|
||||||
self.assertEqual(s.describe(), "not yet established")
|
self.assertEqual(s.connection_ready("p1"), "go")
|
||||||
|
|
||||||
self.assertEqual(s.connection_ready("p1", "desc1"), "go")
|
|
||||||
self.assertEqual(s.describe(), "desc1")
|
|
||||||
self.assertEqual(s._winner, "p1")
|
self.assertEqual(s._winner, "p1")
|
||||||
|
self.assertEqual(s.connection_ready("p2"), "nevermind")
|
||||||
self.assertEqual(s.connection_ready("p2", "desc2"), "nevermind")
|
|
||||||
self.assertEqual(s.describe(), "desc1")
|
|
||||||
self.assertEqual(s._winner, "p1")
|
self.assertEqual(s._winner, "p1")
|
||||||
|
|
||||||
r = transit.TransitReceiver(u"")
|
r = transit.TransitReceiver(u"")
|
||||||
self.assertEqual(r.describe(), "not yet established")
|
self.assertEqual(r.connection_ready("p1"), "wait-for-decision")
|
||||||
|
self.assertEqual(r.connection_ready("p2"), "wait-for-decision")
|
||||||
self.assertEqual(r.connection_ready("p1", "desc1"), "wait-for-decision")
|
|
||||||
self.assertEqual(r.describe(), "not yet established")
|
|
||||||
|
|
||||||
self.assertEqual(r.connection_ready("p2", "desc2"), "wait-for-decision")
|
|
||||||
self.assertEqual(r.describe(), "not yet established")
|
|
||||||
|
|
||||||
|
|
||||||
class Listener(unittest.TestCase):
|
class Listener(unittest.TestCase):
|
||||||
|
@ -305,32 +295,32 @@ class RandomError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class MockConnection:
|
class MockConnection:
|
||||||
def __init__(self, owner, relay_handshake, start):
|
def __init__(self, owner, relay_handshake, start, description):
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
self.relay_handshake = relay_handshake
|
self.relay_handshake = relay_handshake
|
||||||
self.start = start
|
self.start = start
|
||||||
|
self._description = description
|
||||||
def cancel(d):
|
def cancel(d):
|
||||||
self._cancelled = True
|
self._cancelled = True
|
||||||
self._d = defer.Deferred(cancel)
|
self._d = defer.Deferred(cancel)
|
||||||
self._start_negotiation_called = False
|
self._start_negotiation_called = False
|
||||||
self._cancelled = False
|
self._cancelled = False
|
||||||
|
|
||||||
def startNegotiation(self, description):
|
def startNegotiation(self):
|
||||||
self._start_negotiation_called = True
|
self._start_negotiation_called = True
|
||||||
self._description = description
|
|
||||||
return self._d
|
return self._d
|
||||||
|
|
||||||
class InboundConnectionFactory(unittest.TestCase):
|
class InboundConnectionFactory(unittest.TestCase):
|
||||||
def test_describe(self):
|
def test_describe(self):
|
||||||
f = transit.InboundConnectionFactory(None)
|
f = transit.InboundConnectionFactory(None)
|
||||||
addrH = address.HostnameAddress("example.com", 1234)
|
addrH = address.HostnameAddress("example.com", 1234)
|
||||||
self.assertEqual(f.describePeer(addrH), "<-example.com:1234")
|
self.assertEqual(f._describePeer(addrH), "<-example.com:1234")
|
||||||
addr4 = address.IPv4Address("TCP", "1.2.3.4", 1234)
|
addr4 = address.IPv4Address("TCP", "1.2.3.4", 1234)
|
||||||
self.assertEqual(f.describePeer(addr4), "<-1.2.3.4:1234")
|
self.assertEqual(f._describePeer(addr4), "<-1.2.3.4:1234")
|
||||||
addr6 = address.IPv6Address("TCP", "::1", 1234)
|
addr6 = address.IPv6Address("TCP", "::1", 1234)
|
||||||
self.assertEqual(f.describePeer(addr6), "<-::1:1234")
|
self.assertEqual(f._describePeer(addr6), "<-::1:1234")
|
||||||
addrU = address.UNIXAddress("/dev/unlikely")
|
addrU = address.UNIXAddress("/dev/unlikely")
|
||||||
self.assertEqual(f.describePeer(addrU),
|
self.assertEqual(f._describePeer(addrU),
|
||||||
"<-UNIXAddress('/dev/unlikely')")
|
"<-UNIXAddress('/dev/unlikely')")
|
||||||
|
|
||||||
def test_success(self):
|
def test_success(self):
|
||||||
|
@ -350,7 +340,7 @@ class InboundConnectionFactory(unittest.TestCase):
|
||||||
# meh .start
|
# meh .start
|
||||||
|
|
||||||
# this is normally called from Connection.connectionMade
|
# this is normally called from Connection.connectionMade
|
||||||
f.connectionWasMade(p, addr)
|
f.connectionWasMade(p)
|
||||||
self.assertEqual(p._start_negotiation_called, True)
|
self.assertEqual(p._start_negotiation_called, True)
|
||||||
self.assertEqual(results, [])
|
self.assertEqual(results, [])
|
||||||
self.assertEqual(p._description, "<-example.com:1234")
|
self.assertEqual(p._description, "<-example.com:1234")
|
||||||
|
@ -366,12 +356,13 @@ class InboundConnectionFactory(unittest.TestCase):
|
||||||
d.addBoth(results.append)
|
d.addBoth(results.append)
|
||||||
self.assertEqual(results, [])
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr1 = address.HostnameAddress("example.com", 1234)
|
||||||
p1 = f.buildProtocol(addr)
|
addr2 = address.HostnameAddress("example.com", 5678)
|
||||||
p2 = f.buildProtocol(addr)
|
p1 = f.buildProtocol(addr1)
|
||||||
|
p2 = f.buildProtocol(addr2)
|
||||||
|
|
||||||
f.connectionWasMade(p1, "desc1")
|
f.connectionWasMade(p1)
|
||||||
f.connectionWasMade(p2, "desc2")
|
f.connectionWasMade(p2)
|
||||||
self.assertEqual(results, [])
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
p1._d.errback(transit.BadHandshake("nope"))
|
p1._d.errback(transit.BadHandshake("nope"))
|
||||||
|
@ -387,12 +378,13 @@ class InboundConnectionFactory(unittest.TestCase):
|
||||||
d.addBoth(results.append)
|
d.addBoth(results.append)
|
||||||
self.assertEqual(results, [])
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr1 = address.HostnameAddress("example.com", 1234)
|
||||||
p1 = f.buildProtocol(addr)
|
addr2 = address.HostnameAddress("example.com", 5678)
|
||||||
p2 = f.buildProtocol(addr)
|
p1 = f.buildProtocol(addr1)
|
||||||
|
p2 = f.buildProtocol(addr2)
|
||||||
|
|
||||||
f.connectionWasMade(p1, "desc1")
|
f.connectionWasMade(p1)
|
||||||
f.connectionWasMade(p2, "desc2")
|
f.connectionWasMade(p2)
|
||||||
self.assertEqual(results, [])
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
p1._d.callback(p1)
|
p1._d.callback(p1)
|
||||||
|
@ -414,7 +406,7 @@ class InboundConnectionFactory(unittest.TestCase):
|
||||||
# if the Connection protocol throws an unexpected error, that should
|
# if the Connection protocol throws an unexpected error, that should
|
||||||
# get logged to the Twisted logs (as an Unhandled Error in Deferred)
|
# get logged to the Twisted logs (as an Unhandled Error in Deferred)
|
||||||
# so we can diagnose the bug
|
# so we can diagnose the bug
|
||||||
f.connectionWasMade(p1, "desc1")
|
f.connectionWasMade(p1)
|
||||||
p1._d.errback(RandomError("boom"))
|
p1._d.errback(RandomError("boom"))
|
||||||
self.assertEqual(len(results), 0)
|
self.assertEqual(len(results), 0)
|
||||||
|
|
||||||
|
@ -433,12 +425,13 @@ class InboundConnectionFactory(unittest.TestCase):
|
||||||
d.addBoth(results.append)
|
d.addBoth(results.append)
|
||||||
self.assertEqual(results, [])
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr1 = address.HostnameAddress("example.com", 1234)
|
||||||
p1 = f.buildProtocol(addr)
|
addr2 = address.HostnameAddress("example.com", 5678)
|
||||||
p2 = f.buildProtocol(addr)
|
p1 = f.buildProtocol(addr1)
|
||||||
|
p2 = f.buildProtocol(addr2)
|
||||||
|
|
||||||
f.connectionWasMade(p1, "desc1")
|
f.connectionWasMade(p1)
|
||||||
f.connectionWasMade(p2, "desc2")
|
f.connectionWasMade(p2)
|
||||||
self.assertEqual(results, [])
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
d.cancel()
|
d.cancel()
|
||||||
|
@ -454,7 +447,8 @@ class InboundConnectionFactory(unittest.TestCase):
|
||||||
|
|
||||||
class OutboundConnectionFactory(unittest.TestCase):
|
class OutboundConnectionFactory(unittest.TestCase):
|
||||||
def test_success(self):
|
def test_success(self):
|
||||||
f = transit.OutboundConnectionFactory("owner", "relay_handshake")
|
f = transit.OutboundConnectionFactory("owner", "relay_handshake",
|
||||||
|
"description")
|
||||||
f.protocol = MockConnection
|
f.protocol = MockConnection
|
||||||
|
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr = address.HostnameAddress("example.com", 1234)
|
||||||
|
@ -466,16 +460,15 @@ class OutboundConnectionFactory(unittest.TestCase):
|
||||||
# meh .start
|
# meh .start
|
||||||
|
|
||||||
# this is normally called from Connection.connectionMade
|
# this is normally called from Connection.connectionMade
|
||||||
f.connectionWasMade(p, "desc") # no-op for outbound
|
f.connectionWasMade(p) # no-op for outbound
|
||||||
self.assertEqual(p._start_negotiation_called, False)
|
self.assertEqual(p._start_negotiation_called, False)
|
||||||
|
|
||||||
|
|
||||||
class MockOwner:
|
class MockOwner:
|
||||||
_connection_ready_called = False
|
_connection_ready_called = False
|
||||||
def connection_ready(self, connection, description):
|
def connection_ready(self, connection):
|
||||||
self._connection_ready_called = True
|
self._connection_ready_called = True
|
||||||
self._connection = connection
|
self._connection = connection
|
||||||
self._description = description
|
|
||||||
return self._state
|
return self._state
|
||||||
def _send_this(self):
|
def _send_this(self):
|
||||||
return b"send_this"
|
return b"send_this"
|
||||||
|
@ -488,7 +481,7 @@ class MockOwner:
|
||||||
|
|
||||||
class MockFactory:
|
class MockFactory:
|
||||||
_connectionWasMade_called = False
|
_connectionWasMade_called = False
|
||||||
def connectionWasMade(self, p, description):
|
def connectionWasMade(self, p):
|
||||||
self._connectionWasMade_called = True
|
self._connectionWasMade_called = True
|
||||||
self._p = p
|
self._p = p
|
||||||
|
|
||||||
|
@ -496,7 +489,7 @@ class Connection(unittest.TestCase):
|
||||||
# exercise the Connection protocol class
|
# exercise the Connection protocol class
|
||||||
|
|
||||||
def test_check_and_remove(self):
|
def test_check_and_remove(self):
|
||||||
c = transit.Connection(None, None, None)
|
c = transit.Connection(None, None, None, "description")
|
||||||
c.buf = b""
|
c.buf = b""
|
||||||
EXP = b"expectation"
|
EXP = b"expectation"
|
||||||
self.assertFalse(c._check_and_remove(EXP))
|
self.assertFalse(c._check_and_remove(EXP))
|
||||||
|
@ -525,7 +518,7 @@ class Connection(unittest.TestCase):
|
||||||
owner = MockOwner()
|
owner = MockOwner()
|
||||||
factory = MockFactory()
|
factory = MockFactory()
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr = address.HostnameAddress("example.com", 1234)
|
||||||
c = transit.Connection(owner, relay_handshake, None)
|
c = transit.Connection(owner, relay_handshake, None, "description")
|
||||||
self.assertEqual(c.state, "too-early")
|
self.assertEqual(c.state, "too-early")
|
||||||
t = c.transport = FakeTransport(c, addr)
|
t = c.transport = FakeTransport(c, addr)
|
||||||
c.factory = factory
|
c.factory = factory
|
||||||
|
@ -534,7 +527,7 @@ class Connection(unittest.TestCase):
|
||||||
self.assertEqual(factory._p, c)
|
self.assertEqual(factory._p, c)
|
||||||
|
|
||||||
owner._state = "go"
|
owner._state = "go"
|
||||||
d = c.startNegotiation("description")
|
d = c.startNegotiation()
|
||||||
self.assertEqual(c.state, "handshake")
|
self.assertEqual(c.state, "handshake")
|
||||||
self.assertEqual(t.read_buf(), b"send_this")
|
self.assertEqual(t.read_buf(), b"send_this")
|
||||||
results = []
|
results = []
|
||||||
|
@ -555,7 +548,7 @@ class Connection(unittest.TestCase):
|
||||||
owner = MockOwner()
|
owner = MockOwner()
|
||||||
factory = MockFactory()
|
factory = MockFactory()
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr = address.HostnameAddress("example.com", 1234)
|
||||||
c = transit.Connection(owner, relay_handshake, None)
|
c = transit.Connection(owner, relay_handshake, None, "description")
|
||||||
self.assertEqual(c.state, "too-early")
|
self.assertEqual(c.state, "too-early")
|
||||||
t = c.transport = FakeTransport(c, addr)
|
t = c.transport = FakeTransport(c, addr)
|
||||||
c.factory = factory
|
c.factory = factory
|
||||||
|
@ -564,7 +557,7 @@ class Connection(unittest.TestCase):
|
||||||
self.assertEqual(factory._p, c)
|
self.assertEqual(factory._p, c)
|
||||||
|
|
||||||
owner._state = "nevermind"
|
owner._state = "nevermind"
|
||||||
d = c.startNegotiation("description")
|
d = c.startNegotiation()
|
||||||
self.assertEqual(c.state, "handshake")
|
self.assertEqual(c.state, "handshake")
|
||||||
self.assertEqual(t.read_buf(), b"send_this")
|
self.assertEqual(t.read_buf(), b"send_this")
|
||||||
results = []
|
results = []
|
||||||
|
@ -585,7 +578,7 @@ class Connection(unittest.TestCase):
|
||||||
owner = MockOwner()
|
owner = MockOwner()
|
||||||
factory = MockFactory()
|
factory = MockFactory()
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr = address.HostnameAddress("example.com", 1234)
|
||||||
c = transit.Connection(owner, None, None)
|
c = transit.Connection(owner, None, None, "description")
|
||||||
self.assertEqual(c.state, "too-early")
|
self.assertEqual(c.state, "too-early")
|
||||||
t = c.transport = FakeTransport(c, addr)
|
t = c.transport = FakeTransport(c, addr)
|
||||||
c.factory = factory
|
c.factory = factory
|
||||||
|
@ -593,7 +586,7 @@ class Connection(unittest.TestCase):
|
||||||
self.assertEqual(factory._connectionWasMade_called, True)
|
self.assertEqual(factory._connectionWasMade_called, True)
|
||||||
self.assertEqual(factory._p, c)
|
self.assertEqual(factory._p, c)
|
||||||
|
|
||||||
d = c.startNegotiation("description")
|
d = c.startNegotiation()
|
||||||
self.assertEqual(c.state, "handshake")
|
self.assertEqual(c.state, "handshake")
|
||||||
self.assertEqual(t.read_buf(), b"send_this")
|
self.assertEqual(t.read_buf(), b"send_this")
|
||||||
results = []
|
results = []
|
||||||
|
@ -613,7 +606,7 @@ class Connection(unittest.TestCase):
|
||||||
owner = MockOwner()
|
owner = MockOwner()
|
||||||
factory = MockFactory()
|
factory = MockFactory()
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr = address.HostnameAddress("example.com", 1234)
|
||||||
c = transit.Connection(owner, relay_handshake, None)
|
c = transit.Connection(owner, relay_handshake, None, "description")
|
||||||
self.assertEqual(c.state, "too-early")
|
self.assertEqual(c.state, "too-early")
|
||||||
t = c.transport = FakeTransport(c, addr)
|
t = c.transport = FakeTransport(c, addr)
|
||||||
c.factory = factory
|
c.factory = factory
|
||||||
|
@ -623,7 +616,7 @@ class Connection(unittest.TestCase):
|
||||||
self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation
|
self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation
|
||||||
|
|
||||||
owner._state = "go"
|
owner._state = "go"
|
||||||
d = c.startNegotiation("description")
|
d = c.startNegotiation()
|
||||||
self.assertEqual(t.read_buf(), relay_handshake)
|
self.assertEqual(t.read_buf(), relay_handshake)
|
||||||
self.assertEqual(c.state, "relay") # waiting for OK from relay
|
self.assertEqual(c.state, "relay") # waiting for OK from relay
|
||||||
|
|
||||||
|
@ -646,7 +639,7 @@ class Connection(unittest.TestCase):
|
||||||
owner = MockOwner()
|
owner = MockOwner()
|
||||||
factory = MockFactory()
|
factory = MockFactory()
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr = address.HostnameAddress("example.com", 1234)
|
||||||
c = transit.Connection(owner, relay_handshake, None)
|
c = transit.Connection(owner, relay_handshake, None, "description")
|
||||||
self.assertEqual(c.state, "too-early")
|
self.assertEqual(c.state, "too-early")
|
||||||
t = c.transport = FakeTransport(c, addr)
|
t = c.transport = FakeTransport(c, addr)
|
||||||
c.factory = factory
|
c.factory = factory
|
||||||
|
@ -656,7 +649,7 @@ class Connection(unittest.TestCase):
|
||||||
self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation
|
self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation
|
||||||
|
|
||||||
owner._state = "go"
|
owner._state = "go"
|
||||||
d = c.startNegotiation("description")
|
d = c.startNegotiation()
|
||||||
self.assertEqual(t.read_buf(), relay_handshake)
|
self.assertEqual(t.read_buf(), relay_handshake)
|
||||||
self.assertEqual(c.state, "relay") # waiting for OK from relay
|
self.assertEqual(c.state, "relay") # waiting for OK from relay
|
||||||
|
|
||||||
|
@ -678,7 +671,7 @@ class Connection(unittest.TestCase):
|
||||||
owner = MockOwner()
|
owner = MockOwner()
|
||||||
factory = MockFactory()
|
factory = MockFactory()
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr = address.HostnameAddress("example.com", 1234)
|
||||||
c = transit.Connection(owner, None, None)
|
c = transit.Connection(owner, None, None, "description")
|
||||||
self.assertEqual(c.state, "too-early")
|
self.assertEqual(c.state, "too-early")
|
||||||
t = c.transport = FakeTransport(c, addr)
|
t = c.transport = FakeTransport(c, addr)
|
||||||
c.factory = factory
|
c.factory = factory
|
||||||
|
@ -687,7 +680,7 @@ class Connection(unittest.TestCase):
|
||||||
self.assertEqual(factory._p, c)
|
self.assertEqual(factory._p, c)
|
||||||
|
|
||||||
owner._state = "wait-for-decision"
|
owner._state = "wait-for-decision"
|
||||||
d = c.startNegotiation("description")
|
d = c.startNegotiation()
|
||||||
self.assertEqual(c.state, "handshake")
|
self.assertEqual(c.state, "handshake")
|
||||||
self.assertEqual(t.read_buf(), b"send_this")
|
self.assertEqual(t.read_buf(), b"send_this")
|
||||||
results = []
|
results = []
|
||||||
|
@ -707,7 +700,7 @@ class Connection(unittest.TestCase):
|
||||||
owner = MockOwner()
|
owner = MockOwner()
|
||||||
factory = MockFactory()
|
factory = MockFactory()
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr = address.HostnameAddress("example.com", 1234)
|
||||||
c = transit.Connection(owner, None, None)
|
c = transit.Connection(owner, None, None, "description")
|
||||||
self.assertEqual(c.state, "too-early")
|
self.assertEqual(c.state, "too-early")
|
||||||
t = c.transport = FakeTransport(c, addr)
|
t = c.transport = FakeTransport(c, addr)
|
||||||
c.factory = factory
|
c.factory = factory
|
||||||
|
@ -716,7 +709,7 @@ class Connection(unittest.TestCase):
|
||||||
self.assertEqual(factory._p, c)
|
self.assertEqual(factory._p, c)
|
||||||
|
|
||||||
owner._state = "wait-for-decision"
|
owner._state = "wait-for-decision"
|
||||||
d = c.startNegotiation("description")
|
d = c.startNegotiation()
|
||||||
self.assertEqual(c.state, "handshake")
|
self.assertEqual(c.state, "handshake")
|
||||||
self.assertEqual(t.read_buf(), b"send_this")
|
self.assertEqual(t.read_buf(), b"send_this")
|
||||||
results = []
|
results = []
|
||||||
|
@ -742,7 +735,7 @@ class Connection(unittest.TestCase):
|
||||||
owner = MockOwner()
|
owner = MockOwner()
|
||||||
factory = MockFactory()
|
factory = MockFactory()
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr = address.HostnameAddress("example.com", 1234)
|
||||||
c = transit.Connection(owner, None, None)
|
c = transit.Connection(owner, None, None, "description")
|
||||||
self.assertEqual(c.state, "too-early")
|
self.assertEqual(c.state, "too-early")
|
||||||
t = c.transport = FakeTransport(c, addr)
|
t = c.transport = FakeTransport(c, addr)
|
||||||
c.factory = factory
|
c.factory = factory
|
||||||
|
@ -751,7 +744,7 @@ class Connection(unittest.TestCase):
|
||||||
self.assertEqual(factory._p, c)
|
self.assertEqual(factory._p, c)
|
||||||
|
|
||||||
owner._state = "wait-for-decision"
|
owner._state = "wait-for-decision"
|
||||||
d = c.startNegotiation("description")
|
d = c.startNegotiation()
|
||||||
self.assertEqual(c.state, "handshake")
|
self.assertEqual(c.state, "handshake")
|
||||||
self.assertEqual(t.read_buf(), b"send_this")
|
self.assertEqual(t.read_buf(), b"send_this")
|
||||||
results = []
|
results = []
|
||||||
|
@ -775,13 +768,13 @@ class Connection(unittest.TestCase):
|
||||||
owner = MockOwner()
|
owner = MockOwner()
|
||||||
factory = MockFactory()
|
factory = MockFactory()
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr = address.HostnameAddress("example.com", 1234)
|
||||||
c = transit.Connection(owner, None, None)
|
c = transit.Connection(owner, None, None, "description")
|
||||||
self.assertEqual(c.state, "too-early")
|
self.assertEqual(c.state, "too-early")
|
||||||
t = c.transport = FakeTransport(c, addr)
|
t = c.transport = FakeTransport(c, addr)
|
||||||
c.factory = factory
|
c.factory = factory
|
||||||
c.connectionMade()
|
c.connectionMade()
|
||||||
|
|
||||||
d = c.startNegotiation("description")
|
d = c.startNegotiation()
|
||||||
results = []
|
results = []
|
||||||
d.addBoth(results.append)
|
d.addBoth(results.append)
|
||||||
# while we're waiting for negotiation, we get cancelled
|
# while we're waiting for negotiation, we get cancelled
|
||||||
|
@ -799,7 +792,7 @@ class Connection(unittest.TestCase):
|
||||||
owner = MockOwner()
|
owner = MockOwner()
|
||||||
factory = MockFactory()
|
factory = MockFactory()
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr = address.HostnameAddress("example.com", 1234)
|
||||||
c = transit.Connection(owner, None, None)
|
c = transit.Connection(owner, None, None, "description")
|
||||||
def _callLater(period, func):
|
def _callLater(period, func):
|
||||||
clock.callLater(period, func)
|
clock.callLater(period, func)
|
||||||
c.callLater = _callLater
|
c.callLater = _callLater
|
||||||
|
@ -808,7 +801,7 @@ class Connection(unittest.TestCase):
|
||||||
c.factory = factory
|
c.factory = factory
|
||||||
c.connectionMade()
|
c.connectionMade()
|
||||||
# the timer should now be running
|
# the timer should now be running
|
||||||
d = c.startNegotiation("description")
|
d = c.startNegotiation()
|
||||||
results = []
|
results = []
|
||||||
d.addBoth(results.append)
|
d.addBoth(results.append)
|
||||||
# while we're waiting for negotiation, the timer expires
|
# while we're waiting for negotiation, the timer expires
|
||||||
|
@ -825,13 +818,13 @@ class Connection(unittest.TestCase):
|
||||||
owner = MockOwner()
|
owner = MockOwner()
|
||||||
factory = MockFactory()
|
factory = MockFactory()
|
||||||
addr = address.HostnameAddress("example.com", 1234)
|
addr = address.HostnameAddress("example.com", 1234)
|
||||||
c = transit.Connection(owner, None, None)
|
c = transit.Connection(owner, None, None, "description")
|
||||||
t = c.transport = FakeTransport(c, addr)
|
t = c.transport = FakeTransport(c, addr)
|
||||||
c.factory = factory
|
c.factory = factory
|
||||||
c.connectionMade()
|
c.connectionMade()
|
||||||
|
|
||||||
owner._state = "go"
|
owner._state = "go"
|
||||||
d = c.startNegotiation("description")
|
d = c.startNegotiation()
|
||||||
results = []
|
results = []
|
||||||
d.addBoth(results.append)
|
d.addBoth(results.append)
|
||||||
c.dataReceived(b"expect_this")
|
c.dataReceived(b"expect_this")
|
||||||
|
@ -954,7 +947,7 @@ class Connection(unittest.TestCase):
|
||||||
# the key is None.
|
# the key is None.
|
||||||
|
|
||||||
def test_receive_queue(self):
|
def test_receive_queue(self):
|
||||||
c = transit.Connection(None, None, None)
|
c = transit.Connection(None, None, None, "description")
|
||||||
c.transport = FakeTransport(c, None)
|
c.transport = FakeTransport(c, None)
|
||||||
c.transport.signalConnectionLost = False
|
c.transport.signalConnectionLost = False
|
||||||
results = [[] for i in range(5)]
|
results = [[] for i in range(5)]
|
||||||
|
@ -995,7 +988,7 @@ class Connection(unittest.TestCase):
|
||||||
def test_producer(self):
|
def test_producer(self):
|
||||||
# a Transit object (receiving data from the remote peer) produces
|
# a Transit object (receiving data from the remote peer) produces
|
||||||
# data and writes it into a local Consumer
|
# data and writes it into a local Consumer
|
||||||
c = transit.Connection(None, None, None)
|
c = transit.Connection(None, None, None, "description")
|
||||||
c.transport = proto_helpers.StringTransport()
|
c.transport = proto_helpers.StringTransport()
|
||||||
c.recordReceived(b"r1.")
|
c.recordReceived(b"r1.")
|
||||||
c.recordReceived(b"r2.")
|
c.recordReceived(b"r2.")
|
||||||
|
@ -1024,7 +1017,7 @@ class Connection(unittest.TestCase):
|
||||||
|
|
||||||
def test_consumer(self):
|
def test_consumer(self):
|
||||||
# a local producer sends data to a consuming Transit object
|
# a local producer sends data to a consuming Transit object
|
||||||
c = transit.Connection(None, None, None)
|
c = transit.Connection(None, None, None, "description")
|
||||||
c.transport = proto_helpers.StringTransport()
|
c.transport = proto_helpers.StringTransport()
|
||||||
records = []
|
records = []
|
||||||
c.send_record = records.append
|
c.send_record = records.append
|
||||||
|
|
|
@ -26,12 +26,13 @@ TIMEOUT=15
|
||||||
|
|
||||||
@implementer(interfaces.IProducer, interfaces.IConsumer)
|
@implementer(interfaces.IProducer, interfaces.IConsumer)
|
||||||
class Connection(protocol.Protocol, policies.TimeoutMixin):
|
class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
def __init__(self, owner, relay_handshake, start):
|
def __init__(self, owner, relay_handshake, start, description):
|
||||||
self.state = "too-early"
|
self.state = "too-early"
|
||||||
self.buf = b""
|
self.buf = b""
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
self.relay_handshake = relay_handshake
|
self.relay_handshake = relay_handshake
|
||||||
self.start = start
|
self.start = start
|
||||||
|
self._description = description
|
||||||
self._negotiation_d = defer.Deferred(self._cancel)
|
self._negotiation_d = defer.Deferred(self._cancel)
|
||||||
self._error = None
|
self._error = None
|
||||||
self._consumer = None
|
self._consumer = None
|
||||||
|
@ -41,10 +42,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
def connectionMade(self):
|
def connectionMade(self):
|
||||||
debug("handle %r" % (self.transport,))
|
debug("handle %r" % (self.transport,))
|
||||||
self.setTimeout(TIMEOUT) # does timeoutConnection() when it expires
|
self.setTimeout(TIMEOUT) # does timeoutConnection() when it expires
|
||||||
self.factory.connectionWasMade(self, self.transport.getPeer())
|
self.factory.connectionWasMade(self)
|
||||||
|
|
||||||
def startNegotiation(self, description):
|
def startNegotiation(self):
|
||||||
self.description = description
|
|
||||||
if self.relay_handshake is not None:
|
if self.relay_handshake is not None:
|
||||||
self.transport.write(self.relay_handshake)
|
self.transport.write(self.relay_handshake)
|
||||||
self.state = "relay"
|
self.state = "relay"
|
||||||
|
@ -104,7 +104,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
if self.state == "handshake":
|
if self.state == "handshake":
|
||||||
if not self._check_and_remove(self.owner._expect_this()):
|
if not self._check_and_remove(self.owner._expect_this()):
|
||||||
return
|
return
|
||||||
self.state = self.owner.connection_ready(self, self.description)
|
self.state = self.owner.connection_ready(self)
|
||||||
# If we're the receiver, we'll be moved to state
|
# If we're the receiver, we'll be moved to state
|
||||||
# "wait-for-decision", which means we're waiting for the other
|
# "wait-for-decision", which means we're waiting for the other
|
||||||
# side (the sender) to make a decision. If we're the sender,
|
# side (the sender) to make a decision. If we're the sender,
|
||||||
|
@ -161,6 +161,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
record = self.receive_box.decrypt(encrypted)
|
record = self.receive_box.decrypt(encrypted)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
def describe(self):
|
||||||
|
return self._description
|
||||||
|
|
||||||
def send_record(self, record):
|
def send_record(self, record):
|
||||||
if not isinstance(record, type(b"")): raise UsageError
|
if not isinstance(record, type(b"")): raise UsageError
|
||||||
assert SecretBox.NONCE_SIZE == 24
|
assert SecretBox.NONCE_SIZE == 24
|
||||||
|
@ -260,17 +263,19 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
class OutboundConnectionFactory(protocol.ClientFactory):
|
class OutboundConnectionFactory(protocol.ClientFactory):
|
||||||
protocol = Connection
|
protocol = Connection
|
||||||
|
|
||||||
def __init__(self, owner, relay_handshake):
|
def __init__(self, owner, relay_handshake, description):
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
self.relay_handshake = relay_handshake
|
self.relay_handshake = relay_handshake
|
||||||
|
self._description = description
|
||||||
self.start = time.time()
|
self.start = time.time()
|
||||||
|
|
||||||
def buildProtocol(self, addr):
|
def buildProtocol(self, addr):
|
||||||
p = self.protocol(self.owner, self.relay_handshake, self.start)
|
p = self.protocol(self.owner, self.relay_handshake, self.start,
|
||||||
|
self._description)
|
||||||
p.factory = self
|
p.factory = self
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def connectionWasMade(self, p, addr):
|
def connectionWasMade(self, p):
|
||||||
# outbound connections are handled via the endpoint
|
# outbound connections are handled via the endpoint
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -295,7 +300,7 @@ class InboundConnectionFactory(protocol.ClientFactory):
|
||||||
for d in list(self._pending_connections):
|
for d in list(self._pending_connections):
|
||||||
d.cancel() # that fires _remove and _proto_failed
|
d.cancel() # that fires _remove and _proto_failed
|
||||||
|
|
||||||
def describePeer(self, addr):
|
def _describePeer(self, addr):
|
||||||
if isinstance(addr, address.HostnameAddress):
|
if isinstance(addr, address.HostnameAddress):
|
||||||
return "<-%s:%d" % (addr.hostname, addr.port)
|
return "<-%s:%d" % (addr.hostname, addr.port)
|
||||||
elif isinstance(addr, (address.IPv4Address, address.IPv6Address)):
|
elif isinstance(addr, (address.IPv4Address, address.IPv6Address)):
|
||||||
|
@ -303,12 +308,13 @@ class InboundConnectionFactory(protocol.ClientFactory):
|
||||||
return "<-%r" % addr
|
return "<-%r" % addr
|
||||||
|
|
||||||
def buildProtocol(self, addr):
|
def buildProtocol(self, addr):
|
||||||
p = self.protocol(self.owner, None, self.start)
|
p = self.protocol(self.owner, None, self.start,
|
||||||
|
self._describePeer(addr))
|
||||||
p.factory = self
|
p.factory = self
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def connectionWasMade(self, p, addr):
|
def connectionWasMade(self, p):
|
||||||
d = p.startNegotiation(self.describePeer(addr))
|
d = p.startNegotiation()
|
||||||
self._pending_connections.add(d)
|
self._pending_connections.add(d)
|
||||||
d.addBoth(self._remove, d)
|
d.addBoth(self._remove, d)
|
||||||
d.addCallbacks(self._proto_succeeded, self._proto_failed)
|
d.addCallbacks(self._proto_succeeded, self._proto_failed)
|
||||||
|
@ -413,7 +419,6 @@ class Common:
|
||||||
self._waiting_for_transit_key = []
|
self._waiting_for_transit_key = []
|
||||||
self._listener = None
|
self._listener = None
|
||||||
self._winner = None
|
self._winner = None
|
||||||
self._winner_description = None
|
|
||||||
self._reactor = reactor
|
self._reactor = reactor
|
||||||
|
|
||||||
def _build_listener(self):
|
def _build_listener(self):
|
||||||
|
@ -596,10 +601,10 @@ class Common:
|
||||||
if is_relay:
|
if is_relay:
|
||||||
assert self._transit_key
|
assert self._transit_key
|
||||||
relay_handshake = build_relay_handshake(self._transit_key)
|
relay_handshake = build_relay_handshake(self._transit_key)
|
||||||
f = OutboundConnectionFactory(self, relay_handshake)
|
f = OutboundConnectionFactory(self, relay_handshake, description)
|
||||||
d = ep.connect(f)
|
d = ep.connect(f)
|
||||||
# fires with protocol, or ConnectError
|
# fires with protocol, or ConnectError
|
||||||
d.addCallback(lambda p: p.startNegotiation(description))
|
d.addCallback(lambda p: p.startNegotiation())
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def _endpoint_from_hint(self, hint):
|
def _endpoint_from_hint(self, hint):
|
||||||
|
@ -613,7 +618,7 @@ class Common:
|
||||||
return endpoints.HostnameEndpoint(self._reactor, pieces[1],
|
return endpoints.HostnameEndpoint(self._reactor, pieces[1],
|
||||||
int(pieces[2]))
|
int(pieces[2]))
|
||||||
|
|
||||||
def connection_ready(self, p, description):
|
def connection_ready(self, p):
|
||||||
# inbound/outbound Connection protocols call this when they finish
|
# inbound/outbound Connection protocols call this when they finish
|
||||||
# negotiation. The first one wins and gets a "go". Any subsequent
|
# negotiation. The first one wins and gets a "go". Any subsequent
|
||||||
# ones lose and get a "nevermind" before being closed.
|
# ones lose and get a "nevermind" before being closed.
|
||||||
|
@ -626,14 +631,8 @@ class Common:
|
||||||
return "nevermind"
|
return "nevermind"
|
||||||
# this one wins!
|
# this one wins!
|
||||||
self._winner = p
|
self._winner = p
|
||||||
self._winner_description = description
|
|
||||||
return "go"
|
return "go"
|
||||||
|
|
||||||
def describe(self):
|
|
||||||
if not self._winner:
|
|
||||||
return "not yet established"
|
|
||||||
return self._winner_description
|
|
||||||
|
|
||||||
class TransitSender(Common):
|
class TransitSender(Common):
|
||||||
is_sender = True
|
is_sender = True
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user