twisted.transit: implement producer/consumer flow control
This commit is contained in:
parent
763d72f582
commit
a235b507c8
|
@ -4,6 +4,7 @@ from twisted.trial import unittest
|
||||||
from twisted.internet import defer, task, endpoints, protocol, address, error
|
from twisted.internet import defer, task, endpoints, protocol, address, error
|
||||||
from twisted.internet.defer import gatherResults, inlineCallbacks
|
from twisted.internet.defer import gatherResults, inlineCallbacks
|
||||||
from twisted.python import log, failure
|
from twisted.python import log, failure
|
||||||
|
from twisted.test import proto_helpers
|
||||||
from ..twisted import transit
|
from ..twisted import transit
|
||||||
from ..errors import UsageError
|
from ..errors import UsageError
|
||||||
from nacl.secret import SecretBox
|
from nacl.secret import SecretBox
|
||||||
|
@ -991,6 +992,53 @@ class Connection(unittest.TestCase):
|
||||||
self.assertIsInstance(f, failure.Failure)
|
self.assertIsInstance(f, failure.Failure)
|
||||||
self.assertIsInstance(f.value, error.ConnectionClosed)
|
self.assertIsInstance(f.value, error.ConnectionClosed)
|
||||||
|
|
||||||
|
def test_producer(self):
|
||||||
|
# a Transit object (receiving data from the remote peer) produces
|
||||||
|
# data and writes it into a local Consumer
|
||||||
|
c = transit.Connection(None, None, None)
|
||||||
|
c.transport = proto_helpers.StringTransport()
|
||||||
|
c.recordReceived(b"r1.")
|
||||||
|
c.recordReceived(b"r2.")
|
||||||
|
|
||||||
|
consumer = proto_helpers.StringTransport()
|
||||||
|
c.connectConsumer(consumer)
|
||||||
|
self.assertIs(c._consumer, consumer)
|
||||||
|
self.assertEqual(consumer.value(), b"r1.r2.")
|
||||||
|
|
||||||
|
self.assertRaises(RuntimeError, c.connectConsumer, consumer)
|
||||||
|
|
||||||
|
c.recordReceived(b"r3.")
|
||||||
|
self.assertEqual(consumer.value(), b"r1.r2.r3.")
|
||||||
|
|
||||||
|
c.pauseProducing()
|
||||||
|
self.assertEqual(c.transport.producerState, "paused")
|
||||||
|
c.resumeProducing()
|
||||||
|
self.assertEqual(c.transport.producerState, "producing")
|
||||||
|
|
||||||
|
c.disconnectConsumer()
|
||||||
|
self.assertEqual(consumer.producer, None)
|
||||||
|
c.connectConsumer(consumer)
|
||||||
|
|
||||||
|
c.stopProducing()
|
||||||
|
self.assertEqual(c.transport.producerState, "stopped")
|
||||||
|
|
||||||
|
def test_consumer(self):
|
||||||
|
# a local producer sends data to a consuming Transit object
|
||||||
|
c = transit.Connection(None, None, None)
|
||||||
|
c.transport = proto_helpers.StringTransport()
|
||||||
|
records = []
|
||||||
|
c.send_record = records.append
|
||||||
|
|
||||||
|
producer = proto_helpers.StringTransport()
|
||||||
|
c.registerProducer(producer, True)
|
||||||
|
self.assertIs(c.transport.producer, producer)
|
||||||
|
|
||||||
|
c.write(b"r1.")
|
||||||
|
self.assertEqual(records, [b"r1."])
|
||||||
|
|
||||||
|
c.unregisterProducer()
|
||||||
|
self.assertEqual(c.transport.producer, None)
|
||||||
|
|
||||||
|
|
||||||
DIRECT_HINT = u"tcp:direct:1234"
|
DIRECT_HINT = u"tcp:direct:1234"
|
||||||
RELAY_HINT = u"tcp:relay:1234"
|
RELAY_HINT = u"tcp:relay:1234"
|
||||||
|
|
|
@ -34,6 +34,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
self.start = start
|
self.start = start
|
||||||
self._negotiation_d = defer.Deferred(self._cancel)
|
self._negotiation_d = defer.Deferred(self._cancel)
|
||||||
self._error = None
|
self._error = None
|
||||||
|
self._consumer = None
|
||||||
self._inbound_records = collections.deque()
|
self._inbound_records = collections.deque()
|
||||||
self._waiting_reads = collections.deque()
|
self._waiting_reads = collections.deque()
|
||||||
|
|
||||||
|
@ -173,12 +174,11 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
self.transport.write(encrypted)
|
self.transport.write(encrypted)
|
||||||
|
|
||||||
def recordReceived(self, record):
|
def recordReceived(self, record):
|
||||||
|
if self._consumer:
|
||||||
|
self._consumer.write(record)
|
||||||
|
return
|
||||||
self._inbound_records.append(record)
|
self._inbound_records.append(record)
|
||||||
self._deliverRecords()
|
self._deliverRecords()
|
||||||
self._checkPause()
|
|
||||||
|
|
||||||
def _checkPause(self):
|
|
||||||
return # TODO
|
|
||||||
|
|
||||||
def receive_record(self):
|
def receive_record(self):
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
|
@ -191,7 +191,6 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
r = self._inbound_records.popleft()
|
r = self._inbound_records.popleft()
|
||||||
d = self._waiting_reads.popleft()
|
d = self._waiting_reads.popleft()
|
||||||
d.callback(r)
|
d.callback(r)
|
||||||
self._checkPause()
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.transport.loseConnection()
|
self.transport.loseConnection()
|
||||||
|
@ -222,7 +221,41 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
|
|
||||||
d.errback(self._error or BadHandshake("connection lost"))
|
d.errback(self._error or BadHandshake("connection lost"))
|
||||||
|
|
||||||
|
# IConsumer methods, for outbound flow-control. We pass these through to
|
||||||
|
# the transport. The 'producer' is something like a t.p.basic.FileSender
|
||||||
|
def registerProducer(self, producer, streaming):
|
||||||
|
assert interfaces.IConsumer.providedBy(self.transport)
|
||||||
|
self.transport.registerProducer(producer, streaming)
|
||||||
|
def unregisterProducer(self):
|
||||||
|
self.transport.unregisterProducer()
|
||||||
|
def write(self, data):
|
||||||
|
self.send_record(data)
|
||||||
|
|
||||||
|
# IProducer methods, for inbound flow-control. We pass these through to
|
||||||
|
# the transport.
|
||||||
|
def stopProducing(self):
|
||||||
|
self.transport.stopProducing()
|
||||||
|
def pauseProducing(self):
|
||||||
|
self.transport.pauseProducing()
|
||||||
|
def resumeProducing(self):
|
||||||
|
self.transport.resumeProducing()
|
||||||
|
|
||||||
|
# Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to us.
|
||||||
|
# Inbound records will be written as bytes to the consumer.
|
||||||
|
def connectConsumer(self, consumer):
|
||||||
|
if self._consumer:
|
||||||
|
raise RuntimeError("A consumer is already attached: %r" %
|
||||||
|
self._consumer)
|
||||||
|
self._consumer = consumer
|
||||||
|
# drain any pending records
|
||||||
|
while self._inbound_records:
|
||||||
|
r = self._inbound_records.popleft()
|
||||||
|
consumer.write(r)
|
||||||
|
consumer.registerProducer(self, True)
|
||||||
|
|
||||||
|
def disconnectConsumer(self):
|
||||||
|
self._consumer.unregisterProducer()
|
||||||
|
self._consumer = None
|
||||||
|
|
||||||
class OutboundConnectionFactory(protocol.ClientFactory):
|
class OutboundConnectionFactory(protocol.ClientFactory):
|
||||||
protocol = Connection
|
protocol = Connection
|
||||||
|
|
Loading…
Reference in New Issue
Block a user