terminator: shut down Dilator after everything else stops

This makes w.stop() the right way to shut everything down including any
Dilator connections (in-progress, active, or in-shutdown).
This commit is contained in:
Brian Warner 2019-02-10 18:01:14 -08:00
parent 7f90999775
commit c27680b910
7 changed files with 101 additions and 43 deletions

View File

@ -81,8 +81,8 @@ class Boss(object):
self._A.wire(self._RC, self._C)
self._I.wire(self._C, self._L)
self._C.wire(self, self._A, self._N, self._K, self._I)
self._T.wire(self, self._RC, self._N, self._M)
self._D.wire(self._S)
self._T.wire(self, self._RC, self._N, self._M, self._D)
self._D.wire(self._S, self._T)
def _init_other_state(self):
self._did_start_code = False

View File

@ -5,9 +5,9 @@ from attr import attrs, attrib
from attr.validators import provides, instance_of, optional
from automat import MethodicalMachine
from zope.interface import implementer
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, succeed
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
from twisted.python import log
from .._interfaces import IDilator, IDilationManager, ISend
from .._interfaces import IDilator, IDilationManager, ISend, ITerminator
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr
from ..observer import OneShotObserver
from .._key import derive_key
@ -486,8 +486,9 @@ class Dilator(object):
self._pending_inbound_dilate_messages = deque()
self._manager = None
def wire(self, sender):
def wire(self, sender, terminator):
self._S = ISend(sender)
self._T = ITerminator(terminator)
# this is the primary entry point, called when w.dilate() is invoked
def dilate(self, transit_relay_location=None):
@ -547,12 +548,18 @@ class Dilator(object):
endpoints = (control_ep, connect_ep, listen_ep)
returnValue(endpoints)
# Called by Terminator after everything else (mailbox, nameplate, server
# connection) has shut down. Expects to fire T.stoppedD() when Dilator is
# stopped too.
def stop(self):
if not self._started:
return succeed(None)
self._T.stoppedD()
return
if self._started:
self._manager.stop()
return self._manager.when_stopped()
# TODO: avoid Deferreds for control flow, hard to serialize
self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD())
# TODO: tolerate multiple calls
# from Boss

View File

@ -246,7 +246,7 @@ class RendezvousConnector(object):
# internal
def _stopped(self, res):
self._T.stopped()
self._T.stoppedRC()
def _tx(self, mtype, **kwargs):
assert self._ws

View File

@ -15,15 +15,17 @@ class Terminator(object):
def __init__(self):
self._mood = None
def wire(self, boss, rendezvous_connector, nameplate, mailbox):
def wire(self, boss, rendezvous_connector, nameplate, mailbox, dilator):
self._B = _interfaces.IBoss(boss)
self._RC = _interfaces.IRendezvousConnector(rendezvous_connector)
self._N = _interfaces.INameplate(nameplate)
self._M = _interfaces.IMailbox(mailbox)
self._D = _interfaces.IDilator(dilator)
# 4*2-1 main states:
# (nm, m, n, 0): nameplate and/or mailbox is active
# 2*2-1+1 main states:
# (nm, m, n, d): nameplate and/or mailbox is active
# (o, ""): open (not-yet-closing), or trying to close
# after closing the mailbox-server connection, we stop Dilation
# S0 is special: we don't hang out in it
# TODO: rename o to 0, "" to 1. "S1" is special/terminal
@ -64,7 +66,11 @@ class Terminator(object):
# def S0(self): pass # unused
@m.state()
def S_stopping(self):
def S_stoppingRC(self):
pass # pragma: no cover
@m.state()
def S_stoppingD(self):
pass # pragma: no cover
@m.state()
@ -88,7 +94,11 @@ class Terminator(object):
# from RendezvousConnector
@m.input()
def stopped(self):
def stoppedRC(self):
pass
@m.input()
def stoppedD(self):
pass
@m.output()
@ -107,6 +117,10 @@ class Terminator(object):
def RC_stop(self):
self._RC.stop()
@m.output()
def stop_dilator(self):
self._D.stop()
@m.output()
def B_closed(self):
self._B.closed()
@ -115,20 +129,19 @@ class Terminator(object):
Snmo.upon(close, enter=Snm, outputs=[close_nameplate, close_mailbox])
Snmo.upon(nameplate_done, enter=Smo, outputs=[])
Sno.upon(close, enter=Sn, outputs=[close_nameplate, close_mailbox])
Sno.upon(close, enter=Sn, outputs=[close_nameplate])
Sno.upon(nameplate_done, enter=S0o, outputs=[])
Smo.upon(close, enter=Sm, outputs=[close_nameplate, close_mailbox])
Smo.upon(close, enter=Sm, outputs=[close_mailbox])
Smo.upon(mailbox_done, enter=S0o, outputs=[])
Snm.upon(mailbox_done, enter=Sn, outputs=[])
Snm.upon(nameplate_done, enter=Sm, outputs=[])
Sn.upon(nameplate_done, enter=S_stopping, outputs=[RC_stop])
S0o.upon(
close,
enter=S_stopping,
outputs=[close_nameplate, close_mailbox, ignore_mood_and_RC_stop])
Sm.upon(mailbox_done, enter=S_stopping, outputs=[RC_stop])
Sn.upon(nameplate_done, enter=S_stoppingRC, outputs=[RC_stop])
Sm.upon(mailbox_done, enter=S_stoppingRC, outputs=[RC_stop])
S0o.upon(close, enter=S_stoppingRC, outputs=[ignore_mood_and_RC_stop])
S_stopping.upon(stopped, enter=S_stopped, outputs=[B_closed])
S_stoppingRC.upon(stoppedRC, enter=S_stoppingD, outputs=[stop_dilator])
S_stoppingD.upon(stoppedD, enter=S_stopped, outputs=[B_closed])

