runner: strictly use cwd/stdout/stderr from 'args'

This will make it easier to test the scripts in a controlled fashion.
This commit is contained in:
Brian Warner 2016-02-17 12:22:10 -08:00
parent e2f3bebe38
commit c5b2800a3e
4 changed files with 82 additions and 63 deletions

View File

@ -15,21 +15,24 @@ def accept_file(args, them_d, w):
# the basename() is intended to protect us against # the basename() is intended to protect us against
# "~/.ssh/authorized_keys" and other attacks # "~/.ssh/authorized_keys" and other attacks
filename = os.path.basename(file_data["filename"]) # unicode filename = os.path.basename(file_data["filename"]) # unicode
abs_filename = os.path.join(args.cwd, filename)
filesize = file_data["filesize"] filesize = file_data["filesize"]
# get confirmation from the user before writing to the local directory # get confirmation from the user before writing to the local directory
if os.path.exists(filename): if os.path.exists(abs_filename):
print("Error: refusing to overwrite existing file %s" % (filename,)) print(u"Error: refusing to overwrite existing file %s" % (filename,),
file=args.stdout)
data = json.dumps({"error": "file already exists"}).encode("utf-8") data = json.dumps({"error": "file already exists"}).encode("utf-8")
w.send_data(data) w.send_data(data)
return 1 return 1
print("Receiving file (%d bytes) into: %s" % (filesize, filename)) print(u"Receiving file (%d bytes) into: %s" % (filesize, filename),
file=args.stdout)
while True and not args.accept_file: while True and not args.accept_file:
ok = six.moves.input("ok? (y/n): ") ok = six.moves.input("ok? (y/n): ")
if ok.lower().startswith("y"): if ok.lower().startswith("y"):
break break
print("transfer rejected", file=sys.stderr) print(u"transfer rejected", file=sys.stderr)
data = json.dumps({"error": "transfer rejected"}).encode("utf-8") data = json.dumps({"error": "transfer rejected"}).encode("utf-8")
w.send_data(data) w.send_data(data)
return 1 return 1
@ -53,9 +56,9 @@ def accept_file(args, them_d, w):
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"]) transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = transit_receiver.connect() record_pipe = transit_receiver.connect()
print("Receiving %d bytes for '%s' (%s).." % (filesize, filename, print(u"Receiving %d bytes for '%s' (%s).." %
transit_receiver.describe())) (filesize, filename, transit_receiver.describe()), file=args.stdout)
tmp = filename + ".tmp" tmp = abs_filename + ".tmp"
with open(tmp, "wb") as f: with open(tmp, "wb") as f:
received = 0 received = 0
p = ProgressPrinter(filesize, sys.stdout) p = ProgressPrinter(filesize, sys.stdout)
@ -64,9 +67,11 @@ def accept_file(args, them_d, w):
try: try:
plaintext = record_pipe.receive_record() plaintext = record_pipe.receive_record()
except TransitError: except TransitError:
print() print(u"", file=args.stdout)
print("Connection dropped before full file received") print(u"Connection dropped before full file received",
print("got %d bytes, wanted %d" % (received, filesize)) file=args.stdout)
print(u"got %d bytes, wanted %d" % (received, filesize),
file=args.stdout)
return 1 return 1
f.write(plaintext) f.write(plaintext)
received += len(plaintext) received += len(plaintext)
@ -74,9 +79,9 @@ def accept_file(args, them_d, w):
p.finish() p.finish()
assert received == filesize assert received == filesize
os.rename(tmp, filename) os.rename(tmp, abs_filename)
print("Received file written to %s" % filename) print(u"Received file written to %s" % filename, file=args.stdout)
record_pipe.send_record(b"ok\n") record_pipe.send_record(b"ok\n")
record_pipe.close() record_pipe.close()
return 0 return 0
@ -87,7 +92,8 @@ def accept_directory(args, them_d, w):
file_data = them_d["directory"] file_data = them_d["directory"]
mode = file_data["mode"] mode = file_data["mode"]
if mode != "zipfile/deflated": if mode != "zipfile/deflated":
print("Error: unknown directory-transfer mode '%s'" % (mode,)) print(u"Error: unknown directory-transfer mode '%s'" % (mode,),
file=args.stdout)
data = json.dumps({"error": "unknown mode"}).encode("utf-8") data = json.dumps({"error": "unknown mode"}).encode("utf-8")
w.send_data(data) w.send_data(data)
return 1 return 1
@ -98,24 +104,26 @@ def accept_directory(args, them_d, w):
# the basename() is intended to protect us against # the basename() is intended to protect us against
# "~/.ssh/authorized_keys" and other attacks # "~/.ssh/authorized_keys" and other attacks
dirname = os.path.basename(file_data["dirname"]) # unicode dirname = os.path.basename(file_data["dirname"]) # unicode
abs_dirname = os.path.join(args.cwd, dirname)
filesize = file_data["zipsize"] filesize = file_data["zipsize"]
num_files = file_data["numfiles"] num_files = file_data["numfiles"]
num_bytes = file_data["numbytes"] num_bytes = file_data["numbytes"]
if os.path.exists(dirname): if os.path.exists(abs_dirname):
print("Error: refusing to overwrite existing directory %s" % (dirname,)) print(u"Error: refusing to overwrite existing directory %s" %
(dirname,), file=args.stdout)
data = json.dumps({"error": "directory already exists"}).encode("utf-8") data = json.dumps({"error": "directory already exists"}).encode("utf-8")
w.send_data(data) w.send_data(data)
return 1 return 1
print("Receiving directory into: %s/" % (dirname,)) print(u"Receiving directory into: %s/" % (dirname,), file=args.stdout)
print("%d files, %d bytes (%d compressed)" % (num_files, num_bytes, print(u"%d files, %d bytes (%d compressed)" %
filesize)) (num_files, num_bytes, filesize), file=args.stdout)
while True and not args.accept_file: while True and not args.accept_file:
ok = six.moves.input("ok? (y/n): ") ok = six.moves.input("ok? (y/n): ")
if ok.lower().startswith("y"): if ok.lower().startswith("y"):
break break
print("transfer rejected", file=sys.stderr) print(u"transfer rejected", file=sys.stderr)
data = json.dumps({"error": "transfer rejected"}).encode("utf-8") data = json.dumps({"error": "transfer rejected"}).encode("utf-8")
w.send_data(data) w.send_data(data)
return 1 return 1
@ -139,8 +147,8 @@ def accept_directory(args, them_d, w):
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"]) transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = transit_receiver.connect() record_pipe = transit_receiver.connect()
print("Receiving %d bytes for '%s' (%s).." % (filesize, dirname, print(u"Receiving %d bytes for '%s' (%s).." %
transit_receiver.describe())) (filesize, dirname, transit_receiver.describe()), file=args.stdout)
f = tempfile.SpooledTemporaryFile() f = tempfile.SpooledTemporaryFile()
received = 0 received = 0
p = ProgressPrinter(filesize, sys.stdout) p = ProgressPrinter(filesize, sys.stdout)
@ -149,23 +157,25 @@ def accept_directory(args, them_d, w):
try: try:
plaintext = record_pipe.receive_record() plaintext = record_pipe.receive_record()
except TransitError: except TransitError:
print() print(u"", file=args.stdout)
print("Connection dropped before full file received") print(u"Connection dropped before full file received",
print("got %d bytes, wanted %d" % (received, filesize)) file=args.stdout)
print(u"got %d bytes, wanted %d" % (received, filesize),
file=args.stdout)
return 1 return 1
f.write(plaintext) f.write(plaintext)
received += len(plaintext) received += len(plaintext)
p.update(received) p.update(received)
p.finish() p.finish()
assert received == filesize assert received == filesize
print("Unpacking zipfile..") print(u"Unpacking zipfile..", file=args.stdout)
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
zf.extractall(path=dirname) zf.extractall(path=abs_dirname)
# extractall() appears to offer some protection against malicious # extractall() appears to offer some protection against malicious
# pathnames. For example, "/tmp/oops" and "../tmp/oops" both do the # pathnames. For example, "/tmp/oops" and "../tmp/oops" both do the
# same thing as the (safe) "tmp/oops". # same thing as the (safe) "tmp/oops".
print("Received files written to %s/" % dirname) print(u"Received files written to %s/" % dirname, file=args.stdout)
record_pipe.send_record(b"ok\n") record_pipe.send_record(b"ok\n")
record_pipe.close() record_pipe.close()
return 0 return 0
@ -187,27 +197,27 @@ def receive(args):
if args.verify: if args.verify:
verifier = binascii.hexlify(w.get_verifier()).decode("ascii") verifier = binascii.hexlify(w.get_verifier()).decode("ascii")
print("Verifier %s." % verifier) print(u"Verifier %s." % verifier, file=args.stdout)
try: try:
them_bytes = w.get_data() them_bytes = w.get_data()
except WrongPasswordError as e: except WrongPasswordError as e:
print("ERROR: " + e.explain(), file=sys.stderr) print(u"ERROR: " + e.explain(), file=sys.stderr)
return 1 return 1
them_d = json.loads(them_bytes.decode("utf-8")) them_d = json.loads(them_bytes.decode("utf-8"))
if "error" in them_d: if "error" in them_d:
print("ERROR: " + them_d["error"], file=sys.stderr) print(u"ERROR: " + them_d["error"], file=sys.stderr)
return 1 return 1
if "message" in them_d: if "message" in them_d:
# we're receiving a text message # we're receiving a text message
print(them_d["message"]) print(them_d["message"], file=args.stdout)
data = json.dumps({"message_ack": "ok"}).encode("utf-8") data = json.dumps({"message_ack": "ok"}).encode("utf-8")
w.send_data(data) w.send_data(data)
return 0 return 0
if "error" in them_d: if "error" in them_d:
print("ERROR: " + data["error"], file=sys.stderr) print(u"ERROR: " + data["error"], file=sys.stderr)
return 1 return 1
if "file" in them_d: if "file" in them_d:
@ -216,8 +226,8 @@ def receive(args):
if "directory" in them_d: if "directory" in them_d:
return accept_directory(args, them_d, w) return accept_directory(args, them_d, w)
print("I don't know what they're offering\n") print(u"I don't know what they're offering\n", file=args.stdout)
print("Offer details:", them_d) print(u"Offer details:", them_d, file=args.stdout)
data = json.dumps({"error": "unknown offer type"}).encode("utf-8") data = json.dumps({"error": "unknown offer type"}).encode("utf-8")
w.send_data(data) w.send_data(data)
return 1 return 1

