rewrite ProgressPrinter as a class, add tests
This commit is contained in:
parent
00833a4bde
commit
5e928ac9f0
|
@ -1,12 +1,12 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import os, sys, json, binascii, six, tempfile, zipfile
|
import os, sys, json, binascii, six, tempfile, zipfile
|
||||||
from ..errors import handle_server_error
|
from ..errors import handle_server_error
|
||||||
|
from .progress import ProgressPrinter
|
||||||
|
|
||||||
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
||||||
|
|
||||||
def accept_file(args, them_d, w):
|
def accept_file(args, them_d, w):
|
||||||
from ..blocking.transit import TransitReceiver, TransitError
|
from ..blocking.transit import TransitReceiver, TransitError
|
||||||
from .progress import start_progress, update_progress, finish_progress
|
|
||||||
|
|
||||||
file_data = them_d["file"]
|
file_data = them_d["file"]
|
||||||
if args.output_file:
|
if args.output_file:
|
||||||
|
@ -58,7 +58,8 @@ def accept_file(args, them_d, w):
|
||||||
tmp = filename + ".tmp"
|
tmp = filename + ".tmp"
|
||||||
with open(tmp, "wb") as f:
|
with open(tmp, "wb") as f:
|
||||||
received = 0
|
received = 0
|
||||||
next_update = start_progress(filesize)
|
p = ProgressPrinter(filesize, sys.stdout)
|
||||||
|
p.start()
|
||||||
while received < filesize:
|
while received < filesize:
|
||||||
try:
|
try:
|
||||||
plaintext = record_pipe.receive_record()
|
plaintext = record_pipe.receive_record()
|
||||||
|
@ -69,8 +70,8 @@ def accept_file(args, them_d, w):
|
||||||
return 1
|
return 1
|
||||||
f.write(plaintext)
|
f.write(plaintext)
|
||||||
received += len(plaintext)
|
received += len(plaintext)
|
||||||
next_update = update_progress(next_update, received, filesize)
|
p.update(received)
|
||||||
finish_progress(filesize)
|
p.finish()
|
||||||
assert received == filesize
|
assert received == filesize
|
||||||
|
|
||||||
os.rename(tmp, filename)
|
os.rename(tmp, filename)
|
||||||
|
@ -82,7 +83,6 @@ def accept_file(args, them_d, w):
|
||||||
|
|
||||||
def accept_directory(args, them_d, w):
|
def accept_directory(args, them_d, w):
|
||||||
from ..blocking.transit import TransitReceiver, TransitError
|
from ..blocking.transit import TransitReceiver, TransitError
|
||||||
from .progress import start_progress, update_progress, finish_progress
|
|
||||||
|
|
||||||
file_data = them_d["directory"]
|
file_data = them_d["directory"]
|
||||||
mode = file_data["mode"]
|
mode = file_data["mode"]
|
||||||
|
@ -143,7 +143,8 @@ def accept_directory(args, them_d, w):
|
||||||
transit_receiver.describe()))
|
transit_receiver.describe()))
|
||||||
f = tempfile.SpooledTemporaryFile()
|
f = tempfile.SpooledTemporaryFile()
|
||||||
received = 0
|
received = 0
|
||||||
next_update = start_progress(filesize)
|
p = ProgressPrinter(filesize, sys.stdout)
|
||||||
|
p.start()
|
||||||
while received < filesize:
|
while received < filesize:
|
||||||
try:
|
try:
|
||||||
plaintext = record_pipe.receive_record()
|
plaintext = record_pipe.receive_record()
|
||||||
|
@ -154,8 +155,8 @@ def accept_directory(args, them_d, w):
|
||||||
return 1
|
return 1
|
||||||
f.write(plaintext)
|
f.write(plaintext)
|
||||||
received += len(plaintext)
|
received += len(plaintext)
|
||||||
next_update = update_progress(next_update, received, filesize)
|
p.update(received)
|
||||||
finish_progress(filesize)
|
p.finish()
|
||||||
assert received == filesize
|
assert received == filesize
|
||||||
print("Unpacking zipfile..")
|
print("Unpacking zipfile..")
|
||||||
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
|
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import json, binascii, six
|
import sys, json, binascii, six
|
||||||
from ..errors import TransferError
|
from ..errors import TransferError
|
||||||
from .progress import start_progress, update_progress, finish_progress
|
from .progress import ProgressPrinter
|
||||||
|
|
||||||
def send_blocking(appid, args, phase1, fd_to_send):
|
def send_blocking(appid, args, phase1, fd_to_send):
|
||||||
from ..blocking.transcribe import Wormhole, WrongPasswordError
|
from ..blocking.transcribe import Wormhole, WrongPasswordError
|
||||||
|
@ -83,15 +83,16 @@ def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender):
|
||||||
fd_to_send.seek(0,2)
|
fd_to_send.seek(0,2)
|
||||||
filesize = fd_to_send.tell()
|
filesize = fd_to_send.tell()
|
||||||
fd_to_send.seek(0,0)
|
fd_to_send.seek(0,0)
|
||||||
|
p = ProgressPrinter(filesize, sys.stdout)
|
||||||
with fd_to_send as f:
|
with fd_to_send as f:
|
||||||
sent = 0
|
sent = 0
|
||||||
next_update = start_progress(filesize)
|
p.start()
|
||||||
while sent < filesize:
|
while sent < filesize:
|
||||||
plaintext = f.read(CHUNKSIZE)
|
plaintext = f.read(CHUNKSIZE)
|
||||||
record_pipe.send_record(plaintext)
|
record_pipe.send_record(plaintext)
|
||||||
sent += len(plaintext)
|
sent += len(plaintext)
|
||||||
next_update = update_progress(next_update, sent, filesize)
|
p.update(sent)
|
||||||
finish_progress(filesize)
|
p.finish()
|
||||||
|
|
||||||
print("File sent.. waiting for confirmation")
|
print("File sent.. waiting for confirmation")
|
||||||
ack = record_pipe.receive_record()
|
ack = record_pipe.receive_record()
|
||||||
|
|
|
@ -1,40 +1,49 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import sys, time
|
import time
|
||||||
|
|
||||||
def print_progress(completed, expected):
|
class ProgressPrinter:
|
||||||
# scp does "<<FILENAME >>(13% 168MB 39.3MB/s 00:27 ETA)"
|
def __init__(self, expected, stdout, update_every=0.2):
|
||||||
# we do "Progress: #### 13% 168MB"
|
self._expected = expected
|
||||||
fmt = "Progress: %-40s %3d%% %4d%s"
|
self._stdout = stdout
|
||||||
short_unit_size, short_unit_name = 1, "B"
|
self._update_every = update_every
|
||||||
if expected > 9999:
|
|
||||||
short_unit_size, short_unit_name = 1000, "KB"
|
|
||||||
if expected > 9999*1000:
|
|
||||||
short_unit_size, short_unit_name = 1000*1000, "MB"
|
|
||||||
if expected > 9999*1000*1000:
|
|
||||||
short_unit_size, short_unit_name = 1000*1000*1000, "GB"
|
|
||||||
|
|
||||||
percentage_complete = (1.0 * completed / expected) if expected else 1.0
|
def _now(self):
|
||||||
bars = "#" * int(percentage_complete * 40)
|
return time.time()
|
||||||
perc = int(100 * percentage_complete)
|
|
||||||
short_unit_count = int(completed / short_unit_size)
|
|
||||||
out = fmt % (bars, perc, short_unit_count, short_unit_name)
|
|
||||||
print("\r"+" "*70, end="")
|
|
||||||
print("\r"+out, end="")
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
def start_progress(expected, UPDATE_EVERY=0.2):
|
def start(self):
|
||||||
print_progress(0, expected)
|
self._print(0)
|
||||||
next_update = time.time() + UPDATE_EVERY
|
self._next_update = self._now() + self._update_every
|
||||||
return next_update
|
|
||||||
|
|
||||||
def update_progress(next_update, completed, expected, UPDATE_EVERY=0.2):
|
def update(self, completed):
|
||||||
now = time.time()
|
now = self._now()
|
||||||
if now < next_update:
|
if now < self._next_update:
|
||||||
return next_update
|
return
|
||||||
next_update = now + UPDATE_EVERY
|
self._next_update = now + self._update_every
|
||||||
print_progress(completed, expected)
|
self._print(completed)
|
||||||
return next_update
|
|
||||||
|
|
||||||
def finish_progress(expected):
|
def finish(self):
|
||||||
print_progress(expected, expected)
|
self._print(self._expected)
|
||||||
print()
|
print(u"", file=self._stdout)
|
||||||
|
|
||||||
|
def _print(self, completed):
|
||||||
|
# scp does "<<FILENAME >>(13% 168MB 39.3MB/s 00:27 ETA)"
|
||||||
|
# we do "Progress: #### 13% 168MB"
|
||||||
|
fmt = "Progress: %-40s %3d%% %4d%s"
|
||||||
|
short_unit_size, short_unit_name = 1, "B"
|
||||||
|
if self._expected > 9999:
|
||||||
|
short_unit_size, short_unit_name = 1000, "KB"
|
||||||
|
if self._expected > 9999*1000:
|
||||||
|
short_unit_size, short_unit_name = 1000*1000, "MB"
|
||||||
|
if self._expected > 9999*1000*1000:
|
||||||
|
short_unit_size, short_unit_name = 1000*1000*1000, "GB"
|
||||||
|
|
||||||
|
percentage_complete = ((1.0 * completed / self._expected)
|
||||||
|
if self._expected
|
||||||
|
else 1.0)
|
||||||
|
bars = "#" * int(percentage_complete * 40)
|
||||||
|
perc = int(100 * percentage_complete)
|
||||||
|
short_unit_count = int(completed / short_unit_size)
|
||||||
|
out = fmt % (bars, perc, short_unit_count, short_unit_name)
|
||||||
|
print(u"\r"+" "*70, end=u"", file=self._stdout)
|
||||||
|
print(u"\r"+out, end=u"", file=self._stdout)
|
||||||
|
self._stdout.flush()
|
||||||
|
|
69
src/wormhole/test/test_progress.py
Normal file
69
src/wormhole/test/test_progress.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import io, time
|
||||||
|
from twisted.trial import unittest
|
||||||
|
from ..scripts import progress
|
||||||
|
|
||||||
|
class Progress(unittest.TestCase):
|
||||||
|
def test_time(self):
|
||||||
|
p = progress.ProgressPrinter(1e6, None)
|
||||||
|
start = time.time()
|
||||||
|
now = p._now()
|
||||||
|
finish = time.time()
|
||||||
|
self.assertTrue(start <= now <= finish, (start, now, finish))
|
||||||
|
|
||||||
|
def test_basic(self):
|
||||||
|
stdout = io.StringIO()
|
||||||
|
p = progress.ProgressPrinter(1e6, stdout)
|
||||||
|
p._now = lambda: 0.0
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
erase = u"\r"+u" "*70
|
||||||
|
expected = erase
|
||||||
|
fmt = "Progress: %-40s %3d%% %4d%s"
|
||||||
|
expected += u"\r" + fmt % ("", 0, 0, "KB")
|
||||||
|
self.assertEqual(stdout.getvalue(), expected)
|
||||||
|
|
||||||
|
p.update(1e3) # no change, too soon
|
||||||
|
self.assertEqual(stdout.getvalue(), expected)
|
||||||
|
|
||||||
|
p._now = lambda: 1.0
|
||||||
|
p.update(1e3) # enough "time" has passed
|
||||||
|
expected += erase + u"\r" + fmt % ("", 0, 1, "KB")
|
||||||
|
self.assertEqual(stdout.getvalue(), expected)
|
||||||
|
|
||||||
|
p._now = lambda: 2.0
|
||||||
|
p.update(500e3)
|
||||||
|
expected += erase + u"\r" + fmt % ("#"*20, 50, 500, "KB")
|
||||||
|
self.assertEqual(stdout.getvalue(), expected)
|
||||||
|
|
||||||
|
p._now = lambda: 3.0
|
||||||
|
p.finish()
|
||||||
|
expected += erase + u"\r" + fmt % ("#"*40, 100, 1000, "KB")
|
||||||
|
expected += u"\n"
|
||||||
|
self.assertEqual(stdout.getvalue(), expected)
|
||||||
|
|
||||||
|
def test_units(self):
|
||||||
|
def _try(size):
|
||||||
|
stdout = io.StringIO()
|
||||||
|
p = progress.ProgressPrinter(size, stdout)
|
||||||
|
p.finish()
|
||||||
|
return stdout.getvalue()
|
||||||
|
|
||||||
|
fmt = "Progress: %-40s %3d%% %4d%s"
|
||||||
|
def _expect(count, units):
|
||||||
|
erase = u"\r"+u" "*70
|
||||||
|
expected = erase
|
||||||
|
expected += u"\r" + fmt % ("#"*40, 100, count, units)
|
||||||
|
expected += u"\n"
|
||||||
|
return expected
|
||||||
|
|
||||||
|
self.assertEqual(_try(900), _expect(900, "B"))
|
||||||
|
self.assertEqual(_try(9e3), _expect(9000, "B"))
|
||||||
|
self.assertEqual(_try(90e3), _expect(90, "KB"))
|
||||||
|
self.assertEqual(_try(900e3), _expect(900, "KB"))
|
||||||
|
self.assertEqual(_try(9e6), _expect(9000, "KB"))
|
||||||
|
self.assertEqual(_try(90e6), _expect(90, "MB"))
|
||||||
|
self.assertEqual(_try(900e6), _expect(900, "MB"))
|
||||||
|
self.assertEqual(_try(9e9), _expect(9000, "MB"))
|
||||||
|
self.assertEqual(_try(90e9), _expect(90, "GB"))
|
||||||
|
self.assertEqual(_try(900e9), _expect(900, "GB"))
|
Loading…
Reference in New Issue
Block a user