diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index c6dacbc..c262b9a 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -267,10 +267,8 @@ class TwistedReceiver: return abs_destname def _remove_existing(self, path): - if os.path.isfile(path): - os.remove(path) - elif os.path.isdir(path): - shutil.rmtree(path) + if os.path.isfile(path): os.remove(path) + if os.path.isdir(path): shutil.rmtree(path) def _ask_permission(self): with self.args.timing.add("permission", waiting="user") as t: diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index a354b62..49562e0 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -238,7 +238,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): @inlineCallbacks def _do_test(self, as_subprocess=False, mode="text", addslash=False, override_filename=False, - fake_tor=False, overwrite=False): + fake_tor=False, overwrite=False, mock_accept=False): assert mode in ("text", "file", "empty-file", "directory", "slow-text") if fake_tor: assert not as_subprocess @@ -272,7 +272,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): send_cfg.what = send_filename receive_filename = send_filename - recv_cfg.accept_file = True + recv_cfg.accept_file = False if mock_accept else True if override_filename: recv_cfg.output_file = receive_filename = "outfile" if overwrite: @@ -310,7 +310,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): send_cfg.what = send_dirname_arg receive_dirname = send_dirname - recv_cfg.accept_file = True + recv_cfg.accept_file = False if mock_accept else True if override_filename: recv_cfg.output_file = receive_dirname = "outdir" if overwrite: @@ -395,6 +395,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): with mock.patch.object(cmd_send, "VERIFY_TIMER", 0), \ mock.patch.object(cmd_receive, "VERIFY_TIMER", 0): yield gatherResults([send_d, receive_d], True) + elif mock_accept: + with mock.patch.object(cmd_receive.six.moves, 'input', return_value='y'): + yield gatherResults([send_d, receive_d], True) else: yield gatherResults([send_d, receive_d], True) @@ -513,6 +516,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): return self._do_test(mode="file", override_filename=True) def test_file_overwrite(self): return self._do_test(mode="file", overwrite=True) + def test_file_overwrite_mock_accept(self): + return self._do_test(mode="file", overwrite=True, mock_accept=True) def test_file_tor(self): return self._do_test(mode="file", fake_tor=True) def test_empty_file(self): @@ -526,6 +531,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): return self._do_test(mode="directory", override_filename=True) def test_directory_overwrite(self): return self._do_test(mode="directory", overwrite=True) + def test_directory_overwrite_mock_accept(self): + return self._do_test(mode="directory", overwrite=True, mock_accept=True) def test_slow_text(self): return self._do_test(mode="slow-text")