Merge branch '335-buffer-1'

closes #335
This commit is contained in:
Brian Warner 2019-07-13 19:48:58 -07:00
commit 20496e3976
5 changed files with 653 additions and 387 deletions

View File

@ -2,13 +2,20 @@ from __future__ import print_function, unicode_literals
import six
import os
from collections import deque
try:
# py >= 3.3
from collections.abc import Sequence
except ImportError:
# py 2 and py3 < 3.3
from collections import Sequence
from attr import attrs, attrib
from attr.validators import provides, instance_of, optional
from automat import MethodicalMachine
from zope.interface import implementer
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
from twisted.internet.interfaces import IAddress
from twisted.python import log
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import (IStreamClientEndpoint,
IStreamServerEndpoint)
from twisted.python import log, failure
from .._interfaces import IDilator, IDilationManager, ISend, ITerminator
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr
from ..observer import OneShotObserver
@ -47,6 +54,15 @@ class UnexpectedKCM(Exception):
class UnknownMessageType(Exception):
pass
@attrs
class EndpointRecord(Sequence):
control = attrib(validator=provides(IStreamClientEndpoint))
connect = attrib(validator=provides(IStreamClientEndpoint))
listen = attrib(validator=provides(IStreamServerEndpoint))
def __len__(self):
return 3
def __getitem__(self, n):
return (self.control, self.connect, self.listen)[n]
def make_side():
return bytes_to_hexstr(os.urandom(6))
@ -93,13 +109,14 @@ def make_side():
class Manager(object):
_S = attrib(validator=provides(ISend), repr=False)
_my_side = attrib(validator=instance_of(type(u"")))
_transit_key = attrib(validator=instance_of(bytes), repr=False)
_transit_relay_location = attrib(validator=optional(instance_of(str)))
_reactor = attrib(repr=False)
_eventual_queue = attrib(repr=False)
_cooperator = attrib(repr=False)
_host_addr = attrib(validator=provides(IAddress))
_no_listen = attrib(default=False)
# TODO: can this validator work when the parameter is optional?
_no_listen = attrib(validator=instance_of(bool), default=False)
_dilation_key = None
_tor = None # TODO
_timing = None # TODO
_next_subchannel_id = None # initialized in choose_role
@ -111,10 +128,10 @@ class Manager(object):
self._got_versions_d = Deferred()
self._my_role = None # determined upon rx_PLEASE
self._host_addr = _WormholeAddress()
self._connection = None
self._made_first_connection = False
self._first_connected = OneShotObserver(self._eventual_queue)
self._stopped = OneShotObserver(self._eventual_queue)
self._debug_stall_connector = False
@ -127,18 +144,81 @@ class Manager(object):
self._inbound = Inbound(self, self._host_addr)
self._outbound = Outbound(self, self._cooperator) # from us to peer
def set_listener_endpoint(self, listener_endpoint):
self._inbound.set_listener_endpoint(listener_endpoint)
def set_subchannel_zero(self, scid0, sc0):
# We must open subchannel0 early, since messages may arrive very
# quickly once the connection is established. This subchannel may or
# may not ever get revealed to the caller, since the peer might not
# even be capable of dilation.
scid0 = 0
peer_addr0 = _SubchannelAddress(scid0)
sc0 = SubChannel(scid0, self, self._host_addr, peer_addr0)
self._inbound.set_subchannel_zero(scid0, sc0)
def when_first_connected(self):
return self._first_connected.when_fired()
# we can open non-zero subchannels as soon as we get our first
# connection, and we can make the Endpoints even earlier
control_ep = ControlEndpoint(peer_addr0, sc0, self._eventual_queue)
connect_ep = SubchannelConnectorEndpoint(self, self._host_addr, self._eventual_queue)
listen_ep = SubchannelListenerEndpoint(self, self._host_addr, self._eventual_queue)
# TODO: let inbound/outbound create the endpoints, then return them
# to us
self._inbound.set_listener_endpoint(listen_ep)
self._endpoints = EndpointRecord(control_ep, connect_ep, listen_ep)
def get_endpoints(self):
return self._endpoints
def got_dilation_key(self, key):
assert isinstance(key, bytes)
self._dilation_key = key
def got_wormhole_versions(self, their_wormhole_versions):
# this always happens before received_dilation_message
dilation_version = None
their_dilation_versions = set(their_wormhole_versions.get("can-dilate", []))
my_versions = set(DILATION_VERSIONS)
shared_versions = my_versions.intersection(their_dilation_versions)
if "1" in shared_versions:
dilation_version = "1"
# dilation_version is the best mutually-compatible version we have
# with the peer, or None if we have nothing in common
if not dilation_version: # "1" or None
# TODO: be more specific about the error. dilation_version==None
# means we had no version in common with them, which could either
# be because they're so old they don't dilate at all, or because
# they're so new that they no longer accomodate our old version
self.fail(failure.Failure(OldPeerCannotDilateError()))
self.start()
def fail(self, f):
self._endpoints.control._main_channel_failed(f)
self._endpoints.connect._main_channel_failed(f)
self._endpoints.listen._main_channel_failed(f)
def received_dilation_message(self, plaintext):
# this receives new in-order DILATE-n payloads, decrypted but not
# de-JSONed.
message = bytes_to_dict(plaintext)
type = message["type"]
if type == "please":
self.rx_PLEASE(message)
elif type == "connection-hints":
self.rx_HINTS(message)
elif type == "reconnect":
self.rx_RECONNECT()
elif type == "reconnecting":
self.rx_RECONNECTING()
else:
log.err(UnknownDilationMessageType(message))
return
def when_stopped(self):
return self._stopped.when_fired()
def send_dilation_phase(self, **fields):
dilation_phase = self._next_dilation_phase
self._next_dilation_phase += 1
@ -204,7 +284,9 @@ class Manager(object):
self._outbound.use_connection(c) # does c.registerProducer
if not self._made_first_connection:
self._made_first_connection = True
self._first_connected.fire(None)
self._endpoints.control._main_channel_ready()
self._endpoints.connect._main_channel_ready()
self._endpoints.listen._main_channel_ready()
pass
def connector_connection_lost(self):
@ -272,16 +354,11 @@ class Manager(object):
# state machine
# We are born WANTING after the local app calls w.dilate(). We start
# CONNECTING when we receive PLEASE from the remote side
def start(self):
self.send_please()
def send_please(self):
self.send_dilation_phase(type="please", side=self._my_side)
@m.state(initial=True)
def WAITING(self):
pass # pragma: no cover
@m.state()
def WANTING(self):
pass # pragma: no cover
@ -313,6 +390,10 @@ class Manager(object):
def STOPPED(self):
pass # pragma: no cover
@m.input()
def start(self):
pass # pragma: no cover
@m.input()
def rx_PLEASE(self, message):
pass # pragma: no cover
@ -350,6 +431,10 @@ class Manager(object):
def stop(self):
pass # pragma: no cover
@m.output()
def send_please(self):
self.send_dilation_phase(type="please", side=self._my_side)
@m.output()
def choose_role(self, message):
their_side = message["side"]
@ -378,7 +463,8 @@ class Manager(object):
def _start_connecting(self):
assert self._my_role is not None
self._connector = Connector(self._transit_key,
assert self._dilation_key is not None
self._connector = Connector(self._dilation_key,
self._transit_relay_location,
self,
self._reactor, self._eventual_queue,
@ -422,6 +508,11 @@ class Manager(object):
def notify_stopped(self):
self._stopped.fire(None)
# We are born WAITING after the local app calls w.dilate(). We enter
# WANTING (and send a PLEASE) when we learn of a mutually-compatible
# dilation_version.
WAITING.upon(start, enter=WANTING, outputs=[send_please])
# we start CONNECTING when we get rx_PLEASE
WANTING.upon(rx_PLEASE, enter=CONNECTING,
outputs=[choose_role, start_connecting_ignore_message])
@ -489,12 +580,10 @@ class Dilator(object):
_cooperator = attrib()
def __attrs_post_init__(self):
self._got_versions_d = Deferred()
self._started = False
self._endpoints = OneShotObserver(self._eventual_queue)
self._pending_inbound_dilate_messages = deque()
self._manager = None
self._host_addr = _WormholeAddress()
self._pending_dilation_key = None
self._pending_wormhole_versions = None
self._pending_inbound_dilate_messages = deque()
def wire(self, sender, terminator):
self._S = ISend(sender)
@ -502,77 +591,35 @@ class Dilator(object):
# this is the primary entry point, called when w.dilate() is invoked
def dilate(self, transit_relay_location=None, no_listen=False):
self._transit_relay_location = transit_relay_location
if not self._started:
self._started = True
self._start(no_listen).addBoth(self._endpoints.fire)
return self._endpoints.when_fired()
@inlineCallbacks
def _start(self, no_listen):
# first, we wait until we hear the VERSION message, which tells us 1:
# the PAKE key works, so we can talk securely, 2: that they can do
# dilation at all (if they can't then w.dilate() errbacks)
dilation_version = yield self._got_versions_d
# TODO: we could probably return the endpoints earlier, if we flunk
# any connection/listen attempts upon OldPeerCannotDilateError, or
# if/when we give up on the initial connection
if not dilation_version: # "1" or None
# TODO: be more specific about the error. dilation_version==None
# means we had no version in common with them, which could either
# be because they're so old they don't dilate at all, or because
# they're so new that they no longer accomodate our old version
raise OldPeerCannotDilateError()
my_dilation_side = make_side()
self._manager = Manager(self._S, my_dilation_side,
self._transit_key,
self._transit_relay_location,
self._reactor, self._eventual_queue,
self._cooperator, self._host_addr, no_listen)
# We must open subchannel0 early, since messages may arrive very
# quickly once the connection is established. This subchannel may or
# may not ever get revealed to the caller, since the peer might not
# even be capable of dilation.
scid0 = 0
peer_addr0 = _SubchannelAddress(scid0)
sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0)
self._manager.set_subchannel_zero(scid0, sc0)
self._manager.start()
while self._pending_inbound_dilate_messages:
plaintext = self._pending_inbound_dilate_messages.popleft()
self.received_dilate(plaintext)
yield self._manager.when_first_connected()
# we can open non-zero subchannels as soon as we get our first
# connection
control_ep = ControlEndpoint(peer_addr0)
control_ep._subchannel_zero_opened(sc0)
connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr)
listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr)
self._manager.set_listener_endpoint(listen_ep)
endpoints = (control_ep, connect_ep, listen_ep)
returnValue(endpoints)
if not self._manager:
# build the manager right away, and tell it later when the
# VERSIONS message arrives, and also when the dilation_key is set
my_dilation_side = make_side()
m = Manager(self._S, my_dilation_side,
transit_relay_location,
self._reactor, self._eventual_queue,
self._cooperator, no_listen)
self._manager = m
if self._pending_dilation_key is not None:
m.got_dilation_key(self._pending_dilation_key)
if self._pending_wormhole_versions:
m.got_wormhole_versions(self._pending_wormhole_versions)
while self._pending_inbound_dilate_messages:
plaintext = self._pending_inbound_dilate_messages.popleft()
m.received_dilation_message(plaintext)
return self._manager.get_endpoints()
# Called by Terminator after everything else (mailbox, nameplate, server
# connection) has shut down. Expects to fire T.stoppedD() when Dilator is
# stopped too.
def stop(self):
if not self._started:
self._T.stoppedD()
return
if self._started:
if self._manager:
self._manager.stop()
# TODO: avoid Deferreds for control flow, hard to serialize
self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD())
else:
self._T.stoppedD()
return
# TODO: tolerate multiple calls
# from Boss
@ -582,39 +629,20 @@ class Dilator(object):
# to tolerate either ordering
purpose = b"dilation-v1"
LENGTH = 32 # TODO: whatever Noise wants, I guess
self._transit_key = derive_key(key, purpose, LENGTH)
dilation_key = derive_key(key, purpose, LENGTH)
if self._manager:
self._manager.got_dilation_key(dilation_key)
else:
self._pending_dilation_key = dilation_key
def got_wormhole_versions(self, their_wormhole_versions):
assert self._transit_key is not None
# this always happens before received_dilate
dilation_version = None
their_dilation_versions = set(their_wormhole_versions.get("can-dilate", []))
my_versions = set(DILATION_VERSIONS)
shared_versions = my_versions.intersection(their_dilation_versions)
if "1" in shared_versions:
dilation_version = "1"
self._got_versions_d.callback(dilation_version)
if self._manager:
self._manager.got_wormhole_versions(their_wormhole_versions)
else:
self._pending_wormhole_versions = their_wormhole_versions
def received_dilate(self, plaintext):
# this receives new in-order DILATE-n payloads, decrypted but not
# de-JSONed.
# this can appear before our .dilate() method is called, in which case
# we queue them for later
if not self._manager:
self._pending_inbound_dilate_messages.append(plaintext)
return
message = bytes_to_dict(plaintext)
type = message["type"]
if type == "please":
self._manager.rx_PLEASE(message)
elif type == "connection-hints":
self._manager.rx_HINTS(message)
elif type == "reconnect":
self._manager.rx_RECONNECT()
elif type == "reconnecting":
self._manager.rx_RECONNECTING()
else:
log.err(UnknownDilationMessageType(message))
return
self._manager.received_dilation_message(plaintext)

