diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index ab723e2..e1473c4 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -1,10 +1,12 @@ -import os, sys, re +import os, sys, re, io from twisted.trial import unittest from twisted.python import procutils, log from twisted.internet.utils import getProcessOutputAndValue from twisted.internet.defer import inlineCallbacks +from twisted.internet.threads import deferToThread from .. import __version__ from .common import ServerBase +from ..scripts import runner, cmd_send, cmd_receive class ScriptsBase: def find_executable(self): @@ -81,9 +83,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): return d @inlineCallbacks - def _do_test(self, mode="text", override_filename=False): + def _do_test(self, as_subprocess=False, + mode="text", override_filename=False): assert mode in ("text", "file", "directory") - wormhole = self.find_executable() common_args = ["--hide-progress", "--relay-url", self.relayurl, "--transit-helper", ""] @@ -147,78 +149,102 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): receive_args.append(code) - send_d = getProcessOutputAndValue(wormhole, send_args, path=send_dir) - receive_d = getProcessOutputAndValue(wormhole, receive_args, - path=receive_dir) + if as_subprocess: + wormhole_bin = self.find_executable() + send_d = getProcessOutputAndValue(wormhole_bin, send_args, + path=send_dir) + receive_d = getProcessOutputAndValue(wormhole_bin, receive_args, + path=receive_dir) + send_res = yield send_d + send_stdout = send_res[0].decode("utf-8") + send_stderr = send_res[1].decode("utf-8") + send_rc = send_res[2] + receive_res = yield receive_d + receive_stdout = receive_res[0].decode("utf-8") + receive_stderr = receive_res[1].decode("utf-8") + receive_rc = receive_res[2] + else: + sargs = runner.parser.parse_args(send_args) + sargs.cwd = send_dir + sargs.stdout = io.StringIO() + sargs.stderr = io.StringIO() + rargs = runner.parser.parse_args(receive_args) + rargs.cwd = receive_dir + rargs.stdout = io.StringIO() + rargs.stderr = io.StringIO() + send_d = deferToThread(cmd_send.send, sargs) + receive_d = deferToThread(cmd_receive.receive, rargs) + + send_rc = yield send_d + send_stdout = sargs.stdout.getvalue() + send_stderr = sargs.stderr.getvalue() + + receive_rc = yield receive_d + receive_stdout = rargs.stdout.getvalue() + receive_stderr = rargs.stderr.getvalue() self.maxDiff = None # show full output for assertion failures # check sender - send_res = yield send_d - out, err, rc = send_res - out = out.decode("utf-8") - err = err.decode("utf-8") - self.failUnlessEqual(err, "") + self.failUnlessEqual(send_stderr, "") if mode == "text": expected = ("Sending text message (%d bytes)\n" "On the other computer, please run: " "wormhole receive\n" "Wormhole code is: %s\n\n" "text message sent\n" % (len(message), code)) - self.failUnlessEqual(out, expected) + self.failUnlessEqual(send_stdout, expected) elif mode == "file": self.failUnlessIn("Sending %d byte file named '%s'\n" % - (len(message), send_filename), out) + (len(message), send_filename), send_stdout) self.failUnlessIn("On the other computer, please run: " "wormhole receive\n" "Wormhole code is: %s\n\n" % code, - out) + send_stdout) self.failUnlessIn("File sent.. waiting for confirmation\n" "Confirmation received. Transfer complete.\n", - out) + send_stdout) elif mode == "directory": - self.failUnlessIn("Sending directory", out) - self.failUnlessIn("named 'testdir'", out) + self.failUnlessIn("Sending directory", send_stdout) + self.failUnlessIn("named 'testdir'", send_stdout) self.failUnlessIn("On the other computer, please run: " "wormhole receive\n" "Wormhole code is: %s\n\n" % code, - out) + send_stdout) self.failUnlessIn("File sent.. waiting for confirmation\n" "Confirmation received. Transfer complete.\n", - out) - self.failUnlessEqual(rc, 0) + send_stdout) + self.failUnlessEqual(send_rc, 0) # check receiver - receive_res = yield receive_d - out, err, rc = receive_res - out = out.decode("utf-8") - err = err.decode("utf-8") + self.failUnlessEqual(receive_stderr, "") if mode == "text": - self.failUnlessEqual(out, message+"\n") + self.failUnlessEqual(receive_stdout, message+"\n") elif mode == "file": self.failUnlessIn("Receiving %d bytes for '%s'" % - (len(message), receive_filename), out) - self.failUnlessIn("Received file written to ", out) + (len(message), receive_filename), receive_stdout) + self.failUnlessIn("Received file written to ", receive_stdout) fn = os.path.join(receive_dir, receive_filename) self.failUnless(os.path.exists(fn)) with open(fn, "r") as f: self.failUnlessEqual(f.read(), message) elif mode == "directory": self.failUnless(re.search(r"Receiving \d+ bytes for '%s'" % - receive_dirname, out)) + receive_dirname, receive_stdout)) self.failUnlessIn("Received files written to %s" % - receive_dirname, out) + receive_dirname, receive_stdout) fn = os.path.join(receive_dir, receive_dirname) self.failUnless(os.path.exists(fn)) for i in range(5): fn = os.path.join(receive_dir, receive_dirname, str(i)) with open(fn, "r") as f: self.failUnlessEqual(f.read(), message(i)) - self.failUnlessEqual(err, "") - self.failUnlessEqual(rc, 0) + self.failUnlessEqual(receive_rc, 0) def test_text(self): return self._do_test() + def test_text_subprocess(self): + return self._do_test(as_subprocess=True) def test_file(self): return self._do_test(mode="file")