WIP: rewrite w.dilate API to return endpoints synchronously
test_manager still needs rewriting
This commit is contained in:
parent
75fad02a28
commit
85cb003498
|
@ -2,13 +2,20 @@ from __future__ import print_function, unicode_literals
|
||||||
import six
|
import six
|
||||||
import os
|
import os
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
try:
|
||||||
|
# py >= 3.3
|
||||||
|
from collections.abc import Sequence
|
||||||
|
except ImportError:
|
||||||
|
# py 2 and py3 < 3.3
|
||||||
|
from collections import Sequence
|
||||||
from attr import attrs, attrib
|
from attr import attrs, attrib
|
||||||
from attr.validators import provides, instance_of, optional
|
from attr.validators import provides, instance_of, optional
|
||||||
from automat import MethodicalMachine
|
from automat import MethodicalMachine
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
|
from twisted.internet.defer import Deferred
|
||||||
from twisted.internet.interfaces import IAddress
|
from twisted.internet.interfaces import (IStreamClientEndpoint,
|
||||||
from twisted.python import log
|
IStreamServerEndpoint)
|
||||||
|
from twisted.python import log, failure
|
||||||
from .._interfaces import IDilator, IDilationManager, ISend, ITerminator
|
from .._interfaces import IDilator, IDilationManager, ISend, ITerminator
|
||||||
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr
|
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr
|
||||||
from ..observer import OneShotObserver
|
from ..observer import OneShotObserver
|
||||||
|
@ -47,6 +54,15 @@ class UnexpectedKCM(Exception):
|
||||||
class UnknownMessageType(Exception):
|
class UnknownMessageType(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@attrs
|
||||||
|
class EndpointRecord(Sequence):
|
||||||
|
control = attrib(validator=provides(IStreamClientEndpoint))
|
||||||
|
connect = attrib(validator=provides(IStreamClientEndpoint))
|
||||||
|
listen = attrib(validator=provides(IStreamServerEndpoint))
|
||||||
|
def __len__(self):
|
||||||
|
return 3
|
||||||
|
def __getitem__(self, n):
|
||||||
|
return (self.control, self.connect, self.listen)[n]
|
||||||
|
|
||||||
def make_side():
|
def make_side():
|
||||||
return bytes_to_hexstr(os.urandom(6))
|
return bytes_to_hexstr(os.urandom(6))
|
||||||
|
@ -93,13 +109,14 @@ def make_side():
|
||||||
class Manager(object):
|
class Manager(object):
|
||||||
_S = attrib(validator=provides(ISend), repr=False)
|
_S = attrib(validator=provides(ISend), repr=False)
|
||||||
_my_side = attrib(validator=instance_of(type(u"")))
|
_my_side = attrib(validator=instance_of(type(u"")))
|
||||||
_transit_key = attrib(validator=instance_of(bytes), repr=False)
|
|
||||||
_transit_relay_location = attrib(validator=optional(instance_of(str)))
|
_transit_relay_location = attrib(validator=optional(instance_of(str)))
|
||||||
_reactor = attrib(repr=False)
|
_reactor = attrib(repr=False)
|
||||||
_eventual_queue = attrib(repr=False)
|
_eventual_queue = attrib(repr=False)
|
||||||
_cooperator = attrib(repr=False)
|
_cooperator = attrib(repr=False)
|
||||||
_host_addr = attrib(validator=provides(IAddress))
|
# TODO: can this validator work when the parameter is optional?
|
||||||
_no_listen = attrib(default=False)
|
_no_listen = attrib(validator=instance_of(bool), default=False)
|
||||||
|
|
||||||
|
_dilation_key = None
|
||||||
_tor = None # TODO
|
_tor = None # TODO
|
||||||
_timing = None # TODO
|
_timing = None # TODO
|
||||||
_next_subchannel_id = None # initialized in choose_role
|
_next_subchannel_id = None # initialized in choose_role
|
||||||
|
@ -111,10 +128,10 @@ class Manager(object):
|
||||||
self._got_versions_d = Deferred()
|
self._got_versions_d = Deferred()
|
||||||
|
|
||||||
self._my_role = None # determined upon rx_PLEASE
|
self._my_role = None # determined upon rx_PLEASE
|
||||||
|
self._host_addr = _WormholeAddress()
|
||||||
|
|
||||||
self._connection = None
|
self._connection = None
|
||||||
self._made_first_connection = False
|
self._made_first_connection = False
|
||||||
self._first_connected = OneShotObserver(self._eventual_queue)
|
|
||||||
self._stopped = OneShotObserver(self._eventual_queue)
|
self._stopped = OneShotObserver(self._eventual_queue)
|
||||||
self._debug_stall_connector = False
|
self._debug_stall_connector = False
|
||||||
|
|
||||||
|
@ -127,18 +144,81 @@ class Manager(object):
|
||||||
self._inbound = Inbound(self, self._host_addr)
|
self._inbound = Inbound(self, self._host_addr)
|
||||||
self._outbound = Outbound(self, self._cooperator) # from us to peer
|
self._outbound = Outbound(self, self._cooperator) # from us to peer
|
||||||
|
|
||||||
def set_listener_endpoint(self, listener_endpoint):
|
# We must open subchannel0 early, since messages may arrive very
|
||||||
self._inbound.set_listener_endpoint(listener_endpoint)
|
# quickly once the connection is established. This subchannel may or
|
||||||
|
# may not ever get revealed to the caller, since the peer might not
|
||||||
def set_subchannel_zero(self, scid0, sc0):
|
# even be capable of dilation.
|
||||||
|
scid0 = 0
|
||||||
|
peer_addr0 = _SubchannelAddress(scid0)
|
||||||
|
sc0 = SubChannel(scid0, self, self._host_addr, peer_addr0)
|
||||||
self._inbound.set_subchannel_zero(scid0, sc0)
|
self._inbound.set_subchannel_zero(scid0, sc0)
|
||||||
|
|
||||||
def when_first_connected(self):
|
# we can open non-zero subchannels as soon as we get our first
|
||||||
return self._first_connected.when_fired()
|
# connection, and we can make the Endpoints even earlier
|
||||||
|
control_ep = ControlEndpoint(peer_addr0, sc0, self._eventual_queue)
|
||||||
|
connect_ep = SubchannelConnectorEndpoint(self, self._host_addr, self._eventual_queue)
|
||||||
|
listen_ep = SubchannelListenerEndpoint(self, self._host_addr, self._eventual_queue)
|
||||||
|
# TODO: let inbound/outbound create the endpoints, then return them
|
||||||
|
# to us
|
||||||
|
self._inbound.set_listener_endpoint(listen_ep)
|
||||||
|
|
||||||
|
self._endpoints = EndpointRecord(control_ep, connect_ep, listen_ep)
|
||||||
|
|
||||||
|
def get_endpoints(self):
|
||||||
|
return self._endpoints
|
||||||
|
|
||||||
|
def got_dilation_key(self, key):
|
||||||
|
assert isinstance(key, bytes)
|
||||||
|
self._dilation_key = key
|
||||||
|
|
||||||
|
def got_wormhole_versions(self, their_wormhole_versions):
|
||||||
|
# this always happens before received_dilation_message
|
||||||
|
dilation_version = None
|
||||||
|
their_dilation_versions = set(their_wormhole_versions.get("can-dilate", []))
|
||||||
|
my_versions = set(DILATION_VERSIONS)
|
||||||
|
shared_versions = my_versions.intersection(their_dilation_versions)
|
||||||
|
if "1" in shared_versions:
|
||||||
|
dilation_version = "1"
|
||||||
|
|
||||||
|
# dilation_version is the best mutually-compatible version we have
|
||||||
|
# with the peer, or None if we have nothing in common
|
||||||
|
|
||||||
|
if not dilation_version: # "1" or None
|
||||||
|
# TODO: be more specific about the error. dilation_version==None
|
||||||
|
# means we had no version in common with them, which could either
|
||||||
|
# be because they're so old they don't dilate at all, or because
|
||||||
|
# they're so new that they no longer accomodate our old version
|
||||||
|
self.fail(failure.Failure(OldPeerCannotDilateError()))
|
||||||
|
|
||||||
|
self.start()
|
||||||
|
|
||||||
|
def fail(self, f):
|
||||||
|
self._endpoints.control._main_channel_failed(f)
|
||||||
|
self._endpoints.connect._main_channel_failed(f)
|
||||||
|
self._endpoints.listen._main_channel_failed(f)
|
||||||
|
|
||||||
|
def received_dilation_message(self, plaintext):
|
||||||
|
# this receives new in-order DILATE-n payloads, decrypted but not
|
||||||
|
# de-JSONed.
|
||||||
|
|
||||||
|
message = bytes_to_dict(plaintext)
|
||||||
|
type = message["type"]
|
||||||
|
if type == "please":
|
||||||
|
self.rx_PLEASE(message)
|
||||||
|
elif type == "connection-hints":
|
||||||
|
self.rx_HINTS(message)
|
||||||
|
elif type == "reconnect":
|
||||||
|
self.rx_RECONNECT()
|
||||||
|
elif type == "reconnecting":
|
||||||
|
self.rx_RECONNECTING()
|
||||||
|
else:
|
||||||
|
log.err(UnknownDilationMessageType(message))
|
||||||
|
return
|
||||||
|
|
||||||
def when_stopped(self):
|
def when_stopped(self):
|
||||||
return self._stopped.when_fired()
|
return self._stopped.when_fired()
|
||||||
|
|
||||||
|
|
||||||
def send_dilation_phase(self, **fields):
|
def send_dilation_phase(self, **fields):
|
||||||
dilation_phase = self._next_dilation_phase
|
dilation_phase = self._next_dilation_phase
|
||||||
self._next_dilation_phase += 1
|
self._next_dilation_phase += 1
|
||||||
|
@ -204,7 +284,9 @@ class Manager(object):
|
||||||
self._outbound.use_connection(c) # does c.registerProducer
|
self._outbound.use_connection(c) # does c.registerProducer
|
||||||
if not self._made_first_connection:
|
if not self._made_first_connection:
|
||||||
self._made_first_connection = True
|
self._made_first_connection = True
|
||||||
self._first_connected.fire(None)
|
self._endpoints.control._main_channel_ready()
|
||||||
|
self._endpoints.connect._main_channel_ready()
|
||||||
|
self._endpoints.listen._main_channel_ready()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def connector_connection_lost(self):
|
def connector_connection_lost(self):
|
||||||
|
@ -272,16 +354,11 @@ class Manager(object):
|
||||||
|
|
||||||
# state machine
|
# state machine
|
||||||
|
|
||||||
# We are born WANTING after the local app calls w.dilate(). We start
|
|
||||||
# CONNECTING when we receive PLEASE from the remote side
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
self.send_please()
|
|
||||||
|
|
||||||
def send_please(self):
|
|
||||||
self.send_dilation_phase(type="please", side=self._my_side)
|
|
||||||
|
|
||||||
@m.state(initial=True)
|
@m.state(initial=True)
|
||||||
|
def WAITING(self):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
@m.state()
|
||||||
def WANTING(self):
|
def WANTING(self):
|
||||||
pass # pragma: no cover
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
@ -313,6 +390,10 @@ class Manager(object):
|
||||||
def STOPPED(self):
|
def STOPPED(self):
|
||||||
pass # pragma: no cover
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
@m.input()
|
||||||
|
def start(self):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
@m.input()
|
@m.input()
|
||||||
def rx_PLEASE(self, message):
|
def rx_PLEASE(self, message):
|
||||||
pass # pragma: no cover
|
pass # pragma: no cover
|
||||||
|
@ -350,6 +431,10 @@ class Manager(object):
|
||||||
def stop(self):
|
def stop(self):
|
||||||
pass # pragma: no cover
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
@m.output()
|
||||||
|
def send_please(self):
|
||||||
|
self.send_dilation_phase(type="please", side=self._my_side)
|
||||||
|
|
||||||
@m.output()
|
@m.output()
|
||||||
def choose_role(self, message):
|
def choose_role(self, message):
|
||||||
their_side = message["side"]
|
their_side = message["side"]
|
||||||
|
@ -378,7 +463,8 @@ class Manager(object):
|
||||||
|
|
||||||
def _start_connecting(self):
|
def _start_connecting(self):
|
||||||
assert self._my_role is not None
|
assert self._my_role is not None
|
||||||
self._connector = Connector(self._transit_key,
|
assert self._dilation_key is not None
|
||||||
|
self._connector = Connector(self._dilation_key,
|
||||||
self._transit_relay_location,
|
self._transit_relay_location,
|
||||||
self,
|
self,
|
||||||
self._reactor, self._eventual_queue,
|
self._reactor, self._eventual_queue,
|
||||||
|
@ -422,6 +508,11 @@ class Manager(object):
|
||||||
def notify_stopped(self):
|
def notify_stopped(self):
|
||||||
self._stopped.fire(None)
|
self._stopped.fire(None)
|
||||||
|
|
||||||
|
# We are born WAITING after the local app calls w.dilate(). We enter
|
||||||
|
# WANTING (and send a PLEASE) when we learn of a mutually-compatible
|
||||||
|
# dilation_version.
|
||||||
|
WAITING.upon(start, enter=WANTING, outputs=[send_please])
|
||||||
|
|
||||||
# we start CONNECTING when we get rx_PLEASE
|
# we start CONNECTING when we get rx_PLEASE
|
||||||
WANTING.upon(rx_PLEASE, enter=CONNECTING,
|
WANTING.upon(rx_PLEASE, enter=CONNECTING,
|
||||||
outputs=[choose_role, start_connecting_ignore_message])
|
outputs=[choose_role, start_connecting_ignore_message])
|
||||||
|
@ -489,12 +580,10 @@ class Dilator(object):
|
||||||
_cooperator = attrib()
|
_cooperator = attrib()
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
def __attrs_post_init__(self):
|
||||||
self._got_versions_d = Deferred()
|
|
||||||
self._started = False
|
|
||||||
self._endpoints = OneShotObserver(self._eventual_queue)
|
|
||||||
self._pending_inbound_dilate_messages = deque()
|
|
||||||
self._manager = None
|
self._manager = None
|
||||||
self._host_addr = _WormholeAddress()
|
self._pending_dilation_key = None
|
||||||
|
self._pending_wormhole_versions = None
|
||||||
|
self._pending_inbound_dilate_messages = deque()
|
||||||
|
|
||||||
def wire(self, sender, terminator):
|
def wire(self, sender, terminator):
|
||||||
self._S = ISend(sender)
|
self._S = ISend(sender)
|
||||||
|
@ -502,77 +591,35 @@ class Dilator(object):
|
||||||
|
|
||||||
# this is the primary entry point, called when w.dilate() is invoked
|
# this is the primary entry point, called when w.dilate() is invoked
|
||||||
def dilate(self, transit_relay_location=None, no_listen=False):
|
def dilate(self, transit_relay_location=None, no_listen=False):
|
||||||
self._transit_relay_location = transit_relay_location
|
if not self._manager:
|
||||||
if not self._started:
|
# build the manager right away, and tell it later when the
|
||||||
self._started = True
|
# VERSIONS message arrives, and also when the dilation_key is set
|
||||||
self._start(no_listen).addBoth(self._endpoints.fire)
|
|
||||||
return self._endpoints.when_fired()
|
|
||||||
|
|
||||||
@inlineCallbacks
|
|
||||||
def _start(self, no_listen):
|
|
||||||
# first, we wait until we hear the VERSION message, which tells us 1:
|
|
||||||
# the PAKE key works, so we can talk securely, 2: that they can do
|
|
||||||
# dilation at all (if they can't then w.dilate() errbacks)
|
|
||||||
|
|
||||||
dilation_version = yield self._got_versions_d
|
|
||||||
|
|
||||||
# TODO: we could probably return the endpoints earlier, if we flunk
|
|
||||||
# any connection/listen attempts upon OldPeerCannotDilateError, or
|
|
||||||
# if/when we give up on the initial connection
|
|
||||||
|
|
||||||
if not dilation_version: # "1" or None
|
|
||||||
# TODO: be more specific about the error. dilation_version==None
|
|
||||||
# means we had no version in common with them, which could either
|
|
||||||
# be because they're so old they don't dilate at all, or because
|
|
||||||
# they're so new that they no longer accomodate our old version
|
|
||||||
raise OldPeerCannotDilateError()
|
|
||||||
|
|
||||||
my_dilation_side = make_side()
|
my_dilation_side = make_side()
|
||||||
self._manager = Manager(self._S, my_dilation_side,
|
m = Manager(self._S, my_dilation_side,
|
||||||
self._transit_key,
|
transit_relay_location,
|
||||||
self._transit_relay_location,
|
|
||||||
self._reactor, self._eventual_queue,
|
self._reactor, self._eventual_queue,
|
||||||
self._cooperator, self._host_addr, no_listen)
|
self._cooperator, no_listen)
|
||||||
# We must open subchannel0 early, since messages may arrive very
|
self._manager = m
|
||||||
# quickly once the connection is established. This subchannel may or
|
if self._pending_dilation_key is not None:
|
||||||
# may not ever get revealed to the caller, since the peer might not
|
m.got_dilation_key(self._pending_dilation_key)
|
||||||
# even be capable of dilation.
|
if self._pending_wormhole_versions:
|
||||||
scid0 = 0
|
m.got_wormhole_versions(self._pending_wormhole_versions)
|
||||||
peer_addr0 = _SubchannelAddress(scid0)
|
|
||||||
sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0)
|
|
||||||
self._manager.set_subchannel_zero(scid0, sc0)
|
|
||||||
|
|
||||||
# we can open non-zero subchannels as soon as we get our first
|
|
||||||
# connection, and we can make the Endpoints even earlier
|
|
||||||
control_ep = ControlEndpoint(peer_addr0)
|
|
||||||
control_ep._subchannel_zero_opened(sc0)
|
|
||||||
connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr)
|
|
||||||
|
|
||||||
listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr)
|
|
||||||
self._manager.set_listener_endpoint(listen_ep)
|
|
||||||
|
|
||||||
self._manager.start()
|
|
||||||
|
|
||||||
while self._pending_inbound_dilate_messages:
|
while self._pending_inbound_dilate_messages:
|
||||||
plaintext = self._pending_inbound_dilate_messages.popleft()
|
plaintext = self._pending_inbound_dilate_messages.popleft()
|
||||||
self.received_dilate(plaintext)
|
m.received_dilation_message(plaintext)
|
||||||
|
return self._manager.get_endpoints()
|
||||||
yield self._manager.when_first_connected()
|
|
||||||
|
|
||||||
endpoints = (control_ep, connect_ep, listen_ep)
|
|
||||||
returnValue(endpoints)
|
|
||||||
|
|
||||||
# Called by Terminator after everything else (mailbox, nameplate, server
|
# Called by Terminator after everything else (mailbox, nameplate, server
|
||||||
# connection) has shut down. Expects to fire T.stoppedD() when Dilator is
|
# connection) has shut down. Expects to fire T.stoppedD() when Dilator is
|
||||||
# stopped too.
|
# stopped too.
|
||||||
def stop(self):
|
def stop(self):
|
||||||
if not self._started:
|
if self._manager:
|
||||||
self._T.stoppedD()
|
|
||||||
return
|
|
||||||
if self._started:
|
|
||||||
self._manager.stop()
|
self._manager.stop()
|
||||||
# TODO: avoid Deferreds for control flow, hard to serialize
|
# TODO: avoid Deferreds for control flow, hard to serialize
|
||||||
self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD())
|
self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD())
|
||||||
|
else:
|
||||||
|
self._T.stoppedD()
|
||||||
|
return
|
||||||
# TODO: tolerate multiple calls
|
# TODO: tolerate multiple calls
|
||||||
|
|
||||||
# from Boss
|
# from Boss
|
||||||
|
@ -582,39 +629,20 @@ class Dilator(object):
|
||||||
# to tolerate either ordering
|
# to tolerate either ordering
|
||||||
purpose = b"dilation-v1"
|
purpose = b"dilation-v1"
|
||||||
LENGTH = 32 # TODO: whatever Noise wants, I guess
|
LENGTH = 32 # TODO: whatever Noise wants, I guess
|
||||||
self._transit_key = derive_key(key, purpose, LENGTH)
|
dilation_key = derive_key(key, purpose, LENGTH)
|
||||||
|
if self._manager:
|
||||||
|
self._manager.got_dilation_key(dilation_key)
|
||||||
|
else:
|
||||||
|
self._pending_dilation_key = dilation_key
|
||||||
|
|
||||||
def got_wormhole_versions(self, their_wormhole_versions):
|
def got_wormhole_versions(self, their_wormhole_versions):
|
||||||
assert self._transit_key is not None
|
if self._manager:
|
||||||
# this always happens before received_dilate
|
self._manager.got_wormhole_versions(their_wormhole_versions)
|
||||||
dilation_version = None
|
else:
|
||||||
their_dilation_versions = set(their_wormhole_versions.get("can-dilate", []))
|
self._pending_wormhole_versions = their_wormhole_versions
|
||||||
my_versions = set(DILATION_VERSIONS)
|
|
||||||
shared_versions = my_versions.intersection(their_dilation_versions)
|
|
||||||
if "1" in shared_versions:
|
|
||||||
dilation_version = "1"
|
|
||||||
self._got_versions_d.callback(dilation_version)
|
|
||||||
|
|
||||||
def received_dilate(self, plaintext):
|
def received_dilate(self, plaintext):
|
||||||
# this receives new in-order DILATE-n payloads, decrypted but not
|
|
||||||
# de-JSONed.
|
|
||||||
|
|
||||||
# this can appear before our .dilate() method is called, in which case
|
|
||||||
# we queue them for later
|
|
||||||
if not self._manager:
|
if not self._manager:
|
||||||
self._pending_inbound_dilate_messages.append(plaintext)
|
self._pending_inbound_dilate_messages.append(plaintext)
|
||||||
return
|
|
||||||
|
|
||||||
message = bytes_to_dict(plaintext)
|
|
||||||
type = message["type"]
|
|
||||||
if type == "please":
|
|
||||||
self._manager.rx_PLEASE(message)
|
|
||||||
elif type == "connection-hints":
|
|
||||||
self._manager.rx_HINTS(message)
|
|
||||||
elif type == "reconnect":
|
|
||||||
self._manager.rx_RECONNECT()
|
|
||||||
elif type == "reconnecting":
|
|
||||||
self._manager.rx_RECONNECTING()
|
|
||||||
else:
|
else:
|
||||||
log.err(UnknownDilationMessageType(message))
|
self._manager.received_dilation_message(plaintext)
|
||||||
return
|
|
||||||
|
|
|
@ -3,8 +3,7 @@ from collections import deque
|
||||||
from attr import attrs, attrib
|
from attr import attrs, attrib
|
||||||
from attr.validators import instance_of, provides
|
from attr.validators import instance_of, provides
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue,
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||||
succeed)
|
|
||||||
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
|
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
|
||||||
IAddress, IListeningPort,
|
IAddress, IListeningPort,
|
||||||
IStreamClientEndpoint,
|
IStreamClientEndpoint,
|
||||||
|
@ -12,6 +11,7 @@ from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
|
||||||
from twisted.internet.error import ConnectionDone
|
from twisted.internet.error import ConnectionDone
|
||||||
from automat import MethodicalMachine
|
from automat import MethodicalMachine
|
||||||
from .._interfaces import ISubChannel, IDilationManager
|
from .._interfaces import ISubChannel, IDilationManager
|
||||||
|
from ..observer import OneShotObserver
|
||||||
|
|
||||||
# each subchannel frame (the data passed into transport.write(data)) gets a
|
# each subchannel frame (the data passed into transport.write(data)) gets a
|
||||||
# 9-byte header prefix (type, subchannel id, and sequence number), then gets
|
# 9-byte header prefix (type, subchannel id, and sequence number), then gets
|
||||||
|
@ -217,27 +217,33 @@ class SubChannel(object):
|
||||||
|
|
||||||
|
|
||||||
@implementer(IStreamClientEndpoint)
|
@implementer(IStreamClientEndpoint)
|
||||||
|
@attrs
|
||||||
class ControlEndpoint(object):
|
class ControlEndpoint(object):
|
||||||
|
_peer_addr = attrib(validator=provides(IAddress))
|
||||||
|
_subchannel_zero = attrib(validator=provides(ISubChannel))
|
||||||
|
_eventual_queue = attrib(repr=False)
|
||||||
_used = False
|
_used = False
|
||||||
|
|
||||||
def __init__(self, peer_addr):
|
def __attrs_post_init__(self):
|
||||||
self._subchannel_zero = Deferred()
|
|
||||||
self._peer_addr = peer_addr
|
|
||||||
self._once = Once(SingleUseEndpointError)
|
self._once = Once(SingleUseEndpointError)
|
||||||
|
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
|
||||||
|
|
||||||
# from manager
|
# from manager
|
||||||
def _subchannel_zero_opened(self, subchannel):
|
|
||||||
assert ISubChannel.providedBy(subchannel), subchannel
|
def _main_channel_ready(self):
|
||||||
self._subchannel_zero.callback(subchannel)
|
self._wait_for_main_channel.fire(None)
|
||||||
|
def _main_channel_failed(self, f):
|
||||||
|
self._wait_for_main_channel.error(f)
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def connect(self, protocolFactory):
|
def connect(self, protocolFactory):
|
||||||
# return Deferred that fires with IProtocol or Failure(ConnectError)
|
# return Deferred that fires with IProtocol or Failure(ConnectError)
|
||||||
self._once()
|
self._once()
|
||||||
t = yield self._subchannel_zero
|
yield self._wait_for_main_channel.when_fired()
|
||||||
p = protocolFactory.buildProtocol(self._peer_addr)
|
p = protocolFactory.buildProtocol(self._peer_addr)
|
||||||
t._set_protocol(p)
|
self._subchannel_zero._set_protocol(p)
|
||||||
p.makeConnection(t) # set p.transport = t and call connectionMade()
|
# this sets p.transport and calls p.connectionMade()
|
||||||
|
p.makeConnection(self._subchannel_zero)
|
||||||
returnValue(p)
|
returnValue(p)
|
||||||
|
|
||||||
|
|
||||||
|
@ -246,9 +252,21 @@ class ControlEndpoint(object):
|
||||||
class SubchannelConnectorEndpoint(object):
|
class SubchannelConnectorEndpoint(object):
|
||||||
_manager = attrib(validator=provides(IDilationManager))
|
_manager = attrib(validator=provides(IDilationManager))
|
||||||
_host_addr = attrib(validator=instance_of(_WormholeAddress))
|
_host_addr = attrib(validator=instance_of(_WormholeAddress))
|
||||||
|
_eventual_queue = attrib(repr=False)
|
||||||
|
|
||||||
|
def __attrs_post_init__(self):
|
||||||
|
self._connection_deferreds = deque()
|
||||||
|
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
|
||||||
|
|
||||||
|
def _main_channel_ready(self):
|
||||||
|
self._wait_for_main_channel.fire(None)
|
||||||
|
def _main_channel_failed(self, f):
|
||||||
|
self._wait_for_main_channel.error(f)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
def connect(self, protocolFactory):
|
def connect(self, protocolFactory):
|
||||||
# return Deferred that fires with IProtocol or Failure(ConnectError)
|
# return Deferred that fires with IProtocol or Failure(ConnectError)
|
||||||
|
yield self._wait_for_main_channel.when_fired()
|
||||||
scid = self._manager.allocate_subchannel_id()
|
scid = self._manager.allocate_subchannel_id()
|
||||||
self._manager.send_open(scid)
|
self._manager.send_open(scid)
|
||||||
peer_addr = _SubchannelAddress(scid)
|
peer_addr = _SubchannelAddress(scid)
|
||||||
|
@ -259,7 +277,7 @@ class SubchannelConnectorEndpoint(object):
|
||||||
p = protocolFactory.buildProtocol(peer_addr)
|
p = protocolFactory.buildProtocol(peer_addr)
|
||||||
sc._set_protocol(p)
|
sc._set_protocol(p)
|
||||||
p.makeConnection(sc) # set p.transport = sc and call connectionMade()
|
p.makeConnection(sc) # set p.transport = sc and call connectionMade()
|
||||||
return succeed(p)
|
returnValue(p)
|
||||||
|
|
||||||
|
|
||||||
@implementer(IStreamServerEndpoint)
|
@implementer(IStreamServerEndpoint)
|
||||||
|
@ -267,10 +285,13 @@ class SubchannelConnectorEndpoint(object):
|
||||||
class SubchannelListenerEndpoint(object):
|
class SubchannelListenerEndpoint(object):
|
||||||
_manager = attrib(validator=provides(IDilationManager))
|
_manager = attrib(validator=provides(IDilationManager))
|
||||||
_host_addr = attrib(validator=provides(IAddress))
|
_host_addr = attrib(validator=provides(IAddress))
|
||||||
|
_eventual_queue = attrib(repr=False)
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
def __attrs_post_init__(self):
|
||||||
|
self._once = Once(SingleUseEndpointError)
|
||||||
self._factory = None
|
self._factory = None
|
||||||
self._pending_opens = deque()
|
self._pending_opens = deque()
|
||||||
|
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
|
||||||
|
|
||||||
# from manager (actually Inbound)
|
# from manager (actually Inbound)
|
||||||
def _got_open(self, t, peer_addr):
|
def _got_open(self, t, peer_addr):
|
||||||
|
@ -284,15 +305,23 @@ class SubchannelListenerEndpoint(object):
|
||||||
t._set_protocol(p)
|
t._set_protocol(p)
|
||||||
p.makeConnection(t)
|
p.makeConnection(t)
|
||||||
|
|
||||||
|
def _main_channel_ready(self):
|
||||||
|
self._wait_for_main_channel.fire(None)
|
||||||
|
def _main_channel_failed(self, f):
|
||||||
|
self._wait_for_main_channel.error(f)
|
||||||
|
|
||||||
# IStreamServerEndpoint
|
# IStreamServerEndpoint
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
def listen(self, protocolFactory):
|
def listen(self, protocolFactory):
|
||||||
|
self._once()
|
||||||
|
yield self._wait_for_main_channel.when_fired()
|
||||||
self._factory = protocolFactory
|
self._factory = protocolFactory
|
||||||
while self._pending_opens:
|
while self._pending_opens:
|
||||||
(t, peer_addr) = self._pending_opens.popleft()
|
(t, peer_addr) = self._pending_opens.popleft()
|
||||||
self._connect(t, peer_addr)
|
self._connect(t, peer_addr)
|
||||||
lp = SubchannelListeningPort(self._host_addr)
|
lp = SubchannelListeningPort(self._host_addr)
|
||||||
return succeed(lp)
|
returnValue(lp)
|
||||||
|
|
||||||
|
|
||||||
@implementer(IListeningPort)
|
@implementer(IListeningPort)
|
||||||
|
|
|
@ -2,7 +2,10 @@ from __future__ import print_function, unicode_literals
|
||||||
import mock
|
import mock
|
||||||
from zope.interface import alsoProvides
|
from zope.interface import alsoProvides
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
|
from twisted.internet.task import Clock
|
||||||
|
from twisted.python.failure import Failure
|
||||||
from ..._interfaces import ISubChannel
|
from ..._interfaces import ISubChannel
|
||||||
|
from ...eventual import EventualQueue
|
||||||
from ..._dilation.subchannel import (ControlEndpoint,
|
from ..._dilation.subchannel import (ControlEndpoint,
|
||||||
SubchannelConnectorEndpoint,
|
SubchannelConnectorEndpoint,
|
||||||
SubchannelListenerEndpoint,
|
SubchannelListenerEndpoint,
|
||||||
|
@ -11,12 +14,18 @@ from ..._dilation.subchannel import (ControlEndpoint,
|
||||||
SingleUseEndpointError)
|
SingleUseEndpointError)
|
||||||
from .common import mock_manager
|
from .common import mock_manager
|
||||||
|
|
||||||
|
class CannotDilateError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
class Endpoints(unittest.TestCase):
|
class Control(unittest.TestCase):
|
||||||
def test_control(self):
|
def test_early_succeed(self):
|
||||||
|
# ep.connect() is called before dilation can proceed
|
||||||
scid0 = 0
|
scid0 = 0
|
||||||
peeraddr = _SubchannelAddress(scid0)
|
peeraddr = _SubchannelAddress(scid0)
|
||||||
ep = ControlEndpoint(peeraddr)
|
sc0 = mock.Mock()
|
||||||
|
alsoProvides(sc0, ISubChannel)
|
||||||
|
eq = EventualQueue(Clock())
|
||||||
|
ep = ControlEndpoint(peeraddr, sc0, eq)
|
||||||
|
|
||||||
f = mock.Mock()
|
f = mock.Mock()
|
||||||
p = mock.Mock()
|
p = mock.Mock()
|
||||||
|
@ -24,29 +33,105 @@ class Endpoints(unittest.TestCase):
|
||||||
d = ep.connect(f)
|
d = ep.connect(f)
|
||||||
self.assertNoResult(d)
|
self.assertNoResult(d)
|
||||||
|
|
||||||
t = mock.Mock()
|
ep._main_channel_ready()
|
||||||
alsoProvides(t, ISubChannel)
|
eq.flush_sync()
|
||||||
ep._subchannel_zero_opened(t)
|
|
||||||
self.assertIdentical(self.successResultOf(d), p)
|
self.assertIdentical(self.successResultOf(d), p)
|
||||||
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
|
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
|
||||||
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
|
self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)])
|
||||||
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
|
self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)])
|
||||||
|
|
||||||
d = ep.connect(f)
|
d = ep.connect(f)
|
||||||
self.failureResultOf(d, SingleUseEndpointError)
|
self.failureResultOf(d, SingleUseEndpointError)
|
||||||
|
|
||||||
def assert_makeConnection(self, mock_calls):
|
def test_early_fail(self):
|
||||||
|
# ep.connect() is called before dilation is abandoned
|
||||||
|
scid0 = 0
|
||||||
|
peeraddr = _SubchannelAddress(scid0)
|
||||||
|
sc0 = mock.Mock()
|
||||||
|
alsoProvides(sc0, ISubChannel)
|
||||||
|
eq = EventualQueue(Clock())
|
||||||
|
ep = ControlEndpoint(peeraddr, sc0, eq)
|
||||||
|
|
||||||
|
f = mock.Mock()
|
||||||
|
p = mock.Mock()
|
||||||
|
f.buildProtocol = mock.Mock(return_value=p)
|
||||||
|
d = ep.connect(f)
|
||||||
|
self.assertNoResult(d)
|
||||||
|
|
||||||
|
ep._main_channel_failed(Failure(CannotDilateError()))
|
||||||
|
eq.flush_sync()
|
||||||
|
|
||||||
|
self.failureResultOf(d).check(CannotDilateError)
|
||||||
|
self.assertEqual(f.buildProtocol.mock_calls, [])
|
||||||
|
self.assertEqual(sc0.mock_calls, [])
|
||||||
|
|
||||||
|
d = ep.connect(f)
|
||||||
|
self.failureResultOf(d, SingleUseEndpointError)
|
||||||
|
|
||||||
|
def test_late_succeed(self):
|
||||||
|
# dilation can proceed, then ep.connect() is called
|
||||||
|
scid0 = 0
|
||||||
|
peeraddr = _SubchannelAddress(scid0)
|
||||||
|
sc0 = mock.Mock()
|
||||||
|
alsoProvides(sc0, ISubChannel)
|
||||||
|
eq = EventualQueue(Clock())
|
||||||
|
ep = ControlEndpoint(peeraddr, sc0, eq)
|
||||||
|
|
||||||
|
ep._main_channel_ready()
|
||||||
|
|
||||||
|
f = mock.Mock()
|
||||||
|
p = mock.Mock()
|
||||||
|
f.buildProtocol = mock.Mock(return_value=p)
|
||||||
|
d = ep.connect(f)
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertIdentical(self.successResultOf(d), p)
|
||||||
|
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
|
||||||
|
self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)])
|
||||||
|
self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)])
|
||||||
|
|
||||||
|
d = ep.connect(f)
|
||||||
|
self.failureResultOf(d, SingleUseEndpointError)
|
||||||
|
|
||||||
|
def test_late_fail(self):
|
||||||
|
# dilation is abandoned, then ep.connect() is called
|
||||||
|
scid0 = 0
|
||||||
|
peeraddr = _SubchannelAddress(scid0)
|
||||||
|
sc0 = mock.Mock()
|
||||||
|
alsoProvides(sc0, ISubChannel)
|
||||||
|
eq = EventualQueue(Clock())
|
||||||
|
ep = ControlEndpoint(peeraddr, sc0, eq)
|
||||||
|
|
||||||
|
ep._main_channel_failed(Failure(CannotDilateError()))
|
||||||
|
|
||||||
|
f = mock.Mock()
|
||||||
|
p = mock.Mock()
|
||||||
|
f.buildProtocol = mock.Mock(return_value=p)
|
||||||
|
d = ep.connect(f)
|
||||||
|
eq.flush_sync()
|
||||||
|
|
||||||
|
self.failureResultOf(d).check(CannotDilateError)
|
||||||
|
self.assertEqual(f.buildProtocol.mock_calls, [])
|
||||||
|
self.assertEqual(sc0.mock_calls, [])
|
||||||
|
|
||||||
|
d = ep.connect(f)
|
||||||
|
self.failureResultOf(d, SingleUseEndpointError)
|
||||||
|
|
||||||
|
class Endpoints(unittest.TestCase):
|
||||||
|
def OFFassert_makeConnection(self, mock_calls):
|
||||||
self.assertEqual(len(mock_calls), 1)
|
self.assertEqual(len(mock_calls), 1)
|
||||||
self.assertEqual(mock_calls[0][0], "makeConnection")
|
self.assertEqual(mock_calls[0][0], "makeConnection")
|
||||||
self.assertEqual(len(mock_calls[0][1]), 1)
|
self.assertEqual(len(mock_calls[0][1]), 1)
|
||||||
return mock_calls[0][1][0]
|
return mock_calls[0][1][0]
|
||||||
|
|
||||||
def test_connector(self):
|
class Connector(unittest.TestCase):
|
||||||
|
def test_early_succeed(self):
|
||||||
m = mock_manager()
|
m = mock_manager()
|
||||||
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||||
hostaddr = _WormholeAddress()
|
hostaddr = _WormholeAddress()
|
||||||
peeraddr = _SubchannelAddress(0)
|
peeraddr = _SubchannelAddress(0)
|
||||||
ep = SubchannelConnectorEndpoint(m, hostaddr)
|
eq = EventualQueue(Clock())
|
||||||
|
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
|
||||||
|
|
||||||
f = mock.Mock()
|
f = mock.Mock()
|
||||||
p = mock.Mock()
|
p = mock.Mock()
|
||||||
|
@ -55,38 +140,123 @@ class Endpoints(unittest.TestCase):
|
||||||
with mock.patch("wormhole._dilation.subchannel.SubChannel",
|
with mock.patch("wormhole._dilation.subchannel.SubChannel",
|
||||||
return_value=t) as sc:
|
return_value=t) as sc:
|
||||||
d = ep.connect(f)
|
d = ep.connect(f)
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertNoResult(d)
|
||||||
|
ep._main_channel_ready()
|
||||||
|
eq.flush_sync()
|
||||||
|
|
||||||
self.assertIdentical(self.successResultOf(d), p)
|
self.assertIdentical(self.successResultOf(d), p)
|
||||||
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
|
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
|
||||||
self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)])
|
self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)])
|
||||||
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
|
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
|
||||||
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
|
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
|
||||||
|
|
||||||
def test_listener(self):
|
def test_early_fail(self):
|
||||||
m = mock_manager()
|
m = mock_manager()
|
||||||
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||||
hostaddr = _WormholeAddress()
|
hostaddr = _WormholeAddress()
|
||||||
ep = SubchannelListenerEndpoint(m, hostaddr)
|
eq = EventualQueue(Clock())
|
||||||
|
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
|
||||||
|
|
||||||
|
f = mock.Mock()
|
||||||
|
p = mock.Mock()
|
||||||
|
t = mock.Mock()
|
||||||
|
f.buildProtocol = mock.Mock(return_value=p)
|
||||||
|
with mock.patch("wormhole._dilation.subchannel.SubChannel",
|
||||||
|
return_value=t) as sc:
|
||||||
|
d = ep.connect(f)
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertNoResult(d)
|
||||||
|
ep._main_channel_failed(Failure(CannotDilateError()))
|
||||||
|
eq.flush_sync()
|
||||||
|
|
||||||
|
self.failureResultOf(d).check(CannotDilateError)
|
||||||
|
self.assertEqual(f.buildProtocol.mock_calls, [])
|
||||||
|
self.assertEqual(sc.mock_calls, [])
|
||||||
|
self.assertEqual(t.mock_calls, [])
|
||||||
|
|
||||||
|
def test_late_succeed(self):
|
||||||
|
m = mock_manager()
|
||||||
|
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||||
|
hostaddr = _WormholeAddress()
|
||||||
|
peeraddr = _SubchannelAddress(0)
|
||||||
|
eq = EventualQueue(Clock())
|
||||||
|
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
|
||||||
|
ep._main_channel_ready()
|
||||||
|
|
||||||
|
f = mock.Mock()
|
||||||
|
p = mock.Mock()
|
||||||
|
t = mock.Mock()
|
||||||
|
f.buildProtocol = mock.Mock(return_value=p)
|
||||||
|
with mock.patch("wormhole._dilation.subchannel.SubChannel",
|
||||||
|
return_value=t) as sc:
|
||||||
|
d = ep.connect(f)
|
||||||
|
eq.flush_sync()
|
||||||
|
|
||||||
|
self.assertIdentical(self.successResultOf(d), p)
|
||||||
|
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
|
||||||
|
self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)])
|
||||||
|
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
|
||||||
|
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
|
||||||
|
|
||||||
|
def test_late_fail(self):
|
||||||
|
m = mock_manager()
|
||||||
|
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||||
|
hostaddr = _WormholeAddress()
|
||||||
|
eq = EventualQueue(Clock())
|
||||||
|
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
|
||||||
|
ep._main_channel_failed(Failure(CannotDilateError()))
|
||||||
|
|
||||||
|
f = mock.Mock()
|
||||||
|
p = mock.Mock()
|
||||||
|
t = mock.Mock()
|
||||||
|
f.buildProtocol = mock.Mock(return_value=p)
|
||||||
|
with mock.patch("wormhole._dilation.subchannel.SubChannel",
|
||||||
|
return_value=t) as sc:
|
||||||
|
d = ep.connect(f)
|
||||||
|
eq.flush_sync()
|
||||||
|
|
||||||
|
self.failureResultOf(d).check(CannotDilateError)
|
||||||
|
self.assertEqual(f.buildProtocol.mock_calls, [])
|
||||||
|
self.assertEqual(sc.mock_calls, [])
|
||||||
|
self.assertEqual(t.mock_calls, [])
|
||||||
|
|
||||||
|
class Listener(unittest.TestCase):
|
||||||
|
def test_early_succeed(self):
|
||||||
|
# listen, main_channel_ready, got_open, got_open
|
||||||
|
m = mock_manager()
|
||||||
|
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||||
|
hostaddr = _WormholeAddress()
|
||||||
|
eq = EventualQueue(Clock())
|
||||||
|
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
|
||||||
|
|
||||||
f = mock.Mock()
|
f = mock.Mock()
|
||||||
p1 = mock.Mock()
|
p1 = mock.Mock()
|
||||||
p2 = mock.Mock()
|
p2 = mock.Mock()
|
||||||
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
|
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
|
||||||
|
|
||||||
# OPEN that arrives before we ep.listen() should be queued
|
d = ep.listen(f)
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertNoResult(d)
|
||||||
|
self.assertEqual(f.buildProtocol.mock_calls, [])
|
||||||
|
|
||||||
|
ep._main_channel_ready()
|
||||||
|
eq.flush_sync()
|
||||||
|
lp = self.successResultOf(d)
|
||||||
|
self.assertIsInstance(lp, SubchannelListeningPort)
|
||||||
|
|
||||||
|
self.assertEqual(lp.getHost(), hostaddr)
|
||||||
|
# TODO: IListeningPort says we must provide this, but I don't know
|
||||||
|
# that anyone would ever call it.
|
||||||
|
lp.startListening()
|
||||||
|
|
||||||
t1 = mock.Mock()
|
t1 = mock.Mock()
|
||||||
peeraddr1 = _SubchannelAddress(1)
|
peeraddr1 = _SubchannelAddress(1)
|
||||||
ep._got_open(t1, peeraddr1)
|
ep._got_open(t1, peeraddr1)
|
||||||
|
|
||||||
d = ep.listen(f)
|
|
||||||
lp = self.successResultOf(d)
|
|
||||||
self.assertIsInstance(lp, SubchannelListeningPort)
|
|
||||||
|
|
||||||
self.assertEqual(lp.getHost(), hostaddr)
|
|
||||||
lp.startListening()
|
|
||||||
|
|
||||||
self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)])
|
self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)])
|
||||||
self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)])
|
self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)])
|
||||||
|
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)])
|
||||||
|
|
||||||
t2 = mock.Mock()
|
t2 = mock.Mock()
|
||||||
peeraddr2 = _SubchannelAddress(2)
|
peeraddr2 = _SubchannelAddress(2)
|
||||||
|
@ -94,5 +264,92 @@ class Endpoints(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])
|
self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])
|
||||||
self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)])
|
self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)])
|
||||||
|
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1),
|
||||||
|
mock.call(peeraddr2)])
|
||||||
|
|
||||||
lp.stopListening() # TODO: should this do more?
|
lp.stopListening() # TODO: should this do more?
|
||||||
|
|
||||||
|
def test_early_fail(self):
|
||||||
|
# listen, main_channel_fail
|
||||||
|
m = mock_manager()
|
||||||
|
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||||
|
hostaddr = _WormholeAddress()
|
||||||
|
eq = EventualQueue(Clock())
|
||||||
|
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
|
||||||
|
|
||||||
|
f = mock.Mock()
|
||||||
|
p1 = mock.Mock()
|
||||||
|
p2 = mock.Mock()
|
||||||
|
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
|
||||||
|
|
||||||
|
d = ep.listen(f)
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertNoResult(d)
|
||||||
|
|
||||||
|
ep._main_channel_failed(Failure(CannotDilateError()))
|
||||||
|
eq.flush_sync()
|
||||||
|
self.failureResultOf(d).check(CannotDilateError)
|
||||||
|
self.assertEqual(f.buildProtocol.mock_calls, [])
|
||||||
|
|
||||||
|
def test_late_succeed(self):
|
||||||
|
# main_channel_ready, got_open, listen, got_open
|
||||||
|
m = mock_manager()
|
||||||
|
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||||
|
hostaddr = _WormholeAddress()
|
||||||
|
eq = EventualQueue(Clock())
|
||||||
|
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
|
||||||
|
ep._main_channel_ready()
|
||||||
|
|
||||||
|
f = mock.Mock()
|
||||||
|
p1 = mock.Mock()
|
||||||
|
p2 = mock.Mock()
|
||||||
|
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
|
||||||
|
|
||||||
|
t1 = mock.Mock()
|
||||||
|
peeraddr1 = _SubchannelAddress(1)
|
||||||
|
ep._got_open(t1, peeraddr1)
|
||||||
|
eq.flush_sync()
|
||||||
|
|
||||||
|
self.assertEqual(t1.mock_calls, [])
|
||||||
|
self.assertEqual(p1.mock_calls, [])
|
||||||
|
|
||||||
|
d = ep.listen(f)
|
||||||
|
eq.flush_sync()
|
||||||
|
lp = self.successResultOf(d)
|
||||||
|
self.assertIsInstance(lp, SubchannelListeningPort)
|
||||||
|
self.assertEqual(lp.getHost(), hostaddr)
|
||||||
|
lp.startListening()
|
||||||
|
|
||||||
|
self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)])
|
||||||
|
self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)])
|
||||||
|
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)])
|
||||||
|
|
||||||
|
t2 = mock.Mock()
|
||||||
|
peeraddr2 = _SubchannelAddress(2)
|
||||||
|
ep._got_open(t2, peeraddr2)
|
||||||
|
|
||||||
|
self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])
|
||||||
|
self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)])
|
||||||
|
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1),
|
||||||
|
mock.call(peeraddr2)])
|
||||||
|
|
||||||
|
lp.stopListening() # TODO: should this do more?
|
||||||
|
|
||||||
|
def test_late_fail(self):
|
||||||
|
# main_channel_fail, listen
|
||||||
|
m = mock_manager()
|
||||||
|
m.allocate_subchannel_id = mock.Mock(return_value=0)
|
||||||
|
hostaddr = _WormholeAddress()
|
||||||
|
eq = EventualQueue(Clock())
|
||||||
|
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
|
||||||
|
ep._main_channel_failed(Failure(CannotDilateError()))
|
||||||
|
|
||||||
|
f = mock.Mock()
|
||||||
|
p1 = mock.Mock()
|
||||||
|
p2 = mock.Mock()
|
||||||
|
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
|
||||||
|
|
||||||
|
d = ep.listen(f)
|
||||||
|
eq.flush_sync()
|
||||||
|
self.failureResultOf(d).check(CannotDilateError)
|
||||||
|
self.assertEqual(f.buildProtocol.mock_calls, [])
|
||||||
|
|
|
@ -46,24 +46,21 @@ class Full(ServerBase, unittest.TestCase):
|
||||||
yield doBoth(w1.get_verifier(), w2.get_verifier())
|
yield doBoth(w1.get_verifier(), w2.get_verifier())
|
||||||
print("connected")
|
print("connected")
|
||||||
|
|
||||||
eps1_d = w1.dilate()
|
eps1 = w1.dilate()
|
||||||
eps2_d = w2.dilate()
|
eps2 = w2.dilate()
|
||||||
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
|
|
||||||
(control_ep1, connect_ep1, listen_ep1) = eps1
|
|
||||||
(control_ep2, connect_ep2, listen_ep2) = eps2
|
|
||||||
print("w.dilate ready")
|
print("w.dilate ready")
|
||||||
|
|
||||||
f1 = Factory()
|
f1 = Factory()
|
||||||
f1.protocol = L
|
f1.protocol = L
|
||||||
f1.d = Deferred()
|
f1.d = Deferred()
|
||||||
f1.d.addCallback(lambda data: eq.fire_eventually(data))
|
f1.d.addCallback(lambda data: eq.fire_eventually(data))
|
||||||
d1 = control_ep1.connect(f1)
|
d1 = eps1.control.connect(f1)
|
||||||
|
|
||||||
f2 = Factory()
|
f2 = Factory()
|
||||||
f2.protocol = L
|
f2.protocol = L
|
||||||
f2.d = Deferred()
|
f2.d = Deferred()
|
||||||
f2.d.addCallback(lambda data: eq.fire_eventually(data))
|
f2.d.addCallback(lambda data: eq.fire_eventually(data))
|
||||||
d2 = control_ep2.connect(f2)
|
d2 = eps2.control.connect(f2)
|
||||||
yield d1
|
yield d1
|
||||||
yield d2
|
yield d2
|
||||||
print("control endpoints connected")
|
print("control endpoints connected")
|
||||||
|
@ -125,14 +122,12 @@ class Reconnect(ServerBase, unittest.TestCase):
|
||||||
w2.set_code(code)
|
w2.set_code(code)
|
||||||
yield doBoth(w1.get_verifier(), w2.get_verifier())
|
yield doBoth(w1.get_verifier(), w2.get_verifier())
|
||||||
|
|
||||||
eps1_d = w1.dilate()
|
eps1 = w1.dilate()
|
||||||
eps2_d = w2.dilate()
|
eps2 = w2.dilate()
|
||||||
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
|
print("w.dilate ready")
|
||||||
(control_ep1, connect_ep1, listen_ep1) = eps1
|
|
||||||
(control_ep2, connect_ep2, listen_ep2) = eps2
|
|
||||||
|
|
||||||
f1 = ReconF(eq); f2 = ReconF(eq)
|
f1 = ReconF(eq); f2 = ReconF(eq)
|
||||||
d1 = control_ep1.connect(f1); d2 = control_ep2.connect(f2)
|
d1 = eps1.control.connect(f1); d2 = eps2.control.connect(f2)
|
||||||
yield d1
|
yield d1
|
||||||
yield d2
|
yield d2
|
||||||
|
|
||||||
|
@ -194,14 +189,12 @@ class Reconnect(ServerBase, unittest.TestCase):
|
||||||
w2.set_code(code)
|
w2.set_code(code)
|
||||||
yield doBoth(w1.get_verifier(), w2.get_verifier())
|
yield doBoth(w1.get_verifier(), w2.get_verifier())
|
||||||
|
|
||||||
eps1_d = w1.dilate()
|
eps1 = w1.dilate()
|
||||||
eps2_d = w2.dilate()
|
eps2 = w2.dilate()
|
||||||
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
|
print("w.dilate ready")
|
||||||
(control_ep1, connect_ep1, listen_ep1) = eps1
|
|
||||||
(control_ep2, connect_ep2, listen_ep2) = eps2
|
|
||||||
|
|
||||||
f1 = ReconF(eq); f2 = ReconF(eq)
|
f1 = ReconF(eq); f2 = ReconF(eq)
|
||||||
d1 = control_ep1.connect(f1); d2 = control_ep2.connect(f2)
|
d1 = eps1.control.connect(f1); d2 = eps2.control.connect(f2)
|
||||||
yield d1
|
yield d1
|
||||||
yield d2
|
yield d2
|
||||||
|
|
||||||
|
@ -287,19 +280,17 @@ class Endpoints(ServerBase, unittest.TestCase):
|
||||||
w2.set_code(code)
|
w2.set_code(code)
|
||||||
yield doBoth(w1.get_verifier(), w2.get_verifier())
|
yield doBoth(w1.get_verifier(), w2.get_verifier())
|
||||||
|
|
||||||
eps1_d = w1.dilate()
|
eps1 = w1.dilate()
|
||||||
eps2_d = w2.dilate()
|
eps2 = w2.dilate()
|
||||||
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
|
print("w.dilate ready")
|
||||||
(control_ep1, connect_ep1, listen_ep1) = eps1
|
|
||||||
(control_ep2, connect_ep2, listen_ep2) = eps2
|
|
||||||
|
|
||||||
f0 = ReconF(eq)
|
f0 = ReconF(eq)
|
||||||
yield listen_ep2.listen(f0)
|
yield eps2.listen.listen(f0)
|
||||||
|
|
||||||
from twisted.python import log
|
from twisted.python import log
|
||||||
f1 = ReconF(eq)
|
f1 = ReconF(eq)
|
||||||
log.msg("connecting")
|
log.msg("connecting")
|
||||||
p1_client = yield connect_ep1.connect(f1)
|
p1_client = yield eps1.connect.connect(f1)
|
||||||
log.msg("sending c->s")
|
log.msg("sending c->s")
|
||||||
p1_client.transport.write(b"hello from p1\n")
|
p1_client.transport.write(b"hello from p1\n")
|
||||||
data = yield f0.deferreds["dataReceived"]
|
data = yield f0.deferreds["dataReceived"]
|
||||||
|
@ -316,7 +307,7 @@ class Endpoints(ServerBase, unittest.TestCase):
|
||||||
f0.resetDeferred("dataReceived")
|
f0.resetDeferred("dataReceived")
|
||||||
f1.resetDeferred("dataReceived")
|
f1.resetDeferred("dataReceived")
|
||||||
f2 = ReconF(eq)
|
f2 = ReconF(eq)
|
||||||
p2_client = yield connect_ep1.connect(f2)
|
p2_client = yield eps1.connect.connect(f2)
|
||||||
p2_server = yield f0.deferreds["connectionMade"]
|
p2_server = yield f0.deferreds["connectionMade"]
|
||||||
p2_server.transport.write(b"hello p2\n")
|
p2_server.transport.write(b"hello p2\n")
|
||||||
data = yield f2.deferreds["dataReceived"]
|
data = yield f2.deferreds["dataReceived"]
|
||||||
|
|
|
@ -140,25 +140,23 @@ class TestDilator(unittest.TestCase):
|
||||||
|
|
||||||
def test_peer_cannot_dilate(self):
|
def test_peer_cannot_dilate(self):
|
||||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||||
d1 = dil.dilate()
|
eps = dil.dilate()
|
||||||
self.assertNoResult(d1)
|
|
||||||
|
|
||||||
dil._transit_key = b"\x01" * 32
|
dil.got_key(b"\x01" * 32)
|
||||||
dil.got_wormhole_versions({}) # missing "can-dilate"
|
dil.got_wormhole_versions({}) # missing "can-dilate"
|
||||||
|
d = eps.connect.connect(None)
|
||||||
eq.flush_sync()
|
eq.flush_sync()
|
||||||
f = self.failureResultOf(d1)
|
self.failureResultOf(d).check(OldPeerCannotDilateError)
|
||||||
f.check(OldPeerCannotDilateError)
|
|
||||||
|
|
||||||
def test_disjoint_versions(self):
|
def test_disjoint_versions(self):
|
||||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||||
d1 = dil.dilate()
|
eps = dil.dilate()
|
||||||
self.assertNoResult(d1)
|
|
||||||
|
|
||||||
dil._transit_key = b"key"
|
dil.got_key(b"\x01" * 32)
|
||||||
dil.got_wormhole_versions({"can-dilate": [-1]})
|
dil.got_wormhole_versions({"can-dilate": [-1]})
|
||||||
|
d = eps.connect.connect(None)
|
||||||
eq.flush_sync()
|
eq.flush_sync()
|
||||||
f = self.failureResultOf(d1)
|
self.failureResultOf(d).check(OldPeerCannotDilateError)
|
||||||
f.check(OldPeerCannotDilateError)
|
|
||||||
|
|
||||||
def test_early_dilate_messages(self):
|
def test_early_dilate_messages(self):
|
||||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||||
|
@ -276,11 +274,11 @@ def make_manager(leader=True):
|
||||||
h.Inbound = mock.Mock(return_value=h.inbound)
|
h.Inbound = mock.Mock(return_value=h.inbound)
|
||||||
h.outbound = mock.Mock()
|
h.outbound = mock.Mock()
|
||||||
h.Outbound = mock.Mock(return_value=h.outbound)
|
h.Outbound = mock.Mock(return_value=h.outbound)
|
||||||
h.hostaddr = mock.Mock()
|
with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound), \
|
||||||
alsoProvides(h.hostaddr, IAddress)
|
mock.patch("wormhole._dilation.manager.Outbound", h.Outbound):
|
||||||
with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound):
|
m = Manager(h.send, side, h.relay, h.reactor, h.eq, h.coop)
|
||||||
with mock.patch("wormhole._dilation.manager.Outbound", h.Outbound):
|
h.hostaddr = m._host_addr
|
||||||
m = Manager(h.send, side, h.key, h.relay, h.reactor, h.eq, h.coop, h.hostaddr)
|
m.got_dilation_key(h.key)
|
||||||
return m, h
|
return m, h
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user