diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py index ee10e97..c2ecae7 100644 --- a/src/wormhole/_dilation/subchannel.py +++ b/src/wormhole/_dilation/subchannel.py @@ -6,6 +6,7 @@ from zope.interface import implementer from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, IAddress, IListeningPort, + IHalfCloseableProtocol, IStreamClientEndpoint, IStreamServerEndpoint) from twisted.internet.error import ConnectionDone @@ -87,11 +88,29 @@ class SubChannel(object): # self._pending_outbound = {} # self._processed = set() self._protocol = None - self._pending_dataReceived = [] - self._pending_connectionLost = (False, None) + self._pending_remote_data = [] + self._pending_remote_close = False @m.state(initial=True) - def open(self): + def unconnected(self): + pass # pragma: no cover + + # once we get the IProtocol, it's either a IHalfCloseableProtocol, or it + # can only be fully closed + @m.state() + def open_half(self): + pass # pragma: no cover + + @m.state() + def read_closed(): + pass # pragma: no cover + + @m.state() + def write_closed(): + pass # pragma: no cover + + @m.state() + def open_full(self): pass # pragma: no cover @m.state() @@ -102,6 +121,14 @@ class SubChannel(object): def closed(): pass # pragma: no cover + @m.input() + def connect_protocol_half(self): + pass + + @m.input() + def connect_protocol_full(self): + pass + @m.input() def remote_data(self, data): pass @@ -118,6 +145,14 @@ class SubChannel(object): def local_close(self): pass + @m.output() + def queue_remote_data(self, data): + self._pending_remote_data.append(data) + + @m.output() + def queue_remote_close(self): + self._pending_remote_close = True + @m.output() def send_data(self, data): self._manager.send_data(self._scid, data) @@ -128,17 +163,24 @@ class SubChannel(object): @m.output() def signal_dataReceived(self, data): - if self._protocol: - self._protocol.dataReceived(data) - else: - self._pending_dataReceived.append(data) + assert self._protocol + self._protocol.dataReceived(data) + + @m.output() + def signal_readConnectionLost(self): + IHalfCloseableProtocol(self._protocol).readConnectionLost() + + @m.output() + def signal_writeConnectionLost(self): + IHalfCloseableProtocol(self._protocol).writeConnectionLost() @m.output() def signal_connectionLost(self): - if self._protocol: - self._protocol.connectionLost(ConnectionDone()) - else: - self._pending_connectionLost = (True, ConnectionDone()) + assert self._protocol + self._protocol.connectionLost(ConnectionDone()) + + @m.output() + def close_subchannel(self): self._manager.subchannel_closed(self._scid, self) # we're deleted momentarily @@ -151,14 +193,44 @@ class SubChannel(object): raise AlreadyClosedError( "loseConnection not allowed on closed subchannel") - # primary transitions - open.upon(remote_data, enter=open, outputs=[signal_dataReceived]) - open.upon(local_data, enter=open, outputs=[send_data]) - open.upon(remote_close, enter=closed, outputs=[send_close, signal_connectionLost]) - open.upon(local_close, enter=closing, outputs=[send_close]) - closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived]) - closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost]) + # stuff that arrives before we have a protocol connected + unconnected.upon(remote_data, enter=unconnected, outputs=[queue_remote_data]) + unconnected.upon(remote_close, enter=unconnected, outputs=[queue_remote_close]) + # IHalfCloseableProtocol flow + unconnected.upon(connect_protocol_half, enter=open_half, outputs=[]) + open_half.upon(remote_data, enter=open_half, outputs=[signal_dataReceived]) + open_half.upon(local_data, enter=open_half, outputs=[send_data]) + # remote closes first + open_half.upon(remote_close, enter=read_closed, outputs=[signal_readConnectionLost]) + read_closed.upon(local_data, enter=read_closed, outputs=[send_data]) + read_closed.upon(local_close, enter=closed, outputs=[send_close, + close_subchannel, + # TODO: eventual-signal this? + signal_writeConnectionLost, + ]) + # local closes first + open_half.upon(local_close, enter=write_closed, outputs=[signal_writeConnectionLost, + send_close]) + write_closed.upon(local_data, enter=write_closed, outputs=[error_closed_write]) + write_closed.upon(remote_data, enter=write_closed, outputs=[signal_dataReceived]) + write_closed.upon(remote_close, enter=closed, outputs=[close_subchannel, + signal_readConnectionLost, + ]) + # error cases + write_closed.upon(local_close, enter=write_closed, outputs=[error_closed_close]) + + # fully-closeable-only flow + unconnected.upon(connect_protocol_full, enter=open_full, outputs=[]) + open_full.upon(remote_data, enter=open_full, outputs=[signal_dataReceived]) + open_full.upon(local_data, enter=open_full, outputs=[send_data]) + open_full.upon(remote_close, enter=closed, outputs=[send_close, + close_subchannel, + signal_connectionLost]) + open_full.upon(local_close, enter=closing, outputs=[send_close]) + closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived]) + closing.upon(remote_close, enter=closed, outputs=[close_subchannel, + signal_connectionLost]) # error cases # we won't ever see an OPEN, since L4 will log+ignore those for us closing.upon(local_data, enter=closing, outputs=[error_closed_write]) @@ -170,15 +242,19 @@ class SubChannel(object): def _set_protocol(self, protocol): assert not self._protocol self._protocol = protocol + if IHalfCloseableProtocol.providedBy(protocol): + self.connect_protocol_half() + else: + # move from UNCONNECTED to OPEN + self.connect_protocol_full(); def _deliver_queued_data(self): - if self._pending_dataReceived: - for data in self._pending_dataReceived: - self._protocol.dataReceived(data) - self._pending_dataReceived = [] - cl, what = self._pending_connectionLost - if cl: - self._protocol.connectionLost(what) + for data in self._pending_remote_data: + self.remote_data(data) + del self._pending_remote_data + if self._pending_remote_close: + self.remote_close() + del self._pending_remote_close # ITransport def write(self, data): diff --git a/src/wormhole/test/dilate/test_subchannel.py b/src/wormhole/test/dilate/test_subchannel.py index 0705a10..c1b02cb 100644 --- a/src/wormhole/test/dilate/test_subchannel.py +++ b/src/wormhole/test/dilate/test_subchannel.py @@ -1,7 +1,8 @@ from __future__ import print_function, unicode_literals import mock +from zope.interface import directlyProvides from twisted.trial import unittest -from twisted.internet.interfaces import ITransport +from twisted.internet.interfaces import ITransport, IHalfCloseableProtocol from twisted.internet.error import ConnectionDone from ..._dilation.subchannel import (Once, SubChannel, _WormholeAddress, _SubchannelAddress, @@ -9,13 +10,15 @@ from ..._dilation.subchannel import (Once, SubChannel, from .common import mock_manager -def make_sc(set_protocol=True): +def make_sc(set_protocol=True, half_closeable=False): scid = 4 hostaddr = _WormholeAddress() peeraddr = _SubchannelAddress(scid) m = mock_manager() sc = SubChannel(scid, m, hostaddr, peeraddr) p = mock.Mock() + if half_closeable: + directlyProvides(p, IHalfCloseableProtocol) if set_protocol: sc._set_protocol(p) return sc, m, scid, hostaddr, peeraddr, p @@ -145,3 +148,80 @@ class SubChannelAPI(unittest.TestCase): # TODO: more, once this is implemented sc.registerProducer(None, True) sc.unregisterProducer() + +class HalfCloseable(unittest.TestCase): + + def test_create(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True) + self.assert_(ITransport.providedBy(sc)) + self.assertEqual(m.mock_calls, []) + self.assertIdentical(sc.getHost(), hostaddr) + self.assertIdentical(sc.getPeer(), peeraddr) + + def test_local_close(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True) + + sc.write(b"data") + self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")]) + m.mock_calls[:] = [] + sc.writeSequence([b"more", b"data"]) + self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")]) + m.mock_calls[:] = [] + + sc.remote_data(b"inbound1") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")]) + p.mock_calls[:] = [] + + # after a local close, we can't write anymore, but we can still + # receive data + sc.loseConnection() + self.assertEqual(m.mock_calls, [mock.call.send_close(scid)]) + m.mock_calls[:] = [] + self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()]) + p.mock_calls[:] = [] + + with self.assertRaises(AlreadyClosedError) as e: + sc.write(b"data") + self.assertEqual(str(e.exception), + "write not allowed on closed subchannel") + + with self.assertRaises(AlreadyClosedError) as e: + sc.loseConnection() + self.assertEqual(str(e.exception), + "loseConnection not allowed on closed subchannel") + + sc.remote_data(b"inbound2") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound2")]) + p.mock_calls[:] = [] + + # the remote end will finally shut down the connection + sc.remote_close() + self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(scid, sc)]) + self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()]) + + def test_remote_close(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True) + + sc.write(b"data") + self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")]) + m.mock_calls[:] = [] + + sc.remote_data(b"inbound1") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")]) + p.mock_calls[:] = [] + + # after a remote close, we can still write data + sc.remote_close() + self.assertEqual(m.mock_calls, []) + self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()]) + p.mock_calls[:] = [] + + sc.write(b"out2") + self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"out2")]) + m.mock_calls[:] = [] + + # and a local close will shutdown the connection + sc.loseConnection() + self.assertEqual(m.mock_calls, [mock.call.send_close(scid), + mock.call.subchannel_closed(scid, sc)]) + self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])