From a235b507c8760bc2ddcd49e465908924a122e157 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 15 Feb 2016 21:23:20 -0800 Subject: [PATCH] twisted.transit: implement producer/consumer flow control --- src/wormhole/test/test_transit_twisted.py | 48 +++++++++++++++++++++++ src/wormhole/twisted/transit.py | 43 +++++++++++++++++--- 2 files changed, 86 insertions(+), 5 deletions(-) diff --git a/src/wormhole/test/test_transit_twisted.py b/src/wormhole/test/test_transit_twisted.py index 310f55f..56878cc 100644 --- a/src/wormhole/test/test_transit_twisted.py +++ b/src/wormhole/test/test_transit_twisted.py @@ -4,6 +4,7 @@ from twisted.trial import unittest from twisted.internet import defer, task, endpoints, protocol, address, error from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.python import log, failure +from twisted.test import proto_helpers from ..twisted import transit from ..errors import UsageError from nacl.secret import SecretBox @@ -991,6 +992,53 @@ class Connection(unittest.TestCase): self.assertIsInstance(f, failure.Failure) self.assertIsInstance(f.value, error.ConnectionClosed) + def test_producer(self): + # a Transit object (receiving data from the remote peer) produces + # data and writes it into a local Consumer + c = transit.Connection(None, None, None) + c.transport = proto_helpers.StringTransport() + c.recordReceived(b"r1.") + c.recordReceived(b"r2.") + + consumer = proto_helpers.StringTransport() + c.connectConsumer(consumer) + self.assertIs(c._consumer, consumer) + self.assertEqual(consumer.value(), b"r1.r2.") + + self.assertRaises(RuntimeError, c.connectConsumer, consumer) + + c.recordReceived(b"r3.") + self.assertEqual(consumer.value(), b"r1.r2.r3.") + + c.pauseProducing() + self.assertEqual(c.transport.producerState, "paused") + c.resumeProducing() + self.assertEqual(c.transport.producerState, "producing") + + c.disconnectConsumer() + self.assertEqual(consumer.producer, None) + c.connectConsumer(consumer) + + c.stopProducing() + self.assertEqual(c.transport.producerState, "stopped") + + def test_consumer(self): + # a local producer sends data to a consuming Transit object + c = transit.Connection(None, None, None) + c.transport = proto_helpers.StringTransport() + records = [] + c.send_record = records.append + + producer = proto_helpers.StringTransport() + c.registerProducer(producer, True) + self.assertIs(c.transport.producer, producer) + + c.write(b"r1.") + self.assertEqual(records, [b"r1."]) + + c.unregisterProducer() + self.assertEqual(c.transport.producer, None) + DIRECT_HINT = u"tcp:direct:1234" RELAY_HINT = u"tcp:relay:1234" diff --git a/src/wormhole/twisted/transit.py b/src/wormhole/twisted/transit.py index 6544a71..6830f8e 100644 --- a/src/wormhole/twisted/transit.py +++ b/src/wormhole/twisted/transit.py @@ -34,6 +34,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): self.start = start self._negotiation_d = defer.Deferred(self._cancel) self._error = None + self._consumer = None self._inbound_records = collections.deque() self._waiting_reads = collections.deque() @@ -173,12 +174,11 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): self.transport.write(encrypted) def recordReceived(self, record): + if self._consumer: + self._consumer.write(record) + return self._inbound_records.append(record) self._deliverRecords() - self._checkPause() - - def _checkPause(self): - return # TODO def receive_record(self): d = defer.Deferred() @@ -191,7 +191,6 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): r = self._inbound_records.popleft() d = self._waiting_reads.popleft() d.callback(r) - self._checkPause() def close(self): self.transport.loseConnection() @@ -222,7 +221,41 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): d.errback(self._error or BadHandshake("connection lost")) + # IConsumer methods, for outbound flow-control. We pass these through to + # the transport. The 'producer' is something like a t.p.basic.FileSender + def registerProducer(self, producer, streaming): + assert interfaces.IConsumer.providedBy(self.transport) + self.transport.registerProducer(producer, streaming) + def unregisterProducer(self): + self.transport.unregisterProducer() + def write(self, data): + self.send_record(data) + # IProducer methods, for inbound flow-control. We pass these through to + # the transport. + def stopProducing(self): + self.transport.stopProducing() + def pauseProducing(self): + self.transport.pauseProducing() + def resumeProducing(self): + self.transport.resumeProducing() + + # Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to us. + # Inbound records will be written as bytes to the consumer. + def connectConsumer(self, consumer): + if self._consumer: + raise RuntimeError("A consumer is already attached: %r" % + self._consumer) + self._consumer = consumer + # drain any pending records + while self._inbound_records: + r = self._inbound_records.popleft() + consumer.write(r) + consumer.registerProducer(self, True) + + def disconnectConsumer(self): + self._consumer.unregisterProducer() + self._consumer = None class OutboundConnectionFactory(protocol.ClientFactory): protocol = Connection