diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index fb43412..537582a 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -257,8 +257,11 @@ class TwistedReceiver: # get confirmation from the user before writing to the local directory if os.path.exists(abs_destname): - self._msg(u"Error: refusing to overwrite existing '%s'" % destname) - raise TransferRejectedError() + if self.args.output_file: # overwrite is intentional + self._msg(u"Overwriting '%s'" % destname) + else: + self._msg(u"Error: refusing to overwrite existing '%s'" % destname) + raise TransferRejectedError() return abs_destname def _ask_permission(self): diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index c5e2ede..a354b62 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -238,7 +238,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): @inlineCallbacks def _do_test(self, as_subprocess=False, mode="text", addslash=False, override_filename=False, - fake_tor=False): + fake_tor=False, overwrite=False): assert mode in ("text", "file", "empty-file", "directory", "slow-text") if fake_tor: assert not as_subprocess @@ -275,6 +275,11 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): recv_cfg.accept_file = True if override_filename: recv_cfg.output_file = receive_filename = "outfile" + if overwrite: + recv_cfg.output_file = receive_filename + existing_file = os.path.join(receive_dir, receive_filename) + with open(existing_file, 'w') as f: + f.write('pls overwrite me') elif mode == "directory": # $send_dir/ @@ -308,6 +313,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): recv_cfg.accept_file = True if override_filename: recv_cfg.output_file = receive_dirname = "outdir" + if overwrite: + recv_cfg.output_file = receive_dirname + os.mkdir(os.path.join(receive_dir, receive_dirname)) if as_subprocess: wormhole_bin = self.find_executable() @@ -503,6 +511,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): return self._do_test(mode="file") def test_file_override(self): return self._do_test(mode="file", override_filename=True) + def test_file_overwrite(self): + return self._do_test(mode="file", overwrite=True) def test_file_tor(self): return self._do_test(mode="file", fake_tor=True) def test_empty_file(self): @@ -514,6 +524,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): return self._do_test(mode="directory", addslash=True) def test_directory_override(self): return self._do_test(mode="directory", override_filename=True) + def test_directory_overwrite(self): + return self._do_test(mode="directory", overwrite=True) def test_slow_text(self): return self._do_test(mode="slow-text")