WIP: rewrite w.dilate API to return endpoints synchronously

test_manager still needs rewriting
This commit is contained in:
Brian Warner 2019-07-12 00:01:55 -07:00
parent 75fad02a28
commit 85cb003498
5 changed files with 500 additions and 197 deletions

View File

@ -2,13 +2,20 @@ from __future__ import print_function, unicode_literals
import six import six
import os import os
from collections import deque from collections import deque
try:
# py >= 3.3
from collections.abc import Sequence
except ImportError:
# py 2 and py3 < 3.3
from collections import Sequence
from attr import attrs, attrib from attr import attrs, attrib
from attr.validators import provides, instance_of, optional from attr.validators import provides, instance_of, optional
from automat import MethodicalMachine from automat import MethodicalMachine
from zope.interface import implementer from zope.interface import implementer
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress from twisted.internet.interfaces import (IStreamClientEndpoint,
from twisted.python import log IStreamServerEndpoint)
from twisted.python import log, failure
from .._interfaces import IDilator, IDilationManager, ISend, ITerminator from .._interfaces import IDilator, IDilationManager, ISend, ITerminator
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr
from ..observer import OneShotObserver from ..observer import OneShotObserver
@ -47,6 +54,15 @@ class UnexpectedKCM(Exception):
class UnknownMessageType(Exception): class UnknownMessageType(Exception):
pass pass
@attrs
class EndpointRecord(Sequence):
control = attrib(validator=provides(IStreamClientEndpoint))
connect = attrib(validator=provides(IStreamClientEndpoint))
listen = attrib(validator=provides(IStreamServerEndpoint))
def __len__(self):
return 3
def __getitem__(self, n):
return (self.control, self.connect, self.listen)[n]
def make_side(): def make_side():
return bytes_to_hexstr(os.urandom(6)) return bytes_to_hexstr(os.urandom(6))
@ -93,13 +109,14 @@ def make_side():
class Manager(object): class Manager(object):
_S = attrib(validator=provides(ISend), repr=False) _S = attrib(validator=provides(ISend), repr=False)
_my_side = attrib(validator=instance_of(type(u""))) _my_side = attrib(validator=instance_of(type(u"")))
_transit_key = attrib(validator=instance_of(bytes), repr=False)
_transit_relay_location = attrib(validator=optional(instance_of(str))) _transit_relay_location = attrib(validator=optional(instance_of(str)))
_reactor = attrib(repr=False) _reactor = attrib(repr=False)
_eventual_queue = attrib(repr=False) _eventual_queue = attrib(repr=False)
_cooperator = attrib(repr=False) _cooperator = attrib(repr=False)
_host_addr = attrib(validator=provides(IAddress)) # TODO: can this validator work when the parameter is optional?
_no_listen = attrib(default=False) _no_listen = attrib(validator=instance_of(bool), default=False)
_dilation_key = None
_tor = None # TODO _tor = None # TODO
_timing = None # TODO _timing = None # TODO
_next_subchannel_id = None # initialized in choose_role _next_subchannel_id = None # initialized in choose_role
@ -111,10 +128,10 @@ class Manager(object):
self._got_versions_d = Deferred() self._got_versions_d = Deferred()
self._my_role = None # determined upon rx_PLEASE self._my_role = None # determined upon rx_PLEASE
self._host_addr = _WormholeAddress()
self._connection = None self._connection = None
self._made_first_connection = False self._made_first_connection = False
self._first_connected = OneShotObserver(self._eventual_queue)
self._stopped = OneShotObserver(self._eventual_queue) self._stopped = OneShotObserver(self._eventual_queue)
self._debug_stall_connector = False self._debug_stall_connector = False
@ -127,18 +144,81 @@ class Manager(object):
self._inbound = Inbound(self, self._host_addr) self._inbound = Inbound(self, self._host_addr)
self._outbound = Outbound(self, self._cooperator) # from us to peer self._outbound = Outbound(self, self._cooperator) # from us to peer
def set_listener_endpoint(self, listener_endpoint): # We must open subchannel0 early, since messages may arrive very
self._inbound.set_listener_endpoint(listener_endpoint) # quickly once the connection is established. This subchannel may or
# may not ever get revealed to the caller, since the peer might not
def set_subchannel_zero(self, scid0, sc0): # even be capable of dilation.
scid0 = 0
peer_addr0 = _SubchannelAddress(scid0)
sc0 = SubChannel(scid0, self, self._host_addr, peer_addr0)
self._inbound.set_subchannel_zero(scid0, sc0) self._inbound.set_subchannel_zero(scid0, sc0)
def when_first_connected(self): # we can open non-zero subchannels as soon as we get our first
return self._first_connected.when_fired() # connection, and we can make the Endpoints even earlier
control_ep = ControlEndpoint(peer_addr0, sc0, self._eventual_queue)
connect_ep = SubchannelConnectorEndpoint(self, self._host_addr, self._eventual_queue)
listen_ep = SubchannelListenerEndpoint(self, self._host_addr, self._eventual_queue)
# TODO: let inbound/outbound create the endpoints, then return them
# to us
self._inbound.set_listener_endpoint(listen_ep)
self._endpoints = EndpointRecord(control_ep, connect_ep, listen_ep)
def get_endpoints(self):
return self._endpoints
def got_dilation_key(self, key):
assert isinstance(key, bytes)
self._dilation_key = key
def got_wormhole_versions(self, their_wormhole_versions):
# this always happens before received_dilation_message
dilation_version = None
their_dilation_versions = set(their_wormhole_versions.get("can-dilate", []))
my_versions = set(DILATION_VERSIONS)
shared_versions = my_versions.intersection(their_dilation_versions)
if "1" in shared_versions:
dilation_version = "1"
# dilation_version is the best mutually-compatible version we have
# with the peer, or None if we have nothing in common
if not dilation_version: # "1" or None
# TODO: be more specific about the error. dilation_version==None
# means we had no version in common with them, which could either
# be because they're so old they don't dilate at all, or because
# they're so new that they no longer accomodate our old version
self.fail(failure.Failure(OldPeerCannotDilateError()))
self.start()
def fail(self, f):
self._endpoints.control._main_channel_failed(f)
self._endpoints.connect._main_channel_failed(f)
self._endpoints.listen._main_channel_failed(f)
def received_dilation_message(self, plaintext):
# this receives new in-order DILATE-n payloads, decrypted but not
# de-JSONed.
message = bytes_to_dict(plaintext)
type = message["type"]
if type == "please":
self.rx_PLEASE(message)
elif type == "connection-hints":
self.rx_HINTS(message)
elif type == "reconnect":
self.rx_RECONNECT()
elif type == "reconnecting":
self.rx_RECONNECTING()
else:
log.err(UnknownDilationMessageType(message))
return
def when_stopped(self): def when_stopped(self):
return self._stopped.when_fired() return self._stopped.when_fired()
def send_dilation_phase(self, **fields): def send_dilation_phase(self, **fields):
dilation_phase = self._next_dilation_phase dilation_phase = self._next_dilation_phase
self._next_dilation_phase += 1 self._next_dilation_phase += 1
@ -204,7 +284,9 @@ class Manager(object):
self._outbound.use_connection(c) # does c.registerProducer self._outbound.use_connection(c) # does c.registerProducer
if not self._made_first_connection: if not self._made_first_connection:
self._made_first_connection = True self._made_first_connection = True
self._first_connected.fire(None) self._endpoints.control._main_channel_ready()
self._endpoints.connect._main_channel_ready()
self._endpoints.listen._main_channel_ready()
pass pass
def connector_connection_lost(self): def connector_connection_lost(self):
@ -272,16 +354,11 @@ class Manager(object):
# state machine # state machine
# We are born WANTING after the local app calls w.dilate(). We start
# CONNECTING when we receive PLEASE from the remote side
def start(self):
self.send_please()
def send_please(self):
self.send_dilation_phase(type="please", side=self._my_side)
@m.state(initial=True) @m.state(initial=True)
def WAITING(self):
pass # pragma: no cover
@m.state()
def WANTING(self): def WANTING(self):
pass # pragma: no cover pass # pragma: no cover
@ -313,6 +390,10 @@ class Manager(object):
def STOPPED(self): def STOPPED(self):
pass # pragma: no cover pass # pragma: no cover
@m.input()
def start(self):
pass # pragma: no cover
@m.input() @m.input()
def rx_PLEASE(self, message): def rx_PLEASE(self, message):
pass # pragma: no cover pass # pragma: no cover
@ -350,6 +431,10 @@ class Manager(object):
def stop(self): def stop(self):
pass # pragma: no cover pass # pragma: no cover
@m.output()
def send_please(self):
self.send_dilation_phase(type="please", side=self._my_side)
@m.output() @m.output()
def choose_role(self, message): def choose_role(self, message):
their_side = message["side"] their_side = message["side"]
@ -378,7 +463,8 @@ class Manager(object):
def _start_connecting(self): def _start_connecting(self):
assert self._my_role is not None assert self._my_role is not None
self._connector = Connector(self._transit_key, assert self._dilation_key is not None
self._connector = Connector(self._dilation_key,
self._transit_relay_location, self._transit_relay_location,
self, self,
self._reactor, self._eventual_queue, self._reactor, self._eventual_queue,
@ -422,6 +508,11 @@ class Manager(object):
def notify_stopped(self): def notify_stopped(self):
self._stopped.fire(None) self._stopped.fire(None)
# We are born WAITING after the local app calls w.dilate(). We enter
# WANTING (and send a PLEASE) when we learn of a mutually-compatible
# dilation_version.
WAITING.upon(start, enter=WANTING, outputs=[send_please])
# we start CONNECTING when we get rx_PLEASE # we start CONNECTING when we get rx_PLEASE
WANTING.upon(rx_PLEASE, enter=CONNECTING, WANTING.upon(rx_PLEASE, enter=CONNECTING,
outputs=[choose_role, start_connecting_ignore_message]) outputs=[choose_role, start_connecting_ignore_message])
@ -489,12 +580,10 @@ class Dilator(object):
_cooperator = attrib() _cooperator = attrib()
def __attrs_post_init__(self): def __attrs_post_init__(self):
self._got_versions_d = Deferred()
self._started = False
self._endpoints = OneShotObserver(self._eventual_queue)
self._pending_inbound_dilate_messages = deque()
self._manager = None self._manager = None
self._host_addr = _WormholeAddress() self._pending_dilation_key = None
self._pending_wormhole_versions = None
self._pending_inbound_dilate_messages = deque()
def wire(self, sender, terminator): def wire(self, sender, terminator):
self._S = ISend(sender) self._S = ISend(sender)
@ -502,77 +591,35 @@ class Dilator(object):
# this is the primary entry point, called when w.dilate() is invoked # this is the primary entry point, called when w.dilate() is invoked
def dilate(self, transit_relay_location=None, no_listen=False): def dilate(self, transit_relay_location=None, no_listen=False):
self._transit_relay_location = transit_relay_location if not self._manager:
if not self._started: # build the manager right away, and tell it later when the
self._started = True # VERSIONS message arrives, and also when the dilation_key is set
self._start(no_listen).addBoth(self._endpoints.fire)
return self._endpoints.when_fired()
@inlineCallbacks
def _start(self, no_listen):
# first, we wait until we hear the VERSION message, which tells us 1:
# the PAKE key works, so we can talk securely, 2: that they can do
# dilation at all (if they can't then w.dilate() errbacks)
dilation_version = yield self._got_versions_d
# TODO: we could probably return the endpoints earlier, if we flunk
# any connection/listen attempts upon OldPeerCannotDilateError, or
# if/when we give up on the initial connection
if not dilation_version: # "1" or None
# TODO: be more specific about the error. dilation_version==None
# means we had no version in common with them, which could either
# be because they're so old they don't dilate at all, or because
# they're so new that they no longer accomodate our old version
raise OldPeerCannotDilateError()
my_dilation_side = make_side() my_dilation_side = make_side()
self._manager = Manager(self._S, my_dilation_side, m = Manager(self._S, my_dilation_side,
self._transit_key, transit_relay_location,
self._transit_relay_location,
self._reactor, self._eventual_queue, self._reactor, self._eventual_queue,
self._cooperator, self._host_addr, no_listen) self._cooperator, no_listen)
# We must open subchannel0 early, since messages may arrive very self._manager = m
# quickly once the connection is established. This subchannel may or if self._pending_dilation_key is not None:
# may not ever get revealed to the caller, since the peer might not m.got_dilation_key(self._pending_dilation_key)
# even be capable of dilation. if self._pending_wormhole_versions:
scid0 = 0 m.got_wormhole_versions(self._pending_wormhole_versions)
peer_addr0 = _SubchannelAddress(scid0)
sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0)
self._manager.set_subchannel_zero(scid0, sc0)
# we can open non-zero subchannels as soon as we get our first
# connection, and we can make the Endpoints even earlier
control_ep = ControlEndpoint(peer_addr0)
control_ep._subchannel_zero_opened(sc0)
connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr)
listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr)
self._manager.set_listener_endpoint(listen_ep)
self._manager.start()
while self._pending_inbound_dilate_messages: while self._pending_inbound_dilate_messages:
plaintext = self._pending_inbound_dilate_messages.popleft() plaintext = self._pending_inbound_dilate_messages.popleft()
self.received_dilate(plaintext) m.received_dilation_message(plaintext)
return self._manager.get_endpoints()
yield self._manager.when_first_connected()
endpoints = (control_ep, connect_ep, listen_ep)
returnValue(endpoints)
# Called by Terminator after everything else (mailbox, nameplate, server # Called by Terminator after everything else (mailbox, nameplate, server
# connection) has shut down. Expects to fire T.stoppedD() when Dilator is # connection) has shut down. Expects to fire T.stoppedD() when Dilator is
# stopped too. # stopped too.
def stop(self): def stop(self):
if not self._started: if self._manager:
self._T.stoppedD()
return
if self._started:
self._manager.stop() self._manager.stop()
# TODO: avoid Deferreds for control flow, hard to serialize # TODO: avoid Deferreds for control flow, hard to serialize
self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD()) self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD())
else:
self._T.stoppedD()
return
# TODO: tolerate multiple calls # TODO: tolerate multiple calls
# from Boss # from Boss
@ -582,39 +629,20 @@ class Dilator(object):
# to tolerate either ordering # to tolerate either ordering
purpose = b"dilation-v1" purpose = b"dilation-v1"
LENGTH = 32 # TODO: whatever Noise wants, I guess LENGTH = 32 # TODO: whatever Noise wants, I guess
self._transit_key = derive_key(key, purpose, LENGTH) dilation_key = derive_key(key, purpose, LENGTH)
if self._manager:
self._manager.got_dilation_key(dilation_key)
else:
self._pending_dilation_key = dilation_key
def got_wormhole_versions(self, their_wormhole_versions): def got_wormhole_versions(self, their_wormhole_versions):
assert self._transit_key is not None if self._manager:
# this always happens before received_dilate self._manager.got_wormhole_versions(their_wormhole_versions)
dilation_version = None else:
their_dilation_versions = set(their_wormhole_versions.get("can-dilate", [])) self._pending_wormhole_versions = their_wormhole_versions
my_versions = set(DILATION_VERSIONS)
shared_versions = my_versions.intersection(their_dilation_versions)
if "1" in shared_versions:
dilation_version = "1"
self._got_versions_d.callback(dilation_version)
def received_dilate(self, plaintext): def received_dilate(self, plaintext):
# this receives new in-order DILATE-n payloads, decrypted but not
# de-JSONed.
# this can appear before our .dilate() method is called, in which case
# we queue them for later
if not self._manager: if not self._manager:
self._pending_inbound_dilate_messages.append(plaintext) self._pending_inbound_dilate_messages.append(plaintext)
return
message = bytes_to_dict(plaintext)
type = message["type"]
if type == "please":
self._manager.rx_PLEASE(message)
elif type == "connection-hints":
self._manager.rx_HINTS(message)
elif type == "reconnect":
self._manager.rx_RECONNECT()
elif type == "reconnecting":
self._manager.rx_RECONNECTING()
else: else:
log.err(UnknownDilationMessageType(message)) self._manager.received_dilation_message(plaintext)
return

