diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index ce650b7..226735b 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -81,8 +81,8 @@ class Boss(object): self._A.wire(self._RC, self._C) self._I.wire(self._C, self._L) self._C.wire(self, self._A, self._N, self._K, self._I) - self._T.wire(self, self._RC, self._N, self._M) - self._D.wire(self._S) + self._T.wire(self, self._RC, self._N, self._M, self._D) + self._D.wire(self._S, self._T) def _init_other_state(self): self._did_start_code = False diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index 7d50380..f5724be 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -5,9 +5,9 @@ from attr import attrs, attrib from attr.validators import provides, instance_of, optional from automat import MethodicalMachine from zope.interface import implementer -from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, succeed +from twisted.internet.defer import Deferred, inlineCallbacks, returnValue from twisted.python import log -from .._interfaces import IDilator, IDilationManager, ISend +from .._interfaces import IDilator, IDilationManager, ISend, ITerminator from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr from ..observer import OneShotObserver from .._key import derive_key @@ -486,8 +486,9 @@ class Dilator(object): self._pending_inbound_dilate_messages = deque() self._manager = None - def wire(self, sender): + def wire(self, sender, terminator): self._S = ISend(sender) + self._T = ITerminator(terminator) # this is the primary entry point, called when w.dilate() is invoked def dilate(self, transit_relay_location=None): @@ -547,12 +548,18 @@ class Dilator(object): endpoints = (control_ep, connect_ep, listen_ep) returnValue(endpoints) + # Called by Terminator after everything else (mailbox, nameplate, server + # connection) has shut down. Expects to fire T.stoppedD() when Dilator is + # stopped too. def stop(self): if not self._started: - return succeed(None) + self._T.stoppedD() + return if self._started: self._manager.stop() - return self._manager.when_stopped() + # TODO: avoid Deferreds for control flow, hard to serialize + self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD()) + # TODO: tolerate multiple calls # from Boss diff --git a/src/wormhole/_rendezvous.py b/src/wormhole/_rendezvous.py index db02166..b27f318 100644 --- a/src/wormhole/_rendezvous.py +++ b/src/wormhole/_rendezvous.py @@ -246,7 +246,7 @@ class RendezvousConnector(object): # internal def _stopped(self, res): - self._T.stopped() + self._T.stoppedRC() def _tx(self, mtype, **kwargs): assert self._ws diff --git a/src/wormhole/_terminator.py b/src/wormhole/_terminator.py index fe4bdcb..c45f6d7 100644 --- a/src/wormhole/_terminator.py +++ b/src/wormhole/_terminator.py @@ -15,15 +15,17 @@ class Terminator(object): def __init__(self): self._mood = None - def wire(self, boss, rendezvous_connector, nameplate, mailbox): + def wire(self, boss, rendezvous_connector, nameplate, mailbox, dilator): self._B = _interfaces.IBoss(boss) self._RC = _interfaces.IRendezvousConnector(rendezvous_connector) self._N = _interfaces.INameplate(nameplate) self._M = _interfaces.IMailbox(mailbox) + self._D = _interfaces.IDilator(dilator) - # 4*2-1 main states: - # (nm, m, n, 0): nameplate and/or mailbox is active + # 2*2-1+1 main states: + # (nm, m, n, d): nameplate and/or mailbox is active # (o, ""): open (not-yet-closing), or trying to close + # after closing the mailbox-server connection, we stop Dilation # S0 is special: we don't hang out in it # TODO: rename o to 0, "" to 1. "S1" is special/terminal @@ -64,7 +66,11 @@ class Terminator(object): # def S0(self): pass # unused @m.state() - def S_stopping(self): + def S_stoppingRC(self): + pass # pragma: no cover + + @m.state() + def S_stoppingD(self): pass # pragma: no cover @m.state() @@ -88,7 +94,11 @@ class Terminator(object): # from RendezvousConnector @m.input() - def stopped(self): + def stoppedRC(self): + pass + + @m.input() + def stoppedD(self): pass @m.output() @@ -107,6 +117,10 @@ class Terminator(object): def RC_stop(self): self._RC.stop() + @m.output() + def stop_dilator(self): + self._D.stop() + @m.output() def B_closed(self): self._B.closed() @@ -115,20 +129,19 @@ class Terminator(object): Snmo.upon(close, enter=Snm, outputs=[close_nameplate, close_mailbox]) Snmo.upon(nameplate_done, enter=Smo, outputs=[]) - Sno.upon(close, enter=Sn, outputs=[close_nameplate, close_mailbox]) + Sno.upon(close, enter=Sn, outputs=[close_nameplate]) Sno.upon(nameplate_done, enter=S0o, outputs=[]) - Smo.upon(close, enter=Sm, outputs=[close_nameplate, close_mailbox]) + Smo.upon(close, enter=Sm, outputs=[close_mailbox]) Smo.upon(mailbox_done, enter=S0o, outputs=[]) Snm.upon(mailbox_done, enter=Sn, outputs=[]) Snm.upon(nameplate_done, enter=Sm, outputs=[]) - Sn.upon(nameplate_done, enter=S_stopping, outputs=[RC_stop]) - S0o.upon( - close, - enter=S_stopping, - outputs=[close_nameplate, close_mailbox, ignore_mood_and_RC_stop]) - Sm.upon(mailbox_done, enter=S_stopping, outputs=[RC_stop]) + Sn.upon(nameplate_done, enter=S_stoppingRC, outputs=[RC_stop]) + Sm.upon(mailbox_done, enter=S_stoppingRC, outputs=[RC_stop]) + S0o.upon(close, enter=S_stoppingRC, outputs=[ignore_mood_and_RC_stop]) - S_stopping.upon(stopped, enter=S_stopped, outputs=[B_closed]) + S_stoppingRC.upon(stoppedRC, enter=S_stoppingD, outputs=[stop_dilator]) + + S_stoppingD.upon(stoppedD, enter=S_stopped, outputs=[B_closed]) diff --git a/src/wormhole/test/dilate/test_connect.py b/src/wormhole/test/dilate/test_connect.py index a96330f..7a60400 100644 --- a/src/wormhole/test/dilate/test_connect.py +++ b/src/wormhole/test/dilate/test_connect.py @@ -3,11 +3,12 @@ import mock from twisted.internet import reactor from twisted.trial import unittest from twisted.internet.task import Cooperator -from twisted.internet.defer import inlineCallbacks +from twisted.internet.defer import Deferred, inlineCallbacks from zope.interface import implementer from ... import _interfaces from ...eventual import EventualQueue +from ..._interfaces import ITerminator from ..._dilation import manager from ..._dilation._noise import NoiseConnection @@ -27,6 +28,13 @@ class MySend(object): self.rx_phase += 1 self.dilator.received_dilate(plaintext) +@implementer(ITerminator) +class FakeTerminator(object): + def __init__(self): + self.d = Deferred() + def stoppedD(self): + self.d.callback(None) + class Connect(unittest.TestCase): @inlineCallbacks def test1(self): @@ -41,14 +49,17 @@ class Connect(unittest.TestCase): eq = EventualQueue(reactor) cooperator = Cooperator(scheduler=eq.eventually) + t_left = FakeTerminator() + t_right = FakeTerminator() + d_left = manager.Dilator(reactor, eq, cooperator, no_listen=True) - d_left.wire(send_left) + d_left.wire(send_left, t_left) d_left.got_key(key) d_left.got_wormhole_versions({"can-dilate": ["1"]}) send_left.dilator = d_left d_right = manager.Dilator(reactor, eq, cooperator) - d_right.wire(send_right) + d_right.wire(send_right, t_right) d_right.got_key(key) d_right.got_wormhole_versions({"can-dilate": ["1"]}) send_right.dilator = d_right @@ -69,8 +80,13 @@ class Connect(unittest.TestCase): #control_ep_left.connect( - # we shut down with w.close(), which calls Dilator.stop(), which - # calls manager.stop() - yield d_left.stop() - yield d_right.stop() + # we normally shut down with w.close(), which calls Dilator.stop(), + # which calls Terminator.stoppedD(), which (after everything else is + # done) calls Boss.stopped + d_left.stop() + d_right.stop() + + yield t_left.d + yield t_right.d + diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index ee31c5e..b8258f1 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -5,7 +5,7 @@ from twisted.internet.defer import Deferred from twisted.internet.task import Clock, Cooperator import mock from ...eventual import EventualQueue -from ..._interfaces import ISend, IDilationManager +from ..._interfaces import ISend, IDilationManager, ITerminator from ...util import dict_to_bytes from ..._dilation import roles from ..._dilation.encode import to_be4 @@ -32,7 +32,9 @@ def make_dilator(): send = mock.Mock() alsoProvides(send, ISend) dil = Dilator(reactor, eq, coop) - dil.wire(send) + terminator = mock.Mock() + alsoProvides(terminator, ITerminator) + dil.wire(send, terminator) return dil, send, reactor, eq, clock, coop diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index dff3fb0..9e417d6 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -1220,7 +1220,8 @@ class Terminator(unittest.TestCase): rc = Dummy("rc", events, IRendezvousConnector, "stop") n = Dummy("n", events, INameplate, "close") m = Dummy("m", events, IMailbox, "close") - t.wire(b, rc, n, m) + d = Dummy("d", events, IDilator, "stop") + t.wire(b, rc, n, m, d) return t, b, rc, n, m, events # there are three events, and we need to test all orderings of them @@ -1229,45 +1230,64 @@ class Terminator(unittest.TestCase): input_events = { "mailbox": lambda: t.mailbox_done(), "nameplate": lambda: t.nameplate_done(), - "close": lambda: t.close("happy"), + "rc": lambda: t.close("happy"), } close_events = [ ("n.close", ), ("m.close", "happy"), ] + if ev1 == "mailbox": + close_events.remove(("m.close", "happy")) + elif ev1 == "nameplate": + close_events.remove(("n.close",)) + input_events[ev1]() expected = [] - if ev1 == "close": + if ev1 == "rc": expected.extend(close_events) self.assertEqual(events, expected) events[:] = [] + if ev2 == "mailbox": + close_events.remove(("m.close", "happy")) + elif ev2 == "nameplate": + close_events.remove(("n.close",)) + input_events[ev2]() expected = [] - if ev2 == "close": + if ev2 == "rc": expected.extend(close_events) self.assertEqual(events, expected) events[:] = [] + if ev3 == "mailbox": + close_events.remove(("m.close", "happy")) + elif ev3 == "nameplate": + close_events.remove(("n.close",)) + input_events[ev3]() expected = [] - if ev3 == "close": + if ev3 == "rc": expected.extend(close_events) expected.append(("rc.stop", )) self.assertEqual(events, expected) events[:] = [] - t.stopped() + t.stoppedRC() + self.assertEqual(events, [("d.stop", )]) + events[:] = [] + + t.stoppedD() self.assertEqual(events, [("b.closed", )]) def test_terminate(self): - self._do_test("mailbox", "nameplate", "close") - self._do_test("mailbox", "close", "nameplate") - self._do_test("nameplate", "mailbox", "close") - self._do_test("nameplate", "close", "mailbox") - self._do_test("close", "nameplate", "mailbox") - self._do_test("close", "mailbox", "nameplate") + self._do_test("mailbox", "nameplate", "rc") + self._do_test("mailbox", "rc", "nameplate") + self._do_test("nameplate", "mailbox", "rc") + self._do_test("nameplate", "rc", "mailbox") + self._do_test("rc", "nameplate", "mailbox") + self._do_test("rc", "mailbox", "nameplate") # TODO: test moods