twisted.transit: move FileConsumer into RecordPipe
This adds an expected= argument to Connection.connectConsumer(), which then returns a Deferred that fires when enough bytes have been written to the consumer. It also adds Connection.writeToFile(), a helper method that writes bytes to a filehandle.
This commit is contained in:
parent
7234e25897
commit
df2384bea2
|
@ -1,10 +1,8 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import io, json
|
import io, json
|
||||||
from zope.interface import implementer
|
|
||||||
from twisted.internet import interfaces, defer
|
|
||||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
||||||
from ..twisted.transit import TransitReceiver #, TransitError
|
from ..twisted.transit import TransitReceiver
|
||||||
from .cmd_receive_blocking import BlockingReceiver, RespondError, APPID
|
from .cmd_receive_blocking import BlockingReceiver, RespondError, APPID
|
||||||
from ..errors import TransferError
|
from ..errors import TransferError
|
||||||
from .progress import ProgressPrinter
|
from .progress import ProgressPrinter
|
||||||
|
@ -105,11 +103,14 @@ class TwistedReceiver(BlockingReceiver):
|
||||||
progress_stdout = self.args.stdout
|
progress_stdout = self.args.stdout
|
||||||
if self.args.hide_progress:
|
if self.args.hide_progress:
|
||||||
progress_stdout = io.StringIO()
|
progress_stdout = io.StringIO()
|
||||||
pfc = ProgressingFileConsumer(f, self.xfersize, progress_stdout)
|
progress = ProgressPrinter(self.xfersize, progress_stdout)
|
||||||
record_pipe.connectConsumer(pfc)
|
|
||||||
received = yield pfc.when_done
|
progress.start()
|
||||||
record_pipe.disconnectConsumer()
|
received = yield record_pipe.writeToFile(f, self.xfersize,
|
||||||
|
progress.update)
|
||||||
|
progress.finish()
|
||||||
self.args.timing.finish_event(_start)
|
self.args.timing.finish_event(_start)
|
||||||
|
|
||||||
# except TransitError
|
# except TransitError
|
||||||
if received < self.xfersize:
|
if received < self.xfersize:
|
||||||
self.msg()
|
self.msg()
|
||||||
|
@ -124,36 +125,3 @@ class TwistedReceiver(BlockingReceiver):
|
||||||
yield record_pipe.send_record(b"ok\n")
|
yield record_pipe.send_record(b"ok\n")
|
||||||
yield record_pipe.close()
|
yield record_pipe.close()
|
||||||
self.args.timing.finish_event(_start)
|
self.args.timing.finish_event(_start)
|
||||||
|
|
||||||
# based on twisted.protocols.ftp.FileConsumer, but:
|
|
||||||
# - finish after 'xfersize' bytes received, instead of connectionLost()
|
|
||||||
# - don't close the filehandle when done
|
|
||||||
|
|
||||||
@implementer(interfaces.IConsumer)
|
|
||||||
class ProgressingFileConsumer:
|
|
||||||
def __init__(self, f, xfersize, progress_stdout):
|
|
||||||
self._f = f
|
|
||||||
self._xfersize = xfersize
|
|
||||||
self._received = 0
|
|
||||||
self._progress = ProgressPrinter(xfersize, progress_stdout)
|
|
||||||
self._progress.start()
|
|
||||||
self.when_done = defer.Deferred()
|
|
||||||
|
|
||||||
def registerProducer(self, producer, streaming):
|
|
||||||
self.producer = producer
|
|
||||||
assert streaming
|
|
||||||
|
|
||||||
def write(self, bytes):
|
|
||||||
self._f.write(bytes)
|
|
||||||
self._received += len(bytes)
|
|
||||||
self._progress.update(self._received)
|
|
||||||
if self._received >= self._xfersize:
|
|
||||||
self._progress.finish()
|
|
||||||
d,self.when_done = self.when_done,None
|
|
||||||
d.callback(self._received)
|
|
||||||
|
|
||||||
def unregisterProducer(self):
|
|
||||||
self.producer = None
|
|
||||||
if self.when_done:
|
|
||||||
# connection was dropped before all bytes were received
|
|
||||||
self.when_done.callback(self._received)
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
import io
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.internet import defer, task, endpoints, protocol, address, error
|
from twisted.internet import defer, task, endpoints, protocol, address, error
|
||||||
|
@ -1008,7 +1009,8 @@ class Connection(unittest.TestCase):
|
||||||
c.recordReceived(b"r2.")
|
c.recordReceived(b"r2.")
|
||||||
|
|
||||||
consumer = proto_helpers.StringTransport()
|
consumer = proto_helpers.StringTransport()
|
||||||
c.connectConsumer(consumer)
|
rv = c.connectConsumer(consumer)
|
||||||
|
self.assertIs(rv, None)
|
||||||
self.assertIs(c._consumer, consumer)
|
self.assertIs(c._consumer, consumer)
|
||||||
self.assertEqual(consumer.value(), b"r1.r2.")
|
self.assertEqual(consumer.value(), b"r1.r2.")
|
||||||
|
|
||||||
|
@ -1029,6 +1031,109 @@ class Connection(unittest.TestCase):
|
||||||
c.stopProducing()
|
c.stopProducing()
|
||||||
self.assertEqual(c.transport.producerState, "stopped")
|
self.assertEqual(c.transport.producerState, "stopped")
|
||||||
|
|
||||||
|
def test_connectConsumer(self):
|
||||||
|
# connectConsumer() takes an optional number of bytes to expect, and
|
||||||
|
# fires a Deferred when that many have been written
|
||||||
|
c = transit.Connection(None, None, None, "description")
|
||||||
|
c._negotiation_d.addErrback(lambda err: None) # eat it
|
||||||
|
c.transport = proto_helpers.StringTransport()
|
||||||
|
c.recordReceived(b"r1.")
|
||||||
|
|
||||||
|
consumer = proto_helpers.StringTransport()
|
||||||
|
results = []
|
||||||
|
d = c.connectConsumer(consumer, expected=10)
|
||||||
|
d.addBoth(results.append)
|
||||||
|
self.assertEqual(consumer.value(), b"r1.")
|
||||||
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
|
c.recordReceived(b"r2.")
|
||||||
|
self.assertEqual(consumer.value(), b"r1.r2.")
|
||||||
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
|
c.recordReceived(b"r3.")
|
||||||
|
self.assertEqual(consumer.value(), b"r1.r2.r3.")
|
||||||
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
|
c.recordReceived(b"!")
|
||||||
|
self.assertEqual(consumer.value(), b"r1.r2.r3.!")
|
||||||
|
self.assertEqual(results, [10])
|
||||||
|
|
||||||
|
# that should automatically disconnect the consumer, and subsequent
|
||||||
|
# records should get queued, not delivered
|
||||||
|
self.assertIs(c._consumer, None)
|
||||||
|
c.recordReceived(b"overflow")
|
||||||
|
self.assertEqual(consumer.value(), b"r1.r2.r3.!")
|
||||||
|
|
||||||
|
# now test that the Deferred errbacks when the connection is lost
|
||||||
|
results = []
|
||||||
|
d = c.connectConsumer(consumer, expected=10)
|
||||||
|
d.addBoth(results.append)
|
||||||
|
|
||||||
|
c.connectionLost()
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
f = results[0]
|
||||||
|
self.assertIsInstance(f, failure.Failure)
|
||||||
|
self.assertIsInstance(f.value, error.ConnectionClosed)
|
||||||
|
|
||||||
|
def test_writeToFile(self):
|
||||||
|
c = transit.Connection(None, None, None, "description")
|
||||||
|
c._negotiation_d.addErrback(lambda err: None) # eat it
|
||||||
|
c.transport = proto_helpers.StringTransport()
|
||||||
|
c.recordReceived(b"r1.")
|
||||||
|
|
||||||
|
f = io.BytesIO()
|
||||||
|
progress = []
|
||||||
|
results = []
|
||||||
|
d = c.writeToFile(f, 10, progress.append)
|
||||||
|
d.addBoth(results.append)
|
||||||
|
self.assertEqual(f.getvalue(), b"r1.")
|
||||||
|
self.assertEqual(progress, [0, 3])
|
||||||
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
|
c.recordReceived(b"r2.")
|
||||||
|
self.assertEqual(f.getvalue(), b"r1.r2.")
|
||||||
|
self.assertEqual(progress, [0, 3, 6])
|
||||||
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
|
c.recordReceived(b"r3.")
|
||||||
|
self.assertEqual(f.getvalue(), b"r1.r2.r3.")
|
||||||
|
self.assertEqual(progress, [0, 3, 6, 9])
|
||||||
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
|
c.recordReceived(b"!")
|
||||||
|
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
|
||||||
|
self.assertEqual(progress, [0, 3, 6, 9, 10])
|
||||||
|
self.assertEqual(results, [10])
|
||||||
|
|
||||||
|
# that should automatically disconnect the consumer, and subsequent
|
||||||
|
# records should get queued, not delivered
|
||||||
|
self.assertIs(c._consumer, None)
|
||||||
|
c.recordReceived(b"overflow.")
|
||||||
|
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
|
||||||
|
self.assertEqual(progress, [0, 3, 6, 9, 10])
|
||||||
|
|
||||||
|
# test what happens when enough data is queued ahead of time
|
||||||
|
c.recordReceived(b"second.") # now "overflow.second."
|
||||||
|
c.recordReceived(b"third.") # now "overflow.second.third."
|
||||||
|
f = io.BytesIO()
|
||||||
|
results = []
|
||||||
|
d = c.writeToFile(f, 10)
|
||||||
|
d.addBoth(results.append)
|
||||||
|
self.assertEqual(f.getvalue(), b"overflow.second.") # whole records
|
||||||
|
self.assertEqual(results, [16])
|
||||||
|
self.assertEqual(list(c._inbound_records), [b"third."])
|
||||||
|
|
||||||
|
# now test that the Deferred errbacks when the connection is lost
|
||||||
|
results = []
|
||||||
|
d = c.writeToFile(f, 10)
|
||||||
|
d.addBoth(results.append)
|
||||||
|
|
||||||
|
c.connectionLost()
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
f = results[0]
|
||||||
|
self.assertIsInstance(f, failure.Failure)
|
||||||
|
self.assertIsInstance(f.value, error.ConnectionClosed)
|
||||||
|
|
||||||
def test_consumer(self):
|
def test_consumer(self):
|
||||||
# a local producer sends data to a consuming Transit object
|
# a local producer sends data to a consuming Transit object
|
||||||
c = transit.Connection(None, None, None, "description")
|
c = transit.Connection(None, None, None, "description")
|
||||||
|
@ -1046,6 +1151,20 @@ class Connection(unittest.TestCase):
|
||||||
c.unregisterProducer()
|
c.unregisterProducer()
|
||||||
self.assertEqual(c.transport.producer, None)
|
self.assertEqual(c.transport.producer, None)
|
||||||
|
|
||||||
|
class FileConsumer(unittest.TestCase):
|
||||||
|
def test_basic(self):
|
||||||
|
f = io.BytesIO()
|
||||||
|
progress = []
|
||||||
|
fc = transit.FileConsumer(f, 100, progress.append)
|
||||||
|
self.assertEqual(progress, [0])
|
||||||
|
self.assertEqual(f.getvalue(), b"")
|
||||||
|
fc.write(b"."* 99)
|
||||||
|
self.assertEqual(progress, [0, 99])
|
||||||
|
self.assertEqual(f.getvalue(), b"."*99)
|
||||||
|
fc.write(b"!")
|
||||||
|
self.assertEqual(progress, [0, 99, 100])
|
||||||
|
self.assertEqual(f.getvalue(), b"."*99+b"!")
|
||||||
|
|
||||||
|
|
||||||
DIRECT_HINT = u"tcp:direct:1234"
|
DIRECT_HINT = u"tcp:direct:1234"
|
||||||
RELAY_HINT = u"tcp:relay:1234"
|
RELAY_HINT = u"tcp:relay:1234"
|
||||||
|
|
|
@ -38,6 +38,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
self._negotiation_d = defer.Deferred(self._cancel)
|
self._negotiation_d = defer.Deferred(self._cancel)
|
||||||
self._error = None
|
self._error = None
|
||||||
self._consumer = None
|
self._consumer = None
|
||||||
|
self._consumer_bytes_written = 0
|
||||||
|
self._consumer_bytes_expected = None
|
||||||
|
self._consumer_deferred = None
|
||||||
self._inbound_records = collections.deque()
|
self._inbound_records = collections.deque()
|
||||||
self._waiting_reads = collections.deque()
|
self._waiting_reads = collections.deque()
|
||||||
|
|
||||||
|
@ -181,7 +184,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
|
|
||||||
def recordReceived(self, record):
|
def recordReceived(self, record):
|
||||||
if self._consumer:
|
if self._consumer:
|
||||||
self._consumer.write(record)
|
self._writeToConsumer(record)
|
||||||
return
|
return
|
||||||
self._inbound_records.append(record)
|
self._inbound_records.append(record)
|
||||||
self._deliverRecords()
|
self._deliverRecords()
|
||||||
|
@ -226,6 +229,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
# timeout: BadHandshake("timeout")
|
# timeout: BadHandshake("timeout")
|
||||||
|
|
||||||
d.errback(self._error or BadHandshake("connection lost"))
|
d.errback(self._error or BadHandshake("connection lost"))
|
||||||
|
if self._consumer_deferred:
|
||||||
|
self._consumer_deferred.errback(error.ConnectionClosed())
|
||||||
|
|
||||||
# IConsumer methods, for outbound flow-control. We pass these through to
|
# IConsumer methods, for outbound flow-control. We pass these through to
|
||||||
# the transport. The 'producer' is something like a t.p.basic.FileSender
|
# the transport. The 'producer' is something like a t.p.basic.FileSender
|
||||||
|
@ -246,22 +251,75 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
def resumeProducing(self):
|
def resumeProducing(self):
|
||||||
self.transport.resumeProducing()
|
self.transport.resumeProducing()
|
||||||
|
|
||||||
# Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to us.
|
# Helper methods
|
||||||
# Inbound records will be written as bytes to the consumer.
|
|
||||||
def connectConsumer(self, consumer):
|
def connectConsumer(self, consumer, expected=None):
|
||||||
|
"""Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to
|
||||||
|
us. Inbound records will be written as bytes to the consumer.
|
||||||
|
|
||||||
|
Set 'expected' to an integer to automatically disconnect when at
|
||||||
|
least that number of bytes have been written. This function will then
|
||||||
|
return a Deferred (that fires with the number of bytes actually
|
||||||
|
received). If the connection is lost while this Deferred is
|
||||||
|
outstanding, it will errback.
|
||||||
|
|
||||||
|
If 'expected' is None, then this function returns None instead of a
|
||||||
|
Deferred, and you must call disconnectConsumer() when you are done."""
|
||||||
|
|
||||||
if self._consumer:
|
if self._consumer:
|
||||||
raise RuntimeError("A consumer is already attached: %r" %
|
raise RuntimeError("A consumer is already attached: %r" %
|
||||||
self._consumer)
|
self._consumer)
|
||||||
self._consumer = consumer
|
|
||||||
# drain any pending records
|
# be aware of an ordering hazard: when we call the consumer's
|
||||||
while self._inbound_records:
|
# .registerProducer method, they are likely to immediately call
|
||||||
r = self._inbound_records.popleft()
|
# self.resumeProducing, which we'll deliver to self.transport, which
|
||||||
consumer.write(r)
|
# might call our .dataReceived, which may cause more records to be
|
||||||
|
# available. By waiting to set self._consumer until *after* we drain
|
||||||
|
# any pending records, we avoid delivering records out of order,
|
||||||
|
# which would be bad.
|
||||||
consumer.registerProducer(self, True)
|
consumer.registerProducer(self, True)
|
||||||
|
# There might be enough data queued to exceed 'expected' before we
|
||||||
|
# leave this function. We must be sure to register the producer
|
||||||
|
# before it gets unregistered.
|
||||||
|
|
||||||
|
self._consumer = consumer
|
||||||
|
self._consumer_bytes_written = 0
|
||||||
|
self._consumer_bytes_expected = expected
|
||||||
|
d = None
|
||||||
|
if expected is not None:
|
||||||
|
d = defer.Deferred()
|
||||||
|
self._consumer_deferred = d
|
||||||
|
# drain any pending records
|
||||||
|
while self._consumer and self._inbound_records:
|
||||||
|
r = self._inbound_records.popleft()
|
||||||
|
self._writeToConsumer(r)
|
||||||
|
return d
|
||||||
|
|
||||||
|
def _writeToConsumer(self, record):
|
||||||
|
self._consumer.write(record)
|
||||||
|
self._consumer_bytes_written += len(record)
|
||||||
|
if self._consumer_bytes_expected is not None:
|
||||||
|
if self._consumer_bytes_written >= self._consumer_bytes_expected:
|
||||||
|
d = self._consumer_deferred
|
||||||
|
self.disconnectConsumer()
|
||||||
|
d.callback(self._consumer_bytes_written)
|
||||||
|
|
||||||
def disconnectConsumer(self):
|
def disconnectConsumer(self):
|
||||||
self._consumer.unregisterProducer()
|
self._consumer.unregisterProducer()
|
||||||
self._consumer = None
|
self._consumer = None
|
||||||
|
self._consumer_bytes_expected = None
|
||||||
|
self._consumer_deferred = None
|
||||||
|
|
||||||
|
# Helper method to write a known number of bytes to a file. This has no
|
||||||
|
# flow control: the filehandle cannot push back. 'progress' is an
|
||||||
|
# optional callable which will be called frequently with the number of
|
||||||
|
# bytes transferred so far. Returns a Deferred that fires (with the
|
||||||
|
# number of bytes written) when the count is reached or the RecordPipe is
|
||||||
|
# closed.
|
||||||
|
def writeToFile(self, f, expected, progress=None):
|
||||||
|
progress = progress or (lambda n: None)
|
||||||
|
fc = FileConsumer(f, expected, progress)
|
||||||
|
return self.connectConsumer(fc, expected)
|
||||||
|
|
||||||
class OutboundConnectionFactory(protocol.ClientFactory):
|
class OutboundConnectionFactory(protocol.ClientFactory):
|
||||||
protocol = Connection
|
protocol = Connection
|
||||||
|
@ -647,6 +705,35 @@ class TransitSender(Common):
|
||||||
class TransitReceiver(Common):
|
class TransitReceiver(Common):
|
||||||
is_sender = False
|
is_sender = False
|
||||||
|
|
||||||
|
|
||||||
|
# based on twisted.protocols.ftp.FileConsumer, but:
|
||||||
|
# - call a progress-tracking function
|
||||||
|
# - don't close the filehandle when done
|
||||||
|
|
||||||
|
@implementer(interfaces.IConsumer)
|
||||||
|
class FileConsumer:
|
||||||
|
def __init__(self, f, xfersize, progress_f):
|
||||||
|
self._f = f
|
||||||
|
self._xfersize = xfersize
|
||||||
|
self._received = 0
|
||||||
|
self._progress_f = progress_f
|
||||||
|
self._producer = None
|
||||||
|
self._progress_f(0)
|
||||||
|
|
||||||
|
def registerProducer(self, producer, streaming):
|
||||||
|
assert not self._producer
|
||||||
|
self._producer = producer
|
||||||
|
assert streaming
|
||||||
|
|
||||||
|
def write(self, bytes):
|
||||||
|
self._f.write(bytes)
|
||||||
|
self._received += len(bytes)
|
||||||
|
self._progress_f(self._received)
|
||||||
|
|
||||||
|
def unregisterProducer(self):
|
||||||
|
assert self._producer
|
||||||
|
self._producer = None
|
||||||
|
|
||||||
# the TransitSender/Receiver.connect() yields a Connection, on which you can
|
# the TransitSender/Receiver.connect() yields a Connection, on which you can
|
||||||
# do send_record(), but what should the receive API be? set a callback for
|
# do send_record(), but what should the receive API be? set a callback for
|
||||||
# inbound records? get a Deferred for the next record? The producer/consumer
|
# inbound records? get a Deferred for the next record? The producer/consumer
|
||||||
|
|
Loading…
Reference in New Issue
Block a user