update subchannel state machine for half-close
also handle open-but-not-yet-connected subchannels better
This commit is contained in:
parent
b233763082
commit
327e72e6ac
|
@ -6,6 +6,7 @@ from zope.interface import implementer
|
||||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||||
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
|
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
|
||||||
IAddress, IListeningPort,
|
IAddress, IListeningPort,
|
||||||
|
IHalfCloseableProtocol,
|
||||||
IStreamClientEndpoint,
|
IStreamClientEndpoint,
|
||||||
IStreamServerEndpoint)
|
IStreamServerEndpoint)
|
||||||
from twisted.internet.error import ConnectionDone
|
from twisted.internet.error import ConnectionDone
|
||||||
|
@ -87,11 +88,29 @@ class SubChannel(object):
|
||||||
# self._pending_outbound = {}
|
# self._pending_outbound = {}
|
||||||
# self._processed = set()
|
# self._processed = set()
|
||||||
self._protocol = None
|
self._protocol = None
|
||||||
self._pending_dataReceived = []
|
self._pending_remote_data = []
|
||||||
self._pending_connectionLost = (False, None)
|
self._pending_remote_close = False
|
||||||
|
|
||||||
@m.state(initial=True)
|
@m.state(initial=True)
|
||||||
def open(self):
|
def unconnected(self):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
# once we get the IProtocol, it's either a IHalfCloseableProtocol, or it
|
||||||
|
# can only be fully closed
|
||||||
|
@m.state()
|
||||||
|
def open_half(self):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
@m.state()
|
||||||
|
def read_closed():
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
@m.state()
|
||||||
|
def write_closed():
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
@m.state()
|
||||||
|
def open_full(self):
|
||||||
pass # pragma: no cover
|
pass # pragma: no cover
|
||||||
|
|
||||||
@m.state()
|
@m.state()
|
||||||
|
@ -102,6 +121,14 @@ class SubChannel(object):
|
||||||
def closed():
|
def closed():
|
||||||
pass # pragma: no cover
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
@m.input()
|
||||||
|
def connect_protocol_half(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@m.input()
|
||||||
|
def connect_protocol_full(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@m.input()
|
@m.input()
|
||||||
def remote_data(self, data):
|
def remote_data(self, data):
|
||||||
pass
|
pass
|
||||||
|
@ -118,6 +145,14 @@ class SubChannel(object):
|
||||||
def local_close(self):
|
def local_close(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@m.output()
|
||||||
|
def queue_remote_data(self, data):
|
||||||
|
self._pending_remote_data.append(data)
|
||||||
|
|
||||||
|
@m.output()
|
||||||
|
def queue_remote_close(self):
|
||||||
|
self._pending_remote_close = True
|
||||||
|
|
||||||
@m.output()
|
@m.output()
|
||||||
def send_data(self, data):
|
def send_data(self, data):
|
||||||
self._manager.send_data(self._scid, data)
|
self._manager.send_data(self._scid, data)
|
||||||
|
@ -128,17 +163,24 @@ class SubChannel(object):
|
||||||
|
|
||||||
@m.output()
|
@m.output()
|
||||||
def signal_dataReceived(self, data):
|
def signal_dataReceived(self, data):
|
||||||
if self._protocol:
|
assert self._protocol
|
||||||
self._protocol.dataReceived(data)
|
self._protocol.dataReceived(data)
|
||||||
else:
|
|
||||||
self._pending_dataReceived.append(data)
|
@m.output()
|
||||||
|
def signal_readConnectionLost(self):
|
||||||
|
IHalfCloseableProtocol(self._protocol).readConnectionLost()
|
||||||
|
|
||||||
|
@m.output()
|
||||||
|
def signal_writeConnectionLost(self):
|
||||||
|
IHalfCloseableProtocol(self._protocol).writeConnectionLost()
|
||||||
|
|
||||||
@m.output()
|
@m.output()
|
||||||
def signal_connectionLost(self):
|
def signal_connectionLost(self):
|
||||||
if self._protocol:
|
assert self._protocol
|
||||||
self._protocol.connectionLost(ConnectionDone())
|
self._protocol.connectionLost(ConnectionDone())
|
||||||
else:
|
|
||||||
self._pending_connectionLost = (True, ConnectionDone())
|
@m.output()
|
||||||
|
def close_subchannel(self):
|
||||||
self._manager.subchannel_closed(self._scid, self)
|
self._manager.subchannel_closed(self._scid, self)
|
||||||
# we're deleted momentarily
|
# we're deleted momentarily
|
||||||
|
|
||||||
|
@ -151,14 +193,44 @@ class SubChannel(object):
|
||||||
raise AlreadyClosedError(
|
raise AlreadyClosedError(
|
||||||
"loseConnection not allowed on closed subchannel")
|
"loseConnection not allowed on closed subchannel")
|
||||||
|
|
||||||
# primary transitions
|
# stuff that arrives before we have a protocol connected
|
||||||
open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
|
unconnected.upon(remote_data, enter=unconnected, outputs=[queue_remote_data])
|
||||||
open.upon(local_data, enter=open, outputs=[send_data])
|
unconnected.upon(remote_close, enter=unconnected, outputs=[queue_remote_close])
|
||||||
open.upon(remote_close, enter=closed, outputs=[send_close, signal_connectionLost])
|
|
||||||
open.upon(local_close, enter=closing, outputs=[send_close])
|
|
||||||
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
|
|
||||||
closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
|
|
||||||
|
|
||||||
|
# IHalfCloseableProtocol flow
|
||||||
|
unconnected.upon(connect_protocol_half, enter=open_half, outputs=[])
|
||||||
|
open_half.upon(remote_data, enter=open_half, outputs=[signal_dataReceived])
|
||||||
|
open_half.upon(local_data, enter=open_half, outputs=[send_data])
|
||||||
|
# remote closes first
|
||||||
|
open_half.upon(remote_close, enter=read_closed, outputs=[signal_readConnectionLost])
|
||||||
|
read_closed.upon(local_data, enter=read_closed, outputs=[send_data])
|
||||||
|
read_closed.upon(local_close, enter=closed, outputs=[send_close,
|
||||||
|
close_subchannel,
|
||||||
|
# TODO: eventual-signal this?
|
||||||
|
signal_writeConnectionLost,
|
||||||
|
])
|
||||||
|
# local closes first
|
||||||
|
open_half.upon(local_close, enter=write_closed, outputs=[signal_writeConnectionLost,
|
||||||
|
send_close])
|
||||||
|
write_closed.upon(local_data, enter=write_closed, outputs=[error_closed_write])
|
||||||
|
write_closed.upon(remote_data, enter=write_closed, outputs=[signal_dataReceived])
|
||||||
|
write_closed.upon(remote_close, enter=closed, outputs=[close_subchannel,
|
||||||
|
signal_readConnectionLost,
|
||||||
|
])
|
||||||
|
# error cases
|
||||||
|
write_closed.upon(local_close, enter=write_closed, outputs=[error_closed_close])
|
||||||
|
|
||||||
|
# fully-closeable-only flow
|
||||||
|
unconnected.upon(connect_protocol_full, enter=open_full, outputs=[])
|
||||||
|
open_full.upon(remote_data, enter=open_full, outputs=[signal_dataReceived])
|
||||||
|
open_full.upon(local_data, enter=open_full, outputs=[send_data])
|
||||||
|
open_full.upon(remote_close, enter=closed, outputs=[send_close,
|
||||||
|
close_subchannel,
|
||||||
|
signal_connectionLost])
|
||||||
|
open_full.upon(local_close, enter=closing, outputs=[send_close])
|
||||||
|
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
|
||||||
|
closing.upon(remote_close, enter=closed, outputs=[close_subchannel,
|
||||||
|
signal_connectionLost])
|
||||||
# error cases
|
# error cases
|
||||||
# we won't ever see an OPEN, since L4 will log+ignore those for us
|
# we won't ever see an OPEN, since L4 will log+ignore those for us
|
||||||
closing.upon(local_data, enter=closing, outputs=[error_closed_write])
|
closing.upon(local_data, enter=closing, outputs=[error_closed_write])
|
||||||
|
@ -170,15 +242,19 @@ class SubChannel(object):
|
||||||
def _set_protocol(self, protocol):
|
def _set_protocol(self, protocol):
|
||||||
assert not self._protocol
|
assert not self._protocol
|
||||||
self._protocol = protocol
|
self._protocol = protocol
|
||||||
|
if IHalfCloseableProtocol.providedBy(protocol):
|
||||||
|
self.connect_protocol_half()
|
||||||
|
else:
|
||||||
|
# move from UNCONNECTED to OPEN
|
||||||
|
self.connect_protocol_full();
|
||||||
|
|
||||||
def _deliver_queued_data(self):
|
def _deliver_queued_data(self):
|
||||||
if self._pending_dataReceived:
|
for data in self._pending_remote_data:
|
||||||
for data in self._pending_dataReceived:
|
self.remote_data(data)
|
||||||
self._protocol.dataReceived(data)
|
del self._pending_remote_data
|
||||||
self._pending_dataReceived = []
|
if self._pending_remote_close:
|
||||||
cl, what = self._pending_connectionLost
|
self.remote_close()
|
||||||
if cl:
|
del self._pending_remote_close
|
||||||
self._protocol.connectionLost(what)
|
|
||||||
|
|
||||||
# ITransport
|
# ITransport
|
||||||
def write(self, data):
|
def write(self, data):
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
from __future__ import print_function, unicode_literals
|
from __future__ import print_function, unicode_literals
|
||||||
import mock
|
import mock
|
||||||
|
from zope.interface import directlyProvides
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.internet.interfaces import ITransport
|
from twisted.internet.interfaces import ITransport, IHalfCloseableProtocol
|
||||||
from twisted.internet.error import ConnectionDone
|
from twisted.internet.error import ConnectionDone
|
||||||
from ..._dilation.subchannel import (Once, SubChannel,
|
from ..._dilation.subchannel import (Once, SubChannel,
|
||||||
_WormholeAddress, _SubchannelAddress,
|
_WormholeAddress, _SubchannelAddress,
|
||||||
|
@ -9,13 +10,15 @@ from ..._dilation.subchannel import (Once, SubChannel,
|
||||||
from .common import mock_manager
|
from .common import mock_manager
|
||||||
|
|
||||||
|
|
||||||
def make_sc(set_protocol=True):
|
def make_sc(set_protocol=True, half_closeable=False):
|
||||||
scid = 4
|
scid = 4
|
||||||
hostaddr = _WormholeAddress()
|
hostaddr = _WormholeAddress()
|
||||||
peeraddr = _SubchannelAddress(scid)
|
peeraddr = _SubchannelAddress(scid)
|
||||||
m = mock_manager()
|
m = mock_manager()
|
||||||
sc = SubChannel(scid, m, hostaddr, peeraddr)
|
sc = SubChannel(scid, m, hostaddr, peeraddr)
|
||||||
p = mock.Mock()
|
p = mock.Mock()
|
||||||
|
if half_closeable:
|
||||||
|
directlyProvides(p, IHalfCloseableProtocol)
|
||||||
if set_protocol:
|
if set_protocol:
|
||||||
sc._set_protocol(p)
|
sc._set_protocol(p)
|
||||||
return sc, m, scid, hostaddr, peeraddr, p
|
return sc, m, scid, hostaddr, peeraddr, p
|
||||||
|
@ -145,3 +148,80 @@ class SubChannelAPI(unittest.TestCase):
|
||||||
# TODO: more, once this is implemented
|
# TODO: more, once this is implemented
|
||||||
sc.registerProducer(None, True)
|
sc.registerProducer(None, True)
|
||||||
sc.unregisterProducer()
|
sc.unregisterProducer()
|
||||||
|
|
||||||
|
class HalfCloseable(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_create(self):
|
||||||
|
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
|
||||||
|
self.assert_(ITransport.providedBy(sc))
|
||||||
|
self.assertEqual(m.mock_calls, [])
|
||||||
|
self.assertIdentical(sc.getHost(), hostaddr)
|
||||||
|
self.assertIdentical(sc.getPeer(), peeraddr)
|
||||||
|
|
||||||
|
def test_local_close(self):
|
||||||
|
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
|
||||||
|
|
||||||
|
sc.write(b"data")
|
||||||
|
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
|
||||||
|
m.mock_calls[:] = []
|
||||||
|
sc.writeSequence([b"more", b"data"])
|
||||||
|
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")])
|
||||||
|
m.mock_calls[:] = []
|
||||||
|
|
||||||
|
sc.remote_data(b"inbound1")
|
||||||
|
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")])
|
||||||
|
p.mock_calls[:] = []
|
||||||
|
|
||||||
|
# after a local close, we can't write anymore, but we can still
|
||||||
|
# receive data
|
||||||
|
sc.loseConnection()
|
||||||
|
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
|
||||||
|
m.mock_calls[:] = []
|
||||||
|
self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])
|
||||||
|
p.mock_calls[:] = []
|
||||||
|
|
||||||
|
with self.assertRaises(AlreadyClosedError) as e:
|
||||||
|
sc.write(b"data")
|
||||||
|
self.assertEqual(str(e.exception),
|
||||||
|
"write not allowed on closed subchannel")
|
||||||
|
|
||||||
|
with self.assertRaises(AlreadyClosedError) as e:
|
||||||
|
sc.loseConnection()
|
||||||
|
self.assertEqual(str(e.exception),
|
||||||
|
"loseConnection not allowed on closed subchannel")
|
||||||
|
|
||||||
|
sc.remote_data(b"inbound2")
|
||||||
|
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound2")])
|
||||||
|
p.mock_calls[:] = []
|
||||||
|
|
||||||
|
# the remote end will finally shut down the connection
|
||||||
|
sc.remote_close()
|
||||||
|
self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(scid, sc)])
|
||||||
|
self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()])
|
||||||
|
|
||||||
|
def test_remote_close(self):
|
||||||
|
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
|
||||||
|
|
||||||
|
sc.write(b"data")
|
||||||
|
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
|
||||||
|
m.mock_calls[:] = []
|
||||||
|
|
||||||
|
sc.remote_data(b"inbound1")
|
||||||
|
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")])
|
||||||
|
p.mock_calls[:] = []
|
||||||
|
|
||||||
|
# after a remote close, we can still write data
|
||||||
|
sc.remote_close()
|
||||||
|
self.assertEqual(m.mock_calls, [])
|
||||||
|
self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()])
|
||||||
|
p.mock_calls[:] = []
|
||||||
|
|
||||||
|
sc.write(b"out2")
|
||||||
|
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"out2")])
|
||||||
|
m.mock_calls[:] = []
|
||||||
|
|
||||||
|
# and a local close will shutdown the connection
|
||||||
|
sc.loseConnection()
|
||||||
|
self.assertEqual(m.mock_calls, [mock.call.send_close(scid),
|
||||||
|
mock.call.subchannel_closed(scid, sc)])
|
||||||
|
self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user