diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index abdf68c..574c7b2 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -6,6 +6,7 @@ from attr.validators import provides, instance_of, optional from automat import MethodicalMachine from zope.interface import implementer from twisted.internet.defer import Deferred, inlineCallbacks, returnValue +from twisted.internet.interfaces import IAddress from twisted.python import log from .._interfaces import IDilator, IDilationManager, ISend, ITerminator from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr @@ -97,6 +98,7 @@ class Manager(object): _reactor = attrib(repr=False) _eventual_queue = attrib(repr=False) _cooperator = attrib(repr=False) + _host_addr = attrib(validator=provides(IAddress)) _no_listen = attrib(default=False) _tor = None # TODO _timing = None # TODO @@ -114,7 +116,6 @@ class Manager(object): self._made_first_connection = False self._first_connected = OneShotObserver(self._eventual_queue) self._stopped = OneShotObserver(self._eventual_queue) - self._host_addr = _WormholeAddress() self._next_dilation_phase = 0 @@ -484,6 +485,7 @@ class Dilator(object): self._endpoints = OneShotObserver(self._eventual_queue) self._pending_inbound_dilate_messages = deque() self._manager = None + self._host_addr = _WormholeAddress() def wire(self, sender, terminator): self._S = ISend(sender) @@ -521,7 +523,16 @@ class Dilator(object): self._transit_key, self._transit_relay_location, self._reactor, self._eventual_queue, - self._cooperator, no_listen) + self._cooperator, self._host_addr, no_listen) + # We must open subchannel0 early, since messages may arrive very + # quickly once the connection is established. This subchannel may or + # may not ever get revealed to the caller, since the peer might not + # even be capable of dilation. + scid0 = to_be4(0) + peer_addr0 = _SubchannelAddress(scid0) + sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0) + self._manager.set_subchannel_zero(scid0, sc0) + self._manager.start() while self._pending_inbound_dilate_messages: @@ -530,15 +541,10 @@ class Dilator(object): yield self._manager.when_first_connected() - # we can open subchannels as soon as we get our first connection - scid0 = to_be4(0) - self._host_addr = _WormholeAddress() # TODO: share with Manager - peer_addr0 = _SubchannelAddress(scid0) + # we can open non-zero subchannels as soon as we get our first + # connection control_ep = ControlEndpoint(peer_addr0) - sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0) control_ep._subchannel_zero_opened(sc0) - self._manager.set_subchannel_zero(scid0, sc0) - connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index 42ef0f6..5a86d3f 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -3,6 +3,7 @@ from zope.interface import alsoProvides from twisted.trial import unittest from twisted.internet.defer import Deferred from twisted.internet.task import Clock, Cooperator +from twisted.internet.interfaces import IAddress import mock from ...eventual import EventualQueue from ..._interfaces import ISend, IDilationManager, ITerminator @@ -14,7 +15,6 @@ from ..._dilation.manager import (Dilator, Manager, make_side, UnknownDilationMessageType, UnexpectedKCM, UnknownMessageType) -from ..._dilation.subchannel import _WormholeAddress from ..._dilation.connection import Open, Data, Close, Ack, KCM, Ping, Pong from .common import clear_mock_calls @@ -56,6 +56,16 @@ class TestDilator(unittest.TestCase): self.assertNoResult(d1) self.assertNoResult(d2) + host_addr = dil._host_addr + + peer_addr = object() + m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress", + return_value=peer_addr) + sc = mock.Mock() + m_sc = mock.patch("wormhole._dilation.manager.SubChannel", + return_value=sc) + scid0 = b"\x00\x00\x00\x00" + m = mock.Mock() alsoProvides(m, IDilationManager) m.when_first_connected.return_value = wfc_d = Deferred() @@ -63,47 +73,38 @@ class TestDilator(unittest.TestCase): return_value=m) as ml: with mock.patch("wormhole._dilation.manager.make_side", return_value="us"): - dil.got_wormhole_versions({"can-dilate": ["1"]}) + with m_sca, m_sc as m_sc_m: + dil.got_wormhole_versions({"can-dilate": ["1"]}) # that should create the Manager self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key, - None, reactor, eq, coop, False)]) + None, reactor, eq, coop, host_addr, False)]) + # and create subchannel0 + self.assertEqual(m_sc_m.mock_calls, + [mock.call(scid0, m, host_addr, peer_addr)]) # and tell it to start, and get wait-for-it-to-connect Deferred - self.assertEqual(m.mock_calls, [mock.call.start(), + self.assertEqual(m.mock_calls, [mock.call.set_subchannel_zero(scid0, sc), + mock.call.start(), mock.call.when_first_connected(), ]) clear_mock_calls(m) self.assertNoResult(d1) self.assertNoResult(d2) - host_addr = _WormholeAddress() - m_wa = mock.patch("wormhole._dilation.manager._WormholeAddress", - return_value=host_addr) - peer_addr = object() - m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress", - return_value=peer_addr) ce = mock.Mock() m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", return_value=ce) - sc = mock.Mock() - m_sc = mock.patch("wormhole._dilation.manager.SubChannel", - return_value=sc) - lep = object() m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", return_value=lep) - with m_wa, m_sca, m_ce as m_ce_m, m_sc as m_sc_m, m_sle as m_sle_m: + with m_ce as m_ce_m, m_sle as m_sle_m: wfc_d.callback(None) eq.flush_sync() - scid0 = b"\x00\x00\x00\x00" self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)]) - self.assertEqual(m_sc_m.mock_calls, - [mock.call(scid0, m, host_addr, peer_addr)]) self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)]) self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)]) self.assertEqual(m.mock_calls, - [mock.call.set_subchannel_zero(scid0, sc), - mock.call.set_listener_endpoint(lep), + [mock.call.set_listener_endpoint(lep), ]) clear_mock_calls(m) @@ -166,6 +167,7 @@ class TestDilator(unittest.TestCase): dil, send, reactor, eq, clock, coop = make_dilator() dil._transit_key = b"key" d1 = dil.dilate() + host_addr = dil._host_addr self.assertNoResult(d1) pleasemsg = dict(type="please", side="them") dil.received_dilate(dict_to_bytes(pleasemsg)) @@ -176,14 +178,21 @@ class TestDilator(unittest.TestCase): alsoProvides(m, IDilationManager) m.when_first_connected.return_value = Deferred() + scid0 = b"\x00\x00\x00\x00" + sc = mock.Mock() + m_sc = mock.patch("wormhole._dilation.manager.SubChannel", + return_value=sc) + with mock.patch("wormhole._dilation.manager.Manager", return_value=m) as ml: with mock.patch("wormhole._dilation.manager.make_side", return_value="us"): - dil.got_wormhole_versions({"can-dilate": ["1"]}) + with m_sc: + dil.got_wormhole_versions({"can-dilate": ["1"]}) self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", - None, reactor, eq, coop, False)]) - self.assertEqual(m.mock_calls, [mock.call.start(), + None, reactor, eq, coop, host_addr, False)]) + self.assertEqual(m.mock_calls, [mock.call.set_subchannel_zero(scid0, sc), + mock.call.start(), mock.call.rx_PLEASE(pleasemsg), mock.call.rx_HINTS(hintmsg), mock.call.when_first_connected()]) @@ -191,16 +200,24 @@ class TestDilator(unittest.TestCase): def test_transit_relay(self): dil, send, reactor, eq, clock, coop = make_dilator() dil._transit_key = b"key" + host_addr = dil._host_addr relay = object() d1 = dil.dilate(transit_relay_location=relay) self.assertNoResult(d1) + scid0 = b"\x00\x00\x00\x00" + sc = mock.Mock() + m_sc = mock.patch("wormhole._dilation.manager.SubChannel", + return_value=sc) + with mock.patch("wormhole._dilation.manager.Manager") as ml: with mock.patch("wormhole._dilation.manager.make_side", return_value="us"): - dil.got_wormhole_versions({"can-dilate": ["1"]}) + with m_sc: + dil.got_wormhole_versions({"can-dilate": ["1"]}) self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", - relay, reactor, eq, coop, False), + relay, reactor, eq, coop, host_addr, False), + mock.call().set_subchannel_zero(scid0, sc), mock.call().start(), mock.call().when_first_connected()]) @@ -234,12 +251,11 @@ def make_manager(leader=True): h.Inbound = mock.Mock(return_value=h.inbound) h.outbound = mock.Mock() h.Outbound = mock.Mock(return_value=h.outbound) - h.hostaddr = object() + h.hostaddr = mock.Mock() + alsoProvides(h.hostaddr, IAddress) with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound): with mock.patch("wormhole._dilation.manager.Outbound", h.Outbound): - with mock.patch("wormhole._dilation.manager._WormholeAddress", - return_value=h.hostaddr): - m = Manager(h.send, side, h.key, h.relay, h.reactor, h.eq, h.coop) + m = Manager(h.send, side, h.key, h.relay, h.reactor, h.eq, h.coop, h.hostaddr) return m, h