switch to tqdm for nicer CLI progress bars

This commit is contained in:
Brian Warner 2016-04-24 12:04:05 -07:00
parent 16c6c0977e
commit 86edf96412
7 changed files with 42 additions and 185 deletions

View File

@ -26,7 +26,7 @@ setup(name="magic-wormhole",
"wormhole-server = wormhole.server.runner:entry", "wormhole-server = wormhole.server.runner:entry",
]}, ]},
install_requires=["spake2==0.3", "pynacl", "requests", "argparse", install_requires=["spake2==0.3", "pynacl", "requests", "argparse",
"six", "twisted >= 16.1.0", "hkdf", "six", "twisted >= 16.1.0", "hkdf", "tqdm",
"autobahn[twisted]", "pytrie", "autobahn[twisted]", "pytrie",
# autobahn seems to have a bug, and one plugin throws # autobahn seems to have a bug, and one plugin throws
# errors unless pytrie is installed # errors unless pytrie is installed

View File

@ -1,12 +1,11 @@
from __future__ import print_function from __future__ import print_function
import io, os, sys, json, binascii, six, tempfile, zipfile import os, sys, json, binascii, six, tempfile, zipfile
from tqdm import tqdm
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from ..twisted.transcribe import Wormhole, WrongPasswordError from ..twisted.transcribe import Wormhole, WrongPasswordError
from ..twisted.transit import TransitReceiver from ..twisted.transit import TransitReceiver
from ..errors import TransferError from ..errors import TransferError
from .progress import ProgressPrinter
APPID = u"lothar.com/wormhole/text-or-file-xfer" APPID = u"lothar.com/wormhole/text-or-file-xfer"
@ -217,15 +216,12 @@ class TwistedReceiver:
self.msg(u"Receiving (%s).." % record_pipe.describe()) self.msg(u"Receiving (%s).." % record_pipe.describe())
_start = self.args.timing.add_event("rx file") _start = self.args.timing.add_event("rx file")
progress_stdout = self.args.stdout progress = tqdm(file=self.args.stdout,
if self.args.hide_progress: disable=self.args.hide_progress,
progress_stdout = io.StringIO() unit="B", unit_scale=True, total=self.xfersize)
progress = ProgressPrinter(self.xfersize, progress_stdout) with progress:
progress.start()
received = yield record_pipe.writeToFile(f, self.xfersize, received = yield record_pipe.writeToFile(f, self.xfersize,
progress.update) progress.update)
progress.finish()
self.args.timing.finish_event(_start) self.args.timing.finish_event(_start)
# except TransitError # except TransitError

View File

@ -1,10 +1,10 @@
from __future__ import print_function from __future__ import print_function
import os, sys, io, json, binascii, six, tempfile, zipfile import os, sys, json, binascii, six, tempfile, zipfile
from tqdm import tqdm
from twisted.protocols import basic from twisted.protocols import basic
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from ..errors import TransferError from ..errors import TransferError
from .progress import ProgressPrinter
from ..twisted.transcribe import Wormhole, WrongPasswordError from ..twisted.transcribe import Wormhole, WrongPasswordError
from ..twisted.transit import TransitSender from ..twisted.transit import TransitSender
@ -189,24 +189,6 @@ def send_twisted(args, reactor=reactor):
args.stdout, args.hide_progress, args.timing) args.stdout, args.hide_progress, args.timing)
returnValue(0) returnValue(0)
class ProgressingFileSender(basic.FileSender):
def __init__(self, filesize, stdout):
self._sent = 0
self._progress = ProgressPrinter(filesize, stdout)
self._progress.start()
def progress(self, data):
self._sent += len(data)
self._progress.update(self._sent)
return data
def beginFileTransfer(self, file, consumer):
d = basic.FileSender.beginFileTransfer(self, file, consumer,
self.progress)
d.addCallback(self.done)
return d
def done(self, res):
self._progress.finish()
return res
@inlineCallbacks @inlineCallbacks
def _send_file_twisted(tdata, transit_sender, fd_to_send, def _send_file_twisted(tdata, transit_sender, fd_to_send,
stdout, hide_progress, timing): stdout, hide_progress, timing):
@ -216,16 +198,22 @@ def _send_file_twisted(tdata, transit_sender, fd_to_send,
fd_to_send.seek(0,2) fd_to_send.seek(0,2)
filesize = fd_to_send.tell() filesize = fd_to_send.tell()
fd_to_send.seek(0,0) fd_to_send.seek(0,0)
progress_stdout = stdout
if hide_progress:
progress_stdout = io.StringIO()
record_pipe = yield transit_sender.connect() record_pipe = yield transit_sender.connect()
# record_pipe should implement IConsumer, chunks are just records # record_pipe should implement IConsumer, chunks are just records
print(u"Sending (%s).." % record_pipe.describe(), file=stdout) print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
pfs = ProgressingFileSender(filesize, progress_stdout)
progress = tqdm(file=stdout, disable=hide_progress,
unit="B", unit_scale=True,
total=filesize)
def _count(data):
progress.update(len(data))
return data
fs = basic.FileSender()
_start = timing.add_event("tx file") _start = timing.add_event("tx file")
yield pfs.beginFileTransfer(fd_to_send, record_pipe) with progress:
yield fs.beginFileTransfer(fd_to_send, record_pipe, transform=_count)
timing.finish_event(_start) timing.finish_event(_start)
print(u"File sent.. waiting for confirmation", file=stdout) print(u"File sent.. waiting for confirmation", file=stdout)

View File

@ -1,51 +0,0 @@
from __future__ import print_function
import time
class ProgressPrinter:
def __init__(self, expected, stdout, update_every=0.2):
self._expected = expected
self._stdout = stdout
self._update_every = update_every
def _now(self):
return time.time()
def start(self):
self._print(0)
self._next_update = self._now() + self._update_every
def update(self, completed):
now = self._now()
if now < self._next_update:
return
self._next_update = now + self._update_every
self._print(completed)
def finish(self):
self._print(self._expected)
print(u"", file=self._stdout)
def _print(self, completed):
# scp does "<<FILENAME >>(13% 168MB 39.3MB/s 00:27 ETA)"
# we do "Progress: #### 13% 168MB"
screen_width = 70
bar_width = screen_width - 30
fmt = "Progress: %-{}s %3d%% %4d%s".format(bar_width)
short_unit_size, short_unit_name = 1, "B"
if self._expected > 9999:
short_unit_size, short_unit_name = 1000, "KB"
if self._expected > 9999*1000:
short_unit_size, short_unit_name = 1000*1000, "MB"
if self._expected > 9999*1000*1000:
short_unit_size, short_unit_name = 1000*1000*1000, "GB"
percentage_complete = ((1.0 * completed / self._expected)
if self._expected
else 1.0)
bars = "#" * int(percentage_complete * bar_width)
perc = int(100 * percentage_complete)
short_unit_count = int(completed / short_unit_size)
out = fmt % (bars, perc, short_unit_count, short_unit_name)
print(u"\r"+" "*screen_width, end=u"", file=self._stdout)
print(u"\r"+out, end=u"", file=self._stdout)
self._stdout.flush()

View File

@ -1,69 +0,0 @@
from __future__ import print_function
import io, time
from twisted.trial import unittest
from ..cli import progress
class Progress(unittest.TestCase):
def test_time(self):
p = progress.ProgressPrinter(1e6, None)
start = time.time()
now = p._now()
finish = time.time()
self.assertTrue(start <= now <= finish, (start, now, finish))
def test_basic(self):
stdout = io.StringIO()
p = progress.ProgressPrinter(1e6, stdout)
p._now = lambda: 0.0
p.start()
erase = u"\r"+u" "*70
expected = erase
fmt = "Progress: %-40s %3d%% %4d%s"
expected += u"\r" + fmt % ("", 0, 0, "KB")
self.assertEqual(stdout.getvalue(), expected)
p.update(1e3) # no change, too soon
self.assertEqual(stdout.getvalue(), expected)
p._now = lambda: 1.0
p.update(1e3) # enough "time" has passed
expected += erase + u"\r" + fmt % ("", 0, 1, "KB")
self.assertEqual(stdout.getvalue(), expected)
p._now = lambda: 2.0
p.update(500e3)
expected += erase + u"\r" + fmt % ("#"*20, 50, 500, "KB")
self.assertEqual(stdout.getvalue(), expected)
p._now = lambda: 3.0
p.finish()
expected += erase + u"\r" + fmt % ("#"*40, 100, 1000, "KB")
expected += u"\n"
self.assertEqual(stdout.getvalue(), expected)
def test_units(self):
def _try(size):
stdout = io.StringIO()
p = progress.ProgressPrinter(size, stdout)
p.finish()
return stdout.getvalue()
fmt = "Progress: %-40s %3d%% %4d%s"
def _expect(count, units):
erase = u"\r"+u" "*70
expected = erase
expected += u"\r" + fmt % ("#"*40, 100, count, units)
expected += u"\n"
return expected
self.assertEqual(_try(900), _expect(900, "B"))
self.assertEqual(_try(9e3), _expect(9000, "B"))
self.assertEqual(_try(90e3), _expect(90, "KB"))
self.assertEqual(_try(900e3), _expect(900, "KB"))
self.assertEqual(_try(9e6), _expect(9000, "KB"))
self.assertEqual(_try(90e6), _expect(90, "MB"))
self.assertEqual(_try(900e6), _expect(900, "MB"))
self.assertEqual(_try(9e9), _expect(9000, "MB"))
self.assertEqual(_try(90e9), _expect(90, "GB"))
self.assertEqual(_try(900e9), _expect(900, "GB"))

View File

@ -1087,22 +1087,22 @@ class Connection(unittest.TestCase):
d = c.writeToFile(f, 10, progress.append) d = c.writeToFile(f, 10, progress.append)
d.addBoth(results.append) d.addBoth(results.append)
self.assertEqual(f.getvalue(), b"r1.") self.assertEqual(f.getvalue(), b"r1.")
self.assertEqual(progress, [0, 3]) self.assertEqual(progress, [3])
self.assertEqual(results, []) self.assertEqual(results, [])
c.recordReceived(b"r2.") c.recordReceived(b"r2.")
self.assertEqual(f.getvalue(), b"r1.r2.") self.assertEqual(f.getvalue(), b"r1.r2.")
self.assertEqual(progress, [0, 3, 6]) self.assertEqual(progress, [3, 3])
self.assertEqual(results, []) self.assertEqual(results, [])
c.recordReceived(b"r3.") c.recordReceived(b"r3.")
self.assertEqual(f.getvalue(), b"r1.r2.r3.") self.assertEqual(f.getvalue(), b"r1.r2.r3.")
self.assertEqual(progress, [0, 3, 6, 9]) self.assertEqual(progress, [3, 3, 3])
self.assertEqual(results, []) self.assertEqual(results, [])
c.recordReceived(b"!") c.recordReceived(b"!")
self.assertEqual(f.getvalue(), b"r1.r2.r3.!") self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
self.assertEqual(progress, [0, 3, 6, 9, 10]) self.assertEqual(progress, [3, 3, 3, 1])
self.assertEqual(results, [10]) self.assertEqual(results, [10])
# that should automatically disconnect the consumer, and subsequent # that should automatically disconnect the consumer, and subsequent
@ -1110,7 +1110,7 @@ class Connection(unittest.TestCase):
self.assertIs(c._consumer, None) self.assertIs(c._consumer, None)
c.recordReceived(b"overflow.") c.recordReceived(b"overflow.")
self.assertEqual(f.getvalue(), b"r1.r2.r3.!") self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
self.assertEqual(progress, [0, 3, 6, 9, 10]) self.assertEqual(progress, [3, 3, 3, 1])
# test what happens when enough data is queued ahead of time # test what happens when enough data is queued ahead of time
c.recordReceived(b"second.") # now "overflow.second." c.recordReceived(b"second.") # now "overflow.second."
@ -1155,14 +1155,14 @@ class FileConsumer(unittest.TestCase):
def test_basic(self): def test_basic(self):
f = io.BytesIO() f = io.BytesIO()
progress = [] progress = []
fc = transit.FileConsumer(f, 100, progress.append) fc = transit.FileConsumer(f, progress.append)
self.assertEqual(progress, [0]) self.assertEqual(progress, [])
self.assertEqual(f.getvalue(), b"") self.assertEqual(f.getvalue(), b"")
fc.write(b"."* 99) fc.write(b"."* 99)
self.assertEqual(progress, [0, 99]) self.assertEqual(progress, [99])
self.assertEqual(f.getvalue(), b"."*99) self.assertEqual(f.getvalue(), b"."*99)
fc.write(b"!") fc.write(b"!")
self.assertEqual(progress, [0, 99, 100]) self.assertEqual(progress, [99, 1])
self.assertEqual(f.getvalue(), b"."*99+b"!") self.assertEqual(f.getvalue(), b"."*99+b"!")

View File

@ -394,13 +394,11 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
# Helper method to write a known number of bytes to a file. This has no # Helper method to write a known number of bytes to a file. This has no
# flow control: the filehandle cannot push back. 'progress' is an # flow control: the filehandle cannot push back. 'progress' is an
# optional callable which will be called frequently with the number of # optional callable which will be called on each write (with the number
# bytes transferred so far. Returns a Deferred that fires (with the # of bytes written). Returns a Deferred that fires (with the number of
# number of bytes written) when the count is reached or the RecordPipe is # bytes written) when the count is reached or the RecordPipe is closed.
# closed.
def writeToFile(self, f, expected, progress=None): def writeToFile(self, f, expected, progress=None):
progress = progress or (lambda n: None) fc = FileConsumer(f, progress)
fc = FileConsumer(f, expected, progress)
return self.connectConsumer(fc, expected) return self.connectConsumer(fc, expected)
class OutboundConnectionFactory(protocol.ClientFactory): class OutboundConnectionFactory(protocol.ClientFactory):
@ -805,19 +803,16 @@ class TransitReceiver(Common):
is_sender = False is_sender = False
# based on twisted.protocols.ftp.FileConsumer, but: # based on twisted.protocols.ftp.FileConsumer, but don't close the filehandle
# - call a progress-tracking function # when done, and add a progress function that gets called with the length of
# - don't close the filehandle when done # each write.
@implementer(interfaces.IConsumer) @implementer(interfaces.IConsumer)
class FileConsumer: class FileConsumer:
def __init__(self, f, xfersize, progress_f): def __init__(self, f, progress=None):
self._f = f self._f = f
self._xfersize = xfersize self._progress = progress
self._received = 0
self._progress_f = progress_f
self._producer = None self._producer = None
self._progress_f(0)
def registerProducer(self, producer, streaming): def registerProducer(self, producer, streaming):
assert not self._producer assert not self._producer
@ -826,8 +821,8 @@ class FileConsumer:
def write(self, bytes): def write(self, bytes):
self._f.write(bytes) self._f.write(bytes)
self._received += len(bytes) if self._progress:
self._progress_f(self._received) self._progress(len(bytes))
def unregisterProducer(self): def unregisterProducer(self):
assert self._producer assert self._producer
@ -851,8 +846,6 @@ class FileConsumer:
# check start/finish time-gathering instrumentation # check start/finish time-gathering instrumentation
# add progress API
# relay URLs are probably mishandled: both sides probably send their URL, # relay URLs are probably mishandled: both sides probably send their URL,
# then connect to the *other* side's URL, when they really should connect to # then connect to the *other* side's URL, when they really should connect to
# both their own and the other side's. The current implementation probably # both their own and the other side's. The current implementation probably