rewrite ProgressPrinter as a class, add tests

This commit is contained in:
Brian Warner 2016-02-17 11:27:29 -08:00
parent 00833a4bde
commit 5e928ac9f0
4 changed files with 127 additions and 47 deletions

View File

@ -1,12 +1,12 @@
from __future__ import print_function from __future__ import print_function
import os, sys, json, binascii, six, tempfile, zipfile import os, sys, json, binascii, six, tempfile, zipfile
from ..errors import handle_server_error from ..errors import handle_server_error
from .progress import ProgressPrinter
APPID = u"lothar.com/wormhole/text-or-file-xfer" APPID = u"lothar.com/wormhole/text-or-file-xfer"
def accept_file(args, them_d, w): def accept_file(args, them_d, w):
from ..blocking.transit import TransitReceiver, TransitError from ..blocking.transit import TransitReceiver, TransitError
from .progress import start_progress, update_progress, finish_progress
file_data = them_d["file"] file_data = them_d["file"]
if args.output_file: if args.output_file:
@ -58,7 +58,8 @@ def accept_file(args, them_d, w):
tmp = filename + ".tmp" tmp = filename + ".tmp"
with open(tmp, "wb") as f: with open(tmp, "wb") as f:
received = 0 received = 0
next_update = start_progress(filesize) p = ProgressPrinter(filesize, sys.stdout)
p.start()
while received < filesize: while received < filesize:
try: try:
plaintext = record_pipe.receive_record() plaintext = record_pipe.receive_record()
@ -69,8 +70,8 @@ def accept_file(args, them_d, w):
return 1 return 1
f.write(plaintext) f.write(plaintext)
received += len(plaintext) received += len(plaintext)
next_update = update_progress(next_update, received, filesize) p.update(received)
finish_progress(filesize) p.finish()
assert received == filesize assert received == filesize
os.rename(tmp, filename) os.rename(tmp, filename)
@ -82,7 +83,6 @@ def accept_file(args, them_d, w):
def accept_directory(args, them_d, w): def accept_directory(args, them_d, w):
from ..blocking.transit import TransitReceiver, TransitError from ..blocking.transit import TransitReceiver, TransitError
from .progress import start_progress, update_progress, finish_progress
file_data = them_d["directory"] file_data = them_d["directory"]
mode = file_data["mode"] mode = file_data["mode"]
@ -143,7 +143,8 @@ def accept_directory(args, them_d, w):
transit_receiver.describe())) transit_receiver.describe()))
f = tempfile.SpooledTemporaryFile() f = tempfile.SpooledTemporaryFile()
received = 0 received = 0
next_update = start_progress(filesize) p = ProgressPrinter(filesize, sys.stdout)
p.start()
while received < filesize: while received < filesize:
try: try:
plaintext = record_pipe.receive_record() plaintext = record_pipe.receive_record()
@ -154,8 +155,8 @@ def accept_directory(args, them_d, w):
return 1 return 1
f.write(plaintext) f.write(plaintext)
received += len(plaintext) received += len(plaintext)
next_update = update_progress(next_update, received, filesize) p.update(received)
finish_progress(filesize) p.finish()
assert received == filesize assert received == filesize
print("Unpacking zipfile..") print("Unpacking zipfile..")
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:

View File

@ -1,7 +1,7 @@
from __future__ import print_function from __future__ import print_function
import json, binascii, six import sys, json, binascii, six
from ..errors import TransferError from ..errors import TransferError
from .progress import start_progress, update_progress, finish_progress from .progress import ProgressPrinter
def send_blocking(appid, args, phase1, fd_to_send): def send_blocking(appid, args, phase1, fd_to_send):
from ..blocking.transcribe import Wormhole, WrongPasswordError from ..blocking.transcribe import Wormhole, WrongPasswordError
@ -83,15 +83,16 @@ def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender):
fd_to_send.seek(0,2) fd_to_send.seek(0,2)
filesize = fd_to_send.tell() filesize = fd_to_send.tell()
fd_to_send.seek(0,0) fd_to_send.seek(0,0)
p = ProgressPrinter(filesize, sys.stdout)
with fd_to_send as f: with fd_to_send as f:
sent = 0 sent = 0
next_update = start_progress(filesize) p.start()
while sent < filesize: while sent < filesize:
plaintext = f.read(CHUNKSIZE) plaintext = f.read(CHUNKSIZE)
record_pipe.send_record(plaintext) record_pipe.send_record(plaintext)
sent += len(plaintext) sent += len(plaintext)
next_update = update_progress(next_update, sent, filesize) p.update(sent)
finish_progress(filesize) p.finish()
print("File sent.. waiting for confirmation") print("File sent.. waiting for confirmation")
ack = record_pipe.receive_record() ack = record_pipe.receive_record()

View File

