twisted.transit: move FileConsumer into RecordPipe

This adds an expected= argument to Connection.connectConsumer(), which
then returns a Deferred that fires when enough bytes have been written
to the consumer. It also adds Connection.writeToFile(), a helper method
that writes bytes to a filehandle.
This commit is contained in:
Brian Warner 2016-03-02 00:41:33 -08:00
parent 7234e25897
commit df2384bea2
3 changed files with 224 additions and 50 deletions

View File

@ -1,10 +1,8 @@
from __future__ import print_function from __future__ import print_function
import io, json import io, json
from zope.interface import implementer
from twisted.internet import interfaces, defer
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from ..twisted.transcribe import Wormhole, WrongPasswordError from ..twisted.transcribe import Wormhole, WrongPasswordError
from ..twisted.transit import TransitReceiver #, TransitError from ..twisted.transit import TransitReceiver
from .cmd_receive_blocking import BlockingReceiver, RespondError, APPID from .cmd_receive_blocking import BlockingReceiver, RespondError, APPID
from ..errors import TransferError from ..errors import TransferError
from .progress import ProgressPrinter from .progress import ProgressPrinter
@ -105,11 +103,14 @@ class TwistedReceiver(BlockingReceiver):
progress_stdout = self.args.stdout progress_stdout = self.args.stdout
if self.args.hide_progress: if self.args.hide_progress:
progress_stdout = io.StringIO() progress_stdout = io.StringIO()
pfc = ProgressingFileConsumer(f, self.xfersize, progress_stdout) progress = ProgressPrinter(self.xfersize, progress_stdout)
record_pipe.connectConsumer(pfc)
received = yield pfc.when_done progress.start()
record_pipe.disconnectConsumer() received = yield record_pipe.writeToFile(f, self.xfersize,
progress.update)
progress.finish()
self.args.timing.finish_event(_start) self.args.timing.finish_event(_start)
# except TransitError # except TransitError
if received < self.xfersize: if received < self.xfersize:
self.msg() self.msg()
@ -124,36 +125,3 @@ class TwistedReceiver(BlockingReceiver):
yield record_pipe.send_record(b"ok\n") yield record_pipe.send_record(b"ok\n")
yield record_pipe.close() yield record_pipe.close()
self.args.timing.finish_event(_start) self.args.timing.finish_event(_start)
# based on twisted.protocols.ftp.FileConsumer, but:
# - finish after 'xfersize' bytes received, instead of connectionLost()
# - don't close the filehandle when done
@implementer(interfaces.IConsumer)
class ProgressingFileConsumer:
def __init__(self, f, xfersize, progress_stdout):
self._f = f
self._xfersize = xfersize
self._received = 0
self._progress = ProgressPrinter(xfersize, progress_stdout)
self._progress.start()
self.when_done = defer.Deferred()
def registerProducer(self, producer, streaming):
self.producer = producer
assert streaming
def write(self, bytes):
self._f.write(bytes)
self._received += len(bytes)
self._progress.update(self._received)
if self._received >= self._xfersize:
self._progress.finish()
d,self.when_done = self.when_done,None
d.callback(self._received)
def unregisterProducer(self):
self.producer = None
if self.when_done:
# connection was dropped before all bytes were received
self.when_done.callback(self._received)

View File

