update subchannel state machine for half-close

also handle open-but-not-yet-connected subchannels better
This commit is contained in:
Brian Warner 2019-08-04 11:51:33 -07:00
parent b233763082
commit 327e72e6ac
2 changed files with 183 additions and 27 deletions

View File

@ -6,6 +6,7 @@ from zope.interface import implementer
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
IAddress, IListeningPort,
IHalfCloseableProtocol,
IStreamClientEndpoint,
IStreamServerEndpoint)
from twisted.internet.error import ConnectionDone
@ -87,11 +88,29 @@ class SubChannel(object):
# self._pending_outbound = {}
# self._processed = set()
self._protocol = None
self._pending_dataReceived = []
self._pending_connectionLost = (False, None)
self._pending_remote_data = []
self._pending_remote_close = False
@m.state(initial=True)
def open(self):
def unconnected(self):
pass # pragma: no cover
# once we get the IProtocol, it's either a IHalfCloseableProtocol, or it
# can only be fully closed
@m.state()
def open_half(self):
pass # pragma: no cover
@m.state()
def read_closed():
pass # pragma: no cover
@m.state()
def write_closed():
pass # pragma: no cover
@m.state()
def open_full(self):
pass # pragma: no cover
@m.state()
@ -102,6 +121,14 @@ class SubChannel(object):
def closed():
pass # pragma: no cover
@m.input()
def connect_protocol_half(self):
pass
@m.input()
def connect_protocol_full(self):
pass
@m.input()
def remote_data(self, data):
pass
@ -118,6 +145,14 @@ class SubChannel(object):
def local_close(self):
pass
@m.output()
def queue_remote_data(self, data):
self._pending_remote_data.append(data)
@m.output()
def queue_remote_close(self):
self._pending_remote_close = True
@m.output()
def send_data(self, data):
self._manager.send_data(self._scid, data)
@ -128,17 +163,24 @@ class SubChannel(object):
@m.output()
def signal_dataReceived(self, data):
if self._protocol:
self._protocol.dataReceived(data)
else:
self._pending_dataReceived.append(data)
assert self._protocol
self._protocol.dataReceived(data)
@m.output()
def signal_readConnectionLost(self):
IHalfCloseableProtocol(self._protocol).readConnectionLost()
@m.output()
def signal_writeConnectionLost(self):
IHalfCloseableProtocol(self._protocol).writeConnectionLost()
@m.output()
def signal_connectionLost(self):
if self._protocol:
self._protocol.connectionLost(ConnectionDone())
else:
self._pending_connectionLost = (True, ConnectionDone())
assert self._protocol
self._protocol.connectionLost(ConnectionDone())
@m.output()
def close_subchannel(self):
self._manager.subchannel_closed(self._scid, self)
# we're deleted momentarily
@ -151,14 +193,44 @@ class SubChannel(object):
raise AlreadyClosedError(
"loseConnection not allowed on closed subchannel")
# primary transitions
open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
open.upon(local_data, enter=open, outputs=[send_data])
open.upon(remote_close, enter=closed, outputs=[send_close, signal_connectionLost])
open.upon(local_close, enter=closing, outputs=[send_close])
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
# stuff that arrives before we have a protocol connected
unconnected.upon(remote_data, enter=unconnected, outputs=[queue_remote_data])
unconnected.upon(remote_close, enter=unconnected, outputs=[queue_remote_close])
# IHalfCloseableProtocol flow
unconnected.upon(connect_protocol_half, enter=open_half, outputs=[])
open_half.upon(remote_data, enter=open_half, outputs=[signal_dataReceived])
open_half.upon(local_data, enter=open_half, outputs=[send_data])
# remote closes first
open_half.upon(remote_close, enter=read_closed, outputs=[signal_readConnectionLost])
read_closed.upon(local_data, enter=read_closed, outputs=[send_data])
read_closed.upon(local_close, enter=closed, outputs=[send_close,
close_subchannel,
# TODO: eventual-signal this?
signal_writeConnectionLost,
])
# local closes first
open_half.upon(local_close, enter=write_closed, outputs=[signal_writeConnectionLost,
send_close])
write_closed.upon(local_data, enter=write_closed, outputs=[error_closed_write])
write_closed.upon(remote_data, enter=write_closed, outputs=[signal_dataReceived])
write_closed.upon(remote_close, enter=closed, outputs=[close_subchannel,
signal_readConnectionLost,
])
# error cases
write_closed.upon(local_close, enter=write_closed, outputs=[error_closed_close])
# fully-closeable-only flow
unconnected.upon(connect_protocol_full, enter=open_full, outputs=[])
open_full.upon(remote_data, enter=open_full, outputs=[signal_dataReceived])
open_full.upon(local_data, enter=open_full, outputs=[send_data])
open_full.upon(remote_close, enter=closed, outputs=[send_close,
close_subchannel,
signal_connectionLost])
open_full.upon(local_close, enter=closing, outputs=[send_close])
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
closing.upon(remote_close, enter=closed, outputs=[close_subchannel,
signal_connectionLost])
# error cases
# we won't ever see an OPEN, since L4 will log+ignore those for us
closing.upon(local_data, enter=closing, outputs=[error_closed_write])
@ -170,15 +242,19 @@ class SubChannel(object):
def _set_protocol(self, protocol):
assert not self._protocol
self._protocol = protocol
if IHalfCloseableProtocol.providedBy(protocol):
self.connect_protocol_half()
else:
# move from UNCONNECTED to OPEN
self.connect_protocol_full();
def _deliver_queued_data(self):
if self._pending_dataReceived:
for data in self._pending_dataReceived:
self._protocol.dataReceived(data)
self._pending_dataReceived = []
cl, what = self._pending_connectionLost
if cl:
self._protocol.connectionLost(what)
for data in self._pending_remote_data:
self.remote_data(data)
del self._pending_remote_data
if self._pending_remote_close:
self.remote_close()
del self._pending_remote_close
# ITransport
def write(self, data):