@ -1,40 +1,49 @@
from __future__ import print_function from __future__ import print_function
import sys, time import time
def print_progress(completed, expected): class ProgressPrinter:
def __init__(self, expected, stdout, update_every=0.2):
self._expected = expected
self._stdout = stdout
self._update_every = update_every
def _now(self):
return time.time()
def start(self):
self._print(0)
self._next_update = self._now() + self._update_every
def update(self, completed):
now = self._now()
if now < self._next_update:
return
self._next_update = now + self._update_every
self._print(completed)
def finish(self):
self._print(self._expected)
print(u"", file=self._stdout)
def _print(self, completed):
# scp does "<<FILENAME >>(13% 168MB 39.3MB/s 00:27 ETA)" # scp does "<<FILENAME >>(13% 168MB 39.3MB/s 00:27 ETA)"
# we do "Progress: #### 13% 168MB" # we do "Progress: #### 13% 168MB"
fmt = "Progress: %-40s %3d%% %4d%s" fmt = "Progress: %-40s %3d%% %4d%s"
short_unit_size, short_unit_name = 1, "B" short_unit_size, short_unit_name = 1, "B"
if expected > 9999: if self._expected > 9999:
short_unit_size, short_unit_name = 1000, "KB" short_unit_size, short_unit_name = 1000, "KB"
if expected > 9999*1000: if self._expected > 9999*1000:
short_unit_size, short_unit_name = 1000*1000, "MB" short_unit_size, short_unit_name = 1000*1000, "MB"
if expected > 9999*1000*1000: if self._expected > 9999*1000*1000:
short_unit_size, short_unit_name = 1000*1000*1000, "GB" short_unit_size, short_unit_name = 1000*1000*1000, "GB"
percentage_complete = (1.0 * completed / expected) if expected else 1.0 percentage_complete = ((1.0 * completed / self._expected)
if self._expected
else 1.0)
bars = "#" * int(percentage_complete * 40) bars = "#" * int(percentage_complete * 40)
perc = int(100 * percentage_complete) perc = int(100 * percentage_complete)
short_unit_count = int(completed / short_unit_size) short_unit_count = int(completed / short_unit_size)
out = fmt % (bars, perc, short_unit_count, short_unit_name) out = fmt % (bars, perc, short_unit_count, short_unit_name)
print("\r"+" "*70, end="") print(u"\r"+" "*70, end=u"", file=self._stdout)
print("\r"+out, end="") print(u"\r"+out, end=u"", file=self._stdout)
sys.stdout.flush() self._stdout.flush()
def start_progress(expected, UPDATE_EVERY=0.2):
print_progress(0, expected)
next_update = time.time() + UPDATE_EVERY
return next_update
def update_progress(next_update, completed, expected, UPDATE_EVERY=0.2):
now = time.time()
if now < next_update:
return next_update
next_update = now + UPDATE_EVERY
print_progress(completed, expected)
return next_update
def finish_progress(expected):
print_progress(expected, expected)
print()

View File

@ -0,0 +1,69 @@
from __future__ import print_function
import io, time
from twisted.trial import unittest
from ..scripts import progress
class Progress(unittest.TestCase):
def test_time(self):
p = progress.ProgressPrinter(1e6, None)
start = time.time()
now = p._now()
finish = time.time()
self.assertTrue(start <= now <= finish, (start, now, finish))
def test_basic(self):
stdout = io.StringIO()
p = progress.ProgressPrinter(1e6, stdout)
p._now = lambda: 0.0
p.start()
erase = u"\r"+u" "*70
expected = erase
fmt = "Progress: %-40s %3d%% %4d%s"
expected += u"\r" + fmt % ("", 0, 0, "KB")
self.assertEqual(stdout.getvalue(), expected)
p.update(1e3) # no change, too soon
self.assertEqual(stdout.getvalue(), expected)
p._now = lambda: 1.0
p.update(1e3) # enough "time" has passed
expected += erase + u"\r" + fmt % ("", 0, 1, "KB")
self.assertEqual(stdout.getvalue(), expected)
p._now = lambda: 2.0
p.update(500e3)
expected += erase + u"\r" + fmt % ("#"*20, 50, 500, "KB")
self.assertEqual(stdout.getvalue(), expected)
p._now = lambda: 3.0
p.finish()
expected += erase + u"\r" + fmt % ("#"*40, 100, 1000, "KB")
expected += u"\n"
self.assertEqual(stdout.getvalue(), expected)
def test_units(self):
def _try(size):
stdout = io.StringIO()
p = progress.ProgressPrinter(size, stdout)
p.finish()
return stdout.getvalue()
fmt = "Progress: %-40s %3d%% %4d%s"
def _expect(count, units):
erase = u"\r"+u" "*70
expected = erase
expected += u"\r" + fmt % ("#"*40, 100, count, units)
expected += u"\n"
return expected
self.assertEqual(_try(900), _expect(900, "B"))
self.assertEqual(_try(9e3), _expect(9000, "B"))
self.assertEqual(_try(90e3), _expect(90, "KB"))
self.assertEqual(_try(900e3), _expect(900, "KB"))
self.assertEqual(_try(9e6), _expect(9000, "KB"))
self.assertEqual(_try(90e6), _expect(90, "MB"))
self.assertEqual(_try(900e6), _expect(900, "MB"))
self.assertEqual(_try(9e9), _expect(9000, "MB"))
self.assertEqual(_try(90e9), _expect(90, "GB"))
self.assertEqual(_try(900e9), _expect(900, "GB"))