@ -1,4 +1,5 @@
from __future__ import print_function from __future__ import print_function
import io
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer, task, endpoints, protocol, address, error from twisted.internet import defer, task, endpoints, protocol, address, error
@ -1008,7 +1009,8 @@ class Connection(unittest.TestCase):
c.recordReceived(b"r2.") c.recordReceived(b"r2.")
consumer = proto_helpers.StringTransport() consumer = proto_helpers.StringTransport()
c.connectConsumer(consumer) rv = c.connectConsumer(consumer)
self.assertIs(rv, None)
self.assertIs(c._consumer, consumer) self.assertIs(c._consumer, consumer)
self.assertEqual(consumer.value(), b"r1.r2.") self.assertEqual(consumer.value(), b"r1.r2.")
@ -1029,6 +1031,109 @@ class Connection(unittest.TestCase):
c.stopProducing() c.stopProducing()
self.assertEqual(c.transport.producerState, "stopped") self.assertEqual(c.transport.producerState, "stopped")
def test_connectConsumer(self):
# connectConsumer() takes an optional number of bytes to expect, and
# fires a Deferred when that many have been written
c = transit.Connection(None, None, None, "description")
c._negotiation_d.addErrback(lambda err: None) # eat it
c.transport = proto_helpers.StringTransport()
c.recordReceived(b"r1.")
consumer = proto_helpers.StringTransport()
results = []
d = c.connectConsumer(consumer, expected=10)
d.addBoth(results.append)
self.assertEqual(consumer.value(), b"r1.")
self.assertEqual(results, [])
c.recordReceived(b"r2.")
self.assertEqual(consumer.value(), b"r1.r2.")
self.assertEqual(results, [])
c.recordReceived(b"r3.")
self.assertEqual(consumer.value(), b"r1.r2.r3.")
self.assertEqual(results, [])
c.recordReceived(b"!")
self.assertEqual(consumer.value(), b"r1.r2.r3.!")
self.assertEqual(results, [10])
# that should automatically disconnect the consumer, and subsequent
# records should get queued, not delivered
self.assertIs(c._consumer, None)
c.recordReceived(b"overflow")
self.assertEqual(consumer.value(), b"r1.r2.r3.!")
# now test that the Deferred errbacks when the connection is lost
results = []
d = c.connectConsumer(consumer, expected=10)
d.addBoth(results.append)
c.connectionLost()
self.assertEqual(len(results), 1)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, error.ConnectionClosed)
def test_writeToFile(self):
c = transit.Connection(None, None, None, "description")
c._negotiation_d.addErrback(lambda err: None) # eat it
c.transport = proto_helpers.StringTransport()
c.recordReceived(b"r1.")
f = io.BytesIO()
progress = []
results = []
d = c.writeToFile(f, 10, progress.append)
d.addBoth(results.append)
self.assertEqual(f.getvalue(), b"r1.")
self.assertEqual(progress, [0, 3])
self.assertEqual(results, [])
c.recordReceived(b"r2.")
self.assertEqual(f.getvalue(), b"r1.r2.")
self.assertEqual(progress, [0, 3, 6])
self.assertEqual(results, [])
c.recordReceived(b"r3.")
self.assertEqual(f.getvalue(), b"r1.r2.r3.")
self.assertEqual(progress, [0, 3, 6, 9])
self.assertEqual(results, [])
c.recordReceived(b"!")
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
self.assertEqual(progress, [0, 3, 6, 9, 10])
self.assertEqual(results, [10])
# that should automatically disconnect the consumer, and subsequent
# records should get queued, not delivered
self.assertIs(c._consumer, None)
c.recordReceived(b"overflow.")
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
self.assertEqual(progress, [0, 3, 6, 9, 10])
# test what happens when enough data is queued ahead of time
c.recordReceived(b"second.") # now "overflow.second."
c.recordReceived(b"third.") # now "overflow.second.third."
f = io.BytesIO()
results = []
d = c.writeToFile(f, 10)
d.addBoth(results.append)
self.assertEqual(f.getvalue(), b"overflow.second.") # whole records
self.assertEqual(results, [16])
self.assertEqual(list(c._inbound_records), [b"third."])
# now test that the Deferred errbacks when the connection is lost
results = []
d = c.writeToFile(f, 10)
d.addBoth(results.append)
c.connectionLost()
self.assertEqual(len(results), 1)
f = results[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, error.ConnectionClosed)
def test_consumer(self): def test_consumer(self):
# a local producer sends data to a consuming Transit object # a local producer sends data to a consuming Transit object
c = transit.Connection(None, None, None, "description") c = transit.Connection(None, None, None, "description")
@ -1046,6 +1151,20 @@ class Connection(unittest.TestCase):
c.unregisterProducer() c.unregisterProducer()
self.assertEqual(c.transport.producer, None) self.assertEqual(c.transport.producer, None)
class FileConsumer(unittest.TestCase):
def test_basic(self):
f = io.BytesIO()
progress = []
fc = transit.FileConsumer(f, 100, progress.append)
self.assertEqual(progress, [0])
self.assertEqual(f.getvalue(), b"")
fc.write(b"."* 99)
self.assertEqual(progress, [0, 99])
self.assertEqual(f.getvalue(), b"."*99)
fc.write(b"!")
self.assertEqual(progress, [0, 99, 100])
self.assertEqual(f.getvalue(), b"."*99+b"!")
DIRECT_HINT = u"tcp:direct:1234" DIRECT_HINT = u"tcp:direct:1234"
RELAY_HINT = u"tcp:relay:1234" RELAY_HINT = u"tcp:relay:1234"

View File

