diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py index 1e52e90..cca0d3c 100644 --- a/src/wormhole/_dilation/subchannel.py +++ b/src/wormhole/_dilation/subchannel.py @@ -56,6 +56,11 @@ class SingleUseEndpointError(Exception): class AlreadyClosedError(Exception): pass +class NormalCloseUsedOnHalfCloseable(Exception): + pass +class HalfCloseUsedOnNonHalfCloseable(Exception): + pass + @implementer(IAddress) class _WormholeAddress(object): @@ -265,7 +270,18 @@ class SubChannel(object): def writeSequence(self, iovec): self.write(b"".join(iovec)) + def loseWriteConnection(self): + if not IHalfCloseableProtocol.providedBy(self._protocol): + # this is a clear error + raise HalfCloseUsedOnNonHalfCloseable() + self.local_close(); + def loseConnection(self): + # TODO: what happens if an IHalfCloseableProtocol calls normal + # loseConnection()? I think we need to close the read side too. + if IHalfCloseableProtocol.providedBy(self._protocol): + # I don't know is correct, so avoid this for now + raise NormalCloseUsedOnHalfCloseable() self.local_close() def getHost(self): diff --git a/src/wormhole/test/dilate/test_subchannel.py b/src/wormhole/test/dilate/test_subchannel.py index 45d4d98..88e6565 100644 --- a/src/wormhole/test/dilate/test_subchannel.py +++ b/src/wormhole/test/dilate/test_subchannel.py @@ -6,7 +6,8 @@ from twisted.internet.interfaces import ITransport, IHalfCloseableProtocol from twisted.internet.error import ConnectionDone from ..._dilation.subchannel import (Once, SubChannel, _WormholeAddress, _SubchannelAddress, - AlreadyClosedError) + AlreadyClosedError, + NormalCloseUsedOnHalfCloseable) from .common import mock_manager @@ -174,9 +175,12 @@ class HalfCloseable(unittest.TestCase): self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")]) p.mock_calls[:] = [] + with self.assertRaises(NormalCloseUsedOnHalfCloseable) as e: + sc.loseConnection() # TODO: maybe this shouldn't be an error + # after a local close, we can't write anymore, but we can still # receive data - sc.loseConnection() + sc.loseWriteConnection() # TODO or loseConnection? self.assertEqual(m.mock_calls, [mock.call.send_close(scid)]) m.mock_calls[:] = [] self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()]) @@ -188,10 +192,13 @@ class HalfCloseable(unittest.TestCase): "write not allowed on closed subchannel") with self.assertRaises(AlreadyClosedError) as e: - sc.loseConnection() + sc.loseWriteConnection() self.assertEqual(str(e.exception), "loseConnection not allowed on closed subchannel") + with self.assertRaises(NormalCloseUsedOnHalfCloseable) as e: + sc.loseConnection() # TODO: maybe expect AlreadyClosedError + sc.remote_data(b"inbound2") self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound2")]) p.mock_calls[:] = [] @@ -223,7 +230,7 @@ class HalfCloseable(unittest.TestCase): m.mock_calls[:] = [] # and a local close will shutdown the connection - sc.loseConnection() + sc.loseWriteConnection() self.assertEqual(m.mock_calls, [mock.call.send_close(scid), mock.call.subchannel_closed(scid, sc)]) self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])