twisted.transit: implement producer/consumer flow control
This commit is contained in:
parent
763d72f582
commit
a235b507c8
|
@ -4,6 +4,7 @@ from twisted.trial import unittest
|
|||
from twisted.internet import defer, task, endpoints, protocol, address, error
|
||||
from twisted.internet.defer import gatherResults, inlineCallbacks
|
||||
from twisted.python import log, failure
|
||||
from twisted.test import proto_helpers
|
||||
from ..twisted import transit
|
||||
from ..errors import UsageError
|
||||
from nacl.secret import SecretBox
|
||||
|
@ -991,6 +992,53 @@ class Connection(unittest.TestCase):
|
|||
self.assertIsInstance(f, failure.Failure)
|
||||
self.assertIsInstance(f.value, error.ConnectionClosed)
|
||||
|
||||
def test_producer(self):
|
||||
# a Transit object (receiving data from the remote peer) produces
|
||||
# data and writes it into a local Consumer
|
||||
c = transit.Connection(None, None, None)
|
||||
c.transport = proto_helpers.StringTransport()
|
||||
c.recordReceived(b"r1.")
|
||||
c.recordReceived(b"r2.")
|
||||
|
||||
consumer = proto_helpers.StringTransport()
|
||||
c.connectConsumer(consumer)
|
||||
self.assertIs(c._consumer, consumer)
|
||||
self.assertEqual(consumer.value(), b"r1.r2.")
|
||||
|
||||
self.assertRaises(RuntimeError, c.connectConsumer, consumer)
|
||||
|
||||
c.recordReceived(b"r3.")
|
||||
self.assertEqual(consumer.value(), b"r1.r2.r3.")
|
||||
|
||||
c.pauseProducing()
|
||||
self.assertEqual(c.transport.producerState, "paused")
|
||||
c.resumeProducing()
|
||||
self.assertEqual(c.transport.producerState, "producing")
|
||||
|
||||
c.disconnectConsumer()
|
||||
self.assertEqual(consumer.producer, None)
|
||||
c.connectConsumer(consumer)
|
||||
|
||||
c.stopProducing()
|
||||
self.assertEqual(c.transport.producerState, "stopped")
|
||||
|
||||
def test_consumer(self):
|
||||
# a local producer sends data to a consuming Transit object
|
||||
c = transit.Connection(None, None, None)
|
||||
c.transport = proto_helpers.StringTransport()
|
||||
records = []
|
||||
c.send_record = records.append
|
||||
|
||||
producer = proto_helpers.StringTransport()
|
||||
c.registerProducer(producer, True)
|
||||
self.assertIs(c.transport.producer, producer)
|
||||
|
||||
c.write(b"r1.")
|
||||
self.assertEqual(records, [b"r1."])
|
||||
|
||||
c.unregisterProducer()
|
||||
self.assertEqual(c.transport.producer, None)
|
||||
|
||||
|
||||
DIRECT_HINT = u"tcp:direct:1234"
|
||||
RELAY_HINT = u"tcp:relay:1234"
|
||||
|
|
|
@ -34,6 +34,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
|||
self.start = start
|
||||
self._negotiation_d = defer.Deferred(self._cancel)
|
||||
self._error = None
|
||||
self._consumer = None
|
||||
self._inbound_records = collections.deque()
|
||||
self._waiting_reads = collections.deque()
|
||||
|
||||
|
@ -173,12 +174,11 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
|||
self.transport.write(encrypted)
|
||||
|
||||
def recordReceived(self, record):
|
||||
if self._consumer:
|
||||
self._consumer.write(record)
|
||||
return
|
||||
self._inbound_records.append(record)
|
||||
self._deliverRecords()
|
||||
self._checkPause()
|
||||
|
||||
def _checkPause(self):
|
||||
return # TODO
|
||||
|
||||
def receive_record(self):
|
||||
d = defer.Deferred()
|
||||
|
@ -191,7 +191,6 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
|||
r = self._inbound_records.popleft()
|
||||
d = self._waiting_reads.popleft()
|
||||
d.callback(r)
|
||||
self._checkPause()
|
||||
|
||||
def close(self):
|
||||
self.transport.loseConnection()
|
||||
|
@ -222,7 +221,41 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
|||
|
||||
d.errback(self._error or BadHandshake("connection lost"))
|
||||
|
||||
# IConsumer methods, for outbound flow-control. We pass these through to
|
||||
# the transport. The 'producer' is something like a t.p.basic.FileSender
|
||||
def registerProducer(self, producer, streaming):
|
||||
assert interfaces.IConsumer.providedBy(self.transport)
|
||||
self.transport.registerProducer(producer, streaming)
|
||||
def unregisterProducer(self):
|
||||
self.transport.unregisterProducer()
|
||||
def write(self, data):
|
||||
self.send_record(data)
|
||||
|
||||
# IProducer methods, for inbound flow-control. We pass these through to
|
||||
# the transport.
|
||||
def stopProducing(self):
|
||||
self.transport.stopProducing()
|
||||
def pauseProducing(self):
|
||||
self.transport.pauseProducing()
|
||||
def resumeProducing(self):
|
||||
self.transport.resumeProducing()
|
||||
|
||||
# Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to us.
|
||||
# Inbound records will be written as bytes to the consumer.
|
||||
def connectConsumer(self, consumer):
|
||||
if self._consumer:
|
||||
raise RuntimeError("A consumer is already attached: %r" %
|
||||
self._consumer)
|
||||
self._consumer = consumer
|
||||
# drain any pending records
|
||||
while self._inbound_records:
|
||||
r = self._inbound_records.popleft()
|
||||
consumer.write(r)
|
||||
consumer.registerProducer(self, True)
|
||||
|
||||
def disconnectConsumer(self):
|
||||
self._consumer.unregisterProducer()
|
||||
self._consumer = None
|
||||
|
||||
class OutboundConnectionFactory(protocol.ClientFactory):
|
||||
protocol = Connection
|
||||
|
|
Loading…
Reference in New Issue
Block a user