fix subchannel open/close, add test

I think I just managed to forget that inbound_close requires we respond with
a close ourselves. Also outbound open means we must add the subchannel to the
inbound table, so we can receive any data on it at all.
This commit is contained in:
Brian Warner 2019-07-06 01:50:29 -07:00
parent 8043e508fa
commit d1aefa815d
8 changed files with 96 additions and 14 deletions

View File

@ -3,7 +3,7 @@ from attr import attrs, attrib
from attr.validators import provides
from zope.interface import implementer
from twisted.python import log
from .._interfaces import IDilationManager, IInbound
from .._interfaces import IDilationManager, IInbound, ISubChannel
from .subchannel import (SubChannel, _SubchannelAddress)
@ -52,6 +52,11 @@ class Inbound(object):
if self._paused_subchannels:
self._connection.pauseProducing()
def subchannel_local_open(self, scid, sc):
assert ISubChannel.providedBy(sc)
assert scid not in self._open_subchannels
self._open_subchannels[scid] = sc
# Inbound is responsible for tracking the high watermark and deciding
# whether to ignore inbound messages or not

View File

@ -158,6 +158,9 @@ class Manager(object):
def subchannel_stopProducing(self, sc):
self._inbound.subchannel_stopProducing(sc)
def subchannel_local_open(self, scid, sc):
self._inbound.subchannel_local_open(scid, sc)
# forward outbound-ish things to _Outbound
def subchannel_registerProducer(self, sc, producer, streaming):
self._outbound.subchannel_registerProducer(sc, producer, streaming)

View File

@ -253,7 +253,7 @@ class Outbound(object):
self._unpaused_producers.discard(p)
self._check_invariants()
def subchannel_closed(self, sc):
def subchannel_closed(self, scid, sc):
self._check_invariants()
if sc in self._subchannel_producers:
self.subchannel_unregisterProducer(sc)

View File

@ -59,7 +59,7 @@ class _SubchannelAddress(object):
_scid = attrib(validator=instance_of(six.integer_types))
@attrs
@attrs(cmp=False)
@implementer(ITransport)
@implementer(IProducer)
@implementer(IConsumer)
@ -131,7 +131,7 @@ class SubChannel(object):
self._protocol.connectionLost(ConnectionDone())
else:
self._pending_connectionLost = (True, ConnectionDone())
self._manager.subchannel_closed(self)
self._manager.subchannel_closed(self._scid, self)
# we're deleted momentarily
@m.output()
@ -146,7 +146,7 @@ class SubChannel(object):
# primary transitions
open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
open.upon(local_data, enter=open, outputs=[send_data])
open.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
open.upon(remote_close, enter=closed, outputs=[send_close, signal_connectionLost])
open.upon(local_close, enter=closing, outputs=[send_close])
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
@ -245,10 +245,11 @@ class SubchannelConnectorEndpoint(object):
peer_addr = _SubchannelAddress(scid)
# ? f.doStart()
# ? f.startedConnecting(CONNECTOR) # ??
t = SubChannel(scid, self._manager, self._host_addr, peer_addr)
sc = SubChannel(scid, self._manager, self._host_addr, peer_addr)
self._manager.subchannel_local_open(scid, sc)
p = protocolFactory.buildProtocol(peer_addr)
t._set_protocol(p)
p.makeConnection(t) # set p.transport = t and call connectionMade()
sc._set_protocol(p)
p.makeConnection(sc) # set p.transport = sc and call connectionMade()
return succeed(p)

View File

@ -267,3 +267,75 @@ class Reconnect(ServerBase, unittest.TestCase):
yield w1.close()
yield w2.close()
class Endpoints(ServerBase, unittest.TestCase):
@inlineCallbacks
def setUp(self):
if not NoiseConnection:
raise unittest.SkipTest("noiseprotocol unavailable")
# test_welcome wants to see [current_cli_version]
yield self._setup_relay(None)
@inlineCallbacks
def test_endpoints(self):
eq = EventualQueue(reactor)
w1 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True)
w2 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True)
w1.allocate_code()
code = yield w1.get_code()
w2.set_code(code)
yield doBoth(w1.get_verifier(), w2.get_verifier())
eps1_d = w1.dilate()
eps2_d = w2.dilate()
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
(control_ep1, connect_ep1, listen_ep1) = eps1
(control_ep2, connect_ep2, listen_ep2) = eps2
f0 = ReconF(eq)
yield listen_ep2.listen(f0)
from twisted.python import log
f1 = ReconF(eq)
log.msg("connecting")
p1_client = yield connect_ep1.connect(f1)
log.msg("sending c->s")
p1_client.transport.write(b"hello from p1\n")
data = yield f0.deferreds["dataReceived"]
self.assertEqual(data, b"hello from p1\n")
p1_server = self.successResultOf(f0.deferreds["connectionMade"])
log.msg("sending s->c")
p1_server.transport.write(b"hello p1\n")
log.msg("waiting for client to receive")
data = yield f1.deferreds["dataReceived"]
self.assertEqual(data, b"hello p1\n")
# open a second channel
f0.resetDeferred("connectionMade")
f0.resetDeferred("dataReceived")
f1.resetDeferred("dataReceived")
f2 = ReconF(eq)
p2_client = yield connect_ep1.connect(f2)
p2_server = yield f0.deferreds["connectionMade"]
p2_server.transport.write(b"hello p2\n")
data = yield f2.deferreds["dataReceived"]
self.assertEqual(data, b"hello p2\n")
p2_client.transport.write(b"hello from p2\n")
data = yield f0.deferreds["dataReceived"]
self.assertEqual(data, b"hello from p2\n")
self.assertNoResult(f1.deferreds["dataReceived"])
# now close the first subchannel (p1) from the listener side
p1_server.transport.loseConnection()
yield f0.deferreds["connectionLost"]
yield f1.deferreds["connectionLost"]
f0.resetDeferred("connectionLost")
# and close the second from the connector side
p2_client.transport.loseConnection()
yield f0.deferreds["connectionLost"]
yield f2.deferreds["connectionLost"]
yield w1.close()
yield w2.close()

View File

@ -659,9 +659,9 @@ class TestManager(unittest.TestCase):
mock.call.subchannel_unregisterProducer(sc)])
clear_mock_calls(h.outbound)
m.subchannel_closed("scid", sc)
m.subchannel_closed(4, sc)
self.assertEqual(h.inbound.mock_calls, [
mock.call.subchannel_closed("scid", sc)])
mock.call.subchannel_closed(4, sc)])
self.assertEqual(h.outbound.mock_calls, [
mock.call.subchannel_closed("scid", sc)])
mock.call.subchannel_closed(4, sc)])
clear_mock_calls(h.inbound, h.outbound)

View File

@ -337,12 +337,12 @@ class OutboundTest(unittest.TestCase):
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
clear_mock_calls(p1)
o.subchannel_closed(sc1)
o.subchannel_closed(1, sc1)
self.assertEqual(p1.mock_calls, [])
self.assertEqual(list(o._all_producers), [])
sc2 = mock.Mock()
o.subchannel_closed(sc2)
o.subchannel_closed(2, sc2)
def test_disconnect(self):
o, m, c = make_outbound()

View File

@ -92,7 +92,8 @@ class SubChannelAPI(unittest.TestCase):
def test_remote_close(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.remote_close()
self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(sc)])
self.assertEqual(m.mock_calls, [mock.call.send_close(scid),
mock.call.subchannel_closed(scid, sc)])
self.assert_connectionDone(p.mock_calls)
def test_data(self):