rewrite ProgressPrinter as a class, add tests
This commit is contained in:
		
							parent
							
								
									00833a4bde
								
							
						
					
					
						commit
						5e928ac9f0
					
				|  | @ -1,12 +1,12 @@ | |||
| from __future__ import print_function | ||||
| import os, sys, json, binascii, six, tempfile, zipfile | ||||
| from ..errors import handle_server_error | ||||
| from .progress import ProgressPrinter | ||||
| 
 | ||||
| APPID = u"lothar.com/wormhole/text-or-file-xfer" | ||||
| 
 | ||||
| def accept_file(args, them_d, w): | ||||
|     from ..blocking.transit import TransitReceiver, TransitError | ||||
|     from .progress import start_progress, update_progress, finish_progress | ||||
| 
 | ||||
|     file_data = them_d["file"] | ||||
|     if args.output_file: | ||||
|  | @ -58,7 +58,8 @@ def accept_file(args, them_d, w): | |||
|     tmp = filename + ".tmp" | ||||
|     with open(tmp, "wb") as f: | ||||
|         received = 0 | ||||
|         next_update = start_progress(filesize) | ||||
|         p = ProgressPrinter(filesize, sys.stdout) | ||||
|         p.start() | ||||
|         while received < filesize: | ||||
|             try: | ||||
|                 plaintext = record_pipe.receive_record() | ||||
|  | @ -69,8 +70,8 @@ def accept_file(args, them_d, w): | |||
|                 return 1 | ||||
|             f.write(plaintext) | ||||
|             received += len(plaintext) | ||||
|             next_update = update_progress(next_update, received, filesize) | ||||
|         finish_progress(filesize) | ||||
|             p.update(received) | ||||
|         p.finish() | ||||
|         assert received == filesize | ||||
| 
 | ||||
|     os.rename(tmp, filename) | ||||
|  | @ -82,7 +83,6 @@ def accept_file(args, them_d, w): | |||
| 
 | ||||
| def accept_directory(args, them_d, w): | ||||
|     from ..blocking.transit import TransitReceiver, TransitError | ||||
|     from .progress import start_progress, update_progress, finish_progress | ||||
| 
 | ||||
|     file_data = them_d["directory"] | ||||
|     mode = file_data["mode"] | ||||
|  | @ -143,7 +143,8 @@ def accept_directory(args, them_d, w): | |||
|                                                   transit_receiver.describe())) | ||||
|     f = tempfile.SpooledTemporaryFile() | ||||
|     received = 0 | ||||
|     next_update = start_progress(filesize) | ||||
|     p = ProgressPrinter(filesize, sys.stdout) | ||||
|     p.start() | ||||
|     while received < filesize: | ||||
|         try: | ||||
|             plaintext = record_pipe.receive_record() | ||||
|  | @ -154,8 +155,8 @@ def accept_directory(args, them_d, w): | |||
|             return 1 | ||||
|         f.write(plaintext) | ||||
|         received += len(plaintext) | ||||
|         next_update = update_progress(next_update, received, filesize) | ||||
|     finish_progress(filesize) | ||||
|         p.update(received) | ||||
|     p.finish() | ||||
|     assert received == filesize | ||||
|     print("Unpacking zipfile..") | ||||
|     with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: | ||||
|  |  | |||
|  | @ -1,7 +1,7 @@ | |||
| from __future__ import print_function | ||||
| import json, binascii, six | ||||
| import sys, json, binascii, six | ||||
| from ..errors import TransferError | ||||
| from .progress import start_progress, update_progress, finish_progress | ||||
| from .progress import ProgressPrinter | ||||
| 
 | ||||
