From 52ef00b46b30eb980348c1addecb0ec42f2f5c7c Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 14 Jul 2016 22:22:01 -0600 Subject: [PATCH] CLI: refactor to make testing easier When tests need a Config object, they now call a function which invokes Click with a mocked-out go() function, and grabs the Config object before actually doing anything with it. --- src/wormhole/cli/cli.py | 10 +++-- src/wormhole/test/common.py | 16 ++++++++ src/wormhole/test/test_args.py | 66 +++++++++++++++---------------- src/wormhole/test/test_scripts.py | 17 ++++---- 4 files changed, 64 insertions(+), 45 deletions(-) diff --git a/src/wormhole/cli/cli.py b/src/wormhole/cli/cli.py index 576dcf8..5c58c76 100644 --- a/src/wormhole/cli/cli.py +++ b/src/wormhole/cli/cli.py @@ -177,8 +177,13 @@ def send(cfg, what, text, code, zeromode): cfg.zeromode = zeromode cfg.code = code + return go(cmd_send.send, cfg) + +# this intermediate function can be mocked by tests that need to build a +# Config object +def go(f, cfg): # note: react() does not return - return react(_dispatch_command, (cfg, lambda: cmd_send.send(cfg))) + return react(_dispatch_command, (cfg, lambda: f(cfg))) # wormhole receive (or "wormhole rx") @@ -228,5 +233,4 @@ def receive(cfg, code, zeromode, output_file, accept_file, only_text): else: cfg.code = None - # note: react() does not return - return react(_dispatch_command, (cfg, lambda: cmd_receive.receive(cfg))) + return go(cmd_receive.receive, cfg) diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index bea06da..f3b57ad 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -2,6 +2,9 @@ from twisted.application import service from twisted.internet import defer, task from twisted.python import log +from click.testing import CliRunner +import mock +from ..cli import cli from ..transit import allocate_tcp_port from ..server.server import RelayServer from .. import __version__ @@ -68,3 +71,16 @@ class ServerBase: return d wait_d.addCallback(_later) return wait_d + +def config(*argv): + r = CliRunner() + with mock.patch("wormhole.cli.cli.go") as go: + res = r.invoke(cli.wormhole, argv, catch_exceptions=False) + if res.exit_code != 0: + print(res.exit_code) + print(res.output) + print(res) + assert 0 + cfg = go.call_args[0][1] + return cfg + diff --git a/src/wormhole/test/test_args.py b/src/wormhole/test/test_args.py index d82ed2d..8b87e24 100644 --- a/src/wormhole/test/test_args.py +++ b/src/wormhole/test/test_args.py @@ -1,20 +1,12 @@ -import mock +import sys from twisted.trial import unittest -from ..cli.cli import wormhole from ..cli.public_relay import RENDEZVOUS_RELAY, TRANSIT_RELAY -from click.testing import CliRunner +from .common import config #from pprint import pprint -def run(argv): - r = CliRunner() - with mock.patch("wormhole.cli.cli.react") as react: - r.invoke(wormhole, argv) - cfg = react.call_args[0][1][0] - return cfg - class Send(unittest.TestCase): def test_baseline(self): - cfg = run(["send", "--text", "hi"]) + cfg = config("send", "--text", "hi") #pprint(cfg.__dict__) self.assertEqual(cfg.what, None) self.assertEqual(cfg.code, None) @@ -31,51 +23,51 @@ class Send(unittest.TestCase): self.assertEqual(cfg.zeromode, False) def test_file(self): - cfg = run(["send", "fn"]) + cfg = config("send", "fn") #pprint(cfg.__dict__) self.assertEqual(cfg.what, u"fn") self.assertEqual(cfg.text, None) def test_text(self): - cfg = run(["send", "--text", "hi"]) + cfg = config("send", "--text", "hi") self.assertEqual(cfg.what, None) self.assertEqual(cfg.text, u"hi") def test_nolisten(self): - cfg = run(["--no-listen", "send", "fn"]) + cfg = config("--no-listen", "send", "fn") self.assertEqual(cfg.listen, False) def test_code(self): - cfg = run(["send", "--code", "1-abc", "fn"]) + cfg = config("send", "--code", "1-abc", "fn") self.assertEqual(cfg.code, u"1-abc") def test_code_length(self): - cfg = run(["-c", "3", "send", "fn"]) + cfg = config("-c", "3", "send", "fn") self.assertEqual(cfg.code_length, 3) def test_dump_timing(self): - cfg = run(["--dump-timing", "tx.json", "send", "fn"]) + cfg = config("--dump-timing", "tx.json", "send", "fn") self.assertEqual(cfg.dump_timing, "tx.json") def test_hide_progress(self): - cfg = run(["--hide-progress", "send", "fn"]) + cfg = config("--hide-progress", "send", "fn") self.assertEqual(cfg.hide_progress, True) def test_tor(self): - cfg = run(["--tor", "send", "fn"]) + cfg = config("--tor", "send", "fn") self.assertEqual(cfg.tor, True) def test_verify(self): - cfg = run(["--verify", "send", "fn"]) + cfg = config("--verify", "send", "fn") self.assertEqual(cfg.verify, True) def test_zeromode(self): - cfg = run(["send", "-0", "fn"]) + cfg = config("send", "-0", "fn") self.assertEqual(cfg.zeromode, True) class Receive(unittest.TestCase): def test_baseline(self): - cfg = run(["receive"]) + cfg = config("receive") #pprint(cfg.__dict__) self.assertEqual(cfg.accept_file, False) self.assertEqual(cfg.what, None) @@ -94,45 +86,53 @@ class Receive(unittest.TestCase): self.assertEqual(cfg.zeromode, False) def test_nolisten(self): - cfg = run(["--no-listen", "receive"]) + cfg = config("--no-listen", "receive") self.assertEqual(cfg.listen, False) def test_code(self): - cfg = run(["receive", "1-abc"]) + cfg = config("receive", "1-abc") self.assertEqual(cfg.code, u"1-abc") def test_code_length(self): - cfg = run(["-c", "3", "receive"]) + cfg = config("-c", "3", "receive") self.assertEqual(cfg.code_length, 3) def test_dump_timing(self): - cfg = run(["--dump-timing", "tx.json", "receive"]) + cfg = config("--dump-timing", "tx.json", "receive") self.assertEqual(cfg.dump_timing, "tx.json") def test_hide_progress(self): - cfg = run(["--hide-progress", "receive"]) + cfg = config("--hide-progress", "receive") self.assertEqual(cfg.hide_progress, True) def test_tor(self): - cfg = run(["--tor", "receive"]) + cfg = config("--tor", "receive") self.assertEqual(cfg.tor, True) def test_verify(self): - cfg = run(["--verify", "receive"]) + cfg = config("--verify", "receive") self.assertEqual(cfg.verify, True) def test_zeromode(self): - cfg = run(["receive", "-0"]) + cfg = config("receive", "-0") self.assertEqual(cfg.zeromode, True) def test_only_text(self): - cfg = run(["receive", "-t"]) + cfg = config("receive", "-t") self.assertEqual(cfg.only_text, True) def test_accept_file(self): - cfg = run(["receive", "--accept-file"]) + cfg = config("receive", "--accept-file") self.assertEqual(cfg.accept_file, True) def test_output_file(self): - cfg = run(["receive", "--output-file", "fn"]) + cfg = config("receive", "--output-file", "fn") self.assertEqual(cfg.output_file, u"fn") + +class Config(unittest.TestCase): + def test_send(self): + cfg = config("send") + self.assertEqual(cfg.stdout, sys.stdout) + def test_receive(self): + cfg = config("receive") + self.assertEqual(cfg.stdout, sys.stdout) diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index ce4d499..66be1e5 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -6,9 +6,8 @@ from twisted.python import procutils, log from twisted.internet.utils import getProcessOutputAndValue from twisted.internet.defer import gatherResults, inlineCallbacks from .. import __version__ -from .common import ServerBase +from .common import ServerBase, config from ..cli import cmd_send, cmd_receive -from ..cli.cli import Config from ..errors import TransferError, WrongPasswordError, WelcomeError @@ -20,7 +19,7 @@ def build_offer(args): class OfferData(unittest.TestCase): def setUp(self): self._things_to_delete = [] - self.cfg = cfg = Config() + self.cfg = cfg = config("send") cfg.stdout = io.StringIO() cfg.stderr = io.StringIO() @@ -226,8 +225,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): def _do_test(self, as_subprocess=False, mode="text", addslash=False, override_filename=False): assert mode in ("text", "file", "directory") - send_cfg = Config() - recv_cfg = Config() + send_cfg = config("send") + recv_cfg = config("receive") message = "blah blah blah ponies" for cfg in [send_cfg, recv_cfg]: @@ -449,8 +448,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): @inlineCallbacks def test_file_noclobber(self): - send_cfg = Config() - recv_cfg = Config() + send_cfg = config("send") + recv_cfg = config("receive") for cfg in [send_cfg, recv_cfg]: cfg.hide_progress = True @@ -537,7 +536,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): class NotWelcome(ServerBase, unittest.TestCase): def setUp(self): self._setup_relay(error="please upgrade XYZ") - self.cfg = cfg = Config() + self.cfg = cfg = config("send") cfg.hide_progress = True cfg.listen = False cfg.relay_url = self.relayurl @@ -567,7 +566,7 @@ class Cleanup(ServerBase, unittest.TestCase): def setUp(self): d = super(Cleanup, self).setUp() - self.cfg = cfg = Config() + self.cfg = cfg = config("send") # common options for all tests in this suite cfg.hide_progress = True cfg.relay_url = self.relayurl