switch to tqdm for nicer CLI progress bars

This commit is contained in:
Brian Warner 2016-04-24 12:04:05 -07:00
parent 16c6c0977e
commit 86edf96412
7 changed files with 42 additions and 185 deletions

View File

@ -26,7 +26,7 @@ setup(name="magic-wormhole",
"wormhole-server = wormhole.server.runner:entry",
]},
install_requires=["spake2==0.3", "pynacl", "requests", "argparse",
"six", "twisted >= 16.1.0", "hkdf",
"six", "twisted >= 16.1.0", "hkdf", "tqdm",
"autobahn[twisted]", "pytrie",
# autobahn seems to have a bug, and one plugin throws
# errors unless pytrie is installed

View File

@ -1,12 +1,11 @@
from __future__ import print_function
import io, os, sys, json, binascii, six, tempfile, zipfile
import os, sys, json, binascii, six, tempfile, zipfile
from tqdm import tqdm
from twisted.internet import reactor, defer
from twisted.internet.defer import inlineCallbacks, returnValue
from ..twisted.transcribe import Wormhole, WrongPasswordError
from ..twisted.transit import TransitReceiver
from ..errors import TransferError
from .progress import ProgressPrinter
APPID = u"lothar.com/wormhole/text-or-file-xfer"
@ -217,15 +216,12 @@ class TwistedReceiver:
self.msg(u"Receiving (%s).." % record_pipe.describe())
_start = self.args.timing.add_event("rx file")
progress_stdout = self.args.stdout
if self.args.hide_progress:
progress_stdout = io.StringIO()
progress = ProgressPrinter(self.xfersize, progress_stdout)
progress.start()
received = yield record_pipe.writeToFile(f, self.xfersize,
progress.update)
progress.finish()
progress = tqdm(file=self.args.stdout,
disable=self.args.hide_progress,
unit="B", unit_scale=True, total=self.xfersize)
with progress:
received = yield record_pipe.writeToFile(f, self.xfersize,
progress.update)
self.args.timing.finish_event(_start)
# except TransitError

View File

@ -1,10 +1,10 @@
from __future__ import print_function
import os, sys, io, json, binascii, six, tempfile, zipfile
import os, sys, json, binascii, six, tempfile, zipfile
from tqdm import tqdm
from twisted.protocols import basic
from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue
from ..errors import TransferError
from .progress import ProgressPrinter
from ..twisted.transcribe import Wormhole, WrongPasswordError
from ..twisted.transit import TransitSender
@ -189,24 +189,6 @@ def send_twisted(args, reactor=reactor):
args.stdout, args.hide_progress, args.timing)
returnValue(0)
class ProgressingFileSender(basic.FileSender):
def __init__(self, filesize, stdout):
self._sent = 0
self._progress = ProgressPrinter(filesize, stdout)
self._progress.start()
def progress(self, data):
self._sent += len(data)
self._progress.update(self._sent)
return data
def beginFileTransfer(self, file, consumer):
d = basic.FileSender.beginFileTransfer(self, file, consumer,
self.progress)
d.addCallback(self.done)
return d
def done(self, res):
self._progress.finish()
return res
@inlineCallbacks
def _send_file_twisted(tdata, transit_sender, fd_to_send,
stdout, hide_progress, timing):
@ -216,16 +198,22 @@ def _send_file_twisted(tdata, transit_sender, fd_to_send,
fd_to_send.seek(0,2)
filesize = fd_to_send.tell()
fd_to_send.seek(0,0)
progress_stdout = stdout
if hide_progress:
progress_stdout = io.StringIO()
record_pipe = yield transit_sender.connect()
# record_pipe should implement IConsumer, chunks are just records
print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
pfs = ProgressingFileSender(filesize, progress_stdout)
progress = tqdm(file=stdout, disable=hide_progress,
unit="B", unit_scale=True,
total=filesize)
def _count(data):
progress.update(len(data))
return data
fs = basic.FileSender()
_start = timing.add_event("tx file")
yield pfs.beginFileTransfer(fd_to_send, record_pipe)
with progress:
yield fs.beginFileTransfer(fd_to_send, record_pipe, transform=_count)
timing.finish_event(_start)
print(u"File sent.. waiting for confirmation", file=stdout)

View File

@ -1,51 +0,0 @@
from __future__ import print_function
import time
class ProgressPrinter:
def __init__(self, expected, stdout, update_every=0.2):
self._expected = expected
self._stdout = stdout
self._update_every = update_every
def _now(self):
return time.time()
def start(self):
self._print(0)
self._next_update = self._now() + self._update_every
def update(self, completed):
now = self._now()
if now < self._next_update:
return
self._next_update = now + self._update_every
self._print(completed)
def finish(self):
self._print(self._expected)
print(u"", file=self._stdout)
def _print(self, completed):
# scp does "<<FILENAME >>(13% 168MB 39.3MB/s 00:27 ETA)"
# we do "Progress: #### 13% 168MB"
screen_width = 70
bar_width = screen_width - 30
fmt = "Progress: %-{}s %3d%% %4d%s".format(bar_width)
short_unit_size, short_unit_name = 1, "B"
if self._expected > 9999:
short_unit_size, short_unit_name = 1000, "KB"
if self._expected > 9999*1000:
short_unit_size, short_unit_name = 1000*1000, "MB"
if self._expected > 9999*1000*1000:
short_unit_size, short_unit_name = 1000*1000*1000, "GB"
percentage_complete = ((1.0 * completed / self._expected)
if self._expected
else 1.0)
bars = "#" * int(percentage_complete * bar_width)
perc = int(100 * percentage_complete)
short_unit_count = int(completed / short_unit_size)
out = fmt % (bars, perc, short_unit_count, short_unit_name)
print(u"\r"+" "*screen_width, end=u"", file=self._stdout)
print(u"\r"+out, end=u"", file=self._stdout)
self._stdout.flush()

View File

@ -1,69 +0,0 @@
from __future__ import print_function
import io, time
from twisted.trial import unittest
from ..cli import progress
class Progress(unittest.TestCase):
def test_time(self):
p = progress.ProgressPrinter(1e6, None)
start = time.time()
now = p._now()
finish = time.time()
self.assertTrue(start <= now <= finish, (start, now, finish))
def test_basic(self):
stdout = io.StringIO()
p = progress.ProgressPrinter(1e6, stdout)
p._now = lambda: 0.0
p.start()
erase = u"\r"+u" "*70
expected = erase
fmt = "Progress: %-40s %3d%% %4d%s"
expected += u"\r" + fmt % ("", 0, 0, "KB")
self.assertEqual(stdout.getvalue(), expected)
p.update(1e3) # no change, too soon
self.assertEqual(stdout.getvalue(), expected)
p._now = lambda: 1.0
p.update(1e3) # enough "time" has passed
expected += erase + u"\r" + fmt % ("", 0, 1, "KB")
self.assertEqual(stdout.getvalue(), expected)
p._now = lambda: 2.0
p.update(500e3)
expected += erase + u"\r" + fmt % ("#"*20, 50, 500, "KB")
self.assertEqual(stdout.getvalue(), expected)
p._now = lambda: 3.0
p.finish()
expected += erase + u"\r" + fmt % ("#"*40, 100, 1000, "KB")
expected += u"\n"
self.assertEqual(stdout.getvalue(), expected)
def test_units(self):
def _try(size):
stdout = io.StringIO()
p = progress.ProgressPrinter(size, stdout)
p.finish()
return stdout.getvalue()
fmt = "Progress: %-40s %3d%% %4d%s"
def _expect(count, units):
erase = u"\r"+u" "*70
expected = erase
expected += u"\r" + fmt % ("#"*40, 100, count, units)
expected += u"\n"
return expected
self.assertEqual(_try(900), _expect(900, "B"))
self.assertEqual(_try(9e3), _expect(9000, "B"))
self.assertEqual(_try(90e3), _expect(90, "KB"))
self.assertEqual(_try(900e3), _expect(900, "KB"))
self.assertEqual(_try(9e6), _expect(9000, "KB"))
self.assertEqual(_try(90e6), _expect(90, "MB"))
self.assertEqual(_try(900e6), _expect(900, "MB"))
self.assertEqual(_try(9e9), _expect(9000, "MB"))
self.assertEqual(_try(90e9), _expect(90, "GB"))
self.assertEqual(_try(900e9), _expect(900, "GB"))

View File

@ -1087,22 +1087,22 @@ class Connection(unittest.TestCase):
d = c.writeToFile(f, 10, progress.append)
d.addBoth(results.append)
self.assertEqual(f.getvalue(), b"r1.")
self.assertEqual(progress, [0, 3])
self.assertEqual(progress, [3])
self.assertEqual(results, [])
c.recordReceived(b"r2.")
self.assertEqual(f.getvalue(), b"r1.r2.")
self.assertEqual(progress, [0, 3, 6])
self.assertEqual(progress, [3, 3])
self.assertEqual(results, [])
c.recordReceived(b"r3.")
self.assertEqual(f.getvalue(), b"r1.r2.r3.")
self.assertEqual(progress, [0, 3, 6, 9])
self.assertEqual(progress, [3, 3, 3])
self.assertEqual(results, [])
c.recordReceived(b"!")
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
self.assertEqual(progress, [0, 3, 6, 9, 10])
self.assertEqual(progress, [3, 3, 3, 1])
self.assertEqual(results, [10])
# that should automatically disconnect the consumer, and subsequent
@ -1110,7 +1110,7 @@ class Connection(unittest.TestCase):
self.assertIs(c._consumer, None)
c.recordReceived(b"overflow.")
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
self.assertEqual(progress, [0, 3, 6, 9, 10])
self.assertEqual(progress, [3, 3, 3, 1])
# test what happens when enough data is queued ahead of time
c.recordReceived(b"second.") # now "overflow.second."
@ -1155,14 +1155,14 @@ class FileConsumer(unittest.TestCase):
def test_basic(self):
f = io.BytesIO()
progress = []
fc = transit.FileConsumer(f, 100, progress.append)
self.assertEqual(progress, [0])
fc = transit.FileConsumer(f, progress.append)
self.assertEqual(progress, [])
self.assertEqual(f.getvalue(), b"")
fc.write(b"."* 99)
self.assertEqual(progress, [0, 99])
self.assertEqual(progress, [99])
self.assertEqual(f.getvalue(), b"."*99)
fc.write(b"!")
self.assertEqual(progress, [0, 99, 100])
self.assertEqual(progress, [99, 1])
self.assertEqual(f.getvalue(), b"."*99+b"!")

View File

@ -394,13 +394,11 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
# Helper method to write a known number of bytes to a file. This has no
# flow control: the filehandle cannot push back. 'progress' is an
# optional callable which will be called frequently with the number of
# bytes transferred so far. Returns a Deferred that fires (with the
# number of bytes written) when the count is reached or the RecordPipe is
# closed.
# optional callable which will be called on each write (with the number
# of bytes written). Returns a Deferred that fires (with the number of
# bytes written) when the count is reached or the RecordPipe is closed.
def writeToFile(self, f, expected, progress=None):
progress = progress or (lambda n: None)
fc = FileConsumer(f, expected, progress)
fc = FileConsumer(f, progress)
return self.connectConsumer(fc, expected)
class OutboundConnectionFactory(protocol.ClientFactory):
@ -805,19 +803,16 @@ class TransitReceiver(Common):
is_sender = False
# based on twisted.protocols.ftp.FileConsumer, but:
# - call a progress-tracking function
# - don't close the filehandle when done
# based on twisted.protocols.ftp.FileConsumer, but don't close the filehandle
# when done, and add a progress function that gets called with the length of
# each write.
@implementer(interfaces.IConsumer)
class FileConsumer:
def __init__(self, f, xfersize, progress_f):
def __init__(self, f, progress=None):
self._f = f
self._xfersize = xfersize
self._received = 0
self._progress_f = progress_f
self._progress = progress
self._producer = None
self._progress_f(0)
def registerProducer(self, producer, streaming):
assert not self._producer
@ -826,8 +821,8 @@ class FileConsumer:
def write(self, bytes):
self._f.write(bytes)
self._received += len(bytes)
self._progress_f(self._received)
if self._progress:
self._progress(len(bytes))
def unregisterProducer(self):
assert self._producer
@ -851,8 +846,6 @@ class FileConsumer:
# check start/finish time-gathering instrumentation
# add progress API
# relay URLs are probably mishandled: both sides probably send their URL,
# then connect to the *other* side's URL, when they really should connect to
# both their own and the other side's. The current implementation probably