diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index 14b0c89..19430bb 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -2,13 +2,20 @@ from __future__ import print_function, unicode_literals import six import os from collections import deque +try: + # py >= 3.3 + from collections.abc import Sequence +except ImportError: + # py 2 and py3 < 3.3 + from collections import Sequence from attr import attrs, attrib from attr.validators import provides, instance_of, optional from automat import MethodicalMachine from zope.interface import implementer -from twisted.internet.defer import Deferred, inlineCallbacks, returnValue -from twisted.internet.interfaces import IAddress -from twisted.python import log +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import (IStreamClientEndpoint, + IStreamServerEndpoint) +from twisted.python import log, failure from .._interfaces import IDilator, IDilationManager, ISend, ITerminator from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr from ..observer import OneShotObserver @@ -47,6 +54,15 @@ class UnexpectedKCM(Exception): class UnknownMessageType(Exception): pass +@attrs +class EndpointRecord(Sequence): + control = attrib(validator=provides(IStreamClientEndpoint)) + connect = attrib(validator=provides(IStreamClientEndpoint)) + listen = attrib(validator=provides(IStreamServerEndpoint)) + def __len__(self): + return 3 + def __getitem__(self, n): + return (self.control, self.connect, self.listen)[n] def make_side(): return bytes_to_hexstr(os.urandom(6)) @@ -93,13 +109,14 @@ def make_side(): class Manager(object): _S = attrib(validator=provides(ISend), repr=False) _my_side = attrib(validator=instance_of(type(u""))) - _transit_key = attrib(validator=instance_of(bytes), repr=False) _transit_relay_location = attrib(validator=optional(instance_of(str))) _reactor = attrib(repr=False) _eventual_queue = attrib(repr=False) _cooperator = attrib(repr=False) - _host_addr = attrib(validator=provides(IAddress)) - _no_listen = attrib(default=False) + # TODO: can this validator work when the parameter is optional? + _no_listen = attrib(validator=instance_of(bool), default=False) + + _dilation_key = None _tor = None # TODO _timing = None # TODO _next_subchannel_id = None # initialized in choose_role @@ -111,10 +128,10 @@ class Manager(object): self._got_versions_d = Deferred() self._my_role = None # determined upon rx_PLEASE + self._host_addr = _WormholeAddress() self._connection = None self._made_first_connection = False - self._first_connected = OneShotObserver(self._eventual_queue) self._stopped = OneShotObserver(self._eventual_queue) self._debug_stall_connector = False @@ -127,18 +144,81 @@ class Manager(object): self._inbound = Inbound(self, self._host_addr) self._outbound = Outbound(self, self._cooperator) # from us to peer - def set_listener_endpoint(self, listener_endpoint): - self._inbound.set_listener_endpoint(listener_endpoint) - - def set_subchannel_zero(self, scid0, sc0): + # We must open subchannel0 early, since messages may arrive very + # quickly once the connection is established. This subchannel may or + # may not ever get revealed to the caller, since the peer might not + # even be capable of dilation. + scid0 = 0 + peer_addr0 = _SubchannelAddress(scid0) + sc0 = SubChannel(scid0, self, self._host_addr, peer_addr0) self._inbound.set_subchannel_zero(scid0, sc0) - def when_first_connected(self): - return self._first_connected.when_fired() + # we can open non-zero subchannels as soon as we get our first + # connection, and we can make the Endpoints even earlier + control_ep = ControlEndpoint(peer_addr0, sc0, self._eventual_queue) + connect_ep = SubchannelConnectorEndpoint(self, self._host_addr, self._eventual_queue) + listen_ep = SubchannelListenerEndpoint(self, self._host_addr, self._eventual_queue) + # TODO: let inbound/outbound create the endpoints, then return them + # to us + self._inbound.set_listener_endpoint(listen_ep) + + self._endpoints = EndpointRecord(control_ep, connect_ep, listen_ep) + + def get_endpoints(self): + return self._endpoints + + def got_dilation_key(self, key): + assert isinstance(key, bytes) + self._dilation_key = key + + def got_wormhole_versions(self, their_wormhole_versions): + # this always happens before received_dilation_message + dilation_version = None + their_dilation_versions = set(their_wormhole_versions.get("can-dilate", [])) + my_versions = set(DILATION_VERSIONS) + shared_versions = my_versions.intersection(their_dilation_versions) + if "1" in shared_versions: + dilation_version = "1" + + # dilation_version is the best mutually-compatible version we have + # with the peer, or None if we have nothing in common + + if not dilation_version: # "1" or None + # TODO: be more specific about the error. dilation_version==None + # means we had no version in common with them, which could either + # be because they're so old they don't dilate at all, or because + # they're so new that they no longer accomodate our old version + self.fail(failure.Failure(OldPeerCannotDilateError())) + + self.start() + + def fail(self, f): + self._endpoints.control._main_channel_failed(f) + self._endpoints.connect._main_channel_failed(f) + self._endpoints.listen._main_channel_failed(f) + + def received_dilation_message(self, plaintext): + # this receives new in-order DILATE-n payloads, decrypted but not + # de-JSONed. + + message = bytes_to_dict(plaintext) + type = message["type"] + if type == "please": + self.rx_PLEASE(message) + elif type == "connection-hints": + self.rx_HINTS(message) + elif type == "reconnect": + self.rx_RECONNECT() + elif type == "reconnecting": + self.rx_RECONNECTING() + else: + log.err(UnknownDilationMessageType(message)) + return def when_stopped(self): return self._stopped.when_fired() + def send_dilation_phase(self, **fields): dilation_phase = self._next_dilation_phase self._next_dilation_phase += 1 @@ -204,7 +284,9 @@ class Manager(object): self._outbound.use_connection(c) # does c.registerProducer if not self._made_first_connection: self._made_first_connection = True - self._first_connected.fire(None) + self._endpoints.control._main_channel_ready() + self._endpoints.connect._main_channel_ready() + self._endpoints.listen._main_channel_ready() pass def connector_connection_lost(self): @@ -272,16 +354,11 @@ class Manager(object): # state machine - # We are born WANTING after the local app calls w.dilate(). We start - # CONNECTING when we receive PLEASE from the remote side - - def start(self): - self.send_please() - - def send_please(self): - self.send_dilation_phase(type="please", side=self._my_side) - @m.state(initial=True) + def WAITING(self): + pass # pragma: no cover + + @m.state() def WANTING(self): pass # pragma: no cover @@ -313,6 +390,10 @@ class Manager(object): def STOPPED(self): pass # pragma: no cover + @m.input() + def start(self): + pass # pragma: no cover + @m.input() def rx_PLEASE(self, message): pass # pragma: no cover @@ -350,6 +431,10 @@ class Manager(object): def stop(self): pass # pragma: no cover + @m.output() + def send_please(self): + self.send_dilation_phase(type="please", side=self._my_side) + @m.output() def choose_role(self, message): their_side = message["side"] @@ -378,7 +463,8 @@ class Manager(object): def _start_connecting(self): assert self._my_role is not None - self._connector = Connector(self._transit_key, + assert self._dilation_key is not None + self._connector = Connector(self._dilation_key, self._transit_relay_location, self, self._reactor, self._eventual_queue, @@ -422,6 +508,11 @@ class Manager(object): def notify_stopped(self): self._stopped.fire(None) + # We are born WAITING after the local app calls w.dilate(). We enter + # WANTING (and send a PLEASE) when we learn of a mutually-compatible + # dilation_version. + WAITING.upon(start, enter=WANTING, outputs=[send_please]) + # we start CONNECTING when we get rx_PLEASE WANTING.upon(rx_PLEASE, enter=CONNECTING, outputs=[choose_role, start_connecting_ignore_message]) @@ -489,12 +580,10 @@ class Dilator(object): _cooperator = attrib() def __attrs_post_init__(self): - self._got_versions_d = Deferred() - self._started = False - self._endpoints = OneShotObserver(self._eventual_queue) - self._pending_inbound_dilate_messages = deque() self._manager = None - self._host_addr = _WormholeAddress() + self._pending_dilation_key = None + self._pending_wormhole_versions = None + self._pending_inbound_dilate_messages = deque() def wire(self, sender, terminator): self._S = ISend(sender) @@ -502,77 +591,35 @@ class Dilator(object): # this is the primary entry point, called when w.dilate() is invoked def dilate(self, transit_relay_location=None, no_listen=False): - self._transit_relay_location = transit_relay_location - if not self._started: - self._started = True - self._start(no_listen).addBoth(self._endpoints.fire) - return self._endpoints.when_fired() - - @inlineCallbacks - def _start(self, no_listen): - # first, we wait until we hear the VERSION message, which tells us 1: - # the PAKE key works, so we can talk securely, 2: that they can do - # dilation at all (if they can't then w.dilate() errbacks) - - dilation_version = yield self._got_versions_d - - # TODO: we could probably return the endpoints earlier, if we flunk - # any connection/listen attempts upon OldPeerCannotDilateError, or - # if/when we give up on the initial connection - - if not dilation_version: # "1" or None - # TODO: be more specific about the error. dilation_version==None - # means we had no version in common with them, which could either - # be because they're so old they don't dilate at all, or because - # they're so new that they no longer accomodate our old version - raise OldPeerCannotDilateError() - - my_dilation_side = make_side() - self._manager = Manager(self._S, my_dilation_side, - self._transit_key, - self._transit_relay_location, - self._reactor, self._eventual_queue, - self._cooperator, self._host_addr, no_listen) - # We must open subchannel0 early, since messages may arrive very - # quickly once the connection is established. This subchannel may or - # may not ever get revealed to the caller, since the peer might not - # even be capable of dilation. - scid0 = 0 - peer_addr0 = _SubchannelAddress(scid0) - sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0) - self._manager.set_subchannel_zero(scid0, sc0) - - # we can open non-zero subchannels as soon as we get our first - # connection, and we can make the Endpoints even earlier - control_ep = ControlEndpoint(peer_addr0) - control_ep._subchannel_zero_opened(sc0) - connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) - - listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) - self._manager.set_listener_endpoint(listen_ep) - - self._manager.start() - - while self._pending_inbound_dilate_messages: - plaintext = self._pending_inbound_dilate_messages.popleft() - self.received_dilate(plaintext) - - yield self._manager.when_first_connected() - - endpoints = (control_ep, connect_ep, listen_ep) - returnValue(endpoints) + if not self._manager: + # build the manager right away, and tell it later when the + # VERSIONS message arrives, and also when the dilation_key is set + my_dilation_side = make_side() + m = Manager(self._S, my_dilation_side, + transit_relay_location, + self._reactor, self._eventual_queue, + self._cooperator, no_listen) + self._manager = m + if self._pending_dilation_key is not None: + m.got_dilation_key(self._pending_dilation_key) + if self._pending_wormhole_versions: + m.got_wormhole_versions(self._pending_wormhole_versions) + while self._pending_inbound_dilate_messages: + plaintext = self._pending_inbound_dilate_messages.popleft() + m.received_dilation_message(plaintext) + return self._manager.get_endpoints() # Called by Terminator after everything else (mailbox, nameplate, server # connection) has shut down. Expects to fire T.stoppedD() when Dilator is # stopped too. def stop(self): - if not self._started: - self._T.stoppedD() - return - if self._started: + if self._manager: self._manager.stop() # TODO: avoid Deferreds for control flow, hard to serialize self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD()) + else: + self._T.stoppedD() + return # TODO: tolerate multiple calls # from Boss @@ -582,39 +629,20 @@ class Dilator(object): # to tolerate either ordering purpose = b"dilation-v1" LENGTH = 32 # TODO: whatever Noise wants, I guess - self._transit_key = derive_key(key, purpose, LENGTH) + dilation_key = derive_key(key, purpose, LENGTH) + if self._manager: + self._manager.got_dilation_key(dilation_key) + else: + self._pending_dilation_key = dilation_key def got_wormhole_versions(self, their_wormhole_versions): - assert self._transit_key is not None - # this always happens before received_dilate - dilation_version = None - their_dilation_versions = set(their_wormhole_versions.get("can-dilate", [])) - my_versions = set(DILATION_VERSIONS) - shared_versions = my_versions.intersection(their_dilation_versions) - if "1" in shared_versions: - dilation_version = "1" - self._got_versions_d.callback(dilation_version) + if self._manager: + self._manager.got_wormhole_versions(their_wormhole_versions) + else: + self._pending_wormhole_versions = their_wormhole_versions def received_dilate(self, plaintext): - # this receives new in-order DILATE-n payloads, decrypted but not - # de-JSONed. - - # this can appear before our .dilate() method is called, in which case - # we queue them for later if not self._manager: self._pending_inbound_dilate_messages.append(plaintext) - return - - message = bytes_to_dict(plaintext) - type = message["type"] - if type == "please": - self._manager.rx_PLEASE(message) - elif type == "connection-hints": - self._manager.rx_HINTS(message) - elif type == "reconnect": - self._manager.rx_RECONNECT() - elif type == "reconnecting": - self._manager.rx_RECONNECTING() else: - log.err(UnknownDilationMessageType(message)) - return + self._manager.received_dilation_message(plaintext) diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py index 32c9cff..e1e811c 100644 --- a/src/wormhole/_dilation/subchannel.py +++ b/src/wormhole/_dilation/subchannel.py @@ -3,8 +3,7 @@ from collections import deque from attr import attrs, attrib from attr.validators import instance_of, provides from zope.interface import implementer -from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue, - succeed) +from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, IAddress, IListeningPort, IStreamClientEndpoint, @@ -12,6 +11,7 @@ from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, from twisted.internet.error import ConnectionDone from automat import MethodicalMachine from .._interfaces import ISubChannel, IDilationManager +from ..observer import OneShotObserver # each subchannel frame (the data passed into transport.write(data)) gets a # 9-byte header prefix (type, subchannel id, and sequence number), then gets @@ -217,27 +217,33 @@ class SubChannel(object): @implementer(IStreamClientEndpoint) +@attrs class ControlEndpoint(object): + _peer_addr = attrib(validator=provides(IAddress)) + _subchannel_zero = attrib(validator=provides(ISubChannel)) + _eventual_queue = attrib(repr=False) _used = False - def __init__(self, peer_addr): - self._subchannel_zero = Deferred() - self._peer_addr = peer_addr + def __attrs_post_init__(self): self._once = Once(SingleUseEndpointError) + self._wait_for_main_channel = OneShotObserver(self._eventual_queue) # from manager - def _subchannel_zero_opened(self, subchannel): - assert ISubChannel.providedBy(subchannel), subchannel - self._subchannel_zero.callback(subchannel) + + def _main_channel_ready(self): + self._wait_for_main_channel.fire(None) + def _main_channel_failed(self, f): + self._wait_for_main_channel.error(f) @inlineCallbacks def connect(self, protocolFactory): # return Deferred that fires with IProtocol or Failure(ConnectError) self._once() - t = yield self._subchannel_zero + yield self._wait_for_main_channel.when_fired() p = protocolFactory.buildProtocol(self._peer_addr) - t._set_protocol(p) - p.makeConnection(t) # set p.transport = t and call connectionMade() + self._subchannel_zero._set_protocol(p) + # this sets p.transport and calls p.connectionMade() + p.makeConnection(self._subchannel_zero) returnValue(p) @@ -246,9 +252,21 @@ class ControlEndpoint(object): class SubchannelConnectorEndpoint(object): _manager = attrib(validator=provides(IDilationManager)) _host_addr = attrib(validator=instance_of(_WormholeAddress)) + _eventual_queue = attrib(repr=False) + def __attrs_post_init__(self): + self._connection_deferreds = deque() + self._wait_for_main_channel = OneShotObserver(self._eventual_queue) + + def _main_channel_ready(self): + self._wait_for_main_channel.fire(None) + def _main_channel_failed(self, f): + self._wait_for_main_channel.error(f) + + @inlineCallbacks def connect(self, protocolFactory): # return Deferred that fires with IProtocol or Failure(ConnectError) + yield self._wait_for_main_channel.when_fired() scid = self._manager.allocate_subchannel_id() self._manager.send_open(scid) peer_addr = _SubchannelAddress(scid) @@ -259,7 +277,7 @@ class SubchannelConnectorEndpoint(object): p = protocolFactory.buildProtocol(peer_addr) sc._set_protocol(p) p.makeConnection(sc) # set p.transport = sc and call connectionMade() - return succeed(p) + returnValue(p) @implementer(IStreamServerEndpoint) @@ -267,10 +285,13 @@ class SubchannelConnectorEndpoint(object): class SubchannelListenerEndpoint(object): _manager = attrib(validator=provides(IDilationManager)) _host_addr = attrib(validator=provides(IAddress)) + _eventual_queue = attrib(repr=False) def __attrs_post_init__(self): + self._once = Once(SingleUseEndpointError) self._factory = None self._pending_opens = deque() + self._wait_for_main_channel = OneShotObserver(self._eventual_queue) # from manager (actually Inbound) def _got_open(self, t, peer_addr): @@ -284,15 +305,23 @@ class SubchannelListenerEndpoint(object): t._set_protocol(p) p.makeConnection(t) + def _main_channel_ready(self): + self._wait_for_main_channel.fire(None) + def _main_channel_failed(self, f): + self._wait_for_main_channel.error(f) + # IStreamServerEndpoint + @inlineCallbacks def listen(self, protocolFactory): + self._once() + yield self._wait_for_main_channel.when_fired() self._factory = protocolFactory while self._pending_opens: (t, peer_addr) = self._pending_opens.popleft() self._connect(t, peer_addr) lp = SubchannelListeningPort(self._host_addr) - return succeed(lp) + returnValue(lp) @implementer(IListeningPort) diff --git a/src/wormhole/test/dilate/test_endpoints.py b/src/wormhole/test/dilate/test_endpoints.py index 08c5600..647fb9d 100644 --- a/src/wormhole/test/dilate/test_endpoints.py +++ b/src/wormhole/test/dilate/test_endpoints.py @@ -2,7 +2,10 @@ from __future__ import print_function, unicode_literals import mock from zope.interface import alsoProvides from twisted.trial import unittest +from twisted.internet.task import Clock +from twisted.python.failure import Failure from ..._interfaces import ISubChannel +from ...eventual import EventualQueue from ..._dilation.subchannel import (ControlEndpoint, SubchannelConnectorEndpoint, SubchannelListenerEndpoint, @@ -11,12 +14,18 @@ from ..._dilation.subchannel import (ControlEndpoint, SingleUseEndpointError) from .common import mock_manager +class CannotDilateError(Exception): + pass -class Endpoints(unittest.TestCase): - def test_control(self): +class Control(unittest.TestCase): + def test_early_succeed(self): + # ep.connect() is called before dilation can proceed scid0 = 0 peeraddr = _SubchannelAddress(scid0) - ep = ControlEndpoint(peeraddr) + sc0 = mock.Mock() + alsoProvides(sc0, ISubChannel) + eq = EventualQueue(Clock()) + ep = ControlEndpoint(peeraddr, sc0, eq) f = mock.Mock() p = mock.Mock() @@ -24,29 +33,105 @@ class Endpoints(unittest.TestCase): d = ep.connect(f) self.assertNoResult(d) - t = mock.Mock() - alsoProvides(t, ISubChannel) - ep._subchannel_zero_opened(t) + ep._main_channel_ready() + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), p) self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) - self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) - self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) + self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)]) d = ep.connect(f) self.failureResultOf(d, SingleUseEndpointError) - def assert_makeConnection(self, mock_calls): + def test_early_fail(self): + # ep.connect() is called before dilation is abandoned + scid0 = 0 + peeraddr = _SubchannelAddress(scid0) + sc0 = mock.Mock() + alsoProvides(sc0, ISubChannel) + eq = EventualQueue(Clock()) + ep = ControlEndpoint(peeraddr, sc0, eq) + + f = mock.Mock() + p = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + d = ep.connect(f) + self.assertNoResult(d) + + ep._main_channel_failed(Failure(CannotDilateError())) + eq.flush_sync() + + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + self.assertEqual(sc0.mock_calls, []) + + d = ep.connect(f) + self.failureResultOf(d, SingleUseEndpointError) + + def test_late_succeed(self): + # dilation can proceed, then ep.connect() is called + scid0 = 0 + peeraddr = _SubchannelAddress(scid0) + sc0 = mock.Mock() + alsoProvides(sc0, ISubChannel) + eq = EventualQueue(Clock()) + ep = ControlEndpoint(peeraddr, sc0, eq) + + ep._main_channel_ready() + + f = mock.Mock() + p = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + d = ep.connect(f) + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), p) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) + self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)]) + + d = ep.connect(f) + self.failureResultOf(d, SingleUseEndpointError) + + def test_late_fail(self): + # dilation is abandoned, then ep.connect() is called + scid0 = 0 + peeraddr = _SubchannelAddress(scid0) + sc0 = mock.Mock() + alsoProvides(sc0, ISubChannel) + eq = EventualQueue(Clock()) + ep = ControlEndpoint(peeraddr, sc0, eq) + + ep._main_channel_failed(Failure(CannotDilateError())) + + f = mock.Mock() + p = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + d = ep.connect(f) + eq.flush_sync() + + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + self.assertEqual(sc0.mock_calls, []) + + d = ep.connect(f) + self.failureResultOf(d, SingleUseEndpointError) + +class Endpoints(unittest.TestCase): + def OFFassert_makeConnection(self, mock_calls): self.assertEqual(len(mock_calls), 1) self.assertEqual(mock_calls[0][0], "makeConnection") self.assertEqual(len(mock_calls[0][1]), 1) return mock_calls[0][1][0] - def test_connector(self): +class Connector(unittest.TestCase): + def test_early_succeed(self): m = mock_manager() m.allocate_subchannel_id = mock.Mock(return_value=0) hostaddr = _WormholeAddress() peeraddr = _SubchannelAddress(0) - ep = SubchannelConnectorEndpoint(m, hostaddr) + eq = EventualQueue(Clock()) + ep = SubchannelConnectorEndpoint(m, hostaddr, eq) f = mock.Mock() p = mock.Mock() @@ -55,38 +140,123 @@ class Endpoints(unittest.TestCase): with mock.patch("wormhole._dilation.subchannel.SubChannel", return_value=t) as sc: d = ep.connect(f) + eq.flush_sync() + self.assertNoResult(d) + ep._main_channel_ready() + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), p) self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)]) self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) - def test_listener(self): + def test_early_fail(self): m = mock_manager() m.allocate_subchannel_id = mock.Mock(return_value=0) hostaddr = _WormholeAddress() - ep = SubchannelListenerEndpoint(m, hostaddr) + eq = EventualQueue(Clock()) + ep = SubchannelConnectorEndpoint(m, hostaddr, eq) + + f = mock.Mock() + p = mock.Mock() + t = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + with mock.patch("wormhole._dilation.subchannel.SubChannel", + return_value=t) as sc: + d = ep.connect(f) + eq.flush_sync() + self.assertNoResult(d) + ep._main_channel_failed(Failure(CannotDilateError())) + eq.flush_sync() + + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + self.assertEqual(sc.mock_calls, []) + self.assertEqual(t.mock_calls, []) + + def test_late_succeed(self): + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + peeraddr = _SubchannelAddress(0) + eq = EventualQueue(Clock()) + ep = SubchannelConnectorEndpoint(m, hostaddr, eq) + ep._main_channel_ready() + + f = mock.Mock() + p = mock.Mock() + t = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + with mock.patch("wormhole._dilation.subchannel.SubChannel", + return_value=t) as sc: + d = ep.connect(f) + eq.flush_sync() + + self.assertIdentical(self.successResultOf(d), p) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) + self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)]) + self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) + + def test_late_fail(self): + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelConnectorEndpoint(m, hostaddr, eq) + ep._main_channel_failed(Failure(CannotDilateError())) + + f = mock.Mock() + p = mock.Mock() + t = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + with mock.patch("wormhole._dilation.subchannel.SubChannel", + return_value=t) as sc: + d = ep.connect(f) + eq.flush_sync() + + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + self.assertEqual(sc.mock_calls, []) + self.assertEqual(t.mock_calls, []) + +class Listener(unittest.TestCase): + def test_early_succeed(self): + # listen, main_channel_ready, got_open, got_open + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelListenerEndpoint(m, hostaddr, eq) f = mock.Mock() p1 = mock.Mock() p2 = mock.Mock() f.buildProtocol = mock.Mock(side_effect=[p1, p2]) - # OPEN that arrives before we ep.listen() should be queued + d = ep.listen(f) + eq.flush_sync() + self.assertNoResult(d) + self.assertEqual(f.buildProtocol.mock_calls, []) + + ep._main_channel_ready() + eq.flush_sync() + lp = self.successResultOf(d) + self.assertIsInstance(lp, SubchannelListeningPort) + + self.assertEqual(lp.getHost(), hostaddr) + # TODO: IListeningPort says we must provide this, but I don't know + # that anyone would ever call it. + lp.startListening() t1 = mock.Mock() peeraddr1 = _SubchannelAddress(1) ep._got_open(t1, peeraddr1) - d = ep.listen(f) - lp = self.successResultOf(d) - self.assertIsInstance(lp, SubchannelListeningPort) - - self.assertEqual(lp.getHost(), hostaddr) - lp.startListening() - self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)]) self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)]) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)]) t2 = mock.Mock() peeraddr2 = _SubchannelAddress(2) @@ -94,5 +264,92 @@ class Endpoints(unittest.TestCase): self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1), + mock.call(peeraddr2)]) lp.stopListening() # TODO: should this do more? + + def test_early_fail(self): + # listen, main_channel_fail + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelListenerEndpoint(m, hostaddr, eq) + + f = mock.Mock() + p1 = mock.Mock() + p2 = mock.Mock() + f.buildProtocol = mock.Mock(side_effect=[p1, p2]) + + d = ep.listen(f) + eq.flush_sync() + self.assertNoResult(d) + + ep._main_channel_failed(Failure(CannotDilateError())) + eq.flush_sync() + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) + + def test_late_succeed(self): + # main_channel_ready, got_open, listen, got_open + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelListenerEndpoint(m, hostaddr, eq) + ep._main_channel_ready() + + f = mock.Mock() + p1 = mock.Mock() + p2 = mock.Mock() + f.buildProtocol = mock.Mock(side_effect=[p1, p2]) + + t1 = mock.Mock() + peeraddr1 = _SubchannelAddress(1) + ep._got_open(t1, peeraddr1) + eq.flush_sync() + + self.assertEqual(t1.mock_calls, []) + self.assertEqual(p1.mock_calls, []) + + d = ep.listen(f) + eq.flush_sync() + lp = self.successResultOf(d) + self.assertIsInstance(lp, SubchannelListeningPort) + self.assertEqual(lp.getHost(), hostaddr) + lp.startListening() + + self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)]) + self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)]) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)]) + + t2 = mock.Mock() + peeraddr2 = _SubchannelAddress(2) + ep._got_open(t2, peeraddr2) + + self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) + self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1), + mock.call(peeraddr2)]) + + lp.stopListening() # TODO: should this do more? + + def test_late_fail(self): + # main_channel_fail, listen + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=0) + hostaddr = _WormholeAddress() + eq = EventualQueue(Clock()) + ep = SubchannelListenerEndpoint(m, hostaddr, eq) + ep._main_channel_failed(Failure(CannotDilateError())) + + f = mock.Mock() + p1 = mock.Mock() + p2 = mock.Mock() + f.buildProtocol = mock.Mock(side_effect=[p1, p2]) + + d = ep.listen(f) + eq.flush_sync() + self.failureResultOf(d).check(CannotDilateError) + self.assertEqual(f.buildProtocol.mock_calls, []) diff --git a/src/wormhole/test/dilate/test_full.py b/src/wormhole/test/dilate/test_full.py index cc0e8b0..79c0019 100644 --- a/src/wormhole/test/dilate/test_full.py +++ b/src/wormhole/test/dilate/test_full.py @@ -46,24 +46,21 @@ class Full(ServerBase, unittest.TestCase): yield doBoth(w1.get_verifier(), w2.get_verifier()) print("connected") - eps1_d = w1.dilate() - eps2_d = w2.dilate() - (eps1, eps2) = yield doBoth(eps1_d, eps2_d) - (control_ep1, connect_ep1, listen_ep1) = eps1 - (control_ep2, connect_ep2, listen_ep2) = eps2 + eps1 = w1.dilate() + eps2 = w2.dilate() print("w.dilate ready") f1 = Factory() f1.protocol = L f1.d = Deferred() f1.d.addCallback(lambda data: eq.fire_eventually(data)) - d1 = control_ep1.connect(f1) + d1 = eps1.control.connect(f1) f2 = Factory() f2.protocol = L f2.d = Deferred() f2.d.addCallback(lambda data: eq.fire_eventually(data)) - d2 = control_ep2.connect(f2) + d2 = eps2.control.connect(f2) yield d1 yield d2 print("control endpoints connected") @@ -125,14 +122,12 @@ class Reconnect(ServerBase, unittest.TestCase): w2.set_code(code) yield doBoth(w1.get_verifier(), w2.get_verifier()) - eps1_d = w1.dilate() - eps2_d = w2.dilate() - (eps1, eps2) = yield doBoth(eps1_d, eps2_d) - (control_ep1, connect_ep1, listen_ep1) = eps1 - (control_ep2, connect_ep2, listen_ep2) = eps2 + eps1 = w1.dilate() + eps2 = w2.dilate() + print("w.dilate ready") f1 = ReconF(eq); f2 = ReconF(eq) - d1 = control_ep1.connect(f1); d2 = control_ep2.connect(f2) + d1 = eps1.control.connect(f1); d2 = eps2.control.connect(f2) yield d1 yield d2 @@ -194,14 +189,12 @@ class Reconnect(ServerBase, unittest.TestCase): w2.set_code(code) yield doBoth(w1.get_verifier(), w2.get_verifier()) - eps1_d = w1.dilate() - eps2_d = w2.dilate() - (eps1, eps2) = yield doBoth(eps1_d, eps2_d) - (control_ep1, connect_ep1, listen_ep1) = eps1 - (control_ep2, connect_ep2, listen_ep2) = eps2 + eps1 = w1.dilate() + eps2 = w2.dilate() + print("w.dilate ready") f1 = ReconF(eq); f2 = ReconF(eq) - d1 = control_ep1.connect(f1); d2 = control_ep2.connect(f2) + d1 = eps1.control.connect(f1); d2 = eps2.control.connect(f2) yield d1 yield d2 @@ -287,19 +280,17 @@ class Endpoints(ServerBase, unittest.TestCase): w2.set_code(code) yield doBoth(w1.get_verifier(), w2.get_verifier()) - eps1_d = w1.dilate() - eps2_d = w2.dilate() - (eps1, eps2) = yield doBoth(eps1_d, eps2_d) - (control_ep1, connect_ep1, listen_ep1) = eps1 - (control_ep2, connect_ep2, listen_ep2) = eps2 + eps1 = w1.dilate() + eps2 = w2.dilate() + print("w.dilate ready") f0 = ReconF(eq) - yield listen_ep2.listen(f0) + yield eps2.listen.listen(f0) from twisted.python import log f1 = ReconF(eq) log.msg("connecting") - p1_client = yield connect_ep1.connect(f1) + p1_client = yield eps1.connect.connect(f1) log.msg("sending c->s") p1_client.transport.write(b"hello from p1\n") data = yield f0.deferreds["dataReceived"] @@ -316,7 +307,7 @@ class Endpoints(ServerBase, unittest.TestCase): f0.resetDeferred("dataReceived") f1.resetDeferred("dataReceived") f2 = ReconF(eq) - p2_client = yield connect_ep1.connect(f2) + p2_client = yield eps1.connect.connect(f2) p2_server = yield f0.deferreds["connectionMade"] p2_server.transport.write(b"hello p2\n") data = yield f2.deferreds["dataReceived"] diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index cea92de..309c1fd 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -140,25 +140,23 @@ class TestDilator(unittest.TestCase): def test_peer_cannot_dilate(self): dil, send, reactor, eq, clock, coop = make_dilator() - d1 = dil.dilate() - self.assertNoResult(d1) + eps = dil.dilate() - dil._transit_key = b"\x01" * 32 + dil.got_key(b"\x01" * 32) dil.got_wormhole_versions({}) # missing "can-dilate" + d = eps.connect.connect(None) eq.flush_sync() - f = self.failureResultOf(d1) - f.check(OldPeerCannotDilateError) + self.failureResultOf(d).check(OldPeerCannotDilateError) def test_disjoint_versions(self): dil, send, reactor, eq, clock, coop = make_dilator() - d1 = dil.dilate() - self.assertNoResult(d1) + eps = dil.dilate() - dil._transit_key = b"key" + dil.got_key(b"\x01" * 32) dil.got_wormhole_versions({"can-dilate": [-1]}) + d = eps.connect.connect(None) eq.flush_sync() - f = self.failureResultOf(d1) - f.check(OldPeerCannotDilateError) + self.failureResultOf(d).check(OldPeerCannotDilateError) def test_early_dilate_messages(self): dil, send, reactor, eq, clock, coop = make_dilator() @@ -276,11 +274,11 @@ def make_manager(leader=True): h.Inbound = mock.Mock(return_value=h.inbound) h.outbound = mock.Mock() h.Outbound = mock.Mock(return_value=h.outbound) - h.hostaddr = mock.Mock() - alsoProvides(h.hostaddr, IAddress) - with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound): - with mock.patch("wormhole._dilation.manager.Outbound", h.Outbound): - m = Manager(h.send, side, h.key, h.relay, h.reactor, h.eq, h.coop, h.hostaddr) + with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound), \ + mock.patch("wormhole._dilation.manager.Outbound", h.Outbound): + m = Manager(h.send, side, h.relay, h.reactor, h.eq, h.coop) + h.hostaddr = m._host_addr + m.got_dilation_key(h.key) return m, h