terminator: shut down Dilator after everything else stops

This makes w.stop() the right way to shut everything down including any
Dilator connections (in-progress, active, or in-shutdown).
This commit is contained in:
Brian Warner 2019-02-10 18:01:14 -08:00
parent 7f90999775
commit c27680b910
7 changed files with 101 additions and 43 deletions

View File

@ -81,8 +81,8 @@ class Boss(object):
self._A.wire(self._RC, self._C) self._A.wire(self._RC, self._C)
self._I.wire(self._C, self._L) self._I.wire(self._C, self._L)
self._C.wire(self, self._A, self._N, self._K, self._I) self._C.wire(self, self._A, self._N, self._K, self._I)
self._T.wire(self, self._RC, self._N, self._M) self._T.wire(self, self._RC, self._N, self._M, self._D)
self._D.wire(self._S) self._D.wire(self._S, self._T)
def _init_other_state(self): def _init_other_state(self):
self._did_start_code = False self._did_start_code = False

View File

@ -5,9 +5,9 @@ from attr import attrs, attrib
from attr.validators import provides, instance_of, optional from attr.validators import provides, instance_of, optional
from automat import MethodicalMachine from automat import MethodicalMachine
from zope.interface import implementer from zope.interface import implementer
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, succeed from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
from twisted.python import log from twisted.python import log
from .._interfaces import IDilator, IDilationManager, ISend from .._interfaces import IDilator, IDilationManager, ISend, ITerminator
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr
from ..observer import OneShotObserver from ..observer import OneShotObserver
from .._key import derive_key from .._key import derive_key
@ -486,8 +486,9 @@ class Dilator(object):
self._pending_inbound_dilate_messages = deque() self._pending_inbound_dilate_messages = deque()
self._manager = None self._manager = None
def wire(self, sender): def wire(self, sender, terminator):
self._S = ISend(sender) self._S = ISend(sender)
self._T = ITerminator(terminator)
# this is the primary entry point, called when w.dilate() is invoked # this is the primary entry point, called when w.dilate() is invoked
def dilate(self, transit_relay_location=None): def dilate(self, transit_relay_location=None):
@ -547,12 +548,18 @@ class Dilator(object):
endpoints = (control_ep, connect_ep, listen_ep) endpoints = (control_ep, connect_ep, listen_ep)
returnValue(endpoints) returnValue(endpoints)
# Called by Terminator after everything else (mailbox, nameplate, server
# connection) has shut down. Expects to fire T.stoppedD() when Dilator is
# stopped too.
def stop(self): def stop(self):
if not self._started: if not self._started:
return succeed(None) self._T.stoppedD()
return
if self._started: if self._started:
self._manager.stop() self._manager.stop()
return self._manager.when_stopped() # TODO: avoid Deferreds for control flow, hard to serialize
self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD())
# TODO: tolerate multiple calls
# from Boss # from Boss

View File

@ -246,7 +246,7 @@ class RendezvousConnector(object):
# internal # internal
def _stopped(self, res): def _stopped(self, res):
self._T.stopped() self._T.stoppedRC()
def _tx(self, mtype, **kwargs): def _tx(self, mtype, **kwargs):
assert self._ws assert self._ws

View File

