diff --git a/src/wormhole/scripts/cmd_receive.py b/src/wormhole/scripts/cmd_receive.py index 79dbb24..e46f5c8 100644 --- a/src/wormhole/scripts/cmd_receive.py +++ b/src/wormhole/scripts/cmd_receive.py @@ -89,9 +89,12 @@ def accept_directory(args, them_d, w): w.send_data(data) return 1 - # the basename() is intended to protect us against - # "~/.ssh/authorized_keys" and other attacks - dirname = os.path.basename(file_data["dirname"]) # unicode + if args.output_file: + dirname = args.output_file + else: + # the basename() is intended to protect us against + # "~/.ssh/authorized_keys" and other attacks + dirname = os.path.basename(file_data["dirname"]) # unicode filesize = file_data["zipsize"] num_files = file_data["numfiles"] num_bytes = file_data["numbytes"] @@ -154,6 +157,9 @@ def accept_directory(args, them_d, w): print("Unpacking zipfile..") with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: zf.extractall(path=dirname) + # extractall() appears to offer some protection against malicious + # pathnames. For example, "/tmp/oops" and "../tmp/oops" both do the + # same thing as the (safe) "tmp/oops". print("Received files written to %s/" % dirname) record_pipe.send_record(b"ok\n") diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index ce7ba50..fb71040 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -1,4 +1,4 @@ -import os, sys +import os, sys, re from twisted.trial import unittest from twisted.python import procutils, log from twisted.internet.utils import getProcessOutputAndValue @@ -177,3 +177,83 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase): self.failUnlessEqual(f.read(), message) d1.addCallback(_check_receiver) return d1 + + def test_send_directory_pre_generated_code(self): + return self._do_test_send_directory_pre_generated_code(False) + def test_send_directory_pre_generated_code_override(self): + return self._do_test_send_directory_pre_generated_code(True) + + def _do_test_send_directory_pre_generated_code(self, override_dirname): + self.maxDiff=None + code = u"1-abc" + dirname = "testdir" + def message(i): + return "test message %d\n" % i + + source_parent_dir = self.mktemp() + os.mkdir(source_parent_dir) + os.mkdir(os.path.join(source_parent_dir, "middle")) + source_dir = os.path.join(source_parent_dir, "middle", dirname) + os.mkdir(source_dir) + for i in range(5): + with open(os.path.join(source_dir, str(i)), "w") as f: + f.write(message(i)) + + target_parent_dir = self.mktemp() + os.mkdir(target_parent_dir) + + wormhole = self.find_executable() + server_args = ["--relay-url", self.relayurl] + send_args = server_args + [ + "send", + "--code", code, + os.path.join("middle", dirname), + ] + + receive_args = server_args + [ + "receive", "--accept-file", + ] + if override_dirname: + receive_args.extend(["-o", "outdir"]) + dirname = "outdir" + receive_args.append(code) + + d1 = getProcessOutputAndValue(wormhole, send_args, + path=source_parent_dir) + + d2 = getProcessOutputAndValue(wormhole, receive_args, + path=target_parent_dir) + def _check_sender(res): + out, err, rc = res + out = out.decode("utf-8") + err = err.decode("utf-8") + self.failUnlessEqual(err, "") + self.failUnlessIn("Sending directory", out) + self.failUnlessIn("named 'testdir'", out) + self.failUnlessIn("On the other computer, please run: " + "wormhole receive\n" + "Wormhole code is: %s\n\n" % code, + out) + self.failUnlessIn("File sent.. waiting for confirmation\n" + "Confirmation received. Transfer complete.\n", + out) + self.failUnlessEqual(rc, 0) + return d2 + d1.addCallback(_check_sender) + def _check_receiver(res): + out, err, rc = res + out = out.decode("utf-8") + err = err.decode("utf-8") + self.failUnless(re.search(r"Receiving \d+ bytes for '%s'" % + dirname, out)) + self.failUnlessIn("Received files written to %s" % dirname, out) + self.failUnlessEqual(err, "") + self.failUnlessEqual(rc, 0) + fn = os.path.join(target_parent_dir, dirname) + self.failUnless(os.path.exists(fn)) + for i in range(5): + fn = os.path.join(target_parent_dir, dirname, str(i)) + with open(fn, "r") as f: + self.failUnlessEqual(f.read(), message(i)) + d1.addCallback(_check_receiver) + return d1