add description to inbound connections

This commit is contained in:
Brian Warner 2019-02-10 16:52:17 -08:00
parent 74c416517f
commit ebc63e52e0
4 changed files with 29 additions and 17 deletions

View File

@ -475,6 +475,7 @@ class DilatedConnectionProtocol(Protocol, object):
_eventual_queue = attrib() _eventual_queue = attrib()
_role = attrib() _role = attrib()
_description = attrib()
_connector = attrib(validator=provides(IDilationConnector)) _connector = attrib(validator=provides(IDilationConnector))
_noise = attrib() _noise = attrib()
_outbound_prologue = attrib(validator=instance_of(bytes)) _outbound_prologue = attrib(validator=instance_of(bytes))

View File

@ -9,6 +9,7 @@ from twisted.internet.task import deferLater
from twisted.internet.defer import DeferredList from twisted.internet.defer import DeferredList
from twisted.internet.endpoints import serverFromString from twisted.internet.endpoints import serverFromString
from twisted.internet.protocol import ClientFactory, ServerFactory from twisted.internet.protocol import ClientFactory, ServerFactory
from twisted.internet.address import HostnameAddress, IPv4Address, IPv6Address
from twisted.python import log from twisted.python import log
from .. import ipaddrs # TODO: move into _dilation/ from .. import ipaddrs # TODO: move into _dilation/
from .._interfaces import IDilationConnector, IDilationManager from .._interfaces import IDilationConnector, IDilationManager
@ -110,7 +111,7 @@ class Connector(object):
{"type": "relay-v1"}, {"type": "relay-v1"},
] ]
def build_protocol(self, addr): def build_protocol(self, addr, description):
# encryption: let's use Noise NNpsk0 (or maybe NNpsk2). That uses # encryption: let's use Noise NNpsk0 (or maybe NNpsk2). That uses
# ephemeral keys plus a pre-shared symmetric key (the Transit key), a # ephemeral keys plus a pre-shared symmetric key (the Transit key), a
# different one for each potential connection. # different one for each potential connection.
@ -125,6 +126,7 @@ class Connector(object):
outbound_prologue = PROLOGUE_FOLLOWER outbound_prologue = PROLOGUE_FOLLOWER
inbound_prologue = PROLOGUE_LEADER inbound_prologue = PROLOGUE_LEADER
p = DilatedConnectionProtocol(self._eventual_queue, self._role, p = DilatedConnectionProtocol(self._eventual_queue, self._role,
description,
self, noise, self, noise,
outbound_prologue, inbound_prologue) outbound_prologue, inbound_prologue)
return p return p
@ -368,7 +370,7 @@ class Connector(object):
if is_relay: if is_relay:
relay_handshake = build_sided_relay_handshake(self._dilation_key, relay_handshake = build_sided_relay_handshake(self._dilation_key,
self._side) self._side)
f = OutboundConnectionFactory(self, relay_handshake) f = OutboundConnectionFactory(self, relay_handshake, description)
d = ep.connect(f) d = ep.connect(f)
# fires with protocol, or ConnectError # fires with protocol, or ConnectError
@ -399,20 +401,28 @@ class Connector(object):
class OutboundConnectionFactory(ClientFactory, object): class OutboundConnectionFactory(ClientFactory, object):
_connector = attrib(validator=provides(IDilationConnector)) _connector = attrib(validator=provides(IDilationConnector))
_relay_handshake = attrib(validator=optional(instance_of(bytes))) _relay_handshake = attrib(validator=optional(instance_of(bytes)))
_description = attrib()
def buildProtocol(self, addr): def buildProtocol(self, addr):
p = self._connector.build_protocol(addr) p = self._connector.build_protocol(addr, self._description)
p.factory = self p.factory = self
if self._relay_handshake is not None: if self._relay_handshake is not None:
p.use_relay(self._relay_handshake) p.use_relay(self._relay_handshake)
return p return p
def describe_inbound(addr):
if isinstance(addr, HostnameAddress):
return "<-tcp:%s:%d" % (addr.hostname, addr.port)
elif isinstance(addr, (IPv4Address, IPv6Address)):
return "<-tcp:%s:%d" % (addr.host, addr.port)
return "<-%r" % addr
@attrs @attrs
class InboundConnectionFactory(ServerFactory, object): class InboundConnectionFactory(ServerFactory, object):
_connector = attrib(validator=provides(IDilationConnector)) _connector = attrib(validator=provides(IDilationConnector))
def buildProtocol(self, addr): def buildProtocol(self, addr):
p = self._connector.build_protocol(addr) description = describe_inbound(addr)
p = self._connector.build_protocol(addr, description)
p.factory = self p.factory = self
return p return p

