twisted.transit: move FileConsumer into RecordPipe

This adds an expected= argument to Connection.connectConsumer(), which
then returns a Deferred that fires when enough bytes have been written
to the consumer. It also adds Connection.writeToFile(), a helper method
that writes bytes to a filehandle.
This commit is contained in:
Brian Warner 2016-03-02 00:41:33 -08:00
parent 7234e25897
commit df2384bea2
3 changed files with 224 additions and 50 deletions

View File

@ -1,10 +1,8 @@
from __future__ import print_function
import io, json
from zope.interface import implementer
from twisted.internet import interfaces, defer
from twisted.internet.defer import inlineCallbacks, returnValue
from ..twisted.transcribe import Wormhole, WrongPasswordError
from ..twisted.transit import TransitReceiver #, TransitError
from ..twisted.transit import TransitReceiver
from .cmd_receive_blocking import BlockingReceiver, RespondError, APPID
from ..errors import TransferError
from .progress import ProgressPrinter
@ -105,11 +103,14 @@ class TwistedReceiver(BlockingReceiver):
progress_stdout = self.args.stdout
if self.args.hide_progress:
progress_stdout = io.StringIO()
pfc = ProgressingFileConsumer(f, self.xfersize, progress_stdout)
record_pipe.connectConsumer(pfc)
received = yield pfc.when_done
record_pipe.disconnectConsumer()
progress = ProgressPrinter(self.xfersize, progress_stdout)
progress.start()
received = yield record_pipe.writeToFile(f, self.xfersize,
progress.update)
progress.finish()
self.args.timing.finish_event(_start)
# except TransitError
if received < self.xfersize:
self.msg()
@ -124,36 +125,3 @@ class TwistedReceiver(BlockingReceiver):
yield record_pipe.send_record(b"ok\n")
yield record_pipe.close()
self.args.timing.finish_event(_start)
# based on twisted.protocols.ftp.FileConsumer, but:
# - finish after 'xfersize' bytes received, instead of connectionLost()
# - don't close the filehandle when done
@implementer(interfaces.IConsumer)
class ProgressingFileConsumer:
def __init__(self, f, xfersize, progress_stdout):
self._f = f
self._xfersize = xfersize
self._received = 0
self._progress = ProgressPrinter(xfersize, progress_stdout)
self._progress.start()
self.when_done = defer.Deferred()
def registerProducer(self, producer, streaming):
self.producer = producer
assert streaming
def write(self, bytes):
self._f.write(bytes)
self._received += len(bytes)
self._progress.update(self._received)
if self._received >= self._xfersize:
self._progress.finish()
d,self.when_done = self.when_done,None
d.callback(self._received)
def unregisterProducer(self):
self.producer = None
if self.when_done:
# connection was dropped before all bytes were received
self.when_done.callback(self._received)

View File

