update subchannel state machine for half-close

also handle open-but-not-yet-connected subchannels better
This commit is contained in:
Brian Warner 2019-08-04 11:51:33 -07:00
parent b233763082
commit 327e72e6ac
2 changed files with 183 additions and 27 deletions

View File

@ -6,6 +6,7 @@ from zope.interface import implementer
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
IAddress, IListeningPort, IAddress, IListeningPort,
IHalfCloseableProtocol,
IStreamClientEndpoint, IStreamClientEndpoint,
IStreamServerEndpoint) IStreamServerEndpoint)
from twisted.internet.error import ConnectionDone from twisted.internet.error import ConnectionDone
@ -87,11 +88,29 @@ class SubChannel(object):
# self._pending_outbound = {} # self._pending_outbound = {}
# self._processed = set() # self._processed = set()
self._protocol = None self._protocol = None
self._pending_dataReceived = [] self._pending_remote_data = []
self._pending_connectionLost = (False, None) self._pending_remote_close = False
@m.state(initial=True) @m.state(initial=True)
def open(self): def unconnected(self):
pass # pragma: no cover
# once we get the IProtocol, it's either a IHalfCloseableProtocol, or it
# can only be fully closed
@m.state()
def open_half(self):
pass # pragma: no cover
@m.state()
def read_closed():
pass # pragma: no cover
@m.state()
def write_closed():
pass # pragma: no cover
@m.state()
def open_full(self):
pass # pragma: no cover pass # pragma: no cover
@m.state() @m.state()
@ -102,6 +121,14 @@ class SubChannel(object):
def closed(): def closed():
pass # pragma: no cover pass # pragma: no cover
@m.input()
def connect_protocol_half(self):
pass
@m.input()
def connect_protocol_full(self):
pass
@m.input() @m.input()
def remote_data(self, data): def remote_data(self, data):
pass pass
@ -118,6 +145,14 @@ class SubChannel(object):
def local_close(self): def local_close(self):
pass pass
@m.output()
def queue_remote_data(self, data):
self._pending_remote_data.append(data)
@m.output()
def queue_remote_close(self):
self._pending_remote_close = True
@m.output() @m.output()
def send_data(self, data): def send_data(self, data):
self._manager.send_data(self._scid, data) self._manager.send_data(self._scid, data)
@ -128,17 +163,24 @@ class SubChannel(object):
@m.output() @m.output()
def signal_dataReceived(self, data): def signal_dataReceived(self, data):
if self._protocol: assert self._protocol
self._protocol.dataReceived(data) self._protocol.dataReceived(data)
else:
self._pending_dataReceived.append(data) @m.output()
def signal_readConnectionLost(self):
IHalfCloseableProtocol(self._protocol).readConnectionLost()
@m.output()
def signal_writeConnectionLost(self):
IHalfCloseableProtocol(self._protocol).writeConnectionLost()
@m.output() @m.output()
def signal_connectionLost(self): def signal_connectionLost(self):
if self._protocol: assert self._protocol
self._protocol.connectionLost(ConnectionDone()) self._protocol.connectionLost(ConnectionDone())
else:
self._pending_connectionLost = (True, ConnectionDone()) @m.output()
def close_subchannel(self):
self._manager.subchannel_closed(self._scid, self) self._manager.subchannel_closed(self._scid, self)
# we're deleted momentarily # we're deleted momentarily
@ -151,14 +193,44 @@ class SubChannel(object):
raise AlreadyClosedError( raise AlreadyClosedError(
"loseConnection not allowed on closed subchannel") "loseConnection not allowed on closed subchannel")
# primary transitions # stuff that arrives before we have a protocol connected
open.upon(remote_data, enter=open, outputs=[signal_dataReceived]) unconnected.upon(remote_data, enter=unconnected, outputs=[queue_remote_data])
open.upon(local_data, enter=open, outputs=[send_data]) unconnected.upon(remote_close, enter=unconnected, outputs=[queue_remote_close])
open.upon(remote_close, enter=closed, outputs=[send_close, signal_connectionLost])
open.upon(local_close, enter=closing, outputs=[send_close])
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
# IHalfCloseableProtocol flow
unconnected.upon(connect_protocol_half, enter=open_half, outputs=[])
open_half.upon(remote_data, enter=open_half, outputs=[signal_dataReceived])
open_half.upon(local_data, enter=open_half, outputs=[send_data])
# remote closes first
open_half.upon(remote_close, enter=read_closed, outputs=[signal_readConnectionLost])
read_closed.upon(local_data, enter=read_closed, outputs=[send_data])
read_closed.upon(local_close, enter=closed, outputs=[send_close,
close_subchannel,
# TODO: eventual-signal this?
signal_writeConnectionLost,
])
# local closes first
open_half.upon(local_close, enter=write_closed, outputs=[signal_writeConnectionLost,
send_close])
write_closed.upon(local_data, enter=write_closed, outputs=[error_closed_write])
write_closed.upon(remote_data, enter=write_closed, outputs=[signal_dataReceived])
write_closed.upon(remote_close, enter=closed, outputs=[close_subchannel,
signal_readConnectionLost,
])
# error cases
write_closed.upon(local_close, enter=write_closed, outputs=[error_closed_close])
# fully-closeable-only flow
unconnected.upon(connect_protocol_full, enter=open_full, outputs=[])
open_full.upon(remote_data, enter=open_full, outputs=[signal_dataReceived])
open_full.upon(local_data, enter=open_full, outputs=[send_data])
open_full.upon(remote_close, enter=closed, outputs=[send_close,
close_subchannel,
signal_connectionLost])
open_full.upon(local_close, enter=closing, outputs=[send_close])
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
closing.upon(remote_close, enter=closed, outputs=[close_subchannel,
signal_connectionLost])
# error cases # error cases
# we won't ever see an OPEN, since L4 will log+ignore those for us # we won't ever see an OPEN, since L4 will log+ignore those for us
closing.upon(local_data, enter=closing, outputs=[error_closed_write]) closing.upon(local_data, enter=closing, outputs=[error_closed_write])
@ -170,15 +242,19 @@ class SubChannel(object):
def _set_protocol(self, protocol): def _set_protocol(self, protocol):
assert not self._protocol assert not self._protocol
self._protocol = protocol self._protocol = protocol
if IHalfCloseableProtocol.providedBy(protocol):
self.connect_protocol_half()
else:
# move from UNCONNECTED to OPEN
self.connect_protocol_full();
def _deliver_queued_data(self): def _deliver_queued_data(self):
if self._pending_dataReceived: for data in self._pending_remote_data:
for data in self._pending_dataReceived: self.remote_data(data)
self._protocol.dataReceived(data) del self._pending_remote_data
self._pending_dataReceived = [] if self._pending_remote_close:
cl, what = self._pending_connectionLost self.remote_close()
if cl: del self._pending_remote_close
self._protocol.connectionLost(what)
# ITransport # ITransport
def write(self, data): def write(self, data):