@ -15,15 +15,17 @@ class Terminator(object):
def __init__(self): def __init__(self):
self._mood = None self._mood = None
def wire(self, boss, rendezvous_connector, nameplate, mailbox): def wire(self, boss, rendezvous_connector, nameplate, mailbox, dilator):
self._B = _interfaces.IBoss(boss) self._B = _interfaces.IBoss(boss)
self._RC = _interfaces.IRendezvousConnector(rendezvous_connector) self._RC = _interfaces.IRendezvousConnector(rendezvous_connector)
self._N = _interfaces.INameplate(nameplate) self._N = _interfaces.INameplate(nameplate)
self._M = _interfaces.IMailbox(mailbox) self._M = _interfaces.IMailbox(mailbox)
self._D = _interfaces.IDilator(dilator)
# 4*2-1 main states: # 2*2-1+1 main states:
# (nm, m, n, 0): nameplate and/or mailbox is active # (nm, m, n, d): nameplate and/or mailbox is active
# (o, ""): open (not-yet-closing), or trying to close # (o, ""): open (not-yet-closing), or trying to close
# after closing the mailbox-server connection, we stop Dilation
# S0 is special: we don't hang out in it # S0 is special: we don't hang out in it
# TODO: rename o to 0, "" to 1. "S1" is special/terminal # TODO: rename o to 0, "" to 1. "S1" is special/terminal
@ -64,7 +66,11 @@ class Terminator(object):
# def S0(self): pass # unused # def S0(self): pass # unused
@m.state() @m.state()
def S_stopping(self): def S_stoppingRC(self):
pass # pragma: no cover
@m.state()
def S_stoppingD(self):
pass # pragma: no cover pass # pragma: no cover
@m.state() @m.state()
@ -88,7 +94,11 @@ class Terminator(object):
# from RendezvousConnector # from RendezvousConnector
@m.input() @m.input()
def stopped(self): def stoppedRC(self):
pass
@m.input()
def stoppedD(self):
pass pass
@m.output() @m.output()
@ -107,6 +117,10 @@ class Terminator(object):
def RC_stop(self): def RC_stop(self):
self._RC.stop() self._RC.stop()
@m.output()
def stop_dilator(self):
self._D.stop()
@m.output() @m.output()
def B_closed(self): def B_closed(self):
self._B.closed() self._B.closed()
@ -115,20 +129,19 @@ class Terminator(object):
Snmo.upon(close, enter=Snm, outputs=[close_nameplate, close_mailbox]) Snmo.upon(close, enter=Snm, outputs=[close_nameplate, close_mailbox])
Snmo.upon(nameplate_done, enter=Smo, outputs=[]) Snmo.upon(nameplate_done, enter=Smo, outputs=[])
Sno.upon(close, enter=Sn, outputs=[close_nameplate, close_mailbox]) Sno.upon(close, enter=Sn, outputs=[close_nameplate])
Sno.upon(nameplate_done, enter=S0o, outputs=[]) Sno.upon(nameplate_done, enter=S0o, outputs=[])
Smo.upon(close, enter=Sm, outputs=[close_nameplate, close_mailbox]) Smo.upon(close, enter=Sm, outputs=[close_mailbox])
Smo.upon(mailbox_done, enter=S0o, outputs=[]) Smo.upon(mailbox_done, enter=S0o, outputs=[])
Snm.upon(mailbox_done, enter=Sn, outputs=[]) Snm.upon(mailbox_done, enter=Sn, outputs=[])
Snm.upon(nameplate_done, enter=Sm, outputs=[]) Snm.upon(nameplate_done, enter=Sm, outputs=[])
Sn.upon(nameplate_done, enter=S_stopping, outputs=[RC_stop]) Sn.upon(nameplate_done, enter=S_stoppingRC, outputs=[RC_stop])
S0o.upon( Sm.upon(mailbox_done, enter=S_stoppingRC, outputs=[RC_stop])
close, S0o.upon(close, enter=S_stoppingRC, outputs=[ignore_mood_and_RC_stop])
enter=S_stopping,
outputs=[close_nameplate, close_mailbox, ignore_mood_and_RC_stop])
Sm.upon(mailbox_done, enter=S_stopping, outputs=[RC_stop])
S_stopping.upon(stopped, enter=S_stopped, outputs=[B_closed]) S_stoppingRC.upon(stoppedRC, enter=S_stoppingD, outputs=[stop_dilator])
S_stoppingD.upon(stoppedD, enter=S_stopped, outputs=[B_closed])

View File