@ -1,4 +1,5 @@
from __future__ import print_function
import io
from binascii import hexlify, unhexlify
from twisted.trial import unittest
from twisted.internet import defer, task, endpoints, protocol, address, error
@ -1008,7 +1009,8 @@ class Connection(unittest.TestCase):
c.recordReceived(b"r2.")
consumer = proto_helpers.StringTransport()
c.connectConsumer(consumer)
rv = c.connectConsumer(consumer)
self.assertIs(rv, None)
self.assertIs(c._consumer, consumer)
self.assertEqual(consumer.value(), b"r1.r2.")
@ -1029,6 +1031,109 @@ class Connection(unittest.TestCase):
c.stopProducing()
self.assertEqual(c.transport.producerState, "stopped")
def test_connectConsumer(self):
# connectConsumer() takes an optional number of bytes to expect, and
# fires a Deferred when that many have been written
c = transit.Connection(None, None, None, "description")
c._negotiation_d.addErrback(lambda err: None) # eat it
c.transport = proto_helpers.StringTransport()
c.recordReceived(b"r1.")
consumer = proto_helpers.StringTransport()
results = []
d = c.connectConsumer(consumer, expected=10)
d.addBoth(results.append)
self.assertEqual(consumer.value(), b"r1.")
self.assertEqual(results, [])
c.recordReceived(b"r2.")
self.assertEqual(consumer.value(), b"r1.r2.")
self.assertEqual(results, [])
c.recordReceived(b"r3.")
self.assertEqual(consumer.value(), b"r1.r2.r3.")
self.assertEqual(results, [])
c.recordReceived(b"!")
self.assertEqual(consumer.value(), b"r1.r2.r3.!")
self.assertEqual(results, [10])
# that should automatically disconnect the consumer, and subsequent
# records should get queued, not delivered
self.assertIs(c._consumer, None)
c.recordReceived(b"overflow")
self.assertEqual(consumer.value(), b"r1.r2.r3.!")
# now test that the Deferred errbacks when the connection is lost
results = []
d = c.connectConsumer(consumer, expected=10)
d.addBoth(results.append)
c.connectionLost()
self.assertEqual(len(results), 1)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, error.ConnectionClosed)
def test_writeToFile(self):
c = transit.Connection(None, None, None, "description")
c._negotiation_d.addErrback(lambda err: None) # eat it
c.transport = proto_helpers.StringTransport()
c.recordReceived(b"r1.")
f = io.BytesIO()
progress = []
results = []
d = c.writeToFile(f, 10, progress.append)
d.addBoth(results.append)
self.assertEqual(f.getvalue(), b"r1.")
self.assertEqual(progress, [0, 3])
self.assertEqual(results, [])
c.recordReceived(b"r2.")
self.assertEqual(f.getvalue(), b"r1.r2.")
self.assertEqual(progress, [0, 3, 6])
self.assertEqual(results, [])
c.recordReceived(b"r3.")
self.assertEqual(f.getvalue(), b"r1.r2.r3.")
self.assertEqual(progress, [0, 3, 6, 9])
self.assertEqual(results, [])
c.recordReceived(b"!")
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
self.assertEqual(progress, [0, 3, 6, 9, 10])
self.assertEqual(results, [10])
# that should automatically disconnect the consumer, and subsequent
# records should get queued, not delivered
self.assertIs(c._consumer, None)
c.recordReceived(b"overflow.")
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
self.assertEqual(progress, [0, 3, 6, 9, 10])
# test what happens when enough data is queued ahead of time
c.recordReceived(b"second.") # now "overflow.second."
c.recordReceived(b"third.") # now "overflow.second.third."
f = io.BytesIO()
results = []
d = c.writeToFile(f, 10)
d.addBoth(results.append)
self.assertEqual(f.getvalue(), b"overflow.second.") # whole records
self.assertEqual(results, [16])
self.assertEqual(list(c._inbound_records), [b"third."])
# now test that the Deferred errbacks when the connection is lost
results = []
d = c.writeToFile(f, 10)
d.addBoth(results.append)
c.connectionLost()
self.assertEqual(len(results), 1)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, error.ConnectionClosed)
def test_consumer(self):
# a local producer sends data to a consuming Transit object
c = transit.Connection(None, None, None, "description")
@ -1046,6 +1151,20 @@ class Connection(unittest.TestCase):
c.unregisterProducer()
self.assertEqual(c.transport.producer, None)
class FileConsumer(unittest.TestCase):
def test_basic(self):
f = io.BytesIO()
progress = []
fc = transit.FileConsumer(f, 100, progress.append)
self.assertEqual(progress, [0])
self.assertEqual(f.getvalue(), b"")
fc.write(b"."* 99)
self.assertEqual(progress, [0, 99])
self.assertEqual(f.getvalue(), b"."*99)
fc.write(b"!")
self.assertEqual(progress, [0, 99, 100])
self.assertEqual(f.getvalue(), b"."*99+b"!")
DIRECT_HINT = u"tcp:direct:1234"
RELAY_HINT = u"tcp:relay:1234"

View File

