diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index 5ea2ae6..19430bb 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -2,13 +2,20 @@ from __future__ import print_function, unicode_literals import six import os from collections import deque +try: + # py >= 3.3 + from collections.abc import Sequence +except ImportError: + # py 2 and py3 < 3.3 + from collections import Sequence from attr import attrs, attrib from attr.validators import provides, instance_of, optional from automat import MethodicalMachine from zope.interface import implementer -from twisted.internet.defer import Deferred, inlineCallbacks, returnValue -from twisted.internet.interfaces import IAddress -from twisted.python import log +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import (IStreamClientEndpoint, + IStreamServerEndpoint) +from twisted.python import log, failure from .._interfaces import IDilator, IDilationManager, ISend, ITerminator from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr from ..observer import OneShotObserver @@ -47,6 +54,15 @@ class UnexpectedKCM(Exception): class UnknownMessageType(Exception): pass +@attrs +class EndpointRecord(Sequence): + control = attrib(validator=provides(IStreamClientEndpoint)) + connect = attrib(validator=provides(IStreamClientEndpoint)) + listen = attrib(validator=provides(IStreamServerEndpoint)) + def __len__(self): + return 3 + def __getitem__(self, n): + return (self.control, self.connect, self.listen)[n] def make_side(): return bytes_to_hexstr(os.urandom(6)) @@ -93,13 +109,14 @@ def make_side(): class Manager(object): _S = attrib(validator=provides(ISend), repr=False) _my_side = attrib(validator=instance_of(type(u""))) - _transit_key = attrib(validator=instance_of(bytes), repr=False) _transit_relay_location = attrib(validator=optional(instance_of(str))) _reactor = attrib(repr=False) _eventual_queue = attrib(repr=False) _cooperator = attrib(repr=False) - _host_addr = attrib(validator=provides(IAddress)) - _no_listen = attrib(default=False) + # TODO: can this validator work when the parameter is optional? + _no_listen = attrib(validator=instance_of(bool), default=False) + + _dilation_key = None _tor = None # TODO _timing = None # TODO _next_subchannel_id = None # initialized in choose_role @@ -111,10 +128,10 @@ class Manager(object): self._got_versions_d = Deferred() self._my_role = None # determined upon rx_PLEASE + self._host_addr = _WormholeAddress() self._connection = None self._made_first_connection = False - self._first_connected = OneShotObserver(self._eventual_queue) self._stopped = OneShotObserver(self._eventual_queue) self._debug_stall_connector = False @@ -127,18 +144,81 @@ class Manager(object): self._inbound = Inbound(self, self._host_addr) self._outbound = Outbound(self, self._cooperator) # from us to peer - def set_listener_endpoint(self, listener_endpoint): - self._inbound.set_listener_endpoint(listener_endpoint) - - def set_subchannel_zero(self, scid0, sc0): + # We must open subchannel0 early, since messages may arrive very + # quickly once the connection is established. This subchannel may or + # may not ever get revealed to the caller, since the peer might not + # even be capable of dilation. + scid0 = 0 + peer_addr0 = _SubchannelAddress(scid0) + sc0 = SubChannel(scid0, self, self._host_addr, peer_addr0) self._inbound.set_subchannel_zero(scid0, sc0) - def when_first_connected(self): - return self._first_connected.when_fired() + # we can open non-zero subchannels as soon as we get our first + # connection, and we can make the Endpoints even earlier + control_ep = ControlEndpoint(peer_addr0, sc0, self._eventual_queue) + connect_ep = SubchannelConnectorEndpoint(self, self._host_addr, self._eventual_queue) + listen_ep = SubchannelListenerEndpoint(self, self._host_addr, self._eventual_queue) + # TODO: let inbound/outbound create the endpoints, then return them + # to us + self._inbound.set_listener_endpoint(listen_ep) + + self._endpoints = EndpointRecord(control_ep, connect_ep, listen_ep) + + def get_endpoints(self): + return self._endpoints + + def got_dilation_key(self, key): + assert isinstance(key, bytes) + self._dilation_key = key + + def got_wormhole_versions(self, their_wormhole_versions): + # this always happens before received_dilation_message + dilation_version = None + their_dilation_versions = set(their_wormhole_versions.get("can-dilate", [])) + my_versions = set(DILATION_VERSIONS) + shared_versions = my_versions.intersection(their_dilation_versions) + if "1" in shared_versions: + dilation_version = "1" + + # dilation_version is the best mutually-compatible version we have + # with the peer, or None if we have nothing in common + + if not dilation_version: # "1" or None + # TODO: be more specific about the error. dilation_version==None + # means we had no version in common with them, which could either + # be because they're so old they don't dilate at all, or because + # they're so new that they no longer accomodate our old version + self.fail(failure.Failure(OldPeerCannotDilateError())) + + self.start() + + def fail(self, f): + self._endpoints.control._main_channel_failed(f) + self._endpoints.connect._main_channel_failed(f) + self._endpoints.listen._main_channel_failed(f) + + def received_dilation_message(self, plaintext): + # this receives new in-order DILATE-n payloads, decrypted but not + # de-JSONed. + + message = bytes_to_dict(plaintext) + type = message["type"] + if type == "please": + self.rx_PLEASE(message) + elif type == "connection-hints": + self.rx_HINTS(message) + elif type == "reconnect": + self.rx_RECONNECT() + elif type == "reconnecting": + self.rx_RECONNECTING() + else: + log.err(UnknownDilationMessageType(message)) + return def when_stopped(self): return self._stopped.when_fired() + def send_dilation_phase(self, **fields): dilation_phase = self._next_dilation_phase self._next_dilation_phase += 1 @@ -204,7 +284,9 @@ class Manager(object): self._outbound.use_connection(c) # does c.registerProducer if not self._made_first_connection: self._made_first_connection = True - self._first_connected.fire(None) + self._endpoints.control._main_channel_ready() + self._endpoints.connect._main_channel_ready() + self._endpoints.listen._main_channel_ready() pass def connector_connection_lost(self): @@ -272,16 +354,11 @@ class Manager(object): # state machine - # We are born WANTING after the local app calls w.dilate(). We start - # CONNECTING when we receive PLEASE from the remote side - - def start(self): - self.send_please() - - def send_please(self): - self.send_dilation_phase(type="please", side=self._my_side) - @m.state(initial=True) + def WAITING(self): + pass # pragma: no cover + + @m.state() def WANTING(self): pass # pragma: no cover @@ -313,6 +390,10 @@ class Manager(object): def STOPPED(self): pass # pragma: no cover + @m.input() + def start(self): + pass # pragma: no cover + @m.input() def rx_PLEASE(self, message): pass # pragma: no cover @@ -350,6 +431,10 @@ class Manager(object): def stop(self): pass # pragma: no cover + @m.output() + def send_please(self): + self.send_dilation_phase(type="please", side=self._my_side) + @m.output() def choose_role(self, message): their_side = message["side"] @@ -378,7 +463,8 @@ class Manager(object): def _start_connecting(self): assert self._my_role is not None - self._connector = Connector(self._transit_key, + assert self._dilation_key is not None + self._connector = Connector(self._dilation_key, self._transit_relay_location, self, self._reactor, self._eventual_queue, @@ -422,6 +508,11 @@ class Manager(object): def notify_stopped(self): self._stopped.fire(None) + # We are born WAITING after the local app calls w.dilate(). We enter + # WANTING (and send a PLEASE) when we learn of a mutually-compatible + # dilation_version. + WAITING.upon(start, enter=WANTING, outputs=[send_please]) + # we start CONNECTING when we get rx_PLEASE WANTING.upon(rx_PLEASE, enter=CONNECTING, outputs=[choose_role, start_connecting_ignore_message]) @@ -489,12 +580,10 @@ class Dilator(object): _cooperator = attrib() def __attrs_post_init__(self): - self._got_versions_d = Deferred() - self._started = False - self._endpoints = OneShotObserver(self._eventual_queue) - self._pending_inbound_dilate_messages = deque() self._manager = None - self._host_addr = _WormholeAddress() + self._pending_dilation_key = None + self._pending_wormhole_versions = None + self._pending_inbound_dilate_messages = deque() def wire(self, sender, terminator): self._S = ISend(sender) @@ -502,77 +591,35 @@ class Dilator(object): # this is the primary entry point, called when w.dilate() is invoked def dilate(self, transit_relay_location=None, no_listen=False): - self._transit_relay_location = transit_relay_location - if not self._started: - self._started = True - self._start(no_listen).addBoth(self._endpoints.fire) - return self._endpoints.when_fired() - - @inlineCallbacks - def _start(self, no_listen): - # first, we wait until we hear the VERSION message, which tells us 1: - # the PAKE key works, so we can talk securely, 2: that they can do - # dilation at all (if they can't then w.dilate() errbacks) - - dilation_version = yield self._got_versions_d - - # TODO: we could probably return the endpoints earlier, if we flunk - # any connection/listen attempts upon OldPeerCannotDilateError, or - # if/when we give up on the initial connection - - if not dilation_version: # "1" or None - # TODO: be more specific about the error. dilation_version==None - # means we had no version in common with them, which could either - # be because they're so old they don't dilate at all, or because - # they're so new that they no longer accomodate our old version - raise OldPeerCannotDilateError() - - my_dilation_side = make_side() - self._manager = Manager(self._S, my_dilation_side, - self._transit_key, - self._transit_relay_location, - self._reactor, self._eventual_queue, - self._cooperator, self._host_addr, no_listen) - # We must open subchannel0 early, since messages may arrive very - # quickly once the connection is established. This subchannel may or - # may not ever get revealed to the caller, since the peer might not - # even be capable of dilation. - scid0 = 0 - peer_addr0 = _SubchannelAddress(scid0) - sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0) - self._manager.set_subchannel_zero(scid0, sc0) - - self._manager.start() - - while self._pending_inbound_dilate_messages: - plaintext = self._pending_inbound_dilate_messages.popleft() - self.received_dilate(plaintext) - - yield self._manager.when_first_connected() - - # we can open non-zero subchannels as soon as we get our first - # connection - control_ep = ControlEndpoint(peer_addr0) - control_ep._subchannel_zero_opened(sc0) - connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) - - listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) - self._manager.set_listener_endpoint(listen_ep) - - endpoints = (control_ep, connect_ep, listen_ep) - returnValue(endpoints) + if not self._manager: + # build the manager right away, and tell it later when the + # VERSIONS message arrives, and also when the dilation_key is set + my_dilation_side = make_side() + m = Manager(self._S, my_dilation_side, + transit_relay_location, + self._reactor, self._eventual_queue, + self._cooperator, no_listen) + self._manager = m + if self._pending_dilation_key is not None: + m.got_dilation_key(self._pending_dilation_key) + if self._pending_wormhole_versions: + m.got_wormhole_versions(self._pending_wormhole_versions) + while self._pending_inbound_dilate_messages: + plaintext = self._pending_inbound_dilate_messages.popleft() + m.received_dilation_message(plaintext) + return self._manager.get_endpoints() # Called by Terminator after everything else (mailbox, nameplate, server # connection) has shut down. Expects to fire T.stoppedD() when Dilator is # stopped too. def stop(self): - if not self._started: - self._T.stoppedD() - return - if self._started: + if self._manager: self._manager.stop() # TODO: avoid Deferreds for control flow, hard to serialize self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD()) + else: + self._T.stoppedD() + return # TODO: tolerate multiple calls # from Boss @@ -582,39 +629,20 @@ class Dilator(object): # to tolerate either ordering purpose = b"dilation-v1" LENGTH = 32 # TODO: whatever Noise wants, I guess - self._transit_key = derive_key(key, purpose, LENGTH) + dilation_key = derive_key(key, purpose, LENGTH) + if self._manager: + self._manager.got_dilation_key(dilation_key) + else: + self._pending_dilation_key = dilation_key def got_wormhole_versions(self, their_wormhole_versions): - assert self._transit_key is not None - # this always happens before received_dilate - dilation_version = None - their_dilation_versions = set(their_wormhole_versions.get("can-dilate", [])) - my_versions = set(DILATION_VERSIONS) - shared_versions = my_versions.intersection(their_dilation_versions) - if "1" in shared_versions: - dilation_version = "1" - self._got_versions_d.callback(dilation_version) + if self._manager: + self._manager.got_wormhole_versions(their_wormhole_versions) + else: + self._pending_wormhole_versions = their_wormhole_versions def received_dilate(self, plaintext): - # this receives new in-order DILATE-n payloads, decrypted but not - # de-JSONed. - - # this can appear before our .dilate() method is called, in which case - # we queue them for later if not self._manager: self._pending_inbound_dilate_messages.append(plaintext) - return - - message = bytes_to_dict(plaintext) - type = message["type"] - if type == "please": - self._manager.rx_PLEASE(message) - elif type == "connection-hints": - self._manager.rx_HINTS(message) - elif type == "reconnect": - self._manager.rx_RECONNECT() - elif type == "reconnecting": - self._manager.rx_RECONNECTING() else: - log.err(UnknownDilationMessageType(message)) - return + self._manager.received_dilation_message(plaintext) diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py index bcdff6a..e1e811c 100644 --- a/src/wormhole/_dilation/subchannel.py +++ b/src/wormhole/_dilation/subchannel.py @@ -1,9 +1,9 @@ import six +from collections import deque from attr import attrs, attrib from attr.validators import instance_of, provides from zope.interface import implementer -from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue, - succeed) +from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, IAddress, IListeningPort, IStreamClientEndpoint, @@ -11,6 +11,7 @@ from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, from twisted.internet.error import ConnectionDone from automat import MethodicalMachine from .._interfaces import ISubChannel, IDilationManager +from ..observer import OneShotObserver # each subchannel frame (the data passed into transport.write(data)) gets a # 9-byte header prefix (type, subchannel id, and sequence number), then gets @@ -216,27 +217,33 @@ class SubChannel(object): @implementer(IStreamClientEndpoint) +@attrs class ControlEndpoint(object): + _peer_addr = attrib(validator=provides(IAddress)) + _subchannel_zero = attrib(validator=provides(ISubChannel)) + _eventual_queue = attrib(repr=False) _used = False - def __init__(self, peer_addr): - self._subchannel_zero = Deferred() - self._peer_addr = peer_addr + def __attrs_post_init__(self): self._once = Once(SingleUseEndpointError) + self._wait_for_main_channel = OneShotObserver(self._eventual_queue) # from manager - def _subchannel_zero_opened(self, subchannel): - assert ISubChannel.providedBy(subchannel), subchannel - self._subchannel_zero.callback(subchannel) + + def _main_channel_ready(self): + self._wait_for_main_channel.fire(None) + def _main_channel_failed(self, f): + self._wait_for_main_channel.error(f) @inlineCallbacks def connect(self, protocolFactory): # return Deferred that fires with IProtocol or Failure(ConnectError) self._once() - t = yield self._subchannel_zero + yield self._wait_for_main_channel.when_fired() p = protocolFactory.buildProtocol(self._peer_addr) - t._set_protocol(p) - p.makeConnection(t) # set p.transport = t and call connectionMade() + self._subchannel_zero._set_protocol(p) + # this sets p.transport and calls p.connectionMade() + p.makeConnection(self._subchannel_zero) returnValue(p) @@ -245,9 +252,21 @@ class ControlEndpoint(object): class SubchannelConnectorEndpoint(object): _manager = attrib(validator=provides(IDilationManager)) _host_addr = attrib(validator=instance_of(_WormholeAddress)) + _eventual_queue = attrib(repr=False) + def __attrs_post_init__(self): + self._connection_deferreds = deque() + self._wait_for_main_channel = OneShotObserver(self._eventual_queue) + + def _main_channel_ready(self): + self._wait_for_main_channel.fire(None) + def _main_channel_failed(self, f): + self._wait_for_main_channel.error(f) + + @inlineCallbacks def connect(self, protocolFactory): # return Deferred that fires with IProtocol or Failure(ConnectError) + yield self._wait_for_main_channel.when_fired() scid = self._manager.allocate_subchannel_id() self._manager.send_open(scid) peer_addr = _SubchannelAddress(scid) @@ -258,7 +277,7 @@ class SubchannelConnectorEndpoint(object): p = protocolFactory.buildProtocol(peer_addr) sc._set_protocol(p) p.makeConnection(sc) # set p.transport = sc and call connectionMade() - return succeed(p) + returnValue(p) @implementer(IStreamServerEndpoint) @@ -266,12 +285,15 @@ class SubchannelConnectorEndpoint(object): class SubchannelListenerEndpoint(object): _manager = attrib(validator=provides(IDilationManager)) _host_addr = attrib(validator=provides(IAddress)) + _eventual_queue = attrib(repr=False) def __attrs_post_init__(self): + self._once = Once(SingleUseEndpointError) self._factory = None - self._pending_opens = [] + self._pending_opens = deque() + self._wait_for_main_channel = OneShotObserver(self._eventual_queue) - # from manager + # from manager (actually Inbound) def _got_open(self, t, peer_addr): if self._factory: self._connect(t, peer_addr) @@ -283,15 +305,23 @@ class SubchannelListenerEndpoint(object): t._set_protocol(p) p.makeConnection(t) + def _main_channel_ready(self): + self._wait_for_main_channel.fire(None) + def _main_channel_failed(self, f): + self._wait_for_main_channel.error(f) + # IStreamServerEndpoint + @inlineCallbacks def listen(self, protocolFactory): + self._once() + yield self._wait_for_main_channel.when_fired() self._factory = protocolFactory - for (t, peer_addr) in self._pending_opens: + while self._pending_opens: + (t, peer_addr) = self._pending_opens.popleft() self._connect(t, peer_addr) - self._pending_opens = [] lp = SubchannelListeningPort(self._host_addr) - return succeed(lp) + returnValue(lp) @implementer(IListeningPort) diff --git a/src/wormhole/test/dilate/test_endpoints.py b/src/wormhole/test/dilate/test_endpoints.py index 08c5600..647fb9d 100644 --- a/src/wormhole/test/dilate/test_endpoints.py +++ b/src/wormhole/test/dilate/test_endpoints.py @@ -2,7 +2,10 @@ from __future__ import print_function, unicode_literals import mock from zope.interface import alsoProvides from twisted.trial import unittest +from twisted.internet.task import Clock +from twisted.python.failure import Failure from ..._interfaces import ISubChannel +from ...eventual import EventualQueue from ..._dilation.subchannel import (ControlEndpoint, SubchannelConnectorEndpoint, SubchannelListenerEndpoint, @@ -11,12 +14,18 @@ from ..._dilation.subchannel import (ControlEndpoint, SingleUseEndpointError) from .common import mock_manager +class CannotDilateError(Exception): + pass -class Endpoints(unittest.TestCase): - def test_control(self): +class Control(unittest.TestCase): + def test_early_succeed(self): + # ep.connect() is called before dilation can proceed scid0 = 0 peeraddr = _SubchannelAddress(scid0) - ep = ControlEndpoint(peeraddr) + sc0 = mock.Mock() + alsoProvides(sc0, ISubChannel) + eq = EventualQueue(Clock()) + ep = ControlEndpoint(peeraddr, sc0, eq) f = mock.Mock() p = mock.Mock() @@ -24,29 +33,105 @@ class Endpoints(unittest.TestCase): d = ep.connect(f) self.assertNoResult(d) - t = mock.Mock() - alsoProvides(t, ISubChannel) - ep._subchannel_zero_opened(t) + ep._main_channel_ready() + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), p) self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) - self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) - self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) + self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)]) d = ep.connect(f) self.failureResultOf(d, SingleUseEndpointError) - def assert_makeConnection(self, mock_calls): + def test_early_fail(self): + # ep.connect() is called before dilation is abandoned + scid0 = 0 + peeraddr = _SubchannelAddress(scid0) + sc0 = mock.Mock() + alsoProvides(sc0, ISubChannel) + eq = EventualQueue(Clock()) + ep = ControlEndpoint(peeraddr, sc0, eq) + + f = mock.Mock() + p = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + d = ep.connect(f) + self.assertNoResult(d) + + ep._main_channel_failed(Failure(CannotDilateError())) + eq.flush_sync() + + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + self.assertEqual(sc0.mock_calls, []) + + d = ep.connect(f) + self.failureResultOf(d, SingleUseEndpointError) + + def test_late_succeed(self): + # dilation can proceed, then ep.connect() is called + scid0 = 0 + peeraddr = _SubchannelAddress(scid0) + sc0 = mock.Mock() + alsoProvides(sc0, ISubChannel) + eq = EventualQueue(Clock()) + ep = ControlEndpoint(peeraddr, sc0, eq) + + ep._main_channel_ready() + + f = mock.Mock() + p = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + d = ep.connect(f) + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), p) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) + self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)]) + + d = ep.connect(f) + self.failureResultOf(d, SingleUseEndpointError) + + def test_late_fail(self): + # dilation is abandoned, then ep.connect() is called + scid0 = 0 + peeraddr = _SubchannelAddress(scid0) + sc0 = mock.Mock() + alsoProvides(sc0, ISubChannel) + eq = EventualQueue(Clock()) + ep = ControlEndpoint(peeraddr, sc0, eq) + + ep._main_channel_failed(Failure(CannotDilateError())) + + f = mock.Mock() + p = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + d = ep.connect(f) + eq.flush_sync() + + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + self.assertEqual(sc0.mock_calls, []) + + d = ep.connect(f) + self.failureResultOf(d, SingleUseEndpointError) + +class Endpoints(unittest.TestCase): + def OFFassert_makeConnection(self, mock_calls): self.assertEqual(len(mock_calls), 1) self.assertEqual(mock_calls[0][0], "makeConnection") self.assertEqual(len(mock_calls[0][1]), 1) return mock_calls[0][1][0] - def test_connector(self): +class Connector(unittest.TestCase): + def test_early_succeed(self): m = mock_manager() m.allocate_subchannel_id = mock.Mock(return_value=0) hostaddr = _WormholeAddress() peeraddr = _SubchannelAddress(0) - ep = SubchannelConnectorEndpoint(m, hostaddr) + eq = EventualQueue(Clock()) + ep = SubchannelConnectorEndpoint(m, hostaddr, eq) f = mock.Mock() p = mock.Mock() @@ -55,38 +140,123 @@ class Endpoints(unittest.TestCase): with mock.patch("wormhole._dilation.subchannel.SubChannel", return_value=t) as sc: d = ep.connect(f) + eq.flush_sync() + self.assertNoResult(d) + ep._main_channel_ready() + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), p) self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)]) self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) - def test_listener(self): + def test_early_fail(self): m = mock_manager() m.allocate_subchannel_id = mock.Mock(return_value=0) hostaddr = _WormholeAddress() - ep = SubchannelListenerEndpoint(m, hostaddr) + eq = EventualQueue(Clock()) + ep = SubchannelConnectorEndpoint(m, hostaddr, eq) + + f = mock.Mock() + p = mock.Mock() + t = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + with mock.patch("wormhole._dilation.subchannel.SubChannel", + return_value=t) as sc: + d = ep.connect(f) + eq.flush_sync() + self.assertNoResult(d) + ep._main_channel_failed(Failure(CannotDilateError())) + eq.flush_sync() + + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + self.assertEqual(sc.mock_calls, []) + self.assertEqual(t.mock_calls, []) + + def test_late_succeed(self): + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + peeraddr = _SubchannelAddress(0) + eq = EventualQueue(Clock()) + ep = SubchannelConnectorEndpoint(m, hostaddr, eq) + ep._main_channel_ready() + + f = mock.Mock() + p = mock.Mock() + t = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + with mock.patch("wormhole._dilation.subchannel.SubChannel", + return_value=t) as sc: + d = ep.connect(f) + eq.flush_sync() + + self.assertIdentical(self.successResultOf(d), p) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) + self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)]) + self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) + + def test_late_fail(self): + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelConnectorEndpoint(m, hostaddr, eq) + ep._main_channel_failed(Failure(CannotDilateError())) + + f = mock.Mock() + p = mock.Mock() + t = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + with mock.patch("wormhole._dilation.subchannel.SubChannel", + return_value=t) as sc: + d = ep.connect(f) + eq.flush_sync() + + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + self.assertEqual(sc.mock_calls, []) + self.assertEqual(t.mock_calls, []) + +class Listener(unittest.TestCase): + def test_early_succeed(self): + # listen, main_channel_ready, got_open, got_open + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelListenerEndpoint(m, hostaddr, eq) f = mock.Mock() p1 = mock.Mock() p2 = mock.Mock() f.buildProtocol = mock.Mock(side_effect=[p1, p2]) - # OPEN that arrives before we ep.listen() should be queued + d = ep.listen(f) + eq.flush_sync() + self.assertNoResult(d) + self.assertEqual(f.buildProtocol.mock_calls, []) + + ep._main_channel_ready() + eq.flush_sync() + lp = self.successResultOf(d) + self.assertIsInstance(lp, SubchannelListeningPort) + + self.assertEqual(lp.getHost(), hostaddr) + # TODO: IListeningPort says we must provide this, but I don't know + # that anyone would ever call it. + lp.startListening() t1 = mock.Mock() peeraddr1 = _SubchannelAddress(1) ep._got_open(t1, peeraddr1) - d = ep.listen(f) - lp = self.successResultOf(d) - self.assertIsInstance(lp, SubchannelListeningPort) - - self.assertEqual(lp.getHost(), hostaddr) - lp.startListening() - self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)]) self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)]) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)]) t2 = mock.Mock() peeraddr2 = _SubchannelAddress(2) @@ -94,5 +264,92 @@ class Endpoints(unittest.TestCase): self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1), + mock.call(peeraddr2)]) lp.stopListening() # TODO: should this do more? + + def test_early_fail(self): + # listen, main_channel_fail + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelListenerEndpoint(m, hostaddr, eq) + + f = mock.Mock() + p1 = mock.Mock() + p2 = mock.Mock() + f.buildProtocol = mock.Mock(side_effect=[p1, p2]) + + d = ep.listen(f) + eq.flush_sync() + self.assertNoResult(d) + + ep._main_channel_failed(Failure(CannotDilateError())) + eq.flush_sync() + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + + def test_late_succeed(self): + # main_channel_ready, got_open, listen, got_open + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelListenerEndpoint(m, hostaddr, eq) + ep._main_channel_ready() + + f = mock.Mock() + p1 = mock.Mock() + p2 = mock.Mock() + f.buildProtocol = mock.Mock(side_effect=[p1, p2]) + + t1 = mock.Mock() + peeraddr1 = _SubchannelAddress(1) + ep._got_open(t1, peeraddr1) + eq.flush_sync() + + self.assertEqual(t1.mock_calls, []) + self.assertEqual(p1.mock_calls, []) + + d = ep.listen(f) + eq.flush_sync() + lp = self.successResultOf(d) + self.assertIsInstance(lp, SubchannelListeningPort) + self.assertEqual(lp.getHost(), hostaddr) + lp.startListening() + + self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)]) + self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)]) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)]) + + t2 = mock.Mock() + peeraddr2 = _SubchannelAddress(2) + ep._got_open(t2, peeraddr2) + + self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) + self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1), + mock.call(peeraddr2)]) + + lp.stopListening() # TODO: should this do more? + + def test_late_fail(self): + # main_channel_fail, listen + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelListenerEndpoint(m, hostaddr, eq) + ep._main_channel_failed(Failure(CannotDilateError())) + + f = mock.Mock() + p1 = mock.Mock() + p2 = mock.Mock() + f.buildProtocol = mock.Mock(side_effect=[p1, p2]) + + d = ep.listen(f) + eq.flush_sync() + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) diff --git a/src/wormhole/test/dilate/test_full.py b/src/wormhole/test/dilate/test_full.py index cc0e8b0..79c0019 100644 --- a/src/wormhole/test/dilate/test_full.py +++ b/src/wormhole/test/dilate/test_full.py @@ -46,24 +46,21 @@ class Full(ServerBase, unittest.TestCase): yield doBoth(w1.get_verifier(), w2.get_verifier()) print("connected") - eps1_d = w1.dilate() - eps2_d = w2.dilate() - (eps1, eps2) = yield doBoth(eps1_d, eps2_d) - (control_ep1, connect_ep1, listen_ep1) = eps1 - (control_ep2, connect_ep2, listen_ep2) = eps2 + eps1 = w1.dilate() + eps2 = w2.dilate() print("w.dilate ready") f1 = Factory() f1.protocol = L f1.d = Deferred() f1.d.addCallback(lambda data: eq.fire_eventually(data)) - d1 = control_ep1.connect(f1) + d1 = eps1.control.connect(f1) f2 = Factory() f2.protocol = L f2.d = Deferred() f2.d.addCallback(lambda data: eq.fire_eventually(data)) - d2 = control_ep2.connect(f2) + d2 = eps2.control.connect(f2) yield d1 yield d2 print("control endpoints connected") @@ -125,14 +122,12 @@ class Reconnect(ServerBase, unittest.TestCase): w2.set_code(code) yield doBoth(w1.get_verifier(), w2.get_verifier()) - eps1_d = w1.dilate() - eps2_d = w2.dilate() - (eps1, eps2) = yield doBoth(eps1_d, eps2_d) - (control_ep1, connect_ep1, listen_ep1) = eps1 - (control_ep2, connect_ep2, listen_ep2) = eps2 + eps1 = w1.dilate() + eps2 = w2.dilate() + print("w.dilate ready") f1 = ReconF(eq); f2 = ReconF(eq) - d1 = control_ep1.connect(f1); d2 = control_ep2.connect(f2) + d1 = eps1.control.connect(f1); d2 = eps2.control.connect(f2) yield d1 yield d2 @@ -194,14 +189,12 @@ class Reconnect(ServerBase, unittest.TestCase): w2.set_code(code) yield doBoth(w1.get_verifier(), w2.get_verifier()) - eps1_d = w1.dilate() - eps2_d = w2.dilate() - (eps1, eps2) = yield doBoth(eps1_d, eps2_d) - (control_ep1, connect_ep1, listen_ep1) = eps1 - (control_ep2, connect_ep2, listen_ep2) = eps2 + eps1 = w1.dilate() + eps2 = w2.dilate() + print("w.dilate ready") f1 = ReconF(eq); f2 = ReconF(eq) - d1 = control_ep1.connect(f1); d2 = control_ep2.connect(f2) + d1 = eps1.control.connect(f1); d2 = eps2.control.connect(f2) yield d1 yield d2 @@ -287,19 +280,17 @@ class Endpoints(ServerBase, unittest.TestCase): w2.set_code(code) yield doBoth(w1.get_verifier(), w2.get_verifier()) - eps1_d = w1.dilate() - eps2_d = w2.dilate() - (eps1, eps2) = yield doBoth(eps1_d, eps2_d) - (control_ep1, connect_ep1, listen_ep1) = eps1 - (control_ep2, connect_ep2, listen_ep2) = eps2 + eps1 = w1.dilate() + eps2 = w2.dilate() + print("w.dilate ready") f0 = ReconF(eq) - yield listen_ep2.listen(f0) + yield eps2.listen.listen(f0) from twisted.python import log f1 = ReconF(eq) log.msg("connecting") - p1_client = yield connect_ep1.connect(f1) + p1_client = yield eps1.connect.connect(f1) log.msg("sending c->s") p1_client.transport.write(b"hello from p1\n") data = yield f0.deferreds["dataReceived"] @@ -316,7 +307,7 @@ class Endpoints(ServerBase, unittest.TestCase): f0.resetDeferred("dataReceived") f1.resetDeferred("dataReceived") f2 = ReconF(eq) - p2_client = yield connect_ep1.connect(f2) + p2_client = yield eps1.connect.connect(f2) p2_server = yield f0.deferreds["connectionMade"] p2_server.transport.write(b"hello p2\n") data = yield f2.deferreds["dataReceived"] diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index e36e403..3fdd051 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -1,12 +1,11 @@ from __future__ import print_function, unicode_literals from zope.interface import alsoProvides from twisted.trial import unittest -from twisted.internet.defer import Deferred from twisted.internet.task import Clock, Cooperator -from twisted.internet.interfaces import IAddress +from twisted.internet.interfaces import IStreamServerEndpoint import mock from ...eventual import EventualQueue -from ..._interfaces import ISend, IDilationManager, ITerminator +from ..._interfaces import ISend, ITerminator, ISubChannel from ...util import dict_to_bytes from ..._dilation import roles from ..._dilation.manager import (Dilator, Manager, make_side, @@ -15,35 +14,57 @@ from ..._dilation.manager import (Dilator, Manager, make_side, UnexpectedKCM, UnknownMessageType) from ..._dilation.connection import Open, Data, Close, Ack, KCM, Ping, Pong +from ..._dilation.subchannel import _SubchannelAddress from .common import clear_mock_calls +class Holder(): + pass def make_dilator(): - reactor = object() - clock = Clock() - eq = EventualQueue(clock) + h = Holder() + h.reactor = object() + h.clock = Clock() + h.eq = EventualQueue(h.clock) term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick def term_factory(): return term - coop = Cooperator(terminationPredicateFactory=term_factory, - scheduler=eq.eventually) - send = mock.Mock() - alsoProvides(send, ISend) - dil = Dilator(reactor, eq, coop) - terminator = mock.Mock() - alsoProvides(terminator, ITerminator) - dil.wire(send, terminator) - return dil, send, reactor, eq, clock, coop + h.coop = Cooperator(terminationPredicateFactory=term_factory, + scheduler=h.eq.eventually) + h.send = mock.Mock() + alsoProvides(h.send, ISend) + dil = Dilator(h.reactor, h.eq, h.coop) + h.terminator = mock.Mock() + alsoProvides(h.terminator, ITerminator) + dil.wire(h.send, h.terminator) + return dil, h class TestDilator(unittest.TestCase): - def test_manager_and_endpoints(self): - dil, send, reactor, eq, clock, coop = make_dilator() - d1 = dil.dilate() - d2 = dil.dilate() - self.assertNoResult(d1) - self.assertNoResult(d2) + # we should test the interleavings between: + # * application calls w.dilate() and gets back endpoints + # * wormhole gets: dilation key, VERSION, 0-n dilation messages + + def test_dilate_first(self): + (dil, h) = make_dilator() + side = object() + m = mock.Mock() + eps = object() + m.get_endpoints = mock.Mock(return_value=eps) + mm = mock.Mock(side_effect=[m]) + with mock.patch("wormhole._dilation.manager.Manager", mm), \ + mock.patch("wormhole._dilation.manager.make_side", + return_value=side): + eps1 = dil.dilate() + eps2 = dil.dilate() + self.assertIdentical(eps1, eps) + self.assertIdentical(eps1, eps2) + self.assertEqual(mm.mock_calls, [mock.call(h.send, side, None, + h.reactor, h.eq, h.coop, False)]) + + self.assertEqual(m.mock_calls, [mock.call.get_endpoints(), + mock.call.get_endpoints()]) + clear_mock_calls(m) key = b"key" transit_key = object() @@ -51,183 +72,108 @@ class TestDilator(unittest.TestCase): return_value=transit_key) as dk: dil.got_key(key) self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)]) - self.assertIdentical(dil._transit_key, transit_key) - self.assertNoResult(d1) - self.assertNoResult(d2) + self.assertEqual(m.mock_calls, [mock.call.got_dilation_key(transit_key)]) + clear_mock_calls(m) - host_addr = dil._host_addr + wv = object() + dil.got_wormhole_versions(wv) + self.assertEqual(m.mock_calls, [mock.call.got_wormhole_versions(wv)]) + clear_mock_calls(m) - peer_addr = object() - m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress", - return_value=peer_addr) - sc = mock.Mock() - m_sc = mock.patch("wormhole._dilation.manager.SubChannel", - return_value=sc) - scid0 = 0 - - m = mock.Mock() - alsoProvides(m, IDilationManager) - m.when_first_connected.return_value = wfc_d = Deferred() - with mock.patch("wormhole._dilation.manager.Manager", - return_value=m) as ml: - with mock.patch("wormhole._dilation.manager.make_side", - return_value="us"): - with m_sca, m_sc as m_sc_m: - dil.got_wormhole_versions({"can-dilate": ["1"]}) - # that should create the Manager - self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key, - None, reactor, eq, coop, host_addr, False)]) - # and create subchannel0 - self.assertEqual(m_sc_m.mock_calls, - [mock.call(scid0, m, host_addr, peer_addr)]) - # and tell it to start, and get wait-for-it-to-connect Deferred - self.assertEqual(m.mock_calls, [mock.call.set_subchannel_zero(scid0, sc), - mock.call.start(), - mock.call.when_first_connected(), + dm1 = object() + dm2 = object() + dil.received_dilate(dm1) + dil.received_dilate(dm2) + self.assertEqual(m.mock_calls, [mock.call.received_dilation_message(dm1), + mock.call.received_dilation_message(dm2), ]) clear_mock_calls(m) - self.assertNoResult(d1) - self.assertNoResult(d2) - ce = mock.Mock() - m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", - return_value=ce) - lep = object() - m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", - return_value=lep) + stopped_d = mock.Mock() + m.when_stopped = mock.Mock(return_value=stopped_d) + dil.stop() + self.assertEqual(m.mock_calls, [mock.call.stop(), + mock.call.when_stopped(), + ]) - with m_ce as m_ce_m, m_sle as m_sle_m: - wfc_d.callback(None) - eq.flush_sync() - self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)]) - self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)]) - self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)]) - self.assertEqual(m.mock_calls, - [mock.call.set_listener_endpoint(lep), - ]) + def test_dilate_later(self): + (dil, h) = make_dilator() + m = mock.Mock() + mm = mock.Mock(side_effect=[m]) + + key = b"key" + transit_key = object() + with mock.patch("wormhole._dilation.manager.derive_key", + return_value=transit_key) as dk: + dil.got_key(key) + self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)]) + + wv = object() + dil.got_wormhole_versions(wv) + + dm1 = object() + dil.received_dilate(dm1) + + self.assertEqual(mm.mock_calls, []) + + with mock.patch("wormhole._dilation.manager.Manager", mm): + dil.dilate() + self.assertEqual(m.mock_calls, [mock.call.got_dilation_key(transit_key), + mock.call.got_wormhole_versions(wv), + mock.call.received_dilation_message(dm1), + mock.call.get_endpoints(), + ]) clear_mock_calls(m) - eps = self.successResultOf(d1) - self.assertEqual(eps, self.successResultOf(d2)) - d3 = dil.dilate() - eq.flush_sync() - self.assertEqual(eps, self.successResultOf(d3)) + dm2 = object() + dil.received_dilate(dm2) + self.assertEqual(m.mock_calls, [mock.call.received_dilation_message(dm2), + ]) - # all subsequent DILATE-n messages should get passed to the manager - self.assertEqual(m.mock_calls, []) - pleasemsg = dict(type="please", side="them") - dil.received_dilate(dict_to_bytes(pleasemsg)) - self.assertEqual(m.mock_calls, [mock.call.rx_PLEASE(pleasemsg)]) - clear_mock_calls(m) - - hintmsg = dict(type="connection-hints") - dil.received_dilate(dict_to_bytes(hintmsg)) - self.assertEqual(m.mock_calls, [mock.call.rx_HINTS(hintmsg)]) - clear_mock_calls(m) - - # we're nominally the LEADER, and the leader would not normally be - # receiving a RECONNECT, but since we've mocked out the Manager it - # won't notice - dil.received_dilate(dict_to_bytes(dict(type="reconnect"))) - self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECT()]) - clear_mock_calls(m) - - dil.received_dilate(dict_to_bytes(dict(type="reconnecting"))) - self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECTING()]) - clear_mock_calls(m) - - dil.received_dilate(dict_to_bytes(dict(type="unknown"))) - self.assertEqual(m.mock_calls, []) - self.flushLoggedErrors(UnknownDilationMessageType) + def test_stop_early(self): + (dil, h) = make_dilator() + # we stop before w.dilate(), so there is no Manager to stop + dil.stop() + self.assertEqual(h.terminator.mock_calls, [mock.call.stoppedD()]) def test_peer_cannot_dilate(self): - dil, send, reactor, eq, clock, coop = make_dilator() - d1 = dil.dilate() - self.assertNoResult(d1) + (dil, h) = make_dilator() + eps = dil.dilate() - dil._transit_key = b"\x01" * 32 + dil.got_key(b"\x01" * 32) dil.got_wormhole_versions({}) # missing "can-dilate" - eq.flush_sync() - f = self.failureResultOf(d1) - f.check(OldPeerCannotDilateError) + d = eps.connect.connect(None) + h.eq.flush_sync() + self.failureResultOf(d).check(OldPeerCannotDilateError) def test_disjoint_versions(self): - dil, send, reactor, eq, clock, coop = make_dilator() - d1 = dil.dilate() - self.assertNoResult(d1) + (dil, h) = make_dilator() + eps = dil.dilate() - dil._transit_key = b"key" + dil.got_key(b"\x01" * 32) dil.got_wormhole_versions({"can-dilate": [-1]}) - eq.flush_sync() - f = self.failureResultOf(d1) - f.check(OldPeerCannotDilateError) - - def test_early_dilate_messages(self): - dil, send, reactor, eq, clock, coop = make_dilator() - dil._transit_key = b"key" - d1 = dil.dilate() - host_addr = dil._host_addr - self.assertNoResult(d1) - pleasemsg = dict(type="please", side="them") - dil.received_dilate(dict_to_bytes(pleasemsg)) - hintmsg = dict(type="connection-hints") - dil.received_dilate(dict_to_bytes(hintmsg)) - - m = mock.Mock() - alsoProvides(m, IDilationManager) - m.when_first_connected.return_value = Deferred() - - scid0 = 0 - sc = mock.Mock() - m_sc = mock.patch("wormhole._dilation.manager.SubChannel", - return_value=sc) - - with mock.patch("wormhole._dilation.manager.Manager", - return_value=m) as ml: - with mock.patch("wormhole._dilation.manager.make_side", - return_value="us"): - with m_sc: - dil.got_wormhole_versions({"can-dilate": ["1"]}) - self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", - None, reactor, eq, coop, host_addr, False)]) - self.assertEqual(m.mock_calls, [mock.call.set_subchannel_zero(scid0, sc), - mock.call.start(), - mock.call.rx_PLEASE(pleasemsg), - mock.call.rx_HINTS(hintmsg), - mock.call.when_first_connected()]) + d = eps.connect.connect(None) + h.eq.flush_sync() + self.failureResultOf(d).check(OldPeerCannotDilateError) def test_transit_relay(self): - dil, send, reactor, eq, clock, coop = make_dilator() - dil._transit_key = b"key" - host_addr = dil._host_addr - relay = object() - d1 = dil.dilate(transit_relay_location=relay) - self.assertNoResult(d1) - - scid0 = 0 - sc = mock.Mock() - m_sc = mock.patch("wormhole._dilation.manager.SubChannel", - return_value=sc) - - with mock.patch("wormhole._dilation.manager.Manager") as ml: - with mock.patch("wormhole._dilation.manager.make_side", - return_value="us"): - with m_sc: - dil.got_wormhole_versions({"can-dilate": ["1"]}) - self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", - relay, reactor, eq, coop, host_addr, False), - mock.call().set_subchannel_zero(scid0, sc), - mock.call().start(), - mock.call().when_first_connected()]) - + (dil, h) = make_dilator() + transit_relay_location = object() + side = object() + m = mock.Mock() + mm = mock.Mock(side_effect=[m]) + with mock.patch("wormhole._dilation.manager.Manager", mm), \ + mock.patch("wormhole._dilation.manager.make_side", + return_value=side): + dil.dilate(transit_relay_location) + self.assertEqual(mm.mock_calls, [mock.call(h.send, side, transit_relay_location, + h.reactor, h.eq, h.coop, False)]) LEADER = "ff3456abcdef" FOLLOWER = "123456abcdef" def make_manager(leader=True): - class Holder: - pass h = Holder() h.send = mock.Mock() alsoProvides(h.send, ISend) @@ -250,11 +196,19 @@ def make_manager(leader=True): h.Inbound = mock.Mock(return_value=h.inbound) h.outbound = mock.Mock() h.Outbound = mock.Mock(return_value=h.outbound) - h.hostaddr = mock.Mock() - alsoProvides(h.hostaddr, IAddress) - with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound): - with mock.patch("wormhole._dilation.manager.Outbound", h.Outbound): - m = Manager(h.send, side, h.key, h.relay, h.reactor, h.eq, h.coop, h.hostaddr) + h.sc0 = mock.Mock() + alsoProvides(h.sc0, ISubChannel) + h.SubChannel = mock.Mock(return_value=h.sc0) + h.listen_ep = mock.Mock() + alsoProvides(h.listen_ep, IStreamServerEndpoint) + with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound), \ + mock.patch("wormhole._dilation.manager.Outbound", h.Outbound), \ + mock.patch("wormhole._dilation.manager.SubChannel", h.SubChannel), \ + mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", + return_value=h.listen_ep): + m = Manager(h.send, side, h.relay, h.reactor, h.eq, h.coop) + h.hostaddr = m._host_addr + m.got_dilation_key(h.key) return m, h @@ -272,16 +226,26 @@ class TestManager(unittest.TestCase): self.assertEqual(h.send.mock_calls, []) self.assertEqual(h.Inbound.mock_calls, [mock.call(m, h.hostaddr)]) self.assertEqual(h.Outbound.mock_calls, [mock.call(m, h.coop)]) + scid0 = 0 + sc0_peer_addr = _SubchannelAddress(scid0) + self.assertEqual(h.SubChannel.mock_calls, [ + mock.call(scid0, m, m._host_addr, sc0_peer_addr), + ]) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.set_subchannel_zero(scid0, h.sc0), + mock.call.set_listener_endpoint(h.listen_ep) + ]) + clear_mock_calls(h.inbound) - m.start() + m.got_wormhole_versions({"can-dilate": ["1"]}) self.assertEqual(h.send.mock_calls, [ mock.call.send("dilate-0", dict_to_bytes({"type": "please", "side": LEADER})) ]) clear_mock_calls(h.send) - wfc_d = m.when_first_connected() - self.assertNoResult(wfc_d) + listen_d = m.get_endpoints().listen.listen(None) + self.assertNoResult(listen_d) # ignore early hints m.rx_HINTS({}) @@ -303,7 +267,7 @@ class TestManager(unittest.TestCase): self.assertEqual(c.mock_calls, [mock.call.start()]) clear_mock_calls(connector, c) - self.assertNoResult(wfc_d) + self.assertNoResult(listen_d) # now any inbound hints should get passed to our Connector with mock.patch("wormhole._dilation.manager.parse_hint", @@ -323,7 +287,7 @@ class TestManager(unittest.TestCase): clear_mock_calls(h.send) # the first successful connection fires when_first_connected(), so - # the Dilator can create and return the endpoints + # the endpoints can activate c1 = mock.Mock() m.connector_connection_made(c1) @@ -331,23 +295,6 @@ class TestManager(unittest.TestCase): self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c1)]) clear_mock_calls(h.inbound, h.outbound) - h.eq.flush_sync() - self.successResultOf(wfc_d) # fires with None - wfc_d2 = m.when_first_connected() - h.eq.flush_sync() - self.successResultOf(wfc_d2) - - scid0 = 0 - sc0 = mock.Mock() - m.set_subchannel_zero(scid0, sc0) - listen_ep = mock.Mock() - m.set_listener_endpoint(listen_ep) - self.assertEqual(h.inbound.mock_calls, [ - mock.call.set_subchannel_zero(scid0, sc0), - mock.call.set_listener_endpoint(listen_ep), - ]) - clear_mock_calls(h.inbound) - # the Leader making a new outbound channel should get scid=1 scid1 = 1 self.assertEqual(m.allocate_subchannel_id(), scid1) @@ -494,12 +441,13 @@ class TestManager(unittest.TestCase): def test_follower(self): m, h = make_manager(leader=False) - m.start() + m.got_wormhole_versions({"can-dilate": ["1"]}) self.assertEqual(h.send.mock_calls, [ mock.call.send("dilate-0", dict_to_bytes({"type": "please", "side": FOLLOWER})) ]) clear_mock_calls(h.send) + clear_mock_calls(h.inbound) c = mock.Mock() connector = mock.Mock(return_value=c) @@ -629,6 +577,7 @@ class TestManager(unittest.TestCase): def test_subchannel(self): m, h = make_manager(leader=True) + clear_mock_calls(h.inbound) sc = object() m.subchannel_pauseProducing(sc) @@ -665,3 +614,14 @@ class TestManager(unittest.TestCase): self.assertEqual(h.outbound.mock_calls, [ mock.call.subchannel_closed(4, sc)]) clear_mock_calls(h.inbound, h.outbound) + + def test_unknown_message(self): + # receive a PLEASE with the same side as us: shouldn't happen + m, h = make_manager(leader=True) + m.start() + + m.received_dilation_message(dict_to_bytes(dict(type="unknown"))) + self.flushLoggedErrors(UnknownDilationMessageType) + + # TODO: test transit relay is used +