twisted.transit: move FileConsumer into RecordPipe
This adds an expected= argument to Connection.connectConsumer(), which then returns a Deferred that fires when enough bytes have been written to the consumer. It also adds Connection.writeToFile(), a helper method that writes bytes to a filehandle.
This commit is contained in:
parent
7234e25897
commit
df2384bea2
|
@ -1,10 +1,8 @@
|
|||
from __future__ import print_function
|
||||
import io, json
|
||||
from zope.interface import implementer
|
||||
from twisted.internet import interfaces, defer
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
||||
from ..twisted.transit import TransitReceiver #, TransitError
|
||||
from ..twisted.transit import TransitReceiver
|
||||
from .cmd_receive_blocking import BlockingReceiver, RespondError, APPID
|
||||
from ..errors import TransferError
|
||||
from .progress import ProgressPrinter
|
||||
|
@ -105,11 +103,14 @@ class TwistedReceiver(BlockingReceiver):
|
|||
progress_stdout = self.args.stdout
|
||||
if self.args.hide_progress:
|
||||
progress_stdout = io.StringIO()
|
||||
pfc = ProgressingFileConsumer(f, self.xfersize, progress_stdout)
|
||||
record_pipe.connectConsumer(pfc)
|
||||
received = yield pfc.when_done
|
||||
record_pipe.disconnectConsumer()
|
||||
progress = ProgressPrinter(self.xfersize, progress_stdout)
|
||||
|
||||
progress.start()
|
||||
received = yield record_pipe.writeToFile(f, self.xfersize,
|
||||
progress.update)
|
||||
progress.finish()
|
||||
self.args.timing.finish_event(_start)
|
||||
|
||||
# except TransitError
|
||||
if received < self.xfersize:
|
||||
self.msg()
|
||||
|
@ -124,36 +125,3 @@ class TwistedReceiver(BlockingReceiver):
|
|||
yield record_pipe.send_record(b"ok\n")
|
||||
yield record_pipe.close()
|
||||
self.args.timing.finish_event(_start)
|
||||
|
||||
# based on twisted.protocols.ftp.FileConsumer, but:
|
||||
# - finish after 'xfersize' bytes received, instead of connectionLost()
|
||||
# - don't close the filehandle when done
|
||||
|
||||
@implementer(interfaces.IConsumer)
|
||||
class ProgressingFileConsumer:
|
||||
def __init__(self, f, xfersize, progress_stdout):
|
||||
self._f = f
|
||||
self._xfersize = xfersize
|
||||
self._received = 0
|
||||
self._progress = ProgressPrinter(xfersize, progress_stdout)
|
||||
self._progress.start()
|
||||
self.when_done = defer.Deferred()
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
self.producer = producer
|
||||
assert streaming
|
||||
|
||||
def write(self, bytes):
|
||||
self._f.write(bytes)
|
||||
self._received += len(bytes)
|
||||
self._progress.update(self._received)
|
||||
if self._received >= self._xfersize:
|
||||
self._progress.finish()
|
||||
d,self.when_done = self.when_done,None
|
||||
d.callback(self._received)
|
||||
|
||||
def unregisterProducer(self):
|
||||
self.producer = None
|
||||
if self.when_done:
|
||||
# connection was dropped before all bytes were received
|
||||
self.when_done.callback(self._received)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from __future__ import print_function
|
||||
import io
|
||||
from binascii import hexlify, unhexlify
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import defer, task, endpoints, protocol, address, error
|
||||
|
@ -1008,7 +1009,8 @@ class Connection(unittest.TestCase):
|
|||
c.recordReceived(b"r2.")
|
||||
|
||||
consumer = proto_helpers.StringTransport()
|
||||
c.connectConsumer(consumer)
|
||||
rv = c.connectConsumer(consumer)
|
||||
self.assertIs(rv, None)
|
||||
self.assertIs(c._consumer, consumer)
|
||||
self.assertEqual(consumer.value(), b"r1.r2.")
|
||||
|
||||
|
@ -1029,6 +1031,109 @@ class Connection(unittest.TestCase):
|
|||
c.stopProducing()
|
||||
self.assertEqual(c.transport.producerState, "stopped")
|
||||
|
||||
def test_connectConsumer(self):
|
||||
# connectConsumer() takes an optional number of bytes to expect, and
|
||||
# fires a Deferred when that many have been written
|
||||
c = transit.Connection(None, None, None, "description")
|
||||
c._negotiation_d.addErrback(lambda err: None) # eat it
|
||||
c.transport = proto_helpers.StringTransport()
|
||||
c.recordReceived(b"r1.")
|
||||
|
||||
consumer = proto_helpers.StringTransport()
|
||||
results = []
|
||||
d = c.connectConsumer(consumer, expected=10)
|
||||
d.addBoth(results.append)
|
||||
self.assertEqual(consumer.value(), b"r1.")
|
||||
self.assertEqual(results, [])
|
||||
|
||||
c.recordReceived(b"r2.")
|
||||
self.assertEqual(consumer.value(), b"r1.r2.")
|
||||
self.assertEqual(results, [])
|
||||
|
||||
c.recordReceived(b"r3.")
|
||||
self.assertEqual(consumer.value(), b"r1.r2.r3.")
|
||||
self.assertEqual(results, [])
|
||||
|
||||
c.recordReceived(b"!")
|
||||
self.assertEqual(consumer.value(), b"r1.r2.r3.!")
|
||||
self.assertEqual(results, [10])
|
||||
|
||||
# that should automatically disconnect the consumer, and subsequent
|
||||
# records should get queued, not delivered
|
||||
self.assertIs(c._consumer, None)
|
||||
c.recordReceived(b"overflow")
|
||||
self.assertEqual(consumer.value(), b"r1.r2.r3.!")
|
||||
|
||||
# now test that the Deferred errbacks when the connection is lost
|
||||
results = []
|
||||
d = c.connectConsumer(consumer, expected=10)
|
||||
d.addBoth(results.append)
|
||||
|
||||
c.connectionLost()
|
||||
self.assertEqual(len(results), 1)
|
||||
f = results[0]
|
||||
self.assertIsInstance(f, failure.Failure)
|
||||
self.assertIsInstance(f.value, error.ConnectionClosed)
|
||||
|
||||
def test_writeToFile(self):
|
||||
c = transit.Connection(None, None, None, "description")
|
||||
c._negotiation_d.addErrback(lambda err: None) # eat it
|
||||
c.transport = proto_helpers.StringTransport()
|
||||
c.recordReceived(b"r1.")
|
||||
|
||||
f = io.BytesIO()
|
||||
progress = []
|
||||
results = []
|
||||
d = c.writeToFile(f, 10, progress.append)
|
||||
d.addBoth(results.append)
|
||||
self.assertEqual(f.getvalue(), b"r1.")
|
||||
self.assertEqual(progress, [0, 3])
|
||||
self.assertEqual(results, [])
|
||||
|
||||
c.recordReceived(b"r2.")
|
||||
self.assertEqual(f.getvalue(), b"r1.r2.")
|
||||
self.assertEqual(progress, [0, 3, 6])
|
||||
self.assertEqual(results, [])
|
||||
|
||||
c.recordReceived(b"r3.")
|
||||
self.assertEqual(f.getvalue(), b"r1.r2.r3.")
|
||||
self.assertEqual(progress, [0, 3, 6, 9])
|
||||
self.assertEqual(results, [])
|
||||
|
||||
c.recordReceived(b"!")
|
||||
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
|
||||
self.assertEqual(progress, [0, 3, 6, 9, 10])
|
||||
self.assertEqual(results, [10])
|
||||
|
||||
# that should automatically disconnect the consumer, and subsequent
|
||||
# records should get queued, not delivered
|
||||
self.assertIs(c._consumer, None)
|
||||
c.recordReceived(b"overflow.")
|
||||
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
|
||||
self.assertEqual(progress, [0, 3, 6, 9, 10])
|
||||
|
||||
# test what happens when enough data is queued ahead of time
|
||||
c.recordReceived(b"second.") # now "overflow.second."
|
||||
c.recordReceived(b"third.") # now "overflow.second.third."
|
||||
f = io.BytesIO()
|
||||
results = []
|
||||
d = c.writeToFile(f, 10)
|
||||
d.addBoth(results.append)
|
||||
self.assertEqual(f.getvalue(), b"overflow.second.") # whole records
|
||||
self.assertEqual(results, [16])
|
||||
self.assertEqual(list(c._inbound_records), [b"third."])
|
||||
|
||||
# now test that the Deferred errbacks when the connection is lost
|
||||
results = []
|
||||
d = c.writeToFile(f, 10)
|
||||
d.addBoth(results.append)
|
||||
|
||||
c.connectionLost()
|
||||
self.assertEqual(len(results), 1)
|
||||
f = results[0]
|
||||
self.assertIsInstance(f, failure.Failure)
|
||||
self.assertIsInstance(f.value, error.ConnectionClosed)
|
||||
|
||||
def test_consumer(self):
|
||||
# a local producer sends data to a consuming Transit object
|
||||
c = transit.Connection(None, None, None, "description")
|
||||
|
@ -1046,6 +1151,20 @@ class Connection(unittest.TestCase):
|
|||
c.unregisterProducer()
|
||||
self.assertEqual(c.transport.producer, None)
|
||||
|
||||
class FileConsumer(unittest.TestCase):
|
||||
def test_basic(self):
|
||||
f = io.BytesIO()
|
||||
progress = []
|
||||
fc = transit.FileConsumer(f, 100, progress.append)
|
||||
self.assertEqual(progress, [0])
|
||||
self.assertEqual(f.getvalue(), b"")
|
||||
fc.write(b"."* 99)
|
||||
self.assertEqual(progress, [0, 99])
|
||||
self.assertEqual(f.getvalue(), b"."*99)
|
||||
fc.write(b"!")
|
||||
self.assertEqual(progress, [0, 99, 100])
|
||||
self.assertEqual(f.getvalue(), b"."*99+b"!")
|
||||
|
||||
|
||||
DIRECT_HINT = u"tcp:direct:1234"
|
||||
RELAY_HINT = u"tcp:relay:1234"
|
||||
|
|
|
@ -38,6 +38,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
|||
self._negotiation_d = defer.Deferred(self._cancel)
|
||||
self._error = None
|
||||
self._consumer = None
|
||||
self._consumer_bytes_written = 0
|
||||
self._consumer_bytes_expected = None
|
||||
self._consumer_deferred = None
|
||||
self._inbound_records = collections.deque()
|
||||
self._waiting_reads = collections.deque()
|
||||
|
||||
|
@ -181,7 +184,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
|||
|
||||
def recordReceived(self, record):
|
||||
if self._consumer:
|
||||
self._consumer.write(record)
|
||||
self._writeToConsumer(record)
|
||||
return
|
||||
self._inbound_records.append(record)
|
||||
self._deliverRecords()
|
||||
|
@ -226,6 +229,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
|||
# timeout: BadHandshake("timeout")
|
||||
|
||||
d.errback(self._error or BadHandshake("connection lost"))
|
||||
if self._consumer_deferred:
|
||||
self._consumer_deferred.errback(error.ConnectionClosed())
|
||||
|
||||
# IConsumer methods, for outbound flow-control. We pass these through to
|
||||
# the transport. The 'producer' is something like a t.p.basic.FileSender
|
||||
|
@ -246,22 +251,75 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
|||
def resumeProducing(self):
|
||||
self.transport.resumeProducing()
|
||||
|
||||
# Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to us.
|
||||
# Inbound records will be written as bytes to the consumer.
|
||||
def connectConsumer(self, consumer):
|
||||
# Helper methods
|
||||
|
||||
def connectConsumer(self, consumer, expected=None):
|
||||
"""Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to
|
||||
us. Inbound records will be written as bytes to the consumer.
|
||||
|
||||
Set 'expected' to an integer to automatically disconnect when at
|
||||
least that number of bytes have been written. This function will then
|
||||
return a Deferred (that fires with the number of bytes actually
|
||||
received). If the connection is lost while this Deferred is
|
||||
outstanding, it will errback.
|
||||
|
||||
If 'expected' is None, then this function returns None instead of a
|
||||
Deferred, and you must call disconnectConsumer() when you are done."""
|
||||
|
||||
if self._consumer:
|
||||
raise RuntimeError("A consumer is already attached: %r" %
|
||||
self._consumer)
|
||||
self._consumer = consumer
|
||||
# drain any pending records
|
||||
while self._inbound_records:
|
||||
r = self._inbound_records.popleft()
|
||||
consumer.write(r)
|
||||
|
||||
# be aware of an ordering hazard: when we call the consumer's
|
||||
# .registerProducer method, they are likely to immediately call
|
||||
# self.resumeProducing, which we'll deliver to self.transport, which
|
||||
# might call our .dataReceived, which may cause more records to be
|
||||
# available. By waiting to set self._consumer until *after* we drain
|
||||
# any pending records, we avoid delivering records out of order,
|
||||
# which would be bad.
|
||||
consumer.registerProducer(self, True)
|
||||
# There might be enough data queued to exceed 'expected' before we
|
||||
# leave this function. We must be sure to register the producer
|
||||
# before it gets unregistered.
|
||||
|
||||
self._consumer = consumer
|
||||
self._consumer_bytes_written = 0
|
||||
self._consumer_bytes_expected = expected
|
||||
d = None
|
||||
if expected is not None:
|
||||
d = defer.Deferred()
|
||||
self._consumer_deferred = d
|
||||
# drain any pending records
|
||||
while self._consumer and self._inbound_records:
|
||||
r = self._inbound_records.popleft()
|
||||
self._writeToConsumer(r)
|
||||
return d
|
||||
|
||||
def _writeToConsumer(self, record):
|
||||
self._consumer.write(record)
|
||||
self._consumer_bytes_written += len(record)
|
||||
if self._consumer_bytes_expected is not None:
|
||||
if self._consumer_bytes_written >= self._consumer_bytes_expected:
|
||||
d = self._consumer_deferred
|
||||
self.disconnectConsumer()
|
||||
d.callback(self._consumer_bytes_written)
|
||||
|
||||
def disconnectConsumer(self):
|
||||
self._consumer.unregisterProducer()
|
||||
self._consumer = None
|
||||
self._consumer_bytes_expected = None
|
||||
self._consumer_deferred = None
|
||||
|
||||
# Helper method to write a known number of bytes to a file. This has no
|
||||
# flow control: the filehandle cannot push back. 'progress' is an
|
||||
# optional callable which will be called frequently with the number of
|
||||
# bytes transferred so far. Returns a Deferred that fires (with the
|
||||
# number of bytes written) when the count is reached or the RecordPipe is
|
||||
# closed.
|
||||
def writeToFile(self, f, expected, progress=None):
|
||||
progress = progress or (lambda n: None)
|
||||
fc = FileConsumer(f, expected, progress)
|
||||
return self.connectConsumer(fc, expected)
|
||||
|
||||
class OutboundConnectionFactory(protocol.ClientFactory):
|
||||
protocol = Connection
|
||||
|
@ -647,6 +705,35 @@ class TransitSender(Common):
|
|||
class TransitReceiver(Common):
|
||||
is_sender = False
|
||||
|
||||
|
||||
# based on twisted.protocols.ftp.FileConsumer, but:
|
||||
# - call a progress-tracking function
|
||||
# - don't close the filehandle when done
|
||||
|
||||
@implementer(interfaces.IConsumer)
|
||||
class FileConsumer:
|
||||
def __init__(self, f, xfersize, progress_f):
|
||||
self._f = f
|
||||
self._xfersize = xfersize
|
||||
self._received = 0
|
||||
self._progress_f = progress_f
|
||||
self._producer = None
|
||||
self._progress_f(0)
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
assert not self._producer
|
||||
self._producer = producer
|
||||
assert streaming
|
||||
|
||||
def write(self, bytes):
|
||||
self._f.write(bytes)
|
||||
self._received += len(bytes)
|
||||
self._progress_f(self._received)
|
||||
|
||||
def unregisterProducer(self):
|
||||
assert self._producer
|
||||
self._producer = None
|
||||
|
||||
# the TransitSender/Receiver.connect() yields a Connection, on which you can
|
||||
# do send_record(), but what should the receive API be? set a callback for
|
||||
# inbound records? get a Deferred for the next record? The producer/consumer
|
||||
|
|
Loading…
Reference in New Issue
Block a user