diff --git a/setup.py b/setup.py index 95c90e5..0e4e4e3 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ setup(name="magic-wormhole", "wormhole-server = wormhole.server.runner:entry", ]}, install_requires=["spake2==0.3", "pynacl", "requests", "argparse", - "six", "twisted >= 16.1.0", "hkdf", + "six", "twisted >= 16.1.0", "hkdf", "tqdm", "autobahn[twisted]", "pytrie", # autobahn seems to have a bug, and one plugin throws # errors unless pytrie is installed diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index 14716dc..9ce9373 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -1,12 +1,11 @@ from __future__ import print_function -import io, os, sys, json, binascii, six, tempfile, zipfile +import os, sys, json, binascii, six, tempfile, zipfile +from tqdm import tqdm from twisted.internet import reactor, defer from twisted.internet.defer import inlineCallbacks, returnValue from ..twisted.transcribe import Wormhole, WrongPasswordError from ..twisted.transit import TransitReceiver from ..errors import TransferError -from .progress import ProgressPrinter - APPID = u"lothar.com/wormhole/text-or-file-xfer" @@ -217,15 +216,12 @@ class TwistedReceiver: self.msg(u"Receiving (%s).." % record_pipe.describe()) _start = self.args.timing.add_event("rx file") - progress_stdout = self.args.stdout - if self.args.hide_progress: - progress_stdout = io.StringIO() - progress = ProgressPrinter(self.xfersize, progress_stdout) - - progress.start() - received = yield record_pipe.writeToFile(f, self.xfersize, - progress.update) - progress.finish() + progress = tqdm(file=self.args.stdout, + disable=self.args.hide_progress, + unit="B", unit_scale=True, total=self.xfersize) + with progress: + received = yield record_pipe.writeToFile(f, self.xfersize, + progress.update) self.args.timing.finish_event(_start) # except TransitError diff --git a/src/wormhole/cli/cmd_send.py b/src/wormhole/cli/cmd_send.py index 9e6fad1..b5bbe24 100644 --- a/src/wormhole/cli/cmd_send.py +++ b/src/wormhole/cli/cmd_send.py @@ -1,10 +1,10 @@ from __future__ import print_function -import os, sys, io, json, binascii, six, tempfile, zipfile +import os, sys, json, binascii, six, tempfile, zipfile +from tqdm import tqdm from twisted.protocols import basic from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue from ..errors import TransferError -from .progress import ProgressPrinter from ..twisted.transcribe import Wormhole, WrongPasswordError from ..twisted.transit import TransitSender @@ -189,24 +189,6 @@ def send_twisted(args, reactor=reactor): args.stdout, args.hide_progress, args.timing) returnValue(0) -class ProgressingFileSender(basic.FileSender): - def __init__(self, filesize, stdout): - self._sent = 0 - self._progress = ProgressPrinter(filesize, stdout) - self._progress.start() - def progress(self, data): - self._sent += len(data) - self._progress.update(self._sent) - return data - def beginFileTransfer(self, file, consumer): - d = basic.FileSender.beginFileTransfer(self, file, consumer, - self.progress) - d.addCallback(self.done) - return d - def done(self, res): - self._progress.finish() - return res - @inlineCallbacks def _send_file_twisted(tdata, transit_sender, fd_to_send, stdout, hide_progress, timing): @@ -216,16 +198,22 @@ def _send_file_twisted(tdata, transit_sender, fd_to_send, fd_to_send.seek(0,2) filesize = fd_to_send.tell() fd_to_send.seek(0,0) - progress_stdout = stdout - if hide_progress: - progress_stdout = io.StringIO() record_pipe = yield transit_sender.connect() # record_pipe should implement IConsumer, chunks are just records print(u"Sending (%s).." % record_pipe.describe(), file=stdout) - pfs = ProgressingFileSender(filesize, progress_stdout) + + progress = tqdm(file=stdout, disable=hide_progress, + unit="B", unit_scale=True, + total=filesize) + def _count(data): + progress.update(len(data)) + return data + fs = basic.FileSender() + _start = timing.add_event("tx file") - yield pfs.beginFileTransfer(fd_to_send, record_pipe) + with progress: + yield fs.beginFileTransfer(fd_to_send, record_pipe, transform=_count) timing.finish_event(_start) print(u"File sent.. waiting for confirmation", file=stdout) diff --git a/src/wormhole/cli/progress.py b/src/wormhole/cli/progress.py deleted file mode 100644 index 23f13e7..0000000 --- a/src/wormhole/cli/progress.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import print_function -import time - -class ProgressPrinter: - def __init__(self, expected, stdout, update_every=0.2): - self._expected = expected - self._stdout = stdout - self._update_every = update_every - - def _now(self): - return time.time() - - def start(self): - self._print(0) - self._next_update = self._now() + self._update_every - - def update(self, completed): - now = self._now() - if now < self._next_update: - return - self._next_update = now + self._update_every - self._print(completed) - - def finish(self): - self._print(self._expected) - print(u"", file=self._stdout) - - def _print(self, completed): - # scp does "<>(13% 168MB 39.3MB/s 00:27 ETA)" - # we do "Progress: #### 13% 168MB" - screen_width = 70 - bar_width = screen_width - 30 - fmt = "Progress: %-{}s %3d%% %4d%s".format(bar_width) - short_unit_size, short_unit_name = 1, "B" - if self._expected > 9999: - short_unit_size, short_unit_name = 1000, "KB" - if self._expected > 9999*1000: - short_unit_size, short_unit_name = 1000*1000, "MB" - if self._expected > 9999*1000*1000: - short_unit_size, short_unit_name = 1000*1000*1000, "GB" - - percentage_complete = ((1.0 * completed / self._expected) - if self._expected - else 1.0) - bars = "#" * int(percentage_complete * bar_width) - perc = int(100 * percentage_complete) - short_unit_count = int(completed / short_unit_size) - out = fmt % (bars, perc, short_unit_count, short_unit_name) - print(u"\r"+" "*screen_width, end=u"", file=self._stdout) - print(u"\r"+out, end=u"", file=self._stdout) - self._stdout.flush() diff --git a/src/wormhole/test/test_progress.py b/src/wormhole/test/test_progress.py deleted file mode 100644 index b48a105..0000000 --- a/src/wormhole/test/test_progress.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import print_function -import io, time -from twisted.trial import unittest -from ..cli import progress - -class Progress(unittest.TestCase): - def test_time(self): - p = progress.ProgressPrinter(1e6, None) - start = time.time() - now = p._now() - finish = time.time() - self.assertTrue(start <= now <= finish, (start, now, finish)) - - def test_basic(self): - stdout = io.StringIO() - p = progress.ProgressPrinter(1e6, stdout) - p._now = lambda: 0.0 - p.start() - - erase = u"\r"+u" "*70 - expected = erase - fmt = "Progress: %-40s %3d%% %4d%s" - expected += u"\r" + fmt % ("", 0, 0, "KB") - self.assertEqual(stdout.getvalue(), expected) - - p.update(1e3) # no change, too soon - self.assertEqual(stdout.getvalue(), expected) - - p._now = lambda: 1.0 - p.update(1e3) # enough "time" has passed - expected += erase + u"\r" + fmt % ("", 0, 1, "KB") - self.assertEqual(stdout.getvalue(), expected) - - p._now = lambda: 2.0 - p.update(500e3) - expected += erase + u"\r" + fmt % ("#"*20, 50, 500, "KB") - self.assertEqual(stdout.getvalue(), expected) - - p._now = lambda: 3.0 - p.finish() - expected += erase + u"\r" + fmt % ("#"*40, 100, 1000, "KB") - expected += u"\n" - self.assertEqual(stdout.getvalue(), expected) - - def test_units(self): - def _try(size): - stdout = io.StringIO() - p = progress.ProgressPrinter(size, stdout) - p.finish() - return stdout.getvalue() - - fmt = "Progress: %-40s %3d%% %4d%s" - def _expect(count, units): - erase = u"\r"+u" "*70 - expected = erase - expected += u"\r" + fmt % ("#"*40, 100, count, units) - expected += u"\n" - return expected - - self.assertEqual(_try(900), _expect(900, "B")) - self.assertEqual(_try(9e3), _expect(9000, "B")) - self.assertEqual(_try(90e3), _expect(90, "KB")) - self.assertEqual(_try(900e3), _expect(900, "KB")) - self.assertEqual(_try(9e6), _expect(9000, "KB")) - self.assertEqual(_try(90e6), _expect(90, "MB")) - self.assertEqual(_try(900e6), _expect(900, "MB")) - self.assertEqual(_try(9e9), _expect(9000, "MB")) - self.assertEqual(_try(90e9), _expect(90, "GB")) - self.assertEqual(_try(900e9), _expect(900, "GB")) diff --git a/src/wormhole/test/test_transit_twisted.py b/src/wormhole/test/test_transit_twisted.py index 4e4ea81..4d87816 100644 --- a/src/wormhole/test/test_transit_twisted.py +++ b/src/wormhole/test/test_transit_twisted.py @@ -1087,22 +1087,22 @@ class Connection(unittest.TestCase): d = c.writeToFile(f, 10, progress.append) d.addBoth(results.append) self.assertEqual(f.getvalue(), b"r1.") - self.assertEqual(progress, [0, 3]) + self.assertEqual(progress, [3]) self.assertEqual(results, []) c.recordReceived(b"r2.") self.assertEqual(f.getvalue(), b"r1.r2.") - self.assertEqual(progress, [0, 3, 6]) + self.assertEqual(progress, [3, 3]) self.assertEqual(results, []) c.recordReceived(b"r3.") self.assertEqual(f.getvalue(), b"r1.r2.r3.") - self.assertEqual(progress, [0, 3, 6, 9]) + self.assertEqual(progress, [3, 3, 3]) self.assertEqual(results, []) c.recordReceived(b"!") self.assertEqual(f.getvalue(), b"r1.r2.r3.!") - self.assertEqual(progress, [0, 3, 6, 9, 10]) + self.assertEqual(progress, [3, 3, 3, 1]) self.assertEqual(results, [10]) # that should automatically disconnect the consumer, and subsequent @@ -1110,7 +1110,7 @@ class Connection(unittest.TestCase): self.assertIs(c._consumer, None) c.recordReceived(b"overflow.") self.assertEqual(f.getvalue(), b"r1.r2.r3.!") - self.assertEqual(progress, [0, 3, 6, 9, 10]) + self.assertEqual(progress, [3, 3, 3, 1]) # test what happens when enough data is queued ahead of time c.recordReceived(b"second.") # now "overflow.second." @@ -1155,14 +1155,14 @@ class FileConsumer(unittest.TestCase): def test_basic(self): f = io.BytesIO() progress = [] - fc = transit.FileConsumer(f, 100, progress.append) - self.assertEqual(progress, [0]) + fc = transit.FileConsumer(f, progress.append) + self.assertEqual(progress, []) self.assertEqual(f.getvalue(), b"") fc.write(b"."* 99) - self.assertEqual(progress, [0, 99]) + self.assertEqual(progress, [99]) self.assertEqual(f.getvalue(), b"."*99) fc.write(b"!") - self.assertEqual(progress, [0, 99, 100]) + self.assertEqual(progress, [99, 1]) self.assertEqual(f.getvalue(), b"."*99+b"!") diff --git a/src/wormhole/twisted/transit.py b/src/wormhole/twisted/transit.py index 8d3d757..7987962 100644 --- a/src/wormhole/twisted/transit.py +++ b/src/wormhole/twisted/transit.py @@ -394,13 +394,11 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): # Helper method to write a known number of bytes to a file. This has no # flow control: the filehandle cannot push back. 'progress' is an - # optional callable which will be called frequently with the number of - # bytes transferred so far. Returns a Deferred that fires (with the - # number of bytes written) when the count is reached or the RecordPipe is - # closed. + # optional callable which will be called on each write (with the number + # of bytes written). Returns a Deferred that fires (with the number of + # bytes written) when the count is reached or the RecordPipe is closed. def writeToFile(self, f, expected, progress=None): - progress = progress or (lambda n: None) - fc = FileConsumer(f, expected, progress) + fc = FileConsumer(f, progress) return self.connectConsumer(fc, expected) class OutboundConnectionFactory(protocol.ClientFactory): @@ -805,19 +803,16 @@ class TransitReceiver(Common): is_sender = False -# based on twisted.protocols.ftp.FileConsumer, but: -# - call a progress-tracking function -# - don't close the filehandle when done +# based on twisted.protocols.ftp.FileConsumer, but don't close the filehandle +# when done, and add a progress function that gets called with the length of +# each write. @implementer(interfaces.IConsumer) class FileConsumer: - def __init__(self, f, xfersize, progress_f): + def __init__(self, f, progress=None): self._f = f - self._xfersize = xfersize - self._received = 0 - self._progress_f = progress_f + self._progress = progress self._producer = None - self._progress_f(0) def registerProducer(self, producer, streaming): assert not self._producer @@ -826,8 +821,8 @@ class FileConsumer: def write(self, bytes): self._f.write(bytes) - self._received += len(bytes) - self._progress_f(self._received) + if self._progress: + self._progress(len(bytes)) def unregisterProducer(self): assert self._producer @@ -851,8 +846,6 @@ class FileConsumer: # check start/finish time-gathering instrumentation -# add progress API - # relay URLs are probably mishandled: both sides probably send their URL, # then connect to the *other* side's URL, when they really should connect to # both their own and the other side's. The current implementation probably