View File

@ -1,9 +1,9 @@
import six
from collections import deque
from attr import attrs, attrib
from attr.validators import instance_of, provides
from zope.interface import implementer
from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue,
succeed)
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
IAddress, IListeningPort,
IStreamClientEndpoint,
@ -11,6 +11,7 @@ from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
from twisted.internet.error import ConnectionDone
from automat import MethodicalMachine
from .._interfaces import ISubChannel, IDilationManager
from ..observer import OneShotObserver
# each subchannel frame (the data passed into transport.write(data)) gets a
# 9-byte header prefix (type, subchannel id, and sequence number), then gets
@ -216,27 +217,33 @@ class SubChannel(object):
@implementer(IStreamClientEndpoint)
@attrs
class ControlEndpoint(object):
_peer_addr = attrib(validator=provides(IAddress))
_subchannel_zero = attrib(validator=provides(ISubChannel))
_eventual_queue = attrib(repr=False)
_used = False
def __init__(self, peer_addr):
self._subchannel_zero = Deferred()
self._peer_addr = peer_addr
def __attrs_post_init__(self):
self._once = Once(SingleUseEndpointError)
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
# from manager
def _subchannel_zero_opened(self, subchannel):
assert ISubChannel.providedBy(subchannel), subchannel
self._subchannel_zero.callback(subchannel)
def _main_channel_ready(self):
self._wait_for_main_channel.fire(None)
def _main_channel_failed(self, f):
self._wait_for_main_channel.error(f)
@inlineCallbacks
def connect(self, protocolFactory):
# return Deferred that fires with IProtocol or Failure(ConnectError)
self._once()
t = yield self._subchannel_zero
yield self._wait_for_main_channel.when_fired()
p = protocolFactory.buildProtocol(self._peer_addr)
t._set_protocol(p)
p.makeConnection(t) # set p.transport = t and call connectionMade()
self._subchannel_zero._set_protocol(p)
# this sets p.transport and calls p.connectionMade()
p.makeConnection(self._subchannel_zero)
returnValue(p)
@ -245,9 +252,21 @@ class ControlEndpoint(object):
class SubchannelConnectorEndpoint(object):
_manager = attrib(validator=provides(IDilationManager))
_host_addr = attrib(validator=instance_of(_WormholeAddress))
_eventual_queue = attrib(repr=False)
def __attrs_post_init__(self):
self._connection_deferreds = deque()
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
def _main_channel_ready(self):
self._wait_for_main_channel.fire(None)
def _main_channel_failed(self, f):
self._wait_for_main_channel.error(f)
@inlineCallbacks
def connect(self, protocolFactory):
# return Deferred that fires with IProtocol or Failure(ConnectError)
yield self._wait_for_main_channel.when_fired()
scid = self._manager.allocate_subchannel_id()
self._manager.send_open(scid)
peer_addr = _SubchannelAddress(scid)
@ -258,7 +277,7 @@ class SubchannelConnectorEndpoint(object):
p = protocolFactory.buildProtocol(peer_addr)
sc._set_protocol(p)
p.makeConnection(sc) # set p.transport = sc and call connectionMade()
return succeed(p)
returnValue(p)
@implementer(IStreamServerEndpoint)
@ -266,12 +285,15 @@ class SubchannelConnectorEndpoint(object):
class SubchannelListenerEndpoint(object):
_manager = attrib(validator=provides(IDilationManager))
_host_addr = attrib(validator=provides(IAddress))
_eventual_queue = attrib(repr=False)
def __attrs_post_init__(self):
self._once = Once(SingleUseEndpointError)
self._factory = None
self._pending_opens = []
self._pending_opens = deque()
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
# from manager
# from manager (actually Inbound)
def _got_open(self, t, peer_addr):
if self._factory:
self._connect(t, peer_addr)
@ -283,15 +305,23 @@ class SubchannelListenerEndpoint(object):
t._set_protocol(p)
p.makeConnection(t)
def _main_channel_ready(self):
self._wait_for_main_channel.fire(None)
def _main_channel_failed(self, f):
self._wait_for_main_channel.error(f)
# IStreamServerEndpoint
@inlineCallbacks
def listen(self, protocolFactory):
self._once()
yield self._wait_for_main_channel.when_fired()
self._factory = protocolFactory
for (t, peer_addr) in self._pending_opens:
while self._pending_opens:
(t, peer_addr) = self._pending_opens.popleft()
self._connect(t, peer_addr)
self._pending_opens = []
lp = SubchannelListeningPort(self._host_addr)
return succeed(lp)
returnValue(lp)
@implementer(IListeningPort)

