diff --git a/src/wormhole/scripts/send_common.py b/src/wormhole/scripts/send_common.py index 260e201..20f2b63 100644 --- a/src/wormhole/scripts/send_common.py +++ b/src/wormhole/scripts/send_common.py @@ -37,8 +37,6 @@ def build_phase1_data(args): if not os.path.exists(what): raise TransferError("Cannot send: no file/directory named '%s'" % args.what) - - what = os.path.join(args.cwd, args.what) basename = os.path.basename(what) if os.path.isfile(what): @@ -88,4 +86,4 @@ def build_phase1_data(args): % (filesize, basename), file=args.stdout) return phase1, fd_to_send - raise TypeError("'%s' is neither file nor directory" % what) + raise TypeError("'%s' is neither file nor directory" % args.what) diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index 3834c8f..1e9767c 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -1,4 +1,4 @@ -import os, sys, re, io +import os, sys, re, io, zipfile from twisted.trial import unittest from twisted.python import procutils, log from twisted.internet.utils import getProcessOutputAndValue @@ -7,6 +7,129 @@ from twisted.internet.threads import deferToThread from .. import __version__ from .common import ServerBase from ..scripts import runner, cmd_send_blocking, cmd_receive +from ..scripts.send_common import build_phase1_data +from ..errors import TransferError + +class Phase1Data(unittest.TestCase): + def test_text(self): + message = "blah blah blah ponies" + + send_args = [ "send", "--text", message ] + args = runner.parser.parse_args(send_args) + args.cwd = os.getcwd() + args.stdout = io.StringIO() + args.stderr = io.StringIO() + + d, fd_to_send = build_phase1_data(args) + + self.assertIn("message", d) + self.assertNotIn("file", d) + self.assertNotIn("directory", d) + self.assertEqual(d["message"], message) + self.assertEqual(fd_to_send, None) + + def test_file(self): + filename = "my file" + message = b"yay ponies\n" + send_dir = self.mktemp() + os.mkdir(send_dir) + abs_filename = os.path.join(send_dir, filename) + with open(abs_filename, "wb") as f: + f.write(message) + + send_args = [ "send", filename ] + args = runner.parser.parse_args(send_args) + args.cwd = send_dir + args.stdout = io.StringIO() + args.stderr = io.StringIO() + + d, fd_to_send = build_phase1_data(args) + + self.assertNotIn("message", d) + self.assertIn("file", d) + self.assertNotIn("directory", d) + self.assertEqual(d["file"]["filesize"], len(message)) + self.assertEqual(d["file"]["filename"], filename) + self.assertEqual(fd_to_send.tell(), 0) + self.assertEqual(fd_to_send.read(), message) + + def test_missing_file(self): + filename = "missing" + send_dir = self.mktemp() + os.mkdir(send_dir) + + send_args = [ "send", filename ] + args = runner.parser.parse_args(send_args) + args.cwd = send_dir + args.stdout = io.StringIO() + args.stderr = io.StringIO() + + e = self.assertRaises(TransferError, build_phase1_data, args) + self.assertEqual(str(e), + "Cannot send: no file/directory named '%s'" % filename) + + def test_directory(self): + parent_dir = self.mktemp() + os.mkdir(parent_dir) + send_dir = "dirname" + os.mkdir(os.path.join(parent_dir, send_dir)) + ponies = [str(i) for i in range(5)] + for p in ponies: + with open(os.path.join(parent_dir, send_dir, p), "wb") as f: + f.write(("%s ponies\n" % p).encode("ascii")) + + send_args = [ "send", send_dir ] + args = runner.parser.parse_args(send_args) + args.cwd = parent_dir + args.stdout = io.StringIO() + args.stderr = io.StringIO() + + d, fd_to_send = build_phase1_data(args) + + self.assertNotIn("message", d) + self.assertNotIn("file", d) + self.assertIn("directory", d) + self.assertEqual(d["directory"]["dirname"], send_dir) + self.assertEqual(d["directory"]["mode"], "zipfile/deflated") + self.assertEqual(d["directory"]["numfiles"], 5) + self.assertIn("numbytes", d["directory"]) + self.assertIsInstance(d["directory"]["numbytes"], type(123)) + + self.assertEqual(fd_to_send.tell(), 0) + zdata = fd_to_send.read() + self.assertEqual(len(zdata), d["directory"]["zipsize"]) + fd_to_send.seek(0, 0) + with zipfile.ZipFile(fd_to_send, "r", zipfile.ZIP_DEFLATED) as zf: + zipnames = zf.namelist() + self.assertEqual(list(sorted(ponies)), list(sorted(zipnames))) + for name in zipnames: + contents = zf.open(name, "r").read() + self.assertEqual(("%s ponies\n" % name).encode("ascii"), + contents) + + def test_unknown(self): + filename = "unknown" + send_dir = self.mktemp() + os.mkdir(send_dir) + abs_filename = os.path.abspath(os.path.join(send_dir, filename)) + + try: + os.mkfifo(abs_filename) + except OSError: + raise unittest.SkipTest("is mkfifo supported on this platform?") + self.assertFalse(os.path.isfile(abs_filename)) + self.assertFalse(os.path.isdir(abs_filename)) + + send_args = [ "send", filename ] + args = runner.parser.parse_args(send_args) + args.cwd = send_dir + args.stdout = io.StringIO() + args.stderr = io.StringIO() + + e = self.assertRaises(TypeError, build_phase1_data, args) + self.assertEqual(str(e), + "'%s' is neither file nor directory" % filename) + class ScriptsBase: def find_executable(self):