View File

@ -1,7 +1,8 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import mock import mock
from zope.interface import directlyProvides
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet.interfaces import ITransport from twisted.internet.interfaces import ITransport, IHalfCloseableProtocol
from twisted.internet.error import ConnectionDone from twisted.internet.error import ConnectionDone
from ..._dilation.subchannel import (Once, SubChannel, from ..._dilation.subchannel import (Once, SubChannel,
_WormholeAddress, _SubchannelAddress, _WormholeAddress, _SubchannelAddress,
@ -9,13 +10,15 @@ from ..._dilation.subchannel import (Once, SubChannel,
from .common import mock_manager from .common import mock_manager
def make_sc(set_protocol=True): def make_sc(set_protocol=True, half_closeable=False):
scid = 4 scid = 4
hostaddr = _WormholeAddress() hostaddr = _WormholeAddress()
peeraddr = _SubchannelAddress(scid) peeraddr = _SubchannelAddress(scid)
m = mock_manager() m = mock_manager()
sc = SubChannel(scid, m, hostaddr, peeraddr) sc = SubChannel(scid, m, hostaddr, peeraddr)
p = mock.Mock() p = mock.Mock()
if half_closeable:
directlyProvides(p, IHalfCloseableProtocol)
if set_protocol: if set_protocol:
sc._set_protocol(p) sc._set_protocol(p)
return sc, m, scid, hostaddr, peeraddr, p return sc, m, scid, hostaddr, peeraddr, p
@ -145,3 +148,80 @@ class SubChannelAPI(unittest.TestCase):
# TODO: more, once this is implemented # TODO: more, once this is implemented
sc.registerProducer(None, True) sc.registerProducer(None, True)
sc.unregisterProducer() sc.unregisterProducer()
class HalfCloseable(unittest.TestCase):
def test_create(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
self.assert_(ITransport.providedBy(sc))
self.assertEqual(m.mock_calls, [])
self.assertIdentical(sc.getHost(), hostaddr)
self.assertIdentical(sc.getPeer(), peeraddr)
def test_local_close(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
sc.write(b"data")
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
m.mock_calls[:] = []
sc.writeSequence([b"more", b"data"])
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")])
m.mock_calls[:] = []
sc.remote_data(b"inbound1")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")])
p.mock_calls[:] = []
# after a local close, we can't write anymore, but we can still
# receive data
sc.loseConnection()
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
m.mock_calls[:] = []
self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])
p.mock_calls[:] = []
with self.assertRaises(AlreadyClosedError) as e:
sc.write(b"data")
self.assertEqual(str(e.exception),
"write not allowed on closed subchannel")
with self.assertRaises(AlreadyClosedError) as e:
sc.loseConnection()
self.assertEqual(str(e.exception),
"loseConnection not allowed on closed subchannel")
sc.remote_data(b"inbound2")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound2")])
p.mock_calls[:] = []
# the remote end will finally shut down the connection
sc.remote_close()
self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(scid, sc)])
self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()])
def test_remote_close(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
sc.write(b"data")
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
m.mock_calls[:] = []
sc.remote_data(b"inbound1")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")])
p.mock_calls[:] = []
# after a remote close, we can still write data
sc.remote_close()
self.assertEqual(m.mock_calls, [])
self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()])
p.mock_calls[:] = []
sc.write(b"out2")
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"out2")])
m.mock_calls[:] = []
# and a local close will shutdown the connection
sc.loseConnection()
self.assertEqual(m.mock_calls, [mock.call.send_close(scid),
mock.call.subchannel_closed(scid, sc)])
self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])