View File

@ -3,8 +3,7 @@ from collections import deque
from attr import attrs, attrib from attr import attrs, attrib
from attr.validators import instance_of, provides from attr.validators import instance_of, provides
from zope.interface import implementer from zope.interface import implementer
from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue, from twisted.internet.defer import inlineCallbacks, returnValue
succeed)
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
IAddress, IListeningPort, IAddress, IListeningPort,
IStreamClientEndpoint, IStreamClientEndpoint,
@ -12,6 +11,7 @@ from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
from twisted.internet.error import ConnectionDone from twisted.internet.error import ConnectionDone
from automat import MethodicalMachine from automat import MethodicalMachine
from .._interfaces import ISubChannel, IDilationManager from .._interfaces import ISubChannel, IDilationManager
from ..observer import OneShotObserver
# each subchannel frame (the data passed into transport.write(data)) gets a # each subchannel frame (the data passed into transport.write(data)) gets a
# 9-byte header prefix (type, subchannel id, and sequence number), then gets # 9-byte header prefix (type, subchannel id, and sequence number), then gets
@ -217,27 +217,33 @@ class SubChannel(object):
@implementer(IStreamClientEndpoint) @implementer(IStreamClientEndpoint)
@attrs
class ControlEndpoint(object): class ControlEndpoint(object):
_peer_addr = attrib(validator=provides(IAddress))
_subchannel_zero = attrib(validator=provides(ISubChannel))
_eventual_queue = attrib(repr=False)
_used = False _used = False
def __init__(self, peer_addr): def __attrs_post_init__(self):
self._subchannel_zero = Deferred()
self._peer_addr = peer_addr
self._once = Once(SingleUseEndpointError) self._once = Once(SingleUseEndpointError)
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
# from manager # from manager
def _subchannel_zero_opened(self, subchannel):
assert ISubChannel.providedBy(subchannel), subchannel def _main_channel_ready(self):
self._subchannel_zero.callback(subchannel) self._wait_for_main_channel.fire(None)
def _main_channel_failed(self, f):
self._wait_for_main_channel.error(f)
@inlineCallbacks @inlineCallbacks
def connect(self, protocolFactory): def connect(self, protocolFactory):
# return Deferred that fires with IProtocol or Failure(ConnectError) # return Deferred that fires with IProtocol or Failure(ConnectError)
self._once() self._once()
t = yield self._subchannel_zero yield self._wait_for_main_channel.when_fired()
p = protocolFactory.buildProtocol(self._peer_addr) p = protocolFactory.buildProtocol(self._peer_addr)
t._set_protocol(p) self._subchannel_zero._set_protocol(p)
p.makeConnection(t) # set p.transport = t and call connectionMade() # this sets p.transport and calls p.connectionMade()
p.makeConnection(self._subchannel_zero)
returnValue(p) returnValue(p)
@ -246,9 +252,21 @@ class ControlEndpoint(object):
class SubchannelConnectorEndpoint(object): class SubchannelConnectorEndpoint(object):
_manager = attrib(validator=provides(IDilationManager)) _manager = attrib(validator=provides(IDilationManager))
_host_addr = attrib(validator=instance_of(_WormholeAddress)) _host_addr = attrib(validator=instance_of(_WormholeAddress))
_eventual_queue = attrib(repr=False)
def __attrs_post_init__(self):
self._connection_deferreds = deque()
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
def _main_channel_ready(self):
self._wait_for_main_channel.fire(None)
def _main_channel_failed(self, f):
self._wait_for_main_channel.error(f)
@inlineCallbacks
def connect(self, protocolFactory): def connect(self, protocolFactory):
# return Deferred that fires with IProtocol or Failure(ConnectError) # return Deferred that fires with IProtocol or Failure(ConnectError)
yield self._wait_for_main_channel.when_fired()
scid = self._manager.allocate_subchannel_id() scid = self._manager.allocate_subchannel_id()
self._manager.send_open(scid) self._manager.send_open(scid)
peer_addr = _SubchannelAddress(scid) peer_addr = _SubchannelAddress(scid)
@ -259,7 +277,7 @@ class SubchannelConnectorEndpoint(object):
p = protocolFactory.buildProtocol(peer_addr) p = protocolFactory.buildProtocol(peer_addr)
sc._set_protocol(p) sc._set_protocol(p)
p.makeConnection(sc) # set p.transport = sc and call connectionMade() p.makeConnection(sc) # set p.transport = sc and call connectionMade()
return succeed(p) returnValue(p)
@implementer(IStreamServerEndpoint) @implementer(IStreamServerEndpoint)
@ -267,10 +285,13 @@ class SubchannelConnectorEndpoint(object):
class SubchannelListenerEndpoint(object): class SubchannelListenerEndpoint(object):
_manager = attrib(validator=provides(IDilationManager)) _manager = attrib(validator=provides(IDilationManager))
_host_addr = attrib(validator=provides(IAddress)) _host_addr = attrib(validator=provides(IAddress))
_eventual_queue = attrib(repr=False)
def __attrs_post_init__(self): def __attrs_post_init__(self):
self._once = Once(SingleUseEndpointError)
self._factory = None self._factory = None
self._pending_opens = deque() self._pending_opens = deque()
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
# from manager (actually Inbound) # from manager (actually Inbound)
def _got_open(self, t, peer_addr): def _got_open(self, t, peer_addr):
@ -284,15 +305,23 @@ class SubchannelListenerEndpoint(object):
t._set_protocol(p) t._set_protocol(p)
p.makeConnection(t) p.makeConnection(t)
def _main_channel_ready(self):
self._wait_for_main_channel.fire(None)
def _main_channel_failed(self, f):
self._wait_for_main_channel.error(f)
# IStreamServerEndpoint # IStreamServerEndpoint
@inlineCallbacks
def listen(self, protocolFactory): def listen(self, protocolFactory):
self._once()
yield self._wait_for_main_channel.when_fired()
self._factory = protocolFactory self._factory = protocolFactory
while self._pending_opens: while self._pending_opens:
(t, peer_addr) = self._pending_opens.popleft() (t, peer_addr) = self._pending_opens.popleft()
self._connect(t, peer_addr) self._connect(t, peer_addr)
lp = SubchannelListeningPort(self._host_addr) lp = SubchannelListeningPort(self._host_addr)
return succeed(lp) returnValue(lp)
@implementer(IListeningPort) @implementer(IListeningPort)