View File

@ -11,17 +11,18 @@ def send(args):
text = args.text text = args.text
if text == "-": if text == "-":
print("Reading text message from stdin..") print(u"Reading text message from stdin..", file=args.stdout)
text = sys.stdin.read() text = sys.stdin.read()
if not text and not args.what: if not text and not args.what:
text = six.moves.input("Text to send: ") text = six.moves.input("Text to send: ")
if text is not None: if text is not None:
print("Sending text message (%d bytes)" % len(text)) print(u"Sending text message (%d bytes)" % len(text), file=args.stdout)
phase1 = { "message": text } phase1 = { "message": text }
fd_to_send = None fd_to_send = None
else: else:
if not os.path.exists(args.what): what = os.path.join(args.cwd, args.what)
if not os.path.exists(what):
raise TransferError("Cannot send: no file/directory named '%s'" % raise TransferError("Cannot send: no file/directory named '%s'" %
args.what) args.what)
phase1, fd_to_send = _build_phase1_data(args) phase1, fd_to_send = _build_phase1_data(args)
@ -36,7 +37,8 @@ def send(args):
other_cmd = "wormhole --verify receive" other_cmd = "wormhole --verify receive"
if args.zeromode: if args.zeromode:
other_cmd += " -0" other_cmd += " -0"
print("On the other computer, please run: %s" % other_cmd) print(u"On the other computer, please run: %s" % other_cmd,
file=args.stdout)
from .cmd_send_blocking import send_blocking from .cmd_send_blocking import send_blocking
rc = send_blocking(APPID, args, phase1, fd_to_send) rc = send_blocking(APPID, args, phase1, fd_to_send)
@ -44,27 +46,29 @@ def send(args):
def _build_phase1_data(args): def _build_phase1_data(args):
phase1 = {} phase1 = {}
basename = os.path.basename(args.what) what = os.path.join(args.cwd, args.what)
if os.path.isfile(args.what): basename = os.path.basename(what)
if os.path.isfile(what):
# we're sending a file # we're sending a file
filesize = os.stat(args.what).st_size filesize = os.stat(what).st_size
phase1["file"] = { phase1["file"] = {
"filename": basename, "filename": basename,
"filesize": filesize, "filesize": filesize,
} }
print("Sending %d byte file named '%s'" % (filesize, basename)) print(u"Sending %d byte file named '%s'" % (filesize, basename),
fd_to_send = open(args.what, "rb") file=args.stdout)
elif os.path.isdir(args.what): fd_to_send = open(what, "rb")
print("Building zipfile..") elif os.path.isdir(what):
print(u"Building zipfile..", file=args.stdout)
# We're sending a directory. Create a zipfile in a tempdir and # We're sending a directory. Create a zipfile in a tempdir and
# send that. # send that.
fd_to_send = tempfile.SpooledTemporaryFile() fd_to_send = tempfile.SpooledTemporaryFile()
# TODO: I think ZIP_DEFLATED means compressed.. check it # TODO: I think ZIP_DEFLATED means compressed.. check it
num_files = 0 num_files = 0
num_bytes = 0 num_bytes = 0
tostrip = len(args.what.split(os.sep)) tostrip = len(what.split(os.sep))
with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf: with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf:
for path,dirs,files in os.walk(args.what): for path,dirs,files in os.walk(what):
# path always starts with args.what, then sometimes might # path always starts with args.what, then sometimes might
# have "/subdir" appended. We want the zipfile to contain # have "/subdir" appended. We want the zipfile to contain
# "" or "subdir" # "" or "subdir"
@ -85,8 +89,8 @@ def _build_phase1_data(args):
"numbytes": num_bytes, "numbytes": num_bytes,
"numfiles": num_files, "numfiles": num_files,
} }
print("Sending directory (%d bytes compressed) named '%s'" print(u"Sending directory (%d bytes compressed) named '%s'"
% (filesize, basename)) % (filesize, basename), file=args.stdout)
else: else:
raise TypeError("'%s' is neither file nor directory" % args.what) raise TypeError("'%s' is neither file nor directory" % what)
return phase1, fd_to_send return phase1, fd_to_send

