CLI: refactor to make testing easier

When tests need a Config object, they now call a function which invokes
Click with a mocked-out go() function, and grabs the Config object
before actually doing anything with it.
This commit is contained in:
Brian Warner 2016-07-14 22:22:01 -06:00
parent cdb5c19010
commit 52ef00b46b
4 changed files with 64 additions and 45 deletions

View File

@ -177,8 +177,13 @@ def send(cfg, what, text, code, zeromode):
cfg.zeromode = zeromode cfg.zeromode = zeromode
cfg.code = code cfg.code = code
return go(cmd_send.send, cfg)
# this intermediate function can be mocked by tests that need to build a
# Config object
def go(f, cfg):
# note: react() does not return # note: react() does not return
return react(_dispatch_command, (cfg, lambda: cmd_send.send(cfg))) return react(_dispatch_command, (cfg, lambda: f(cfg)))
# wormhole receive (or "wormhole rx") # wormhole receive (or "wormhole rx")
@ -228,5 +233,4 @@ def receive(cfg, code, zeromode, output_file, accept_file, only_text):
else: else:
cfg.code = None cfg.code = None
# note: react() does not return return go(cmd_receive.receive, cfg)
return react(_dispatch_command, (cfg, lambda: cmd_receive.receive(cfg)))

View File

@ -2,6 +2,9 @@
from twisted.application import service from twisted.application import service
from twisted.internet import defer, task from twisted.internet import defer, task
from twisted.python import log from twisted.python import log
from click.testing import CliRunner
import mock
from ..cli import cli
from ..transit import allocate_tcp_port from ..transit import allocate_tcp_port
from ..server.server import RelayServer from ..server.server import RelayServer
from .. import __version__ from .. import __version__
@ -68,3 +71,16 @@ class ServerBase:
return d return d
wait_d.addCallback(_later) wait_d.addCallback(_later)
return wait_d return wait_d
def config(*argv):
r = CliRunner()
with mock.patch("wormhole.cli.cli.go") as go:
res = r.invoke(cli.wormhole, argv, catch_exceptions=False)
if res.exit_code != 0:
print(res.exit_code)
print(res.output)
print(res)
assert 0
cfg = go.call_args[0][1]
return cfg

View File