View File

@ -2,7 +2,10 @@ from __future__ import print_function, unicode_literals
import mock import mock
from zope.interface import alsoProvides from zope.interface import alsoProvides
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet.task import Clock
from twisted.python.failure import Failure
from ..._interfaces import ISubChannel from ..._interfaces import ISubChannel
from ...eventual import EventualQueue
from ..._dilation.subchannel import (ControlEndpoint, from ..._dilation.subchannel import (ControlEndpoint,
SubchannelConnectorEndpoint, SubchannelConnectorEndpoint,
SubchannelListenerEndpoint, SubchannelListenerEndpoint,
@ -11,12 +14,18 @@ from ..._dilation.subchannel import (ControlEndpoint,
SingleUseEndpointError) SingleUseEndpointError)
from .common import mock_manager from .common import mock_manager
class CannotDilateError(Exception):
pass
class Endpoints(unittest.TestCase): class Control(unittest.TestCase):
def test_control(self): def test_early_succeed(self):
# ep.connect() is called before dilation can proceed
scid0 = 0 scid0 = 0
peeraddr = _SubchannelAddress(scid0) peeraddr = _SubchannelAddress(scid0)
ep = ControlEndpoint(peeraddr) sc0 = mock.Mock()
alsoProvides(sc0, ISubChannel)
eq = EventualQueue(Clock())
ep = ControlEndpoint(peeraddr, sc0, eq)
f = mock.Mock() f = mock.Mock()
p = mock.Mock() p = mock.Mock()
@ -24,29 +33,105 @@ class Endpoints(unittest.TestCase):
d = ep.connect(f) d = ep.connect(f)
self.assertNoResult(d) self.assertNoResult(d)
t = mock.Mock() ep._main_channel_ready()
alsoProvides(t, ISubChannel) eq.flush_sync()
ep._subchannel_zero_opened(t)
self.assertIdentical(self.successResultOf(d), p) self.assertIdentical(self.successResultOf(d), p)
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)])
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)])
d = ep.connect(f) d = ep.connect(f)
self.failureResultOf(d, SingleUseEndpointError) self.failureResultOf(d, SingleUseEndpointError)
def assert_makeConnection(self, mock_calls): def test_early_fail(self):
# ep.connect() is called before dilation is abandoned
scid0 = 0
peeraddr = _SubchannelAddress(scid0)
sc0 = mock.Mock()
alsoProvides(sc0, ISubChannel)
eq = EventualQueue(Clock())
ep = ControlEndpoint(peeraddr, sc0, eq)
f = mock.Mock()
p = mock.Mock()
f.buildProtocol = mock.Mock(return_value=p)
d = ep.connect(f)
self.assertNoResult(d)
ep._main_channel_failed(Failure(CannotDilateError()))
eq.flush_sync()
self.failureResultOf(d).check(CannotDilateError)
self.assertEqual(f.buildProtocol.mock_calls, [])
self.assertEqual(sc0.mock_calls, [])
d = ep.connect(f)
self.failureResultOf(d, SingleUseEndpointError)
def test_late_succeed(self):
# dilation can proceed, then ep.connect() is called
scid0 = 0
peeraddr = _SubchannelAddress(scid0)
sc0 = mock.Mock()
alsoProvides(sc0, ISubChannel)
eq = EventualQueue(Clock())
ep = ControlEndpoint(peeraddr, sc0, eq)
ep._main_channel_ready()
f = mock.Mock()
p = mock.Mock()
f.buildProtocol = mock.Mock(return_value=p)
d = ep.connect(f)
eq.flush_sync()
self.assertIdentical(self.successResultOf(d), p)
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)])
self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)])
d = ep.connect(f)
self.failureResultOf(d, SingleUseEndpointError)
def test_late_fail(self):
# dilation is abandoned, then ep.connect() is called
scid0 = 0
peeraddr = _SubchannelAddress(scid0)
sc0 = mock.Mock()
alsoProvides(sc0, ISubChannel)
eq = EventualQueue(Clock())
ep = ControlEndpoint(peeraddr, sc0, eq)
ep._main_channel_failed(Failure(CannotDilateError()))
f = mock.Mock()
p = mock.Mock()
f.buildProtocol = mock.Mock(return_value=p)
d = ep.connect(f)
eq.flush_sync()
self.failureResultOf(d).check(CannotDilateError)
self.assertEqual(f.buildProtocol.mock_calls, [])
self.assertEqual(sc0.mock_calls, [])
d = ep.connect(f)
self.failureResultOf(d, SingleUseEndpointError)
class Endpoints(unittest.TestCase):
def OFFassert_makeConnection(self, mock_calls):
self.assertEqual(len(mock_calls), 1) self.assertEqual(len(mock_calls), 1)
self.assertEqual(mock_calls[0][0], "makeConnection") self.assertEqual(mock_calls[0][0], "makeConnection")
self.assertEqual(len(mock_calls[0][1]), 1) self.assertEqual(len(mock_calls[0][1]), 1)
return mock_calls[0][1][0] return mock_calls[0][1][0]
def test_connector(self): class Connector(unittest.TestCase):
def test_early_succeed(self):
m = mock_manager() m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0) m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress() hostaddr = _WormholeAddress()
peeraddr = _SubchannelAddress(0) peeraddr = _SubchannelAddress(0)
ep = SubchannelConnectorEndpoint(m, hostaddr) eq = EventualQueue(Clock())
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
f = mock.Mock() f = mock.Mock()
p = mock.Mock() p = mock.Mock()
@ -55,38 +140,123 @@ class Endpoints(unittest.TestCase):
with mock.patch("wormhole._dilation.subchannel.SubChannel", with mock.patch("wormhole._dilation.subchannel.SubChannel",
return_value=t) as sc: return_value=t) as sc:
d = ep.connect(f) d = ep.connect(f)
eq.flush_sync()
self.assertNoResult(d)
ep._main_channel_ready()
eq.flush_sync()
self.assertIdentical(self.successResultOf(d), p) self.assertIdentical(self.successResultOf(d), p)
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)]) self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)])
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
def test_listener(self): def test_early_fail(self):
m = mock_manager() m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0) m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress() hostaddr = _WormholeAddress()
ep = SubchannelListenerEndpoint(m, hostaddr) eq = EventualQueue(Clock())
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
f = mock.Mock()
p = mock.Mock()
t = mock.Mock()
f.buildProtocol = mock.Mock(return_value=p)
with mock.patch("wormhole._dilation.subchannel.SubChannel",
return_value=t) as sc:
d = ep.connect(f)
eq.flush_sync()
self.assertNoResult(d)
ep._main_channel_failed(Failure(CannotDilateError()))
eq.flush_sync()
self.failureResultOf(d).check(CannotDilateError)
self.assertEqual(f.buildProtocol.mock_calls, [])
self.assertEqual(sc.mock_calls, [])
self.assertEqual(t.mock_calls, [])
def test_late_succeed(self):
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
peeraddr = _SubchannelAddress(0)
eq = EventualQueue(Clock())
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
ep._main_channel_ready()
f = mock.Mock()
p = mock.Mock()
t = mock.Mock()
f.buildProtocol = mock.Mock(return_value=p)
with mock.patch("wormhole._dilation.subchannel.SubChannel",
return_value=t) as sc:
d = ep.connect(f)
eq.flush_sync()
self.assertIdentical(self.successResultOf(d), p)
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)])
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
def test_late_fail(self):
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
eq = EventualQueue(Clock())
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
ep._main_channel_failed(Failure(CannotDilateError()))
f = mock.Mock()
p = mock.Mock()
t = mock.Mock()
f.buildProtocol = mock.Mock(return_value=p)
with mock.patch("wormhole._dilation.subchannel.SubChannel",
return_value=t) as sc:
d = ep.connect(f)
eq.flush_sync()
self.failureResultOf(d).check(CannotDilateError)
self.assertEqual(f.buildProtocol.mock_calls, [])
self.assertEqual(sc.mock_calls, [])
self.assertEqual(t.mock_calls, [])
class Listener(unittest.TestCase):
def test_early_succeed(self):
# listen, main_channel_ready, got_open, got_open
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
eq = EventualQueue(Clock())
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
f = mock.Mock() f = mock.Mock()
p1 = mock.Mock() p1 = mock.Mock()
p2 = mock.Mock() p2 = mock.Mock()
f.buildProtocol = mock.Mock(side_effect=[p1, p2]) f.buildProtocol = mock.Mock(side_effect=[p1, p2])
# OPEN that arrives before we ep.listen() should be queued d = ep.listen(f)
eq.flush_sync()
self.assertNoResult(d)
self.assertEqual(f.buildProtocol.mock_calls, [])
ep._main_channel_ready()
eq.flush_sync()
lp = self.successResultOf(d)
self.assertIsInstance(lp, SubchannelListeningPort)
self.assertEqual(lp.getHost(), hostaddr)
# TODO: IListeningPort says we must provide this, but I don't know
# that anyone would ever call it.
lp.startListening()
t1 = mock.Mock() t1 = mock.Mock()
peeraddr1 = _SubchannelAddress(1) peeraddr1 = _SubchannelAddress(1)
ep._got_open(t1, peeraddr1) ep._got_open(t1, peeraddr1)
d = ep.listen(f)
lp = self.successResultOf(d)
self.assertIsInstance(lp, SubchannelListeningPort)
self.assertEqual(lp.getHost(), hostaddr)
lp.startListening()
self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)]) self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)])
self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)]) self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)])
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)])
t2 = mock.Mock() t2 = mock.Mock()
peeraddr2 = _SubchannelAddress(2) peeraddr2 = _SubchannelAddress(2)
@ -94,5 +264,92 @@ class Endpoints(unittest.TestCase):
self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])
self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)])
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1),
mock.call(peeraddr2)])
lp.stopListening() # TODO: should this do more? lp.stopListening() # TODO: should this do more?
def test_early_fail(self):
# listen, main_channel_fail
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
eq = EventualQueue(Clock())
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
f = mock.Mock()
p1 = mock.Mock()
p2 = mock.Mock()
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
d = ep.listen(f)
eq.flush_sync()
self.assertNoResult(d)
ep._main_channel_failed(Failure(CannotDilateError()))
eq.flush_sync()
self.failureResultOf(d).check(CannotDilateError)
self.assertEqual(f.buildProtocol.mock_calls, [])
def test_late_succeed(self):
# main_channel_ready, got_open, listen, got_open
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
eq = EventualQueue(Clock())
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
ep._main_channel_ready()
f = mock.Mock()
p1 = mock.Mock()
p2 = mock.Mock()
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
t1 = mock.Mock()
peeraddr1 = _SubchannelAddress(1)
ep._got_open(t1, peeraddr1)
eq.flush_sync()
self.assertEqual(t1.mock_calls, [])
self.assertEqual(p1.mock_calls, [])
d = ep.listen(f)
eq.flush_sync()
lp = self.successResultOf(d)
self.assertIsInstance(lp, SubchannelListeningPort)
self.assertEqual(lp.getHost(), hostaddr)
lp.startListening()
self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)])
self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)])
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)])
t2 = mock.Mock()
peeraddr2 = _SubchannelAddress(2)
ep._got_open(t2, peeraddr2)
self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])
self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)])
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1),
mock.call(peeraddr2)])
lp.stopListening() # TODO: should this do more?
def test_late_fail(self):
# main_channel_fail, listen
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
eq = EventualQueue(Clock())
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
ep._main_channel_failed(Failure(CannotDilateError()))
f = mock.Mock()
p1 = mock.Mock()
p2 = mock.Mock()
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
d = ep.listen(f)
eq.flush_sync()
self.failureResultOf(d).check(CannotDilateError)
self.assertEqual(f.buildProtocol.mock_calls, [])

