test send/receive directory

This commit is contained in:
Brian Warner 2015-11-29 01:33:15 -06:00
parent 6fefcde061
commit 3a343f9895
2 changed files with 90 additions and 4 deletions

View File

@ -89,9 +89,12 @@ def accept_directory(args, them_d, w):
w.send_data(data) w.send_data(data)
return 1 return 1
# the basename() is intended to protect us against if args.output_file:
# "~/.ssh/authorized_keys" and other attacks dirname = args.output_file
dirname = os.path.basename(file_data["dirname"]) # unicode else:
# the basename() is intended to protect us against
# "~/.ssh/authorized_keys" and other attacks
dirname = os.path.basename(file_data["dirname"]) # unicode
filesize = file_data["zipsize"] filesize = file_data["zipsize"]
num_files = file_data["numfiles"] num_files = file_data["numfiles"]
num_bytes = file_data["numbytes"] num_bytes = file_data["numbytes"]
@ -154,6 +157,9 @@ def accept_directory(args, them_d, w):
print("Unpacking zipfile..") print("Unpacking zipfile..")
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
zf.extractall(path=dirname) zf.extractall(path=dirname)
# extractall() appears to offer some protection against malicious
# pathnames. For example, "/tmp/oops" and "../tmp/oops" both do the
# same thing as the (safe) "tmp/oops".
print("Received files written to %s/" % dirname) print("Received files written to %s/" % dirname)
record_pipe.send_record(b"ok\n") record_pipe.send_record(b"ok\n")

View File

@ -1,4 +1,4 @@
import os, sys import os, sys, re
from twisted.trial import unittest from twisted.trial import unittest
from twisted.python import procutils, log from twisted.python import procutils, log
from twisted.internet.utils import getProcessOutputAndValue from twisted.internet.utils import getProcessOutputAndValue
@ -177,3 +177,83 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
self.failUnlessEqual(f.read(), message) self.failUnlessEqual(f.read(), message)
d1.addCallback(_check_receiver) d1.addCallback(_check_receiver)
return d1 return d1
def test_send_directory_pre_generated_code(self):
return self._do_test_send_directory_pre_generated_code(False)
def test_send_directory_pre_generated_code_override(self):
return self._do_test_send_directory_pre_generated_code(True)
def _do_test_send_directory_pre_generated_code(self, override_dirname):
self.maxDiff=None
code = u"1-abc"
dirname = "testdir"
def message(i):
return "test message %d\n" % i
source_parent_dir = self.mktemp()
os.mkdir(source_parent_dir)
os.mkdir(os.path.join(source_parent_dir, "middle"))
source_dir = os.path.join(source_parent_dir, "middle", dirname)
os.mkdir(source_dir)
for i in range(5):
with open(os.path.join(source_dir, str(i)), "w") as f:
f.write(message(i))
target_parent_dir = self.mktemp()
os.mkdir(target_parent_dir)
wormhole = self.find_executable()
server_args = ["--relay-url", self.relayurl]
send_args = server_args + [
"send",
"--code", code,
os.path.join("middle", dirname),
]
receive_args = server_args + [
"receive", "--accept-file",
]
if override_dirname:
receive_args.extend(["-o", "outdir"])
dirname = "outdir"
receive_args.append(code)
d1 = getProcessOutputAndValue(wormhole, send_args,
path=source_parent_dir)
d2 = getProcessOutputAndValue(wormhole, receive_args,
path=target_parent_dir)
def _check_sender(res):
out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessEqual(err, "")
self.failUnlessIn("Sending directory", out)
self.failUnlessIn("named 'testdir'", out)
self.failUnlessIn("On the other computer, please run: "
"wormhole receive\n"
"Wormhole code is: %s\n\n" % code,
out)
self.failUnlessIn("File sent.. waiting for confirmation\n"
"Confirmation received. Transfer complete.\n",
out)
self.failUnlessEqual(rc, 0)
return d2
d1.addCallback(_check_sender)
def _check_receiver(res):
out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnless(re.search(r"Receiving \d+ bytes for '%s'" %
dirname, out))
self.failUnlessIn("Received files written to %s" % dirname, out)
self.failUnlessEqual(err, "")
self.failUnlessEqual(rc, 0)
fn = os.path.join(target_parent_dir, dirname)
self.failUnless(os.path.exists(fn))
for i in range(5):
fn = os.path.join(target_parent_dir, dirname, str(i))
with open(fn, "r") as f:
self.failUnlessEqual(f.read(), message(i))
d1.addCallback(_check_receiver)
return d1