runner: strictly use cwd/stdout/stderr from 'args'

This will make it easier to test the scripts in a controlled fashion.
This commit is contained in:
Brian Warner 2016-02-17 12:22:10 -08:00
parent e2f3bebe38
commit c5b2800a3e
4 changed files with 82 additions and 63 deletions

View File

@ -15,21 +15,24 @@ def accept_file(args, them_d, w):
# the basename() is intended to protect us against
# "~/.ssh/authorized_keys" and other attacks
filename = os.path.basename(file_data["filename"]) # unicode
abs_filename = os.path.join(args.cwd, filename)
filesize = file_data["filesize"]
# get confirmation from the user before writing to the local directory
if os.path.exists(filename):
print("Error: refusing to overwrite existing file %s" % (filename,))
if os.path.exists(abs_filename):
print(u"Error: refusing to overwrite existing file %s" % (filename,),
file=args.stdout)
data = json.dumps({"error": "file already exists"}).encode("utf-8")
w.send_data(data)
return 1
print("Receiving file (%d bytes) into: %s" % (filesize, filename))
print(u"Receiving file (%d bytes) into: %s" % (filesize, filename),
file=args.stdout)
while True and not args.accept_file:
ok = six.moves.input("ok? (y/n): ")
if ok.lower().startswith("y"):
break
print("transfer rejected", file=sys.stderr)
print(u"transfer rejected", file=sys.stderr)
data = json.dumps({"error": "transfer rejected"}).encode("utf-8")
w.send_data(data)
return 1
@ -53,9 +56,9 @@ def accept_file(args, them_d, w):
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = transit_receiver.connect()
print("Receiving %d bytes for '%s' (%s).." % (filesize, filename,
transit_receiver.describe()))
tmp = filename + ".tmp"
print(u"Receiving %d bytes for '%s' (%s).." %
(filesize, filename, transit_receiver.describe()), file=args.stdout)
tmp = abs_filename + ".tmp"
with open(tmp, "wb") as f:
received = 0
p = ProgressPrinter(filesize, sys.stdout)
@ -64,9 +67,11 @@ def accept_file(args, them_d, w):
try:
plaintext = record_pipe.receive_record()
except TransitError:
print()
print("Connection dropped before full file received")
print("got %d bytes, wanted %d" % (received, filesize))
print(u"", file=args.stdout)
print(u"Connection dropped before full file received",
file=args.stdout)
print(u"got %d bytes, wanted %d" % (received, filesize),
file=args.stdout)
return 1
f.write(plaintext)
received += len(plaintext)
@ -74,9 +79,9 @@ def accept_file(args, them_d, w):
p.finish()
assert received == filesize
os.rename(tmp, filename)
os.rename(tmp, abs_filename)
print("Received file written to %s" % filename)
print(u"Received file written to %s" % filename, file=args.stdout)
record_pipe.send_record(b"ok\n")
record_pipe.close()
return 0
@ -87,7 +92,8 @@ def accept_directory(args, them_d, w):
file_data = them_d["directory"]
mode = file_data["mode"]
if mode != "zipfile/deflated":
print("Error: unknown directory-transfer mode '%s'" % (mode,))
print(u"Error: unknown directory-transfer mode '%s'" % (mode,),
file=args.stdout)
data = json.dumps({"error": "unknown mode"}).encode("utf-8")
w.send_data(data)
return 1
@ -98,24 +104,26 @@ def accept_directory(args, them_d, w):
# the basename() is intended to protect us against
# "~/.ssh/authorized_keys" and other attacks
dirname = os.path.basename(file_data["dirname"]) # unicode
abs_dirname = os.path.join(args.cwd, dirname)
filesize = file_data["zipsize"]
num_files = file_data["numfiles"]
num_bytes = file_data["numbytes"]
if os.path.exists(dirname):
print("Error: refusing to overwrite existing directory %s" % (dirname,))
if os.path.exists(abs_dirname):
print(u"Error: refusing to overwrite existing directory %s" %
(dirname,), file=args.stdout)
data = json.dumps({"error": "directory already exists"}).encode("utf-8")
w.send_data(data)
return 1
print("Receiving directory into: %s/" % (dirname,))
print("%d files, %d bytes (%d compressed)" % (num_files, num_bytes,
filesize))
print(u"Receiving directory into: %s/" % (dirname,), file=args.stdout)
print(u"%d files, %d bytes (%d compressed)" %
(num_files, num_bytes, filesize), file=args.stdout)
while True and not args.accept_file:
ok = six.moves.input("ok? (y/n): ")
if ok.lower().startswith("y"):
break
print("transfer rejected", file=sys.stderr)
print(u"transfer rejected", file=sys.stderr)
data = json.dumps({"error": "transfer rejected"}).encode("utf-8")
w.send_data(data)
return 1
@ -139,8 +147,8 @@ def accept_directory(args, them_d, w):
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = transit_receiver.connect()
print("Receiving %d bytes for '%s' (%s).." % (filesize, dirname,
transit_receiver.describe()))
print(u"Receiving %d bytes for '%s' (%s).." %
(filesize, dirname, transit_receiver.describe()), file=args.stdout)
f = tempfile.SpooledTemporaryFile()
received = 0
p = ProgressPrinter(filesize, sys.stdout)
@ -149,23 +157,25 @@ def accept_directory(args, them_d, w):
try:
plaintext = record_pipe.receive_record()
except TransitError:
print()
print("Connection dropped before full file received")
print("got %d bytes, wanted %d" % (received, filesize))
print(u"", file=args.stdout)
print(u"Connection dropped before full file received",
file=args.stdout)
print(u"got %d bytes, wanted %d" % (received, filesize),
file=args.stdout)
return 1
f.write(plaintext)
received += len(plaintext)
p.update(received)
p.finish()
assert received == filesize
print("Unpacking zipfile..")
print(u"Unpacking zipfile..", file=args.stdout)
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
zf.extractall(path=dirname)
zf.extractall(path=abs_dirname)
# extractall() appears to offer some protection against malicious
# pathnames. For example, "/tmp/oops" and "../tmp/oops" both do the
# same thing as the (safe) "tmp/oops".
print("Received files written to %s/" % dirname)
print(u"Received files written to %s/" % dirname, file=args.stdout)
record_pipe.send_record(b"ok\n")
record_pipe.close()
return 0
@ -187,27 +197,27 @@ def receive(args):
if args.verify:
verifier = binascii.hexlify(w.get_verifier()).decode("ascii")
print("Verifier %s." % verifier)
print(u"Verifier %s." % verifier, file=args.stdout)
try:
them_bytes = w.get_data()
except WrongPasswordError as e:
print("ERROR: " + e.explain(), file=sys.stderr)
print(u"ERROR: " + e.explain(), file=sys.stderr)
return 1
them_d = json.loads(them_bytes.decode("utf-8"))
if "error" in them_d:
print("ERROR: " + them_d["error"], file=sys.stderr)
print(u"ERROR: " + them_d["error"], file=sys.stderr)
return 1
if "message" in them_d:
# we're receiving a text message
print(them_d["message"])
print(them_d["message"], file=args.stdout)
data = json.dumps({"message_ack": "ok"}).encode("utf-8")
w.send_data(data)
return 0
if "error" in them_d:
print("ERROR: " + data["error"], file=sys.stderr)
print(u"ERROR: " + data["error"], file=sys.stderr)
return 1
if "file" in them_d:
@ -216,8 +226,8 @@ def receive(args):
if "directory" in them_d:
return accept_directory(args, them_d, w)
print("I don't know what they're offering\n")
print("Offer details:", them_d)
print(u"I don't know what they're offering\n", file=args.stdout)
print(u"Offer details:", them_d, file=args.stdout)
data = json.dumps({"error": "unknown offer type"}).encode("utf-8")
w.send_data(data)
return 1

