update subchannel state machine for half-close
also handle open-but-not-yet-connected subchannels better
This commit is contained in:
parent
b233763082
commit
327e72e6ac
|
@ -6,6 +6,7 @@ from zope.interface import implementer
|
|||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
|
||||
IAddress, IListeningPort,
|
||||
IHalfCloseableProtocol,
|
||||
IStreamClientEndpoint,
|
||||
IStreamServerEndpoint)
|
||||
from twisted.internet.error import ConnectionDone
|
||||
|
@ -87,11 +88,29 @@ class SubChannel(object):
|
|||
# self._pending_outbound = {}
|
||||
# self._processed = set()
|
||||
self._protocol = None
|
||||
self._pending_dataReceived = []
|
||||
self._pending_connectionLost = (False, None)
|
||||
self._pending_remote_data = []
|
||||
self._pending_remote_close = False
|
||||
|
||||
@m.state(initial=True)
|
||||
def open(self):
|
||||
def unconnected(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
# once we get the IProtocol, it's either a IHalfCloseableProtocol, or it
|
||||
# can only be fully closed
|
||||
@m.state()
|
||||
def open_half(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def read_closed():
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def write_closed():
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def open_full(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
|
@ -102,6 +121,14 @@ class SubChannel(object):
|
|||
def closed():
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def connect_protocol_half(self):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def connect_protocol_full(self):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def remote_data(self, data):
|
||||
pass
|
||||
|
@ -118,6 +145,14 @@ class SubChannel(object):
|
|||
def local_close(self):
|
||||
pass
|
||||
|
||||
@m.output()
|
||||
def queue_remote_data(self, data):
|
||||
self._pending_remote_data.append(data)
|
||||
|
||||
@m.output()
|
||||
def queue_remote_close(self):
|
||||
self._pending_remote_close = True
|
||||
|
||||
@m.output()
|
||||
def send_data(self, data):
|
||||
self._manager.send_data(self._scid, data)
|
||||
|
@ -128,17 +163,24 @@ class SubChannel(object):
|
|||
|
||||
@m.output()
|
||||
def signal_dataReceived(self, data):
|
||||
if self._protocol:
|
||||
self._protocol.dataReceived(data)
|
||||
else:
|
||||
self._pending_dataReceived.append(data)
|
||||
assert self._protocol
|
||||
self._protocol.dataReceived(data)
|
||||
|
||||
@m.output()
|
||||
def signal_readConnectionLost(self):
|
||||
IHalfCloseableProtocol(self._protocol).readConnectionLost()
|
||||
|
||||
@m.output()
|
||||
def signal_writeConnectionLost(self):
|
||||
IHalfCloseableProtocol(self._protocol).writeConnectionLost()
|
||||
|
||||
@m.output()
|
||||
def signal_connectionLost(self):
|
||||
if self._protocol:
|
||||
self._protocol.connectionLost(ConnectionDone())
|
||||
else:
|
||||
self._pending_connectionLost = (True, ConnectionDone())
|
||||
assert self._protocol
|
||||
self._protocol.connectionLost(ConnectionDone())
|
||||
|
||||
@m.output()
|
||||
def close_subchannel(self):
|
||||
self._manager.subchannel_closed(self._scid, self)
|
||||
# we're deleted momentarily
|
||||
|
||||
|
@ -151,14 +193,44 @@ class SubChannel(object):
|
|||
raise AlreadyClosedError(
|
||||
"loseConnection not allowed on closed subchannel")
|
||||
|
||||
# primary transitions
|
||||
open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
|
||||
open.upon(local_data, enter=open, outputs=[send_data])
|
||||
open.upon(remote_close, enter=closed, outputs=[send_close, signal_connectionLost])
|
||||
open.upon(local_close, enter=closing, outputs=[send_close])
|
||||
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
|
||||
closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
|
||||
# stuff that arrives before we have a protocol connected
|
||||
unconnected.upon(remote_data, enter=unconnected, outputs=[queue_remote_data])
|
||||
unconnected.upon(remote_close, enter=unconnected, outputs=[queue_remote_close])
|
||||
|
||||
# IHalfCloseableProtocol flow
|
||||
unconnected.upon(connect_protocol_half, enter=open_half, outputs=[])
|
||||
open_half.upon(remote_data, enter=open_half, outputs=[signal_dataReceived])
|
||||
open_half.upon(local_data, enter=open_half, outputs=[send_data])
|
||||
# remote closes first
|
||||
open_half.upon(remote_close, enter=read_closed, outputs=[signal_readConnectionLost])
|
||||
read_closed.upon(local_data, enter=read_closed, outputs=[send_data])
|
||||
read_closed.upon(local_close, enter=closed, outputs=[send_close,
|
||||
close_subchannel,
|
||||
# TODO: eventual-signal this?
|
||||
signal_writeConnectionLost,
|
||||
])
|
||||
# local closes first
|
||||
open_half.upon(local_close, enter=write_closed, outputs=[signal_writeConnectionLost,
|
||||
send_close])
|
||||
write_closed.upon(local_data, enter=write_closed, outputs=[error_closed_write])
|
||||
write_closed.upon(remote_data, enter=write_closed, outputs=[signal_dataReceived])
|
||||
write_closed.upon(remote_close, enter=closed, outputs=[close_subchannel,
|
||||
signal_readConnectionLost,
|
||||
])
|
||||
# error cases
|
||||
write_closed.upon(local_close, enter=write_closed, outputs=[error_closed_close])
|
||||
|
||||
# fully-closeable-only flow
|
||||
unconnected.upon(connect_protocol_full, enter=open_full, outputs=[])
|
||||
open_full.upon(remote_data, enter=open_full, outputs=[signal_dataReceived])
|
||||
open_full.upon(local_data, enter=open_full, outputs=[send_data])
|
||||
open_full.upon(remote_close, enter=closed, outputs=[send_close,
|
||||
close_subchannel,
|
||||
signal_connectionLost])
|
||||
open_full.upon(local_close, enter=closing, outputs=[send_close])
|
||||
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
|
||||
closing.upon(remote_close, enter=closed, outputs=[close_subchannel,
|
||||
signal_connectionLost])
|
||||
# error cases
|
||||
# we won't ever see an OPEN, since L4 will log+ignore those for us
|
||||
closing.upon(local_data, enter=closing, outputs=[error_closed_write])
|
||||
|
@ -170,15 +242,19 @@ class SubChannel(object):
|
|||
def _set_protocol(self, protocol):
|
||||
assert not self._protocol
|
||||
self._protocol = protocol
|
||||
if IHalfCloseableProtocol.providedBy(protocol):
|
||||
self.connect_protocol_half()
|
||||
else:
|
||||
# move from UNCONNECTED to OPEN
|
||||
self.connect_protocol_full();
|
||||
|
||||
def _deliver_queued_data(self):
|
||||
if self._pending_dataReceived:
|
||||
for data in self._pending_dataReceived:
|
||||
self._protocol.dataReceived(data)
|
||||
self._pending_dataReceived = []
|
||||
cl, what = self._pending_connectionLost
|
||||
if cl:
|
||||
self._protocol.connectionLost(what)
|
||||
for data in self._pending_remote_data:
|
||||
self.remote_data(data)
|
||||
del self._pending_remote_data
|
||||
if self._pending_remote_close:
|
||||
self.remote_close()
|
||||
del self._pending_remote_close
|
||||
|
||||
# ITransport
|
||||
def write(self, data):
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from zope.interface import directlyProvides
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.interfaces import ITransport
|
||||
from twisted.internet.interfaces import ITransport, IHalfCloseableProtocol
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from ..._dilation.subchannel import (Once, SubChannel,
|
||||
_WormholeAddress, _SubchannelAddress,
|
||||
|
@ -9,13 +10,15 @@ from ..._dilation.subchannel import (Once, SubChannel,
|
|||
from .common import mock_manager
|
||||
|
||||
|
||||
def make_sc(set_protocol=True):
|
||||
def make_sc(set_protocol=True, half_closeable=False):
|
||||
scid = 4
|
||||
hostaddr = _WormholeAddress()
|
||||
peeraddr = _SubchannelAddress(scid)
|
||||
m = mock_manager()
|
||||
sc = SubChannel(scid, m, hostaddr, peeraddr)
|
||||
p = mock.Mock()
|
||||
if half_closeable:
|
||||
directlyProvides(p, IHalfCloseableProtocol)
|
||||
if set_protocol:
|
||||
sc._set_protocol(p)
|
||||
return sc, m, scid, hostaddr, peeraddr, p
|
||||
|
@ -145,3 +148,80 @@ class SubChannelAPI(unittest.TestCase):
|
|||
# TODO: more, once this is implemented
|
||||
sc.registerProducer(None, True)
|
||||
sc.unregisterProducer()
|
||||
|
||||
class HalfCloseable(unittest.TestCase):
|
||||
|
||||
def test_create(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
|
||||
self.assert_(ITransport.providedBy(sc))
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
self.assertIdentical(sc.getHost(), hostaddr)
|
||||
self.assertIdentical(sc.getPeer(), peeraddr)
|
||||
|
||||
def test_local_close(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
|
||||
|
||||
sc.write(b"data")
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
|
||||
m.mock_calls[:] = []
|
||||
sc.writeSequence([b"more", b"data"])
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")])
|
||||
m.mock_calls[:] = []
|
||||
|
||||
sc.remote_data(b"inbound1")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")])
|
||||
p.mock_calls[:] = []
|
||||
|
||||
# after a local close, we can't write anymore, but we can still
|
||||
# receive data
|
||||
sc.loseConnection()
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
|
||||
m.mock_calls[:] = []
|
||||
self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])
|
||||
p.mock_calls[:] = []
|
||||
|
||||
with self.assertRaises(AlreadyClosedError) as e:
|
||||
sc.write(b"data")
|
||||
self.assertEqual(str(e.exception),
|
||||
"write not allowed on closed subchannel")
|
||||
|
||||
with self.assertRaises(AlreadyClosedError) as e:
|
||||
sc.loseConnection()
|
||||
self.assertEqual(str(e.exception),
|
||||
"loseConnection not allowed on closed subchannel")
|
||||
|
||||
sc.remote_data(b"inbound2")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound2")])
|
||||
p.mock_calls[:] = []
|
||||
|
||||
# the remote end will finally shut down the connection
|
||||
sc.remote_close()
|
||||
self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(scid, sc)])
|
||||
self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()])
|
||||
|
||||
def test_remote_close(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
|
||||
|
||||
sc.write(b"data")
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
|
||||
m.mock_calls[:] = []
|
||||
|
||||
sc.remote_data(b"inbound1")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")])
|
||||
p.mock_calls[:] = []
|
||||
|
||||
# after a remote close, we can still write data
|
||||
sc.remote_close()
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()])
|
||||
p.mock_calls[:] = []
|
||||
|
||||
sc.write(b"out2")
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"out2")])
|
||||
m.mock_calls[:] = []
|
||||
|
||||
# and a local close will shutdown the connection
|
||||
sc.loseConnection()
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_close(scid),
|
||||
mock.call.subchannel_closed(scid, sc)])
|
||||
self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])
|
||||
|
|
Loading…
Reference in New Issue
Block a user