View File

@ -1,7 +1,8 @@
from __future__ import print_function, unicode_literals
import mock
from zope.interface import directlyProvides
from twisted.trial import unittest
from twisted.internet.interfaces import ITransport
from twisted.internet.interfaces import ITransport, IHalfCloseableProtocol
from twisted.internet.error import ConnectionDone
from ..._dilation.subchannel import (Once, SubChannel,
_WormholeAddress, _SubchannelAddress,
@ -9,13 +10,15 @@ from ..._dilation.subchannel import (Once, SubChannel,
from .common import mock_manager
def make_sc(set_protocol=True):
def make_sc(set_protocol=True, half_closeable=False):
scid = 4
hostaddr = _WormholeAddress()
peeraddr = _SubchannelAddress(scid)
m = mock_manager()
sc = SubChannel(scid, m, hostaddr, peeraddr)
p = mock.Mock()
if half_closeable:
directlyProvides(p, IHalfCloseableProtocol)
if set_protocol:
sc._set_protocol(p)
return sc, m, scid, hostaddr, peeraddr, p
@ -145,3 +148,80 @@ class SubChannelAPI(unittest.TestCase):
# TODO: more, once this is implemented
sc.registerProducer(None, True)
sc.unregisterProducer()
class HalfCloseable(unittest.TestCase):
def test_create(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
self.assert_(ITransport.providedBy(sc))
self.assertEqual(m.mock_calls, [])
self.assertIdentical(sc.getHost(), hostaddr)
self.assertIdentical(sc.getPeer(), peeraddr)
def test_local_close(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
sc.write(b"data")
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
m.mock_calls[:] = []
sc.writeSequence([b"more", b"data"])
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")])
m.mock_calls[:] = []
sc.remote_data(b"inbound1")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")])
p.mock_calls[:] = []
# after a local close, we can't write anymore, but we can still
# receive data
sc.loseConnection()
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
m.mock_calls[:] = []
self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])
p.mock_calls[:] = []
with self.assertRaises(AlreadyClosedError) as e:
sc.write(b"data")
self.assertEqual(str(e.exception),
"write not allowed on closed subchannel")
with self.assertRaises(AlreadyClosedError) as e:
sc.loseConnection()
self.assertEqual(str(e.exception),
"loseConnection not allowed on closed subchannel")
sc.remote_data(b"inbound2")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound2")])
p.mock_calls[:] = []
# the remote end will finally shut down the connection
sc.remote_close()
self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(scid, sc)])
self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()])
def test_remote_close(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
sc.write(b"data")
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
m.mock_calls[:] = []
sc.remote_data(b"inbound1")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")])
p.mock_calls[:] = []
# after a remote close, we can still write data
sc.remote_close()
self.assertEqual(m.mock_calls, [])
self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()])
p.mock_calls[:] = []
sc.write(b"out2")
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"out2")])
m.mock_calls[:] = []
# and a local close will shutdown the connection
sc.loseConnection()
self.assertEqual(m.mock_calls, [mock.call.send_close(scid),
mock.call.subchannel_closed(scid, sc)])
self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])