@ -38,6 +38,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
self._negotiation_d = defer.Deferred(self._cancel) self._negotiation_d = defer.Deferred(self._cancel)
self._error = None self._error = None
self._consumer = None self._consumer = None
self._consumer_bytes_written = 0
self._consumer_bytes_expected = None
self._consumer_deferred = None
self._inbound_records = collections.deque() self._inbound_records = collections.deque()
self._waiting_reads = collections.deque() self._waiting_reads = collections.deque()
@ -181,7 +184,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
def recordReceived(self, record): def recordReceived(self, record):
if self._consumer: if self._consumer:
self._consumer.write(record) self._writeToConsumer(record)
return return
self._inbound_records.append(record) self._inbound_records.append(record)
self._deliverRecords() self._deliverRecords()
@ -226,6 +229,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
# timeout: BadHandshake("timeout") # timeout: BadHandshake("timeout")
d.errback(self._error or BadHandshake("connection lost")) d.errback(self._error or BadHandshake("connection lost"))
if self._consumer_deferred:
self._consumer_deferred.errback(error.ConnectionClosed())
# IConsumer methods, for outbound flow-control. We pass these through to # IConsumer methods, for outbound flow-control. We pass these through to
# the transport. The 'producer' is something like a t.p.basic.FileSender # the transport. The 'producer' is something like a t.p.basic.FileSender
@ -246,22 +251,75 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
def resumeProducing(self): def resumeProducing(self):
self.transport.resumeProducing() self.transport.resumeProducing()
# Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to us. # Helper methods
# Inbound records will be written as bytes to the consumer.
def connectConsumer(self, consumer): def connectConsumer(self, consumer, expected=None):
"""Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to
us. Inbound records will be written as bytes to the consumer.
Set 'expected' to an integer to automatically disconnect when at
least that number of bytes have been written. This function will then
return a Deferred (that fires with the number of bytes actually
received). If the connection is lost while this Deferred is
outstanding, it will errback.
If 'expected' is None, then this function returns None instead of a
Deferred, and you must call disconnectConsumer() when you are done."""
if self._consumer: if self._consumer:
raise RuntimeError("A consumer is already attached: %r" % raise RuntimeError("A consumer is already attached: %r" %
self._consumer) self._consumer)
self._consumer = consumer
# drain any pending records # be aware of an ordering hazard: when we call the consumer's
while self._inbound_records: # .registerProducer method, they are likely to immediately call
r = self._inbound_records.popleft() # self.resumeProducing, which we'll deliver to self.transport, which
consumer.write(r) # might call our .dataReceived, which may cause more records to be
# available. By waiting to set self._consumer until *after* we drain
# any pending records, we avoid delivering records out of order,
# which would be bad.
consumer.registerProducer(self, True) consumer.registerProducer(self, True)
# There might be enough data queued to exceed 'expected' before we
# leave this function. We must be sure to register the producer
# before it gets unregistered.
self._consumer = consumer
self._consumer_bytes_written = 0
self._consumer_bytes_expected = expected
d = None
if expected is not None:
d = defer.Deferred()
self._consumer_deferred = d
# drain any pending records
while self._consumer and self._inbound_records:
r = self._inbound_records.popleft()
self._writeToConsumer(r)
return d
def _writeToConsumer(self, record):
self._consumer.write(record)
self._consumer_bytes_written += len(record)
if self._consumer_bytes_expected is not None:
if self._consumer_bytes_written >= self._consumer_bytes_expected:
d = self._consumer_deferred
self.disconnectConsumer()
d.callback(self._consumer_bytes_written)
def disconnectConsumer(self): def disconnectConsumer(self):
self._consumer.unregisterProducer() self._consumer.unregisterProducer()
self._consumer = None self._consumer = None
self._consumer_bytes_expected = None
self._consumer_deferred = None
# Helper method to write a known number of bytes to a file. This has no
# flow control: the filehandle cannot push back. 'progress' is an
# optional callable which will be called frequently with the number of
# bytes transferred so far. Returns a Deferred that fires (with the
# number of bytes written) when the count is reached or the RecordPipe is
# closed.
def writeToFile(self, f, expected, progress=None):
progress = progress or (lambda n: None)
fc = FileConsumer(f, expected, progress)
return self.connectConsumer(fc, expected)
class OutboundConnectionFactory(protocol.ClientFactory): class OutboundConnectionFactory(protocol.ClientFactory):
protocol = Connection protocol = Connection
@ -647,6 +705,35 @@ class TransitSender(Common):
class TransitReceiver(Common): class TransitReceiver(Common):
is_sender = False is_sender = False
# based on twisted.protocols.ftp.FileConsumer, but:
# - call a progress-tracking function
# - don't close the filehandle when done
@implementer(interfaces.IConsumer)
class FileConsumer:
def __init__(self, f, xfersize, progress_f):
self._f = f
self._xfersize = xfersize
self._received = 0
self._progress_f = progress_f
self._producer = None
self._progress_f(0)
def registerProducer(self, producer, streaming):
assert not self._producer
self._producer = producer
assert streaming
def write(self, bytes):
self._f.write(bytes)
self._received += len(bytes)
self._progress_f(self._received)
def unregisterProducer(self):
assert self._producer
self._producer = None
# the TransitSender/Receiver.connect() yields a Connection, on which you can # the TransitSender/Receiver.connect() yields a Connection, on which you can
# do send_record(), but what should the receive API be? set a callback for # do send_record(), but what should the receive API be? set a callback for
# inbound records? get a Deferred for the next record? The producer/consumer # inbound records? get a Deferred for the next record? The producer/consumer