diff --git a/src/wormhole/scripts/cmd_receive_twisted.py b/src/wormhole/scripts/cmd_receive_twisted.py index cc2423c..487fc14 100644 --- a/src/wormhole/scripts/cmd_receive_twisted.py +++ b/src/wormhole/scripts/cmd_receive_twisted.py @@ -1,10 +1,8 @@ from __future__ import print_function import io, json -from zope.interface import implementer -from twisted.internet import interfaces, defer from twisted.internet.defer import inlineCallbacks, returnValue from ..twisted.transcribe import Wormhole, WrongPasswordError -from ..twisted.transit import TransitReceiver #, TransitError +from ..twisted.transit import TransitReceiver from .cmd_receive_blocking import BlockingReceiver, RespondError, APPID from ..errors import TransferError from .progress import ProgressPrinter @@ -105,11 +103,14 @@ class TwistedReceiver(BlockingReceiver): progress_stdout = self.args.stdout if self.args.hide_progress: progress_stdout = io.StringIO() - pfc = ProgressingFileConsumer(f, self.xfersize, progress_stdout) - record_pipe.connectConsumer(pfc) - received = yield pfc.when_done - record_pipe.disconnectConsumer() + progress = ProgressPrinter(self.xfersize, progress_stdout) + + progress.start() + received = yield record_pipe.writeToFile(f, self.xfersize, + progress.update) + progress.finish() self.args.timing.finish_event(_start) + # except TransitError if received < self.xfersize: self.msg() @@ -124,36 +125,3 @@ class TwistedReceiver(BlockingReceiver): yield record_pipe.send_record(b"ok\n") yield record_pipe.close() self.args.timing.finish_event(_start) - -# based on twisted.protocols.ftp.FileConsumer, but: -# - finish after 'xfersize' bytes received, instead of connectionLost() -# - don't close the filehandle when done - -@implementer(interfaces.IConsumer) -class ProgressingFileConsumer: - def __init__(self, f, xfersize, progress_stdout): - self._f = f - self._xfersize = xfersize - self._received = 0 - self._progress = ProgressPrinter(xfersize, progress_stdout) - self._progress.start() - self.when_done = defer.Deferred() - - def registerProducer(self, producer, streaming): - self.producer = producer - assert streaming - - def write(self, bytes): - self._f.write(bytes) - self._received += len(bytes) - self._progress.update(self._received) - if self._received >= self._xfersize: - self._progress.finish() - d,self.when_done = self.when_done,None - d.callback(self._received) - - def unregisterProducer(self): - self.producer = None - if self.when_done: - # connection was dropped before all bytes were received - self.when_done.callback(self._received) diff --git a/src/wormhole/test/test_transit_twisted.py b/src/wormhole/test/test_transit_twisted.py index e89bb74..4e4ea81 100644 --- a/src/wormhole/test/test_transit_twisted.py +++ b/src/wormhole/test/test_transit_twisted.py @@ -1,4 +1,5 @@ from __future__ import print_function +import io from binascii import hexlify, unhexlify from twisted.trial import unittest from twisted.internet import defer, task, endpoints, protocol, address, error @@ -1008,7 +1009,8 @@ class Connection(unittest.TestCase): c.recordReceived(b"r2.") consumer = proto_helpers.StringTransport() - c.connectConsumer(consumer) + rv = c.connectConsumer(consumer) + self.assertIs(rv, None) self.assertIs(c._consumer, consumer) self.assertEqual(consumer.value(), b"r1.r2.") @@ -1029,6 +1031,109 @@ class Connection(unittest.TestCase): c.stopProducing() self.assertEqual(c.transport.producerState, "stopped") + def test_connectConsumer(self): + # connectConsumer() takes an optional number of bytes to expect, and + # fires a Deferred when that many have been written + c = transit.Connection(None, None, None, "description") + c._negotiation_d.addErrback(lambda err: None) # eat it + c.transport = proto_helpers.StringTransport() + c.recordReceived(b"r1.") + + consumer = proto_helpers.StringTransport() + results = [] + d = c.connectConsumer(consumer, expected=10) + d.addBoth(results.append) + self.assertEqual(consumer.value(), b"r1.") + self.assertEqual(results, []) + + c.recordReceived(b"r2.") + self.assertEqual(consumer.value(), b"r1.r2.") + self.assertEqual(results, []) + + c.recordReceived(b"r3.") + self.assertEqual(consumer.value(), b"r1.r2.r3.") + self.assertEqual(results, []) + + c.recordReceived(b"!") + self.assertEqual(consumer.value(), b"r1.r2.r3.!") + self.assertEqual(results, [10]) + + # that should automatically disconnect the consumer, and subsequent + # records should get queued, not delivered + self.assertIs(c._consumer, None) + c.recordReceived(b"overflow") + self.assertEqual(consumer.value(), b"r1.r2.r3.!") + + # now test that the Deferred errbacks when the connection is lost + results = [] + d = c.connectConsumer(consumer, expected=10) + d.addBoth(results.append) + + c.connectionLost() + self.assertEqual(len(results), 1) + f = results[0] + self.assertIsInstance(f, failure.Failure) + self.assertIsInstance(f.value, error.ConnectionClosed) + + def test_writeToFile(self): + c = transit.Connection(None, None, None, "description") + c._negotiation_d.addErrback(lambda err: None) # eat it + c.transport = proto_helpers.StringTransport() + c.recordReceived(b"r1.") + + f = io.BytesIO() + progress = [] + results = [] + d = c.writeToFile(f, 10, progress.append) + d.addBoth(results.append) + self.assertEqual(f.getvalue(), b"r1.") + self.assertEqual(progress, [0, 3]) + self.assertEqual(results, []) + + c.recordReceived(b"r2.") + self.assertEqual(f.getvalue(), b"r1.r2.") + self.assertEqual(progress, [0, 3, 6]) + self.assertEqual(results, []) + + c.recordReceived(b"r3.") + self.assertEqual(f.getvalue(), b"r1.r2.r3.") + self.assertEqual(progress, [0, 3, 6, 9]) + self.assertEqual(results, []) + + c.recordReceived(b"!") + self.assertEqual(f.getvalue(), b"r1.r2.r3.!") + self.assertEqual(progress, [0, 3, 6, 9, 10]) + self.assertEqual(results, [10]) + + # that should automatically disconnect the consumer, and subsequent + # records should get queued, not delivered + self.assertIs(c._consumer, None) + c.recordReceived(b"overflow.") + self.assertEqual(f.getvalue(), b"r1.r2.r3.!") + self.assertEqual(progress, [0, 3, 6, 9, 10]) + + # test what happens when enough data is queued ahead of time + c.recordReceived(b"second.") # now "overflow.second." + c.recordReceived(b"third.") # now "overflow.second.third." + f = io.BytesIO() + results = [] + d = c.writeToFile(f, 10) + d.addBoth(results.append) + self.assertEqual(f.getvalue(), b"overflow.second.") # whole records + self.assertEqual(results, [16]) + self.assertEqual(list(c._inbound_records), [b"third."]) + + # now test that the Deferred errbacks when the connection is lost + results = [] + d = c.writeToFile(f, 10) + d.addBoth(results.append) + + c.connectionLost() + self.assertEqual(len(results), 1) + f = results[0] + self.assertIsInstance(f, failure.Failure) + self.assertIsInstance(f.value, error.ConnectionClosed) + def test_consumer(self): # a local producer sends data to a consuming Transit object c = transit.Connection(None, None, None, "description") @@ -1046,6 +1151,20 @@ class Connection(unittest.TestCase): c.unregisterProducer() self.assertEqual(c.transport.producer, None) +class FileConsumer(unittest.TestCase): + def test_basic(self): + f = io.BytesIO() + progress = [] + fc = transit.FileConsumer(f, 100, progress.append) + self.assertEqual(progress, [0]) + self.assertEqual(f.getvalue(), b"") + fc.write(b"."* 99) + self.assertEqual(progress, [0, 99]) + self.assertEqual(f.getvalue(), b"."*99) + fc.write(b"!") + self.assertEqual(progress, [0, 99, 100]) + self.assertEqual(f.getvalue(), b"."*99+b"!") + DIRECT_HINT = u"tcp:direct:1234" RELAY_HINT = u"tcp:relay:1234" diff --git a/src/wormhole/twisted/transit.py b/src/wormhole/twisted/transit.py index 58176fa..83b1ca3 100644 --- a/src/wormhole/twisted/transit.py +++ b/src/wormhole/twisted/transit.py @@ -38,6 +38,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): self._negotiation_d = defer.Deferred(self._cancel) self._error = None self._consumer = None + self._consumer_bytes_written = 0 + self._consumer_bytes_expected = None + self._consumer_deferred = None self._inbound_records = collections.deque() self._waiting_reads = collections.deque() @@ -181,7 +184,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): def recordReceived(self, record): if self._consumer: - self._consumer.write(record) + self._writeToConsumer(record) return self._inbound_records.append(record) self._deliverRecords() @@ -226,6 +229,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): # timeout: BadHandshake("timeout") d.errback(self._error or BadHandshake("connection lost")) + if self._consumer_deferred: + self._consumer_deferred.errback(error.ConnectionClosed()) # IConsumer methods, for outbound flow-control. We pass these through to # the transport. The 'producer' is something like a t.p.basic.FileSender @@ -246,22 +251,75 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): def resumeProducing(self): self.transport.resumeProducing() - # Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to us. - # Inbound records will be written as bytes to the consumer. - def connectConsumer(self, consumer): + # Helper methods + + def connectConsumer(self, consumer, expected=None): + """Helper method to glue an instance of e.g. t.p.ftp.FileConsumer to + us. Inbound records will be written as bytes to the consumer. + + Set 'expected' to an integer to automatically disconnect when at + least that number of bytes have been written. This function will then + return a Deferred (that fires with the number of bytes actually + received). If the connection is lost while this Deferred is + outstanding, it will errback. + + If 'expected' is None, then this function returns None instead of a + Deferred, and you must call disconnectConsumer() when you are done.""" + if self._consumer: raise RuntimeError("A consumer is already attached: %r" % self._consumer) - self._consumer = consumer - # drain any pending records - while self._inbound_records: - r = self._inbound_records.popleft() - consumer.write(r) + + # be aware of an ordering hazard: when we call the consumer's + # .registerProducer method, they are likely to immediately call + # self.resumeProducing, which we'll deliver to self.transport, which + # might call our .dataReceived, which may cause more records to be + # available. By waiting to set self._consumer until *after* we drain + # any pending records, we avoid delivering records out of order, + # which would be bad. consumer.registerProducer(self, True) + # There might be enough data queued to exceed 'expected' before we + # leave this function. We must be sure to register the producer + # before it gets unregistered. + + self._consumer = consumer + self._consumer_bytes_written = 0 + self._consumer_bytes_expected = expected + d = None + if expected is not None: + d = defer.Deferred() + self._consumer_deferred = d + # drain any pending records + while self._consumer and self._inbound_records: + r = self._inbound_records.popleft() + self._writeToConsumer(r) + return d + + def _writeToConsumer(self, record): + self._consumer.write(record) + self._consumer_bytes_written += len(record) + if self._consumer_bytes_expected is not None: + if self._consumer_bytes_written >= self._consumer_bytes_expected: + d = self._consumer_deferred + self.disconnectConsumer() + d.callback(self._consumer_bytes_written) def disconnectConsumer(self): self._consumer.unregisterProducer() self._consumer = None + self._consumer_bytes_expected = None + self._consumer_deferred = None + + # Helper method to write a known number of bytes to a file. This has no + # flow control: the filehandle cannot push back. 'progress' is an + # optional callable which will be called frequently with the number of + # bytes transferred so far. Returns a Deferred that fires (with the + # number of bytes written) when the count is reached or the RecordPipe is + # closed. + def writeToFile(self, f, expected, progress=None): + progress = progress or (lambda n: None) + fc = FileConsumer(f, expected, progress) + return self.connectConsumer(fc, expected) class OutboundConnectionFactory(protocol.ClientFactory): protocol = Connection @@ -647,6 +705,35 @@ class TransitSender(Common): class TransitReceiver(Common): is_sender = False + +# based on twisted.protocols.ftp.FileConsumer, but: +# - call a progress-tracking function +# - don't close the filehandle when done + +@implementer(interfaces.IConsumer) +class FileConsumer: + def __init__(self, f, xfersize, progress_f): + self._f = f + self._xfersize = xfersize + self._received = 0 + self._progress_f = progress_f + self._producer = None + self._progress_f(0) + + def registerProducer(self, producer, streaming): + assert not self._producer + self._producer = producer + assert streaming + + def write(self, bytes): + self._f.write(bytes) + self._received += len(bytes) + self._progress_f(self._received) + + def unregisterProducer(self): + assert self._producer + self._producer = None + # the TransitSender/Receiver.connect() yields a Connection, on which you can # do send_record(), but what should the receive API be? set a callback for # inbound records? get a Deferred for the next record? The producer/consumer