commit
4df4cf0016
|
@ -1,5 +1,5 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import os, sys, six, tempfile, zipfile, hashlib
|
import os, sys, six, tempfile, zipfile, hashlib, shutil
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from humanize import naturalsize
|
from humanize import naturalsize
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
|
@ -257,15 +257,26 @@ class TwistedReceiver:
|
||||||
|
|
||||||
# get confirmation from the user before writing to the local directory
|
# get confirmation from the user before writing to the local directory
|
||||||
if os.path.exists(abs_destname):
|
if os.path.exists(abs_destname):
|
||||||
self._msg(u"Error: refusing to overwrite existing '%s'" % destname)
|
if self.args.output_file: # overwrite is intentional
|
||||||
raise TransferRejectedError()
|
self._msg(u"Overwriting '%s'" % destname)
|
||||||
|
if self.args.accept_file:
|
||||||
|
self._remove_existing(abs_destname)
|
||||||
|
else:
|
||||||
|
self._msg(u"Error: refusing to overwrite existing '%s'" % destname)
|
||||||
|
raise TransferRejectedError()
|
||||||
return abs_destname
|
return abs_destname
|
||||||
|
|
||||||
|
def _remove_existing(self, path):
|
||||||
|
if os.path.isfile(path): os.remove(path)
|
||||||
|
if os.path.isdir(path): shutil.rmtree(path)
|
||||||
|
|
||||||
def _ask_permission(self):
|
def _ask_permission(self):
|
||||||
with self.args.timing.add("permission", waiting="user") as t:
|
with self.args.timing.add("permission", waiting="user") as t:
|
||||||
while True and not self.args.accept_file:
|
while True and not self.args.accept_file:
|
||||||
ok = six.moves.input("ok? (y/n): ")
|
ok = six.moves.input("ok? (y/n): ")
|
||||||
if ok.lower().startswith("y"):
|
if ok.lower().startswith("y"):
|
||||||
|
if os.path.exists(self.abs_destname):
|
||||||
|
self._remove_existing(self.abs_destname)
|
||||||
break
|
break
|
||||||
print(u"transfer rejected", file=sys.stderr)
|
print(u"transfer rejected", file=sys.stderr)
|
||||||
t.detail(answer="no")
|
t.detail(answer="no")
|
||||||
|
|
|
@ -238,7 +238,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def _do_test(self, as_subprocess=False,
|
def _do_test(self, as_subprocess=False,
|
||||||
mode="text", addslash=False, override_filename=False,
|
mode="text", addslash=False, override_filename=False,
|
||||||
fake_tor=False):
|
fake_tor=False, overwrite=False, mock_accept=False):
|
||||||
assert mode in ("text", "file", "empty-file", "directory", "slow-text")
|
assert mode in ("text", "file", "empty-file", "directory", "slow-text")
|
||||||
if fake_tor:
|
if fake_tor:
|
||||||
assert not as_subprocess
|
assert not as_subprocess
|
||||||
|
@ -272,9 +272,14 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
send_cfg.what = send_filename
|
send_cfg.what = send_filename
|
||||||
receive_filename = send_filename
|
receive_filename = send_filename
|
||||||
|
|
||||||
recv_cfg.accept_file = True
|
recv_cfg.accept_file = False if mock_accept else True
|
||||||
if override_filename:
|
if override_filename:
|
||||||
recv_cfg.output_file = receive_filename = "outfile"
|
recv_cfg.output_file = receive_filename = "outfile"
|
||||||
|
if overwrite:
|
||||||
|
recv_cfg.output_file = receive_filename
|
||||||
|
existing_file = os.path.join(receive_dir, receive_filename)
|
||||||
|
with open(existing_file, 'w') as f:
|
||||||
|
f.write('pls overwrite me')
|
||||||
|
|
||||||
elif mode == "directory":
|
elif mode == "directory":
|
||||||
# $send_dir/
|
# $send_dir/
|
||||||
|
@ -305,9 +310,12 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
send_cfg.what = send_dirname_arg
|
send_cfg.what = send_dirname_arg
|
||||||
receive_dirname = send_dirname
|
receive_dirname = send_dirname
|
||||||
|
|
||||||
recv_cfg.accept_file = True
|
recv_cfg.accept_file = False if mock_accept else True
|
||||||
if override_filename:
|
if override_filename:
|
||||||
recv_cfg.output_file = receive_dirname = "outdir"
|
recv_cfg.output_file = receive_dirname = "outdir"
|
||||||
|
if overwrite:
|
||||||
|
recv_cfg.output_file = receive_dirname
|
||||||
|
os.mkdir(os.path.join(receive_dir, receive_dirname))
|
||||||
|
|
||||||
if as_subprocess:
|
if as_subprocess:
|
||||||
wormhole_bin = self.find_executable()
|
wormhole_bin = self.find_executable()
|
||||||
|
@ -387,6 +395,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
with mock.patch.object(cmd_send, "VERIFY_TIMER", 0), \
|
with mock.patch.object(cmd_send, "VERIFY_TIMER", 0), \
|
||||||
mock.patch.object(cmd_receive, "VERIFY_TIMER", 0):
|
mock.patch.object(cmd_receive, "VERIFY_TIMER", 0):
|
||||||
yield gatherResults([send_d, receive_d], True)
|
yield gatherResults([send_d, receive_d], True)
|
||||||
|
elif mock_accept:
|
||||||
|
with mock.patch.object(cmd_receive.six.moves, 'input', return_value='y'):
|
||||||
|
yield gatherResults([send_d, receive_d], True)
|
||||||
else:
|
else:
|
||||||
yield gatherResults([send_d, receive_d], True)
|
yield gatherResults([send_d, receive_d], True)
|
||||||
|
|
||||||
|
@ -503,6 +514,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
return self._do_test(mode="file")
|
return self._do_test(mode="file")
|
||||||
def test_file_override(self):
|
def test_file_override(self):
|
||||||
return self._do_test(mode="file", override_filename=True)
|
return self._do_test(mode="file", override_filename=True)
|
||||||
|
def test_file_overwrite(self):
|
||||||
|
return self._do_test(mode="file", overwrite=True)
|
||||||
|
def test_file_overwrite_mock_accept(self):
|
||||||
|
return self._do_test(mode="file", overwrite=True, mock_accept=True)
|
||||||
def test_file_tor(self):
|
def test_file_tor(self):
|
||||||
return self._do_test(mode="file", fake_tor=True)
|
return self._do_test(mode="file", fake_tor=True)
|
||||||
def test_empty_file(self):
|
def test_empty_file(self):
|
||||||
|
@ -514,6 +529,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
return self._do_test(mode="directory", addslash=True)
|
return self._do_test(mode="directory", addslash=True)
|
||||||
def test_directory_override(self):
|
def test_directory_override(self):
|
||||||
return self._do_test(mode="directory", override_filename=True)
|
return self._do_test(mode="directory", override_filename=True)
|
||||||
|
def test_directory_overwrite(self):
|
||||||
|
return self._do_test(mode="directory", overwrite=True)
|
||||||
|
def test_directory_overwrite_mock_accept(self):
|
||||||
|
return self._do_test(mode="directory", overwrite=True, mock_accept=True)
|
||||||
|
|
||||||
def test_slow_text(self):
|
def test_slow_text(self):
|
||||||
return self._do_test(mode="slow-text")
|
return self._do_test(mode="slow-text")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user