tolerate trailing slash on "wormhole send dirname/"

Previously, the trailing slash would cause the receiving side to get an
empty-named directory.
This commit is contained in:
Brian Warner 2016-03-24 08:46:29 -07:00
parent 5a018c23f4
commit 049fac01db
2 changed files with 25 additions and 4 deletions

View File

@ -34,6 +34,7 @@ def build_phase1_data(args):
return phase1, fd_to_send return phase1, fd_to_send
what = os.path.join(args.cwd, args.what) what = os.path.join(args.cwd, args.what)
what = what.rstrip(os.sep)
if not os.path.exists(what): if not os.path.exists(what):
raise TransferError("Cannot send: no file/directory named '%s'" % raise TransferError("Cannot send: no file/directory named '%s'" %
args.what) args.what)

View File

@ -71,7 +71,7 @@ class Phase1Data(unittest.TestCase):
self.assertEqual(str(e), self.assertEqual(str(e),
"Cannot send: no file/directory named '%s'" % filename) "Cannot send: no file/directory named '%s'" % filename)
def test_directory(self): def _do_test_directory(self, addslash):
parent_dir = self.mktemp() parent_dir = self.mktemp()
os.mkdir(parent_dir) os.mkdir(parent_dir)
send_dir = "dirname" send_dir = "dirname"
@ -81,7 +81,10 @@ class Phase1Data(unittest.TestCase):
with open(os.path.join(parent_dir, send_dir, p), "wb") as f: with open(os.path.join(parent_dir, send_dir, p), "wb") as f:
f.write(("%s ponies\n" % p).encode("ascii")) f.write(("%s ponies\n" % p).encode("ascii"))
send_args = [ "send", send_dir ] send_dir_arg = send_dir
if addslash:
send_dir_arg += os.sep
send_args = [ "send", send_dir_arg ]
args = runner.parser.parse_args(send_args) args = runner.parser.parse_args(send_args)
args.cwd = parent_dir args.cwd = parent_dir
args.stdout = io.StringIO() args.stdout = io.StringIO()
@ -110,6 +113,12 @@ class Phase1Data(unittest.TestCase):
self.assertEqual(("%s ponies\n" % name).encode("ascii"), self.assertEqual(("%s ponies\n" % name).encode("ascii"),
contents) contents)
def test_directory(self):
return self._do_test_directory(addslash=False)
def test_directory_addslash(self):
return self._do_test_directory(addslash=True)
def test_unknown(self): def test_unknown(self):
filename = "unknown" filename = "unknown"
send_dir = self.mktemp() send_dir = self.mktemp()
@ -210,7 +219,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def _do_test(self, as_subprocess=False, def _do_test(self, as_subprocess=False,
mode="text", override_filename=False, mode="text", addslash=False, override_filename=False,
sender_twisted=False, receiver_twisted=False): sender_twisted=False, receiver_twisted=False):
assert mode in ("text", "file", "directory") assert mode in ("text", "file", "directory")
common_args = ["--hide-progress", common_args = ["--hide-progress",
@ -266,7 +275,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
for i in range(5): for i in range(5):
with open(os.path.join(source_dir, str(i)), "w") as f: with open(os.path.join(source_dir, str(i)), "w") as f:
f.write(message(i)) f.write(message(i))
send_args.append(os.path.join("middle", send_dirname)) send_dirname_arg = os.path.join("middle", send_dirname)
if addslash:
send_dirname_arg += os.sep
send_args.append(send_dirname_arg)
receive_dirname = send_dirname receive_dirname = send_dirname
receive_args.append("--accept-file") receive_args.append("--accept-file")
@ -427,12 +439,20 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
def test_directory(self): def test_directory(self):
return self._do_test(mode="directory") return self._do_test(mode="directory")
def test_directory_addslash(self):
return self._do_test(mode="directory", addslash=True)
def test_directory_override(self): def test_directory_override(self):
return self._do_test(mode="directory", override_filename=True) return self._do_test(mode="directory", override_filename=True)
def test_directory_twisted_to_blocking(self): def test_directory_twisted_to_blocking(self):
return self._do_test(mode="directory", sender_twisted=True) return self._do_test(mode="directory", sender_twisted=True)
def test_directory_twisted_to_blocking_addslash(self):
return self._do_test(mode="directory", addslash=True,
sender_twisted=True)
def test_directory_blocking_to_twisted(self): def test_directory_blocking_to_twisted(self):
return self._do_test(mode="directory", receiver_twisted=True) return self._do_test(mode="directory", receiver_twisted=True)
def test_directory_twisted_to_twisted(self): def test_directory_twisted_to_twisted(self):
return self._do_test(mode="directory", return self._do_test(mode="directory",
sender_twisted=True, receiver_twisted=True) sender_twisted=True, receiver_twisted=True)
def test_directory_twisted_to_twisted_addslash(self):
return self._do_test(mode="directory", addslash=True,
sender_twisted=True, receiver_twisted=True)