fix subchannel open/close, add test

I think I just managed to forget that inbound_close requires we respond with
a close ourselves. Also outbound open means we must add the subchannel to the
inbound table, so we can receive any data on it at all.
This commit is contained in:
Brian Warner 2019-07-06 01:50:29 -07:00
parent 8043e508fa
commit d1aefa815d
8 changed files with 96 additions and 14 deletions

View File

@ -3,7 +3,7 @@ from attr import attrs, attrib
from attr.validators import provides from attr.validators import provides
from zope.interface import implementer from zope.interface import implementer
from twisted.python import log from twisted.python import log
from .._interfaces import IDilationManager, IInbound from .._interfaces import IDilationManager, IInbound, ISubChannel
from .subchannel import (SubChannel, _SubchannelAddress) from .subchannel import (SubChannel, _SubchannelAddress)
@ -52,6 +52,11 @@ class Inbound(object):
if self._paused_subchannels: if self._paused_subchannels:
self._connection.pauseProducing() self._connection.pauseProducing()
def subchannel_local_open(self, scid, sc):
assert ISubChannel.providedBy(sc)
assert scid not in self._open_subchannels
self._open_subchannels[scid] = sc
# Inbound is responsible for tracking the high watermark and deciding # Inbound is responsible for tracking the high watermark and deciding
# whether to ignore inbound messages or not # whether to ignore inbound messages or not

View File

@ -158,6 +158,9 @@ class Manager(object):
def subchannel_stopProducing(self, sc): def subchannel_stopProducing(self, sc):
self._inbound.subchannel_stopProducing(sc) self._inbound.subchannel_stopProducing(sc)
def subchannel_local_open(self, scid, sc):
self._inbound.subchannel_local_open(scid, sc)
# forward outbound-ish things to _Outbound # forward outbound-ish things to _Outbound
def subchannel_registerProducer(self, sc, producer, streaming): def subchannel_registerProducer(self, sc, producer, streaming):
self._outbound.subchannel_registerProducer(sc, producer, streaming) self._outbound.subchannel_registerProducer(sc, producer, streaming)

View File

@ -253,7 +253,7 @@ class Outbound(object):
self._unpaused_producers.discard(p) self._unpaused_producers.discard(p)
self._check_invariants() self._check_invariants()
def subchannel_closed(self, sc): def subchannel_closed(self, scid, sc):
self._check_invariants() self._check_invariants()
if sc in self._subchannel_producers: if sc in self._subchannel_producers:
self.subchannel_unregisterProducer(sc) self.subchannel_unregisterProducer(sc)

View File

@ -59,7 +59,7 @@ class _SubchannelAddress(object):
_scid = attrib(validator=instance_of(six.integer_types)) _scid = attrib(validator=instance_of(six.integer_types))
@attrs @attrs(cmp=False)
@implementer(ITransport) @implementer(ITransport)
@implementer(IProducer) @implementer(IProducer)
@implementer(IConsumer) @implementer(IConsumer)
@ -131,7 +131,7 @@ class SubChannel(object):
self._protocol.connectionLost(ConnectionDone()) self._protocol.connectionLost(ConnectionDone())
else: else:
self._pending_connectionLost = (True, ConnectionDone()) self._pending_connectionLost = (True, ConnectionDone())
self._manager.subchannel_closed(self) self._manager.subchannel_closed(self._scid, self)
# we're deleted momentarily # we're deleted momentarily
@m.output() @m.output()
@ -146,7 +146,7 @@ class SubChannel(object):
# primary transitions # primary transitions
open.upon(remote_data, enter=open, outputs=[signal_dataReceived]) open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
open.upon(local_data, enter=open, outputs=[send_data]) open.upon(local_data, enter=open, outputs=[send_data])
open.upon(remote_close, enter=closed, outputs=[signal_connectionLost]) open.upon(remote_close, enter=closed, outputs=[send_close, signal_connectionLost])
open.upon(local_close, enter=closing, outputs=[send_close]) open.upon(local_close, enter=closing, outputs=[send_close])
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived]) closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost]) closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
@ -245,10 +245,11 @@ class SubchannelConnectorEndpoint(object):
peer_addr = _SubchannelAddress(scid) peer_addr = _SubchannelAddress(scid)
# ? f.doStart() # ? f.doStart()
# ? f.startedConnecting(CONNECTOR) # ?? # ? f.startedConnecting(CONNECTOR) # ??
t = SubChannel(scid, self._manager, self._host_addr, peer_addr) sc = SubChannel(scid, self._manager, self._host_addr, peer_addr)
self._manager.subchannel_local_open(scid, sc)
p = protocolFactory.buildProtocol(peer_addr) p = protocolFactory.buildProtocol(peer_addr)
t._set_protocol(p) sc._set_protocol(p)
p.makeConnection(t) # set p.transport = t and call connectionMade() p.makeConnection(sc) # set p.transport = sc and call connectionMade()
return succeed(p) return succeed(p)

