subchannel: enforce separation between half-close and full-close API

This commit is contained in:
Brian Warner 2019-08-11 18:47:54 -07:00
parent 1c8c2997c7
commit 1219fd08ca
2 changed files with 27 additions and 4 deletions

View File

@ -56,6 +56,11 @@ class SingleUseEndpointError(Exception):
class AlreadyClosedError(Exception):
pass
class NormalCloseUsedOnHalfCloseable(Exception):
pass
class HalfCloseUsedOnNonHalfCloseable(Exception):
pass
@implementer(IAddress)
class _WormholeAddress(object):
@ -265,7 +270,18 @@ class SubChannel(object):
def writeSequence(self, iovec):
self.write(b"".join(iovec))
def loseWriteConnection(self):
if not IHalfCloseableProtocol.providedBy(self._protocol):
# this is a clear error
raise HalfCloseUsedOnNonHalfCloseable()
self.local_close();
def loseConnection(self):
# TODO: what happens if an IHalfCloseableProtocol calls normal
# loseConnection()? I think we need to close the read side too.
if IHalfCloseableProtocol.providedBy(self._protocol):
# I don't know is correct, so avoid this for now
raise NormalCloseUsedOnHalfCloseable()
self.local_close()
def getHost(self):

View File

@ -6,7 +6,8 @@ from twisted.internet.interfaces import ITransport, IHalfCloseableProtocol
from twisted.internet.error import ConnectionDone
from ..._dilation.subchannel import (Once, SubChannel,
_WormholeAddress, _SubchannelAddress,
AlreadyClosedError)
AlreadyClosedError,
NormalCloseUsedOnHalfCloseable)
from .common import mock_manager
@ -174,9 +175,12 @@ class HalfCloseable(unittest.TestCase):
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")])
p.mock_calls[:] = []
with self.assertRaises(NormalCloseUsedOnHalfCloseable) as e:
sc.loseConnection() # TODO: maybe this shouldn't be an error
# after a local close, we can't write anymore, but we can still
# receive data
sc.loseConnection()
sc.loseWriteConnection() # TODO or loseConnection?
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
m.mock_calls[:] = []
self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])
@ -188,10 +192,13 @@ class HalfCloseable(unittest.TestCase):
"write not allowed on closed subchannel")
with self.assertRaises(AlreadyClosedError) as e:
sc.loseConnection()
sc.loseWriteConnection()
self.assertEqual(str(e.exception),
"loseConnection not allowed on closed subchannel")
with self.assertRaises(NormalCloseUsedOnHalfCloseable) as e:
sc.loseConnection() # TODO: maybe expect AlreadyClosedError
sc.remote_data(b"inbound2")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound2")])
p.mock_calls[:] = []
@ -223,7 +230,7 @@ class HalfCloseable(unittest.TestCase):
m.mock_calls[:] = []
# and a local close will shutdown the connection
sc.loseConnection()
sc.loseWriteConnection()
self.assertEqual(m.mock_calls, [mock.call.send_close(scid),
mock.call.subchannel_closed(scid, sc)])
self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])