switch to tqdm for nicer CLI progress bars
This commit is contained in:
parent
16c6c0977e
commit
86edf96412
2
setup.py
2
setup.py
|
@ -26,7 +26,7 @@ setup(name="magic-wormhole",
|
||||||
"wormhole-server = wormhole.server.runner:entry",
|
"wormhole-server = wormhole.server.runner:entry",
|
||||||
]},
|
]},
|
||||||
install_requires=["spake2==0.3", "pynacl", "requests", "argparse",
|
install_requires=["spake2==0.3", "pynacl", "requests", "argparse",
|
||||||
"six", "twisted >= 16.1.0", "hkdf",
|
"six", "twisted >= 16.1.0", "hkdf", "tqdm",
|
||||||
"autobahn[twisted]", "pytrie",
|
"autobahn[twisted]", "pytrie",
|
||||||
# autobahn seems to have a bug, and one plugin throws
|
# autobahn seems to have a bug, and one plugin throws
|
||||||
# errors unless pytrie is installed
|
# errors unless pytrie is installed
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import io, os, sys, json, binascii, six, tempfile, zipfile
|
import os, sys, json, binascii, six, tempfile, zipfile
|
||||||
|
from tqdm import tqdm
|
||||||
from twisted.internet import reactor, defer
|
from twisted.internet import reactor, defer
|
||||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
||||||
from ..twisted.transit import TransitReceiver
|
from ..twisted.transit import TransitReceiver
|
||||||
from ..errors import TransferError
|
from ..errors import TransferError
|
||||||
from .progress import ProgressPrinter
|
|
||||||
|
|
||||||
|
|
||||||
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
||||||
|
|
||||||
|
@ -217,15 +216,12 @@ class TwistedReceiver:
|
||||||
self.msg(u"Receiving (%s).." % record_pipe.describe())
|
self.msg(u"Receiving (%s).." % record_pipe.describe())
|
||||||
|
|
||||||
_start = self.args.timing.add_event("rx file")
|
_start = self.args.timing.add_event("rx file")
|
||||||
progress_stdout = self.args.stdout
|
progress = tqdm(file=self.args.stdout,
|
||||||
if self.args.hide_progress:
|
disable=self.args.hide_progress,
|
||||||
progress_stdout = io.StringIO()
|
unit="B", unit_scale=True, total=self.xfersize)
|
||||||
progress = ProgressPrinter(self.xfersize, progress_stdout)
|
with progress:
|
||||||
|
|
||||||
progress.start()
|
|
||||||
received = yield record_pipe.writeToFile(f, self.xfersize,
|
received = yield record_pipe.writeToFile(f, self.xfersize,
|
||||||
progress.update)
|
progress.update)
|
||||||
progress.finish()
|
|
||||||
self.args.timing.finish_event(_start)
|
self.args.timing.finish_event(_start)
|
||||||
|
|
||||||
# except TransitError
|
# except TransitError
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import os, sys, io, json, binascii, six, tempfile, zipfile
|
import os, sys, json, binascii, six, tempfile, zipfile
|
||||||
|
from tqdm import tqdm
|
||||||
from twisted.protocols import basic
|
from twisted.protocols import basic
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||||
from ..errors import TransferError
|
from ..errors import TransferError
|
||||||
from .progress import ProgressPrinter
|
|
||||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
||||||
from ..twisted.transit import TransitSender
|
from ..twisted.transit import TransitSender
|
||||||
|
|
||||||
|
@ -189,24 +189,6 @@ def send_twisted(args, reactor=reactor):
|
||||||
args.stdout, args.hide_progress, args.timing)
|
args.stdout, args.hide_progress, args.timing)
|
||||||
returnValue(0)
|
returnValue(0)
|
||||||
|
|
||||||
class ProgressingFileSender(basic.FileSender):
|
|
||||||
def __init__(self, filesize, stdout):
|
|
||||||
self._sent = 0
|
|
||||||
self._progress = ProgressPrinter(filesize, stdout)
|
|
||||||
self._progress.start()
|
|
||||||
def progress(self, data):
|
|
||||||
self._sent += len(data)
|
|
||||||
self._progress.update(self._sent)
|
|
||||||
return data
|
|
||||||
def beginFileTransfer(self, file, consumer):
|
|
||||||
d = basic.FileSender.beginFileTransfer(self, file, consumer,
|
|
||||||
self.progress)
|
|
||||||
d.addCallback(self.done)
|
|
||||||
return d
|
|
||||||
def done(self, res):
|
|
||||||
self._progress.finish()
|
|
||||||
return res
|
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def _send_file_twisted(tdata, transit_sender, fd_to_send,
|
def _send_file_twisted(tdata, transit_sender, fd_to_send,
|
||||||
stdout, hide_progress, timing):
|
stdout, hide_progress, timing):
|
||||||
|
@ -216,16 +198,22 @@ def _send_file_twisted(tdata, transit_sender, fd_to_send,
|
||||||
fd_to_send.seek(0,2)
|
fd_to_send.seek(0,2)
|
||||||
filesize = fd_to_send.tell()
|
filesize = fd_to_send.tell()
|
||||||
fd_to_send.seek(0,0)
|
fd_to_send.seek(0,0)
|
||||||
progress_stdout = stdout
|
|
||||||
if hide_progress:
|
|
||||||
progress_stdout = io.StringIO()
|
|
||||||
|
|
||||||
record_pipe = yield transit_sender.connect()
|
record_pipe = yield transit_sender.connect()
|
||||||
# record_pipe should implement IConsumer, chunks are just records
|
# record_pipe should implement IConsumer, chunks are just records
|
||||||
print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
|
print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
|
||||||
pfs = ProgressingFileSender(filesize, progress_stdout)
|
|
||||||
|
progress = tqdm(file=stdout, disable=hide_progress,
|
||||||
|
unit="B", unit_scale=True,
|
||||||
|
total=filesize)
|
||||||
|
def _count(data):
|
||||||
|
progress.update(len(data))
|
||||||
|
return data
|
||||||
|
fs = basic.FileSender()
|
||||||
|
|
||||||
_start = timing.add_event("tx file")
|
_start = timing.add_event("tx file")
|
||||||
yield pfs.beginFileTransfer(fd_to_send, record_pipe)
|
with progress:
|
||||||
|
yield fs.beginFileTransfer(fd_to_send, record_pipe, transform=_count)
|
||||||
timing.finish_event(_start)
|
timing.finish_event(_start)
|
||||||
|
|
||||||
print(u"File sent.. waiting for confirmation", file=stdout)
|
print(u"File sent.. waiting for confirmation", file=stdout)
|
||||||
|
|
|
@ -1,51 +0,0 @@
|
||||||
from __future__ import print_function
|
|
||||||
import time
|
|
||||||
|
|
||||||
class ProgressPrinter:
|
|
||||||
def __init__(self, expected, stdout, update_every=0.2):
|
|
||||||
self._expected = expected
|
|
||||||
self._stdout = stdout
|
|
||||||
self._update_every = update_every
|
|
||||||
|
|
||||||
def _now(self):
|
|
||||||
return time.time()
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
self._print(0)
|
|
||||||
self._next_update = self._now() + self._update_every
|
|
||||||
|
|
||||||
def update(self, completed):
|
|
||||||
now = self._now()
|
|
||||||
if now < self._next_update:
|
|
||||||
return
|
|
||||||
self._next_update = now + self._update_every
|
|
||||||
self._print(completed)
|
|
||||||
|
|
||||||
def finish(self):
|
|
||||||
self._print(self._expected)
|
|
||||||
print(u"", file=self._stdout)
|
|
||||||
|
|
||||||
def _print(self, completed):
|
|
||||||
# scp does "<<FILENAME >>(13% 168MB 39.3MB/s 00:27 ETA)"
|
|
||||||
# we do "Progress: #### 13% 168MB"
|
|
||||||
screen_width = 70
|
|
||||||
bar_width = screen_width - 30
|
|
||||||
fmt = "Progress: %-{}s %3d%% %4d%s".format(bar_width)
|
|
||||||
short_unit_size, short_unit_name = 1, "B"
|
|
||||||
if self._expected > 9999:
|
|
||||||
short_unit_size, short_unit_name = 1000, "KB"
|
|
||||||
if self._expected > 9999*1000:
|
|
||||||
short_unit_size, short_unit_name = 1000*1000, "MB"
|
|
||||||
if self._expected > 9999*1000*1000:
|
|
||||||
short_unit_size, short_unit_name = 1000*1000*1000, "GB"
|
|
||||||
|
|
||||||
percentage_complete = ((1.0 * completed / self._expected)
|
|
||||||
if self._expected
|
|
||||||
else 1.0)
|
|
||||||
bars = "#" * int(percentage_complete * bar_width)
|
|
||||||
perc = int(100 * percentage_complete)
|
|
||||||
short_unit_count = int(completed / short_unit_size)
|
|
||||||
out = fmt % (bars, perc, short_unit_count, short_unit_name)
|
|
||||||
print(u"\r"+" "*screen_width, end=u"", file=self._stdout)
|
|
||||||
print(u"\r"+out, end=u"", file=self._stdout)
|
|
||||||
self._stdout.flush()
|
|
|
@ -1,69 +0,0 @@
|
||||||
from __future__ import print_function
|
|
||||||
import io, time
|
|
||||||
from twisted.trial import unittest
|
|
||||||
from ..cli import progress
|
|
||||||
|
|
||||||
class Progress(unittest.TestCase):
|
|
||||||
def test_time(self):
|
|
||||||
p = progress.ProgressPrinter(1e6, None)
|
|
||||||
start = time.time()
|
|
||||||
now = p._now()
|
|
||||||
finish = time.time()
|
|
||||||
self.assertTrue(start <= now <= finish, (start, now, finish))
|
|
||||||
|
|
||||||
def test_basic(self):
|
|
||||||
stdout = io.StringIO()
|
|
||||||
p = progress.ProgressPrinter(1e6, stdout)
|
|
||||||
p._now = lambda: 0.0
|
|
||||||
p.start()
|
|
||||||
|
|
||||||
erase = u"\r"+u" "*70
|
|
||||||
expected = erase
|
|
||||||
fmt = "Progress: %-40s %3d%% %4d%s"
|
|
||||||
expected += u"\r" + fmt % ("", 0, 0, "KB")
|
|
||||||
self.assertEqual(stdout.getvalue(), expected)
|
|
||||||
|
|
||||||
p.update(1e3) # no change, too soon
|
|
||||||
self.assertEqual(stdout.getvalue(), expected)
|
|
||||||
|
|
||||||
p._now = lambda: 1.0
|
|
||||||
p.update(1e3) # enough "time" has passed
|
|
||||||
expected += erase + u"\r" + fmt % ("", 0, 1, "KB")
|
|
||||||
self.assertEqual(stdout.getvalue(), expected)
|
|
||||||
|
|
||||||
p._now = lambda: 2.0
|
|
||||||
p.update(500e3)
|
|
||||||
expected += erase + u"\r" + fmt % ("#"*20, 50, 500, "KB")
|
|
||||||
self.assertEqual(stdout.getvalue(), expected)
|
|
||||||
|
|
||||||
p._now = lambda: 3.0
|
|
||||||
p.finish()
|
|
||||||
expected += erase + u"\r" + fmt % ("#"*40, 100, 1000, "KB")
|
|
||||||
expected += u"\n"
|
|
||||||
self.assertEqual(stdout.getvalue(), expected)
|
|
||||||
|
|
||||||
def test_units(self):
|
|
||||||
def _try(size):
|
|
||||||
stdout = io.StringIO()
|
|
||||||
p = progress.ProgressPrinter(size, stdout)
|
|
||||||
p.finish()
|
|
||||||
return stdout.getvalue()
|
|
||||||
|
|
||||||
fmt = "Progress: %-40s %3d%% %4d%s"
|
|
||||||
def _expect(count, units):
|
|
||||||
erase = u"\r"+u" "*70
|
|
||||||
expected = erase
|
|
||||||
expected += u"\r" + fmt % ("#"*40, 100, count, units)
|
|
||||||
expected += u"\n"
|
|
||||||
return expected
|
|
||||||
|
|
||||||
self.assertEqual(_try(900), _expect(900, "B"))
|
|
||||||
self.assertEqual(_try(9e3), _expect(9000, "B"))
|
|
||||||
self.assertEqual(_try(90e3), _expect(90, "KB"))
|
|
||||||
self.assertEqual(_try(900e3), _expect(900, "KB"))
|
|
||||||
self.assertEqual(_try(9e6), _expect(9000, "KB"))
|
|
||||||
self.assertEqual(_try(90e6), _expect(90, "MB"))
|
|
||||||
self.assertEqual(_try(900e6), _expect(900, "MB"))
|
|
||||||
self.assertEqual(_try(9e9), _expect(9000, "MB"))
|
|
||||||
self.assertEqual(_try(90e9), _expect(90, "GB"))
|
|
||||||
self.assertEqual(_try(900e9), _expect(900, "GB"))
|
|
|
@ -1087,22 +1087,22 @@ class Connection(unittest.TestCase):
|
||||||
d = c.writeToFile(f, 10, progress.append)
|
d = c.writeToFile(f, 10, progress.append)
|
||||||
d.addBoth(results.append)
|
d.addBoth(results.append)
|
||||||
self.assertEqual(f.getvalue(), b"r1.")
|
self.assertEqual(f.getvalue(), b"r1.")
|
||||||
self.assertEqual(progress, [0, 3])
|
self.assertEqual(progress, [3])
|
||||||
self.assertEqual(results, [])
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
c.recordReceived(b"r2.")
|
c.recordReceived(b"r2.")
|
||||||
self.assertEqual(f.getvalue(), b"r1.r2.")
|
self.assertEqual(f.getvalue(), b"r1.r2.")
|
||||||
self.assertEqual(progress, [0, 3, 6])
|
self.assertEqual(progress, [3, 3])
|
||||||
self.assertEqual(results, [])
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
c.recordReceived(b"r3.")
|
c.recordReceived(b"r3.")
|
||||||
self.assertEqual(f.getvalue(), b"r1.r2.r3.")
|
self.assertEqual(f.getvalue(), b"r1.r2.r3.")
|
||||||
self.assertEqual(progress, [0, 3, 6, 9])
|
self.assertEqual(progress, [3, 3, 3])
|
||||||
self.assertEqual(results, [])
|
self.assertEqual(results, [])
|
||||||
|
|
||||||
c.recordReceived(b"!")
|
c.recordReceived(b"!")
|
||||||
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
|
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
|
||||||
self.assertEqual(progress, [0, 3, 6, 9, 10])
|
self.assertEqual(progress, [3, 3, 3, 1])
|
||||||
self.assertEqual(results, [10])
|
self.assertEqual(results, [10])
|
||||||
|
|
||||||
# that should automatically disconnect the consumer, and subsequent
|
# that should automatically disconnect the consumer, and subsequent
|
||||||
|
@ -1110,7 +1110,7 @@ class Connection(unittest.TestCase):
|
||||||
self.assertIs(c._consumer, None)
|
self.assertIs(c._consumer, None)
|
||||||
c.recordReceived(b"overflow.")
|
c.recordReceived(b"overflow.")
|
||||||
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
|
self.assertEqual(f.getvalue(), b"r1.r2.r3.!")
|
||||||
self.assertEqual(progress, [0, 3, 6, 9, 10])
|
self.assertEqual(progress, [3, 3, 3, 1])
|
||||||
|
|
||||||
# test what happens when enough data is queued ahead of time
|
# test what happens when enough data is queued ahead of time
|
||||||
c.recordReceived(b"second.") # now "overflow.second."
|
c.recordReceived(b"second.") # now "overflow.second."
|
||||||
|
@ -1155,14 +1155,14 @@ class FileConsumer(unittest.TestCase):
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
f = io.BytesIO()
|
f = io.BytesIO()
|
||||||
progress = []
|
progress = []
|
||||||
fc = transit.FileConsumer(f, 100, progress.append)
|
fc = transit.FileConsumer(f, progress.append)
|
||||||
self.assertEqual(progress, [0])
|
self.assertEqual(progress, [])
|
||||||
self.assertEqual(f.getvalue(), b"")
|
self.assertEqual(f.getvalue(), b"")
|
||||||
fc.write(b"."* 99)
|
fc.write(b"."* 99)
|
||||||
self.assertEqual(progress, [0, 99])
|
self.assertEqual(progress, [99])
|
||||||
self.assertEqual(f.getvalue(), b"."*99)
|
self.assertEqual(f.getvalue(), b"."*99)
|
||||||
fc.write(b"!")
|
fc.write(b"!")
|
||||||
self.assertEqual(progress, [0, 99, 100])
|
self.assertEqual(progress, [99, 1])
|
||||||
self.assertEqual(f.getvalue(), b"."*99+b"!")
|
self.assertEqual(f.getvalue(), b"."*99+b"!")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -394,13 +394,11 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
|
|
||||||
# Helper method to write a known number of bytes to a file. This has no
|
# Helper method to write a known number of bytes to a file. This has no
|
||||||
# flow control: the filehandle cannot push back. 'progress' is an
|
# flow control: the filehandle cannot push back. 'progress' is an
|
||||||
# optional callable which will be called frequently with the number of
|
# optional callable which will be called on each write (with the number
|
||||||
# bytes transferred so far. Returns a Deferred that fires (with the
|
# of bytes written). Returns a Deferred that fires (with the number of
|
||||||
# number of bytes written) when the count is reached or the RecordPipe is
|
# bytes written) when the count is reached or the RecordPipe is closed.
|
||||||
# closed.
|
|
||||||
def writeToFile(self, f, expected, progress=None):
|
def writeToFile(self, f, expected, progress=None):
|
||||||
progress = progress or (lambda n: None)
|
fc = FileConsumer(f, progress)
|
||||||
fc = FileConsumer(f, expected, progress)
|
|
||||||
return self.connectConsumer(fc, expected)
|
return self.connectConsumer(fc, expected)
|
||||||
|
|
||||||
class OutboundConnectionFactory(protocol.ClientFactory):
|
class OutboundConnectionFactory(protocol.ClientFactory):
|
||||||
|
@ -805,19 +803,16 @@ class TransitReceiver(Common):
|
||||||
is_sender = False
|
is_sender = False
|
||||||
|
|
||||||
|
|
||||||
# based on twisted.protocols.ftp.FileConsumer, but:
|
# based on twisted.protocols.ftp.FileConsumer, but don't close the filehandle
|
||||||
# - call a progress-tracking function
|
# when done, and add a progress function that gets called with the length of
|
||||||
# - don't close the filehandle when done
|
# each write.
|
||||||
|
|
||||||
@implementer(interfaces.IConsumer)
|
@implementer(interfaces.IConsumer)
|
||||||
class FileConsumer:
|
class FileConsumer:
|
||||||
def __init__(self, f, xfersize, progress_f):
|
def __init__(self, f, progress=None):
|
||||||
self._f = f
|
self._f = f
|
||||||
self._xfersize = xfersize
|
self._progress = progress
|
||||||
self._received = 0
|
|
||||||
self._progress_f = progress_f
|
|
||||||
self._producer = None
|
self._producer = None
|
||||||
self._progress_f(0)
|
|
||||||
|
|
||||||
def registerProducer(self, producer, streaming):
|
def registerProducer(self, producer, streaming):
|
||||||
assert not self._producer
|
assert not self._producer
|
||||||
|
@ -826,8 +821,8 @@ class FileConsumer:
|
||||||
|
|
||||||
def write(self, bytes):
|
def write(self, bytes):
|
||||||
self._f.write(bytes)
|
self._f.write(bytes)
|
||||||
self._received += len(bytes)
|
if self._progress:
|
||||||
self._progress_f(self._received)
|
self._progress(len(bytes))
|
||||||
|
|
||||||
def unregisterProducer(self):
|
def unregisterProducer(self):
|
||||||
assert self._producer
|
assert self._producer
|
||||||
|
@ -851,8 +846,6 @@ class FileConsumer:
|
||||||
|
|
||||||
# check start/finish time-gathering instrumentation
|
# check start/finish time-gathering instrumentation
|
||||||
|
|
||||||
# add progress API
|
|
||||||
|
|
||||||
# relay URLs are probably mishandled: both sides probably send their URL,
|
# relay URLs are probably mishandled: both sides probably send their URL,
|
||||||
# then connect to the *other* side's URL, when they really should connect to
|
# then connect to the *other* side's URL, when they really should connect to
|
||||||
# both their own and the other side's. The current implementation probably
|
# both their own and the other side's. The current implementation probably
|
||||||
|
|
Loading…
Reference in New Issue
Block a user