View File

@ -2,7 +2,10 @@ from __future__ import print_function, unicode_literals
import mock
from zope.interface import alsoProvides
from twisted.trial import unittest
from twisted.internet.task import Clock
from twisted.python.failure import Failure
from ..._interfaces import ISubChannel
from ...eventual import EventualQueue
from ..._dilation.subchannel import (ControlEndpoint,
SubchannelConnectorEndpoint,
SubchannelListenerEndpoint,
@ -11,12 +14,18 @@ from ..._dilation.subchannel import (ControlEndpoint,
SingleUseEndpointError)
from .common import mock_manager
class CannotDilateError(Exception):
pass
class Endpoints(unittest.TestCase):
def test_control(self):
class Control(unittest.TestCase):
def test_early_succeed(self):
# ep.connect() is called before dilation can proceed
scid0 = 0
peeraddr = _SubchannelAddress(scid0)
ep = ControlEndpoint(peeraddr)
sc0 = mock.Mock()
alsoProvides(sc0, ISubChannel)
eq = EventualQueue(Clock())
ep = ControlEndpoint(peeraddr, sc0, eq)
f = mock.Mock()
p = mock.Mock()
@ -24,29 +33,105 @@ class Endpoints(unittest.TestCase):
d = ep.connect(f)
self.assertNoResult(d)
t = mock.Mock()
alsoProvides(t, ISubChannel)
ep._subchannel_zero_opened(t)
ep._main_channel_ready()
eq.flush_sync()
self.assertIdentical(self.successResultOf(d), p)
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)])
self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)])
d = ep.connect(f)
self.failureResultOf(d, SingleUseEndpointError)
def assert_makeConnection(self, mock_calls):
def test_early_fail(self):
# ep.connect() is called before dilation is abandoned
scid0 = 0
peeraddr = _SubchannelAddress(scid0)
sc0 = mock.Mock()
alsoProvides(sc0, ISubChannel)
eq = EventualQueue(Clock())
ep = ControlEndpoint(peeraddr, sc0, eq)
f = mock.Mock()
p = mock.Mock()
f.buildProtocol = mock.Mock(return_value=p)
d = ep.connect(f)
self.assertNoResult(d)
ep._main_channel_failed(Failure(CannotDilateError()))
eq.flush_sync()
self.failureResultOf(d).check(CannotDilateError)
self.assertEqual(f.buildProtocol.mock_calls, [])
self.assertEqual(sc0.mock_calls, [])
d = ep.connect(f)
self.failureResultOf(d, SingleUseEndpointError)
def test_late_succeed(self):
# dilation can proceed, then ep.connect() is called
scid0 = 0
peeraddr = _SubchannelAddress(scid0)
sc0 = mock.Mock()
alsoProvides(sc0, ISubChannel)
eq = EventualQueue(Clock())
ep = ControlEndpoint(peeraddr, sc0, eq)
ep._main_channel_ready()
f = mock.Mock()
p = mock.Mock()
f.buildProtocol = mock.Mock(return_value=p)
d = ep.connect(f)
eq.flush_sync()
self.assertIdentical(self.successResultOf(d), p)
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)])
self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)])
d = ep.connect(f)
self.failureResultOf(d, SingleUseEndpointError)
def test_late_fail(self):
# dilation is abandoned, then ep.connect() is called
scid0 = 0
peeraddr = _SubchannelAddress(scid0)
sc0 = mock.Mock()
alsoProvides(sc0, ISubChannel)
eq = EventualQueue(Clock())
ep = ControlEndpoint(peeraddr, sc0, eq)
ep._main_channel_failed(Failure(CannotDilateError()))
f = mock.Mock()
p = mock.Mock()
f.buildProtocol = mock.Mock(return_value=p)
d = ep.connect(f)
eq.flush_sync()
self.failureResultOf(d).check(CannotDilateError)
self.assertEqual(f.buildProtocol.mock_calls, [])
self.assertEqual(sc0.mock_calls, [])
d = ep.connect(f)
self.failureResultOf(d, SingleUseEndpointError)
class Endpoints(unittest.TestCase):
def OFFassert_makeConnection(self, mock_calls):
self.assertEqual(len(mock_calls), 1)
self.assertEqual(mock_calls[0][0], "makeConnection")
self.assertEqual(len(mock_calls[0][1]), 1)
return mock_calls[0][1][0]
def test_connector(self):
class Connector(unittest.TestCase):
def test_early_succeed(self):
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
peeraddr = _SubchannelAddress(0)
ep = SubchannelConnectorEndpoint(m, hostaddr)
eq = EventualQueue(Clock())
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
f = mock.Mock()
p = mock.Mock()
@ -55,38 +140,123 @@ class Endpoints(unittest.TestCase):
with mock.patch("wormhole._dilation.subchannel.SubChannel",
return_value=t) as sc:
d = ep.connect(f)
eq.flush_sync()
self.assertNoResult(d)
ep._main_channel_ready()
eq.flush_sync()
self.assertIdentical(self.successResultOf(d), p)
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)])
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
def test_listener(self):
def test_early_fail(self):
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
ep = SubchannelListenerEndpoint(m, hostaddr)
eq = EventualQueue(Clock())
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
f = mock.Mock()
p = mock.Mock()
t = mock.Mock()
f.buildProtocol = mock.Mock(return_value=p)
with mock.patch("wormhole._dilation.subchannel.SubChannel",
return_value=t) as sc:
d = ep.connect(f)
eq.flush_sync()
self.assertNoResult(d)
ep._main_channel_failed(Failure(CannotDilateError()))
eq.flush_sync()
self.failureResultOf(d).check(CannotDilateError)
self.assertEqual(f.buildProtocol.mock_calls, [])
self.assertEqual(sc.mock_calls, [])
self.assertEqual(t.mock_calls, [])
def test_late_succeed(self):
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
peeraddr = _SubchannelAddress(0)
eq = EventualQueue(Clock())
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
ep._main_channel_ready()
f = mock.Mock()
p = mock.Mock()
t = mock.Mock()
f.buildProtocol = mock.Mock(return_value=p)
with mock.patch("wormhole._dilation.subchannel.SubChannel",
return_value=t) as sc:
d = ep.connect(f)
eq.flush_sync()
self.assertIdentical(self.successResultOf(d), p)
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)])
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
def test_late_fail(self):
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
eq = EventualQueue(Clock())
ep = SubchannelConnectorEndpoint(m, hostaddr, eq)
ep._main_channel_failed(Failure(CannotDilateError()))
f = mock.Mock()
p = mock.Mock()
t = mock.Mock()
f.buildProtocol = mock.Mock(return_value=p)
with mock.patch("wormhole._dilation.subchannel.SubChannel",
return_value=t) as sc:
d = ep.connect(f)
eq.flush_sync()
self.failureResultOf(d).check(CannotDilateError)
self.assertEqual(f.buildProtocol.mock_calls, [])
self.assertEqual(sc.mock_calls, [])
self.assertEqual(t.mock_calls, [])
class Listener(unittest.TestCase):
def test_early_succeed(self):
# listen, main_channel_ready, got_open, got_open
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
eq = EventualQueue(Clock())
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
f = mock.Mock()
p1 = mock.Mock()
p2 = mock.Mock()
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
# OPEN that arrives before we ep.listen() should be queued
d = ep.listen(f)
eq.flush_sync()
self.assertNoResult(d)
self.assertEqual(f.buildProtocol.mock_calls, [])
ep._main_channel_ready()
eq.flush_sync()
lp = self.successResultOf(d)
self.assertIsInstance(lp, SubchannelListeningPort)
self.assertEqual(lp.getHost(), hostaddr)
# TODO: IListeningPort says we must provide this, but I don't know
# that anyone would ever call it.
lp.startListening()
t1 = mock.Mock()
peeraddr1 = _SubchannelAddress(1)
ep._got_open(t1, peeraddr1)
d = ep.listen(f)
lp = self.successResultOf(d)
self.assertIsInstance(lp, SubchannelListeningPort)
self.assertEqual(lp.getHost(), hostaddr)
lp.startListening()
self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)])
self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)])
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)])
t2 = mock.Mock()
peeraddr2 = _SubchannelAddress(2)
@ -94,5 +264,92 @@ class Endpoints(unittest.TestCase):
self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])
self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)])
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1),
mock.call(peeraddr2)])
lp.stopListening() # TODO: should this do more?
def test_early_fail(self):
# listen, main_channel_fail
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
eq = EventualQueue(Clock())
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
f = mock.Mock()
p1 = mock.Mock()
p2 = mock.Mock()
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
d = ep.listen(f)
eq.flush_sync()
self.assertNoResult(d)
ep._main_channel_failed(Failure(CannotDilateError()))
eq.flush_sync()
self.failureResultOf(d).check(CannotDilateError)
self.assertEqual(f.buildProtocol.mock_calls, [])
def test_late_succeed(self):
# main_channel_ready, got_open, listen, got_open
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
eq = EventualQueue(Clock())
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
ep._main_channel_ready()
f = mock.Mock()
p1 = mock.Mock()
p2 = mock.Mock()
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
t1 = mock.Mock()
peeraddr1 = _SubchannelAddress(1)
ep._got_open(t1, peeraddr1)
eq.flush_sync()
self.assertEqual(t1.mock_calls, [])
self.assertEqual(p1.mock_calls, [])
d = ep.listen(f)
eq.flush_sync()
lp = self.successResultOf(d)
self.assertIsInstance(lp, SubchannelListeningPort)
self.assertEqual(lp.getHost(), hostaddr)
lp.startListening()
self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)])
self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)])
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)])
t2 = mock.Mock()
peeraddr2 = _SubchannelAddress(2)
ep._got_open(t2, peeraddr2)
self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])
self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)])
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1),
mock.call(peeraddr2)])
lp.stopListening() # TODO: should this do more?
def test_late_fail(self):
# main_channel_fail, listen
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
eq = EventualQueue(Clock())
ep = SubchannelListenerEndpoint(m, hostaddr, eq)
ep._main_channel_failed(Failure(CannotDilateError()))
f = mock.Mock()
p1 = mock.Mock()
p2 = mock.Mock()
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
d = ep.listen(f)
eq.flush_sync()
self.failureResultOf(d).check(CannotDilateError)
self.assertEqual(f.buildProtocol.mock_calls, [])

