diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index fb43412..c262b9a 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -1,5 +1,5 @@ from __future__ import print_function -import os, sys, six, tempfile, zipfile, hashlib +import os, sys, six, tempfile, zipfile, hashlib, shutil from tqdm import tqdm from humanize import naturalsize from twisted.internet import reactor @@ -257,15 +257,26 @@ class TwistedReceiver: # get confirmation from the user before writing to the local directory if os.path.exists(abs_destname): - self._msg(u"Error: refusing to overwrite existing '%s'" % destname) - raise TransferRejectedError() + if self.args.output_file: # overwrite is intentional + self._msg(u"Overwriting '%s'" % destname) + if self.args.accept_file: + self._remove_existing(abs_destname) + else: + self._msg(u"Error: refusing to overwrite existing '%s'" % destname) + raise TransferRejectedError() return abs_destname + def _remove_existing(self, path): + if os.path.isfile(path): os.remove(path) + if os.path.isdir(path): shutil.rmtree(path) + def _ask_permission(self): with self.args.timing.add("permission", waiting="user") as t: while True and not self.args.accept_file: ok = six.moves.input("ok? (y/n): ") if ok.lower().startswith("y"): + if os.path.exists(self.abs_destname): + self._remove_existing(self.abs_destname) break print(u"transfer rejected", file=sys.stderr) t.detail(answer="no") diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index c5e2ede..49562e0 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -238,7 +238,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): @inlineCallbacks def _do_test(self, as_subprocess=False, mode="text", addslash=False, override_filename=False, - fake_tor=False): + fake_tor=False, overwrite=False, mock_accept=False): assert mode in ("text", "file", "empty-file", "directory", "slow-text") if fake_tor: assert not as_subprocess @@ -272,9 +272,14 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): send_cfg.what = send_filename receive_filename = send_filename - recv_cfg.accept_file = True + recv_cfg.accept_file = False if mock_accept else True if override_filename: recv_cfg.output_file = receive_filename = "outfile" + if overwrite: + recv_cfg.output_file = receive_filename + existing_file = os.path.join(receive_dir, receive_filename) + with open(existing_file, 'w') as f: + f.write('pls overwrite me') elif mode == "directory": # $send_dir/ @@ -305,9 +310,12 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): send_cfg.what = send_dirname_arg receive_dirname = send_dirname - recv_cfg.accept_file = True + recv_cfg.accept_file = False if mock_accept else True if override_filename: recv_cfg.output_file = receive_dirname = "outdir" + if overwrite: + recv_cfg.output_file = receive_dirname + os.mkdir(os.path.join(receive_dir, receive_dirname)) if as_subprocess: wormhole_bin = self.find_executable() @@ -387,6 +395,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): with mock.patch.object(cmd_send, "VERIFY_TIMER", 0), \ mock.patch.object(cmd_receive, "VERIFY_TIMER", 0): yield gatherResults([send_d, receive_d], True) + elif mock_accept: + with mock.patch.object(cmd_receive.six.moves, 'input', return_value='y'): + yield gatherResults([send_d, receive_d], True) else: yield gatherResults([send_d, receive_d], True) @@ -503,6 +514,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): return self._do_test(mode="file") def test_file_override(self): return self._do_test(mode="file", override_filename=True) + def test_file_overwrite(self): + return self._do_test(mode="file", overwrite=True) + def test_file_overwrite_mock_accept(self): + return self._do_test(mode="file", overwrite=True, mock_accept=True) def test_file_tor(self): return self._do_test(mode="file", fake_tor=True) def test_empty_file(self): @@ -514,6 +529,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): return self._do_test(mode="directory", addslash=True) def test_directory_override(self): return self._do_test(mode="directory", override_filename=True) + def test_directory_overwrite(self): + return self._do_test(mode="directory", overwrite=True) + def test_directory_overwrite_mock_accept(self): + return self._do_test(mode="directory", overwrite=True, mock_accept=True) def test_slow_text(self): return self._do_test(mode="slow-text")