View File

@ -3,11 +3,12 @@ import mock
from twisted.internet import reactor
from twisted.trial import unittest
from twisted.internet.task import Cooperator
from twisted.internet.defer import inlineCallbacks
from twisted.internet.defer import Deferred, inlineCallbacks
from zope.interface import implementer
from ... import _interfaces
from ...eventual import EventualQueue
from ..._interfaces import ITerminator
from ..._dilation import manager
from ..._dilation._noise import NoiseConnection
@ -27,6 +28,13 @@ class MySend(object):
self.rx_phase += 1
self.dilator.received_dilate(plaintext)
@implementer(ITerminator)
class FakeTerminator(object):
def __init__(self):
self.d = Deferred()
def stoppedD(self):
self.d.callback(None)
class Connect(unittest.TestCase):
@inlineCallbacks
def test1(self):
@ -41,14 +49,17 @@ class Connect(unittest.TestCase):
eq = EventualQueue(reactor)
cooperator = Cooperator(scheduler=eq.eventually)
t_left = FakeTerminator()
t_right = FakeTerminator()
d_left = manager.Dilator(reactor, eq, cooperator, no_listen=True)
d_left.wire(send_left)
d_left.wire(send_left, t_left)
d_left.got_key(key)
d_left.got_wormhole_versions({"can-dilate": ["1"]})
send_left.dilator = d_left
d_right = manager.Dilator(reactor, eq, cooperator)
d_right.wire(send_right)
d_right.wire(send_right, t_right)
d_right.got_key(key)
d_right.got_wormhole_versions({"can-dilate": ["1"]})
send_right.dilator = d_right
@ -69,8 +80,13 @@ class Connect(unittest.TestCase):
#control_ep_left.connect(
# we shut down with w.close(), which calls Dilator.stop(), which
# calls manager.stop()
yield d_left.stop()
yield d_right.stop()
# we normally shut down with w.close(), which calls Dilator.stop(),
# which calls Terminator.stoppedD(), which (after everything else is
# done) calls Boss.stopped
d_left.stop()
d_right.stop()
yield t_left.d
yield t_right.d

View File

@ -5,7 +5,7 @@ from twisted.internet.defer import Deferred
from twisted.internet.task import Clock, Cooperator
import mock
from ...eventual import EventualQueue
from ..._interfaces import ISend, IDilationManager
from ..._interfaces import ISend, IDilationManager, ITerminator
from ...util import dict_to_bytes
from ..._dilation import roles
from ..._dilation.encode import to_be4
@ -32,7 +32,9 @@ def make_dilator():
send = mock.Mock()
alsoProvides(send, ISend)
dil = Dilator(reactor, eq, coop)
dil.wire(send)
terminator = mock.Mock()
alsoProvides(terminator, ITerminator)
dil.wire(send, terminator)
return dil, send, reactor, eq, clock, coop

View File

@ -1220,7 +1220,8 @@ class Terminator(unittest.TestCase):
rc = Dummy("rc", events, IRendezvousConnector, "stop")
n = Dummy("n", events, INameplate, "close")
m = Dummy("m", events, IMailbox, "close")
t.wire(b, rc, n, m)
d = Dummy("d", events, IDilator, "stop")
t.wire(b, rc, n, m, d)
return t, b, rc, n, m, events
# there are three events, and we need to test all orderings of them
@ -1229,45 +1230,64 @@ class Terminator(unittest.TestCase):
input_events = {
"mailbox": lambda: t.mailbox_done(),
"nameplate": lambda: t.nameplate_done(),
"close": lambda: t.close("happy"),
"rc": lambda: t.close("happy"),
}
close_events = [
("n.close", ),
("m.close", "happy"),
]
if ev1 == "mailbox":
close_events.remove(("m.close", "happy"))
elif ev1 == "nameplate":
close_events.remove(("n.close",))
input_events[ev1]()
expected = []
if ev1 == "close":
if ev1 == "rc":
expected.extend(close_events)
self.assertEqual(events, expected)
events[:] = []
if ev2 == "mailbox":
close_events.remove(("m.close", "happy"))
elif ev2 == "nameplate":
close_events.remove(("n.close",))
input_events[ev2]()
expected = []
if ev2 == "close":
if ev2 == "rc":
expected.extend(close_events)
self.assertEqual(events, expected)
events[:] = []
if ev3 == "mailbox":
close_events.remove(("m.close", "happy"))
elif ev3 == "nameplate":
close_events.remove(("n.close",))
input_events[ev3]()
expected = []
if ev3 == "close":
if ev3 == "rc":
expected.extend(close_events)
expected.append(("rc.stop", ))
self.assertEqual(events, expected)
events[:] = []
t.stopped()
t.stoppedRC()
self.assertEqual(events, [("d.stop", )])
events[:] = []
t.stoppedD()
self.assertEqual(events, [("b.closed", )])
def test_terminate(self):
self._do_test("mailbox", "nameplate", "close")
self._do_test("mailbox", "close", "nameplate")
self._do_test("nameplate", "mailbox", "close")
self._do_test("nameplate", "close", "mailbox")
self._do_test("close", "nameplate", "mailbox")
self._do_test("close", "mailbox", "nameplate")
self._do_test("mailbox", "nameplate", "rc")
self._do_test("mailbox", "rc", "nameplate")
self._do_test("nameplate", "mailbox", "rc")
self._do_test("nameplate", "rc", "mailbox")
self._do_test("rc", "nameplate", "mailbox")
self._do_test("rc", "mailbox", "nameplate")
# TODO: test moods