View File

@ -20,7 +20,7 @@ def make_con(role, use_relay=False):
alsoProvides(connector, IDilationConnector) alsoProvides(connector, IDilationConnector)
n = mock.Mock() # pretends to be a Noise object n = mock.Mock() # pretends to be a Noise object
n.write_message = mock.Mock(side_effect=[b"handshake"]) n.write_message = mock.Mock(side_effect=[b"handshake"])
c = DilatedConnectionProtocol(eq, role, connector, n, c = DilatedConnectionProtocol(eq, role, "desc", connector, n,
b"outbound_prologue\n", b"inbound_prologue\n") b"outbound_prologue\n", b"inbound_prologue\n")
if use_relay: if use_relay:
c.use_relay(b"relay_handshake\n") c.use_relay(b"relay_handshake\n")

View File

@ -5,6 +5,7 @@ from zope.interface import alsoProvides
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet.task import Clock from twisted.internet.task import Clock
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.internet.address import IPv4Address
from ...eventual import EventualQueue from ...eventual import EventualQueue
from ..._interfaces import IDilationManager, IDilationConnector from ..._interfaces import IDilationManager, IDilationConnector
from ..._hints import DirectTCPV1Hint, RelayV1Hint, TorTCPV1Hint from ..._hints import DirectTCPV1Hint, RelayV1Hint, TorTCPV1Hint
@ -34,11 +35,11 @@ class Outbound(unittest.TestCase):
p0 = mock.Mock() p0 = mock.Mock()
c.build_protocol = mock.Mock(return_value=p0) c.build_protocol = mock.Mock(return_value=p0)
relay_handshake = None relay_handshake = None
f = OutboundConnectionFactory(c, relay_handshake) f = OutboundConnectionFactory(c, relay_handshake, "desc")
addr = object() addr = object()
p = f.buildProtocol(addr) p = f.buildProtocol(addr)
self.assertIdentical(p, p0) self.assertIdentical(p, p0)
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)]) self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr, "desc")])
self.assertEqual(p.mock_calls, []) self.assertEqual(p.mock_calls, [])
self.assertIdentical(p.factory, f) self.assertIdentical(p.factory, f)
@ -48,11 +49,11 @@ class Outbound(unittest.TestCase):
p0 = mock.Mock() p0 = mock.Mock()
c.build_protocol = mock.Mock(return_value=p0) c.build_protocol = mock.Mock(return_value=p0)
relay_handshake = b"relay handshake" relay_handshake = b"relay handshake"
f = OutboundConnectionFactory(c, relay_handshake) f = OutboundConnectionFactory(c, relay_handshake, "desc")
addr = object() addr = object()
p = f.buildProtocol(addr) p = f.buildProtocol(addr)
self.assertIdentical(p, p0) self.assertIdentical(p, p0)
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)]) self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr, "desc")])
self.assertEqual(p.mock_calls, [mock.call.use_relay(relay_handshake)]) self.assertEqual(p.mock_calls, [mock.call.use_relay(relay_handshake)])
self.assertIdentical(p.factory, f) self.assertIdentical(p.factory, f)
@ -63,10 +64,10 @@ class Inbound(unittest.TestCase):
p0 = mock.Mock() p0 = mock.Mock()
c.build_protocol = mock.Mock(return_value=p0) c.build_protocol = mock.Mock(return_value=p0)
f = InboundConnectionFactory(c) f = InboundConnectionFactory(c)
addr = object() addr = IPv4Address("TCP", "1.2.3.4", 55)
p = f.buildProtocol(addr) p = f.buildProtocol(addr)
self.assertIdentical(p, p0) self.assertIdentical(p, p0)
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)]) self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr, "<-tcp:1.2.3.4:55")])
self.assertIdentical(p.factory, f) self.assertIdentical(p.factory, f)
def make_connector(listen=True, tor=False, relay=None, role=roles.LEADER): def make_connector(listen=True, tor=False, relay=None, role=roles.LEADER):
@ -115,13 +116,13 @@ class TestConnector(unittest.TestCase):
return_value=n0) as bn: return_value=n0) as bn:
with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol", with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol",
return_value=p0) as dcp: return_value=p0) as dcp:
p = c.build_protocol(addr) p = c.build_protocol(addr, "desc")
self.assertEqual(bn.mock_calls, [mock.call()]) self.assertEqual(bn.mock_calls, [mock.call()])
self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key), self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key),
mock.call.set_as_initiator()]) mock.call.set_as_initiator()])
self.assertIdentical(p, p0) self.assertIdentical(p, p0)
self.assertEqual(dcp.mock_calls, self.assertEqual(dcp.mock_calls,
[mock.call(h.eq, h.role, c, n0, [mock.call(h.eq, h.role, "desc", c, n0,
PROLOGUE_LEADER, PROLOGUE_FOLLOWER)]) PROLOGUE_LEADER, PROLOGUE_FOLLOWER)])
def test_build_protocol_follower(self): def test_build_protocol_follower(self):
@ -133,13 +134,13 @@ class TestConnector(unittest.TestCase):
return_value=n0) as bn: return_value=n0) as bn:
with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol", with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol",
return_value=p0) as dcp: return_value=p0) as dcp:
p = c.build_protocol(addr) p = c.build_protocol(addr, "desc")
self.assertEqual(bn.mock_calls, [mock.call()]) self.assertEqual(bn.mock_calls, [mock.call()])
self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key), self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key),
mock.call.set_as_responder()]) mock.call.set_as_responder()])
self.assertIdentical(p, p0) self.assertIdentical(p, p0)
self.assertEqual(dcp.mock_calls, self.assertEqual(dcp.mock_calls,
[mock.call(h.eq, h.role, c, n0, [mock.call(h.eq, h.role, "desc", c, n0,
PROLOGUE_FOLLOWER, PROLOGUE_LEADER)]) PROLOGUE_FOLLOWER, PROLOGUE_LEADER)])
def test_start_stop(self): def test_start_stop(self):
@ -244,7 +245,7 @@ class TestConnector(unittest.TestCase):
with mock.patch("wormhole._dilation.connector.OutboundConnectionFactory", with mock.patch("wormhole._dilation.connector.OutboundConnectionFactory",
return_value=f) as ocf: return_value=f) as ocf:
h.clock.advance(1.0) h.clock.advance(1.0)
self.assertEqual(ocf.mock_calls, [mock.call(c, None)]) self.assertEqual(ocf.mock_calls, [mock.call(c, None, "->tcp:foo:55")])
self.assertEqual(ep.connect.mock_calls, [mock.call(f)]) self.assertEqual(ep.connect.mock_calls, [mock.call(f)])
p = mock.Mock() p = mock.Mock()
d.callback(p) d.callback(p)
@ -269,7 +270,7 @@ class TestConnector(unittest.TestCase):
return_value=f) as ocf: return_value=f) as ocf:
h.clock.advance(1.0) h.clock.advance(1.0)
handshake = build_sided_relay_handshake(h.dilation_key, h.side) handshake = build_sided_relay_handshake(h.dilation_key, h.side)
self.assertEqual(ocf.mock_calls, [mock.call(c, handshake)]) self.assertEqual(ocf.mock_calls, [mock.call(c, handshake, "->relay:tcp:foo:55")])
def test_listen_but_tor(self): def test_listen_but_tor(self):
c, h = make_connector(listen=True, tor=True, role=roles.LEADER) c, h = make_connector(listen=True, tor=True, role=roles.LEADER)