@ -1,20 +1,12 @@
import mock import sys
from twisted.trial import unittest from twisted.trial import unittest
from ..cli.cli import wormhole
from ..cli.public_relay import RENDEZVOUS_RELAY, TRANSIT_RELAY from ..cli.public_relay import RENDEZVOUS_RELAY, TRANSIT_RELAY
from click.testing import CliRunner from .common import config
#from pprint import pprint #from pprint import pprint
def run(argv):
r = CliRunner()
with mock.patch("wormhole.cli.cli.react") as react:
r.invoke(wormhole, argv)
cfg = react.call_args[0][1][0]
return cfg
class Send(unittest.TestCase): class Send(unittest.TestCase):
def test_baseline(self): def test_baseline(self):
cfg = run(["send", "--text", "hi"]) cfg = config("send", "--text", "hi")
#pprint(cfg.__dict__) #pprint(cfg.__dict__)
self.assertEqual(cfg.what, None) self.assertEqual(cfg.what, None)
self.assertEqual(cfg.code, None) self.assertEqual(cfg.code, None)
@ -31,51 +23,51 @@ class Send(unittest.TestCase):
self.assertEqual(cfg.zeromode, False) self.assertEqual(cfg.zeromode, False)
def test_file(self): def test_file(self):
cfg = run(["send", "fn"]) cfg = config("send", "fn")
#pprint(cfg.__dict__) #pprint(cfg.__dict__)
self.assertEqual(cfg.what, u"fn") self.assertEqual(cfg.what, u"fn")
self.assertEqual(cfg.text, None) self.assertEqual(cfg.text, None)
def test_text(self): def test_text(self):
cfg = run(["send", "--text", "hi"]) cfg = config("send", "--text", "hi")
self.assertEqual(cfg.what, None) self.assertEqual(cfg.what, None)
self.assertEqual(cfg.text, u"hi") self.assertEqual(cfg.text, u"hi")
def test_nolisten(self): def test_nolisten(self):
cfg = run(["--no-listen", "send", "fn"]) cfg = config("--no-listen", "send", "fn")
self.assertEqual(cfg.listen, False) self.assertEqual(cfg.listen, False)
def test_code(self): def test_code(self):
cfg = run(["send", "--code", "1-abc", "fn"]) cfg = config("send", "--code", "1-abc", "fn")
self.assertEqual(cfg.code, u"1-abc") self.assertEqual(cfg.code, u"1-abc")
def test_code_length(self): def test_code_length(self):
cfg = run(["-c", "3", "send", "fn"]) cfg = config("-c", "3", "send", "fn")
self.assertEqual(cfg.code_length, 3) self.assertEqual(cfg.code_length, 3)
def test_dump_timing(self): def test_dump_timing(self):
cfg = run(["--dump-timing", "tx.json", "send", "fn"]) cfg = config("--dump-timing", "tx.json", "send", "fn")
self.assertEqual(cfg.dump_timing, "tx.json") self.assertEqual(cfg.dump_timing, "tx.json")
def test_hide_progress(self): def test_hide_progress(self):
cfg = run(["--hide-progress", "send", "fn"]) cfg = config("--hide-progress", "send", "fn")
self.assertEqual(cfg.hide_progress, True) self.assertEqual(cfg.hide_progress, True)
def test_tor(self): def test_tor(self):
cfg = run(["--tor", "send", "fn"]) cfg = config("--tor", "send", "fn")
self.assertEqual(cfg.tor, True) self.assertEqual(cfg.tor, True)
def test_verify(self): def test_verify(self):
cfg = run(["--verify", "send", "fn"]) cfg = config("--verify", "send", "fn")
self.assertEqual(cfg.verify, True) self.assertEqual(cfg.verify, True)
def test_zeromode(self): def test_zeromode(self):
cfg = run(["send", "-0", "fn"]) cfg = config("send", "-0", "fn")
self.assertEqual(cfg.zeromode, True) self.assertEqual(cfg.zeromode, True)
class Receive(unittest.TestCase): class Receive(unittest.TestCase):
def test_baseline(self): def test_baseline(self):
cfg = run(["receive"]) cfg = config("receive")
#pprint(cfg.__dict__) #pprint(cfg.__dict__)
self.assertEqual(cfg.accept_file, False) self.assertEqual(cfg.accept_file, False)
self.assertEqual(cfg.what, None) self.assertEqual(cfg.what, None)
@ -94,45 +86,53 @@ class Receive(unittest.TestCase):
self.assertEqual(cfg.zeromode, False) self.assertEqual(cfg.zeromode, False)
def test_nolisten(self): def test_nolisten(self):
cfg = run(["--no-listen", "receive"]) cfg = config("--no-listen", "receive")
self.assertEqual(cfg.listen, False) self.assertEqual(cfg.listen, False)
def test_code(self): def test_code(self):
cfg = run(["receive", "1-abc"]) cfg = config("receive", "1-abc")
self.assertEqual(cfg.code, u"1-abc") self.assertEqual(cfg.code, u"1-abc")
def test_code_length(self): def test_code_length(self):
cfg = run(["-c", "3", "receive"]) cfg = config("-c", "3", "receive")
self.assertEqual(cfg.code_length, 3) self.assertEqual(cfg.code_length, 3)
def test_dump_timing(self): def test_dump_timing(self):
cfg = run(["--dump-timing", "tx.json", "receive"]) cfg = config("--dump-timing", "tx.json", "receive")
self.assertEqual(cfg.dump_timing, "tx.json") self.assertEqual(cfg.dump_timing, "tx.json")
def test_hide_progress(self): def test_hide_progress(self):
cfg = run(["--hide-progress", "receive"]) cfg = config("--hide-progress", "receive")
self.assertEqual(cfg.hide_progress, True) self.assertEqual(cfg.hide_progress, True)
def test_tor(self): def test_tor(self):
cfg = run(["--tor", "receive"]) cfg = config("--tor", "receive")
self.assertEqual(cfg.tor, True) self.assertEqual(cfg.tor, True)
def test_verify(self): def test_verify(self):
cfg = run(["--verify", "receive"]) cfg = config("--verify", "receive")
self.assertEqual(cfg.verify, True) self.assertEqual(cfg.verify, True)
def test_zeromode(self): def test_zeromode(self):
cfg = run(["receive", "-0"]) cfg = config("receive", "-0")
self.assertEqual(cfg.zeromode, True) self.assertEqual(cfg.zeromode, True)
def test_only_text(self): def test_only_text(self):
cfg = run(["receive", "-t"]) cfg = config("receive", "-t")
self.assertEqual(cfg.only_text, True) self.assertEqual(cfg.only_text, True)
def test_accept_file(self): def test_accept_file(self):
cfg = run(["receive", "--accept-file"]) cfg = config("receive", "--accept-file")
self.assertEqual(cfg.accept_file, True) self.assertEqual(cfg.accept_file, True)
def test_output_file(self): def test_output_file(self):
cfg = run(["receive", "--output-file", "fn"]) cfg = config("receive", "--output-file", "fn")
self.assertEqual(cfg.output_file, u"fn") self.assertEqual(cfg.output_file, u"fn")
class Config(unittest.TestCase):
def test_send(self):
cfg = config("send")
self.assertEqual(cfg.stdout, sys.stdout)
def test_receive(self):
cfg = config("receive")
self.assertEqual(cfg.stdout, sys.stdout)