View File

@ -11,17 +11,18 @@ def send(args):
text = args.text
if text == "-":
print("Reading text message from stdin..")
print(u"Reading text message from stdin..", file=args.stdout)
text = sys.stdin.read()
if not text and not args.what:
text = six.moves.input("Text to send: ")
if text is not None:
print("Sending text message (%d bytes)" % len(text))
print(u"Sending text message (%d bytes)" % len(text), file=args.stdout)
phase1 = { "message": text }
fd_to_send = None
else:
if not os.path.exists(args.what):
what = os.path.join(args.cwd, args.what)
if not os.path.exists(what):
raise TransferError("Cannot send: no file/directory named '%s'" %
args.what)
phase1, fd_to_send = _build_phase1_data(args)
@ -36,7 +37,8 @@ def send(args):
other_cmd = "wormhole --verify receive"
if args.zeromode:
other_cmd += " -0"
print("On the other computer, please run: %s" % other_cmd)
print(u"On the other computer, please run: %s" % other_cmd,
file=args.stdout)
from .cmd_send_blocking import send_blocking
rc = send_blocking(APPID, args, phase1, fd_to_send)
@ -44,27 +46,29 @@ def send(args):
def _build_phase1_data(args):
phase1 = {}
basename = os.path.basename(args.what)
if os.path.isfile(args.what):
what = os.path.join(args.cwd, args.what)
basename = os.path.basename(what)
if os.path.isfile(what):
# we're sending a file
filesize = os.stat(args.what).st_size
filesize = os.stat(what).st_size
phase1["file"] = {
"filename": basename,
"filesize": filesize,
}
print("Sending %d byte file named '%s'" % (filesize, basename))
fd_to_send = open(args.what, "rb")
elif os.path.isdir(args.what):
print("Building zipfile..")
print(u"Sending %d byte file named '%s'" % (filesize, basename),
file=args.stdout)
fd_to_send = open(what, "rb")
elif os.path.isdir(what):
print(u"Building zipfile..", file=args.stdout)
# We're sending a directory. Create a zipfile in a tempdir and
# send that.
fd_to_send = tempfile.SpooledTemporaryFile()
# TODO: I think ZIP_DEFLATED means compressed.. check it
num_files = 0
num_bytes = 0
tostrip = len(args.what.split(os.sep))
tostrip = len(what.split(os.sep))
with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf:
for path,dirs,files in os.walk(args.what):
for path,dirs,files in os.walk(what):
# path always starts with args.what, then sometimes might
# have "/subdir" appended. We want the zipfile to contain
# "" or "subdir"
@ -85,8 +89,8 @@ def _build_phase1_data(args):
"numbytes": num_bytes,
"numfiles": num_files,
}
print("Sending directory (%d bytes compressed) named '%s'"
% (filesize, basename))
print(u"Sending directory (%d bytes compressed) named '%s'"
% (filesize, basename), file=args.stdout)
else:
raise TypeError("'%s' is neither file nor directory" % args.what)
raise TypeError("'%s' is neither file nor directory" % what)
return phase1, fd_to_send

