fix subchannel open/close, add test
I think I just managed to forget that inbound_close requires we respond with a close ourselves. Also outbound open means we must add the subchannel to the inbound table, so we can receive any data on it at all.
This commit is contained in:
parent
8043e508fa
commit
d1aefa815d
|
@ -3,7 +3,7 @@ from attr import attrs, attrib
|
|||
from attr.validators import provides
|
||||
from zope.interface import implementer
|
||||
from twisted.python import log
|
||||
from .._interfaces import IDilationManager, IInbound
|
||||
from .._interfaces import IDilationManager, IInbound, ISubChannel
|
||||
from .subchannel import (SubChannel, _SubchannelAddress)
|
||||
|
||||
|
||||
|
@ -52,6 +52,11 @@ class Inbound(object):
|
|||
if self._paused_subchannels:
|
||||
self._connection.pauseProducing()
|
||||
|
||||
def subchannel_local_open(self, scid, sc):
|
||||
assert ISubChannel.providedBy(sc)
|
||||
assert scid not in self._open_subchannels
|
||||
self._open_subchannels[scid] = sc
|
||||
|
||||
# Inbound is responsible for tracking the high watermark and deciding
|
||||
# whether to ignore inbound messages or not
|
||||
|
||||
|
|
|
@ -158,6 +158,9 @@ class Manager(object):
|
|||
def subchannel_stopProducing(self, sc):
|
||||
self._inbound.subchannel_stopProducing(sc)
|
||||
|
||||
def subchannel_local_open(self, scid, sc):
|
||||
self._inbound.subchannel_local_open(scid, sc)
|
||||
|
||||
# forward outbound-ish things to _Outbound
|
||||
def subchannel_registerProducer(self, sc, producer, streaming):
|
||||
self._outbound.subchannel_registerProducer(sc, producer, streaming)
|
||||
|
|
|
@ -253,7 +253,7 @@ class Outbound(object):
|
|||
self._unpaused_producers.discard(p)
|
||||
self._check_invariants()
|
||||
|
||||
def subchannel_closed(self, sc):
|
||||
def subchannel_closed(self, scid, sc):
|
||||
self._check_invariants()
|
||||
if sc in self._subchannel_producers:
|
||||
self.subchannel_unregisterProducer(sc)
|
||||
|
|
|
@ -59,7 +59,7 @@ class _SubchannelAddress(object):
|
|||
_scid = attrib(validator=instance_of(six.integer_types))
|
||||
|
||||
|
||||
@attrs
|
||||
@attrs(cmp=False)
|
||||
@implementer(ITransport)
|
||||
@implementer(IProducer)
|
||||
@implementer(IConsumer)
|
||||
|
@ -131,7 +131,7 @@ class SubChannel(object):
|
|||
self._protocol.connectionLost(ConnectionDone())
|
||||
else:
|
||||
self._pending_connectionLost = (True, ConnectionDone())
|
||||
self._manager.subchannel_closed(self)
|
||||
self._manager.subchannel_closed(self._scid, self)
|
||||
# we're deleted momentarily
|
||||
|
||||
@m.output()
|
||||
|
@ -146,7 +146,7 @@ class SubChannel(object):
|
|||
# primary transitions
|
||||
open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
|
||||
open.upon(local_data, enter=open, outputs=[send_data])
|
||||
open.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
|
||||
open.upon(remote_close, enter=closed, outputs=[send_close, signal_connectionLost])
|
||||
open.upon(local_close, enter=closing, outputs=[send_close])
|
||||
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
|
||||
closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
|
||||
|
@ -245,10 +245,11 @@ class SubchannelConnectorEndpoint(object):
|
|||
peer_addr = _SubchannelAddress(scid)
|
||||
# ? f.doStart()
|
||||
# ? f.startedConnecting(CONNECTOR) # ??
|
||||
t = SubChannel(scid, self._manager, self._host_addr, peer_addr)
|
||||
sc = SubChannel(scid, self._manager, self._host_addr, peer_addr)
|
||||
self._manager.subchannel_local_open(scid, sc)
|
||||
p = protocolFactory.buildProtocol(peer_addr)
|
||||
t._set_protocol(p)
|
||||
p.makeConnection(t) # set p.transport = t and call connectionMade()
|
||||
sc._set_protocol(p)
|
||||
p.makeConnection(sc) # set p.transport = sc and call connectionMade()
|
||||
return succeed(p)
|
||||
|
||||
|
||||
|
|
|
@ -267,3 +267,75 @@ class Reconnect(ServerBase, unittest.TestCase):
|
|||
|
||||
yield w1.close()
|
||||
yield w2.close()
|
||||
|
||||
|
||||
class Endpoints(ServerBase, unittest.TestCase):
|
||||
@inlineCallbacks
|
||||
def setUp(self):
|
||||
if not NoiseConnection:
|
||||
raise unittest.SkipTest("noiseprotocol unavailable")
|
||||
# test_welcome wants to see [current_cli_version]
|
||||
yield self._setup_relay(None)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_endpoints(self):
|
||||
eq = EventualQueue(reactor)
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True)
|
||||
w2 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True)
|
||||
w1.allocate_code()
|
||||
code = yield w1.get_code()
|
||||
w2.set_code(code)
|
||||
yield doBoth(w1.get_verifier(), w2.get_verifier())
|
||||
|
||||
eps1_d = w1.dilate()
|
||||
eps2_d = w2.dilate()
|
||||
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
|
||||
(control_ep1, connect_ep1, listen_ep1) = eps1
|
||||
(control_ep2, connect_ep2, listen_ep2) = eps2
|
||||
|
||||
f0 = ReconF(eq)
|
||||
yield listen_ep2.listen(f0)
|
||||
|
||||
from twisted.python import log
|
||||
f1 = ReconF(eq)
|
||||
log.msg("connecting")
|
||||
p1_client = yield connect_ep1.connect(f1)
|
||||
log.msg("sending c->s")
|
||||
p1_client.transport.write(b"hello from p1\n")
|
||||
data = yield f0.deferreds["dataReceived"]
|
||||
self.assertEqual(data, b"hello from p1\n")
|
||||
p1_server = self.successResultOf(f0.deferreds["connectionMade"])
|
||||
log.msg("sending s->c")
|
||||
p1_server.transport.write(b"hello p1\n")
|
||||
log.msg("waiting for client to receive")
|
||||
data = yield f1.deferreds["dataReceived"]
|
||||
self.assertEqual(data, b"hello p1\n")
|
||||
|
||||
# open a second channel
|
||||
f0.resetDeferred("connectionMade")
|
||||
f0.resetDeferred("dataReceived")
|
||||
f1.resetDeferred("dataReceived")
|
||||
f2 = ReconF(eq)
|
||||
p2_client = yield connect_ep1.connect(f2)
|
||||
p2_server = yield f0.deferreds["connectionMade"]
|
||||
p2_server.transport.write(b"hello p2\n")
|
||||
data = yield f2.deferreds["dataReceived"]
|
||||
self.assertEqual(data, b"hello p2\n")
|
||||
p2_client.transport.write(b"hello from p2\n")
|
||||
data = yield f0.deferreds["dataReceived"]
|
||||
self.assertEqual(data, b"hello from p2\n")
|
||||
self.assertNoResult(f1.deferreds["dataReceived"])
|
||||
|
||||
# now close the first subchannel (p1) from the listener side
|
||||
p1_server.transport.loseConnection()
|
||||
yield f0.deferreds["connectionLost"]
|
||||
yield f1.deferreds["connectionLost"]
|
||||
|
||||
f0.resetDeferred("connectionLost")
|
||||
# and close the second from the connector side
|
||||
p2_client.transport.loseConnection()
|
||||
yield f0.deferreds["connectionLost"]
|
||||
yield f2.deferreds["connectionLost"]
|
||||
|
||||
yield w1.close()
|
||||
yield w2.close()
|
||||
|
|
|
@ -659,9 +659,9 @@ class TestManager(unittest.TestCase):
|
|||
mock.call.subchannel_unregisterProducer(sc)])
|
||||
clear_mock_calls(h.outbound)
|
||||
|
||||
m.subchannel_closed("scid", sc)
|
||||
m.subchannel_closed(4, sc)
|
||||
self.assertEqual(h.inbound.mock_calls, [
|
||||
mock.call.subchannel_closed("scid", sc)])
|
||||
mock.call.subchannel_closed(4, sc)])
|
||||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.subchannel_closed("scid", sc)])
|
||||
mock.call.subchannel_closed(4, sc)])
|
||||
clear_mock_calls(h.inbound, h.outbound)
|
||||
|
|
|
@ -337,12 +337,12 @@ class OutboundTest(unittest.TestCase):
|
|||
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
|
||||
clear_mock_calls(p1)
|
||||
|
||||
o.subchannel_closed(sc1)
|
||||
o.subchannel_closed(1, sc1)
|
||||
self.assertEqual(p1.mock_calls, [])
|
||||
self.assertEqual(list(o._all_producers), [])
|
||||
|
||||
sc2 = mock.Mock()
|
||||
o.subchannel_closed(sc2)
|
||||
o.subchannel_closed(2, sc2)
|
||||
|
||||
def test_disconnect(self):
|
||||
o, m, c = make_outbound()
|
||||
|
|
|
@ -92,7 +92,8 @@ class SubChannelAPI(unittest.TestCase):
|
|||
def test_remote_close(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
sc.remote_close()
|
||||
self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(sc)])
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_close(scid),
|
||||
mock.call.subchannel_closed(scid, sc)])
|
||||
self.assert_connectionDone(p.mock_calls)
|
||||
|
||||
def test_data(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user