| def send_blocking(appid, args, phase1, fd_to_send): | ||||
|     from ..blocking.transcribe import Wormhole, WrongPasswordError | ||||
|  | @ -83,15 +83,16 @@ def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender): | |||
|     fd_to_send.seek(0,2) | ||||
|     filesize = fd_to_send.tell() | ||||
|     fd_to_send.seek(0,0) | ||||
|     p = ProgressPrinter(filesize, sys.stdout) | ||||
|     with fd_to_send as f: | ||||
|         sent = 0 | ||||
|         next_update = start_progress(filesize) | ||||
|         p.start() | ||||
|         while sent < filesize: | ||||
|             plaintext = f.read(CHUNKSIZE) | ||||
|             record_pipe.send_record(plaintext) | ||||
|             sent += len(plaintext) | ||||
|             next_update = update_progress(next_update, sent, filesize) | ||||
|         finish_progress(filesize) | ||||
|             p.update(sent) | ||||
|         p.finish() | ||||
| 
 | ||||
|     print("File sent.. waiting for confirmation") | ||||
|     ack = record_pipe.receive_record() | ||||
|  |  | |||
|  | @ -1,40 +1,49 @@ | |||
| from __future__ import print_function | ||||
| import sys, time | ||||
| import time | ||||
| 
 | ||||
| def print_progress(completed, expected): | ||||
| class ProgressPrinter: | ||||
|     def __init__(self, expected, stdout, update_every=0.2): | ||||
|         self._expected = expected | ||||
|         self._stdout = stdout | ||||
|         self._update_every = update_every | ||||
| 
 | ||||
|     def _now(self): | ||||
|         return time.time() | ||||
| 
 | ||||
|     def start(self): | ||||
|         self._print(0) | ||||
|         self._next_update = self._now() + self._update_every | ||||
| 
 | ||||
|     def update(self, completed): | ||||
|         now = self._now() | ||||
|         if now < self._next_update: | ||||
|             return | ||||
|         self._next_update = now + self._update_every | ||||
|         self._print(completed) | ||||
| 
 | ||||
|     def finish(self): | ||||
|         self._print(self._expected) | ||||
|         print(u"", file=self._stdout) | ||||
| 
 | ||||
|     def _print(self, completed): | ||||
|         # scp does "<<FILENAME >>(13%  168MB  39.3MB/s   00:27 ETA)" | ||||
|         # we do "Progress: ####         13%  168MB" | ||||
|         fmt = "Progress: %-40s %3d%%  %4d%s" | ||||
|         short_unit_size, short_unit_name = 1, "B" | ||||
|     if expected > 9999: | ||||
|         if self._expected > 9999: | ||||
|             short_unit_size, short_unit_name = 1000, "KB" | ||||
|     if expected > 9999*1000: | ||||
|         if self._expected > 9999*1000: | ||||
|             short_unit_size, short_unit_name = 1000*1000, "MB" | ||||
|     if expected > 9999*1000*1000: | ||||
|         if self._expected > 9999*1000*1000: | ||||
|             short_unit_size, short_unit_name = 1000*1000*1000, "GB" | ||||
| 
 | ||||
|     percentage_complete = (1.0 * completed / expected) if expected else 1.0 | ||||
|         percentage_complete = ((1.0 * completed / self._expected) | ||||
|                                if self._expected | ||||
|                                else 1.0) | ||||
|         bars = "#" * int(percentage_complete * 40) | ||||
|         perc = int(100 * percentage_complete) | ||||
|         short_unit_count = int(completed / short_unit_size) | ||||
|         out = fmt % (bars, perc, short_unit_count, short_unit_name) | ||||
|     print("\r"+" "*70, end="") | ||||
|     print("\r"+out, end="") | ||||
|     sys.stdout.flush() | ||||
| 
 | ||||
| def start_progress(expected, UPDATE_EVERY=0.2): | ||||
|     print_progress(0, expected) | ||||
|     next_update = time.time() + UPDATE_EVERY | ||||
|     return next_update | ||||
| 
 | ||||
| def update_progress(next_update, completed, expected, UPDATE_EVERY=0.2): | ||||
|     now = time.time() | ||||
|     if now < next_update: | ||||
|         return next_update | ||||
|     next_update = now + UPDATE_EVERY | ||||
|     print_progress(completed, expected) | ||||
|     return next_update | ||||
| 
 | ||||
