rewrite ProgressPrinter as a class, add tests
This commit is contained in:
parent
00833a4bde
commit
5e928ac9f0
|
@ -1,12 +1,12 @@
|
|||
from __future__ import print_function
|
||||
import os, sys, json, binascii, six, tempfile, zipfile
|
||||
from ..errors import handle_server_error
|
||||
from .progress import ProgressPrinter
|
||||
|
||||
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
||||
|
||||
def accept_file(args, them_d, w):
|
||||
from ..blocking.transit import TransitReceiver, TransitError
|
||||
from .progress import start_progress, update_progress, finish_progress
|
||||
|
||||
file_data = them_d["file"]
|
||||
if args.output_file:
|
||||
|
@ -58,7 +58,8 @@ def accept_file(args, them_d, w):
|
|||
tmp = filename + ".tmp"
|
||||
with open(tmp, "wb") as f:
|
||||
received = 0
|
||||
next_update = start_progress(filesize)
|
||||
p = ProgressPrinter(filesize, sys.stdout)
|
||||
p.start()
|
||||
while received < filesize:
|
||||
try:
|
||||
plaintext = record_pipe.receive_record()
|
||||
|
@ -69,8 +70,8 @@ def accept_file(args, them_d, w):
|
|||
return 1
|
||||
f.write(plaintext)
|
||||
received += len(plaintext)
|
||||
next_update = update_progress(next_update, received, filesize)
|
||||
finish_progress(filesize)
|
||||
p.update(received)
|
||||
p.finish()
|
||||
assert received == filesize
|
||||
|
||||
os.rename(tmp, filename)
|
||||
|
@ -82,7 +83,6 @@ def accept_file(args, them_d, w):
|
|||
|
||||
def accept_directory(args, them_d, w):
|
||||
from ..blocking.transit import TransitReceiver, TransitError
|
||||
from .progress import start_progress, update_progress, finish_progress
|
||||
|
||||
file_data = them_d["directory"]
|
||||
mode = file_data["mode"]
|
||||
|
@ -143,7 +143,8 @@ def accept_directory(args, them_d, w):
|
|||
transit_receiver.describe()))
|
||||
f = tempfile.SpooledTemporaryFile()
|
||||
received = 0
|
||||
next_update = start_progress(filesize)
|
||||
p = ProgressPrinter(filesize, sys.stdout)
|
||||
p.start()
|
||||
while received < filesize:
|
||||
try:
|
||||
plaintext = record_pipe.receive_record()
|
||||
|
@ -154,8 +155,8 @@ def accept_directory(args, them_d, w):
|
|||
return 1
|
||||
f.write(plaintext)
|
||||
received += len(plaintext)
|
||||
next_update = update_progress(next_update, received, filesize)
|
||||
finish_progress(filesize)
|
||||
p.update(received)
|
||||
p.finish()
|
||||
assert received == filesize
|
||||
print("Unpacking zipfile..")
|
||||
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import print_function
|
||||
import json, binascii, six
|
||||
import sys, json, binascii, six
|
||||
from ..errors import TransferError
|
||||
from .progress import start_progress, update_progress, finish_progress
|
||||
from .progress import ProgressPrinter
|
||||
|
||||
def send_blocking(appid, args, phase1, fd_to_send):
|
||||
from ..blocking.transcribe import Wormhole, WrongPasswordError
|
||||
|
@ -83,15 +83,16 @@ def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender):
|
|||
fd_to_send.seek(0,2)
|
||||
filesize = fd_to_send.tell()
|
||||
fd_to_send.seek(0,0)
|
||||
p = ProgressPrinter(filesize, sys.stdout)
|
||||
with fd_to_send as f:
|
||||
sent = 0
|
||||
next_update = start_progress(filesize)
|
||||
p.start()
|
||||
while sent < filesize:
|
||||
plaintext = f.read(CHUNKSIZE)
|
||||
record_pipe.send_record(plaintext)
|
||||
sent += len(plaintext)
|
||||
next_update = update_progress(next_update, sent, filesize)
|
||||
finish_progress(filesize)
|
||||
p.update(sent)
|
||||
p.finish()
|
||||
|
||||
print("File sent.. waiting for confirmation")
|
||||
ack = record_pipe.receive_record()
|
||||
|
|
|
@ -1,40 +1,49 @@
|
|||
from __future__ import print_function
|
||||
import sys, time
|
||||
import time
|
||||
|
||||
def print_progress(completed, expected):
|
||||
class ProgressPrinter:
|
||||
def __init__(self, expected, stdout, update_every=0.2):
|
||||
self._expected = expected
|
||||
self._stdout = stdout
|
||||
self._update_every = update_every
|
||||
|
||||
def _now(self):
|
||||
return time.time()
|
||||
|
||||
def start(self):
|
||||
self._print(0)
|
||||
self._next_update = self._now() + self._update_every
|
||||
|
||||
def update(self, completed):
|
||||
now = self._now()
|
||||
if now < self._next_update:
|
||||
return
|
||||
self._next_update = now + self._update_every
|
||||
self._print(completed)
|
||||
|
||||
def finish(self):
|
||||
self._print(self._expected)
|
||||
print(u"", file=self._stdout)
|
||||
|
||||
def _print(self, completed):
|
||||
# scp does "<<FILENAME >>(13% 168MB 39.3MB/s 00:27 ETA)"
|
||||
# we do "Progress: #### 13% 168MB"
|
||||
fmt = "Progress: %-40s %3d%% %4d%s"
|
||||
short_unit_size, short_unit_name = 1, "B"
|
||||
if expected > 9999:
|
||||
if self._expected > 9999:
|
||||
short_unit_size, short_unit_name = 1000, "KB"
|
||||
if expected > 9999*1000:
|
||||
if self._expected > 9999*1000:
|
||||
short_unit_size, short_unit_name = 1000*1000, "MB"
|
||||
if expected > 9999*1000*1000:
|
||||
if self._expected > 9999*1000*1000:
|
||||
short_unit_size, short_unit_name = 1000*1000*1000, "GB"
|
||||
|
||||
percentage_complete = (1.0 * completed / expected) if expected else 1.0
|
||||
percentage_complete = ((1.0 * completed / self._expected)
|
||||
if self._expected
|
||||
else 1.0)
|
||||
bars = "#" * int(percentage_complete * 40)
|
||||
perc = int(100 * percentage_complete)
|
||||
short_unit_count = int(completed / short_unit_size)
|
||||
out = fmt % (bars, perc, short_unit_count, short_unit_name)
|
||||
print("\r"+" "*70, end="")
|
||||
print("\r"+out, end="")
|
||||
sys.stdout.flush()
|
||||
|
||||
def start_progress(expected, UPDATE_EVERY=0.2):
|
||||
print_progress(0, expected)
|
||||
next_update = time.time() + UPDATE_EVERY
|
||||
return next_update
|
||||
|
||||
def update_progress(next_update, completed, expected, UPDATE_EVERY=0.2):
|
||||
now = time.time()
|
||||
if now < next_update:
|
||||
return next_update
|
||||
next_update = now + UPDATE_EVERY
|
||||
print_progress(completed, expected)
|
||||
return next_update
|
||||
|
||||
def finish_progress(expected):
|
||||
print_progress(expected, expected)
|
||||
print()
|
||||
print(u"\r"+" "*70, end=u"", file=self._stdout)
|
||||
print(u"\r"+out, end=u"", file=self._stdout)
|
||||
self._stdout.flush()
|
||||
|
|
69
src/wormhole/test/test_progress.py
Normal file
69
src/wormhole/test/test_progress.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
from __future__ import print_function
|
||||
import io, time
|
||||
from twisted.trial import unittest
|
||||
from ..scripts import progress
|
||||
|
||||
class Progress(unittest.TestCase):
|
||||
def test_time(self):
|
||||
p = progress.ProgressPrinter(1e6, None)
|
||||
start = time.time()
|
||||
now = p._now()
|
||||
finish = time.time()
|
||||
self.assertTrue(start <= now <= finish, (start, now, finish))
|
||||
|
||||
def test_basic(self):
|
||||
stdout = io.StringIO()
|
||||
p = progress.ProgressPrinter(1e6, stdout)
|
||||
p._now = lambda: 0.0
|
||||
p.start()
|
||||
|
||||
erase = u"\r"+u" "*70
|
||||
expected = erase
|
||||
fmt = "Progress: %-40s %3d%% %4d%s"
|
||||
expected += u"\r" + fmt % ("", 0, 0, "KB")
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
p.update(1e3) # no change, too soon
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
p._now = lambda: 1.0
|
||||
p.update(1e3) # enough "time" has passed
|
||||
expected += erase + u"\r" + fmt % ("", 0, 1, "KB")
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
p._now = lambda: 2.0
|
||||
p.update(500e3)
|
||||
expected += erase + u"\r" + fmt % ("#"*20, 50, 500, "KB")
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
p._now = lambda: 3.0
|
||||
p.finish()
|
||||
expected += erase + u"\r" + fmt % ("#"*40, 100, 1000, "KB")
|
||||
expected += u"\n"
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
def test_units(self):
|
||||
def _try(size):
|
||||
stdout = io.StringIO()
|
||||
p = progress.ProgressPrinter(size, stdout)
|
||||
p.finish()
|
||||
return stdout.getvalue()
|
||||
|
||||
fmt = "Progress: %-40s %3d%% %4d%s"
|
||||
def _expect(count, units):
|
||||
erase = u"\r"+u" "*70
|
||||
expected = erase
|
||||
expected += u"\r" + fmt % ("#"*40, 100, count, units)
|
||||
expected += u"\n"
|
||||
return expected
|
||||
|
||||
self.assertEqual(_try(900), _expect(900, "B"))
|
||||
self.assertEqual(_try(9e3), _expect(9000, "B"))
|
||||
self.assertEqual(_try(90e3), _expect(90, "KB"))
|
||||
self.assertEqual(_try(900e3), _expect(900, "KB"))
|
||||
self.assertEqual(_try(9e6), _expect(9000, "KB"))
|
||||
self.assertEqual(_try(90e6), _expect(90, "MB"))
|
||||
self.assertEqual(_try(900e6), _expect(900, "MB"))
|
||||
self.assertEqual(_try(9e9), _expect(9000, "MB"))
|
||||
self.assertEqual(_try(90e9), _expect(90, "GB"))
|
||||
self.assertEqual(_try(900e9), _expect(900, "GB"))
|
Loading…
Reference in New Issue
Block a user