View File

@ -1,5 +1,5 @@
from __future__ import print_function from __future__ import print_function
import sys, json, binascii, six import json, binascii, six
from ..errors import TransferError from ..errors import TransferError
from .progress import ProgressPrinter from .progress import ProgressPrinter
@ -21,8 +21,8 @@ def send_blocking(appid, args, phase1, fd_to_send):
else: else:
code = w.get_code(args.code_length) code = w.get_code(args.code_length)
if not args.zeromode: if not args.zeromode:
print("Wormhole code is: %s" % code) print(u"Wormhole code is: %s" % code, file=args.stdout)
print("") print(u"", file=args.stdout)
if args.verify: if args.verify:
_do_verify(w) _do_verify(w)
@ -40,12 +40,12 @@ def send_blocking(appid, args, phase1, fd_to_send):
if fd_to_send is None: if fd_to_send is None:
if them_phase1["message_ack"] == "ok": if them_phase1["message_ack"] == "ok":
print("text message sent") print(u"text message sent", file=args.stdout)
return 0 return 0
raise TransferError("error sending text: %r" % (them_phase1,)) raise TransferError("error sending text: %r" % (them_phase1,))
return _send_file_blocking(w, appid, them_phase1, fd_to_send, return _send_file_blocking(w, appid, them_phase1, fd_to_send,
transit_sender) transit_sender, args.stdout)
def _do_verify(w): def _do_verify(w):
verifier = binascii.hexlify(w.get_verifier()).decode("ascii") verifier = binascii.hexlify(w.get_verifier()).decode("ascii")
@ -59,7 +59,8 @@ def _do_verify(w):
w.send_data(reject_data) w.send_data(reject_data)
raise TransferError("verification rejected, abandoning transfer") raise TransferError("verification rejected, abandoning transfer")
def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender): def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender,
stdout):
# we're sending a file, if they accept it # we're sending a file, if they accept it
@ -77,13 +78,13 @@ def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender):
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"]) transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = transit_sender.connect() record_pipe = transit_sender.connect()
print("Sending (%s).." % transit_sender.describe()) print(u"Sending (%s).." % transit_sender.describe(), file=stdout)
CHUNKSIZE = 64*1024 CHUNKSIZE = 64*1024
fd_to_send.seek(0,2) fd_to_send.seek(0,2)
filesize = fd_to_send.tell() filesize = fd_to_send.tell()
fd_to_send.seek(0,0) fd_to_send.seek(0,0)
p = ProgressPrinter(filesize, sys.stdout) p = ProgressPrinter(filesize, stdout)
with fd_to_send as f: with fd_to_send as f:
sent = 0 sent = 0
p.start() p.start()
@ -94,10 +95,10 @@ def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender):
p.update(sent) p.update(sent)
p.finish() p.finish()
print("File sent.. waiting for confirmation") print(u"File sent.. waiting for confirmation", file=stdout)
ack = record_pipe.receive_record() ack = record_pipe.receive_record()
record_pipe.close() record_pipe.close()
if ack == b"ok\n": if ack == b"ok\n":
print("Confirmation received. Transfer complete.") print(u"Confirmation received. Transfer complete.", file=stdout)
return 0 return 0
raise TransferError("Transfer failed (remote says: %r)" % ack) raise TransferError("Transfer failed (remote says: %r)" % ack)

