From b2337630828859c249095bfa22006c54ffd21174 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 8 Jul 2019 10:26:43 -0700 Subject: [PATCH] subchannel: deliver queued connectionMade before any data The previous implementation would call the control/receiving Protocol completely backwards: dataReceived first, then connectionLost, then finally connectionMade. Which didn't work at all, of course. --- src/wormhole/_dilation/subchannel.py | 6 +++++- src/wormhole/test/dilate/test_endpoints.py | 19 +++++++++++++------ src/wormhole/test/dilate/test_subchannel.py | 2 ++ 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py index e1e811c..ee10e97 100644 --- a/src/wormhole/_dilation/subchannel.py +++ b/src/wormhole/_dilation/subchannel.py @@ -165,11 +165,13 @@ class SubChannel(object): closing.upon(local_close, enter=closing, outputs=[error_closed_close]) # the CLOSED state won't ever see messages, since we'll be deleted - # our endpoints use this + # our endpoints use these def _set_protocol(self, protocol): assert not self._protocol self._protocol = protocol + + def _deliver_queued_data(self): if self._pending_dataReceived: for data in self._pending_dataReceived: self._protocol.dataReceived(data) @@ -244,6 +246,7 @@ class ControlEndpoint(object): self._subchannel_zero._set_protocol(p) # this sets p.transport and calls p.connectionMade() p.makeConnection(self._subchannel_zero) + self._subchannel_zero._deliver_queued_data() returnValue(p) @@ -304,6 +307,7 @@ class SubchannelListenerEndpoint(object): p = self._factory.buildProtocol(peer_addr) t._set_protocol(p) p.makeConnection(t) + t._deliver_queued_data() def _main_channel_ready(self): self._wait_for_main_channel.fire(None) diff --git a/src/wormhole/test/dilate/test_endpoints.py b/src/wormhole/test/dilate/test_endpoints.py index 647fb9d..f36296c 100644 --- a/src/wormhole/test/dilate/test_endpoints.py +++ b/src/wormhole/test/dilate/test_endpoints.py @@ -38,7 +38,8 @@ class Control(unittest.TestCase): self.assertIdentical(self.successResultOf(d), p) self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) - self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p), + mock.call._deliver_queued_data()]) self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)]) d = ep.connect(f) @@ -87,7 +88,8 @@ class Control(unittest.TestCase): eq.flush_sync() self.assertIdentical(self.successResultOf(d), p) self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) - self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(sc0.mock_calls, [mock.call._set_protocol(p), + mock.call._deliver_queued_data()]) self.assertEqual(p.mock_calls, [mock.call.makeConnection(sc0)]) d = ep.connect(f) @@ -254,7 +256,8 @@ class Listener(unittest.TestCase): peeraddr1 = _SubchannelAddress(1) ep._got_open(t1, peeraddr1) - self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)]) + self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1), + mock.call._deliver_queued_data()]) self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)]) self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)]) @@ -262,7 +265,8 @@ class Listener(unittest.TestCase): peeraddr2 = _SubchannelAddress(2) ep._got_open(t2, peeraddr2) - self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) + self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2), + mock.call._deliver_queued_data()]) self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1), mock.call(peeraddr2)]) @@ -320,7 +324,9 @@ class Listener(unittest.TestCase): self.assertEqual(lp.getHost(), hostaddr) lp.startListening() - self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)]) + # TODO: assert makeConnection is called *before* _deliver_queued_data + self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1), + mock.call._deliver_queued_data()]) self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)]) self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1)]) @@ -328,7 +334,8 @@ class Listener(unittest.TestCase): peeraddr2 = _SubchannelAddress(2) ep._got_open(t2, peeraddr2) - self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) + self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2), + mock.call._deliver_queued_data()]) self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr1), mock.call(peeraddr2)]) diff --git a/src/wormhole/test/dilate/test_subchannel.py b/src/wormhole/test/dilate/test_subchannel.py index c26248a..0705a10 100644 --- a/src/wormhole/test/dilate/test_subchannel.py +++ b/src/wormhole/test/dilate/test_subchannel.py @@ -112,6 +112,7 @@ class SubChannelAPI(unittest.TestCase): sc.remote_data(b"data") self.assertEqual(p.mock_calls, []) sc._set_protocol(p) + sc._deliver_queued_data() self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")]) p.mock_calls[:] = [] sc.remote_data(b"more") @@ -122,6 +123,7 @@ class SubChannelAPI(unittest.TestCase): sc.remote_close() self.assertEqual(p.mock_calls, []) sc._set_protocol(p) + sc._deliver_queued_data() self.assert_connectionDone(p.mock_calls) def test_producer(self):