From 443d2489722de30188e2b7f6173d28cbf6c2ef28 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 8 Jul 2019 01:02:15 -0700 Subject: [PATCH 1/4] manager: call inbound.set_listener_endpoint() before start() This should fix the immediate issue of the remote side opening a subchannel (and sending data on it) before the local side even sees the Endpoints, so before it can register a listening factory to receive the OPEN. We were already buffering early OPENs in the SubchannelListenerEndpoint, but this makes sure that endpoint is available (for the manager's Inbound half to deliver) them as soon as the dilation connection is established. The downside to buffering OPENs (and all data written to inbound subchannels) is that the application has no way to reject or pause them, until it registers the listening factory. If the application never calls `listen_ep.listen()`, we'll buffer this data forever (or until the wormhole is closed). The upside is that we don't lose a roundtrip waiting for an ack on the OPEN. See ticket #335 for more details. refs #335 --- src/wormhole/_dilation/manager.py | 18 +++--- src/wormhole/test/dilate/test_manager.py | 78 ++++++++++++++++-------- 2 files changed, 61 insertions(+), 35 deletions(-) diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index 5ea2ae6..14b0c89 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -542,6 +542,15 @@ class Dilator(object): sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0) self._manager.set_subchannel_zero(scid0, sc0) + # we can open non-zero subchannels as soon as we get our first + # connection, and we can make the Endpoints even earlier + control_ep = ControlEndpoint(peer_addr0) + control_ep._subchannel_zero_opened(sc0) + connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) + + listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) + self._manager.set_listener_endpoint(listen_ep) + self._manager.start() while self._pending_inbound_dilate_messages: @@ -550,15 +559,6 @@ class Dilator(object): yield self._manager.when_first_connected() - # we can open non-zero subchannels as soon as we get our first - # connection - control_ep = ControlEndpoint(peer_addr0) - control_ep._subchannel_zero_opened(sc0) - connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) - - listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) - self._manager.set_listener_endpoint(listen_ep) - endpoints = (control_ep, connect_ep, listen_ep) returnValue(endpoints) diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index e36e403..cea92de 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -63,6 +63,12 @@ class TestDilator(unittest.TestCase): sc = mock.Mock() m_sc = mock.patch("wormhole._dilation.manager.SubChannel", return_value=sc) + ce = mock.Mock() + m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", + return_value=ce) + lep = object() + m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", + return_value=lep) scid0 = 0 m = mock.Mock() @@ -72,16 +78,21 @@ class TestDilator(unittest.TestCase): return_value=m) as ml: with mock.patch("wormhole._dilation.manager.make_side", return_value="us"): - with m_sca, m_sc as m_sc_m: - dil.got_wormhole_versions({"can-dilate": ["1"]}) + with m_sca, m_sc as m_sc_m, m_ce as m_ce_m, m_sle as m_sle_m: + dil.got_wormhole_versions({"can-dilate": ["1"]}) # that should create the Manager self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key, None, reactor, eq, coop, host_addr, False)]) + # and the three endpoints + self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)]) + self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)]) + self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)]) # and create subchannel0 self.assertEqual(m_sc_m.mock_calls, [mock.call(scid0, m, host_addr, peer_addr)]) # and tell it to start, and get wait-for-it-to-connect Deferred self.assertEqual(m.mock_calls, [mock.call.set_subchannel_zero(scid0, sc), + mock.call.set_listener_endpoint(lep), mock.call.start(), mock.call.when_first_connected(), ]) @@ -89,22 +100,9 @@ class TestDilator(unittest.TestCase): self.assertNoResult(d1) self.assertNoResult(d2) - ce = mock.Mock() - m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", - return_value=ce) - lep = object() - m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", - return_value=lep) - - with m_ce as m_ce_m, m_sle as m_sle_m: - wfc_d.callback(None) - eq.flush_sync() - self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)]) - self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)]) - self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)]) - self.assertEqual(m.mock_calls, - [mock.call.set_listener_endpoint(lep), - ]) + wfc_d.callback(None) + eq.flush_sync() + self.assertEqual(m.mock_calls, []) clear_mock_calls(m) eps = self.successResultOf(d1) @@ -181,16 +179,26 @@ class TestDilator(unittest.TestCase): sc = mock.Mock() m_sc = mock.patch("wormhole._dilation.manager.SubChannel", return_value=sc) + peer_addr = object() + m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress", + return_value=peer_addr) + ce = mock.Mock() + m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", + return_value=ce) + lep = object() + m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", + return_value=lep) with mock.patch("wormhole._dilation.manager.Manager", return_value=m) as ml: with mock.patch("wormhole._dilation.manager.make_side", return_value="us"): - with m_sc: + with m_sca, m_sc, m_ce, m_sle: dil.got_wormhole_versions({"can-dilate": ["1"]}) self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", None, reactor, eq, coop, host_addr, False)]) self.assertEqual(m.mock_calls, [mock.call.set_subchannel_zero(scid0, sc), + mock.call.set_listener_endpoint(lep), mock.call.start(), mock.call.rx_PLEASE(pleasemsg), mock.call.rx_HINTS(hintmsg), @@ -204,21 +212,39 @@ class TestDilator(unittest.TestCase): d1 = dil.dilate(transit_relay_location=relay) self.assertNoResult(d1) + m = mock.Mock() + alsoProvides(m, IDilationManager) + m.when_first_connected.return_value = Deferred() + scid0 = 0 sc = mock.Mock() m_sc = mock.patch("wormhole._dilation.manager.SubChannel", return_value=sc) + peer_addr = object() + m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress", + return_value=peer_addr) + ce = mock.Mock() + m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", + return_value=ce) + lep = object() + m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", + return_value=lep) - with mock.patch("wormhole._dilation.manager.Manager") as ml: + with mock.patch("wormhole._dilation.manager.Manager", + return_value=m) as ml: with mock.patch("wormhole._dilation.manager.make_side", return_value="us"): - with m_sc: + with m_sca, m_sc, m_ce, m_sle: dil.got_wormhole_versions({"can-dilate": ["1"]}) - self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", - relay, reactor, eq, coop, host_addr, False), - mock.call().set_subchannel_zero(scid0, sc), - mock.call().start(), - mock.call().when_first_connected()]) + + self.assertEqual(ml.mock_calls, [ + mock.call(send, "us", b"key", + relay, reactor, eq, coop, host_addr, False)]) + self.assertEqual(m.mock_calls, [ + mock.call.set_subchannel_zero(scid0, sc), + mock.call.set_listener_endpoint(lep), + mock.call.start(), + mock.call.when_first_connected()]) LEADER = "ff3456abcdef" From 75fad02a28eaca12f67c9ea7ac78894b17e4858d Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 8 Jul 2019 01:02:39 -0700 Subject: [PATCH 2/4] subchannel: queue pending opens with a deque(), not a list slightly cleaner --- src/wormhole/_dilation/subchannel.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py index bcdff6a..32c9cff 100644 --- a/src/wormhole/_dilation/subchannel.py +++ b/src/wormhole/_dilation/subchannel.py @@ -1,4 +1,5 @@ import six +from collections import deque from attr import attrs, attrib from attr.validators import instance_of, provides from zope.interface import implementer @@ -269,9 +270,9 @@ class SubchannelListenerEndpoint(object): def __attrs_post_init__(self): self._factory = None - self._pending_opens = [] + self._pending_opens = deque() - # from manager + # from manager (actually Inbound) def _got_open(self, t, peer_addr): if self._factory: self._connect(t, peer_addr) @@ -287,9 +288,9 @@ class SubchannelListenerEndpoint(object): def listen(self, protocolFactory): self._factory = protocolFactory - for (t, peer_addr) in self._pending_opens: + while self._pending_opens: + (t, peer_addr) = self._pending_opens.popleft() self._connect(t, peer_addr) - self._pending_opens = [] lp = SubchannelListeningPort(self._host_addr) return succeed(lp) From 85cb0034987a3256696626d0d5de6600c92fedba Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 12 Jul 2019 00:01:55 -0700 Subject: [PATCH 3/4] WIP: rewrite w.dilate API to return endpoints synchronously test_manager still needs rewriting --- src/wormhole/_dilation/manager.py | 270 ++++++++++--------- src/wormhole/_dilation/subchannel.py | 55 +++- src/wormhole/test/dilate/test_endpoints.py | 299 +++++++++++++++++++-- src/wormhole/test/dilate/test_full.py | 45 ++-- src/wormhole/test/dilate/test_manager.py | 28 +- 5 files changed, 500 insertions(+), 197 deletions(-) diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index 14b0c89..19430bb 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -2,13 +2,20 @@ from __future__ import print_function, unicode_literals import six import os from collections import deque +try: + # py >= 3.3 + from collections.abc import Sequence +except ImportError: + # py 2 and py3 < 3.3 + from collections import Sequence from attr import attrs, attrib from attr.validators import provides, instance_of, optional from automat import MethodicalMachine from zope.interface import implementer -from twisted.internet.defer import Deferred, inlineCallbacks, returnValue -from twisted.internet.interfaces import IAddress -from twisted.python import log +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import (IStreamClientEndpoint, + IStreamServerEndpoint) +from twisted.python import log, failure from .._interfaces import IDilator, IDilationManager, ISend, ITerminator from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr from ..observer import OneShotObserver @@ -47,6 +54,15 @@ class UnexpectedKCM(Exception): class UnknownMessageType(Exception): pass +@attrs +class EndpointRecord(Sequence): + control = attrib(validator=provides(IStreamClientEndpoint)) + connect = attrib(validator=provides(IStreamClientEndpoint)) + listen = attrib(validator=provides(IStreamServerEndpoint)) + def __len__(self): + return 3 + def __getitem__(self, n): + return (self.control, self.connect, self.listen)[n] def make_side(): return bytes_to_hexstr(os.urandom(6)) @@ -93,13 +109,14 @@ def make_side(): class Manager(object): _S = attrib(validator=provides(ISend), repr=False) _my_side = attrib(validator=instance_of(type(u""))) - _transit_key = attrib(validator=instance_of(bytes), repr=False) _transit_relay_location = attrib(validator=optional(instance_of(str))) _reactor = attrib(repr=False) _eventual_queue = attrib(repr=False) _cooperator = attrib(repr=False) - _host_addr = attrib(validator=provides(IAddress)) - _no_listen = attrib(default=False) + # TODO: can this validator work when the parameter is optional? + _no_listen = attrib(validator=instance_of(bool), default=False) + + _dilation_key = None _tor = None # TODO _timing = None # TODO _next_subchannel_id = None # initialized in choose_role @@ -111,10 +128,10 @@ class Manager(object): self._got_versions_d = Deferred() self._my_role = None # determined upon rx_PLEASE + self._host_addr = _WormholeAddress() self._connection = None self._made_first_connection = False - self._first_connected = OneShotObserver(self._eventual_queue) self._stopped = OneShotObserver(self._eventual_queue) self._debug_stall_connector = False @@ -127,18 +144,81 @@ class Manager(object): self._inbound = Inbound(self, self._host_addr) self._outbound = Outbound(self, self._cooperator) # from us to peer - def set_listener_endpoint(self, listener_endpoint): - self._inbound.set_listener_endpoint(listener_endpoint) - - def set_subchannel_zero(self, scid0, sc0): + # We must open subchannel0 early, since messages may arrive very + # quickly once the connection is established. This subchannel may or + # may not ever get revealed to the caller, since the peer might not + # even be capable of dilation. + scid0 = 0 + peer_addr0 = _SubchannelAddress(scid0) + sc0 = SubChannel(scid0, self, self._host_addr, peer_addr0) self._inbound.set_subchannel_zero(scid0, sc0) - def when_first_connected(self): - return self._first_connected.when_fired() + # we can open non-zero subchannels as soon as we get our first + # connection, and we can make the Endpoints even earlier + control_ep = ControlEndpoint(peer_addr0, sc0, self._eventual_queue) + connect_ep = SubchannelConnectorEndpoint(self, self._host_addr, self._eventual_queue) + listen_ep = SubchannelListenerEndpoint(self, self._host_addr, self._eventual_queue) + # TODO: let inbound/outbound create the endpoints, then return them + # to us + self._inbound.set_listener_endpoint(listen_ep) + + self._endpoints = EndpointRecord(control_ep, connect_ep, listen_ep) + + def get_endpoints(self): + return self._endpoints + + def got_dilation_key(self, key): + assert isinstance(key, bytes) + self._dilation_key = key + + def got_wormhole_versions(self, their_wormhole_versions): + # this always happens before received_dilation_message + dilation_version = None + their_dilation_versions = set(their_wormhole_versions.get("can-dilate", [])) + my_versions = set(DILATION_VERSIONS) + shared_versions = my_versions.intersection(their_dilation_versions) + if "1" in shared_versions: + dilation_version = "1" + + # dilation_version is the best mutually-compatible version we have + # with the peer, or None if we have nothing in common + + if not dilation_version: # "1" or None + # TODO: be more specific about the error. dilation_version==None + # means we had no version in common with them, which could either + # be because they're so old they don't dilate at all, or because + # they're so new that they no longer accomodate our old version + self.fail(failure.Failure(OldPeerCannotDilateError())) + + self.start() + + def fail(self, f): + self._endpoints.control._main_channel_failed(f) + self._endpoints.connect._main_channel_failed(f) + self._endpoints.listen._main_channel_failed(f) + + def received_dilation_message(self, plaintext): + # this receives new in-order DILATE-n payloads, decrypted but not + # de-JSONed. + + message = bytes_to_dict(plaintext) + type = message["type"] + if type == "please": + self.rx_PLEASE(message) + elif type == "connection-hints": + self.rx_HINTS(message) + elif type == "reconnect": + self.rx_RECONNECT() + elif type == "reconnecting": + self.rx_RECONNECTING() + else: + log.err(UnknownDilationMessageType(message)) + return def when_stopped(self): return self._stopped.when_fired() + def send_dilation_phase(self, **fields): dilation_phase = self._next_dilation_phase self._next_dilation_phase += 1 @@ -204,7 +284,9 @@ class Manager(object): self._outbound.use_connection(c) # does c.registerProducer if not self._made_first_connection: self._made_first_connection = True - self._first_connected.fire(None) + self._endpoints.control._main_channel_ready() + self._endpoints.connect._main_channel_ready() + self._endpoints.listen._main_channel_ready() pass def connector_connection_lost(self): @@ -272,16 +354,11 @@ class Manager(object): # state machine - # We are born WANTING after the local app calls w.dilate(). We start - # CONNECTING when we receive PLEASE from the remote side - - def start(self): - self.send_please() - - def send_please(self): - self.send_dilation_phase(type="please", side=self._my_side) - @m.state(initial=True) + def WAITING(self): + pass # pragma: no cover + + @m.state() def WANTING(self): pass # pragma: no cover @@ -313,6 +390,10 @@ class Manager(object): def STOPPED(self): pass # pragma: no cover + @m.input() + def start(self): + pass # pragma: no cover + @m.input() def rx_PLEASE(self, message): pass # pragma: no cover @@ -350,6 +431,10 @@ class Manager(object): def stop(self): pass # pragma: no cover + @m.output() + def send_please(self): + self.send_dilation_phase(type="please", side=self._my_side) + @m.output() def choose_role(self, message): their_side = message["side"] @@ -378,7 +463,8 @@ class Manager(object): def _start_connecting(self): assert self._my_role is not None - self._connector = Connector(self._transit_key, + assert self._dilation_key is not None + self._connector = Connector(self._dilation_key, self._transit_relay_location, self, self._reactor, self._eventual_queue, @@ -422,6 +508,11 @@ class Manager(object): def notify_stopped(self): self._stopped.fire(None) + # We are born WAITING after the local app calls w.dilate(). We enter + # WANTING (and send a PLEASE) when we learn of a mutually-compatible + # dilation_version. + WAITING.upon(start, enter=WANTING, outputs=[send_please]) + # we start CONNECTING when we get rx_PLEASE WANTING.upon(rx_PLEASE, enter=CONNECTING, outputs=[choose_role, start_connecting_ignore_message]) @@ -489,12 +580,10 @@ class Dilator(object): _cooperator = attrib() def __attrs_post_init__(self): - self._got_versions_d = Deferred() - self._started = False - self._endpoints = OneShotObserver(self._eventual_queue) - self._pending_inbound_dilate_messages = deque() self._manager = None - self._host_addr = _WormholeAddress() + self._pending_dilation_key = None + self._pending_wormhole_versions = None + self._pending_inbound_dilate_messages = deque() def wire(self, sender, terminator): self._S = ISend(sender) @@ -502,77 +591,35 @@ class Dilator(object): # this is the primary entry point, called when w.dilate() is invoked def dilate(self, transit_relay_location=None, no_listen=False): - self._transit_relay_location = transit_relay_location - if not self._started: - self._started = True - self._start(no_listen).addBoth(self._endpoints.fire) - return self._endpoints.when_fired() - - @inlineCallbacks - def _start(self, no_listen): - # first, we wait until we hear the VERSION message, which tells us 1: - # the PAKE key works, so we can talk securely, 2: that they can do - # dilation at all (if they can't then w.dilate() errbacks) - - dilation_version = yield self._got_versions_d - - # TODO: we could probably return the endpoints earlier, if we flunk - # any connection/listen attempts upon OldPeerCannotDilateError, or - # if/when we give up on the initial connection - - if not dilation_version: # "1" or None - # TODO: be more specific about the error. dilation_version==None - # means we had no version in common with them, which could either - # be because they're so old they don't dilate at all, or because - # they're so new that they no longer accomodate our old version - raise OldPeerCannotDilateError() - - my_dilation_side = make_side() - self._manager = Manager(self._S, my_dilation_side, - self._transit_key, - self._transit_relay_location, - self._reactor, self._eventual_queue, - self._cooperator, self._host_addr, no_listen) - # We must open subchannel0 early, since messages may arrive very - # quickly once the connection is established. This subchannel may or - # may not ever get revealed to the caller, since the peer might not - # even be capable of dilation. - scid0 = 0 - peer_addr0 = _SubchannelAddress(scid0) - sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0) - self._manager.set_subchannel_zero(scid0, sc0) - - # we can open non-zero subchannels as soon as we get our first - # connection, and we can make the Endpoints even earlier - control_ep = ControlEndpoint(peer_addr0) - control_ep._subchannel_zero_opened(sc0) - connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) - - listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) - self._manager.set_listener_endpoint(listen_ep) - - self._manager.start() - - while self._pending_inbound_dilate_messages: - plaintext = self._pending_inbound_dilate_messages.popleft() - self.received_dilate(plaintext) - - yield self._manager.when_first_connected() - - endpoints = (control_ep, connect_ep, listen_ep) - returnValue(endpoints) + if not self._manager: + # build the manager right away, and tell it later when the + # VERSIONS message arrives, and also when the dilation_key is set + my_dilation_side = make_side() + m = Manager(self._S, my_dilation_side, + transit_relay_location, + self._reactor, self._eventual_queue, + self._cooperator, no_listen) + self._manager = m + if self._pending_dilation_key is not None: + m.got_dilation_key(self._pending_dilation_key) + if self._pending_wormhole_versions: + m.got_wormhole_versions(self._pending_wormhole_versions) + while self._pending_inbound_dilate_messages: + plaintext = self._pending_inbound_dilate_messages.popleft() + m.received_dilation_message(plaintext) + return self._manager.get_endpoints() # Called by Terminator after everything else (mailbox, nameplate, server # connection) has shut down. Expects to fire T.stoppedD() when Dilator is # stopped too. def stop(self): - if not self._started: - self._T.stoppedD() - return - if self._started: + if self._manager: self._manager.stop() # TODO: avoid Deferreds for control flow, hard to serialize self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD()) + else: + self._T.stoppedD() + return # TODO: tolerate multiple calls # from Boss @@ -582,39 +629,20 @@ class Dilator(object): # to tolerate either ordering purpose = b"dilation-v1" LENGTH = 32 # TODO: whatever Noise wants, I guess - self._transit_key = derive_key(key, purpose, LENGTH) + dilation_key = derive_key(key, purpose, LENGTH) + if self._manager: + self._manager.got_dilation_key(dilation_key) + else: + self._pending_dilation_key = dilation_key def got_wormhole_versions(self, their_wormhole_versions): - assert self._transit_key is not None - # this always happens before received_dilate - dilation_version = None - their_dilation_versions = set(their_wormhole_versions.get("can-dilate", [])) - my_versions = set(DILATION_VERSIONS) - shared_versions = my_versions.intersection(their_dilation_versions) - if "1" in shared_versions: - dilation_version = "1" - self._got_versions_d.callback(dilation_version) + if self._manager: + self._manager.got_wormhole_versions(their_wormhole_versions) + else: + self._pending_wormhole_versions = their_wormhole_versions def received_dilate(self, plaintext): - # this receives new in-order DILATE-n payloads, decrypted but not - # de-JSONed. - - # this can appear before our .dilate() method is called, in which case - # we queue them for later if not self._manager: self._pending_inbound_dilate_messages.append(plaintext) - return - - message = bytes_to_dict(plaintext) - type = message["type"] - if type == "please": - self._manager.rx_PLEASE(message) - elif type == "connection-hints": - self._manager.rx_HINTS(message) - elif type == "reconnect": - self._manager.rx_RECONNECT() - elif type == "reconnecting": - self._manager.rx_RECONNECTING() else: - log.err(UnknownDilationMessageType(message)) - return + self._manager.received_dilation_message(plaintext) diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py index 32c9cff..e1e811c 100644 --- a/src/wormhole/_dilation/subchannel.py +++ b/src/wormhole/_dilation/subchannel.py @@ -3,8 +3,7 @@ from collections import deque from attr import attrs, attrib from attr.validators import instance_of, provides from zope.interface import implementer -from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue, - succeed) +from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, IAddress, IListeningPort, IStreamClientEndpoint, @@ -12,6 +11,7 @@ from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, from twisted.internet.error import ConnectionDone from automat import MethodicalMachine from .._interfaces import ISubChannel, IDilationManager +from ..observer import OneShotObserver # each subchannel frame (the data passed into transport.write(data)) gets a # 9-byte header prefix (type, subchannel id, and sequence number), then gets @@ -217,27 +217,33 @@ class SubChannel(object): @implementer(IStreamClientEndpoint) +@attrs class ControlEndpoint(object): + _peer_addr = attrib(validator=provides(IAddress)) + _subchannel_zero = attrib(validator=provides(ISubChannel)) + _eventual_queue = attrib(repr=False) _used = False - def __init__(self, peer_addr): - self._subchannel_zero = Deferred() - self._peer_addr = peer_addr + def __attrs_post_init__(self): self._once = Once(SingleUseEndpointError) + self._wait_for_main_channel = OneShotObserver(self._eventual_queue) # from manager - def _subchannel_zero_opened(self, subchannel): - assert ISubChannel.providedBy(subchannel), subchannel - self._subchannel_zero.callback(subchannel) + + def _main_channel_ready(self): + self._wait_for_main_channel.fire(None) + def _main_channel_failed(self, f): + self._wait_for_main_channel.error(f) @inlineCallbacks def connect(self, protocolFactory): # return Deferred that fires with IProtocol or Failure(ConnectError) self._once() - t = yield self._subchannel_zero + yield self._wait_for_main_channel.when_fired() p = protocolFactory.buildProtocol(self._peer_addr) - t._set_protocol(p) - p.makeConnection(t) # set p.transport = t and call connectionMade() + self._subchannel_zero._set_protocol(p) + # this sets p.transport and calls p.connectionMade() + p.makeConnection(self._subchannel_zero) returnValue(p) @@ -246,9 +252,21 @@ class ControlEndpoint(object): class SubchannelConnectorEndpoint(object): _manager = attrib(validator=provides(IDilationManager)) _host_addr = attrib(validator=instance_of(_WormholeAddress)) + _eventual_queue = attrib(repr=False) + def __attrs_post_init__(self): + self._connection_deferreds = deque() + self._wait_for_main_channel = OneShotObserver(self._eventual_queue) + + def _main_channel_ready(self): + self._wait_for_main_channel.fire(None) + def _main_channel_failed(self, f): + self._wait_for_main_channel.error(f) + + @inlineCallbacks def connect(self, protocolFactory): # return Deferred that fires with IProtocol or Failure(ConnectError) + yield self._wait_for_main_channel.when_fired() scid = self._manager.allocate_subchannel_id() self._manager.send_open(scid) peer_addr = _SubchannelAddress(scid) @@ -259,7 +277,7 @@ class SubchannelConnectorEndpoint(object): p = protocolFactory.buildProtocol(peer_addr) sc._set_protocol(p) p.makeConnection(sc) # set p.transport = sc and call connectionMade() - return succeed(p) + returnValue(p) @implementer(IStreamServerEndpoint) @@ -267,10 +285,13 @@ class SubchannelConnectorEndpoint(object): class SubchannelListenerEndpoint(object): _manager = attrib(validator=provides(IDilationManager)) _host_addr = attrib(validator=provides(IAddress)) + _eventual_queue = attrib(repr=False) def __attrs_post_init__(self): + self._once = Once(SingleUseEndpointError) self._factory = None self._pending_opens = deque() + self._wait_for_main_channel = OneShotObserver(self._eventual_queue) # from manager (actually Inbound) def _got_open(self, t, peer_addr): @@ -284,15 +305,23 @@ class SubchannelListenerEndpoint(object): t._set_protocol(p) p.makeConnection(t) + def _main_channel_ready(self): + self._wait_for_main_channel.fire(None) + def _main_channel_failed(self, f): + self._wait_for_main_channel.error(f) + # IStreamServerEndpoint + @inlineCallbacks def listen(self, protocolFactory): + self._once() + yield self._wait_for_main_channel.when_fired() self._factory = protocolFactory while self._pending_opens: (t, peer_addr) = self._pending_opens.popleft() self._connect(t, peer_addr) lp = SubchannelListeningPort(self._host_addr) - return succeed(lp) + returnValue(lp) @implementer(IListeningPort) diff --git a/src/wormhole/test/dilate/test_endpoints.py b/src/wormhole/test/dilate/test_endpoints.py index 08c5600..647fb9d 100644 --- a/src/wormhole/test/dilate/test_endpoints.py +++ b/src/wormhole/test/dilate/test_endpoints.py @@ -2,7 +2,10 @@ from __future__ import print_function, unicode_literals import mock from zope.interface import alsoProvides from twisted.trial import unittest +from twisted.internet.task import Clock +from twisted.python.failure import Failure from ..._interfaces import ISubChannel +from ...eventual import EventualQueue from ..._dilation.subchannel import (ControlEndpoint, SubchannelConnectorEndpoint, SubchannelListenerEndpoint, @@ -11,12 +14,18 @@ from ..._dilation.subchannel import (ControlEndpoint, SingleUseEndpointError) from .common import mock_manager +class CannotDilateError(Exception): + pass -class Endpoints(unittest.TestCase): - def test_control(self): +class Control(unittest.TestCase): + def test_early_succeed(self): + # ep.connect() is called before dilation can proceed scid0 = 0 peeraddr = _SubchannelAddress(scid0) - ep = ControlEndpoint(peeraddr) + sc0 = mock.Mock() + alsoProvides(sc0, ISubChannel) + eq = EventualQueue(Clock()) + ep = ControlEndpoint(peeraddr, sc0, eq) f = mock.Mock() p = mock.Mock() @@ -24,29 +33,105 @@ class Endpoints(unittest.TestCase): d = ep.connect(f) self.assertNoResult(d) - t = mock.Mock() - alsoProvides(t, ISubChannel) - ep._subchannel_zero_opened(t) + ep._main_channel_ready() + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), p) self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) - self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) - self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) + self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)]) d = ep.connect(f) self.failureResultOf(d, SingleUseEndpointError) - def assert_makeConnection(self, mock_calls): + def test_early_fail(self): + # ep.connect() is called before dilation is abandoned + scid0 = 0 + peeraddr = _SubchannelAddress(scid0) + sc0 = mock.Mock() + alsoProvides(sc0, ISubChannel) + eq = EventualQueue(Clock()) + ep = ControlEndpoint(peeraddr, sc0, eq) + + f = mock.Mock() + p = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + d = ep.connect(f) + self.assertNoResult(d) + + ep._main_channel_failed(Failure(CannotDilateError())) + eq.flush_sync() + + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + self.assertEqual(sc0.mock_calls, []) + + d = ep.connect(f) + self.failureResultOf(d, SingleUseEndpointError) + + def test_late_succeed(self): + # dilation can proceed, then ep.connect() is called + scid0 = 0 + peeraddr = _SubchannelAddress(scid0) + sc0 = mock.Mock() + alsoProvides(sc0, ISubChannel) + eq = EventualQueue(Clock()) + ep = ControlEndpoint(peeraddr, sc0, eq) + + ep._main_channel_ready() + + f = mock.Mock() + p = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + d = ep.connect(f) + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), p) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) + self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)]) + + d = ep.connect(f) + self.failureResultOf(d, SingleUseEndpointError) + + def test_late_fail(self): + # dilation is abandoned, then ep.connect() is called + scid0 = 0 + peeraddr = _SubchannelAddress(scid0) + sc0 = mock.Mock() + alsoProvides(sc0, ISubChannel) + eq = EventualQueue(Clock()) + ep = ControlEndpoint(peeraddr, sc0, eq) + + ep._main_channel_failed(Failure(CannotDilateError())) + + f = mock.Mock() + p = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + d = ep.connect(f) + eq.flush_sync() + + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + self.assertEqual(sc0.mock_calls, []) + + d = ep.connect(f) + self.failureResultOf(d, SingleUseEndpointError) + +class Endpoints(unittest.TestCase): + def OFFassert_makeConnection(self, mock_calls): self.assertEqual(len(mock_calls), 1) self.assertEqual(mock_calls[0][0], "makeConnection") self.assertEqual(len(mock_calls[0][1]), 1) return mock_calls[0][1][0] - def test_connector(self): +class Connector(unittest.TestCase): + def test_early_succeed(self): m = mock_manager() m.allocate_subchannel_id = mock.Mock(return_value=0) hostaddr = _WormholeAddress() peeraddr = _SubchannelAddress(0) - ep = SubchannelConnectorEndpoint(m, hostaddr) + eq = EventualQueue(Clock()) + ep = SubchannelConnectorEndpoint(m, hostaddr, eq) f = mock.Mock() p = mock.Mock() @@ -55,38 +140,123 @@ class Endpoints(unittest.TestCase): with mock.patch("wormhole._dilation.subchannel.SubChannel", return_value=t) as sc: d = ep.connect(f) + eq.flush_sync() + self.assertNoResult(d) + ep._main_channel_ready() + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), p) self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)]) self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) - def test_listener(self): + def test_early_fail(self): m = mock_manager() m.allocate_subchannel_id = mock.Mock(return_value=0) hostaddr = _WormholeAddress() - ep = SubchannelListenerEndpoint(m, hostaddr) + eq = EventualQueue(Clock()) + ep = SubchannelConnectorEndpoint(m, hostaddr, eq) + + f = mock.Mock() + p = mock.Mock() + t = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + with mock.patch("wormhole._dilation.subchannel.SubChannel", + return_value=t) as sc: + d = ep.connect(f) + eq.flush_sync() + self.assertNoResult(d) + ep._main_channel_failed(Failure(CannotDilateError())) + eq.flush_sync() + + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + self.assertEqual(sc.mock_calls, []) + self.assertEqual(t.mock_calls, []) + + def test_late_succeed(self): + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + peeraddr = _SubchannelAddress(0) + eq = EventualQueue(Clock()) + ep = SubchannelConnectorEndpoint(m, hostaddr, eq) + ep._main_channel_ready() + + f = mock.Mock() + p = mock.Mock() + t = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + with mock.patch("wormhole._dilation.subchannel.SubChannel", + return_value=t) as sc: + d = ep.connect(f) + eq.flush_sync() + + self.assertIdentical(self.successResultOf(d), p) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) + self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)]) + self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) + + def test_late_fail(self): + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelConnectorEndpoint(m, hostaddr, eq) + ep._main_channel_failed(Failure(CannotDilateError())) + + f = mock.Mock() + p = mock.Mock() + t = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + with mock.patch("wormhole._dilation.subchannel.SubChannel", + return_value=t) as sc: + d = ep.connect(f) + eq.flush_sync() + + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + self.assertEqual(sc.mock_calls, []) + self.assertEqual(t.mock_calls, []) + +class Listener(unittest.TestCase): + def test_early_succeed(self): + # listen, main_channel_ready, got_open, got_open + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelListenerEndpoint(m, hostaddr, eq) f = mock.Mock() p1 = mock.Mock() p2 = mock.Mock() f.buildProtocol = mock.Mock(side_effect=[p1, p2]) - # OPEN that arrives before we ep.listen() should be queued + d = ep.listen(f) + eq.flush_sync() + self.assertNoResult(d) + self.assertEqual(f.buildProtocol.mock_calls, []) + + ep._main_channel_ready() + eq.flush_sync() + lp = self.successResultOf(d) + self.assertIsInstance(lp, SubchannelListeningPort) + + self.assertEqual(lp.getHost(), hostaddr) + # TODO: IListeningPort says we must provide this, but I don't know + # that anyone would ever call it. + lp.startListening() t1 = mock.Mock() peeraddr1 = _SubchannelAddress(1) ep._got_open(t1, peeraddr1) - d = ep.listen(f) - lp = self.successResultOf(d) - self.assertIsInstance(lp, SubchannelListeningPort) - - self.assertEqual(lp.getHost(), hostaddr) - lp.startListening() - self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)]) self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)]) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)]) t2 = mock.Mock() peeraddr2 = _SubchannelAddress(2) @@ -94,5 +264,92 @@ class Endpoints(unittest.TestCase): self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1), + mock.call(peeraddr2)]) lp.stopListening() # TODO: should this do more? + + def test_early_fail(self): + # listen, main_channel_fail + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelListenerEndpoint(m, hostaddr, eq) + + f = mock.Mock() + p1 = mock.Mock() + p2 = mock.Mock() + f.buildProtocol = mock.Mock(side_effect=[p1, p2]) + + d = ep.listen(f) + eq.flush_sync() + self.assertNoResult(d) + + ep._main_channel_failed(Failure(CannotDilateError())) + eq.flush_sync() + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + + def test_late_succeed(self): + # main_channel_ready, got_open, listen, got_open + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelListenerEndpoint(m, hostaddr, eq) + ep._main_channel_ready() + + f = mock.Mock() + p1 = mock.Mock() + p2 = mock.Mock() + f.buildProtocol = mock.Mock(side_effect=[p1, p2]) + + t1 = mock.Mock() + peeraddr1 = _SubchannelAddress(1) + ep._got_open(t1, peeraddr1) + eq.flush_sync() + + self.assertEqual(t1.mock_calls, []) + self.assertEqual(p1.mock_calls, []) + + d = ep.listen(f) + eq.flush_sync() + lp = self.successResultOf(d) + self.assertIsInstance(lp, SubchannelListeningPort) + self.assertEqual(lp.getHost(), hostaddr) + lp.startListening() + + self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)]) + self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)]) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)]) + + t2 = mock.Mock() + peeraddr2 = _SubchannelAddress(2) + ep._got_open(t2, peeraddr2) + + self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) + self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1), + mock.call(peeraddr2)]) + + lp.stopListening() # TODO: should this do more? + + def test_late_fail(self): + # main_channel_fail, listen + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelListenerEndpoint(m, hostaddr, eq) + ep._main_channel_failed(Failure(CannotDilateError())) + + f = mock.Mock() + p1 = mock.Mock() + p2 = mock.Mock() + f.buildProtocol = mock.Mock(side_effect=[p1, p2]) + + d = ep.listen(f) + eq.flush_sync() + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) diff --git a/src/wormhole/test/dilate/test_full.py b/src/wormhole/test/dilate/test_full.py index cc0e8b0..79c0019 100644 --- a/src/wormhole/test/dilate/test_full.py +++ b/src/wormhole/test/dilate/test_full.py @@ -46,24 +46,21 @@ class Full(ServerBase, unittest.TestCase): yield doBoth(w1.get_verifier(), w2.get_verifier()) print("connected") - eps1_d = w1.dilate() - eps2_d = w2.dilate() - (eps1, eps2) = yield doBoth(eps1_d, eps2_d) - (control_ep1, connect_ep1, listen_ep1) = eps1 - (control_ep2, connect_ep2, listen_ep2) = eps2 + eps1 = w1.dilate() + eps2 = w2.dilate() print("w.dilate ready") f1 = Factory() f1.protocol = L f1.d = Deferred() f1.d.addCallback(lambda data: eq.fire_eventually(data)) - d1 = control_ep1.connect(f1) + d1 = eps1.control.connect(f1) f2 = Factory() f2.protocol = L f2.d = Deferred() f2.d.addCallback(lambda data: eq.fire_eventually(data)) - d2 = control_ep2.connect(f2) + d2 = eps2.control.connect(f2) yield d1 yield d2 print("control endpoints connected") @@ -125,14 +122,12 @@ class Reconnect(ServerBase, unittest.TestCase): w2.set_code(code) yield doBoth(w1.get_verifier(), w2.get_verifier()) - eps1_d = w1.dilate() - eps2_d = w2.dilate() - (eps1, eps2) = yield doBoth(eps1_d, eps2_d) - (control_ep1, connect_ep1, listen_ep1) = eps1 - (control_ep2, connect_ep2, listen_ep2) = eps2 + eps1 = w1.dilate() + eps2 = w2.dilate() + print("w.dilate ready") f1 = ReconF(eq); f2 = ReconF(eq) - d1 = control_ep1.connect(f1); d2 = control_ep2.connect(f2) + d1 = eps1.control.connect(f1); d2 = eps2.control.connect(f2) yield d1 yield d2 @@ -194,14 +189,12 @@ class Reconnect(ServerBase, unittest.TestCase): w2.set_code(code) yield doBoth(w1.get_verifier(), w2.get_verifier()) - eps1_d = w1.dilate() - eps2_d = w2.dilate() - (eps1, eps2) = yield doBoth(eps1_d, eps2_d) - (control_ep1, connect_ep1, listen_ep1) = eps1 - (control_ep2, connect_ep2, listen_ep2) = eps2 + eps1 = w1.dilate() + eps2 = w2.dilate() + print("w.dilate ready") f1 = ReconF(eq); f2 = ReconF(eq) - d1 = control_ep1.connect(f1); d2 = control_ep2.connect(f2) + d1 = eps1.control.connect(f1); d2 = eps2.control.connect(f2) yield d1 yield d2 @@ -287,19 +280,17 @@ class Endpoints(ServerBase, unittest.TestCase): w2.set_code(code) yield doBoth(w1.get_verifier(), w2.get_verifier()) - eps1_d = w1.dilate() - eps2_d = w2.dilate() - (eps1, eps2) = yield doBoth(eps1_d, eps2_d) - (control_ep1, connect_ep1, listen_ep1) = eps1 - (control_ep2, connect_ep2, listen_ep2) = eps2 + eps1 = w1.dilate() + eps2 = w2.dilate() + print("w.dilate ready") f0 = ReconF(eq) - yield listen_ep2.listen(f0) + yield eps2.listen.listen(f0) from twisted.python import log f1 = ReconF(eq) log.msg("connecting") - p1_client = yield connect_ep1.connect(f1) + p1_client = yield eps1.connect.connect(f1) log.msg("sending c->s") p1_client.transport.write(b"hello from p1\n") data = yield f0.deferreds["dataReceived"] @@ -316,7 +307,7 @@ class Endpoints(ServerBase, unittest.TestCase): f0.resetDeferred("dataReceived") f1.resetDeferred("dataReceived") f2 = ReconF(eq) - p2_client = yield connect_ep1.connect(f2) + p2_client = yield eps1.connect.connect(f2) p2_server = yield f0.deferreds["connectionMade"] p2_server.transport.write(b"hello p2\n") data = yield f2.deferreds["dataReceived"] diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index cea92de..309c1fd 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -140,25 +140,23 @@ class TestDilator(unittest.TestCase): def test_peer_cannot_dilate(self): dil, send, reactor, eq, clock, coop = make_dilator() - d1 = dil.dilate() - self.assertNoResult(d1) + eps = dil.dilate() - dil._transit_key = b"\x01" * 32 + dil.got_key(b"\x01" * 32) dil.got_wormhole_versions({}) # missing "can-dilate" + d = eps.connect.connect(None) eq.flush_sync() - f = self.failureResultOf(d1) - f.check(OldPeerCannotDilateError) + self.failureResultOf(d).check(OldPeerCannotDilateError) def test_disjoint_versions(self): dil, send, reactor, eq, clock, coop = make_dilator() - d1 = dil.dilate() - self.assertNoResult(d1) + eps = dil.dilate() - dil._transit_key = b"key" + dil.got_key(b"\x01" * 32) dil.got_wormhole_versions({"can-dilate": [-1]}) + d = eps.connect.connect(None) eq.flush_sync() - f = self.failureResultOf(d1) - f.check(OldPeerCannotDilateError) + self.failureResultOf(d).check(OldPeerCannotDilateError) def test_early_dilate_messages(self): dil, send, reactor, eq, clock, coop = make_dilator() @@ -276,11 +274,11 @@ def make_manager(leader=True): h.Inbound = mock.Mock(return_value=h.inbound) h.outbound = mock.Mock() h.Outbound = mock.Mock(return_value=h.outbound) - h.hostaddr = mock.Mock() - alsoProvides(h.hostaddr, IAddress) - with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound): - with mock.patch("wormhole._dilation.manager.Outbound", h.Outbound): - m = Manager(h.send, side, h.key, h.relay, h.reactor, h.eq, h.coop, h.hostaddr) + with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound), \ + mock.patch("wormhole._dilation.manager.Outbound", h.Outbound): + m = Manager(h.send, side, h.relay, h.reactor, h.eq, h.coop) + h.hostaddr = m._host_addr + m.got_dilation_key(h.key) return m, h From b633602a02aef56aaae905d0c663c6c827ae7162 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 13 Jul 2019 19:25:50 -0700 Subject: [PATCH 4/4] update test_manager to match --- src/wormhole/test/dilate/test_manager.py | 362 ++++++++++------------- 1 file changed, 149 insertions(+), 213 deletions(-) diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index 309c1fd..3fdd051 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -1,12 +1,11 @@ from __future__ import print_function, unicode_literals from zope.interface import alsoProvides from twisted.trial import unittest -from twisted.internet.defer import Deferred from twisted.internet.task import Clock, Cooperator -from twisted.internet.interfaces import IAddress +from twisted.internet.interfaces import IStreamServerEndpoint import mock from ...eventual import EventualQueue -from ..._interfaces import ISend, IDilationManager, ITerminator +from ..._interfaces import ISend, ITerminator, ISubChannel from ...util import dict_to_bytes from ..._dilation import roles from ..._dilation.manager import (Dilator, Manager, make_side, @@ -15,35 +14,57 @@ from ..._dilation.manager import (Dilator, Manager, make_side, UnexpectedKCM, UnknownMessageType) from ..._dilation.connection import Open, Data, Close, Ack, KCM, Ping, Pong +from ..._dilation.subchannel import _SubchannelAddress from .common import clear_mock_calls +class Holder(): + pass def make_dilator(): - reactor = object() - clock = Clock() - eq = EventualQueue(clock) + h = Holder() + h.reactor = object() + h.clock = Clock() + h.eq = EventualQueue(h.clock) term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick def term_factory(): return term - coop = Cooperator(terminationPredicateFactory=term_factory, - scheduler=eq.eventually) - send = mock.Mock() - alsoProvides(send, ISend) - dil = Dilator(reactor, eq, coop) - terminator = mock.Mock() - alsoProvides(terminator, ITerminator) - dil.wire(send, terminator) - return dil, send, reactor, eq, clock, coop + h.coop = Cooperator(terminationPredicateFactory=term_factory, + scheduler=h.eq.eventually) + h.send = mock.Mock() + alsoProvides(h.send, ISend) + dil = Dilator(h.reactor, h.eq, h.coop) + h.terminator = mock.Mock() + alsoProvides(h.terminator, ITerminator) + dil.wire(h.send, h.terminator) + return dil, h class TestDilator(unittest.TestCase): - def test_manager_and_endpoints(self): - dil, send, reactor, eq, clock, coop = make_dilator() - d1 = dil.dilate() - d2 = dil.dilate() - self.assertNoResult(d1) - self.assertNoResult(d2) + # we should test the interleavings between: + # * application calls w.dilate() and gets back endpoints + # * wormhole gets: dilation key, VERSION, 0-n dilation messages + + def test_dilate_first(self): + (dil, h) = make_dilator() + side = object() + m = mock.Mock() + eps = object() + m.get_endpoints = mock.Mock(return_value=eps) + mm = mock.Mock(side_effect=[m]) + with mock.patch("wormhole._dilation.manager.Manager", mm), \ + mock.patch("wormhole._dilation.manager.make_side", + return_value=side): + eps1 = dil.dilate() + eps2 = dil.dilate() + self.assertIdentical(eps1, eps) + self.assertIdentical(eps1, eps2) + self.assertEqual(mm.mock_calls, [mock.call(h.send, side, None, + h.reactor, h.eq, h.coop, False)]) + + self.assertEqual(m.mock_calls, [mock.call.get_endpoints(), + mock.call.get_endpoints()]) + clear_mock_calls(m) key = b"key" transit_key = object() @@ -51,207 +72,108 @@ class TestDilator(unittest.TestCase): return_value=transit_key) as dk: dil.got_key(key) self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)]) - self.assertIdentical(dil._transit_key, transit_key) - self.assertNoResult(d1) - self.assertNoResult(d2) + self.assertEqual(m.mock_calls, [mock.call.got_dilation_key(transit_key)]) + clear_mock_calls(m) - host_addr = dil._host_addr + wv = object() + dil.got_wormhole_versions(wv) + self.assertEqual(m.mock_calls, [mock.call.got_wormhole_versions(wv)]) + clear_mock_calls(m) - peer_addr = object() - m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress", - return_value=peer_addr) - sc = mock.Mock() - m_sc = mock.patch("wormhole._dilation.manager.SubChannel", - return_value=sc) - ce = mock.Mock() - m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", - return_value=ce) - lep = object() - m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", - return_value=lep) - scid0 = 0 - - m = mock.Mock() - alsoProvides(m, IDilationManager) - m.when_first_connected.return_value = wfc_d = Deferred() - with mock.patch("wormhole._dilation.manager.Manager", - return_value=m) as ml: - with mock.patch("wormhole._dilation.manager.make_side", - return_value="us"): - with m_sca, m_sc as m_sc_m, m_ce as m_ce_m, m_sle as m_sle_m: - dil.got_wormhole_versions({"can-dilate": ["1"]}) - # that should create the Manager - self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key, - None, reactor, eq, coop, host_addr, False)]) - # and the three endpoints - self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)]) - self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)]) - self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)]) - # and create subchannel0 - self.assertEqual(m_sc_m.mock_calls, - [mock.call(scid0, m, host_addr, peer_addr)]) - # and tell it to start, and get wait-for-it-to-connect Deferred - self.assertEqual(m.mock_calls, [mock.call.set_subchannel_zero(scid0, sc), - mock.call.set_listener_endpoint(lep), - mock.call.start(), - mock.call.when_first_connected(), + dm1 = object() + dm2 = object() + dil.received_dilate(dm1) + dil.received_dilate(dm2) + self.assertEqual(m.mock_calls, [mock.call.received_dilation_message(dm1), + mock.call.received_dilation_message(dm2), ]) clear_mock_calls(m) - self.assertNoResult(d1) - self.assertNoResult(d2) - wfc_d.callback(None) - eq.flush_sync() - self.assertEqual(m.mock_calls, []) + stopped_d = mock.Mock() + m.when_stopped = mock.Mock(return_value=stopped_d) + dil.stop() + self.assertEqual(m.mock_calls, [mock.call.stop(), + mock.call.when_stopped(), + ]) + + def test_dilate_later(self): + (dil, h) = make_dilator() + m = mock.Mock() + mm = mock.Mock(side_effect=[m]) + + key = b"key" + transit_key = object() + with mock.patch("wormhole._dilation.manager.derive_key", + return_value=transit_key) as dk: + dil.got_key(key) + self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)]) + + wv = object() + dil.got_wormhole_versions(wv) + + dm1 = object() + dil.received_dilate(dm1) + + self.assertEqual(mm.mock_calls, []) + + with mock.patch("wormhole._dilation.manager.Manager", mm): + dil.dilate() + self.assertEqual(m.mock_calls, [mock.call.got_dilation_key(transit_key), + mock.call.got_wormhole_versions(wv), + mock.call.received_dilation_message(dm1), + mock.call.get_endpoints(), + ]) clear_mock_calls(m) - eps = self.successResultOf(d1) - self.assertEqual(eps, self.successResultOf(d2)) - d3 = dil.dilate() - eq.flush_sync() - self.assertEqual(eps, self.successResultOf(d3)) + dm2 = object() + dil.received_dilate(dm2) + self.assertEqual(m.mock_calls, [mock.call.received_dilation_message(dm2), + ]) - # all subsequent DILATE-n messages should get passed to the manager - self.assertEqual(m.mock_calls, []) - pleasemsg = dict(type="please", side="them") - dil.received_dilate(dict_to_bytes(pleasemsg)) - self.assertEqual(m.mock_calls, [mock.call.rx_PLEASE(pleasemsg)]) - clear_mock_calls(m) - - hintmsg = dict(type="connection-hints") - dil.received_dilate(dict_to_bytes(hintmsg)) - self.assertEqual(m.mock_calls, [mock.call.rx_HINTS(hintmsg)]) - clear_mock_calls(m) - - # we're nominally the LEADER, and the leader would not normally be - # receiving a RECONNECT, but since we've mocked out the Manager it - # won't notice - dil.received_dilate(dict_to_bytes(dict(type="reconnect"))) - self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECT()]) - clear_mock_calls(m) - - dil.received_dilate(dict_to_bytes(dict(type="reconnecting"))) - self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECTING()]) - clear_mock_calls(m) - - dil.received_dilate(dict_to_bytes(dict(type="unknown"))) - self.assertEqual(m.mock_calls, []) - self.flushLoggedErrors(UnknownDilationMessageType) + def test_stop_early(self): + (dil, h) = make_dilator() + # we stop before w.dilate(), so there is no Manager to stop + dil.stop() + self.assertEqual(h.terminator.mock_calls, [mock.call.stoppedD()]) def test_peer_cannot_dilate(self): - dil, send, reactor, eq, clock, coop = make_dilator() + (dil, h) = make_dilator() eps = dil.dilate() dil.got_key(b"\x01" * 32) dil.got_wormhole_versions({}) # missing "can-dilate" d = eps.connect.connect(None) - eq.flush_sync() + h.eq.flush_sync() self.failureResultOf(d).check(OldPeerCannotDilateError) def test_disjoint_versions(self): - dil, send, reactor, eq, clock, coop = make_dilator() + (dil, h) = make_dilator() eps = dil.dilate() dil.got_key(b"\x01" * 32) dil.got_wormhole_versions({"can-dilate": [-1]}) d = eps.connect.connect(None) - eq.flush_sync() + h.eq.flush_sync() self.failureResultOf(d).check(OldPeerCannotDilateError) - def test_early_dilate_messages(self): - dil, send, reactor, eq, clock, coop = make_dilator() - dil._transit_key = b"key" - d1 = dil.dilate() - host_addr = dil._host_addr - self.assertNoResult(d1) - pleasemsg = dict(type="please", side="them") - dil.received_dilate(dict_to_bytes(pleasemsg)) - hintmsg = dict(type="connection-hints") - dil.received_dilate(dict_to_bytes(hintmsg)) - - m = mock.Mock() - alsoProvides(m, IDilationManager) - m.when_first_connected.return_value = Deferred() - - scid0 = 0 - sc = mock.Mock() - m_sc = mock.patch("wormhole._dilation.manager.SubChannel", - return_value=sc) - peer_addr = object() - m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress", - return_value=peer_addr) - ce = mock.Mock() - m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", - return_value=ce) - lep = object() - m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", - return_value=lep) - - with mock.patch("wormhole._dilation.manager.Manager", - return_value=m) as ml: - with mock.patch("wormhole._dilation.manager.make_side", - return_value="us"): - with m_sca, m_sc, m_ce, m_sle: - dil.got_wormhole_versions({"can-dilate": ["1"]}) - self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", - None, reactor, eq, coop, host_addr, False)]) - self.assertEqual(m.mock_calls, [mock.call.set_subchannel_zero(scid0, sc), - mock.call.set_listener_endpoint(lep), - mock.call.start(), - mock.call.rx_PLEASE(pleasemsg), - mock.call.rx_HINTS(hintmsg), - mock.call.when_first_connected()]) - def test_transit_relay(self): - dil, send, reactor, eq, clock, coop = make_dilator() - dil._transit_key = b"key" - host_addr = dil._host_addr - relay = object() - d1 = dil.dilate(transit_relay_location=relay) - self.assertNoResult(d1) - + (dil, h) = make_dilator() + transit_relay_location = object() + side = object() m = mock.Mock() - alsoProvides(m, IDilationManager) - m.when_first_connected.return_value = Deferred() - - scid0 = 0 - sc = mock.Mock() - m_sc = mock.patch("wormhole._dilation.manager.SubChannel", - return_value=sc) - peer_addr = object() - m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress", - return_value=peer_addr) - ce = mock.Mock() - m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", - return_value=ce) - lep = object() - m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", - return_value=lep) - - with mock.patch("wormhole._dilation.manager.Manager", - return_value=m) as ml: - with mock.patch("wormhole._dilation.manager.make_side", - return_value="us"): - with m_sca, m_sc, m_ce, m_sle: - dil.got_wormhole_versions({"can-dilate": ["1"]}) - - self.assertEqual(ml.mock_calls, [ - mock.call(send, "us", b"key", - relay, reactor, eq, coop, host_addr, False)]) - self.assertEqual(m.mock_calls, [ - mock.call.set_subchannel_zero(scid0, sc), - mock.call.set_listener_endpoint(lep), - mock.call.start(), - mock.call.when_first_connected()]) - + mm = mock.Mock(side_effect=[m]) + with mock.patch("wormhole._dilation.manager.Manager", mm), \ + mock.patch("wormhole._dilation.manager.make_side", + return_value=side): + dil.dilate(transit_relay_location) + self.assertEqual(mm.mock_calls, [mock.call(h.send, side, transit_relay_location, + h.reactor, h.eq, h.coop, False)]) LEADER = "ff3456abcdef" FOLLOWER = "123456abcdef" def make_manager(leader=True): - class Holder: - pass h = Holder() h.send = mock.Mock() alsoProvides(h.send, ISend) @@ -274,8 +196,16 @@ def make_manager(leader=True): h.Inbound = mock.Mock(return_value=h.inbound) h.outbound = mock.Mock() h.Outbound = mock.Mock(return_value=h.outbound) + h.sc0 = mock.Mock() + alsoProvides(h.sc0, ISubChannel) + h.SubChannel = mock.Mock(return_value=h.sc0) + h.listen_ep = mock.Mock() + alsoProvides(h.listen_ep, IStreamServerEndpoint) with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound), \ - mock.patch("wormhole._dilation.manager.Outbound", h.Outbound): + mock.patch("wormhole._dilation.manager.Outbound", h.Outbound), \ + mock.patch("wormhole._dilation.manager.SubChannel", h.SubChannel), \ + mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", + return_value=h.listen_ep): m = Manager(h.send, side, h.relay, h.reactor, h.eq, h.coop) h.hostaddr = m._host_addr m.got_dilation_key(h.key) @@ -296,16 +226,26 @@ class TestManager(unittest.TestCase): self.assertEqual(h.send.mock_calls, []) self.assertEqual(h.Inbound.mock_calls, [mock.call(m, h.hostaddr)]) self.assertEqual(h.Outbound.mock_calls, [mock.call(m, h.coop)]) + scid0 = 0 + sc0_peer_addr = _SubchannelAddress(scid0) + self.assertEqual(h.SubChannel.mock_calls, [ + mock.call(scid0, m, m._host_addr, sc0_peer_addr), + ]) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.set_subchannel_zero(scid0, h.sc0), + mock.call.set_listener_endpoint(h.listen_ep) + ]) + clear_mock_calls(h.inbound) - m.start() + m.got_wormhole_versions({"can-dilate": ["1"]}) self.assertEqual(h.send.mock_calls, [ mock.call.send("dilate-0", dict_to_bytes({"type": "please", "side": LEADER})) ]) clear_mock_calls(h.send) - wfc_d = m.when_first_connected() - self.assertNoResult(wfc_d) + listen_d = m.get_endpoints().listen.listen(None) + self.assertNoResult(listen_d) # ignore early hints m.rx_HINTS({}) @@ -327,7 +267,7 @@ class TestManager(unittest.TestCase): self.assertEqual(c.mock_calls, [mock.call.start()]) clear_mock_calls(connector, c) - self.assertNoResult(wfc_d) + self.assertNoResult(listen_d) # now any inbound hints should get passed to our Connector with mock.patch("wormhole._dilation.manager.parse_hint", @@ -347,7 +287,7 @@ class TestManager(unittest.TestCase): clear_mock_calls(h.send) # the first successful connection fires when_first_connected(), so - # the Dilator can create and return the endpoints + # the endpoints can activate c1 = mock.Mock() m.connector_connection_made(c1) @@ -355,23 +295,6 @@ class TestManager(unittest.TestCase): self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c1)]) clear_mock_calls(h.inbound, h.outbound) - h.eq.flush_sync() - self.successResultOf(wfc_d) # fires with None - wfc_d2 = m.when_first_connected() - h.eq.flush_sync() - self.successResultOf(wfc_d2) - - scid0 = 0 - sc0 = mock.Mock() - m.set_subchannel_zero(scid0, sc0) - listen_ep = mock.Mock() - m.set_listener_endpoint(listen_ep) - self.assertEqual(h.inbound.mock_calls, [ - mock.call.set_subchannel_zero(scid0, sc0), - mock.call.set_listener_endpoint(listen_ep), - ]) - clear_mock_calls(h.inbound) - # the Leader making a new outbound channel should get scid=1 scid1 = 1 self.assertEqual(m.allocate_subchannel_id(), scid1) @@ -518,12 +441,13 @@ class TestManager(unittest.TestCase): def test_follower(self): m, h = make_manager(leader=False) - m.start() + m.got_wormhole_versions({"can-dilate": ["1"]}) self.assertEqual(h.send.mock_calls, [ mock.call.send("dilate-0", dict_to_bytes({"type": "please", "side": FOLLOWER})) ]) clear_mock_calls(h.send) + clear_mock_calls(h.inbound) c = mock.Mock() connector = mock.Mock(return_value=c) @@ -653,6 +577,7 @@ class TestManager(unittest.TestCase): def test_subchannel(self): m, h = make_manager(leader=True) + clear_mock_calls(h.inbound) sc = object() m.subchannel_pauseProducing(sc) @@ -689,3 +614,14 @@ class TestManager(unittest.TestCase): self.assertEqual(h.outbound.mock_calls, [ mock.call.subchannel_closed(4, sc)]) clear_mock_calls(h.inbound, h.outbound) + + def test_unknown_message(self): + # receive a PLEASE with the same side as us: shouldn't happen + m, h = make_manager(leader=True) + m.start() + + m.received_dilation_message(dict_to_bytes(dict(type="unknown"))) + self.flushLoggedErrors(UnknownDilationMessageType) + + # TODO: test transit relay is used +