From c5b2800a3e73a7be7d7e443d703b9e7fb98a920d Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 17 Feb 2016 12:22:10 -0800 Subject: [PATCH] runner: strictly use cwd/stdout/stderr from 'args' This will make it easier to test the scripts in a controlled fashion. --- src/wormhole/scripts/cmd_receive.py | 78 +++++++++++++---------- src/wormhole/scripts/cmd_send.py | 36 ++++++----- src/wormhole/scripts/cmd_send_blocking.py | 21 +++--- src/wormhole/scripts/runner.py | 10 ++- 4 files changed, 82 insertions(+), 63 deletions(-) diff --git a/src/wormhole/scripts/cmd_receive.py b/src/wormhole/scripts/cmd_receive.py index a54f788..12003de 100644 --- a/src/wormhole/scripts/cmd_receive.py +++ b/src/wormhole/scripts/cmd_receive.py @@ -15,21 +15,24 @@ def accept_file(args, them_d, w): # the basename() is intended to protect us against # "~/.ssh/authorized_keys" and other attacks filename = os.path.basename(file_data["filename"]) # unicode + abs_filename = os.path.join(args.cwd, filename) filesize = file_data["filesize"] # get confirmation from the user before writing to the local directory - if os.path.exists(filename): - print("Error: refusing to overwrite existing file %s" % (filename,)) + if os.path.exists(abs_filename): + print(u"Error: refusing to overwrite existing file %s" % (filename,), + file=args.stdout) data = json.dumps({"error": "file already exists"}).encode("utf-8") w.send_data(data) return 1 - print("Receiving file (%d bytes) into: %s" % (filesize, filename)) + print(u"Receiving file (%d bytes) into: %s" % (filesize, filename), + file=args.stdout) while True and not args.accept_file: ok = six.moves.input("ok? (y/n): ") if ok.lower().startswith("y"): break - print("transfer rejected", file=sys.stderr) + print(u"transfer rejected", file=sys.stderr) data = json.dumps({"error": "transfer rejected"}).encode("utf-8") w.send_data(data) return 1 @@ -53,9 +56,9 @@ def accept_file(args, them_d, w): transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"]) record_pipe = transit_receiver.connect() - print("Receiving %d bytes for '%s' (%s).." % (filesize, filename, - transit_receiver.describe())) - tmp = filename + ".tmp" + print(u"Receiving %d bytes for '%s' (%s).." % + (filesize, filename, transit_receiver.describe()), file=args.stdout) + tmp = abs_filename + ".tmp" with open(tmp, "wb") as f: received = 0 p = ProgressPrinter(filesize, sys.stdout) @@ -64,9 +67,11 @@ def accept_file(args, them_d, w): try: plaintext = record_pipe.receive_record() except TransitError: - print() - print("Connection dropped before full file received") - print("got %d bytes, wanted %d" % (received, filesize)) + print(u"", file=args.stdout) + print(u"Connection dropped before full file received", + file=args.stdout) + print(u"got %d bytes, wanted %d" % (received, filesize), + file=args.stdout) return 1 f.write(plaintext) received += len(plaintext) @@ -74,9 +79,9 @@ def accept_file(args, them_d, w): p.finish() assert received == filesize - os.rename(tmp, filename) + os.rename(tmp, abs_filename) - print("Received file written to %s" % filename) + print(u"Received file written to %s" % filename, file=args.stdout) record_pipe.send_record(b"ok\n") record_pipe.close() return 0 @@ -87,7 +92,8 @@ def accept_directory(args, them_d, w): file_data = them_d["directory"] mode = file_data["mode"] if mode != "zipfile/deflated": - print("Error: unknown directory-transfer mode '%s'" % (mode,)) + print(u"Error: unknown directory-transfer mode '%s'" % (mode,), + file=args.stdout) data = json.dumps({"error": "unknown mode"}).encode("utf-8") w.send_data(data) return 1 @@ -98,24 +104,26 @@ def accept_directory(args, them_d, w): # the basename() is intended to protect us against # "~/.ssh/authorized_keys" and other attacks dirname = os.path.basename(file_data["dirname"]) # unicode + abs_dirname = os.path.join(args.cwd, dirname) filesize = file_data["zipsize"] num_files = file_data["numfiles"] num_bytes = file_data["numbytes"] - if os.path.exists(dirname): - print("Error: refusing to overwrite existing directory %s" % (dirname,)) + if os.path.exists(abs_dirname): + print(u"Error: refusing to overwrite existing directory %s" % + (dirname,), file=args.stdout) data = json.dumps({"error": "directory already exists"}).encode("utf-8") w.send_data(data) return 1 - print("Receiving directory into: %s/" % (dirname,)) - print("%d files, %d bytes (%d compressed)" % (num_files, num_bytes, - filesize)) + print(u"Receiving directory into: %s/" % (dirname,), file=args.stdout) + print(u"%d files, %d bytes (%d compressed)" % + (num_files, num_bytes, filesize), file=args.stdout) while True and not args.accept_file: ok = six.moves.input("ok? (y/n): ") if ok.lower().startswith("y"): break - print("transfer rejected", file=sys.stderr) + print(u"transfer rejected", file=sys.stderr) data = json.dumps({"error": "transfer rejected"}).encode("utf-8") w.send_data(data) return 1 @@ -139,8 +147,8 @@ def accept_directory(args, them_d, w): transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"]) record_pipe = transit_receiver.connect() - print("Receiving %d bytes for '%s' (%s).." % (filesize, dirname, - transit_receiver.describe())) + print(u"Receiving %d bytes for '%s' (%s).." % + (filesize, dirname, transit_receiver.describe()), file=args.stdout) f = tempfile.SpooledTemporaryFile() received = 0 p = ProgressPrinter(filesize, sys.stdout) @@ -149,23 +157,25 @@ def accept_directory(args, them_d, w): try: plaintext = record_pipe.receive_record() except TransitError: - print() - print("Connection dropped before full file received") - print("got %d bytes, wanted %d" % (received, filesize)) + print(u"", file=args.stdout) + print(u"Connection dropped before full file received", + file=args.stdout) + print(u"got %d bytes, wanted %d" % (received, filesize), + file=args.stdout) return 1 f.write(plaintext) received += len(plaintext) p.update(received) p.finish() assert received == filesize - print("Unpacking zipfile..") + print(u"Unpacking zipfile..", file=args.stdout) with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: - zf.extractall(path=dirname) + zf.extractall(path=abs_dirname) # extractall() appears to offer some protection against malicious # pathnames. For example, "/tmp/oops" and "../tmp/oops" both do the # same thing as the (safe) "tmp/oops". - print("Received files written to %s/" % dirname) + print(u"Received files written to %s/" % dirname, file=args.stdout) record_pipe.send_record(b"ok\n") record_pipe.close() return 0 @@ -187,27 +197,27 @@ def receive(args): if args.verify: verifier = binascii.hexlify(w.get_verifier()).decode("ascii") - print("Verifier %s." % verifier) + print(u"Verifier %s." % verifier, file=args.stdout) try: them_bytes = w.get_data() except WrongPasswordError as e: - print("ERROR: " + e.explain(), file=sys.stderr) + print(u"ERROR: " + e.explain(), file=sys.stderr) return 1 them_d = json.loads(them_bytes.decode("utf-8")) if "error" in them_d: - print("ERROR: " + them_d["error"], file=sys.stderr) + print(u"ERROR: " + them_d["error"], file=sys.stderr) return 1 if "message" in them_d: # we're receiving a text message - print(them_d["message"]) + print(them_d["message"], file=args.stdout) data = json.dumps({"message_ack": "ok"}).encode("utf-8") w.send_data(data) return 0 if "error" in them_d: - print("ERROR: " + data["error"], file=sys.stderr) + print(u"ERROR: " + data["error"], file=sys.stderr) return 1 if "file" in them_d: @@ -216,8 +226,8 @@ def receive(args): if "directory" in them_d: return accept_directory(args, them_d, w) - print("I don't know what they're offering\n") - print("Offer details:", them_d) + print(u"I don't know what they're offering\n", file=args.stdout) + print(u"Offer details:", them_d, file=args.stdout) data = json.dumps({"error": "unknown offer type"}).encode("utf-8") w.send_data(data) return 1 diff --git a/src/wormhole/scripts/cmd_send.py b/src/wormhole/scripts/cmd_send.py index 141146a..e366822 100644 --- a/src/wormhole/scripts/cmd_send.py +++ b/src/wormhole/scripts/cmd_send.py @@ -11,17 +11,18 @@ def send(args): text = args.text if text == "-": - print("Reading text message from stdin..") + print(u"Reading text message from stdin..", file=args.stdout) text = sys.stdin.read() if not text and not args.what: text = six.moves.input("Text to send: ") if text is not None: - print("Sending text message (%d bytes)" % len(text)) + print(u"Sending text message (%d bytes)" % len(text), file=args.stdout) phase1 = { "message": text } fd_to_send = None else: - if not os.path.exists(args.what): + what = os.path.join(args.cwd, args.what) + if not os.path.exists(what): raise TransferError("Cannot send: no file/directory named '%s'" % args.what) phase1, fd_to_send = _build_phase1_data(args) @@ -36,7 +37,8 @@ def send(args): other_cmd = "wormhole --verify receive" if args.zeromode: other_cmd += " -0" - print("On the other computer, please run: %s" % other_cmd) + print(u"On the other computer, please run: %s" % other_cmd, + file=args.stdout) from .cmd_send_blocking import send_blocking rc = send_blocking(APPID, args, phase1, fd_to_send) @@ -44,27 +46,29 @@ def send(args): def _build_phase1_data(args): phase1 = {} - basename = os.path.basename(args.what) - if os.path.isfile(args.what): + what = os.path.join(args.cwd, args.what) + basename = os.path.basename(what) + if os.path.isfile(what): # we're sending a file - filesize = os.stat(args.what).st_size + filesize = os.stat(what).st_size phase1["file"] = { "filename": basename, "filesize": filesize, } - print("Sending %d byte file named '%s'" % (filesize, basename)) - fd_to_send = open(args.what, "rb") - elif os.path.isdir(args.what): - print("Building zipfile..") + print(u"Sending %d byte file named '%s'" % (filesize, basename), + file=args.stdout) + fd_to_send = open(what, "rb") + elif os.path.isdir(what): + print(u"Building zipfile..", file=args.stdout) # We're sending a directory. Create a zipfile in a tempdir and # send that. fd_to_send = tempfile.SpooledTemporaryFile() # TODO: I think ZIP_DEFLATED means compressed.. check it num_files = 0 num_bytes = 0 - tostrip = len(args.what.split(os.sep)) + tostrip = len(what.split(os.sep)) with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf: - for path,dirs,files in os.walk(args.what): + for path,dirs,files in os.walk(what): # path always starts with args.what, then sometimes might # have "/subdir" appended. We want the zipfile to contain # "" or "subdir" @@ -85,8 +89,8 @@ def _build_phase1_data(args): "numbytes": num_bytes, "numfiles": num_files, } - print("Sending directory (%d bytes compressed) named '%s'" - % (filesize, basename)) + print(u"Sending directory (%d bytes compressed) named '%s'" + % (filesize, basename), file=args.stdout) else: - raise TypeError("'%s' is neither file nor directory" % args.what) + raise TypeError("'%s' is neither file nor directory" % what) return phase1, fd_to_send diff --git a/src/wormhole/scripts/cmd_send_blocking.py b/src/wormhole/scripts/cmd_send_blocking.py index 0d52751..90a3bac 100644 --- a/src/wormhole/scripts/cmd_send_blocking.py +++ b/src/wormhole/scripts/cmd_send_blocking.py @@ -1,5 +1,5 @@ from __future__ import print_function -import sys, json, binascii, six +import json, binascii, six from ..errors import TransferError from .progress import ProgressPrinter @@ -21,8 +21,8 @@ def send_blocking(appid, args, phase1, fd_to_send): else: code = w.get_code(args.code_length) if not args.zeromode: - print("Wormhole code is: %s" % code) - print("") + print(u"Wormhole code is: %s" % code, file=args.stdout) + print(u"", file=args.stdout) if args.verify: _do_verify(w) @@ -40,12 +40,12 @@ def send_blocking(appid, args, phase1, fd_to_send): if fd_to_send is None: if them_phase1["message_ack"] == "ok": - print("text message sent") + print(u"text message sent", file=args.stdout) return 0 raise TransferError("error sending text: %r" % (them_phase1,)) return _send_file_blocking(w, appid, them_phase1, fd_to_send, - transit_sender) + transit_sender, args.stdout) def _do_verify(w): verifier = binascii.hexlify(w.get_verifier()).decode("ascii") @@ -59,7 +59,8 @@ def _do_verify(w): w.send_data(reject_data) raise TransferError("verification rejected, abandoning transfer") -def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender): +def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender, + stdout): # we're sending a file, if they accept it @@ -77,13 +78,13 @@ def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender): transit_sender.add_their_relay_hints(tdata["relay_connection_hints"]) record_pipe = transit_sender.connect() - print("Sending (%s).." % transit_sender.describe()) + print(u"Sending (%s).." % transit_sender.describe(), file=stdout) CHUNKSIZE = 64*1024 fd_to_send.seek(0,2) filesize = fd_to_send.tell() fd_to_send.seek(0,0) - p = ProgressPrinter(filesize, sys.stdout) + p = ProgressPrinter(filesize, stdout) with fd_to_send as f: sent = 0 p.start() @@ -94,10 +95,10 @@ def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender): p.update(sent) p.finish() - print("File sent.. waiting for confirmation") + print(u"File sent.. waiting for confirmation", file=stdout) ack = record_pipe.receive_record() record_pipe.close() if ack == b"ok\n": - print("Confirmation received. Transfer complete.") + print(u"Confirmation received. Transfer complete.", file=stdout) return 0 raise TransferError("Transfer failed (remote says: %r)" % ack) diff --git a/src/wormhole/scripts/runner.py b/src/wormhole/scripts/runner.py index 707ce99..83e584a 100644 --- a/src/wormhole/scripts/runner.py +++ b/src/wormhole/scripts/runner.py @@ -1,5 +1,5 @@ from __future__ import print_function -import sys, argparse +import os, sys, argparse from textwrap import dedent from .. import public_relay from .. import __version__ @@ -120,7 +120,7 @@ p.set_defaults(func=cmd_receive.receive) -def run(args, stdout, stderr, executable=None): +def run(args, cwd, stdout, stderr, executable=None): """This is invoked directly by the 'wormhole' entry-point script. It can also invoked by entry() below.""" @@ -130,6 +130,9 @@ def run(args, stdout, stderr, executable=None): # "error: too few arguments" during parse_args(). parser.print_help() sys.exit(0) + args.cwd = cwd + args.stdout = stdout + args.stderr = stderr try: #rc = command.func(args, stdout, stderr) rc = args.func(args) @@ -147,7 +150,8 @@ def run(args, stdout, stderr, executable=None): def entry(): """This is used by a setuptools entry_point. When invoked this way, setuptools has already put the installed package on sys.path .""" - return run(sys.argv[1:], sys.stdout, sys.stderr, executable=sys.argv[0]) + return run(sys.argv[1:], os.getcwd(), sys.stdout, sys.stderr, + executable=sys.argv[0]) if __name__ == "__main__": args = parser.parse_args()