runner: strictly use cwd/stdout/stderr from 'args'
This will make it easier to test the scripts in a controlled fashion.
This commit is contained in:
parent
e2f3bebe38
commit
c5b2800a3e
|
@ -15,21 +15,24 @@ def accept_file(args, them_d, w):
|
|||
# the basename() is intended to protect us against
|
||||
# "~/.ssh/authorized_keys" and other attacks
|
||||
filename = os.path.basename(file_data["filename"]) # unicode
|
||||
abs_filename = os.path.join(args.cwd, filename)
|
||||
filesize = file_data["filesize"]
|
||||
|
||||
# get confirmation from the user before writing to the local directory
|
||||
if os.path.exists(filename):
|
||||
print("Error: refusing to overwrite existing file %s" % (filename,))
|
||||
if os.path.exists(abs_filename):
|
||||
print(u"Error: refusing to overwrite existing file %s" % (filename,),
|
||||
file=args.stdout)
|
||||
data = json.dumps({"error": "file already exists"}).encode("utf-8")
|
||||
w.send_data(data)
|
||||
return 1
|
||||
|
||||
print("Receiving file (%d bytes) into: %s" % (filesize, filename))
|
||||
print(u"Receiving file (%d bytes) into: %s" % (filesize, filename),
|
||||
file=args.stdout)
|
||||
while True and not args.accept_file:
|
||||
ok = six.moves.input("ok? (y/n): ")
|
||||
if ok.lower().startswith("y"):
|
||||
break
|
||||
print("transfer rejected", file=sys.stderr)
|
||||
print(u"transfer rejected", file=sys.stderr)
|
||||
data = json.dumps({"error": "transfer rejected"}).encode("utf-8")
|
||||
w.send_data(data)
|
||||
return 1
|
||||
|
@ -53,9 +56,9 @@ def accept_file(args, them_d, w):
|
|||
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
|
||||
record_pipe = transit_receiver.connect()
|
||||
|
||||
print("Receiving %d bytes for '%s' (%s).." % (filesize, filename,
|
||||
transit_receiver.describe()))
|
||||
tmp = filename + ".tmp"
|
||||
print(u"Receiving %d bytes for '%s' (%s).." %
|
||||
(filesize, filename, transit_receiver.describe()), file=args.stdout)
|
||||
tmp = abs_filename + ".tmp"
|
||||
with open(tmp, "wb") as f:
|
||||
received = 0
|
||||
p = ProgressPrinter(filesize, sys.stdout)
|
||||
|
@ -64,9 +67,11 @@ def accept_file(args, them_d, w):
|
|||
try:
|
||||
plaintext = record_pipe.receive_record()
|
||||
except TransitError:
|
||||
print()
|
||||
print("Connection dropped before full file received")
|
||||
print("got %d bytes, wanted %d" % (received, filesize))
|
||||
print(u"", file=args.stdout)
|
||||
print(u"Connection dropped before full file received",
|
||||
file=args.stdout)
|
||||
print(u"got %d bytes, wanted %d" % (received, filesize),
|
||||
file=args.stdout)
|
||||
return 1
|
||||
f.write(plaintext)
|
||||
received += len(plaintext)
|
||||
|
@ -74,9 +79,9 @@ def accept_file(args, them_d, w):
|
|||
p.finish()
|
||||
assert received == filesize
|
||||
|
||||
os.rename(tmp, filename)
|
||||
os.rename(tmp, abs_filename)
|
||||
|
||||
print("Received file written to %s" % filename)
|
||||
print(u"Received file written to %s" % filename, file=args.stdout)
|
||||
record_pipe.send_record(b"ok\n")
|
||||
record_pipe.close()
|
||||
return 0
|
||||
|
@ -87,7 +92,8 @@ def accept_directory(args, them_d, w):
|
|||
file_data = them_d["directory"]
|
||||
mode = file_data["mode"]
|
||||
if mode != "zipfile/deflated":
|
||||
print("Error: unknown directory-transfer mode '%s'" % (mode,))
|
||||
print(u"Error: unknown directory-transfer mode '%s'" % (mode,),
|
||||
file=args.stdout)
|
||||
data = json.dumps({"error": "unknown mode"}).encode("utf-8")
|
||||
w.send_data(data)
|
||||
return 1
|
||||
|
@ -98,24 +104,26 @@ def accept_directory(args, them_d, w):
|
|||
# the basename() is intended to protect us against
|
||||
# "~/.ssh/authorized_keys" and other attacks
|
||||
dirname = os.path.basename(file_data["dirname"]) # unicode
|
||||
abs_dirname = os.path.join(args.cwd, dirname)
|
||||
filesize = file_data["zipsize"]
|
||||
num_files = file_data["numfiles"]
|
||||
num_bytes = file_data["numbytes"]
|
||||
|
||||
if os.path.exists(dirname):
|
||||
print("Error: refusing to overwrite existing directory %s" % (dirname,))
|
||||
if os.path.exists(abs_dirname):
|
||||
print(u"Error: refusing to overwrite existing directory %s" %
|
||||
(dirname,), file=args.stdout)
|
||||
data = json.dumps({"error": "directory already exists"}).encode("utf-8")
|
||||
w.send_data(data)
|
||||
return 1
|
||||
|
||||
print("Receiving directory into: %s/" % (dirname,))
|
||||
print("%d files, %d bytes (%d compressed)" % (num_files, num_bytes,
|
||||
filesize))
|
||||
print(u"Receiving directory into: %s/" % (dirname,), file=args.stdout)
|
||||
print(u"%d files, %d bytes (%d compressed)" %
|
||||
(num_files, num_bytes, filesize), file=args.stdout)
|
||||
while True and not args.accept_file:
|
||||
ok = six.moves.input("ok? (y/n): ")
|
||||
if ok.lower().startswith("y"):
|
||||
break
|
||||
print("transfer rejected", file=sys.stderr)
|
||||
print(u"transfer rejected", file=sys.stderr)
|
||||
data = json.dumps({"error": "transfer rejected"}).encode("utf-8")
|
||||
w.send_data(data)
|
||||
return 1
|
||||
|
@ -139,8 +147,8 @@ def accept_directory(args, them_d, w):
|
|||
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
|
||||
record_pipe = transit_receiver.connect()
|
||||
|
||||
print("Receiving %d bytes for '%s' (%s).." % (filesize, dirname,
|
||||
transit_receiver.describe()))
|
||||
print(u"Receiving %d bytes for '%s' (%s).." %
|
||||
(filesize, dirname, transit_receiver.describe()), file=args.stdout)
|
||||
f = tempfile.SpooledTemporaryFile()
|
||||
received = 0
|
||||
p = ProgressPrinter(filesize, sys.stdout)
|
||||
|
@ -149,23 +157,25 @@ def accept_directory(args, them_d, w):
|
|||
try:
|
||||
plaintext = record_pipe.receive_record()
|
||||
except TransitError:
|
||||
print()
|
||||
print("Connection dropped before full file received")
|
||||
print("got %d bytes, wanted %d" % (received, filesize))
|
||||
print(u"", file=args.stdout)
|
||||
print(u"Connection dropped before full file received",
|
||||
file=args.stdout)
|
||||
print(u"got %d bytes, wanted %d" % (received, filesize),
|
||||
file=args.stdout)
|
||||
return 1
|
||||
f.write(plaintext)
|
||||
received += len(plaintext)
|
||||
p.update(received)
|
||||
p.finish()
|
||||
assert received == filesize
|
||||
print("Unpacking zipfile..")
|
||||
print(u"Unpacking zipfile..", file=args.stdout)
|
||||
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.extractall(path=dirname)
|
||||
zf.extractall(path=abs_dirname)
|
||||
# extractall() appears to offer some protection against malicious
|
||||
# pathnames. For example, "/tmp/oops" and "../tmp/oops" both do the
|
||||
# same thing as the (safe) "tmp/oops".
|
||||
|
||||
print("Received files written to %s/" % dirname)
|
||||
print(u"Received files written to %s/" % dirname, file=args.stdout)
|
||||
record_pipe.send_record(b"ok\n")
|
||||
record_pipe.close()
|
||||
return 0
|
||||
|
@ -187,27 +197,27 @@ def receive(args):
|
|||
|
||||
if args.verify:
|
||||
verifier = binascii.hexlify(w.get_verifier()).decode("ascii")
|
||||
print("Verifier %s." % verifier)
|
||||
print(u"Verifier %s." % verifier, file=args.stdout)
|
||||
|
||||
try:
|
||||
them_bytes = w.get_data()
|
||||
except WrongPasswordError as e:
|
||||
print("ERROR: " + e.explain(), file=sys.stderr)
|
||||
print(u"ERROR: " + e.explain(), file=sys.stderr)
|
||||
return 1
|
||||
them_d = json.loads(them_bytes.decode("utf-8"))
|
||||
if "error" in them_d:
|
||||
print("ERROR: " + them_d["error"], file=sys.stderr)
|
||||
print(u"ERROR: " + them_d["error"], file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if "message" in them_d:
|
||||
# we're receiving a text message
|
||||
print(them_d["message"])
|
||||
print(them_d["message"], file=args.stdout)
|
||||
data = json.dumps({"message_ack": "ok"}).encode("utf-8")
|
||||
w.send_data(data)
|
||||
return 0
|
||||
|
||||
if "error" in them_d:
|
||||
print("ERROR: " + data["error"], file=sys.stderr)
|
||||
print(u"ERROR: " + data["error"], file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if "file" in them_d:
|
||||
|
@ -216,8 +226,8 @@ def receive(args):
|
|||
if "directory" in them_d:
|
||||
return accept_directory(args, them_d, w)
|
||||
|
||||
print("I don't know what they're offering\n")
|
||||
print("Offer details:", them_d)
|
||||
print(u"I don't know what they're offering\n", file=args.stdout)
|
||||
print(u"Offer details:", them_d, file=args.stdout)
|
||||
data = json.dumps({"error": "unknown offer type"}).encode("utf-8")
|
||||
w.send_data(data)
|
||||
return 1
|
||||
|
|
|
@ -11,17 +11,18 @@ def send(args):
|
|||
|
||||
text = args.text
|
||||
if text == "-":
|
||||
print("Reading text message from stdin..")
|
||||
print(u"Reading text message from stdin..", file=args.stdout)
|
||||
text = sys.stdin.read()
|
||||
if not text and not args.what:
|
||||
text = six.moves.input("Text to send: ")
|
||||
|
||||
if text is not None:
|
||||
print("Sending text message (%d bytes)" % len(text))
|
||||
print(u"Sending text message (%d bytes)" % len(text), file=args.stdout)
|
||||
phase1 = { "message": text }
|
||||
fd_to_send = None
|
||||
else:
|
||||
if not os.path.exists(args.what):
|
||||
what = os.path.join(args.cwd, args.what)
|
||||
if not os.path.exists(what):
|
||||
raise TransferError("Cannot send: no file/directory named '%s'" %
|
||||
args.what)
|
||||
phase1, fd_to_send = _build_phase1_data(args)
|
||||
|
@ -36,7 +37,8 @@ def send(args):
|
|||
other_cmd = "wormhole --verify receive"
|
||||
if args.zeromode:
|
||||
other_cmd += " -0"
|
||||
print("On the other computer, please run: %s" % other_cmd)
|
||||
print(u"On the other computer, please run: %s" % other_cmd,
|
||||
file=args.stdout)
|
||||
|
||||
from .cmd_send_blocking import send_blocking
|
||||
rc = send_blocking(APPID, args, phase1, fd_to_send)
|
||||
|
@ -44,27 +46,29 @@ def send(args):
|
|||
|
||||
def _build_phase1_data(args):
|
||||
phase1 = {}
|
||||
basename = os.path.basename(args.what)
|
||||
if os.path.isfile(args.what):
|
||||
what = os.path.join(args.cwd, args.what)
|
||||
basename = os.path.basename(what)
|
||||
if os.path.isfile(what):
|
||||
# we're sending a file
|
||||
filesize = os.stat(args.what).st_size
|
||||
filesize = os.stat(what).st_size
|
||||
phase1["file"] = {
|
||||
"filename": basename,
|
||||
"filesize": filesize,
|
||||
}
|
||||
print("Sending %d byte file named '%s'" % (filesize, basename))
|
||||
fd_to_send = open(args.what, "rb")
|
||||
elif os.path.isdir(args.what):
|
||||
print("Building zipfile..")
|
||||
print(u"Sending %d byte file named '%s'" % (filesize, basename),
|
||||
file=args.stdout)
|
||||
fd_to_send = open(what, "rb")
|
||||
elif os.path.isdir(what):
|
||||
print(u"Building zipfile..", file=args.stdout)
|
||||
# We're sending a directory. Create a zipfile in a tempdir and
|
||||
# send that.
|
||||
fd_to_send = tempfile.SpooledTemporaryFile()
|
||||
# TODO: I think ZIP_DEFLATED means compressed.. check it
|
||||
num_files = 0
|
||||
num_bytes = 0
|
||||
tostrip = len(args.what.split(os.sep))
|
||||
tostrip = len(what.split(os.sep))
|
||||
with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for path,dirs,files in os.walk(args.what):
|
||||
for path,dirs,files in os.walk(what):
|
||||
# path always starts with args.what, then sometimes might
|
||||
# have "/subdir" appended. We want the zipfile to contain
|
||||
# "" or "subdir"
|
||||
|
@ -85,8 +89,8 @@ def _build_phase1_data(args):
|
|||
"numbytes": num_bytes,
|
||||
"numfiles": num_files,
|
||||
}
|
||||
print("Sending directory (%d bytes compressed) named '%s'"
|
||||
% (filesize, basename))
|
||||
print(u"Sending directory (%d bytes compressed) named '%s'"
|
||||
% (filesize, basename), file=args.stdout)
|
||||
else:
|
||||
raise TypeError("'%s' is neither file nor directory" % args.what)
|
||||
raise TypeError("'%s' is neither file nor directory" % what)
|
||||
return phase1, fd_to_send
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import print_function
|
||||
import sys, json, binascii, six
|
||||
import json, binascii, six
|
||||
from ..errors import TransferError
|
||||
from .progress import ProgressPrinter
|
||||
|
||||
|
@ -21,8 +21,8 @@ def send_blocking(appid, args, phase1, fd_to_send):
|
|||
else:
|
||||
code = w.get_code(args.code_length)
|
||||
if not args.zeromode:
|
||||
print("Wormhole code is: %s" % code)
|
||||
print("")
|
||||
print(u"Wormhole code is: %s" % code, file=args.stdout)
|
||||
print(u"", file=args.stdout)
|
||||
|
||||
if args.verify:
|
||||
_do_verify(w)
|
||||
|
@ -40,12 +40,12 @@ def send_blocking(appid, args, phase1, fd_to_send):
|
|||
|
||||
if fd_to_send is None:
|
||||
if them_phase1["message_ack"] == "ok":
|
||||
print("text message sent")
|
||||
print(u"text message sent", file=args.stdout)
|
||||
return 0
|
||||
raise TransferError("error sending text: %r" % (them_phase1,))
|
||||
|
||||
return _send_file_blocking(w, appid, them_phase1, fd_to_send,
|
||||
transit_sender)
|
||||
transit_sender, args.stdout)
|
||||
|
||||
def _do_verify(w):
|
||||
verifier = binascii.hexlify(w.get_verifier()).decode("ascii")
|
||||
|
@ -59,7 +59,8 @@ def _do_verify(w):
|
|||
w.send_data(reject_data)
|
||||
raise TransferError("verification rejected, abandoning transfer")
|
||||
|
||||
def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender):
|
||||
def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender,
|
||||
stdout):
|
||||
|
||||
# we're sending a file, if they accept it
|
||||
|
||||
|
@ -77,13 +78,13 @@ def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender):
|
|||
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
|
||||
record_pipe = transit_sender.connect()
|
||||
|
||||
print("Sending (%s).." % transit_sender.describe())
|
||||
print(u"Sending (%s).." % transit_sender.describe(), file=stdout)
|
||||
|
||||
CHUNKSIZE = 64*1024
|
||||
fd_to_send.seek(0,2)
|
||||
filesize = fd_to_send.tell()
|
||||
fd_to_send.seek(0,0)
|
||||
p = ProgressPrinter(filesize, sys.stdout)
|
||||
p = ProgressPrinter(filesize, stdout)
|
||||
with fd_to_send as f:
|
||||
sent = 0
|
||||
p.start()
|
||||
|
@ -94,10 +95,10 @@ def _send_file_blocking(w, appid, them_phase1, fd_to_send, transit_sender):
|
|||
p.update(sent)
|
||||
p.finish()
|
||||
|
||||
print("File sent.. waiting for confirmation")
|
||||
print(u"File sent.. waiting for confirmation", file=stdout)
|
||||
ack = record_pipe.receive_record()
|
||||
record_pipe.close()
|
||||
if ack == b"ok\n":
|
||||
print("Confirmation received. Transfer complete.")
|
||||
print(u"Confirmation received. Transfer complete.", file=stdout)
|
||||
return 0
|
||||
raise TransferError("Transfer failed (remote says: %r)" % ack)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import print_function
|
||||
import sys, argparse
|
||||
import os, sys, argparse
|
||||
from textwrap import dedent
|
||||
from .. import public_relay
|
||||
from .. import __version__
|
||||
|
@ -120,7 +120,7 @@ p.set_defaults(func=cmd_receive.receive)
|
|||
|
||||
|
||||
|
||||
def run(args, stdout, stderr, executable=None):
|
||||
def run(args, cwd, stdout, stderr, executable=None):
|
||||
"""This is invoked directly by the 'wormhole' entry-point script. It can
|
||||
also invoked by entry() below."""
|
||||
|
||||
|
@ -130,6 +130,9 @@ def run(args, stdout, stderr, executable=None):
|
|||
# "error: too few arguments" during parse_args().
|
||||
parser.print_help()
|
||||
sys.exit(0)
|
||||
args.cwd = cwd
|
||||
args.stdout = stdout
|
||||
args.stderr = stderr
|
||||
try:
|
||||
#rc = command.func(args, stdout, stderr)
|
||||
rc = args.func(args)
|
||||
|
@ -147,7 +150,8 @@ def run(args, stdout, stderr, executable=None):
|
|||
def entry():
|
||||
"""This is used by a setuptools entry_point. When invoked this way,
|
||||
setuptools has already put the installed package on sys.path ."""
|
||||
return run(sys.argv[1:], sys.stdout, sys.stderr, executable=sys.argv[0])
|
||||
return run(sys.argv[1:], os.getcwd(), sys.stdout, sys.stderr,
|
||||
executable=sys.argv[0])
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user