View File

@ -267,3 +267,75 @@ class Reconnect(ServerBase, unittest.TestCase):
yield w1.close() yield w1.close()
yield w2.close() yield w2.close()
class Endpoints(ServerBase, unittest.TestCase):
@inlineCallbacks
def setUp(self):
if not NoiseConnection:
raise unittest.SkipTest("noiseprotocol unavailable")
# test_welcome wants to see [current_cli_version]
yield self._setup_relay(None)
@inlineCallbacks
def test_endpoints(self):
eq = EventualQueue(reactor)
w1 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True)
w2 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True)
w1.allocate_code()
code = yield w1.get_code()
w2.set_code(code)
yield doBoth(w1.get_verifier(), w2.get_verifier())
eps1_d = w1.dilate()
eps2_d = w2.dilate()
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
(control_ep1, connect_ep1, listen_ep1) = eps1
(control_ep2, connect_ep2, listen_ep2) = eps2
f0 = ReconF(eq)
yield listen_ep2.listen(f0)
from twisted.python import log
f1 = ReconF(eq)
log.msg("connecting")
p1_client = yield connect_ep1.connect(f1)
log.msg("sending c->s")
p1_client.transport.write(b"hello from p1\n")
data = yield f0.deferreds["dataReceived"]
self.assertEqual(data, b"hello from p1\n")
p1_server = self.successResultOf(f0.deferreds["connectionMade"])
log.msg("sending s->c")
p1_server.transport.write(b"hello p1\n")
log.msg("waiting for client to receive")
data = yield f1.deferreds["dataReceived"]
self.assertEqual(data, b"hello p1\n")
# open a second channel
f0.resetDeferred("connectionMade")
f0.resetDeferred("dataReceived")
f1.resetDeferred("dataReceived")
f2 = ReconF(eq)
p2_client = yield connect_ep1.connect(f2)
p2_server = yield f0.deferreds["connectionMade"]
p2_server.transport.write(b"hello p2\n")
data = yield f2.deferreds["dataReceived"]
self.assertEqual(data, b"hello p2\n")
p2_client.transport.write(b"hello from p2\n")
data = yield f0.deferreds["dataReceived"]
self.assertEqual(data, b"hello from p2\n")
self.assertNoResult(f1.deferreds["dataReceived"])
# now close the first subchannel (p1) from the listener side
p1_server.transport.loseConnection()
yield f0.deferreds["connectionLost"]
yield f1.deferreds["connectionLost"]
f0.resetDeferred("connectionLost")
# and close the second from the connector side
p2_client.transport.loseConnection()
yield f0.deferreds["connectionLost"]
yield f2.deferreds["connectionLost"]
yield w1.close()
yield w2.close()

View File

@ -659,9 +659,9 @@ class TestManager(unittest.TestCase):
mock.call.subchannel_unregisterProducer(sc)]) mock.call.subchannel_unregisterProducer(sc)])
clear_mock_calls(h.outbound) clear_mock_calls(h.outbound)
m.subchannel_closed("scid", sc) m.subchannel_closed(4, sc)
self.assertEqual(h.inbound.mock_calls, [ self.assertEqual(h.inbound.mock_calls, [
mock.call.subchannel_closed("scid", sc)]) mock.call.subchannel_closed(4, sc)])
self.assertEqual(h.outbound.mock_calls, [ self.assertEqual(h.outbound.mock_calls, [
mock.call.subchannel_closed("scid", sc)]) mock.call.subchannel_closed(4, sc)])
clear_mock_calls(h.inbound, h.outbound) clear_mock_calls(h.inbound, h.outbound)

View File

@ -337,12 +337,12 @@ class OutboundTest(unittest.TestCase):
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()]) self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
clear_mock_calls(p1) clear_mock_calls(p1)
o.subchannel_closed(sc1) o.subchannel_closed(1, sc1)
self.assertEqual(p1.mock_calls, []) self.assertEqual(p1.mock_calls, [])
self.assertEqual(list(o._all_producers), []) self.assertEqual(list(o._all_producers), [])
sc2 = mock.Mock() sc2 = mock.Mock()
o.subchannel_closed(sc2) o.subchannel_closed(2, sc2)
def test_disconnect(self): def test_disconnect(self):
o, m, c = make_outbound() o, m, c = make_outbound()

View File

@ -92,7 +92,8 @@ class SubChannelAPI(unittest.TestCase):
def test_remote_close(self): def test_remote_close(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc() sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.remote_close() sc.remote_close()
self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(sc)]) self.assertEqual(m.mock_calls, [mock.call.send_close(scid),
mock.call.subchannel_closed(scid, sc)])
self.assert_connectionDone(p.mock_calls) self.assert_connectionDone(p.mock_calls)
def test_data(self): def test_data(self):