View File

@ -1,5 +1,5 @@
from __future__ import print_function
import sys, json, binascii, six
import json, binascii, six
from ..errors import TransferError
from .progress import ProgressPrinter
@ -21,8 +21,8 @@ def send_blocking(appid, args, phase1, fd_to_send):
else:
code = w.get_code(args.code_length)
if not args.zeromode:
print("Wormhole code is: %s" % code)
print("")
print(u"Wormhole code is: %s" % code, file=args.stdout)
print(u"", file=args.stdout)
if args.verify:
_do_verify(w)
@ -40,12 +40,12 @@ def send_blocking(appid, args, phase1, fd_to_send):
if fd_to_send is None:
if them_phase1["message_ack"] == "ok":
print("text message sent")
print(u"text message sent", file=args.stdout)
return 0
raise TransferError("error sending text: %r" % (them_phase1,))
return _send_file_blocking(w, appid, them_phase1, fd_to_send,
transit_sender)
transit_sender, args.stdout)
def _do_verify(w):
verifier = binascii.hexlify(w.get_verifier()).decode("ascii")
@ -59,7 +59,8 @@ def _do_verify(w):
w.send_data(reject_data)
raise TransferError("verification rejected, abandoning transfer")
def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender):
def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender,
stdout):
# we're sending a file, if they accept it
@ -77,13 +78,13 @@ def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender):
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = transit_sender.connect()
print("Sending (%s).." % transit_sender.describe())
print(u"Sending (%s).." % transit_sender.describe(), file=stdout)
CHUNKSIZE = 64*1024
fd_to_send.seek(0,2)
filesize = fd_to_send.tell()
fd_to_send.seek(0,0)
p = ProgressPrinter(filesize, sys.stdout)
p = ProgressPrinter(filesize, stdout)
with fd_to_send as f:
sent = 0
p.start()
@ -94,10 +95,10 @@ def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender):
p.update(sent)
p.finish()
print("File sent.. waiting for confirmation")
print(u"File sent.. waiting for confirmation", file=stdout)
ack = record_pipe.receive_record()
record_pipe.close()
if ack == b"ok\n":
print("Confirmation received. Transfer complete.")
print(u"Confirmation received. Transfer complete.", file=stdout)
return 0
raise TransferError("Transfer failed (remote says: %r)" % ack)

View File

@ -1,5 +1,5 @@
from __future__ import print_function
import sys, argparse
import os, sys, argparse
from textwrap import dedent
from .. import public_relay
from .. import __version__
@ -120,7 +120,7 @@ p.set_defaults(func=cmd_receive.receive)
def run(args, stdout, stderr, executable=None):
def run(args, cwd, stdout, stderr, executable=None):
"""This is invoked directly by the 'wormhole' entry-point script. It can
also invoked by entry() below."""
@ -130,6 +130,9 @@ def run(args, stdout, stderr, executable=None):
# "error: too few arguments" during parse_args().
parser.print_help()
sys.exit(0)
args.cwd = cwd
args.stdout = stdout
args.stderr = stderr
try:
#rc = command.func(args, stdout, stderr)
rc = args.func(args)
@ -147,7 +150,8 @@ def run(args, stdout, stderr, executable=None):
def entry():
"""This is used by a setuptools entry_point. When invoked this way,
setuptools has already put the installed package on sys.path ."""
return run(sys.argv[1:], sys.stdout, sys.stderr, executable=sys.argv[0])
return run(sys.argv[1:], os.getcwd(), sys.stdout, sys.stderr,
executable=sys.argv[0])
if __name__ == "__main__":
args = parser.parse_args()