twisted.transit: implement producer/consumer flow control

This commit is contained in:
Brian Warner 2016-02-15 21:23:20 -08:00
parent 763d72f582
commit a235b507c8
2 changed files with 86 additions and 5 deletions

View File

@ -4,6 +4,7 @@ from twisted.trial import unittest
from twisted.internet import defer, task, endpoints, protocol, address, error from twisted.internet import defer, task, endpoints, protocol, address, error
from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.internet.defer import gatherResults, inlineCallbacks
from twisted.python import log, failure from twisted.python import log, failure
from twisted.test import proto_helpers
from ..twisted import transit from ..twisted import transit
from ..errors import UsageError from ..errors import UsageError
from nacl.secret import SecretBox from nacl.secret import SecretBox
@ -991,6 +992,53 @@ class Connection(unittest.TestCase):
self.assertIsInstance(f, failure.Failure) self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, error.ConnectionClosed) self.assertIsInstance(f.value, error.ConnectionClosed)
def test_producer(self):
# a Transit object (receiving data from the remote peer) produces
# data and writes it into a local Consumer
c = transit.Connection(None, None, None)
c.transport = proto_helpers.StringTransport()
c.recordReceived(b"r1.")
c.recordReceived(b"r2.")
consumer = proto_helpers.StringTransport()
c.connectConsumer(consumer)
self.assertIs(c._consumer, consumer)
self.assertEqual(consumer.value(), b"r1.r2.")
self.assertRaises(RuntimeError, c.connectConsumer, consumer)
c.recordReceived(b"r3.")
self.assertEqual(consumer.value(), b"r1.r2.r3.")
c.pauseProducing()
self.assertEqual(c.transport.producerState, "paused")
c.resumeProducing()
self.assertEqual(c.transport.producerState, "producing")
c.disconnectConsumer()
self.assertEqual(consumer.producer, None)
c.connectConsumer(consumer)
c.stopProducing()
self.assertEqual(c.transport.producerState, "stopped")
def test_consumer(self):
# a local producer sends data to a consuming Transit object
c = transit.Connection(None, None, None)
c.transport = proto_helpers.StringTransport()
records = []
c.send_record = records.append
producer = proto_helpers.StringTransport()
c.registerProducer(producer, True)
self.assertIs(c.transport.producer, producer)
c.write(b"r1.")
self.assertEqual(records, [b"r1."])
c.unregisterProducer()
self.assertEqual(c.transport.producer, None)
DIRECT_HINT = u"tcp:direct:1234" DIRECT_HINT = u"tcp:direct:1234"
RELAY_HINT = u"tcp:relay:1234" RELAY_HINT = u"tcp:relay:1234"

View File

@ -34,6 +34,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
self.start = start self.start = start
self._negotiation_d = defer.Deferred(self._cancel) self._negotiation_d = defer.Deferred(self._cancel)
self._error = None self._error = None
self._consumer = None
self._inbound_records = collections.deque() self._inbound_records = collections.deque()
self._waiting_reads = collections.deque() self._waiting_reads = collections.deque()
@ -173,12 +174,11 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
self.transport.write(encrypted) self.transport.write(encrypted)
def recordReceived(self, record): def recordReceived(self, record):
if self._consumer:
self._consumer.write(record)
return
self._inbound_records.append(record) self._inbound_records.append(record)
self._deliverRecords() self._deliverRecords()
self._checkPause()
def _checkPause(self):
return # TODO
def receive_record(self): def receive_record(self):
d = defer.Deferred() d = defer.Deferred()
@ -191,7 +191,6 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
r = self._inbound_records.popleft() r = self._inbound_records.popleft()
d = self._waiting_reads.popleft() d = self._waiting_reads.popleft()
d.callback(r) d.callback(r)
self._checkPause()
def close(self): def close(self):
self.transport.loseConnection() self.transport.loseConnection()
@ -222,7 +221,41 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
d.errback(self._error or BadHandshake("connection lost")) d.errback(self._error or BadHandshake("connection lost"))
# IConsumer methods, for outbound flow-control. We pass these through to
# the transport. The 'producer' is something like a t.p.basic.FileSender
def registerProducer(self, producer, streaming):
assert interfaces.IConsumer.providedBy(self.transport)
self.transport.registerProducer(producer, streaming)
def unregisterProducer(self):
self.transport.unregisterProducer()
def write(self, data):
self.send_record(data)
# IProducer methods, for inbound flow-control. We pass these through to
# the transport.
def stopProducing(self):
self.transport.stopProducing()
def pauseProducing(self):
self.transport.pauseProducing()
def resumeProducing(self):
self.transport.resumeProducing()
# Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to us.
# Inbound records will be written as bytes to the consumer.
def connectConsumer(self, consumer):
if self._consumer:
raise RuntimeError("A consumer is already attached: %r" %
self._consumer)
self._consumer = consumer
# drain any pending records
while self._inbound_records:
r = self._inbound_records.popleft()
consumer.write(r)
consumer.registerProducer(self, True)
def disconnectConsumer(self):
self._consumer.unregisterProducer()
self._consumer = None
class OutboundConnectionFactory(protocol.ClientFactory): class OutboundConnectionFactory(protocol.ClientFactory):
protocol = Connection protocol = Connection