switch to tqdm for nicer CLI progress bars
This commit is contained in:
parent
16c6c0977e
commit
86edf96412
2
setup.py
2
setup.py
|
@ -26,7 +26,7 @@ setup(name="magic-wormhole",
|
|||
"wormhole-server = wormhole.server.runner:entry",
|
||||
]},
|
||||
install_requires=["spake2==0.3", "pynacl", "requests", "argparse",
|
||||
"six", "twisted >= 16.1.0", "hkdf",
|
||||
"six", "twisted >= 16.1.0", "hkdf", "tqdm",
|
||||
"autobahn[twisted]", "pytrie",
|
||||
# autobahn seems to have a bug, and one plugin throws
|
||||
# errors unless pytrie is installed
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
from __future__ import print_function
|
||||
import io, os, sys, json, binascii, six, tempfile, zipfile
|
||||
import os, sys, json, binascii, six, tempfile, zipfile
|
||||
from tqdm import tqdm
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
||||
from ..twisted.transit import TransitReceiver
|
||||
from ..errors import TransferError
|
||||
from .progress import ProgressPrinter
|
||||
|
||||
|
||||
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
||||
|
||||
|
@ -217,15 +216,12 @@ class TwistedReceiver:
|
|||
self.msg(u"Receiving (%s).." % record_pipe.describe())
|
||||
|
||||
_start = self.args.timing.add_event("rx file")
|
||||
progress_stdout = self.args.stdout
|
||||
if self.args.hide_progress:
|
||||
progress_stdout = io.StringIO()
|
||||
progress = ProgressPrinter(self.xfersize, progress_stdout)
|
||||
|
||||
progress.start()
|
||||
progress = tqdm(file=self.args.stdout,
|
||||
disable=self.args.hide_progress,
|
||||
unit="B", unit_scale=True, total=self.xfersize)
|
||||
with progress:
|
||||
received = yield record_pipe.writeToFile(f, self.xfersize,
|
||||
progress.update)
|
||||
progress.finish()
|
||||
self.args.timing.finish_event(_start)
|
||||
|
||||
# except TransitError
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from __future__ import print_function
|
||||
import os, sys, io, json, binascii, six, tempfile, zipfile
|
||||
import os, sys, json, binascii, six, tempfile, zipfile
|
||||
from tqdm import tqdm
|
||||
from twisted.protocols import basic
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from ..errors import TransferError
|
||||
from .progress import ProgressPrinter
|
||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
||||
from ..twisted.transit import TransitSender
|
||||
|
||||
|
@ -189,24 +189,6 @@ def send_twisted(args, reactor=reactor):
|
|||
args.stdout, args.hide_progress, args.timing)
|
||||
returnValue(0)
|
||||
|
||||
class ProgressingFileSender(basic.FileSender):
|
||||
def __init__(self, filesize, stdout):
|
||||
self._sent = 0
|
||||
self._progress = ProgressPrinter(filesize, stdout)
|
||||
self._progress.start()
|
||||
def progress(self, data):
|
||||
self._sent += len(data)
|
||||
self._progress.update(self._sent)
|
||||
return data
|
||||
def beginFileTransfer(self, file, consumer):
|
||||
d = basic.FileSender.beginFileTransfer(self, file, consumer,
|
||||
self.progress)
|
||||
d.addCallback(self.done)
|
||||
return d
|
||||
def done(self, res):
|
||||
self._progress.finish()
|
||||
return res
|
||||
|
||||
@inlineCallbacks
|
||||
def _send_file_twisted(tdata, transit_sender, fd_to_send,
|
||||
stdout, hide_progress, timing):
|
||||
|
@ -216,16 +198,22 @@ def _send_file_twisted(tdata, transit_sender, fd_to_send,
|
|||
fd_to_send.seek(0,2)
|
||||
filesize = fd_to_send.tell()
|
||||
fd_to_send.seek(0,0)
|
||||
progress_stdout = stdout
|
||||
if hide_progress:
|
||||
progress_stdout = io.StringIO()
|
||||
|
||||
record_pipe = yield transit_sender.connect()
|
||||
# record_pipe should implement IConsumer, chunks are just records
|
||||
print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
|
||||
pfs = ProgressingFileSender(filesize, progress_stdout)
|
||||
|
||||
progress = tqdm(file=stdout, disable=hide_progress,
|
||||
unit="B", unit_scale=True,
|
||||
total=filesize)
|
||||
def _count(data):
|
||||
progress.update(len(data))
|
||||
return data
|
||||
fs = basic.FileSender()
|
||||
|
||||
_start = timing.add_event("tx file")
|
||||
yield pfs.beginFileTransfer(fd_to_send, record_pipe)
|
||||
with progress:
|
||||
yield fs.beginFileTransfer(fd_to_send, record_pipe, transform=_count)
|
||||
timing.finish_event(_start)
|
||||
|
||||
print(u"File sent.. waiting for confirmation", file=stdout)
|
||||
|
|
|
@ -1,51 +0,0 @@
|
|||
from __future__ import print_function
|
||||
import time
|
||||
|
||||
class ProgressPrinter:
|
||||
def __init__(self, expected, stdout, update_every=0.2):
|
||||
self._expected = expected
|
||||
self._stdout = stdout
|
||||
self._update_every = update_every
|
||||
|
||||
def _now(self):
|
||||
return time.time()
|
||||
|
||||
def start(self):
|
||||
self._print(0)
|
||||
self._next_update = self._now() + self._update_every
|
||||
|
||||
def update(self, completed):
|
||||
now = self._now()
|
||||
if now < self._next_update:
|
||||
return
|
||||
self._next_update = now + self._update_every
|
||||
self._print(completed)
|
||||
|
||||
def finish(self):
|
||||
self._print(self._expected)
|
||||
print(u"", file=self._stdout)
|
||||
|
||||
def _print(self, completed):
|
||||
# scp does "<<FILENAME >>(13% 168MB 39.3MB/s 00:27 ETA)"
|
||||
# we do "Progress: #### 13% 168MB"
|
||||
screen_width = 70
|
||||
bar_width = screen_width - 30
|
||||
fmt = "Progress: %-{}s %3d%% %4d%s".format(bar_width)
|
||||
short_unit_size, short_unit_name = 1, "B"
|
||||
if self._expected > 9999:
|
||||
short_unit_size, short_unit_name = 1000, "KB"
|
||||
if self._expected > 9999*1000:
|
||||
short_unit_size, short_unit_name = 1000*1000, "MB"
|
||||
if self._expected > 9999*1000*1000:
|
||||
short_unit_size, short_unit_name = 1000*1000*1000, "GB"
|
||||
|
||||
percentage_complete = ((1.0 * completed / self._expected)
|
||||
if self._expected
|
||||
else 1.0)
|
||||
bars = "#" * int(percentage_complete * bar_width)
|
||||
perc = int(100 * percentage_complete)
|
||||
short_unit_count = int(completed / short_unit_size)
|
||||
out = fmt % (bars, perc, short_unit_count, short_unit_name)
|
||||
print(u"\r"+" "*screen_width, end=u"", file=self._stdout)
|
||||
print(u"\r"+out, end=u"", file=self._stdout)
|
||||
self._stdout.flush()
|
|
@ -1,69 +0,0 @@
|
|||
from __future__ import print_function
|
||||
import io, time
|
||||
from twisted.trial import unittest
|
||||
from ..cli import progress
|
||||
|
||||
class Progress(unittest.TestCase):
|
||||
def test_time(self):
|
||||
p = progress.ProgressPrinter(1e6, None)
|
||||
start = time.time()
|
||||
now = p._now()
|
||||
finish = time.time()
|
||||
self.assertTrue(start <= now <= finish, (start, now, finish))
|
||||
|
||||
def test_basic(self):
|
||||
stdout = io.StringIO()
|
||||
p = progress.ProgressPrinter(1e6, stdout)
|
||||
p._now = lambda: 0.0
|
||||
p.start()
|
||||
|
||||
erase = u"\r"+u" "*70
|
||||
expected = erase
|
||||
fmt = "Progress: %-40s %3d%% %4d%s"
|
||||
expected += u"\r" + fmt % ("", 0, 0, "KB")
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
p.update(1e3) # no change, too soon
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
p._now = lambda: 1.0
|
||||
p.update(1e3) # enough "time" has passed
|
||||
expected += erase + u"\r" + fmt % ("", 0, 1, "KB")
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
p._now = lambda: 2.0
|
||||
p.update(500e3)
|
||||
expected += erase + u"\r" + fmt % ("#"*20, 50, 500, "KB")
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
p._now = lambda: 3.0
|
||||
p.finish()
|
||||
expected += erase + u"\r" + fmt % ("#"*40, 100, 1000, "KB")
|
||||
expected += u"\n"
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
def test_units(self):
|
||||
def _try(size):
|
||||
stdout = io.StringIO()
|
||||
p = progress.ProgressPrinter(size, stdout)
|
||||
p.finish()
|
||||
return stdout.getvalue()
|
||||
|
||||
fmt = "Progress: %-40s %3d%% %4d%s"
|
||||
def _expect(count, units):
|
||||
erase = u"\r"+u" "*70
|
||||
expected = erase
|
||||
expected += u"\r" + fmt % ("#"*40, 100, count, units)
|
||||
expected += u"\n"
|
||||
return expected
|
||||
|
||||
self.assertEqual(_try(900), _expect(900, "B"))
|
||||
self.assertEqual(_try(9e3), _expect(9000, "B"))
|
||||
self.assertEqual(_try(90e3), _expect(90, "KB"))
|
||||
self.assertEqual(_try(900e3), _expect(900, "KB"))
|
||||
self.assertEqual(_try(9e6), _expect(9000, "KB"))
|
||||
self.assertEqual(_try(90e6), _expect(90, "MB"))
|
||||
self.assertEqual(_try(900e6), _expect(900, "MB"))
|
||||
self.assertEqual(_try(9e9), _expect(9000, "MB"))
|
||||
self.assertEqual(_try(90e9), _expect(90, "GB"))
|
||||
self.assertEqual(_try(900e9), _expect(900, "GB"))
|
|
@ -1087,22 +1087,22 @@ class Connection(unittest.TestCase):
|
|||
d = c.writeToFile(f, 10, progress.append)
|
||||
d.addBoth(results.append)
|
||||
self.assertEqual(f.getvalue(), b"r1.")
|
||||
self.assertEqual(progress, [0, 3])
|
||||
self.assertEqual(progress, [3])
|
||||
self.assertEqual(results, [])
|
||||
|
||||
c.recordReceived(b"r2.")
|
||||
self.assertEqual(f.getvalue(), b"r1.r2.")
|
||||
self.assertEqual(progress, [0, 3, 6])
|
||||
self.assertEqual(progress, [3, 3])
|
||||
self.assertEqual(results, [])
|
||||
|
||||
c.recordReceived(b"r3.")
|
||||
self.assertEqual(f.getvalue(), b"r1.r2.r3.")
|
||||
self.assertEqual(progress, [0, 3, 6, 9])
|
||||
self.assertEqual(progress, [3, 3, 3])
|
||||
self.assertEqual(results, [])
|
||||
|
||||
c.recordReceived(b"!")
|
||||
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
|
||||
self.assertEqual(progress, [0, 3, 6, 9, 10])
|
||||
self.assertEqual(progress, [3, 3, 3, 1])
|
||||
self.assertEqual(results, [10])
|
||||
|
||||
# that should automatically disconnect the consumer, and subsequent
|
||||
|
@ -1110,7 +1110,7 @@ class Connection(unittest.TestCase):
|
|||
self.assertIs(c._consumer, None)
|
||||
c.recordReceived(b"overflow.")
|
||||
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
|
||||
self.assertEqual(progress, [0, 3, 6, 9, 10])
|
||||
self.assertEqual(progress, [3, 3, 3, 1])
|
||||
|
||||
# test what happens when enough data is queued ahead of time
|
||||
c.recordReceived(b"second.") # now "overflow.second."
|
||||
|
@ -1155,14 +1155,14 @@ class FileConsumer(unittest.TestCase):
|
|||
def test_basic(self):
|
||||
f = io.BytesIO()
|
||||
progress = []
|
||||
fc = transit.FileConsumer(f, 100, progress.append)
|
||||
self.assertEqual(progress, [0])
|
||||
fc = transit.FileConsumer(f, progress.append)
|
||||
self.assertEqual(progress, [])
|
||||
self.assertEqual(f.getvalue(), b"")
|
||||
fc.write(b"."* 99)
|
||||
self.assertEqual(progress, [0, 99])
|
||||
self.assertEqual(progress, [99])
|
||||
self.assertEqual(f.getvalue(), b"."*99)
|
||||
fc.write(b"!")
|
||||
self.assertEqual(progress, [0, 99, 100])
|
||||
self.assertEqual(progress, [99, 1])
|
||||
self.assertEqual(f.getvalue(), b"."*99+b"!")
|
||||
|
||||
|
||||
|
|
|
@ -394,13 +394,11 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
|||
|
||||
# Helper method to write a known number of bytes to a file. This has no
|
||||
# flow control: the filehandle cannot push back. 'progress' is an
|
||||
# optional callable which will be called frequently with the number of
|
||||
# bytes transferred so far. Returns a Deferred that fires (with the
|
||||
# number of bytes written) when the count is reached or the RecordPipe is
|
||||
# closed.
|
||||
# optional callable which will be called on each write (with the number
|
||||
# of bytes written). Returns a Deferred that fires (with the number of
|
||||
# bytes written) when the count is reached or the RecordPipe is closed.
|
||||
def writeToFile(self, f, expected, progress=None):
|
||||
progress = progress or (lambda n: None)
|
||||
fc = FileConsumer(f, expected, progress)
|
||||
fc = FileConsumer(f, progress)
|
||||
return self.connectConsumer(fc, expected)
|
||||
|
||||
class OutboundConnectionFactory(protocol.ClientFactory):
|
||||
|
@ -805,19 +803,16 @@ class TransitReceiver(Common):
|
|||
is_sender = False
|
||||
|
||||
|
||||
# based on twisted.protocols.ftp.FileConsumer, but:
|
||||
# - call a progress-tracking function
|
||||
# - don't close the filehandle when done
|
||||
# based on twisted.protocols.ftp.FileConsumer, but don't close the filehandle
|
||||
# when done, and add a progress function that gets called with the length of
|
||||
# each write.
|
||||
|
||||
@implementer(interfaces.IConsumer)
|
||||
class FileConsumer:
|
||||
def __init__(self, f, xfersize, progress_f):
|
||||
def __init__(self, f, progress=None):
|
||||
self._f = f
|
||||
self._xfersize = xfersize
|
||||
self._received = 0
|
||||
self._progress_f = progress_f
|
||||
self._progress = progress
|
||||
self._producer = None
|
||||
self._progress_f(0)
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
assert not self._producer
|
||||
|
@ -826,8 +821,8 @@ class FileConsumer:
|
|||
|
||||
def write(self, bytes):
|
||||
self._f.write(bytes)
|
||||
self._received += len(bytes)
|
||||
self._progress_f(self._received)
|
||||
if self._progress:
|
||||
self._progress(len(bytes))
|
||||
|
||||
def unregisterProducer(self):
|
||||
assert self._producer
|
||||
|
@ -851,8 +846,6 @@ class FileConsumer:
|
|||
|
||||
# check start/finish time-gathering instrumentation
|
||||
|
||||
# add progress API
|
||||
|
||||
# relay URLs are probably mishandled: both sides probably send their URL,
|
||||
# then connect to the *other* side's URL, when they really should connect to
|
||||
# both their own and the other side's. The current implementation probably
|
||||
|
|
Loading…
Reference in New Issue
Block a user