diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index 309c1fd..3fdd051 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -1,12 +1,11 @@ from __future__ import print_function, unicode_literals from zope.interface import alsoProvides from twisted.trial import unittest -from twisted.internet.defer import Deferred from twisted.internet.task import Clock, Cooperator -from twisted.internet.interfaces import IAddress +from twisted.internet.interfaces import IStreamServerEndpoint import mock from ...eventual import EventualQueue -from ..._interfaces import ISend, IDilationManager, ITerminator +from ..._interfaces import ISend, ITerminator, ISubChannel from ...util import dict_to_bytes from ..._dilation import roles from ..._dilation.manager import (Dilator, Manager, make_side, @@ -15,35 +14,57 @@ from ..._dilation.manager import (Dilator, Manager, make_side, UnexpectedKCM, UnknownMessageType) from ..._dilation.connection import Open, Data, Close, Ack, KCM, Ping, Pong +from ..._dilation.subchannel import _SubchannelAddress from .common import clear_mock_calls +class Holder(): + pass def make_dilator(): - reactor = object() - clock = Clock() - eq = EventualQueue(clock) + h = Holder() + h.reactor = object() + h.clock = Clock() + h.eq = EventualQueue(h.clock) term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick def term_factory(): return term - coop = Cooperator(terminationPredicateFactory=term_factory, - scheduler=eq.eventually) - send = mock.Mock() - alsoProvides(send, ISend) - dil = Dilator(reactor, eq, coop) - terminator = mock.Mock() - alsoProvides(terminator, ITerminator) - dil.wire(send, terminator) - return dil, send, reactor, eq, clock, coop + h.coop = Cooperator(terminationPredicateFactory=term_factory, + scheduler=h.eq.eventually) + h.send = mock.Mock() + alsoProvides(h.send, ISend) + dil = Dilator(h.reactor, h.eq, h.coop) + h.terminator = mock.Mock() + alsoProvides(h.terminator, ITerminator) + dil.wire(h.send, h.terminator) + return dil, h class TestDilator(unittest.TestCase): - def test_manager_and_endpoints(self): - dil, send, reactor, eq, clock, coop = make_dilator() - d1 = dil.dilate() - d2 = dil.dilate() - self.assertNoResult(d1) - self.assertNoResult(d2) + # we should test the interleavings between: + # * application calls w.dilate() and gets back endpoints + # * wormhole gets: dilation key, VERSION, 0-n dilation messages + + def test_dilate_first(self): + (dil, h) = make_dilator() + side = object() + m = mock.Mock() + eps = object() + m.get_endpoints = mock.Mock(return_value=eps) + mm = mock.Mock(side_effect=[m]) + with mock.patch("wormhole._dilation.manager.Manager", mm), \ + mock.patch("wormhole._dilation.manager.make_side", + return_value=side): + eps1 = dil.dilate() + eps2 = dil.dilate() + self.assertIdentical(eps1, eps) + self.assertIdentical(eps1, eps2) + self.assertEqual(mm.mock_calls, [mock.call(h.send, side, None, + h.reactor, h.eq, h.coop, False)]) + + self.assertEqual(m.mock_calls, [mock.call.get_endpoints(), + mock.call.get_endpoints()]) + clear_mock_calls(m) key = b"key" transit_key = object() @@ -51,207 +72,108 @@ class TestDilator(unittest.TestCase): return_value=transit_key) as dk: dil.got_key(key) self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)]) - self.assertIdentical(dil._transit_key, transit_key) - self.assertNoResult(d1) - self.assertNoResult(d2) + self.assertEqual(m.mock_calls, [mock.call.got_dilation_key(transit_key)]) + clear_mock_calls(m) - host_addr = dil._host_addr + wv = object() + dil.got_wormhole_versions(wv) + self.assertEqual(m.mock_calls, [mock.call.got_wormhole_versions(wv)]) + clear_mock_calls(m) - peer_addr = object() - m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress", - return_value=peer_addr) - sc = mock.Mock() - m_sc = mock.patch("wormhole._dilation.manager.SubChannel", - return_value=sc) - ce = mock.Mock() - m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", - return_value=ce) - lep = object() - m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", - return_value=lep) - scid0 = 0 - - m = mock.Mock() - alsoProvides(m, IDilationManager) - m.when_first_connected.return_value = wfc_d = Deferred() - with mock.patch("wormhole._dilation.manager.Manager", - return_value=m) as ml: - with mock.patch("wormhole._dilation.manager.make_side", - return_value="us"): - with m_sca, m_sc as m_sc_m, m_ce as m_ce_m, m_sle as m_sle_m: - dil.got_wormhole_versions({"can-dilate": ["1"]}) - # that should create the Manager - self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key, - None, reactor, eq, coop, host_addr, False)]) - # and the three endpoints - self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)]) - self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)]) - self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)]) - # and create subchannel0 - self.assertEqual(m_sc_m.mock_calls, - [mock.call(scid0, m, host_addr, peer_addr)]) - # and tell it to start, and get wait-for-it-to-connect Deferred - self.assertEqual(m.mock_calls, [mock.call.set_subchannel_zero(scid0, sc), - mock.call.set_listener_endpoint(lep), - mock.call.start(), - mock.call.when_first_connected(), + dm1 = object() + dm2 = object() + dil.received_dilate(dm1) + dil.received_dilate(dm2) + self.assertEqual(m.mock_calls, [mock.call.received_dilation_message(dm1), + mock.call.received_dilation_message(dm2), ]) clear_mock_calls(m) - self.assertNoResult(d1) - self.assertNoResult(d2) - wfc_d.callback(None) - eq.flush_sync() - self.assertEqual(m.mock_calls, []) + stopped_d = mock.Mock() + m.when_stopped = mock.Mock(return_value=stopped_d) + dil.stop() + self.assertEqual(m.mock_calls, [mock.call.stop(), + mock.call.when_stopped(), + ]) + + def test_dilate_later(self): + (dil, h) = make_dilator() + m = mock.Mock() + mm = mock.Mock(side_effect=[m]) + + key = b"key" + transit_key = object() + with mock.patch("wormhole._dilation.manager.derive_key", + return_value=transit_key) as dk: + dil.got_key(key) + self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)]) + + wv = object() + dil.got_wormhole_versions(wv) + + dm1 = object() + dil.received_dilate(dm1) + + self.assertEqual(mm.mock_calls, []) + + with mock.patch("wormhole._dilation.manager.Manager", mm): + dil.dilate() + self.assertEqual(m.mock_calls, [mock.call.got_dilation_key(transit_key), + mock.call.got_wormhole_versions(wv), + mock.call.received_dilation_message(dm1), + mock.call.get_endpoints(), + ]) clear_mock_calls(m) - eps = self.successResultOf(d1) - self.assertEqual(eps, self.successResultOf(d2)) - d3 = dil.dilate() - eq.flush_sync() - self.assertEqual(eps, self.successResultOf(d3)) + dm2 = object() + dil.received_dilate(dm2) + self.assertEqual(m.mock_calls, [mock.call.received_dilation_message(dm2), + ]) - # all subsequent DILATE-n messages should get passed to the manager - self.assertEqual(m.mock_calls, []) - pleasemsg = dict(type="please", side="them") - dil.received_dilate(dict_to_bytes(pleasemsg)) - self.assertEqual(m.mock_calls, [mock.call.rx_PLEASE(pleasemsg)]) - clear_mock_calls(m) - - hintmsg = dict(type="connection-hints") - dil.received_dilate(dict_to_bytes(hintmsg)) - self.assertEqual(m.mock_calls, [mock.call.rx_HINTS(hintmsg)]) - clear_mock_calls(m) - - # we're nominally the LEADER, and the leader would not normally be - # receiving a RECONNECT, but since we've mocked out the Manager it - # won't notice - dil.received_dilate(dict_to_bytes(dict(type="reconnect"))) - self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECT()]) - clear_mock_calls(m) - - dil.received_dilate(dict_to_bytes(dict(type="reconnecting"))) - self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECTING()]) - clear_mock_calls(m) - - dil.received_dilate(dict_to_bytes(dict(type="unknown"))) - self.assertEqual(m.mock_calls, []) - self.flushLoggedErrors(UnknownDilationMessageType) + def test_stop_early(self): + (dil, h) = make_dilator() + # we stop before w.dilate(), so there is no Manager to stop + dil.stop() + self.assertEqual(h.terminator.mock_calls, [mock.call.stoppedD()]) def test_peer_cannot_dilate(self): - dil, send, reactor, eq, clock, coop = make_dilator() + (dil, h) = make_dilator() eps = dil.dilate() dil.got_key(b"\x01" * 32) dil.got_wormhole_versions({}) # missing "can-dilate" d = eps.connect.connect(None) - eq.flush_sync() + h.eq.flush_sync() self.failureResultOf(d).check(OldPeerCannotDilateError) def test_disjoint_versions(self): - dil, send, reactor, eq, clock, coop = make_dilator() + (dil, h) = make_dilator() eps = dil.dilate() dil.got_key(b"\x01" * 32) dil.got_wormhole_versions({"can-dilate": [-1]}) d = eps.connect.connect(None) - eq.flush_sync() + h.eq.flush_sync() self.failureResultOf(d).check(OldPeerCannotDilateError) - def test_early_dilate_messages(self): - dil, send, reactor, eq, clock, coop = make_dilator() - dil._transit_key = b"key" - d1 = dil.dilate() - host_addr = dil._host_addr - self.assertNoResult(d1) - pleasemsg = dict(type="please", side="them") - dil.received_dilate(dict_to_bytes(pleasemsg)) - hintmsg = dict(type="connection-hints") - dil.received_dilate(dict_to_bytes(hintmsg)) - - m = mock.Mock() - alsoProvides(m, IDilationManager) - m.when_first_connected.return_value = Deferred() - - scid0 = 0 - sc = mock.Mock() - m_sc = mock.patch("wormhole._dilation.manager.SubChannel", - return_value=sc) - peer_addr = object() - m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress", - return_value=peer_addr) - ce = mock.Mock() - m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", - return_value=ce) - lep = object() - m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", - return_value=lep) - - with mock.patch("wormhole._dilation.manager.Manager", - return_value=m) as ml: - with mock.patch("wormhole._dilation.manager.make_side", - return_value="us"): - with m_sca, m_sc, m_ce, m_sle: - dil.got_wormhole_versions({"can-dilate": ["1"]}) - self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", - None, reactor, eq, coop, host_addr, False)]) - self.assertEqual(m.mock_calls, [mock.call.set_subchannel_zero(scid0, sc), - mock.call.set_listener_endpoint(lep), - mock.call.start(), - mock.call.rx_PLEASE(pleasemsg), - mock.call.rx_HINTS(hintmsg), - mock.call.when_first_connected()]) - def test_transit_relay(self): - dil, send, reactor, eq, clock, coop = make_dilator() - dil._transit_key = b"key" - host_addr = dil._host_addr - relay = object() - d1 = dil.dilate(transit_relay_location=relay) - self.assertNoResult(d1) - + (dil, h) = make_dilator() + transit_relay_location = object() + side = object() m = mock.Mock() - alsoProvides(m, IDilationManager) - m.when_first_connected.return_value = Deferred() - - scid0 = 0 - sc = mock.Mock() - m_sc = mock.patch("wormhole._dilation.manager.SubChannel", - return_value=sc) - peer_addr = object() - m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress", - return_value=peer_addr) - ce = mock.Mock() - m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", - return_value=ce) - lep = object() - m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", - return_value=lep) - - with mock.patch("wormhole._dilation.manager.Manager", - return_value=m) as ml: - with mock.patch("wormhole._dilation.manager.make_side", - return_value="us"): - with m_sca, m_sc, m_ce, m_sle: - dil.got_wormhole_versions({"can-dilate": ["1"]}) - - self.assertEqual(ml.mock_calls, [ - mock.call(send, "us", b"key", - relay, reactor, eq, coop, host_addr, False)]) - self.assertEqual(m.mock_calls, [ - mock.call.set_subchannel_zero(scid0, sc), - mock.call.set_listener_endpoint(lep), - mock.call.start(), - mock.call.when_first_connected()]) - + mm = mock.Mock(side_effect=[m]) + with mock.patch("wormhole._dilation.manager.Manager", mm), \ + mock.patch("wormhole._dilation.manager.make_side", + return_value=side): + dil.dilate(transit_relay_location) + self.assertEqual(mm.mock_calls, [mock.call(h.send, side, transit_relay_location, + h.reactor, h.eq, h.coop, False)]) LEADER = "ff3456abcdef" FOLLOWER = "123456abcdef" def make_manager(leader=True): - class Holder: - pass h = Holder() h.send = mock.Mock() alsoProvides(h.send, ISend) @@ -274,8 +196,16 @@ def make_manager(leader=True): h.Inbound = mock.Mock(return_value=h.inbound) h.outbound = mock.Mock() h.Outbound = mock.Mock(return_value=h.outbound) + h.sc0 = mock.Mock() + alsoProvides(h.sc0, ISubChannel) + h.SubChannel = mock.Mock(return_value=h.sc0) + h.listen_ep = mock.Mock() + alsoProvides(h.listen_ep, IStreamServerEndpoint) with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound), \ - mock.patch("wormhole._dilation.manager.Outbound", h.Outbound): + mock.patch("wormhole._dilation.manager.Outbound", h.Outbound), \ + mock.patch("wormhole._dilation.manager.SubChannel", h.SubChannel), \ + mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", + return_value=h.listen_ep): m = Manager(h.send, side, h.relay, h.reactor, h.eq, h.coop) h.hostaddr = m._host_addr m.got_dilation_key(h.key) @@ -296,16 +226,26 @@ class TestManager(unittest.TestCase): self.assertEqual(h.send.mock_calls, []) self.assertEqual(h.Inbound.mock_calls, [mock.call(m, h.hostaddr)]) self.assertEqual(h.Outbound.mock_calls, [mock.call(m, h.coop)]) + scid0 = 0 + sc0_peer_addr = _SubchannelAddress(scid0) + self.assertEqual(h.SubChannel.mock_calls, [ + mock.call(scid0, m, m._host_addr, sc0_peer_addr), + ]) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.set_subchannel_zero(scid0, h.sc0), + mock.call.set_listener_endpoint(h.listen_ep) + ]) + clear_mock_calls(h.inbound) - m.start() + m.got_wormhole_versions({"can-dilate": ["1"]}) self.assertEqual(h.send.mock_calls, [ mock.call.send("dilate-0", dict_to_bytes({"type": "please", "side": LEADER})) ]) clear_mock_calls(h.send) - wfc_d = m.when_first_connected() - self.assertNoResult(wfc_d) + listen_d = m.get_endpoints().listen.listen(None) + self.assertNoResult(listen_d) # ignore early hints m.rx_HINTS({}) @@ -327,7 +267,7 @@ class TestManager(unittest.TestCase): self.assertEqual(c.mock_calls, [mock.call.start()]) clear_mock_calls(connector, c) - self.assertNoResult(wfc_d) + self.assertNoResult(listen_d) # now any inbound hints should get passed to our Connector with mock.patch("wormhole._dilation.manager.parse_hint", @@ -347,7 +287,7 @@ class TestManager(unittest.TestCase): clear_mock_calls(h.send) # the first successful connection fires when_first_connected(), so - # the Dilator can create and return the endpoints + # the endpoints can activate c1 = mock.Mock() m.connector_connection_made(c1) @@ -355,23 +295,6 @@ class TestManager(unittest.TestCase): self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c1)]) clear_mock_calls(h.inbound, h.outbound) - h.eq.flush_sync() - self.successResultOf(wfc_d) # fires with None - wfc_d2 = m.when_first_connected() - h.eq.flush_sync() - self.successResultOf(wfc_d2) - - scid0 = 0 - sc0 = mock.Mock() - m.set_subchannel_zero(scid0, sc0) - listen_ep = mock.Mock() - m.set_listener_endpoint(listen_ep) - self.assertEqual(h.inbound.mock_calls, [ - mock.call.set_subchannel_zero(scid0, sc0), - mock.call.set_listener_endpoint(listen_ep), - ]) - clear_mock_calls(h.inbound) - # the Leader making a new outbound channel should get scid=1 scid1 = 1 self.assertEqual(m.allocate_subchannel_id(), scid1) @@ -518,12 +441,13 @@ class TestManager(unittest.TestCase): def test_follower(self): m, h = make_manager(leader=False) - m.start() + m.got_wormhole_versions({"can-dilate": ["1"]}) self.assertEqual(h.send.mock_calls, [ mock.call.send("dilate-0", dict_to_bytes({"type": "please", "side": FOLLOWER})) ]) clear_mock_calls(h.send) + clear_mock_calls(h.inbound) c = mock.Mock() connector = mock.Mock(return_value=c) @@ -653,6 +577,7 @@ class TestManager(unittest.TestCase): def test_subchannel(self): m, h = make_manager(leader=True) + clear_mock_calls(h.inbound) sc = object() m.subchannel_pauseProducing(sc) @@ -689,3 +614,14 @@ class TestManager(unittest.TestCase): self.assertEqual(h.outbound.mock_calls, [ mock.call.subchannel_closed(4, sc)]) clear_mock_calls(h.inbound, h.outbound) + + def test_unknown_message(self): + # receive a PLEASE with the same side as us: shouldn't happen + m, h = make_manager(leader=True) + m.start() + + m.received_dilation_message(dict_to_bytes(dict(type="unknown"))) + self.flushLoggedErrors(UnknownDilationMessageType) + + # TODO: test transit relay is used +