@ -3,11 +3,12 @@ import mock
from twisted.internet import reactor from twisted.internet import reactor
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet.task import Cooperator from twisted.internet.task import Cooperator
from twisted.internet.defer import inlineCallbacks from twisted.internet.defer import Deferred, inlineCallbacks
from zope.interface import implementer from zope.interface import implementer
from ... import _interfaces from ... import _interfaces
from ...eventual import EventualQueue from ...eventual import EventualQueue
from ..._interfaces import ITerminator
from ..._dilation import manager from ..._dilation import manager
from ..._dilation._noise import NoiseConnection from ..._dilation._noise import NoiseConnection
@ -27,6 +28,13 @@ class MySend(object):
self.rx_phase += 1 self.rx_phase += 1
self.dilator.received_dilate(plaintext) self.dilator.received_dilate(plaintext)
@implementer(ITerminator)
class FakeTerminator(object):
def __init__(self):
self.d = Deferred()
def stoppedD(self):
self.d.callback(None)
class Connect(unittest.TestCase): class Connect(unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test1(self): def test1(self):
@ -41,14 +49,17 @@ class Connect(unittest.TestCase):
eq = EventualQueue(reactor) eq = EventualQueue(reactor)
cooperator = Cooperator(scheduler=eq.eventually) cooperator = Cooperator(scheduler=eq.eventually)
t_left = FakeTerminator()
t_right = FakeTerminator()
d_left = manager.Dilator(reactor, eq, cooperator, no_listen=True) d_left = manager.Dilator(reactor, eq, cooperator, no_listen=True)
d_left.wire(send_left) d_left.wire(send_left, t_left)
d_left.got_key(key) d_left.got_key(key)
d_left.got_wormhole_versions({"can-dilate": ["1"]}) d_left.got_wormhole_versions({"can-dilate": ["1"]})
send_left.dilator = d_left send_left.dilator = d_left
d_right = manager.Dilator(reactor, eq, cooperator) d_right = manager.Dilator(reactor, eq, cooperator)
d_right.wire(send_right) d_right.wire(send_right, t_right)
d_right.got_key(key) d_right.got_key(key)
d_right.got_wormhole_versions({"can-dilate": ["1"]}) d_right.got_wormhole_versions({"can-dilate": ["1"]})
send_right.dilator = d_right send_right.dilator = d_right
@ -69,8 +80,13 @@ class Connect(unittest.TestCase):
#control_ep_left.connect( #control_ep_left.connect(
# we shut down with w.close(), which calls Dilator.stop(), which # we normally shut down with w.close(), which calls Dilator.stop(),
# calls manager.stop() # which calls Terminator.stoppedD(), which (after everything else is
yield d_left.stop() # done) calls Boss.stopped
yield d_right.stop() d_left.stop()
d_right.stop()
yield t_left.d
yield t_right.d

View File

@ -5,7 +5,7 @@ from twisted.internet.defer import Deferred
from twisted.internet.task import Clock, Cooperator from twisted.internet.task import Clock, Cooperator
import mock import mock
from ...eventual import EventualQueue from ...eventual import EventualQueue
from ..._interfaces import ISend, IDilationManager from ..._interfaces import ISend, IDilationManager, ITerminator
from ...util import dict_to_bytes from ...util import dict_to_bytes
from ..._dilation import roles from ..._dilation import roles
from ..._dilation.encode import to_be4 from ..._dilation.encode import to_be4
@ -32,7 +32,9 @@ def make_dilator():
send = mock.Mock() send = mock.Mock()
alsoProvides(send, ISend) alsoProvides(send, ISend)
dil = Dilator(reactor, eq, coop) dil = Dilator(reactor, eq, coop)
dil.wire(send) terminator = mock.Mock()
alsoProvides(terminator, ITerminator)
dil.wire(send, terminator)
return dil, send, reactor, eq, clock, coop return dil, send, reactor, eq, clock, coop

View File

@ -1220,7 +1220,8 @@ class Terminator(unittest.TestCase):
rc = Dummy("rc", events, IRendezvousConnector, "stop") rc = Dummy("rc", events, IRendezvousConnector, "stop")
n = Dummy("n", events, INameplate, "close") n = Dummy("n", events, INameplate, "close")
m = Dummy("m", events, IMailbox, "close") m = Dummy("m", events, IMailbox, "close")
t.wire(b, rc, n, m) d = Dummy("d", events, IDilator, "stop")
t.wire(b, rc, n, m, d)
return t, b, rc, n, m, events return t, b, rc, n, m, events
# there are three events, and we need to test all orderings of them # there are three events, and we need to test all orderings of them
@ -1229,45 +1230,64 @@ class Terminator(unittest.TestCase):
input_events = { input_events = {
"mailbox": lambda: t.mailbox_done(), "mailbox": lambda: t.mailbox_done(),
"nameplate": lambda: t.nameplate_done(), "nameplate": lambda: t.nameplate_done(),
"close": lambda: t.close("happy"), "rc": lambda: t.close("happy"),
} }
close_events = [ close_events = [
("n.close", ), ("n.close", ),
("m.close", "happy"), ("m.close", "happy"),
] ]
if ev1 == "mailbox":
close_events.remove(("m.close", "happy"))
elif ev1 == "nameplate":
close_events.remove(("n.close",))
input_events[ev1]() input_events[ev1]()
expected = [] expected = []
if ev1 == "close": if ev1 == "rc":
expected.extend(close_events) expected.extend(close_events)
self.assertEqual(events, expected) self.assertEqual(events, expected)
events[:] = [] events[:] = []
if ev2 == "mailbox":
close_events.remove(("m.close", "happy"))
elif ev2 == "nameplate":
close_events.remove(("n.close",))
input_events[ev2]() input_events[ev2]()
expected = [] expected = []
if ev2 == "close": if ev2 == "rc":
expected.extend(close_events) expected.extend(close_events)
self.assertEqual(events, expected) self.assertEqual(events, expected)
events[:] = [] events[:] = []
if ev3 == "mailbox":
close_events.remove(("m.close", "happy"))
elif ev3 == "nameplate":
close_events.remove(("n.close",))
input_events[ev3]() input_events[ev3]()
expected = [] expected = []
if ev3 == "close": if ev3 == "rc":
expected.extend(close_events) expected.extend(close_events)
expected.append(("rc.stop", )) expected.append(("rc.stop", ))
self.assertEqual(events, expected) self.assertEqual(events, expected)
events[:] = [] events[:] = []
t.stopped() t.stoppedRC()
self.assertEqual(events, [("d.stop", )])
events[:] = []
t.stoppedD()
self.assertEqual(events, [("b.closed", )]) self.assertEqual(events, [("b.closed", )])
def test_terminate(self): def test_terminate(self):
self._do_test("mailbox", "nameplate", "close") self._do_test("mailbox", "nameplate", "rc")
self._do_test("mailbox", "close", "nameplate") self._do_test("mailbox", "rc", "nameplate")
self._do_test("nameplate", "mailbox", "close") self._do_test("nameplate", "mailbox", "rc")
self._do_test("nameplate", "close", "mailbox") self._do_test("nameplate", "rc", "mailbox")
self._do_test("close", "nameplate", "mailbox") self._do_test("rc", "nameplate", "mailbox")
self._do_test("close", "mailbox", "nameplate") self._do_test("rc", "mailbox", "nameplate")
# TODO: test moods # TODO: test moods