From d1aefa815d3d08f341326c0552715e37537ed485 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 6 Jul 2019 01:50:29 -0700 Subject: [PATCH] fix subchannel open/close, add test I think I just managed to forget that inbound_close requires we respond with a close ourselves. Also outbound open means we must add the subchannel to the inbound table, so we can receive any data on it at all. --- src/wormhole/_dilation/inbound.py | 7 +- src/wormhole/_dilation/manager.py | 3 + src/wormhole/_dilation/outbound.py | 2 +- src/wormhole/_dilation/subchannel.py | 13 ++-- src/wormhole/test/dilate/test_full.py | 72 +++++++++++++++++++++ src/wormhole/test/dilate/test_manager.py | 6 +- src/wormhole/test/dilate/test_outbound.py | 4 +- src/wormhole/test/dilate/test_subchannel.py | 3 +- 8 files changed, 96 insertions(+), 14 deletions(-) diff --git a/src/wormhole/_dilation/inbound.py b/src/wormhole/_dilation/inbound.py index 3f7ed95..cfc01b3 100644 --- a/src/wormhole/_dilation/inbound.py +++ b/src/wormhole/_dilation/inbound.py @@ -3,7 +3,7 @@ from attr import attrs, attrib from attr.validators import provides from zope.interface import implementer from twisted.python import log -from .._interfaces import IDilationManager, IInbound +from .._interfaces import IDilationManager, IInbound, ISubChannel from .subchannel import (SubChannel, _SubchannelAddress) @@ -52,6 +52,11 @@ class Inbound(object): if self._paused_subchannels: self._connection.pauseProducing() + def subchannel_local_open(self, scid, sc): + assert ISubChannel.providedBy(sc) + assert scid not in self._open_subchannels + self._open_subchannels[scid] = sc + # Inbound is responsible for tracking the high watermark and deciding # whether to ignore inbound messages or not diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index 1e8262a..5ea2ae6 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -158,6 +158,9 @@ class Manager(object): def subchannel_stopProducing(self, sc): self._inbound.subchannel_stopProducing(sc) + def subchannel_local_open(self, scid, sc): + self._inbound.subchannel_local_open(scid, sc) + # forward outbound-ish things to _Outbound def subchannel_registerProducer(self, sc, producer, streaming): self._outbound.subchannel_registerProducer(sc, producer, streaming) diff --git a/src/wormhole/_dilation/outbound.py b/src/wormhole/_dilation/outbound.py index 96786ca..c958849 100644 --- a/src/wormhole/_dilation/outbound.py +++ b/src/wormhole/_dilation/outbound.py @@ -253,7 +253,7 @@ class Outbound(object): self._unpaused_producers.discard(p) self._check_invariants() - def subchannel_closed(self, sc): + def subchannel_closed(self, scid, sc): self._check_invariants() if sc in self._subchannel_producers: self.subchannel_unregisterProducer(sc) diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py index b32fd4e..8b0e7bf 100644 --- a/src/wormhole/_dilation/subchannel.py +++ b/src/wormhole/_dilation/subchannel.py @@ -59,7 +59,7 @@ class _SubchannelAddress(object): _scid = attrib(validator=instance_of(six.integer_types)) -@attrs +@attrs(cmp=False) @implementer(ITransport) @implementer(IProducer) @implementer(IConsumer) @@ -131,7 +131,7 @@ class SubChannel(object): self._protocol.connectionLost(ConnectionDone()) else: self._pending_connectionLost = (True, ConnectionDone()) - self._manager.subchannel_closed(self) + self._manager.subchannel_closed(self._scid, self) # we're deleted momentarily @m.output() @@ -146,7 +146,7 @@ class SubChannel(object): # primary transitions open.upon(remote_data, enter=open, outputs=[signal_dataReceived]) open.upon(local_data, enter=open, outputs=[send_data]) - open.upon(remote_close, enter=closed, outputs=[signal_connectionLost]) + open.upon(remote_close, enter=closed, outputs=[send_close, signal_connectionLost]) open.upon(local_close, enter=closing, outputs=[send_close]) closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived]) closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost]) @@ -245,10 +245,11 @@ class SubchannelConnectorEndpoint(object): peer_addr = _SubchannelAddress(scid) # ? f.doStart() # ? f.startedConnecting(CONNECTOR) # ?? - t = SubChannel(scid, self._manager, self._host_addr, peer_addr) + sc = SubChannel(scid, self._manager, self._host_addr, peer_addr) + self._manager.subchannel_local_open(scid, sc) p = protocolFactory.buildProtocol(peer_addr) - t._set_protocol(p) - p.makeConnection(t) # set p.transport = t and call connectionMade() + sc._set_protocol(p) + p.makeConnection(sc) # set p.transport = sc and call connectionMade() return succeed(p) diff --git a/src/wormhole/test/dilate/test_full.py b/src/wormhole/test/dilate/test_full.py index 25025f0..cc0e8b0 100644 --- a/src/wormhole/test/dilate/test_full.py +++ b/src/wormhole/test/dilate/test_full.py @@ -267,3 +267,75 @@ class Reconnect(ServerBase, unittest.TestCase): yield w1.close() yield w2.close() + + +class Endpoints(ServerBase, unittest.TestCase): + @inlineCallbacks + def setUp(self): + if not NoiseConnection: + raise unittest.SkipTest("noiseprotocol unavailable") + # test_welcome wants to see [current_cli_version] + yield self._setup_relay(None) + + @inlineCallbacks + def test_endpoints(self): + eq = EventualQueue(reactor) + w1 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True) + w2 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True) + w1.allocate_code() + code = yield w1.get_code() + w2.set_code(code) + yield doBoth(w1.get_verifier(), w2.get_verifier()) + + eps1_d = w1.dilate() + eps2_d = w2.dilate() + (eps1, eps2) = yield doBoth(eps1_d, eps2_d) + (control_ep1, connect_ep1, listen_ep1) = eps1 + (control_ep2, connect_ep2, listen_ep2) = eps2 + + f0 = ReconF(eq) + yield listen_ep2.listen(f0) + + from twisted.python import log + f1 = ReconF(eq) + log.msg("connecting") + p1_client = yield connect_ep1.connect(f1) + log.msg("sending c->s") + p1_client.transport.write(b"hello from p1\n") + data = yield f0.deferreds["dataReceived"] + self.assertEqual(data, b"hello from p1\n") + p1_server = self.successResultOf(f0.deferreds["connectionMade"]) + log.msg("sending s->c") + p1_server.transport.write(b"hello p1\n") + log.msg("waiting for client to receive") + data = yield f1.deferreds["dataReceived"] + self.assertEqual(data, b"hello p1\n") + + # open a second channel + f0.resetDeferred("connectionMade") + f0.resetDeferred("dataReceived") + f1.resetDeferred("dataReceived") + f2 = ReconF(eq) + p2_client = yield connect_ep1.connect(f2) + p2_server = yield f0.deferreds["connectionMade"] + p2_server.transport.write(b"hello p2\n") + data = yield f2.deferreds["dataReceived"] + self.assertEqual(data, b"hello p2\n") + p2_client.transport.write(b"hello from p2\n") + data = yield f0.deferreds["dataReceived"] + self.assertEqual(data, b"hello from p2\n") + self.assertNoResult(f1.deferreds["dataReceived"]) + + # now close the first subchannel (p1) from the listener side + p1_server.transport.loseConnection() + yield f0.deferreds["connectionLost"] + yield f1.deferreds["connectionLost"] + + f0.resetDeferred("connectionLost") + # and close the second from the connector side + p2_client.transport.loseConnection() + yield f0.deferreds["connectionLost"] + yield f2.deferreds["connectionLost"] + + yield w1.close() + yield w2.close() diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index 8ae0b20..e36e403 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -659,9 +659,9 @@ class TestManager(unittest.TestCase): mock.call.subchannel_unregisterProducer(sc)]) clear_mock_calls(h.outbound) - m.subchannel_closed("scid", sc) + m.subchannel_closed(4, sc) self.assertEqual(h.inbound.mock_calls, [ - mock.call.subchannel_closed("scid", sc)]) + mock.call.subchannel_closed(4, sc)]) self.assertEqual(h.outbound.mock_calls, [ - mock.call.subchannel_closed("scid", sc)]) + mock.call.subchannel_closed(4, sc)]) clear_mock_calls(h.inbound, h.outbound) diff --git a/src/wormhole/test/dilate/test_outbound.py b/src/wormhole/test/dilate/test_outbound.py index ed43a47..7521f2d 100644 --- a/src/wormhole/test/dilate/test_outbound.py +++ b/src/wormhole/test/dilate/test_outbound.py @@ -337,12 +337,12 @@ class OutboundTest(unittest.TestCase): self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()]) clear_mock_calls(p1) - o.subchannel_closed(sc1) + o.subchannel_closed(1, sc1) self.assertEqual(p1.mock_calls, []) self.assertEqual(list(o._all_producers), []) sc2 = mock.Mock() - o.subchannel_closed(sc2) + o.subchannel_closed(2, sc2) def test_disconnect(self): o, m, c = make_outbound() diff --git a/src/wormhole/test/dilate/test_subchannel.py b/src/wormhole/test/dilate/test_subchannel.py index b079d00..c26248a 100644 --- a/src/wormhole/test/dilate/test_subchannel.py +++ b/src/wormhole/test/dilate/test_subchannel.py @@ -92,7 +92,8 @@ class SubChannelAPI(unittest.TestCase): def test_remote_close(self): sc, m, scid, hostaddr, peeraddr, p = make_sc() sc.remote_close() - self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(sc)]) + self.assertEqual(m.mock_calls, [mock.call.send_close(scid), + mock.call.subchannel_closed(scid, sc)]) self.assert_connectionDone(p.mock_calls) def test_data(self):