View File

@ -6,9 +6,8 @@ from twisted.python import procutils, log
from twisted.internet.utils import getProcessOutputAndValue from twisted.internet.utils import getProcessOutputAndValue
from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.internet.defer import gatherResults, inlineCallbacks
from .. import __version__ from .. import __version__
from .common import ServerBase from .common import ServerBase, config
from ..cli import cmd_send, cmd_receive from ..cli import cmd_send, cmd_receive
from ..cli.cli import Config
from ..errors import TransferError, WrongPasswordError, WelcomeError from ..errors import TransferError, WrongPasswordError, WelcomeError
@ -20,7 +19,7 @@ def build_offer(args):
class OfferData(unittest.TestCase): class OfferData(unittest.TestCase):
def setUp(self): def setUp(self):
self._things_to_delete = [] self._things_to_delete = []
self.cfg = cfg = Config() self.cfg = cfg = config("send")
cfg.stdout = io.StringIO() cfg.stdout = io.StringIO()
cfg.stderr = io.StringIO() cfg.stderr = io.StringIO()
@ -226,8 +225,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
def _do_test(self, as_subprocess=False, def _do_test(self, as_subprocess=False,
mode="text", addslash=False, override_filename=False): mode="text", addslash=False, override_filename=False):
assert mode in ("text", "file", "directory") assert mode in ("text", "file", "directory")
send_cfg = Config() send_cfg = config("send")
recv_cfg = Config() recv_cfg = config("receive")
message = "blah blah blah ponies" message = "blah blah blah ponies"
for cfg in [send_cfg, recv_cfg]: for cfg in [send_cfg, recv_cfg]:
@ -449,8 +448,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_file_noclobber(self): def test_file_noclobber(self):
send_cfg = Config() send_cfg = config("send")
recv_cfg = Config() recv_cfg = config("receive")
for cfg in [send_cfg, recv_cfg]: for cfg in [send_cfg, recv_cfg]:
cfg.hide_progress = True cfg.hide_progress = True
@ -537,7 +536,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
class NotWelcome(ServerBase, unittest.TestCase): class NotWelcome(ServerBase, unittest.TestCase):
def setUp(self): def setUp(self):
self._setup_relay(error="please upgrade XYZ") self._setup_relay(error="please upgrade XYZ")
self.cfg = cfg = Config() self.cfg = cfg = config("send")
cfg.hide_progress = True cfg.hide_progress = True
cfg.listen = False cfg.listen = False
cfg.relay_url = self.relayurl cfg.relay_url = self.relayurl
@ -567,7 +566,7 @@ class Cleanup(ServerBase, unittest.TestCase):
def setUp(self): def setUp(self):
d = super(Cleanup, self).setUp() d = super(Cleanup, self).setUp()
self.cfg = cfg = Config() self.cfg = cfg = config("send")
# common options for all tests in this suite # common options for all tests in this suite
cfg.hide_progress = True cfg.hide_progress = True
cfg.relay_url = self.relayurl cfg.relay_url = self.relayurl