@ -38,6 +38,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
self._negotiation_d = defer.Deferred(self._cancel)
self._error = None
self._consumer = None
self._consumer_bytes_written = 0
self._consumer_bytes_expected = None
self._consumer_deferred = None
self._inbound_records = collections.deque()
self._waiting_reads = collections.deque()
@ -181,7 +184,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
def recordReceived(self, record):
if self._consumer:
self._consumer.write(record)
self._writeToConsumer(record)
return
self._inbound_records.append(record)
self._deliverRecords()
@ -226,6 +229,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
# timeout: BadHandshake("timeout")
d.errback(self._error or BadHandshake("connection lost"))
if self._consumer_deferred:
self._consumer_deferred.errback(error.ConnectionClosed())
# IConsumer methods, for outbound flow-control. We pass these through to
# the transport. The 'producer' is something like a t.p.basic.FileSender
@ -246,22 +251,75 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
def resumeProducing(self):
self.transport.resumeProducing()
# Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to us.
# Inbound records will be written as bytes to the consumer.
def connectConsumer(self, consumer):
# Helper methods
def connectConsumer(self, consumer, expected=None):
"""Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to
us. Inbound records will be written as bytes to the consumer.
Set 'expected' to an integer to automatically disconnect when at
least that number of bytes have been written. This function will then
return a Deferred (that fires with the number of bytes actually
received). If the connection is lost while this Deferred is
outstanding, it will errback.
If 'expected' is None, then this function returns None instead of a
Deferred, and you must call disconnectConsumer() when you are done."""
if self._consumer:
raise RuntimeError("A consumer is already attached: %r" %
self._consumer)
self._consumer = consumer
# drain any pending records
while self._inbound_records:
r = self._inbound_records.popleft()
consumer.write(r)
# be aware of an ordering hazard: when we call the consumer's
# .registerProducer method, they are likely to immediately call
# self.resumeProducing, which we'll deliver to self.transport, which
# might call our .dataReceived, which may cause more records to be
# available. By waiting to set self._consumer until *after* we drain
# any pending records, we avoid delivering records out of order,
# which would be bad.
consumer.registerProducer(self, True)
# There might be enough data queued to exceed 'expected' before we
# leave this function. We must be sure to register the producer
# before it gets unregistered.
self._consumer = consumer
self._consumer_bytes_written = 0
self._consumer_bytes_expected = expected
d = None
if expected is not None:
d = defer.Deferred()
self._consumer_deferred = d
# drain any pending records
while self._consumer and self._inbound_records:
r = self._inbound_records.popleft()
self._writeToConsumer(r)
return d
def _writeToConsumer(self, record):
self._consumer.write(record)
self._consumer_bytes_written += len(record)
if self._consumer_bytes_expected is not None:
if self._consumer_bytes_written >= self._consumer_bytes_expected:
d = self._consumer_deferred
self.disconnectConsumer()
d.callback(self._consumer_bytes_written)
def disconnectConsumer(self):
self._consumer.unregisterProducer()
self._consumer = None
self._consumer_bytes_expected = None
self._consumer_deferred = None
# Helper method to write a known number of bytes to a file. This has no
# flow control: the filehandle cannot push back. 'progress' is an
# optional callable which will be called frequently with the number of
# bytes transferred so far. Returns a Deferred that fires (with the
# number of bytes written) when the count is reached or the RecordPipe is
# closed.
def writeToFile(self, f, expected, progress=None):
progress = progress or (lambda n: None)
fc = FileConsumer(f, expected, progress)
return self.connectConsumer(fc, expected)
class OutboundConnectionFactory(protocol.ClientFactory):
protocol = Connection
@ -647,6 +705,35 @@ class TransitSender(Common):
class TransitReceiver(Common):
is_sender = False
# based on twisted.protocols.ftp.FileConsumer, but:
# - call a progress-tracking function
# - don't close the filehandle when done
@implementer(interfaces.IConsumer)
class FileConsumer:
def __init__(self, f, xfersize, progress_f):
self._f = f
self._xfersize = xfersize
self._received = 0
self._progress_f = progress_f
self._producer = None
self._progress_f(0)
def registerProducer(self, producer, streaming):
assert not self._producer
self._producer = producer
assert streaming
def write(self, bytes):
self._f.write(bytes)
self._received += len(bytes)
self._progress_f(self._received)
def unregisterProducer(self):
assert self._producer
self._producer = None
# the TransitSender/Receiver.connect() yields a Connection, on which you can
# do send_record(), but what should the receive API be? set a callback for
# inbound records? get a Deferred for the next record? The producer/consumer