View File

@ -46,24 +46,21 @@ class Full(ServerBase, unittest.TestCase):
yield doBoth(w1.get_verifier(), w2.get_verifier())
print("connected")
eps1_d = w1.dilate()
eps2_d = w2.dilate()
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
(control_ep1, connect_ep1, listen_ep1) = eps1
(control_ep2, connect_ep2, listen_ep2) = eps2
eps1 = w1.dilate()
eps2 = w2.dilate()
print("w.dilate ready")
f1 = Factory()
f1.protocol = L
f1.d = Deferred()
f1.d.addCallback(lambda data: eq.fire_eventually(data))
d1 = control_ep1.connect(f1)
d1 = eps1.control.connect(f1)
f2 = Factory()
f2.protocol = L
f2.d = Deferred()
f2.d.addCallback(lambda data: eq.fire_eventually(data))
d2 = control_ep2.connect(f2)
d2 = eps2.control.connect(f2)
yield d1
yield d2
print("control endpoints connected")
@ -125,14 +122,12 @@ class Reconnect(ServerBase, unittest.TestCase):
w2.set_code(code)
yield doBoth(w1.get_verifier(), w2.get_verifier())
eps1_d = w1.dilate()
eps2_d = w2.dilate()
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
(control_ep1, connect_ep1, listen_ep1) = eps1
(control_ep2, connect_ep2, listen_ep2) = eps2
eps1 = w1.dilate()
eps2 = w2.dilate()
print("w.dilate ready")
f1 = ReconF(eq); f2 = ReconF(eq)
d1 = control_ep1.connect(f1); d2 = control_ep2.connect(f2)
d1 = eps1.control.connect(f1); d2 = eps2.control.connect(f2)
yield d1
yield d2
@ -194,14 +189,12 @@ class Reconnect(ServerBase, unittest.TestCase):
w2.set_code(code)
yield doBoth(w1.get_verifier(), w2.get_verifier())
eps1_d = w1.dilate()
eps2_d = w2.dilate()
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
(control_ep1, connect_ep1, listen_ep1) = eps1
(control_ep2, connect_ep2, listen_ep2) = eps2
eps1 = w1.dilate()
eps2 = w2.dilate()
print("w.dilate ready")
f1 = ReconF(eq); f2 = ReconF(eq)
d1 = control_ep1.connect(f1); d2 = control_ep2.connect(f2)
d1 = eps1.control.connect(f1); d2 = eps2.control.connect(f2)
yield d1
yield d2
@ -287,19 +280,17 @@ class Endpoints(ServerBase, unittest.TestCase):
w2.set_code(code)
yield doBoth(w1.get_verifier(), w2.get_verifier())
eps1_d = w1.dilate()
eps2_d = w2.dilate()
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
(control_ep1, connect_ep1, listen_ep1) = eps1
(control_ep2, connect_ep2, listen_ep2) = eps2
eps1 = w1.dilate()
eps2 = w2.dilate()
print("w.dilate ready")
f0 = ReconF(eq)
yield listen_ep2.listen(f0)
yield eps2.listen.listen(f0)
from twisted.python import log
f1 = ReconF(eq)
log.msg("connecting")
p1_client = yield connect_ep1.connect(f1)
p1_client = yield eps1.connect.connect(f1)
log.msg("sending c->s")
p1_client.transport.write(b"hello from p1\n")
data = yield f0.deferreds["dataReceived"]
@ -316,7 +307,7 @@ class Endpoints(ServerBase, unittest.TestCase):
f0.resetDeferred("dataReceived")
f1.resetDeferred("dataReceived")
f2 = ReconF(eq)
p2_client = yield connect_ep1.connect(f2)
p2_client = yield eps1.connect.connect(f2)
p2_server = yield f0.deferreds["connectionMade"]
p2_server.transport.write(b"hello p2\n")
data = yield f2.deferreds["dataReceived"]

