Merge PR139

closes #73
This commit is contained in:
Brian Warner 2017-02-22 17:48:20 -08:00
commit 4df4cf0016
2 changed files with 36 additions and 6 deletions

View File

@ -1,5 +1,5 @@
from __future__ import print_function from __future__ import print_function
import os, sys, six, tempfile, zipfile, hashlib import os, sys, six, tempfile, zipfile, hashlib, shutil
from tqdm import tqdm from tqdm import tqdm
from humanize import naturalsize from humanize import naturalsize
from twisted.internet import reactor from twisted.internet import reactor
@ -257,15 +257,26 @@ class TwistedReceiver:
# get confirmation from the user before writing to the local directory # get confirmation from the user before writing to the local directory
if os.path.exists(abs_destname): if os.path.exists(abs_destname):
self._msg(u"Error: refusing to overwrite existing '%s'" % destname) if self.args.output_file: # overwrite is intentional
raise TransferRejectedError() self._msg(u"Overwriting '%s'" % destname)
if self.args.accept_file:
self._remove_existing(abs_destname)
else:
self._msg(u"Error: refusing to overwrite existing '%s'" % destname)
raise TransferRejectedError()
return abs_destname return abs_destname
def _remove_existing(self, path):
if os.path.isfile(path): os.remove(path)
if os.path.isdir(path): shutil.rmtree(path)
def _ask_permission(self): def _ask_permission(self):
with self.args.timing.add("permission", waiting="user") as t: with self.args.timing.add("permission", waiting="user") as t:
while True and not self.args.accept_file: while True and not self.args.accept_file:
ok = six.moves.input("ok? (y/n): ") ok = six.moves.input("ok? (y/n): ")
if ok.lower().startswith("y"): if ok.lower().startswith("y"):
if os.path.exists(self.abs_destname):
self._remove_existing(self.abs_destname)
break break
print(u"transfer rejected", file=sys.stderr) print(u"transfer rejected", file=sys.stderr)
t.detail(answer="no") t.detail(answer="no")

View File

@ -238,7 +238,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def _do_test(self, as_subprocess=False, def _do_test(self, as_subprocess=False,
mode="text", addslash=False, override_filename=False, mode="text", addslash=False, override_filename=False,
fake_tor=False): fake_tor=False, overwrite=False, mock_accept=False):
assert mode in ("text", "file", "empty-file", "directory", "slow-text") assert mode in ("text", "file", "empty-file", "directory", "slow-text")
if fake_tor: if fake_tor:
assert not as_subprocess assert not as_subprocess
@ -272,9 +272,14 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
send_cfg.what = send_filename send_cfg.what = send_filename
receive_filename = send_filename receive_filename = send_filename
recv_cfg.accept_file = True recv_cfg.accept_file = False if mock_accept else True
if override_filename: if override_filename:
recv_cfg.output_file = receive_filename = "outfile" recv_cfg.output_file = receive_filename = "outfile"
if overwrite:
recv_cfg.output_file = receive_filename
existing_file = os.path.join(receive_dir, receive_filename)
with open(existing_file, 'w') as f:
f.write('pls overwrite me')
elif mode == "directory": elif mode == "directory":
# $send_dir/ # $send_dir/
@ -305,9 +310,12 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
send_cfg.what = send_dirname_arg send_cfg.what = send_dirname_arg
receive_dirname = send_dirname receive_dirname = send_dirname
recv_cfg.accept_file = True recv_cfg.accept_file = False if mock_accept else True
if override_filename: if override_filename:
recv_cfg.output_file = receive_dirname = "outdir" recv_cfg.output_file = receive_dirname = "outdir"
if overwrite:
recv_cfg.output_file = receive_dirname
os.mkdir(os.path.join(receive_dir, receive_dirname))
if as_subprocess: if as_subprocess:
wormhole_bin = self.find_executable() wormhole_bin = self.find_executable()
@ -387,6 +395,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
with mock.patch.object(cmd_send, "VERIFY_TIMER", 0), \ with mock.patch.object(cmd_send, "VERIFY_TIMER", 0), \
mock.patch.object(cmd_receive, "VERIFY_TIMER", 0): mock.patch.object(cmd_receive, "VERIFY_TIMER", 0):
yield gatherResults([send_d, receive_d], True) yield gatherResults([send_d, receive_d], True)
elif mock_accept:
with mock.patch.object(cmd_receive.six.moves, 'input', return_value='y'):
yield gatherResults([send_d, receive_d], True)
else: else:
yield gatherResults([send_d, receive_d], True) yield gatherResults([send_d, receive_d], True)
@ -503,6 +514,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
return self._do_test(mode="file") return self._do_test(mode="file")
def test_file_override(self): def test_file_override(self):
return self._do_test(mode="file", override_filename=True) return self._do_test(mode="file", override_filename=True)
def test_file_overwrite(self):
return self._do_test(mode="file", overwrite=True)
def test_file_overwrite_mock_accept(self):
return self._do_test(mode="file", overwrite=True, mock_accept=True)
def test_file_tor(self): def test_file_tor(self):
return self._do_test(mode="file", fake_tor=True) return self._do_test(mode="file", fake_tor=True)
def test_empty_file(self): def test_empty_file(self):
@ -514,6 +529,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
return self._do_test(mode="directory", addslash=True) return self._do_test(mode="directory", addslash=True)
def test_directory_override(self): def test_directory_override(self):
return self._do_test(mode="directory", override_filename=True) return self._do_test(mode="directory", override_filename=True)
def test_directory_overwrite(self):
return self._do_test(mode="directory", overwrite=True)
def test_directory_overwrite_mock_accept(self):
return self._do_test(mode="directory", overwrite=True, mock_accept=True)
def test_slow_text(self): def test_slow_text(self):
return self._do_test(mode="slow-text") return self._do_test(mode="slow-text")