View File

@ -46,24 +46,21 @@ class Full(ServerBase, unittest.TestCase):
yield doBoth(w1.get_verifier(), w2.get_verifier()) yield doBoth(w1.get_verifier(), w2.get_verifier())
print("connected") print("connected")
eps1_d = w1.dilate() eps1 = w1.dilate()
eps2_d = w2.dilate() eps2 = w2.dilate()
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
(control_ep1, connect_ep1, listen_ep1) = eps1
(control_ep2, connect_ep2, listen_ep2) = eps2
print("w.dilate ready") print("w.dilate ready")
f1 = Factory() f1 = Factory()
f1.protocol = L f1.protocol = L
f1.d = Deferred() f1.d = Deferred()
f1.d.addCallback(lambda data: eq.fire_eventually(data)) f1.d.addCallback(lambda data: eq.fire_eventually(data))
d1 = control_ep1.connect(f1) d1 = eps1.control.connect(f1)
f2 = Factory() f2 = Factory()
f2.protocol = L f2.protocol = L
f2.d = Deferred() f2.d = Deferred()
f2.d.addCallback(lambda data: eq.fire_eventually(data)) f2.d.addCallback(lambda data: eq.fire_eventually(data))
d2 = control_ep2.connect(f2) d2 = eps2.control.connect(f2)
yield d1 yield d1
yield d2 yield d2
print("control endpoints connected") print("control endpoints connected")
@ -125,14 +122,12 @@ class Reconnect(ServerBase, unittest.TestCase):
w2.set_code(code) w2.set_code(code)
yield doBoth(w1.get_verifier(), w2.get_verifier()) yield doBoth(w1.get_verifier(), w2.get_verifier())
eps1_d = w1.dilate() eps1 = w1.dilate()
eps2_d = w2.dilate() eps2 = w2.dilate()
(eps1, eps2) = yield doBoth(eps1_d, eps2_d) print("w.dilate ready")
(control_ep1, connect_ep1, listen_ep1) = eps1
(control_ep2, connect_ep2, listen_ep2) = eps2
f1 = ReconF(eq); f2 = ReconF(eq) f1 = ReconF(eq); f2 = ReconF(eq)
d1 = control_ep1.connect(f1); d2 = control_ep2.connect(f2) d1 = eps1.control.connect(f1); d2 = eps2.control.connect(f2)
yield d1 yield d1
yield d2 yield d2
@ -194,14 +189,12 @@ class Reconnect(ServerBase, unittest.TestCase):
w2.set_code(code) w2.set_code(code)
yield doBoth(w1.get_verifier(), w2.get_verifier()) yield doBoth(w1.get_verifier(), w2.get_verifier())
eps1_d = w1.dilate() eps1 = w1.dilate()
eps2_d = w2.dilate() eps2 = w2.dilate()
(eps1, eps2) = yield doBoth(eps1_d, eps2_d) print("w.dilate ready")
(control_ep1, connect_ep1, listen_ep1) = eps1
(control_ep2, connect_ep2, listen_ep2) = eps2
f1 = ReconF(eq); f2 = ReconF(eq) f1 = ReconF(eq); f2 = ReconF(eq)
d1 = control_ep1.connect(f1); d2 = control_ep2.connect(f2) d1 = eps1.control.connect(f1); d2 = eps2.control.connect(f2)
yield d1 yield d1
yield d2 yield d2
@ -287,19 +280,17 @@ class Endpoints(ServerBase, unittest.TestCase):
w2.set_code(code) w2.set_code(code)
yield doBoth(w1.get_verifier(), w2.get_verifier()) yield doBoth(w1.get_verifier(), w2.get_verifier())
eps1_d = w1.dilate() eps1 = w1.dilate()
eps2_d = w2.dilate() eps2 = w2.dilate()
(eps1, eps2) = yield doBoth(eps1_d, eps2_d) print("w.dilate ready")
(control_ep1, connect_ep1, listen_ep1) = eps1
(control_ep2, connect_ep2, listen_ep2) = eps2
f0 = ReconF(eq) f0 = ReconF(eq)
yield listen_ep2.listen(f0) yield eps2.listen.listen(f0)
from twisted.python import log from twisted.python import log
f1 = ReconF(eq) f1 = ReconF(eq)
log.msg("connecting") log.msg("connecting")
p1_client = yield connect_ep1.connect(f1) p1_client = yield eps1.connect.connect(f1)
log.msg("sending c->s") log.msg("sending c->s")
p1_client.transport.write(b"hello from p1\n") p1_client.transport.write(b"hello from p1\n")
data = yield f0.deferreds["dataReceived"] data = yield f0.deferreds["dataReceived"]
@ -316,7 +307,7 @@ class Endpoints(ServerBase, unittest.TestCase):
f0.resetDeferred("dataReceived") f0.resetDeferred("dataReceived")
f1.resetDeferred("dataReceived") f1.resetDeferred("dataReceived")
f2 = ReconF(eq) f2 = ReconF(eq)
p2_client = yield connect_ep1.connect(f2) p2_client = yield eps1.connect.connect(f2)
p2_server = yield f0.deferreds["connectionMade"] p2_server = yield f0.deferreds["connectionMade"]
p2_server.transport.write(b"hello p2\n") p2_server.transport.write(b"hello p2\n")
data = yield f2.deferreds["dataReceived"] data = yield f2.deferreds["dataReceived"]