View File

@ -1,12 +1,11 @@
from __future__ import print_function, unicode_literals
from zope.interface import alsoProvides
from twisted.trial import unittest
from twisted.internet.defer import Deferred
from twisted.internet.task import Clock, Cooperator
from twisted.internet.interfaces import IAddress
from twisted.internet.interfaces import IStreamServerEndpoint
import mock
from ...eventual import EventualQueue
from ..._interfaces import ISend, IDilationManager, ITerminator
from ..._interfaces import ISend, ITerminator, ISubChannel
from ...util import dict_to_bytes
from ..._dilation import roles
from ..._dilation.manager import (Dilator, Manager, make_side,
@ -15,35 +14,57 @@ from ..._dilation.manager import (Dilator, Manager, make_side,
UnexpectedKCM,
UnknownMessageType)
from ..._dilation.connection import Open, Data, Close, Ack, KCM, Ping, Pong
from ..._dilation.subchannel import _SubchannelAddress
from .common import clear_mock_calls
class Holder():
pass
def make_dilator():
reactor = object()
clock = Clock()
eq = EventualQueue(clock)
h = Holder()
h.reactor = object()
h.clock = Clock()
h.eq = EventualQueue(h.clock)
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
def term_factory():
return term
coop = Cooperator(terminationPredicateFactory=term_factory,
scheduler=eq.eventually)
send = mock.Mock()
alsoProvides(send, ISend)
dil = Dilator(reactor, eq, coop)
terminator = mock.Mock()
alsoProvides(terminator, ITerminator)
dil.wire(send, terminator)
return dil, send, reactor, eq, clock, coop
h.coop = Cooperator(terminationPredicateFactory=term_factory,
scheduler=h.eq.eventually)
h.send = mock.Mock()
alsoProvides(h.send, ISend)
dil = Dilator(h.reactor, h.eq, h.coop)
h.terminator = mock.Mock()
alsoProvides(h.terminator, ITerminator)
dil.wire(h.send, h.terminator)
return dil, h
class TestDilator(unittest.TestCase):
def test_manager_and_endpoints(self):
dil, send, reactor, eq, clock, coop = make_dilator()
d1 = dil.dilate()
d2 = dil.dilate()
self.assertNoResult(d1)
self.assertNoResult(d2)
# we should test the interleavings between:
# * application calls w.dilate() and gets back endpoints
# * wormhole gets: dilation key, VERSION, 0-n dilation messages
def test_dilate_first(self):
(dil, h) = make_dilator()
side = object()
m = mock.Mock()
eps = object()
m.get_endpoints = mock.Mock(return_value=eps)
mm = mock.Mock(side_effect=[m])
with mock.patch("wormhole._dilation.manager.Manager", mm), \
mock.patch("wormhole._dilation.manager.make_side",
return_value=side):
eps1 = dil.dilate()
eps2 = dil.dilate()
self.assertIdentical(eps1, eps)
self.assertIdentical(eps1, eps2)
self.assertEqual(mm.mock_calls, [mock.call(h.send, side, None,
h.reactor, h.eq, h.coop, False)])
self.assertEqual(m.mock_calls, [mock.call.get_endpoints(),
mock.call.get_endpoints()])
clear_mock_calls(m)
key = b"key"
transit_key = object()
@ -51,183 +72,108 @@ class TestDilator(unittest.TestCase):
return_value=transit_key) as dk:
dil.got_key(key)
self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)])
self.assertIdentical(dil._transit_key, transit_key)
self.assertNoResult(d1)
self.assertNoResult(d2)
self.assertEqual(m.mock_calls, [mock.call.got_dilation_key(transit_key)])
clear_mock_calls(m)
host_addr = dil._host_addr
wv = object()
dil.got_wormhole_versions(wv)
self.assertEqual(m.mock_calls, [mock.call.got_wormhole_versions(wv)])
clear_mock_calls(m)
peer_addr = object()
m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress",
return_value=peer_addr)
sc = mock.Mock()
m_sc = mock.patch("wormhole._dilation.manager.SubChannel",
return_value=sc)
scid0 = 0
m = mock.Mock()
alsoProvides(m, IDilationManager)
m.when_first_connected.return_value = wfc_d = Deferred()
with mock.patch("wormhole._dilation.manager.Manager",
return_value=m) as ml:
with mock.patch("wormhole._dilation.manager.make_side",
return_value="us"):
with m_sca, m_sc as m_sc_m:
dil.got_wormhole_versions({"can-dilate": ["1"]})
# that should create the Manager
self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key,
None, reactor, eq, coop, host_addr, False)])
# and create subchannel0
self.assertEqual(m_sc_m.mock_calls,
[mock.call(scid0, m, host_addr, peer_addr)])
# and tell it to start, and get wait-for-it-to-connect Deferred
self.assertEqual(m.mock_calls, [mock.call.set_subchannel_zero(scid0, sc),
mock.call.start(),
mock.call.when_first_connected(),
dm1 = object()
dm2 = object()
dil.received_dilate(dm1)
dil.received_dilate(dm2)
self.assertEqual(m.mock_calls, [mock.call.received_dilation_message(dm1),
mock.call.received_dilation_message(dm2),
])
clear_mock_calls(m)
self.assertNoResult(d1)
self.assertNoResult(d2)
ce = mock.Mock()
m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint",
return_value=ce)
lep = object()
m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint",
return_value=lep)
stopped_d = mock.Mock()
m.when_stopped = mock.Mock(return_value=stopped_d)
dil.stop()
self.assertEqual(m.mock_calls, [mock.call.stop(),
mock.call.when_stopped(),
])
with m_ce as m_ce_m, m_sle as m_sle_m:
wfc_d.callback(None)
eq.flush_sync()
self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)])
self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)])
self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)])
self.assertEqual(m.mock_calls,
[mock.call.set_listener_endpoint(lep),
])
def test_dilate_later(self):
(dil, h) = make_dilator()
m = mock.Mock()
mm = mock.Mock(side_effect=[m])
key = b"key"
transit_key = object()
with mock.patch("wormhole._dilation.manager.derive_key",
return_value=transit_key) as dk:
dil.got_key(key)
self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)])
wv = object()
dil.got_wormhole_versions(wv)
dm1 = object()
dil.received_dilate(dm1)
self.assertEqual(mm.mock_calls, [])
with mock.patch("wormhole._dilation.manager.Manager", mm):
dil.dilate()
self.assertEqual(m.mock_calls, [mock.call.got_dilation_key(transit_key),
mock.call.got_wormhole_versions(wv),
mock.call.received_dilation_message(dm1),
mock.call.get_endpoints(),
])
clear_mock_calls(m)
eps = self.successResultOf(d1)
self.assertEqual(eps, self.successResultOf(d2))
d3 = dil.dilate()
eq.flush_sync()
self.assertEqual(eps, self.successResultOf(d3))
dm2 = object()
dil.received_dilate(dm2)
self.assertEqual(m.mock_calls, [mock.call.received_dilation_message(dm2),
])
# all subsequent DILATE-n messages should get passed to the manager
self.assertEqual(m.mock_calls, [])
pleasemsg = dict(type="please", side="them")
dil.received_dilate(dict_to_bytes(pleasemsg))
self.assertEqual(m.mock_calls, [mock.call.rx_PLEASE(pleasemsg)])
clear_mock_calls(m)
hintmsg = dict(type="connection-hints")
dil.received_dilate(dict_to_bytes(hintmsg))
self.assertEqual(m.mock_calls, [mock.call.rx_HINTS(hintmsg)])
clear_mock_calls(m)
# we're nominally the LEADER, and the leader would not normally be
# receiving a RECONNECT, but since we've mocked out the Manager it
# won't notice
dil.received_dilate(dict_to_bytes(dict(type="reconnect")))
self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECT()])
clear_mock_calls(m)
dil.received_dilate(dict_to_bytes(dict(type="reconnecting")))
self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECTING()])
clear_mock_calls(m)
dil.received_dilate(dict_to_bytes(dict(type="unknown")))
self.assertEqual(m.mock_calls, [])
self.flushLoggedErrors(UnknownDilationMessageType)
def test_stop_early(self):
(dil, h) = make_dilator()
# we stop before w.dilate(), so there is no Manager to stop
dil.stop()
self.assertEqual(h.terminator.mock_calls, [mock.call.stoppedD()])
def test_peer_cannot_dilate(self):
dil, send, reactor, eq, clock, coop = make_dilator()
d1 = dil.dilate()
self.assertNoResult(d1)
(dil, h) = make_dilator()
eps = dil.dilate()
dil._transit_key = b"\x01" * 32
dil.got_key(b"\x01" * 32)
dil.got_wormhole_versions({}) # missing "can-dilate"
eq.flush_sync()
f = self.failureResultOf(d1)
f.check(OldPeerCannotDilateError)
d = eps.connect.connect(None)
h.eq.flush_sync()
self.failureResultOf(d).check(OldPeerCannotDilateError)
def test_disjoint_versions(self):
dil, send, reactor, eq, clock, coop = make_dilator()
d1 = dil.dilate()
self.assertNoResult(d1)
(dil, h) = make_dilator()
eps = dil.dilate()
dil._transit_key = b"key"
dil.got_key(b"\x01" * 32)
dil.got_wormhole_versions({"can-dilate": [-1]})
eq.flush_sync()
f = self.failureResultOf(d1)
f.check(OldPeerCannotDilateError)
def test_early_dilate_messages(self):
dil, send, reactor, eq, clock, coop = make_dilator()
dil._transit_key = b"key"
d1 = dil.dilate()
host_addr = dil._host_addr
self.assertNoResult(d1)
pleasemsg = dict(type="please", side="them")
dil.received_dilate(dict_to_bytes(pleasemsg))
hintmsg = dict(type="connection-hints")
dil.received_dilate(dict_to_bytes(hintmsg))
m = mock.Mock()
alsoProvides(m, IDilationManager)
m.when_first_connected.return_value = Deferred()
scid0 = 0
sc = mock.Mock()
m_sc = mock.patch("wormhole._dilation.manager.SubChannel",
return_value=sc)
with mock.patch("wormhole._dilation.manager.Manager",
return_value=m) as ml:
with mock.patch("wormhole._dilation.manager.make_side",
return_value="us"):
with m_sc:
dil.got_wormhole_versions({"can-dilate": ["1"]})
self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
None, reactor, eq, coop, host_addr, False)])
self.assertEqual(m.mock_calls, [mock.call.set_subchannel_zero(scid0, sc),
mock.call.start(),
mock.call.rx_PLEASE(pleasemsg),
mock.call.rx_HINTS(hintmsg),
mock.call.when_first_connected()])
d = eps.connect.connect(None)
h.eq.flush_sync()
self.failureResultOf(d).check(OldPeerCannotDilateError)
def test_transit_relay(self):
dil, send, reactor, eq, clock, coop = make_dilator()
dil._transit_key = b"key"
host_addr = dil._host_addr
relay = object()
d1 = dil.dilate(transit_relay_location=relay)
self.assertNoResult(d1)
scid0 = 0
sc = mock.Mock()
m_sc = mock.patch("wormhole._dilation.manager.SubChannel",
return_value=sc)
with mock.patch("wormhole._dilation.manager.Manager") as ml:
with mock.patch("wormhole._dilation.manager.make_side",
return_value="us"):
with m_sc:
dil.got_wormhole_versions({"can-dilate": ["1"]})
self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
relay, reactor, eq, coop, host_addr, False),
mock.call().set_subchannel_zero(scid0, sc),
mock.call().start(),
mock.call().when_first_connected()])
(dil, h) = make_dilator()
transit_relay_location = object()
side = object()
m = mock.Mock()
mm = mock.Mock(side_effect=[m])
with mock.patch("wormhole._dilation.manager.Manager", mm), \
mock.patch("wormhole._dilation.manager.make_side",
return_value=side):
dil.dilate(transit_relay_location)
self.assertEqual(mm.mock_calls, [mock.call(h.send, side, transit_relay_location,
h.reactor, h.eq, h.coop, False)])
LEADER = "ff3456abcdef"
FOLLOWER = "123456abcdef"
def make_manager(leader=True):
class Holder:
pass
h = Holder()
h.send = mock.Mock()
alsoProvides(h.send, ISend)
@ -250,11 +196,19 @@ def make_manager(leader=True):
h.Inbound = mock.Mock(return_value=h.inbound)
h.outbound = mock.Mock()
h.Outbound = mock.Mock(return_value=h.outbound)
h.hostaddr = mock.Mock()
alsoProvides(h.hostaddr, IAddress)
with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound):
with mock.patch("wormhole._dilation.manager.Outbound", h.Outbound):
m = Manager(h.send, side, h.key, h.relay, h.reactor, h.eq, h.coop, h.hostaddr)
h.sc0 = mock.Mock()
alsoProvides(h.sc0, ISubChannel)
h.SubChannel = mock.Mock(return_value=h.sc0)
h.listen_ep = mock.Mock()
alsoProvides(h.listen_ep, IStreamServerEndpoint)
with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound), \
mock.patch("wormhole._dilation.manager.Outbound", h.Outbound), \
mock.patch("wormhole._dilation.manager.SubChannel", h.SubChannel), \
mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint",
return_value=h.listen_ep):
m = Manager(h.send, side, h.relay, h.reactor, h.eq, h.coop)
h.hostaddr = m._host_addr
m.got_dilation_key(h.key)
return m, h
@ -272,16 +226,26 @@ class TestManager(unittest.TestCase):
self.assertEqual(h.send.mock_calls, [])
self.assertEqual(h.Inbound.mock_calls, [mock.call(m, h.hostaddr)])
self.assertEqual(h.Outbound.mock_calls, [mock.call(m, h.coop)])
scid0 = 0
sc0_peer_addr = _SubchannelAddress(scid0)
self.assertEqual(h.SubChannel.mock_calls, [
mock.call(scid0, m, m._host_addr, sc0_peer_addr),
])
self.assertEqual(h.inbound.mock_calls, [
mock.call.set_subchannel_zero(scid0, h.sc0),
mock.call.set_listener_endpoint(h.listen_ep)
])
clear_mock_calls(h.inbound)
m.start()
m.got_wormhole_versions({"can-dilate": ["1"]})
self.assertEqual(h.send.mock_calls, [
mock.call.send("dilate-0",
dict_to_bytes({"type": "please", "side": LEADER}))
])
clear_mock_calls(h.send)
wfc_d = m.when_first_connected()
self.assertNoResult(wfc_d)
listen_d = m.get_endpoints().listen.listen(None)
self.assertNoResult(listen_d)
# ignore early hints
m.rx_HINTS({})
@ -303,7 +267,7 @@ class TestManager(unittest.TestCase):
self.assertEqual(c.mock_calls, [mock.call.start()])
clear_mock_calls(connector, c)
self.assertNoResult(wfc_d)
self.assertNoResult(listen_d)
# now any inbound hints should get passed to our Connector
with mock.patch("wormhole._dilation.manager.parse_hint",
@ -323,7 +287,7 @@ class TestManager(unittest.TestCase):
clear_mock_calls(h.send)
# the first successful connection fires when_first_connected(), so
# the Dilator can create and return the endpoints
# the endpoints can activate
c1 = mock.Mock()
m.connector_connection_made(c1)
@ -331,23 +295,6 @@ class TestManager(unittest.TestCase):
self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c1)])
clear_mock_calls(h.inbound, h.outbound)
h.eq.flush_sync()
self.successResultOf(wfc_d) # fires with None
wfc_d2 = m.when_first_connected()
h.eq.flush_sync()
self.successResultOf(wfc_d2)
scid0 = 0
sc0 = mock.Mock()
m.set_subchannel_zero(scid0, sc0)
listen_ep = mock.Mock()
m.set_listener_endpoint(listen_ep)
self.assertEqual(h.inbound.mock_calls, [
mock.call.set_subchannel_zero(scid0, sc0),
mock.call.set_listener_endpoint(listen_ep),
])
clear_mock_calls(h.inbound)
# the Leader making a new outbound channel should get scid=1
scid1 = 1
self.assertEqual(m.allocate_subchannel_id(), scid1)
@ -494,12 +441,13 @@ class TestManager(unittest.TestCase):
def test_follower(self):
m, h = make_manager(leader=False)
m.start()
m.got_wormhole_versions({"can-dilate": ["1"]})
self.assertEqual(h.send.mock_calls, [
mock.call.send("dilate-0",
dict_to_bytes({"type": "please", "side": FOLLOWER}))
])
clear_mock_calls(h.send)
clear_mock_calls(h.inbound)
c = mock.Mock()
connector = mock.Mock(return_value=c)
@ -629,6 +577,7 @@ class TestManager(unittest.TestCase):
def test_subchannel(self):
m, h = make_manager(leader=True)
clear_mock_calls(h.inbound)
sc = object()
m.subchannel_pauseProducing(sc)
@ -665,3 +614,14 @@ class TestManager(unittest.TestCase):
self.assertEqual(h.outbound.mock_calls, [
mock.call.subchannel_closed(4, sc)])
clear_mock_calls(h.inbound, h.outbound)
def test_unknown_message(self):
# receive a PLEASE with the same side as us: shouldn't happen
m, h = make_manager(leader=True)
m.start()
m.received_dilation_message(dict_to_bytes(dict(type="unknown")))
self.flushLoggedErrors(UnknownDilationMessageType)
# TODO: test transit relay is used