View File

@ -1,5 +1,5 @@
from __future__ import print_function from __future__ import print_function
import sys, argparse import os, sys, argparse
from textwrap import dedent from textwrap import dedent
from .. import public_relay from .. import public_relay
from .. import __version__ from .. import __version__
@ -120,7 +120,7 @@ p.set_defaults(func=cmd_receive.receive)
def run(args, stdout, stderr, executable=None): def run(args, cwd, stdout, stderr, executable=None):
"""This is invoked directly by the 'wormhole' entry-point script. It can """This is invoked directly by the 'wormhole' entry-point script. It can
also invoked by entry() below.""" also invoked by entry() below."""
@ -130,6 +130,9 @@ def run(args, stdout, stderr, executable=None):
# "error: too few arguments" during parse_args(). # "error: too few arguments" during parse_args().
parser.print_help() parser.print_help()
sys.exit(0) sys.exit(0)
args.cwd = cwd
args.stdout = stdout
args.stderr = stderr
try: try:
#rc = command.func(args, stdout, stderr) #rc = command.func(args, stdout, stderr)
rc = args.func(args) rc = args.func(args)
@ -147,7 +150,8 @@ def run(args, stdout, stderr, executable=None):
def entry(): def entry():
"""This is used by a setuptools entry_point. When invoked this way, """This is used by a setuptools entry_point. When invoked this way,
setuptools has already put the installed package on sys.path .""" setuptools has already put the installed package on sys.path ."""
return run(sys.argv[1:], sys.stdout, sys.stderr, executable=sys.argv[0]) return run(sys.argv[1:], os.getcwd(), sys.stdout, sys.stderr,
executable=sys.argv[0])
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()