magic-wormhole/src/wormhole/test/dilate/test_subchannel.py
Brian Warner d1aefa815d fix subchannel open/close, add test
I think I just managed to forget that inbound_close requires we respond with
a close ourselves. Also outbound open means we must add the subchannel to the
inbound table, so we can receive any data on it at all.
2019-07-06 01:50:29 -07:00

146 lines
5.2 KiB
Python

from __future__ import print_function, unicode_literals
import mock
from twisted.trial import unittest
from twisted.internet.interfaces import ITransport
from twisted.internet.error import ConnectionDone
from ..._dilation.subchannel import (Once, SubChannel,
_WormholeAddress, _SubchannelAddress,
AlreadyClosedError)
from .common import mock_manager
def make_sc(set_protocol=True):
scid = 4
hostaddr = _WormholeAddress()
peeraddr = _SubchannelAddress(scid)
m = mock_manager()
sc = SubChannel(scid, m, hostaddr, peeraddr)
p = mock.Mock()
if set_protocol:
sc._set_protocol(p)
return sc, m, scid, hostaddr, peeraddr, p
class SubChannelAPI(unittest.TestCase):
def test_once(self):
o = Once(ValueError)
o()
with self.assertRaises(ValueError):
o()
def test_create(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
self.assert_(ITransport.providedBy(sc))
self.assertEqual(m.mock_calls, [])
self.assertIdentical(sc.getHost(), hostaddr)
self.assertIdentical(sc.getPeer(), peeraddr)
def test_write(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.write(b"data")
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
m.mock_calls[:] = []
sc.writeSequence([b"more", b"data"])
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")])
def test_write_when_closing(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.loseConnection()
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
m.mock_calls[:] = []
with self.assertRaises(AlreadyClosedError) as e:
sc.write(b"data")
self.assertEqual(str(e.exception),
"write not allowed on closed subchannel")
def test_local_close(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.loseConnection()
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
m.mock_calls[:] = []
# late arriving data is still delivered
sc.remote_data(b"late")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"late")])
p.mock_calls[:] = []
sc.remote_close()
self.assert_connectionDone(p.mock_calls)
def test_local_double_close(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.loseConnection()
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
m.mock_calls[:] = []
with self.assertRaises(AlreadyClosedError) as e:
sc.loseConnection()
self.assertEqual(str(e.exception),
"loseConnection not allowed on closed subchannel")
def assert_connectionDone(self, mock_calls):
self.assertEqual(len(mock_calls), 1)
self.assertEqual(mock_calls[0][0], "connectionLost")
self.assertEqual(len(mock_calls[0][1]), 1)
self.assertIsInstance(mock_calls[0][1][0], ConnectionDone)
def test_remote_close(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.remote_close()
self.assertEqual(m.mock_calls, [mock.call.send_close(scid),
mock.call.subchannel_closed(scid, sc)])
self.assert_connectionDone(p.mock_calls)
def test_data(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.remote_data(b"data")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")])
p.mock_calls[:] = []
sc.remote_data(b"not")
sc.remote_data(b"coalesced")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"not"),
mock.call.dataReceived(b"coalesced"),
])
def test_data_before_open(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False)
sc.remote_data(b"data")
self.assertEqual(p.mock_calls, [])
sc._set_protocol(p)
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")])
p.mock_calls[:] = []
sc.remote_data(b"more")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"more")])
def test_close_before_open(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False)
sc.remote_close()
self.assertEqual(p.mock_calls, [])
sc._set_protocol(p)
self.assert_connectionDone(p.mock_calls)
def test_producer(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.pauseProducing()
self.assertEqual(m.mock_calls, [mock.call.subchannel_pauseProducing(sc)])
m.mock_calls[:] = []
sc.resumeProducing()
self.assertEqual(m.mock_calls, [mock.call.subchannel_resumeProducing(sc)])
m.mock_calls[:] = []
sc.stopProducing()
self.assertEqual(m.mock_calls, [mock.call.subchannel_stopProducing(sc)])
m.mock_calls[:] = []
def test_consumer(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
# TODO: more, once this is implemented
sc.registerProducer(None, True)
sc.unregisterProducer()