diff --git a/src/wormhole/scripts/cmd_receive.py b/src/wormhole/scripts/cmd_receive.py index 2936638..e46f5c8 100644 --- a/src/wormhole/scripts/cmd_receive.py +++ b/src/wormhole/scripts/cmd_receive.py @@ -1,88 +1,46 @@ from __future__ import print_function -import os, sys, json, binascii, six +import os, sys, json, binascii, six, tempfile, zipfile from ..errors import handle_server_error APPID = u"lothar.com/wormhole/text-or-file-xfer" -@handle_server_error -def receive(args): - # we're receiving text, or a file - from ..blocking.transcribe import Wormhole, WrongPasswordError +def accept_file(args, them_d, w): from ..blocking.transit import TransitReceiver, TransitError from .progress import start_progress, update_progress, finish_progress - assert isinstance(args.relay_url, type(u"")) - with Wormhole(APPID, args.relay_url) as w: - if args.zeromode: - assert not args.code - args.code = u"0-" - code = args.code - if not code: - code = w.input_code("Enter receive wormhole code: ", args.code_length) - w.set_code(code) + file_data = them_d["file"] + # the basename() is intended to protect us against + # "~/.ssh/authorized_keys" and other attacks + filename = os.path.basename(file_data["filename"]) # unicode + filesize = file_data["filesize"] - if args.verify: - verifier = binascii.hexlify(w.get_verifier()).decode("ascii") - print("Verifier %s." % verifier) - - try: - them_bytes = w.get_data() - except WrongPasswordError as e: - print("ERROR: " + e.explain(), file=sys.stderr) - return 1 - them_d = json.loads(them_bytes.decode("utf-8")) - if "error" in them_d: - print("ERROR: " + them_d["error"], file=sys.stderr) - return 1 - - if "message" in them_d: - # we're receiving a text message - print(them_d["message"]) - data = json.dumps({"message_ack": "ok"}).encode("utf-8") - w.send_data(data) - return 0 - - if not "file" in them_d: - print("I don't know what they're offering\n") - print(them_d) - return 1 - - if "error" in them_d: - print("ERROR: " + data["error"], file=sys.stderr) - return 1 - - file_data = them_d["file"] - # the basename() is intended to protect us against - # "~/.ssh/authorized_keys" and other attacks - filename = os.path.basename(file_data["filename"]) # unicode - filesize = file_data["filesize"] - - # get confirmation from the user before writing to the local directory - if os.path.exists(filename): - print("Error: refusing to overwrite existing file %s" % (filename,)) - data = json.dumps({"error": "file already exists"}).encode("utf-8") - w.send_data(data) - return 1 - - print("Receiving file (%d bytes) into: %s" % (filesize, filename)) - while True and not args.accept_file: - ok = six.moves.input("ok? (y/n): ") - if ok.lower().startswith("y"): - break - print("transfer rejected", file=sys.stderr) - data = json.dumps({"error": "transfer rejected"}).encode("utf-8") - w.send_data(data) - return 1 - - transit_receiver = TransitReceiver(args.transit_helper) - data = json.dumps({ - "file_ack": "ok", - "transit": { - "direct_connection_hints": transit_receiver.get_direct_hints(), - "relay_connection_hints": transit_receiver.get_relay_hints(), - }, - }).encode("utf-8") + # get confirmation from the user before writing to the local directory + if os.path.exists(filename): + print("Error: refusing to overwrite existing file %s" % (filename,)) + data = json.dumps({"error": "file already exists"}).encode("utf-8") w.send_data(data) + return 1 + + print("Receiving file (%d bytes) into: %s" % (filesize, filename)) + while True and not args.accept_file: + ok = six.moves.input("ok? (y/n): ") + if ok.lower().startswith("y"): + break + print("transfer rejected", file=sys.stderr) + data = json.dumps({"error": "transfer rejected"}).encode("utf-8") + w.send_data(data) + return 1 + + transit_receiver = TransitReceiver(args.transit_helper) + data = json.dumps({ + "file_ack": "ok", + "transit": { + "direct_connection_hints": transit_receiver.get_direct_hints(), + "relay_connection_hints": transit_receiver.get_relay_hints(), + }, + }).encode("utf-8") + w.send_data(data) + # now done with the Wormhole object # now receive the rest of the owl tdata = them_d["transit"] @@ -118,3 +76,144 @@ def receive(args): record_pipe.send_record(b"ok\n") record_pipe.close() return 0 + +def accept_directory(args, them_d, w): + from ..blocking.transit import TransitReceiver, TransitError + from .progress import start_progress, update_progress, finish_progress + + file_data = them_d["directory"] + mode = file_data["mode"] + if mode != "zipfile/deflated": + print("Error: unknown directory-transfer mode '%s'" % (mode,)) + data = json.dumps({"error": "unknown mode"}).encode("utf-8") + w.send_data(data) + return 1 + + if args.output_file: + dirname = args.output_file + else: + # the basename() is intended to protect us against + # "~/.ssh/authorized_keys" and other attacks + dirname = os.path.basename(file_data["dirname"]) # unicode + filesize = file_data["zipsize"] + num_files = file_data["numfiles"] + num_bytes = file_data["numbytes"] + + if os.path.exists(dirname): + print("Error: refusing to overwrite existing directory %s" % (dirname,)) + data = json.dumps({"error": "directory already exists"}).encode("utf-8") + w.send_data(data) + return 1 + + print("Receiving directory into: %s/" % (dirname,)) + print("%d files, %d bytes (%d compressed)" % (num_files, num_bytes, + filesize)) + while True and not args.accept_file: + ok = six.moves.input("ok? (y/n): ") + if ok.lower().startswith("y"): + break + print("transfer rejected", file=sys.stderr) + data = json.dumps({"error": "transfer rejected"}).encode("utf-8") + w.send_data(data) + return 1 + + transit_receiver = TransitReceiver(args.transit_helper) + data = json.dumps({ + "file_ack": "ok", + "transit": { + "direct_connection_hints": transit_receiver.get_direct_hints(), + "relay_connection_hints": transit_receiver.get_relay_hints(), + }, + }).encode("utf-8") + w.send_data(data) + # now done with the Wormhole object + + # now receive the rest of the owl + tdata = them_d["transit"] + transit_key = w.derive_key(APPID+u"/transit-key") + transit_receiver.set_transit_key(transit_key) + transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"]) + transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"]) + record_pipe = transit_receiver.connect() + + print("Receiving %d bytes for '%s' (%s).." % (filesize, dirname, + transit_receiver.describe())) + f = tempfile.SpooledTemporaryFile() + received = 0 + next_update = start_progress(filesize) + while received < filesize: + try: + plaintext = record_pipe.receive_record() + except TransitError: + print() + print("Connection dropped before full file received") + print("got %d bytes, wanted %d" % (received, filesize)) + return 1 + f.write(plaintext) + received += len(plaintext) + next_update = update_progress(next_update, received, filesize) + finish_progress(filesize) + assert received == filesize + print("Unpacking zipfile..") + with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: + zf.extractall(path=dirname) + # extractall() appears to offer some protection against malicious + # pathnames. For example, "/tmp/oops" and "../tmp/oops" both do the + # same thing as the (safe) "tmp/oops". + + print("Received files written to %s/" % dirname) + record_pipe.send_record(b"ok\n") + record_pipe.close() + return 0 + +@handle_server_error +def receive(args): + # we're receiving text, or a file + from ..blocking.transcribe import Wormhole, WrongPasswordError + assert isinstance(args.relay_url, type(u"")) + + with Wormhole(APPID, args.relay_url) as w: + if args.zeromode: + assert not args.code + args.code = u"0-" + code = args.code + if not code: + code = w.input_code("Enter receive wormhole code: ", args.code_length) + w.set_code(code) + + if args.verify: + verifier = binascii.hexlify(w.get_verifier()).decode("ascii") + print("Verifier %s." % verifier) + + try: + them_bytes = w.get_data() + except WrongPasswordError as e: + print("ERROR: " + e.explain(), file=sys.stderr) + return 1 + them_d = json.loads(them_bytes.decode("utf-8")) + if "error" in them_d: + print("ERROR: " + them_d["error"], file=sys.stderr) + return 1 + + if "message" in them_d: + # we're receiving a text message + print(them_d["message"]) + data = json.dumps({"message_ack": "ok"}).encode("utf-8") + w.send_data(data) + return 0 + + if "error" in them_d: + print("ERROR: " + data["error"], file=sys.stderr) + return 1 + + if "file" in them_d: + return accept_file(args, them_d, w) + + if "directory" in them_d: + return accept_directory(args, them_d, w) + + print("I don't know what they're offering\n") + print("Offer details:", them_d) + data = json.dumps({"error": "unknown offer type"}).encode("utf-8") + w.send_data(data) + return 1 diff --git a/src/wormhole/scripts/cmd_send.py b/src/wormhole/scripts/cmd_send.py index 9daf6cb..f026264 100644 --- a/src/wormhole/scripts/cmd_send.py +++ b/src/wormhole/scripts/cmd_send.py @@ -1,12 +1,12 @@ from __future__ import print_function -import os, sys, json, binascii, six +import os, sys, json, binascii, six, tempfile, zipfile from ..errors import handle_server_error APPID = u"lothar.com/wormhole/text-or-file-xfer" @handle_server_error def send(args): - # we're sending text, or a file + # we're sending text, or a file/directory from ..blocking.transcribe import Wormhole, WrongPasswordError from ..blocking.transit import TransitSender from .progress import start_progress, update_progress, finish_progress @@ -26,25 +26,60 @@ def send(args): "message": text, } else: - if not os.path.isfile(args.what): - print("Cannot send: no file named '%s'" % args.what) + if not os.path.exists(args.what): + print("Cannot send: no file/directory named '%s'" % args.what) return 1 - # we're sending a file sending_message = False - filesize = os.stat(args.what).st_size - basename = os.path.basename(args.what) - print("Sending %d byte file named '%s'" % (filesize, basename)) transit_sender = TransitSender(args.transit_helper) phase1 = { - "file": { - "filename": basename, - "filesize": filesize, - }, "transit": { "direct_connection_hints": transit_sender.get_direct_hints(), "relay_connection_hints": transit_sender.get_relay_hints(), }, } + basename = os.path.basename(args.what) + if os.path.isfile(args.what): + # we're sending a file + filesize = os.stat(args.what).st_size + phase1["file"] = { + "filename": basename, + "filesize": filesize, + } + print("Sending %d byte file named '%s'" % (filesize, basename)) + fd_to_send = open(args.what, "rb") + elif os.path.isdir(args.what): + print("Building zipfile..") + # We're sending a directory. Create a zipfile in a tempdir and + # send that. + fd_to_send = tempfile.SpooledTemporaryFile() + # TODO: I think ZIP_DEFLATED means compressed.. check it + num_files = 0 + num_bytes = 0 + tostrip = len(args.what.split(os.sep)) + with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf: + for path,dirs,files in os.walk(args.what): + # path always starts with args.what, then sometimes might + # have "/subdir" appended. We want the zipfile to contain + # "" or "subdir" + localpath = list(path.split(os.sep)[tostrip:]) + for fn in files: + archivename = os.path.join(*tuple(localpath+[fn])) + localfilename = os.path.join(path, fn) + zf.write(localfilename, archivename) + num_bytes += os.stat(localfilename).st_size + num_files += 1 + fd_to_send.seek(0,2) + filesize = fd_to_send.tell() + fd_to_send.seek(0,0) + phase1["directory"] = { + "mode": "zipfile/deflated", + "dirname": basename, + "zipsize": filesize, + "numbytes": num_bytes, + "numfiles": num_files, + } + print("Sending directory (%d bytes compressed) named '%s'" + % (filesize, basename)) with Wormhole(APPID, args.relay_url) as w: if args.zeromode: @@ -114,7 +149,7 @@ def send(args): print("Sending (%s).." % transit_sender.describe()) CHUNKSIZE = 64*1024 - with open(args.what, "rb") as f: + with fd_to_send as f: sent = 0 next_update = start_progress(filesize) while sent < filesize: diff --git a/src/wormhole/scripts/runner.py b/src/wormhole/scripts/runner.py index d78205a..e4497b9 100644 --- a/src/wormhole/scripts/runner.py +++ b/src/wormhole/scripts/runner.py @@ -76,21 +76,21 @@ sp_tail_usage.set_defaults(func=cmd_usage.tail_usage) # CLI: send p = subparsers.add_parser("send", - description="Send text message or file", - usage="wormhole send [FILENAME]") + description="Send text message, file, or directory", + usage="wormhole send [FILENAME|DIRNAME]") p.add_argument("--text", metavar="MESSAGE", help="text message to send, instead of a file. Use '-' to read from stdin.") p.add_argument("--code", metavar="CODE", help="human-generated code phrase", type=type(u"")) p.add_argument("-0", dest="zeromode", action="store_true", help="enable no-code anything-goes mode") -p.add_argument("what", nargs="?", default=None, metavar="[FILENAME]", - help="the file to send") +p.add_argument("what", nargs="?", default=None, metavar="[FILENAME|DIRNAME]", + help="the file/directory to send") p.set_defaults(func=cmd_send.send) # CLI: receive p = subparsers.add_parser("receive", - description="Receive a text message or file", + description="Receive a text message, file, or directory", usage="wormhole receive [CODE]") p.add_argument("-0", dest="zeromode", action="store_true", help="enable no-code anything-goes mode") @@ -98,10 +98,10 @@ p.add_argument("-t", "--only-text", dest="only_text", action="store_true", help="refuse file transfers, only accept text transfers") p.add_argument("--accept-file", dest="accept_file", action="store_true", help="accept file transfer with asking for confirmation") -p.add_argument("-o", "--output-file", default=None, metavar="FILENAME", +p.add_argument("-o", "--output-file", default=None, metavar="FILENAME|DIRNAME", help=dedent("""\ - The file to create, overriding the filename suggested by the - sender."""), + The file or directory to create, overriding the name suggested + by the sender."""), ) p.add_argument("code", nargs="?", default=None, metavar="[CODE]", help=dedent("""\ diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index ce7ba50..fb71040 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -1,4 +1,4 @@ -import os, sys +import os, sys, re from twisted.trial import unittest from twisted.python import procutils, log from twisted.internet.utils import getProcessOutputAndValue @@ -177,3 +177,83 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase): self.failUnlessEqual(f.read(), message) d1.addCallback(_check_receiver) return d1 + + def test_send_directory_pre_generated_code(self): + return self._do_test_send_directory_pre_generated_code(False) + def test_send_directory_pre_generated_code_override(self): + return self._do_test_send_directory_pre_generated_code(True) + + def _do_test_send_directory_pre_generated_code(self, override_dirname): + self.maxDiff=None + code = u"1-abc" + dirname = "testdir" + def message(i): + return "test message %d\n" % i + + source_parent_dir = self.mktemp() + os.mkdir(source_parent_dir) + os.mkdir(os.path.join(source_parent_dir, "middle")) + source_dir = os.path.join(source_parent_dir, "middle", dirname) + os.mkdir(source_dir) + for i in range(5): + with open(os.path.join(source_dir, str(i)), "w") as f: + f.write(message(i)) + + target_parent_dir = self.mktemp() + os.mkdir(target_parent_dir) + + wormhole = self.find_executable() + server_args = ["--relay-url", self.relayurl] + send_args = server_args + [ + "send", + "--code", code, + os.path.join("middle", dirname), + ] + + receive_args = server_args + [ + "receive", "--accept-file", + ] + if override_dirname: + receive_args.extend(["-o", "outdir"]) + dirname = "outdir" + receive_args.append(code) + + d1 = getProcessOutputAndValue(wormhole, send_args, + path=source_parent_dir) + + d2 = getProcessOutputAndValue(wormhole, receive_args, + path=target_parent_dir) + def _check_sender(res): + out, err, rc = res + out = out.decode("utf-8") + err = err.decode("utf-8") + self.failUnlessEqual(err, "") + self.failUnlessIn("Sending directory", out) + self.failUnlessIn("named 'testdir'", out) + self.failUnlessIn("On the other computer, please run: " + "wormhole receive\n" + "Wormhole code is: %s\n\n" % code, + out) + self.failUnlessIn("File sent.. waiting for confirmation\n" + "Confirmation received. Transfer complete.\n", + out) + self.failUnlessEqual(rc, 0) + return d2 + d1.addCallback(_check_sender) + def _check_receiver(res): + out, err, rc = res + out = out.decode("utf-8") + err = err.decode("utf-8") + self.failUnless(re.search(r"Receiving \d+ bytes for '%s'" % + dirname, out)) + self.failUnlessIn("Received files written to %s" % dirname, out) + self.failUnlessEqual(err, "") + self.failUnlessEqual(rc, 0) + fn = os.path.join(target_parent_dir, dirname) + self.failUnless(os.path.exists(fn)) + for i in range(5): + fn = os.path.join(target_parent_dir, dirname, str(i)) + with open(fn, "r") as f: + self.failUnlessEqual(f.read(), message(i)) + d1.addCallback(_check_receiver) + return d1