diff --git a/src/wormhole/scripts/cmd_receive.py b/src/wormhole/scripts/cmd_receive.py index e46f5c8..47bf10a 100644 --- a/src/wormhole/scripts/cmd_receive.py +++ b/src/wormhole/scripts/cmd_receive.py @@ -9,9 +9,12 @@ def accept_file(args, them_d, w): from .progress import start_progress, update_progress, finish_progress file_data = them_d["file"] - # the basename() is intended to protect us against - # "~/.ssh/authorized_keys" and other attacks - filename = os.path.basename(file_data["filename"]) # unicode + if args.output_file: + filename = args.output_file + else: + # the basename() is intended to protect us against + # "~/.ssh/authorized_keys" and other attacks + filename = os.path.basename(file_data["filename"]) # unicode filesize = file_data["filesize"] # get confirmation from the user before writing to the local directory diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index fb71040..74c4106 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -119,6 +119,11 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase): return d1 def test_send_file_pre_generated_code(self): + return self._do_test_send_file_pre_generated_code(False) + def test_send_file_pre_generated_code_override(self): + return self._do_test_send_file_pre_generated_code(True) + + def _do_test_send_file_pre_generated_code(self, override_filename): self.maxDiff=None code = u"1-abc" filename = "testfile" @@ -136,22 +141,26 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase): "--code", code, filename, ] - d1 = getProcessOutputAndValue(wormhole, send_args, path=send_dir) receive_dir = self.mktemp() os.mkdir(receive_dir) receive_args = server_args + [ "receive", "--accept-file", - code, ] + if override_filename: + receive_args.extend(["-o", "outfile"]) + filename = "outfile" + receive_args.append(code) + + d1 = getProcessOutputAndValue(wormhole, send_args, path=send_dir) d2 = getProcessOutputAndValue(wormhole, receive_args, path=receive_dir) def _check_sender(res): out, err, rc = res out = out.decode("utf-8") err = err.decode("utf-8") self.failUnlessEqual(err, "") - self.failUnlessIn("Sending %d byte file named '%s'\n" % - (len(message), filename), out) + self.failUnlessIn("Sending %d byte file named 'testfile'\n" % + len(message), out) self.failUnlessIn("On the other computer, please run: " "wormhole receive\n" "Wormhole code is: %s\n\n" % code,