View File

@ -140,25 +140,23 @@ class TestDilator(unittest.TestCase):
def test_peer_cannot_dilate(self): def test_peer_cannot_dilate(self):
dil, send, reactor, eq, clock, coop = make_dilator() dil, send, reactor, eq, clock, coop = make_dilator()
d1 = dil.dilate() eps = dil.dilate()
self.assertNoResult(d1)
dil._transit_key = b"\x01" * 32 dil.got_key(b"\x01" * 32)
dil.got_wormhole_versions({}) # missing "can-dilate" dil.got_wormhole_versions({}) # missing "can-dilate"
d = eps.connect.connect(None)
eq.flush_sync() eq.flush_sync()
f = self.failureResultOf(d1) self.failureResultOf(d).check(OldPeerCannotDilateError)
f.check(OldPeerCannotDilateError)
def test_disjoint_versions(self): def test_disjoint_versions(self):
dil, send, reactor, eq, clock, coop = make_dilator() dil, send, reactor, eq, clock, coop = make_dilator()
d1 = dil.dilate() eps = dil.dilate()
self.assertNoResult(d1)
dil._transit_key = b"key" dil.got_key(b"\x01" * 32)
dil.got_wormhole_versions({"can-dilate": [-1]}) dil.got_wormhole_versions({"can-dilate": [-1]})
d = eps.connect.connect(None)
eq.flush_sync() eq.flush_sync()
f = self.failureResultOf(d1) self.failureResultOf(d).check(OldPeerCannotDilateError)
f.check(OldPeerCannotDilateError)
def test_early_dilate_messages(self): def test_early_dilate_messages(self):
dil, send, reactor, eq, clock, coop = make_dilator() dil, send, reactor, eq, clock, coop = make_dilator()
@ -276,11 +274,11 @@ def make_manager(leader=True):
h.Inbound = mock.Mock(return_value=h.inbound) h.Inbound = mock.Mock(return_value=h.inbound)
h.outbound = mock.Mock() h.outbound = mock.Mock()
h.Outbound = mock.Mock(return_value=h.outbound) h.Outbound = mock.Mock(return_value=h.outbound)
h.hostaddr = mock.Mock() with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound), \
alsoProvides(h.hostaddr, IAddress) mock.patch("wormhole._dilation.manager.Outbound", h.Outbound):
with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound): m = Manager(h.send, side, h.relay, h.reactor, h.eq, h.coop)
with mock.patch("wormhole._dilation.manager.Outbound", h.Outbound): h.hostaddr = m._host_addr
m = Manager(h.send, side, h.key, h.relay, h.reactor, h.eq, h.coop, h.hostaddr) m.got_dilation_key(h.key)
return m, h return m, h