From 5e928ac9f07da675701b1187a4f656b9487e2ec8 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 17 Feb 2016 11:27:29 -0800 Subject: [PATCH] rewrite ProgressPrinter as a class, add tests --- src/wormhole/scripts/cmd_receive.py | 17 ++--- src/wormhole/scripts/cmd_send_blocking.py | 11 ++-- src/wormhole/scripts/progress.py | 77 +++++++++++++---------- src/wormhole/test/test_progress.py | 69 ++++++++++++++++++++ 4 files changed, 127 insertions(+), 47 deletions(-) create mode 100644 src/wormhole/test/test_progress.py diff --git a/src/wormhole/scripts/cmd_receive.py b/src/wormhole/scripts/cmd_receive.py index 47bf10a..a54f788 100644 --- a/src/wormhole/scripts/cmd_receive.py +++ b/src/wormhole/scripts/cmd_receive.py @@ -1,12 +1,12 @@ from __future__ import print_function import os, sys, json, binascii, six, tempfile, zipfile from ..errors import handle_server_error +from .progress import ProgressPrinter APPID = u"lothar.com/wormhole/text-or-file-xfer" def accept_file(args, them_d, w): from ..blocking.transit import TransitReceiver, TransitError - from .progress import start_progress, update_progress, finish_progress file_data = them_d["file"] if args.output_file: @@ -58,7 +58,8 @@ def accept_file(args, them_d, w): tmp = filename + ".tmp" with open(tmp, "wb") as f: received = 0 - next_update = start_progress(filesize) + p = ProgressPrinter(filesize, sys.stdout) + p.start() while received < filesize: try: plaintext = record_pipe.receive_record() @@ -69,8 +70,8 @@ def accept_file(args, them_d, w): return 1 f.write(plaintext) received += len(plaintext) - next_update = update_progress(next_update, received, filesize) - finish_progress(filesize) + p.update(received) + p.finish() assert received == filesize os.rename(tmp, filename) @@ -82,7 +83,6 @@ def accept_file(args, them_d, w): def accept_directory(args, them_d, w): from ..blocking.transit import TransitReceiver, TransitError - from .progress import start_progress, update_progress, finish_progress file_data = them_d["directory"] mode = file_data["mode"] @@ -143,7 +143,8 @@ def accept_directory(args, them_d, w): transit_receiver.describe())) f = tempfile.SpooledTemporaryFile() received = 0 - next_update = start_progress(filesize) + p = ProgressPrinter(filesize, sys.stdout) + p.start() while received < filesize: try: plaintext = record_pipe.receive_record() @@ -154,8 +155,8 @@ def accept_directory(args, them_d, w): return 1 f.write(plaintext) received += len(plaintext) - next_update = update_progress(next_update, received, filesize) - finish_progress(filesize) + p.update(received) + p.finish() assert received == filesize print("Unpacking zipfile..") with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: diff --git a/src/wormhole/scripts/cmd_send_blocking.py b/src/wormhole/scripts/cmd_send_blocking.py index 580eb4d..e59209e 100644 --- a/src/wormhole/scripts/cmd_send_blocking.py +++ b/src/wormhole/scripts/cmd_send_blocking.py @@ -1,7 +1,7 @@ from __future__ import print_function -import json, binascii, six +import sys, json, binascii, six from ..errors import TransferError -from .progress import start_progress, update_progress, finish_progress +from .progress import ProgressPrinter def send_blocking(appid, args, phase1, fd_to_send): from ..blocking.transcribe import Wormhole, WrongPasswordError @@ -83,15 +83,16 @@ def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender): fd_to_send.seek(0,2) filesize = fd_to_send.tell() fd_to_send.seek(0,0) + p = ProgressPrinter(filesize, sys.stdout) with fd_to_send as f: sent = 0 - next_update = start_progress(filesize) + p.start() while sent < filesize: plaintext = f.read(CHUNKSIZE) record_pipe.send_record(plaintext) sent += len(plaintext) - next_update = update_progress(next_update, sent, filesize) - finish_progress(filesize) + p.update(sent) + p.finish() print("File sent.. waiting for confirmation") ack = record_pipe.receive_record() diff --git a/src/wormhole/scripts/progress.py b/src/wormhole/scripts/progress.py index 2eb8ce3..dec42df 100644 --- a/src/wormhole/scripts/progress.py +++ b/src/wormhole/scripts/progress.py @@ -1,40 +1,49 @@ from __future__ import print_function -import sys, time +import time -def print_progress(completed, expected): - # scp does "<>(13% 168MB 39.3MB/s 00:27 ETA)" - # we do "Progress: #### 13% 168MB" - fmt = "Progress: %-40s %3d%% %4d%s" - short_unit_size, short_unit_name = 1, "B" - if expected > 9999: - short_unit_size, short_unit_name = 1000, "KB" - if expected > 9999*1000: - short_unit_size, short_unit_name = 1000*1000, "MB" - if expected > 9999*1000*1000: - short_unit_size, short_unit_name = 1000*1000*1000, "GB" +class ProgressPrinter: + def __init__(self, expected, stdout, update_every=0.2): + self._expected = expected + self._stdout = stdout + self._update_every = update_every - percentage_complete = (1.0 * completed / expected) if expected else 1.0 - bars = "#" * int(percentage_complete * 40) - perc = int(100 * percentage_complete) - short_unit_count = int(completed / short_unit_size) - out = fmt % (bars, perc, short_unit_count, short_unit_name) - print("\r"+" "*70, end="") - print("\r"+out, end="") - sys.stdout.flush() + def _now(self): + return time.time() -def start_progress(expected, UPDATE_EVERY=0.2): - print_progress(0, expected) - next_update = time.time() + UPDATE_EVERY - return next_update + def start(self): + self._print(0) + self._next_update = self._now() + self._update_every -def update_progress(next_update, completed, expected, UPDATE_EVERY=0.2): - now = time.time() - if now < next_update: - return next_update - next_update = now + UPDATE_EVERY - print_progress(completed, expected) - return next_update + def update(self, completed): + now = self._now() + if now < self._next_update: + return + self._next_update = now + self._update_every + self._print(completed) -def finish_progress(expected): - print_progress(expected, expected) - print() + def finish(self): + self._print(self._expected) + print(u"", file=self._stdout) + + def _print(self, completed): + # scp does "<>(13% 168MB 39.3MB/s 00:27 ETA)" + # we do "Progress: #### 13% 168MB" + fmt = "Progress: %-40s %3d%% %4d%s" + short_unit_size, short_unit_name = 1, "B" + if self._expected > 9999: + short_unit_size, short_unit_name = 1000, "KB" + if self._expected > 9999*1000: + short_unit_size, short_unit_name = 1000*1000, "MB" + if self._expected > 9999*1000*1000: + short_unit_size, short_unit_name = 1000*1000*1000, "GB" + + percentage_complete = ((1.0 * completed / self._expected) + if self._expected + else 1.0) + bars = "#" * int(percentage_complete * 40) + perc = int(100 * percentage_complete) + short_unit_count = int(completed / short_unit_size) + out = fmt % (bars, perc, short_unit_count, short_unit_name) + print(u"\r"+" "*70, end=u"", file=self._stdout) + print(u"\r"+out, end=u"", file=self._stdout) + self._stdout.flush() diff --git a/src/wormhole/test/test_progress.py b/src/wormhole/test/test_progress.py new file mode 100644 index 0000000..879f93f --- /dev/null +++ b/src/wormhole/test/test_progress.py @@ -0,0 +1,69 @@ +from __future__ import print_function +import io, time +from twisted.trial import unittest +from ..scripts import progress + +class Progress(unittest.TestCase): + def test_time(self): + p = progress.ProgressPrinter(1e6, None) + start = time.time() + now = p._now() + finish = time.time() + self.assertTrue(start <= now <= finish, (start, now, finish)) + + def test_basic(self): + stdout = io.StringIO() + p = progress.ProgressPrinter(1e6, stdout) + p._now = lambda: 0.0 + p.start() + + erase = u"\r"+u" "*70 + expected = erase + fmt = "Progress: %-40s %3d%% %4d%s" + expected += u"\r" + fmt % ("", 0, 0, "KB") + self.assertEqual(stdout.getvalue(), expected) + + p.update(1e3) # no change, too soon + self.assertEqual(stdout.getvalue(), expected) + + p._now = lambda: 1.0 + p.update(1e3) # enough "time" has passed + expected += erase + u"\r" + fmt % ("", 0, 1, "KB") + self.assertEqual(stdout.getvalue(), expected) + + p._now = lambda: 2.0 + p.update(500e3) + expected += erase + u"\r" + fmt % ("#"*20, 50, 500, "KB") + self.assertEqual(stdout.getvalue(), expected) + + p._now = lambda: 3.0 + p.finish() + expected += erase + u"\r" + fmt % ("#"*40, 100, 1000, "KB") + expected += u"\n" + self.assertEqual(stdout.getvalue(), expected) + + def test_units(self): + def _try(size): + stdout = io.StringIO() + p = progress.ProgressPrinter(size, stdout) + p.finish() + return stdout.getvalue() + + fmt = "Progress: %-40s %3d%% %4d%s" + def _expect(count, units): + erase = u"\r"+u" "*70 + expected = erase + expected += u"\r" + fmt % ("#"*40, 100, count, units) + expected += u"\n" + return expected + + self.assertEqual(_try(900), _expect(900, "B")) + self.assertEqual(_try(9e3), _expect(9000, "B")) + self.assertEqual(_try(90e3), _expect(90, "KB")) + self.assertEqual(_try(900e3), _expect(900, "KB")) + self.assertEqual(_try(9e6), _expect(9000, "KB")) + self.assertEqual(_try(90e6), _expect(90, "MB")) + self.assertEqual(_try(900e6), _expect(900, "MB")) + self.assertEqual(_try(9e9), _expect(9000, "MB")) + self.assertEqual(_try(90e9), _expect(90, "GB")) + self.assertEqual(_try(900e9), _expect(900, "GB"))