| def finish_progress(expected): | ||||
|     print_progress(expected, expected) | ||||
|     print() | ||||
|         print(u"\r"+" "*70, end=u"", file=self._stdout) | ||||
|         print(u"\r"+out, end=u"", file=self._stdout) | ||||
|         self._stdout.flush() | ||||
|  |  | |||
							
								
								
									
										69
									
								
								src/wormhole/test/test_progress.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								src/wormhole/test/test_progress.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,69 @@ | |||
| from __future__ import print_function | ||||
| import io, time | ||||
| from twisted.trial import unittest | ||||
| from ..scripts import progress | ||||
| 
 | ||||
| class Progress(unittest.TestCase): | ||||
|     def test_time(self): | ||||
|         p = progress.ProgressPrinter(1e6, None) | ||||
|         start = time.time() | ||||
|         now = p._now() | ||||
|         finish = time.time() | ||||
|         self.assertTrue(start <= now <= finish, (start, now, finish)) | ||||
| 
 | ||||
|     def test_basic(self): | ||||
|         stdout = io.StringIO() | ||||
|         p = progress.ProgressPrinter(1e6, stdout) | ||||
|         p._now = lambda: 0.0 | ||||
|         p.start() | ||||
| 
 | ||||
|         erase = u"\r"+u" "*70 | ||||
|         expected = erase | ||||
|         fmt = "Progress: %-40s %3d%%  %4d%s" | ||||
|         expected += u"\r" + fmt % ("", 0, 0, "KB") | ||||
|         self.assertEqual(stdout.getvalue(), expected) | ||||
| 
 | ||||
|         p.update(1e3) # no change, too soon | ||||
|         self.assertEqual(stdout.getvalue(), expected) | ||||
| 
 | ||||
|         p._now = lambda: 1.0 | ||||
|         p.update(1e3) # enough "time" has passed | ||||
|         expected += erase + u"\r" + fmt % ("", 0, 1, "KB") | ||||
|         self.assertEqual(stdout.getvalue(), expected) | ||||
| 
 | ||||
|         p._now = lambda: 2.0 | ||||
|         p.update(500e3) | ||||
|         expected += erase + u"\r" + fmt % ("#"*20, 50, 500, "KB") | ||||
|         self.assertEqual(stdout.getvalue(), expected) | ||||
| 
 | ||||
|         p._now = lambda: 3.0 | ||||
|         p.finish() | ||||
|         expected += erase + u"\r" + fmt % ("#"*40, 100, 1000, "KB") | ||||
|         expected += u"\n" | ||||
|         self.assertEqual(stdout.getvalue(), expected) | ||||
| 
 | ||||
|     def test_units(self): | ||||
|         def _try(size): | ||||
|             stdout = io.StringIO() | ||||
|             p = progress.ProgressPrinter(size, stdout) | ||||
|             p.finish() | ||||
|             return stdout.getvalue() | ||||
| 
 | ||||
|         fmt = "Progress: %-40s %3d%%  %4d%s" | ||||
|         def _expect(count, units): | ||||
|             erase = u"\r"+u" "*70 | ||||
|             expected = erase | ||||
|             expected += u"\r" + fmt % ("#"*40, 100, count, units) | ||||
|             expected += u"\n" | ||||
|             return expected | ||||
| 
 | ||||
|         self.assertEqual(_try(900), _expect(900, "B")) | ||||
|         self.assertEqual(_try(9e3), _expect(9000, "B")) | ||||
|         self.assertEqual(_try(90e3), _expect(90, "KB")) | ||||
|         self.assertEqual(_try(900e3), _expect(900, "KB")) | ||||
|         self.assertEqual(_try(9e6), _expect(9000, "KB")) | ||||
|         self.assertEqual(_try(90e6), _expect(90, "MB")) | ||||
|         self.assertEqual(_try(900e6), _expect(900, "MB")) | ||||
|         self.assertEqual(_try(9e9), _expect(9000, "MB")) | ||||
|         self.assertEqual(_try(90e9), _expect(90, "GB")) | ||||
|         self.assertEqual(_try(900e9), _expect(900, "GB")) | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user