commit
20496e3976
|
@ -2,13 +2,20 @@ from __future__ import print_function, unicode_literals
|
|||
import six
|
||||
import os
|
||||
from collections import deque
|
||||
try:
|
||||
# py >= 3.3
|
||||
from collections.abc import Sequence
|
||||
except ImportError:
|
||||
# py 2 and py3 < 3.3
|
||||
from collections import Sequence
|
||||
from attr import attrs, attrib
|
||||
from attr.validators import provides, instance_of, optional
|
||||
from automat import MethodicalMachine
|
||||
from zope.interface import implementer
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
|
||||
from twisted.internet.interfaces import IAddress
|
||||
from twisted.python import log
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.interfaces import (IStreamClientEndpoint,
|
||||
IStreamServerEndpoint)
|
||||
from twisted.python import log, failure
|
||||
from .._interfaces import IDilator, IDilationManager, ISend, ITerminator
|
||||
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr
|
||||
from ..observer import OneShotObserver
|
||||
|
@ -47,6 +54,15 @@ class UnexpectedKCM(Exception):
|
|||
class UnknownMessageType(Exception):
|
||||
pass
|
||||
|
||||
@attrs
|
||||
class EndpointRecord(Sequence):
|
||||
control = attrib(validator=provides(IStreamClientEndpoint))
|
||||
connect = attrib(validator=provides(IStreamClientEndpoint))
|
||||
listen = attrib(validator=provides(IStreamServerEndpoint))
|
||||
def __len__(self):
|
||||
return 3
|
||||
def __getitem__(self, n):
|
||||
return (self.control, self.connect, self.listen)[n]
|
||||
|
||||
def make_side():
|
||||
return bytes_to_hexstr(os.urandom(6))
|
||||
|
@ -93,13 +109,14 @@ def make_side():
|
|||
class Manager(object):
|
||||
_S = attrib(validator=provides(ISend), repr=False)
|
||||
_my_side = attrib(validator=instance_of(type(u"")))
|
||||
_transit_key = attrib(validator=instance_of(bytes), repr=False)
|
||||
_transit_relay_location = attrib(validator=optional(instance_of(str)))
|
||||
_reactor = attrib(repr=False)
|
||||
_eventual_queue = attrib(repr=False)
|
||||
_cooperator = attrib(repr=False)
|
||||
_host_addr = attrib(validator=provides(IAddress))
|
||||
_no_listen = attrib(default=False)
|
||||
# TODO: can this validator work when the parameter is optional?
|
||||
_no_listen = attrib(validator=instance_of(bool), default=False)
|
||||
|
||||
_dilation_key = None
|
||||
_tor = None # TODO
|
||||
_timing = None # TODO
|
||||
_next_subchannel_id = None # initialized in choose_role
|
||||
|
@ -111,10 +128,10 @@ class Manager(object):
|
|||
self._got_versions_d = Deferred()
|
||||
|
||||
self._my_role = None # determined upon rx_PLEASE
|
||||
self._host_addr = _WormholeAddress()
|
||||
|
||||
self._connection = None
|
||||
self._made_first_connection = False
|
||||
self._first_connected = OneShotObserver(self._eventual_queue)
|
||||
self._stopped = OneShotObserver(self._eventual_queue)
|
||||
self._debug_stall_connector = False
|
||||
|
||||
|
@ -127,18 +144,81 @@ class Manager(object):
|
|||
self._inbound = Inbound(self, self._host_addr)
|
||||
self._outbound = Outbound(self, self._cooperator) # from us to peer
|
||||
|
||||
def set_listener_endpoint(self, listener_endpoint):
|
||||
self._inbound.set_listener_endpoint(listener_endpoint)
|
||||
|
||||
def set_subchannel_zero(self, scid0, sc0):
|
||||
# We must open subchannel0 early, since messages may arrive very
|
||||
# quickly once the connection is established. This subchannel may or
|
||||
# may not ever get revealed to the caller, since the peer might not
|
||||
# even be capable of dilation.
|
||||
scid0 = 0
|
||||
peer_addr0 = _SubchannelAddress(scid0)
|
||||
sc0 = SubChannel(scid0, self, self._host_addr, peer_addr0)
|
||||
self._inbound.set_subchannel_zero(scid0, sc0)
|
||||
|
||||
def when_first_connected(self):
|
||||
return self._first_connected.when_fired()
|
||||
# we can open non-zero subchannels as soon as we get our first
|
||||
# connection, and we can make the Endpoints even earlier
|
||||
control_ep = ControlEndpoint(peer_addr0, sc0, self._eventual_queue)
|
||||
connect_ep = SubchannelConnectorEndpoint(self, self._host_addr, self._eventual_queue)
|
||||
listen_ep = SubchannelListenerEndpoint(self, self._host_addr, self._eventual_queue)
|
||||
# TODO: let inbound/outbound create the endpoints, then return them
|
||||
# to us
|
||||
self._inbound.set_listener_endpoint(listen_ep)
|
||||
|
||||
self._endpoints = EndpointRecord(control_ep, connect_ep, listen_ep)
|
||||
|
||||
def get_endpoints(self):
|
||||
return self._endpoints
|
||||
|
||||
def got_dilation_key(self, key):
|
||||
assert isinstance(key, bytes)
|
||||
self._dilation_key = key
|
||||
|
||||
def got_wormhole_versions(self, their_wormhole_versions):
|
||||
# this always happens before received_dilation_message
|
||||
dilation_version = None
|
||||
their_dilation_versions = set(their_wormhole_versions.get("can-dilate", []))
|
||||
my_versions = set(DILATION_VERSIONS)
|
||||
shared_versions = my_versions.intersection(their_dilation_versions)
|
||||
if "1" in shared_versions:
|
||||
dilation_version = "1"
|
||||
|
||||
# dilation_version is the best mutually-compatible version we have
|
||||
# with the peer, or None if we have nothing in common
|
||||
|
||||
if not dilation_version: # "1" or None
|
||||
# TODO: be more specific about the error. dilation_version==None
|
||||
# means we had no version in common with them, which could either
|
||||
# be because they're so old they don't dilate at all, or because
|
||||
# they're so new that they no longer accomodate our old version
|
||||
self.fail(failure.Failure(OldPeerCannotDilateError()))
|
||||
|
||||
self.start()
|
||||
|
||||
def fail(self, f):
|
||||
self._endpoints.control._main_channel_failed(f)
|
||||
self._endpoints.connect._main_channel_failed(f)
|
||||
self._endpoints.listen._main_channel_failed(f)
|
||||
|
||||
def received_dilation_message(self, plaintext):
|
||||
# this receives new in-order DILATE-n payloads, decrypted but not
|
||||
# de-JSONed.
|
||||
|
||||
message = bytes_to_dict(plaintext)
|
||||
type = message["type"]
|
||||
if type == "please":
|
||||
self.rx_PLEASE(message)
|
||||
elif type == "connection-hints":
|
||||
self.rx_HINTS(message)
|
||||
elif type == "reconnect":
|
||||
self.rx_RECONNECT()
|
||||
elif type == "reconnecting":
|
||||
self.rx_RECONNECTING()
|
||||
else:
|
||||
log.err(UnknownDilationMessageType(message))
|
||||
return
|
||||
|
||||
def when_stopped(self):
|
||||
return self._stopped.when_fired()
|
||||
|
||||
|
||||
def send_dilation_phase(self, **fields):
|
||||
dilation_phase = self._next_dilation_phase
|
||||
self._next_dilation_phase += 1
|
||||
|
@ -204,7 +284,9 @@ class Manager(object):
|
|||
self._outbound.use_connection(c) # does c.registerProducer
|
||||
if not self._made_first_connection:
|
||||
self._made_first_connection = True
|
||||
self._first_connected.fire(None)
|
||||
self._endpoints.control._main_channel_ready()
|
||||
self._endpoints.connect._main_channel_ready()
|
||||
self._endpoints.listen._main_channel_ready()
|
||||
pass
|
||||
|
||||
def connector_connection_lost(self):
|
||||
|
@ -272,16 +354,11 @@ class Manager(object):
|
|||
|
||||
# state machine
|
||||
|
||||
# We are born WANTING after the local app calls w.dilate(). We start
|
||||
# CONNECTING when we receive PLEASE from the remote side
|
||||
|
||||
def start(self):
|
||||
self.send_please()
|
||||
|
||||
def send_please(self):
|
||||
self.send_dilation_phase(type="please", side=self._my_side)
|
||||
|
||||
@m.state(initial=True)
|
||||
def WAITING(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def WANTING(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
|
@ -313,6 +390,10 @@ class Manager(object):
|
|||
def STOPPED(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def start(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def rx_PLEASE(self, message):
|
||||
pass # pragma: no cover
|
||||
|
@ -350,6 +431,10 @@ class Manager(object):
|
|||
def stop(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.output()
|
||||
def send_please(self):
|
||||
self.send_dilation_phase(type="please", side=self._my_side)
|
||||
|
||||
@m.output()
|
||||
def choose_role(self, message):
|
||||
their_side = message["side"]
|
||||
|
@ -378,7 +463,8 @@ class Manager(object):
|
|||
|
||||
def _start_connecting(self):
|
||||
assert self._my_role is not None
|
||||
self._connector = Connector(self._transit_key,
|
||||
assert self._dilation_key is not None
|
||||
self._connector = Connector(self._dilation_key,
|
||||
self._transit_relay_location,
|
||||
self,
|
||||
self._reactor, self._eventual_queue,
|
||||
|
@ -422,6 +508,11 @@ class Manager(object):
|
|||
def notify_stopped(self):
|
||||
self._stopped.fire(None)
|
||||
|
||||
# We are born WAITING after the local app calls w.dilate(). We enter
|
||||
# WANTING (and send a PLEASE) when we learn of a mutually-compatible
|
||||
# dilation_version.
|
||||
WAITING.upon(start, enter=WANTING, outputs=[send_please])
|
||||
|
||||
# we start CONNECTING when we get rx_PLEASE
|
||||
WANTING.upon(rx_PLEASE, enter=CONNECTING,
|
||||
outputs=[choose_role, start_connecting_ignore_message])
|
||||
|
@ -489,12 +580,10 @@ class Dilator(object):
|
|||
_cooperator = attrib()
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._got_versions_d = Deferred()
|
||||
self._started = False
|
||||
self._endpoints = OneShotObserver(self._eventual_queue)
|
||||
self._pending_inbound_dilate_messages = deque()
|
||||
self._manager = None
|
||||
self._host_addr = _WormholeAddress()
|
||||
self._pending_dilation_key = None
|
||||
self._pending_wormhole_versions = None
|
||||
self._pending_inbound_dilate_messages = deque()
|
||||
|
||||
def wire(self, sender, terminator):
|
||||
self._S = ISend(sender)
|
||||
|
@ -502,77 +591,35 @@ class Dilator(object):
|
|||
|
||||
# this is the primary entry point, called when w.dilate() is invoked
|
||||
def dilate(self, transit_relay_location=None, no_listen=False):
|
||||
self._transit_relay_location = transit_relay_location
|
||||
if not self._started:
|
||||
self._started = True
|
||||
self._start(no_listen).addBoth(self._endpoints.fire)
|
||||
return self._endpoints.when_fired()
|
||||
|
||||
@inlineCallbacks
|
||||
def _start(self, no_listen):
|
||||
# first, we wait until we hear the VERSION message, which tells us 1:
|
||||
# the PAKE key works, so we can talk securely, 2: that they can do
|
||||
# dilation at all (if they can't then w.dilate() errbacks)
|
||||
|
||||
dilation_version = yield self._got_versions_d
|
||||
|
||||
# TODO: we could probably return the endpoints earlier, if we flunk
|
||||
# any connection/listen attempts upon OldPeerCannotDilateError, or
|
||||
# if/when we give up on the initial connection
|
||||
|
||||
if not dilation_version: # "1" or None
|
||||
# TODO: be more specific about the error. dilation_version==None
|
||||
# means we had no version in common with them, which could either
|
||||
# be because they're so old they don't dilate at all, or because
|
||||
# they're so new that they no longer accomodate our old version
|
||||
raise OldPeerCannotDilateError()
|
||||
|
||||
my_dilation_side = make_side()
|
||||
self._manager = Manager(self._S, my_dilation_side,
|
||||
self._transit_key,
|
||||
self._transit_relay_location,
|
||||
self._reactor, self._eventual_queue,
|
||||
self._cooperator, self._host_addr, no_listen)
|
||||
# We must open subchannel0 early, since messages may arrive very
|
||||
# quickly once the connection is established. This subchannel may or
|
||||
# may not ever get revealed to the caller, since the peer might not
|
||||
# even be capable of dilation.
|
||||
scid0 = 0
|
||||
peer_addr0 = _SubchannelAddress(scid0)
|
||||
sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0)
|
||||
self._manager.set_subchannel_zero(scid0, sc0)
|
||||
|
||||
self._manager.start()
|
||||
|
||||
while self._pending_inbound_dilate_messages:
|
||||
plaintext = self._pending_inbound_dilate_messages.popleft()
|
||||
self.received_dilate(plaintext)
|
||||
|
||||
yield self._manager.when_first_connected()
|
||||
|
||||
# we can open non-zero subchannels as soon as we get our first
|
||||
# connection
|
||||
control_ep = ControlEndpoint(peer_addr0)
|
||||
control_ep._subchannel_zero_opened(sc0)
|
||||
connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr)
|
||||
|
||||
listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr)
|
||||
self._manager.set_listener_endpoint(listen_ep)
|
||||
|
||||
endpoints = (control_ep, connect_ep, listen_ep)
|
||||
returnValue(endpoints)
|
||||
if not self._manager:
|
||||
# build the manager right away, and tell it later when the
|
||||
# VERSIONS message arrives, and also when the dilation_key is set
|
||||
my_dilation_side = make_side()
|
||||
m = Manager(self._S, my_dilation_side,
|
||||
transit_relay_location,
|
||||
self._reactor, self._eventual_queue,
|
||||
self._cooperator, no_listen)
|
||||
self._manager = m
|
||||
if self._pending_dilation_key is not None:
|
||||
m.got_dilation_key(self._pending_dilation_key)
|
||||
if self._pending_wormhole_versions:
|
||||
m.got_wormhole_versions(self._pending_wormhole_versions)
|
||||
while self._pending_inbound_dilate_messages:
|
||||
plaintext = self._pending_inbound_dilate_messages.popleft()
|
||||
m.received_dilation_message(plaintext)
|
||||
return self._manager.get_endpoints()
|
||||
|
||||
# Called by Terminator after everything else (mailbox, nameplate, server
|
||||
# connection) has shut down. Expects to fire T.stoppedD() when Dilator is
|
||||
# stopped too.
|
||||
def stop(self):
|
||||
if not self._started:
|
||||
self._T.stoppedD()
|
||||
return
|
||||
if self._started:
|
||||
if self._manager:
|
||||
self._manager.stop()
|
||||
# TODO: avoid Deferreds for control flow, hard to serialize
|
||||
self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD())
|
||||
else:
|
||||
self._T.stoppedD()
|
||||
return
|
||||
# TODO: tolerate multiple calls
|
||||
|
||||
# from Boss
|
||||
|
@ -582,39 +629,20 @@ class Dilator(object):
|
|||
# to tolerate either ordering
|
||||
purpose = b"dilation-v1"
|
||||
LENGTH = 32 # TODO: whatever Noise wants, I guess
|
||||
self._transit_key = derive_key(key, purpose, LENGTH)
|
||||
dilation_key = derive_key(key, purpose, LENGTH)
|
||||
if self._manager:
|
||||
self._manager.got_dilation_key(dilation_key)
|
||||
else:
|
||||
self._pending_dilation_key = dilation_key
|
||||
|
||||
def got_wormhole_versions(self, their_wormhole_versions):
|
||||
assert self._transit_key is not None
|
||||
# this always happens before received_dilate
|
||||
dilation_version = None
|
||||
their_dilation_versions = set(their_wormhole_versions.get("can-dilate", []))
|
||||
my_versions = set(DILATION_VERSIONS)
|
||||
shared_versions = my_versions.intersection(their_dilation_versions)
|
||||
if "1" in shared_versions:
|
||||
dilation_version = "1"
|
||||
self._got_versions_d.callback(dilation_version)
|
||||
if self._manager:
|
||||
self._manager.got_wormhole_versions(their_wormhole_versions)
|
||||
else:
|
||||
self._pending_wormhole_versions = their_wormhole_versions
|
||||
|
||||
def received_dilate(self, plaintext):
|
||||
# this receives new in-order DILATE-n payloads, decrypted but not
|
||||
# de-JSONed.
|
||||
|
||||
# this can appear before our .dilate() method is called, in which case
|
||||
# we queue them for later
|
||||
if not self._manager:
|
||||
self._pending_inbound_dilate_messages.append(plaintext)
|
||||
return
|
||||
|
||||
message = bytes_to_dict(plaintext)
|
||||
type = message["type"]
|
||||
if type == "please":
|
||||
self._manager.rx_PLEASE(message)
|
||||
elif type == "connection-hints":
|
||||
self._manager.rx_HINTS(message)
|
||||
elif type == "reconnect":
|
||||
self._manager.rx_RECONNECT()
|
||||
elif type == "reconnecting":
|
||||
self._manager.rx_RECONNECTING()
|
||||
else:
|
||||
log.err(UnknownDilationMessageType(message))
|
||||
return
|
||||
self._manager.received_dilation_message(plaintext)
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import six
|
||||
from collections import deque
|
||||
from attr import attrs, attrib
|
||||
from attr.validators import instance_of, provides
|
||||
from zope.interface import implementer
|
||||
from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue,
|
||||
succeed)
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
|
||||
IAddress, IListeningPort,
|
||||
IStreamClientEndpoint,
|
||||
|
@ -11,6 +11,7 @@ from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
|
|||
from twisted.internet.error import ConnectionDone
|
||||
from automat import MethodicalMachine
|
||||
from .._interfaces import ISubChannel, IDilationManager
|
||||
from ..observer import OneShotObserver
|
||||
|
||||
# each subchannel frame (the data passed into transport.write(data)) gets a
|
||||
# 9-byte header prefix (type, subchannel id, and sequence number), then gets
|
||||
|
@ -216,27 +217,33 @@ class SubChannel(object):
|
|||
|
||||
|
||||
@implementer(IStreamClientEndpoint)
|
||||
@attrs
|
||||
class ControlEndpoint(object):
|
||||
_peer_addr = attrib(validator=provides(IAddress))
|
||||
_subchannel_zero = attrib(validator=provides(ISubChannel))
|
||||
_eventual_queue = attrib(repr=False)
|
||||
_used = False
|
||||
|
||||
def __init__(self, peer_addr):
|
||||
self._subchannel_zero = Deferred()
|
||||
self._peer_addr = peer_addr
|
||||
def __attrs_post_init__(self):
|
||||
self._once = Once(SingleUseEndpointError)
|
||||
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
|
||||
|
||||
# from manager
|
||||
def _subchannel_zero_opened(self, subchannel):
|
||||
assert ISubChannel.providedBy(subchannel), subchannel
|
||||
self._subchannel_zero.callback(subchannel)
|
||||
|
||||
def _main_channel_ready(self):
|
||||
self._wait_for_main_channel.fire(None)
|
||||
def _main_channel_failed(self, f):
|
||||
self._wait_for_main_channel.error(f)
|
||||
|
||||
@inlineCallbacks
|
||||
def connect(self, protocolFactory):
|
||||
# return Deferred that fires with IProtocol or Failure(ConnectError)
|
||||
self._once()
|
||||
t = yield self._subchannel_zero
|
||||
yield self._wait_for_main_channel.when_fired()
|
||||
p = protocolFactory.buildProtocol(self._peer_addr)
|
||||
t._set_protocol(p)
|
||||
p.makeConnection(t) # set p.transport = t and call connectionMade()
|
||||
self._subchannel_zero._set_protocol(p)
|
||||
# this sets p.transport and calls p.connectionMade()
|
||||
p.makeConnection(self._subchannel_zero)
|
||||
returnValue(p)
|
||||
|
||||
|
||||
|
@ -245,9 +252,21 @@ class ControlEndpoint(object):
|
|||
class SubchannelConnectorEndpoint(object):
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
_host_addr = attrib(validator=instance_of(_WormholeAddress))
|
||||
_eventual_queue = attrib(repr=False)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._connection_deferreds = deque()
|
||||
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
|
||||
|
||||
def _main_channel_ready(self):
|
||||
self._wait_for_main_channel.fire(None)
|
||||
def _main_channel_failed(self, f):
|
||||
self._wait_for_main_channel.error(f)
|
||||
|
||||
@inlineCallbacks
|
||||
def connect(self, protocolFactory):
|
||||
# return Deferred that fires with IProtocol or Failure(ConnectError)
|
||||
yield self._wait_for_main_channel.when_fired()
|
||||
scid = self._manager.allocate_subchannel_id()
|
||||
self._manager.send_open(scid)
|
||||
peer_addr = _SubchannelAddress(scid)
|
||||
|
@ -258,7 +277,7 @@ class SubchannelConnectorEndpoint(object):
|
|||
p = protocolFactory.buildProtocol(peer_addr)
|
||||
sc._set_protocol(p)
|
||||
p.makeConnection(sc) # set p.transport = sc and call connectionMade()
|
||||
return succeed(p)
|
||||
returnValue(p)
|
||||
|
||||
|
||||
@implementer(IStreamServerEndpoint)
|
||||
|
@ -266,12 +285,15 @@ class SubchannelConnectorEndpoint(object):
|
|||
class SubchannelListenerEndpoint(object):
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
_host_addr = attrib(validator=provides(IAddress))
|
||||
_eventual_queue = attrib(repr=False)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._once = Once(SingleUseEndpointError)
|
||||
self._factory = None
|
||||
self._pending_opens = []
|
||||
self._pending_opens = deque()
|
||||
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
|
||||
|
||||
# from manager
|
||||
# from manager (actually Inbound)
|
||||
def _got_open(self, t, peer_addr):
|
||||
if self._factory:
|
||||
self._connect(t, peer_addr)
|
||||
|
@ -283,15 +305,23 @@ class SubchannelListenerEndpoint(object):
|
|||
t._set_protocol(p)
|
||||
p.makeConnection(t)
|
||||
|
||||
def _main_channel_ready(self):
|
||||
self._wait_for_main_channel.fire(None)
|
||||
def _main_channel_failed(self, f):
|
||||
self._wait_for_main_channel.error(f)
|
||||
|
||||
# IStreamServerEndpoint
|
||||
|
||||
@inlineCallbacks
|
||||
def listen(self, protocolFactory):
|
||||
self._once()
|
||||
yield self._wait_for_main_channel.when_fired()
|
||||
self._factory = protocolFactory
|
||||
for (t, peer_addr) in self._pending_opens:
|
||||
while self._pending_opens:
|
||||
(t, peer_addr) = self._pending_opens.popleft()
|
||||
self._connect(t, peer_addr)
|
||||
self._pending_opens = []
|
||||
lp = SubchannelListeningPort(self._host_addr)
|
||||
return succeed(lp)
|
||||
returnValue(lp)
|
||||
|
||||
|
||||
@implementer(IListeningPort)
|
||||
|
|
|
@ -2,7 +2,10 @@ from __future__ import print_function, unicode_literals
|
|||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.task import Clock
|
||||
from twisted.python.failure import Failure
|
||||
from ..._interfaces import ISubChannel
|
||||
from ...eventual import EventualQueue
|
||||
from ..._dilation.subchannel import (ControlEndpoint,
|
||||
SubchannelConnectorEndpoint,
|
||||
SubchannelListenerEndpoint,
|
||||
|
@ -11,12 +14,18 @@ from ..._dilation.subchannel import (ControlEndpoint,
|
|||
SingleUseEndpointError)
|
||||
from .common import mock_manager
|
||||
|
||||
class CannotDilateError(Exception):
|
||||
pass
|
||||
|
||||
class Endpoints(unittest.TestCase):
|
||||
def test_control(self):
|
||||
class Control(unittest.TestCase):
|
||||
def test_early_succeed(self):
|
||||
# ep.connect() is called before dilation can proceed
|
||||
scid0 = 0
|
||||
peeraddr = _SubchannelAddress(scid0)
|
||||
ep = ControlEndpoint(peeraddr)
|
||||
sc0 = mock.Mock()
|
||||
alsoProvides(sc0, ISubChannel)
|
||||
eq = EventualQueue(Clock())
|
||||
ep = ControlEndpoint(peeraddr, sc0, eq)
|
||||
|
||||
f = mock.Mock()
|
||||
p = mock.Mock()
|
||||
|
@ -24,29 +33,105 @@ class Endpoints(unittest.TestCase):
|
|||
d = ep.connect(f)
|
||||
self.assertNoResult(d)
|
||||
|
||||
t = mock.Mock()
|
||||
alsoProvides(t, ISubChannel)
|
||||
ep._subchannel_zero_opened(t)
|
||||
ep._main_channel_ready()
|
||||
eq.flush_sync()
|
||||
|
||||
self.assertIdentical(self.successResultOf(d), p)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
|
||||
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
|
||||
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
|
||||
self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)])
|
||||
self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)])
|
||||
|
||||
d = ep.connect(f)
|
||||
self.failureResultOf(d, SingleUseEndpointError)
|
||||
|
||||
def assert_makeConnection(self, mock_calls):
|
||||
def test_early_fail(self):
|
||||
# ep.connect() is called before dilation is abandoned
|
||||
scid0 = 0
|
||||
peeraddr = _SubchannelAddress(scid0)
|
||||
sc0 = mock.Mock()
|
||||
alsoProvides(sc0, ISubChannel)
|
||||
eq = EventualQueue(Clock())
|
||||
ep = ControlEndpoint(peeraddr, sc0, eq)
|
||||
|
||||
f = mock.Mock()
|
||||
p = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(return_value=p)
|
||||
d = ep.connect(f)
|
||||
self.assertNoResult(d)
|
||||
|
||||
ep._main_channel_failed(Failure(CannotDilateError()))
|
||||
eq.flush_sync()
|
||||
|
||||
self.failureResultOf(d).check(CannotDilateError)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [])
|
||||
self.assertEqual(sc0.mock_calls, [])
|
||||
|
||||
d = ep.connect(f)
|
||||
self.failureResultOf(d, SingleUseEndpointError)
|
||||
|
||||
def test_late_succeed(self):
|
||||
# dilation can proceed, then ep.connect() is called
|
||||
scid0 = 0
|
||||
peeraddr = _SubchannelAddress(scid0)
|
||||
sc0 = mock.Mock()
|
||||
alsoProvides(sc0, ISubChannel)
|
||||
eq = EventualQueue(Clock())
|
||||
ep = ControlEndpoint(peeraddr, sc0, eq)
|
||||
|
||||
ep._main_channel_ready()
|
||||
|
||||
f = mock.Mock()
|
||||
p = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(return_value=p)
|
||||
d = ep.connect(f)
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.successResultOf(d), p)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
|
||||
self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)])
|
||||
self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)])
|
||||
|
||||
d = ep.connect(f)
|
||||
self.failureResultOf(d, SingleUseEndpointError)
|
||||
|
||||
def test_late_fail(self):
|
||||
# dilation is abandoned, then ep.connect() is called
|
||||
scid0 = 0
|
||||
peeraddr = _SubchannelAddress(scid0)
|
||||
sc0 = mock.Mock()
|
||||
alsoProvides(sc0, ISubChannel)
|
||||
eq = EventualQueue(Clock())
|
||||
ep = ControlEndpoint(peeraddr, sc0, eq)
|
||||
|
||||
ep._main_channel_failed(Failure(CannotDilateError()))
|
||||
|
||||
f = mock.Mock()
|
||||
p = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(return_value=p)
|
||||
d = ep.connect(f)
|
||||
eq.flush_sync()
|
||||
|
||||
self.failureResultOf(d).check(CannotDilateError)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [])
|
||||
self.assertEqual(sc0.mock_calls, [])
|
||||
|
||||
d = ep.connect(f)
|
||||
self.failureResultOf(d, SingleUseEndpointError)
|
||||
|
||||
class Endpoints(unittest.TestCase):
|
||||
def OFFassert_makeConnection(self, mock_calls):
|
||||
self.assertEqual(len(mock_calls), 1)
|
||||
self.assertEqual(mock_calls[0][0], "makeConnection")
|
||||
self.assertEqual(len(mock_calls[0][1]), 1)
|
||||
return mock_calls[0][1][0]
|
||||
|
||||
def test_connector(self):
|
||||
class Connector(unittest.TestCase):
|
||||
def test_early_succeed(self):
|
||||
m = mock_manager()
|
||||
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||
hostaddr = _WormholeAddress()
|
||||
peeraddr = _SubchannelAddress(0)
|
||||
ep = SubchannelConnectorEndpoint(m, hostaddr)
|
||||
eq = EventualQueue(Clock())
|
||||
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
|
||||
|
||||
f = mock.Mock()
|
||||
p = mock.Mock()
|
||||
|
@ -55,38 +140,123 @@ class Endpoints(unittest.TestCase):
|
|||
with mock.patch("wormhole._dilation.subchannel.SubChannel",
|
||||
return_value=t) as sc:
|
||||
d = ep.connect(f)
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d)
|
||||
ep._main_channel_ready()
|
||||
eq.flush_sync()
|
||||
|
||||
self.assertIdentical(self.successResultOf(d), p)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
|
||||
self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)])
|
||||
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
|
||||
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
|
||||
|
||||
def test_listener(self):
|
||||
def test_early_fail(self):
|
||||
m = mock_manager()
|
||||
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||
hostaddr = _WormholeAddress()
|
||||
ep = SubchannelListenerEndpoint(m, hostaddr)
|
||||
eq = EventualQueue(Clock())
|
||||
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
|
||||
|
||||
f = mock.Mock()
|
||||
p = mock.Mock()
|
||||
t = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(return_value=p)
|
||||
with mock.patch("wormhole._dilation.subchannel.SubChannel",
|
||||
return_value=t) as sc:
|
||||
d = ep.connect(f)
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d)
|
||||
ep._main_channel_failed(Failure(CannotDilateError()))
|
||||
eq.flush_sync()
|
||||
|
||||
self.failureResultOf(d).check(CannotDilateError)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [])
|
||||
self.assertEqual(sc.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
def test_late_succeed(self):
|
||||
m = mock_manager()
|
||||
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||
hostaddr = _WormholeAddress()
|
||||
peeraddr = _SubchannelAddress(0)
|
||||
eq = EventualQueue(Clock())
|
||||
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
|
||||
ep._main_channel_ready()
|
||||
|
||||
f = mock.Mock()
|
||||
p = mock.Mock()
|
||||
t = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(return_value=p)
|
||||
with mock.patch("wormhole._dilation.subchannel.SubChannel",
|
||||
return_value=t) as sc:
|
||||
d = ep.connect(f)
|
||||
eq.flush_sync()
|
||||
|
||||
self.assertIdentical(self.successResultOf(d), p)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
|
||||
self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)])
|
||||
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
|
||||
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
|
||||
|
||||
def test_late_fail(self):
|
||||
m = mock_manager()
|
||||
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||
hostaddr = _WormholeAddress()
|
||||
eq = EventualQueue(Clock())
|
||||
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
|
||||
ep._main_channel_failed(Failure(CannotDilateError()))
|
||||
|
||||
f = mock.Mock()
|
||||
p = mock.Mock()
|
||||
t = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(return_value=p)
|
||||
with mock.patch("wormhole._dilation.subchannel.SubChannel",
|
||||
return_value=t) as sc:
|
||||
d = ep.connect(f)
|
||||
eq.flush_sync()
|
||||
|
||||
self.failureResultOf(d).check(CannotDilateError)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [])
|
||||
self.assertEqual(sc.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
class Listener(unittest.TestCase):
|
||||
def test_early_succeed(self):
|
||||
# listen, main_channel_ready, got_open, got_open
|
||||
m = mock_manager()
|
||||
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||
hostaddr = _WormholeAddress()
|
||||
eq = EventualQueue(Clock())
|
||||
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
|
||||
|
||||
f = mock.Mock()
|
||||
p1 = mock.Mock()
|
||||
p2 = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
|
||||
|
||||
# OPEN that arrives before we ep.listen() should be queued
|
||||
d = ep.listen(f)
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [])
|
||||
|
||||
ep._main_channel_ready()
|
||||
eq.flush_sync()
|
||||
lp = self.successResultOf(d)
|
||||
self.assertIsInstance(lp, SubchannelListeningPort)
|
||||
|
||||
self.assertEqual(lp.getHost(), hostaddr)
|
||||
# TODO: IListeningPort says we must provide this, but I don't know
|
||||
# that anyone would ever call it.
|
||||
lp.startListening()
|
||||
|
||||
t1 = mock.Mock()
|
||||
peeraddr1 = _SubchannelAddress(1)
|
||||
ep._got_open(t1, peeraddr1)
|
||||
|
||||
d = ep.listen(f)
|
||||
lp = self.successResultOf(d)
|
||||
self.assertIsInstance(lp, SubchannelListeningPort)
|
||||
|
||||
self.assertEqual(lp.getHost(), hostaddr)
|
||||
lp.startListening()
|
||||
|
||||
self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)])
|
||||
self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)])
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)])
|
||||
|
||||
t2 = mock.Mock()
|
||||
peeraddr2 = _SubchannelAddress(2)
|
||||
|
@ -94,5 +264,92 @@ class Endpoints(unittest.TestCase):
|
|||
|
||||
self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])
|
||||
self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)])
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1),
|
||||
mock.call(peeraddr2)])
|
||||
|
||||
lp.stopListening() # TODO: should this do more?
|
||||
|
||||
def test_early_fail(self):
|
||||
# listen, main_channel_fail
|
||||
m = mock_manager()
|
||||
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||
hostaddr = _WormholeAddress()
|
||||
eq = EventualQueue(Clock())
|
||||
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
|
||||
|
||||
f = mock.Mock()
|
||||
p1 = mock.Mock()
|
||||
p2 = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
|
||||
|
||||
d = ep.listen(f)
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d)
|
||||
|
||||
ep._main_channel_failed(Failure(CannotDilateError()))
|
||||
eq.flush_sync()
|
||||
self.failureResultOf(d).check(CannotDilateError)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [])
|
||||
|
||||
def test_late_succeed(self):
|
||||
# main_channel_ready, got_open, listen, got_open
|
||||
m = mock_manager()
|
||||
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||
hostaddr = _WormholeAddress()
|
||||
eq = EventualQueue(Clock())
|
||||
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
|
||||
ep._main_channel_ready()
|
||||
|
||||
f = mock.Mock()
|
||||
p1 = mock.Mock()
|
||||
p2 = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
|
||||
|
||||
t1 = mock.Mock()
|
||||
peeraddr1 = _SubchannelAddress(1)
|
||||
ep._got_open(t1, peeraddr1)
|
||||
eq.flush_sync()
|
||||
|
||||
self.assertEqual(t1.mock_calls, [])
|
||||
self.assertEqual(p1.mock_calls, [])
|
||||
|
||||
d = ep.listen(f)
|
||||
eq.flush_sync()
|
||||
lp = self.successResultOf(d)
|
||||
self.assertIsInstance(lp, SubchannelListeningPort)
|
||||
self.assertEqual(lp.getHost(), hostaddr)
|
||||
lp.startListening()
|
||||
|
||||
self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)])
|
||||
self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)])
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)])
|
||||
|
||||
t2 = mock.Mock()
|
||||
peeraddr2 = _SubchannelAddress(2)
|
||||
ep._got_open(t2, peeraddr2)
|
||||
|
||||
self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])
|
||||
self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)])
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1),
|
||||
mock.call(peeraddr2)])
|
||||
|
||||
lp.stopListening() # TODO: should this do more?
|
||||
|
||||
def test_late_fail(self):
|
||||
# main_channel_fail, listen
|
||||
m = mock_manager()
|
||||
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||
hostaddr = _WormholeAddress()
|
||||
eq = EventualQueue(Clock())
|
||||
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
|
||||
ep._main_channel_failed(Failure(CannotDilateError()))
|
||||
|
||||
f = mock.Mock()
|
||||
p1 = mock.Mock()
|
||||
p2 = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
|
||||
|
||||
d = ep.listen(f)
|
||||
eq.flush_sync()
|
||||
self.failureResultOf(d).check(CannotDilateError)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [])
|
||||
|
|
|
@ -46,24 +46,21 @@ class Full(ServerBase, unittest.TestCase):
|
|||
yield doBoth(w1.get_verifier(), w2.get_verifier())
|
||||
print("connected")
|
||||
|
||||
eps1_d = w1.dilate()
|
||||
eps2_d = w2.dilate()
|
||||
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
|
||||
(control_ep1, connect_ep1, listen_ep1) = eps1
|
||||
(control_ep2, connect_ep2, listen_ep2) = eps2
|
||||
eps1 = w1.dilate()
|
||||
eps2 = w2.dilate()
|
||||
print("w.dilate ready")
|
||||
|
||||
f1 = Factory()
|
||||
f1.protocol = L
|
||||
f1.d = Deferred()
|
||||
f1.d.addCallback(lambda data: eq.fire_eventually(data))
|
||||
d1 = control_ep1.connect(f1)
|
||||
d1 = eps1.control.connect(f1)
|
||||
|
||||
f2 = Factory()
|
||||
f2.protocol = L
|
||||
f2.d = Deferred()
|
||||
f2.d.addCallback(lambda data: eq.fire_eventually(data))
|
||||
d2 = control_ep2.connect(f2)
|
||||
d2 = eps2.control.connect(f2)
|
||||
yield d1
|
||||
yield d2
|
||||
print("control endpoints connected")
|
||||
|
@ -125,14 +122,12 @@ class Reconnect(ServerBase, unittest.TestCase):
|
|||
w2.set_code(code)
|
||||
yield doBoth(w1.get_verifier(), w2.get_verifier())
|
||||
|
||||
eps1_d = w1.dilate()
|
||||
eps2_d = w2.dilate()
|
||||
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
|
||||
(control_ep1, connect_ep1, listen_ep1) = eps1
|
||||
(control_ep2, connect_ep2, listen_ep2) = eps2
|
||||
eps1 = w1.dilate()
|
||||
eps2 = w2.dilate()
|
||||
print("w.dilate ready")
|
||||
|
||||
f1 = ReconF(eq); f2 = ReconF(eq)
|
||||
d1 = control_ep1.connect(f1); d2 = control_ep2.connect(f2)
|
||||
d1 = eps1.control.connect(f1); d2 = eps2.control.connect(f2)
|
||||
yield d1
|
||||
yield d2
|
||||
|
||||
|
@ -194,14 +189,12 @@ class Reconnect(ServerBase, unittest.TestCase):
|
|||
w2.set_code(code)
|
||||
yield doBoth(w1.get_verifier(), w2.get_verifier())
|
||||
|
||||
eps1_d = w1.dilate()
|
||||
eps2_d = w2.dilate()
|
||||
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
|
||||
(control_ep1, connect_ep1, listen_ep1) = eps1
|
||||
(control_ep2, connect_ep2, listen_ep2) = eps2
|
||||
eps1 = w1.dilate()
|
||||
eps2 = w2.dilate()
|
||||
print("w.dilate ready")
|
||||
|
||||
f1 = ReconF(eq); f2 = ReconF(eq)
|
||||
d1 = control_ep1.connect(f1); d2 = control_ep2.connect(f2)
|
||||
d1 = eps1.control.connect(f1); d2 = eps2.control.connect(f2)
|
||||
yield d1
|
||||
yield d2
|
||||
|
||||
|
@ -287,19 +280,17 @@ class Endpoints(ServerBase, unittest.TestCase):
|
|||
w2.set_code(code)
|
||||
yield doBoth(w1.get_verifier(), w2.get_verifier())
|
||||
|
||||
eps1_d = w1.dilate()
|
||||
eps2_d = w2.dilate()
|
||||
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
|
||||
(control_ep1, connect_ep1, listen_ep1) = eps1
|
||||
(control_ep2, connect_ep2, listen_ep2) = eps2
|
||||
eps1 = w1.dilate()
|
||||
eps2 = w2.dilate()
|
||||
print("w.dilate ready")
|
||||
|
||||
f0 = ReconF(eq)
|
||||
yield listen_ep2.listen(f0)
|
||||
yield eps2.listen.listen(f0)
|
||||
|
||||
from twisted.python import log
|
||||
f1 = ReconF(eq)
|
||||
log.msg("connecting")
|
||||
p1_client = yield connect_ep1.connect(f1)
|
||||
p1_client = yield eps1.connect.connect(f1)
|
||||
log.msg("sending c->s")
|
||||
p1_client.transport.write(b"hello from p1\n")
|
||||
data = yield f0.deferreds["dataReceived"]
|
||||
|
@ -316,7 +307,7 @@ class Endpoints(ServerBase, unittest.TestCase):
|
|||
f0.resetDeferred("dataReceived")
|
||||
f1.resetDeferred("dataReceived")
|
||||
f2 = ReconF(eq)
|
||||
p2_client = yield connect_ep1.connect(f2)
|
||||
p2_client = yield eps1.connect.connect(f2)
|
||||
p2_server = yield f0.deferreds["connectionMade"]
|
||||
p2_server.transport.write(b"hello p2\n")
|
||||
data = yield f2.deferreds["dataReceived"]
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.task import Clock, Cooperator
|
||||
from twisted.internet.interfaces import IAddress
|
||||
from twisted.internet.interfaces import IStreamServerEndpoint
|
||||
import mock
|
||||
from ...eventual import EventualQueue
|
||||
from ..._interfaces import ISend, IDilationManager, ITerminator
|
||||
from ..._interfaces import ISend, ITerminator, ISubChannel
|
||||
from ...util import dict_to_bytes
|
||||
from ..._dilation import roles
|
||||
from ..._dilation.manager import (Dilator, Manager, make_side,
|
||||
|
@ -15,35 +14,57 @@ from ..._dilation.manager import (Dilator, Manager, make_side,
|
|||
UnexpectedKCM,
|
||||
UnknownMessageType)
|
||||
from ..._dilation.connection import Open, Data, Close, Ack, KCM, Ping, Pong
|
||||
from ..._dilation.subchannel import _SubchannelAddress
|
||||
from .common import clear_mock_calls
|
||||
|
||||
class Holder():
|
||||
pass
|
||||
|
||||
def make_dilator():
|
||||
reactor = object()
|
||||
clock = Clock()
|
||||
eq = EventualQueue(clock)
|
||||
h = Holder()
|
||||
h.reactor = object()
|
||||
h.clock = Clock()
|
||||
h.eq = EventualQueue(h.clock)
|
||||
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
|
||||
|
||||
def term_factory():
|
||||
return term
|
||||
coop = Cooperator(terminationPredicateFactory=term_factory,
|
||||
scheduler=eq.eventually)
|
||||
send = mock.Mock()
|
||||
alsoProvides(send, ISend)
|
||||
dil = Dilator(reactor, eq, coop)
|
||||
terminator = mock.Mock()
|
||||
alsoProvides(terminator, ITerminator)
|
||||
dil.wire(send, terminator)
|
||||
return dil, send, reactor, eq, clock, coop
|
||||
h.coop = Cooperator(terminationPredicateFactory=term_factory,
|
||||
scheduler=h.eq.eventually)
|
||||
h.send = mock.Mock()
|
||||
alsoProvides(h.send, ISend)
|
||||
dil = Dilator(h.reactor, h.eq, h.coop)
|
||||
h.terminator = mock.Mock()
|
||||
alsoProvides(h.terminator, ITerminator)
|
||||
dil.wire(h.send, h.terminator)
|
||||
return dil, h
|
||||
|
||||
|
||||
class TestDilator(unittest.TestCase):
|
||||
def test_manager_and_endpoints(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
d1 = dil.dilate()
|
||||
d2 = dil.dilate()
|
||||
self.assertNoResult(d1)
|
||||
self.assertNoResult(d2)
|
||||
# we should test the interleavings between:
|
||||
# * application calls w.dilate() and gets back endpoints
|
||||
# * wormhole gets: dilation key, VERSION, 0-n dilation messages
|
||||
|
||||
def test_dilate_first(self):
|
||||
(dil, h) = make_dilator()
|
||||
side = object()
|
||||
m = mock.Mock()
|
||||
eps = object()
|
||||
m.get_endpoints = mock.Mock(return_value=eps)
|
||||
mm = mock.Mock(side_effect=[m])
|
||||
with mock.patch("wormhole._dilation.manager.Manager", mm), \
|
||||
mock.patch("wormhole._dilation.manager.make_side",
|
||||
return_value=side):
|
||||
eps1 = dil.dilate()
|
||||
eps2 = dil.dilate()
|
||||
self.assertIdentical(eps1, eps)
|
||||
self.assertIdentical(eps1, eps2)
|
||||
self.assertEqual(mm.mock_calls, [mock.call(h.send, side, None,
|
||||
h.reactor, h.eq, h.coop, False)])
|
||||
|
||||
self.assertEqual(m.mock_calls, [mock.call.get_endpoints(),
|
||||
mock.call.get_endpoints()])
|
||||
clear_mock_calls(m)
|
||||
|
||||
key = b"key"
|
||||
transit_key = object()
|
||||
|
@ -51,183 +72,108 @@ class TestDilator(unittest.TestCase):
|
|||
return_value=transit_key) as dk:
|
||||
dil.got_key(key)
|
||||
self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)])
|
||||
self.assertIdentical(dil._transit_key, transit_key)
|
||||
self.assertNoResult(d1)
|
||||
self.assertNoResult(d2)
|
||||
self.assertEqual(m.mock_calls, [mock.call.got_dilation_key(transit_key)])
|
||||
clear_mock_calls(m)
|
||||
|
||||
host_addr = dil._host_addr
|
||||
wv = object()
|
||||
dil.got_wormhole_versions(wv)
|
||||
self.assertEqual(m.mock_calls, [mock.call.got_wormhole_versions(wv)])
|
||||
clear_mock_calls(m)
|
||||
|
||||
peer_addr = object()
|
||||
m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress",
|
||||
return_value=peer_addr)
|
||||
sc = mock.Mock()
|
||||
m_sc = mock.patch("wormhole._dilation.manager.SubChannel",
|
||||
return_value=sc)
|
||||
scid0 = 0
|
||||
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IDilationManager)
|
||||
m.when_first_connected.return_value = wfc_d = Deferred()
|
||||
with mock.patch("wormhole._dilation.manager.Manager",
|
||||
return_value=m) as ml:
|
||||
with mock.patch("wormhole._dilation.manager.make_side",
|
||||
return_value="us"):
|
||||
with m_sca, m_sc as m_sc_m:
|
||||
dil.got_wormhole_versions({"can-dilate": ["1"]})
|
||||
# that should create the Manager
|
||||
self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key,
|
||||
None, reactor, eq, coop, host_addr, False)])
|
||||
# and create subchannel0
|
||||
self.assertEqual(m_sc_m.mock_calls,
|
||||
[mock.call(scid0, m, host_addr, peer_addr)])
|
||||
# and tell it to start, and get wait-for-it-to-connect Deferred
|
||||
self.assertEqual(m.mock_calls, [mock.call.set_subchannel_zero(scid0, sc),
|
||||
mock.call.start(),
|
||||
mock.call.when_first_connected(),
|
||||
dm1 = object()
|
||||
dm2 = object()
|
||||
dil.received_dilate(dm1)
|
||||
dil.received_dilate(dm2)
|
||||
self.assertEqual(m.mock_calls, [mock.call.received_dilation_message(dm1),
|
||||
mock.call.received_dilation_message(dm2),
|
||||
])
|
||||
clear_mock_calls(m)
|
||||
self.assertNoResult(d1)
|
||||
self.assertNoResult(d2)
|
||||
|
||||
ce = mock.Mock()
|
||||
m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint",
|
||||
return_value=ce)
|
||||
lep = object()
|
||||
m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint",
|
||||
return_value=lep)
|
||||
stopped_d = mock.Mock()
|
||||
m.when_stopped = mock.Mock(return_value=stopped_d)
|
||||
dil.stop()
|
||||
self.assertEqual(m.mock_calls, [mock.call.stop(),
|
||||
mock.call.when_stopped(),
|
||||
])
|
||||
|
||||
with m_ce as m_ce_m, m_sle as m_sle_m:
|
||||
wfc_d.callback(None)
|
||||
eq.flush_sync()
|
||||
self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)])
|
||||
self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)])
|
||||
self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)])
|
||||
self.assertEqual(m.mock_calls,
|
||||
[mock.call.set_listener_endpoint(lep),
|
||||
])
|
||||
def test_dilate_later(self):
|
||||
(dil, h) = make_dilator()
|
||||
m = mock.Mock()
|
||||
mm = mock.Mock(side_effect=[m])
|
||||
|
||||
key = b"key"
|
||||
transit_key = object()
|
||||
with mock.patch("wormhole._dilation.manager.derive_key",
|
||||
return_value=transit_key) as dk:
|
||||
dil.got_key(key)
|
||||
self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)])
|
||||
|
||||
wv = object()
|
||||
dil.got_wormhole_versions(wv)
|
||||
|
||||
dm1 = object()
|
||||
dil.received_dilate(dm1)
|
||||
|
||||
self.assertEqual(mm.mock_calls, [])
|
||||
|
||||
with mock.patch("wormhole._dilation.manager.Manager", mm):
|
||||
dil.dilate()
|
||||
self.assertEqual(m.mock_calls, [mock.call.got_dilation_key(transit_key),
|
||||
mock.call.got_wormhole_versions(wv),
|
||||
mock.call.received_dilation_message(dm1),
|
||||
mock.call.get_endpoints(),
|
||||
])
|
||||
clear_mock_calls(m)
|
||||
|
||||
eps = self.successResultOf(d1)
|
||||
self.assertEqual(eps, self.successResultOf(d2))
|
||||
d3 = dil.dilate()
|
||||
eq.flush_sync()
|
||||
self.assertEqual(eps, self.successResultOf(d3))
|
||||
dm2 = object()
|
||||
dil.received_dilate(dm2)
|
||||
self.assertEqual(m.mock_calls, [mock.call.received_dilation_message(dm2),
|
||||
])
|
||||
|
||||
# all subsequent DILATE-n messages should get passed to the manager
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
pleasemsg = dict(type="please", side="them")
|
||||
dil.received_dilate(dict_to_bytes(pleasemsg))
|
||||
self.assertEqual(m.mock_calls, [mock.call.rx_PLEASE(pleasemsg)])
|
||||
clear_mock_calls(m)
|
||||
|
||||
hintmsg = dict(type="connection-hints")
|
||||
dil.received_dilate(dict_to_bytes(hintmsg))
|
||||
self.assertEqual(m.mock_calls, [mock.call.rx_HINTS(hintmsg)])
|
||||
clear_mock_calls(m)
|
||||
|
||||
# we're nominally the LEADER, and the leader would not normally be
|
||||
# receiving a RECONNECT, but since we've mocked out the Manager it
|
||||
# won't notice
|
||||
dil.received_dilate(dict_to_bytes(dict(type="reconnect")))
|
||||
self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECT()])
|
||||
clear_mock_calls(m)
|
||||
|
||||
dil.received_dilate(dict_to_bytes(dict(type="reconnecting")))
|
||||
self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECTING()])
|
||||
clear_mock_calls(m)
|
||||
|
||||
dil.received_dilate(dict_to_bytes(dict(type="unknown")))
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
self.flushLoggedErrors(UnknownDilationMessageType)
|
||||
def test_stop_early(self):
|
||||
(dil, h) = make_dilator()
|
||||
# we stop before w.dilate(), so there is no Manager to stop
|
||||
dil.stop()
|
||||
self.assertEqual(h.terminator.mock_calls, [mock.call.stoppedD()])
|
||||
|
||||
def test_peer_cannot_dilate(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
d1 = dil.dilate()
|
||||
self.assertNoResult(d1)
|
||||
(dil, h) = make_dilator()
|
||||
eps = dil.dilate()
|
||||
|
||||
dil._transit_key = b"\x01" * 32
|
||||
dil.got_key(b"\x01" * 32)
|
||||
dil.got_wormhole_versions({}) # missing "can-dilate"
|
||||
eq.flush_sync()
|
||||
f = self.failureResultOf(d1)
|
||||
f.check(OldPeerCannotDilateError)
|
||||
d = eps.connect.connect(None)
|
||||
h.eq.flush_sync()
|
||||
self.failureResultOf(d).check(OldPeerCannotDilateError)
|
||||
|
||||
def test_disjoint_versions(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
d1 = dil.dilate()
|
||||
self.assertNoResult(d1)
|
||||
(dil, h) = make_dilator()
|
||||
eps = dil.dilate()
|
||||
|
||||
dil._transit_key = b"key"
|
||||
dil.got_key(b"\x01" * 32)
|
||||
dil.got_wormhole_versions({"can-dilate": [-1]})
|
||||
eq.flush_sync()
|
||||
f = self.failureResultOf(d1)
|
||||
f.check(OldPeerCannotDilateError)
|
||||
|
||||
def test_early_dilate_messages(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
dil._transit_key = b"key"
|
||||
d1 = dil.dilate()
|
||||
host_addr = dil._host_addr
|
||||
self.assertNoResult(d1)
|
||||
pleasemsg = dict(type="please", side="them")
|
||||
dil.received_dilate(dict_to_bytes(pleasemsg))
|
||||
hintmsg = dict(type="connection-hints")
|
||||
dil.received_dilate(dict_to_bytes(hintmsg))
|
||||
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IDilationManager)
|
||||
m.when_first_connected.return_value = Deferred()
|
||||
|
||||
scid0 = 0
|
||||
sc = mock.Mock()
|
||||
m_sc = mock.patch("wormhole._dilation.manager.SubChannel",
|
||||
return_value=sc)
|
||||
|
||||
with mock.patch("wormhole._dilation.manager.Manager",
|
||||
return_value=m) as ml:
|
||||
with mock.patch("wormhole._dilation.manager.make_side",
|
||||
return_value="us"):
|
||||
with m_sc:
|
||||
dil.got_wormhole_versions({"can-dilate": ["1"]})
|
||||
self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
|
||||
None, reactor, eq, coop, host_addr, False)])
|
||||
self.assertEqual(m.mock_calls, [mock.call.set_subchannel_zero(scid0, sc),
|
||||
mock.call.start(),
|
||||
mock.call.rx_PLEASE(pleasemsg),
|
||||
mock.call.rx_HINTS(hintmsg),
|
||||
mock.call.when_first_connected()])
|
||||
d = eps.connect.connect(None)
|
||||
h.eq.flush_sync()
|
||||
self.failureResultOf(d).check(OldPeerCannotDilateError)
|
||||
|
||||
def test_transit_relay(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
dil._transit_key = b"key"
|
||||
host_addr = dil._host_addr
|
||||
relay = object()
|
||||
d1 = dil.dilate(transit_relay_location=relay)
|
||||
self.assertNoResult(d1)
|
||||
|
||||
scid0 = 0
|
||||
sc = mock.Mock()
|
||||
m_sc = mock.patch("wormhole._dilation.manager.SubChannel",
|
||||
return_value=sc)
|
||||
|
||||
with mock.patch("wormhole._dilation.manager.Manager") as ml:
|
||||
with mock.patch("wormhole._dilation.manager.make_side",
|
||||
return_value="us"):
|
||||
with m_sc:
|
||||
dil.got_wormhole_versions({"can-dilate": ["1"]})
|
||||
self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
|
||||
relay, reactor, eq, coop, host_addr, False),
|
||||
mock.call().set_subchannel_zero(scid0, sc),
|
||||
mock.call().start(),
|
||||
mock.call().when_first_connected()])
|
||||
|
||||
(dil, h) = make_dilator()
|
||||
transit_relay_location = object()
|
||||
side = object()
|
||||
m = mock.Mock()
|
||||
mm = mock.Mock(side_effect=[m])
|
||||
with mock.patch("wormhole._dilation.manager.Manager", mm), \
|
||||
mock.patch("wormhole._dilation.manager.make_side",
|
||||
return_value=side):
|
||||
dil.dilate(transit_relay_location)
|
||||
self.assertEqual(mm.mock_calls, [mock.call(h.send, side, transit_relay_location,
|
||||
h.reactor, h.eq, h.coop, False)])
|
||||
|
||||
LEADER = "ff3456abcdef"
|
||||
FOLLOWER = "123456abcdef"
|
||||
|
||||
|
||||
def make_manager(leader=True):
|
||||
class Holder:
|
||||
pass
|
||||
h = Holder()
|
||||
h.send = mock.Mock()
|
||||
alsoProvides(h.send, ISend)
|
||||
|
@ -250,11 +196,19 @@ def make_manager(leader=True):
|
|||
h.Inbound = mock.Mock(return_value=h.inbound)
|
||||
h.outbound = mock.Mock()
|
||||
h.Outbound = mock.Mock(return_value=h.outbound)
|
||||
h.hostaddr = mock.Mock()
|
||||
alsoProvides(h.hostaddr, IAddress)
|
||||
with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound):
|
||||
with mock.patch("wormhole._dilation.manager.Outbound", h.Outbound):
|
||||
m = Manager(h.send, side, h.key, h.relay, h.reactor, h.eq, h.coop, h.hostaddr)
|
||||
h.sc0 = mock.Mock()
|
||||
alsoProvides(h.sc0, ISubChannel)
|
||||
h.SubChannel = mock.Mock(return_value=h.sc0)
|
||||
h.listen_ep = mock.Mock()
|
||||
alsoProvides(h.listen_ep, IStreamServerEndpoint)
|
||||
with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound), \
|
||||
mock.patch("wormhole._dilation.manager.Outbound", h.Outbound), \
|
||||
mock.patch("wormhole._dilation.manager.SubChannel", h.SubChannel), \
|
||||
mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint",
|
||||
return_value=h.listen_ep):
|
||||
m = Manager(h.send, side, h.relay, h.reactor, h.eq, h.coop)
|
||||
h.hostaddr = m._host_addr
|
||||
m.got_dilation_key(h.key)
|
||||
return m, h
|
||||
|
||||
|
||||
|
@ -272,16 +226,26 @@ class TestManager(unittest.TestCase):
|
|||
self.assertEqual(h.send.mock_calls, [])
|
||||
self.assertEqual(h.Inbound.mock_calls, [mock.call(m, h.hostaddr)])
|
||||
self.assertEqual(h.Outbound.mock_calls, [mock.call(m, h.coop)])
|
||||
scid0 = 0
|
||||
sc0_peer_addr = _SubchannelAddress(scid0)
|
||||
self.assertEqual(h.SubChannel.mock_calls, [
|
||||
mock.call(scid0, m, m._host_addr, sc0_peer_addr),
|
||||
])
|
||||
self.assertEqual(h.inbound.mock_calls, [
|
||||
mock.call.set_subchannel_zero(scid0, h.sc0),
|
||||
mock.call.set_listener_endpoint(h.listen_ep)
|
||||
])
|
||||
clear_mock_calls(h.inbound)
|
||||
|
||||
m.start()
|
||||
m.got_wormhole_versions({"can-dilate": ["1"]})
|
||||
self.assertEqual(h.send.mock_calls, [
|
||||
mock.call.send("dilate-0",
|
||||
dict_to_bytes({"type": "please", "side": LEADER}))
|
||||
])
|
||||
clear_mock_calls(h.send)
|
||||
|
||||
wfc_d = m.when_first_connected()
|
||||
self.assertNoResult(wfc_d)
|
||||
listen_d = m.get_endpoints().listen.listen(None)
|
||||
self.assertNoResult(listen_d)
|
||||
|
||||
# ignore early hints
|
||||
m.rx_HINTS({})
|
||||
|
@ -303,7 +267,7 @@ class TestManager(unittest.TestCase):
|
|||
self.assertEqual(c.mock_calls, [mock.call.start()])
|
||||
clear_mock_calls(connector, c)
|
||||
|
||||
self.assertNoResult(wfc_d)
|
||||
self.assertNoResult(listen_d)
|
||||
|
||||
# now any inbound hints should get passed to our Connector
|
||||
with mock.patch("wormhole._dilation.manager.parse_hint",
|
||||
|
@ -323,7 +287,7 @@ class TestManager(unittest.TestCase):
|
|||
clear_mock_calls(h.send)
|
||||
|
||||
# the first successful connection fires when_first_connected(), so
|
||||
# the Dilator can create and return the endpoints
|
||||
# the endpoints can activate
|
||||
c1 = mock.Mock()
|
||||
m.connector_connection_made(c1)
|
||||
|
||||
|
@ -331,23 +295,6 @@ class TestManager(unittest.TestCase):
|
|||
self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c1)])
|
||||
clear_mock_calls(h.inbound, h.outbound)
|
||||
|
||||
h.eq.flush_sync()
|
||||
self.successResultOf(wfc_d) # fires with None
|
||||
wfc_d2 = m.when_first_connected()
|
||||
h.eq.flush_sync()
|
||||
self.successResultOf(wfc_d2)
|
||||
|
||||
scid0 = 0
|
||||
sc0 = mock.Mock()
|
||||
m.set_subchannel_zero(scid0, sc0)
|
||||
listen_ep = mock.Mock()
|
||||
m.set_listener_endpoint(listen_ep)
|
||||
self.assertEqual(h.inbound.mock_calls, [
|
||||
mock.call.set_subchannel_zero(scid0, sc0),
|
||||
mock.call.set_listener_endpoint(listen_ep),
|
||||
])
|
||||
clear_mock_calls(h.inbound)
|
||||
|
||||
# the Leader making a new outbound channel should get scid=1
|
||||
scid1 = 1
|
||||
self.assertEqual(m.allocate_subchannel_id(), scid1)
|
||||
|
@ -494,12 +441,13 @@ class TestManager(unittest.TestCase):
|
|||
def test_follower(self):
|
||||
m, h = make_manager(leader=False)
|
||||
|
||||
m.start()
|
||||
m.got_wormhole_versions({"can-dilate": ["1"]})
|
||||
self.assertEqual(h.send.mock_calls, [
|
||||
mock.call.send("dilate-0",
|
||||
dict_to_bytes({"type": "please", "side": FOLLOWER}))
|
||||
])
|
||||
clear_mock_calls(h.send)
|
||||
clear_mock_calls(h.inbound)
|
||||
|
||||
c = mock.Mock()
|
||||
connector = mock.Mock(return_value=c)
|
||||
|
@ -629,6 +577,7 @@ class TestManager(unittest.TestCase):
|
|||
|
||||
def test_subchannel(self):
|
||||
m, h = make_manager(leader=True)
|
||||
clear_mock_calls(h.inbound)
|
||||
sc = object()
|
||||
|
||||
m.subchannel_pauseProducing(sc)
|
||||
|
@ -665,3 +614,14 @@ class TestManager(unittest.TestCase):
|
|||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.subchannel_closed(4, sc)])
|
||||
clear_mock_calls(h.inbound, h.outbound)
|
||||
|
||||
def test_unknown_message(self):
|
||||
# receive a PLEASE with the same side as us: shouldn't happen
|
||||
m, h = make_manager(leader=True)
|
||||
m.start()
|
||||
|
||||
m.received_dilation_message(dict_to_bytes(dict(type="unknown")))
|
||||
self.flushLoggedErrors(UnknownDilationMessageType)
|
||||
|
||||
# TODO: test transit relay is used
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user