CLI: refactor to make testing easier
When tests need a Config object, they now call a function which invokes Click with a mocked-out go() function, and grabs the Config object before actually doing anything with it.
This commit is contained in:
parent
cdb5c19010
commit
52ef00b46b
|
@ -177,8 +177,13 @@ def send(cfg, what, text, code, zeromode):
|
||||||
cfg.zeromode = zeromode
|
cfg.zeromode = zeromode
|
||||||
cfg.code = code
|
cfg.code = code
|
||||||
|
|
||||||
|
return go(cmd_send.send, cfg)
|
||||||
|
|
||||||
|
# this intermediate function can be mocked by tests that need to build a
|
||||||
|
# Config object
|
||||||
|
def go(f, cfg):
|
||||||
# note: react() does not return
|
# note: react() does not return
|
||||||
return react(_dispatch_command, (cfg, lambda: cmd_send.send(cfg)))
|
return react(_dispatch_command, (cfg, lambda: f(cfg)))
|
||||||
|
|
||||||
|
|
||||||
# wormhole receive (or "wormhole rx")
|
# wormhole receive (or "wormhole rx")
|
||||||
|
@ -228,5 +233,4 @@ def receive(cfg, code, zeromode, output_file, accept_file, only_text):
|
||||||
else:
|
else:
|
||||||
cfg.code = None
|
cfg.code = None
|
||||||
|
|
||||||
# note: react() does not return
|
return go(cmd_receive.receive, cfg)
|
||||||
return react(_dispatch_command, (cfg, lambda: cmd_receive.receive(cfg)))
|
|
||||||
|
|
|
@ -2,6 +2,9 @@
|
||||||
from twisted.application import service
|
from twisted.application import service
|
||||||
from twisted.internet import defer, task
|
from twisted.internet import defer, task
|
||||||
from twisted.python import log
|
from twisted.python import log
|
||||||
|
from click.testing import CliRunner
|
||||||
|
import mock
|
||||||
|
from ..cli import cli
|
||||||
from ..transit import allocate_tcp_port
|
from ..transit import allocate_tcp_port
|
||||||
from ..server.server import RelayServer
|
from ..server.server import RelayServer
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
|
@ -68,3 +71,16 @@ class ServerBase:
|
||||||
return d
|
return d
|
||||||
wait_d.addCallback(_later)
|
wait_d.addCallback(_later)
|
||||||
return wait_d
|
return wait_d
|
||||||
|
|
||||||
|
def config(*argv):
|
||||||
|
r = CliRunner()
|
||||||
|
with mock.patch("wormhole.cli.cli.go") as go:
|
||||||
|
res = r.invoke(cli.wormhole, argv, catch_exceptions=False)
|
||||||
|
if res.exit_code != 0:
|
||||||
|
print(res.exit_code)
|
||||||
|
print(res.output)
|
||||||
|
print(res)
|
||||||
|
assert 0
|
||||||
|
cfg = go.call_args[0][1]
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
|
@ -1,20 +1,12 @@
|
||||||
import mock
|
import sys
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from ..cli.cli import wormhole
|
|
||||||
from ..cli.public_relay import RENDEZVOUS_RELAY, TRANSIT_RELAY
|
from ..cli.public_relay import RENDEZVOUS_RELAY, TRANSIT_RELAY
|
||||||
from click.testing import CliRunner
|
from .common import config
|
||||||
#from pprint import pprint
|
#from pprint import pprint
|
||||||
|
|
||||||
def run(argv):
|
|
||||||
r = CliRunner()
|
|
||||||
with mock.patch("wormhole.cli.cli.react") as react:
|
|
||||||
r.invoke(wormhole, argv)
|
|
||||||
cfg = react.call_args[0][1][0]
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
class Send(unittest.TestCase):
|
class Send(unittest.TestCase):
|
||||||
def test_baseline(self):
|
def test_baseline(self):
|
||||||
cfg = run(["send", "--text", "hi"])
|
cfg = config("send", "--text", "hi")
|
||||||
#pprint(cfg.__dict__)
|
#pprint(cfg.__dict__)
|
||||||
self.assertEqual(cfg.what, None)
|
self.assertEqual(cfg.what, None)
|
||||||
self.assertEqual(cfg.code, None)
|
self.assertEqual(cfg.code, None)
|
||||||
|
@ -31,51 +23,51 @@ class Send(unittest.TestCase):
|
||||||
self.assertEqual(cfg.zeromode, False)
|
self.assertEqual(cfg.zeromode, False)
|
||||||
|
|
||||||
def test_file(self):
|
def test_file(self):
|
||||||
cfg = run(["send", "fn"])
|
cfg = config("send", "fn")
|
||||||
#pprint(cfg.__dict__)
|
#pprint(cfg.__dict__)
|
||||||
self.assertEqual(cfg.what, u"fn")
|
self.assertEqual(cfg.what, u"fn")
|
||||||
self.assertEqual(cfg.text, None)
|
self.assertEqual(cfg.text, None)
|
||||||
|
|
||||||
def test_text(self):
|
def test_text(self):
|
||||||
cfg = run(["send", "--text", "hi"])
|
cfg = config("send", "--text", "hi")
|
||||||
self.assertEqual(cfg.what, None)
|
self.assertEqual(cfg.what, None)
|
||||||
self.assertEqual(cfg.text, u"hi")
|
self.assertEqual(cfg.text, u"hi")
|
||||||
|
|
||||||
def test_nolisten(self):
|
def test_nolisten(self):
|
||||||
cfg = run(["--no-listen", "send", "fn"])
|
cfg = config("--no-listen", "send", "fn")
|
||||||
self.assertEqual(cfg.listen, False)
|
self.assertEqual(cfg.listen, False)
|
||||||
|
|
||||||
def test_code(self):
|
def test_code(self):
|
||||||
cfg = run(["send", "--code", "1-abc", "fn"])
|
cfg = config("send", "--code", "1-abc", "fn")
|
||||||
self.assertEqual(cfg.code, u"1-abc")
|
self.assertEqual(cfg.code, u"1-abc")
|
||||||
|
|
||||||
def test_code_length(self):
|
def test_code_length(self):
|
||||||
cfg = run(["-c", "3", "send", "fn"])
|
cfg = config("-c", "3", "send", "fn")
|
||||||
self.assertEqual(cfg.code_length, 3)
|
self.assertEqual(cfg.code_length, 3)
|
||||||
|
|
||||||
def test_dump_timing(self):
|
def test_dump_timing(self):
|
||||||
cfg = run(["--dump-timing", "tx.json", "send", "fn"])
|
cfg = config("--dump-timing", "tx.json", "send", "fn")
|
||||||
self.assertEqual(cfg.dump_timing, "tx.json")
|
self.assertEqual(cfg.dump_timing, "tx.json")
|
||||||
|
|
||||||
def test_hide_progress(self):
|
def test_hide_progress(self):
|
||||||
cfg = run(["--hide-progress", "send", "fn"])
|
cfg = config("--hide-progress", "send", "fn")
|
||||||
self.assertEqual(cfg.hide_progress, True)
|
self.assertEqual(cfg.hide_progress, True)
|
||||||
|
|
||||||
def test_tor(self):
|
def test_tor(self):
|
||||||
cfg = run(["--tor", "send", "fn"])
|
cfg = config("--tor", "send", "fn")
|
||||||
self.assertEqual(cfg.tor, True)
|
self.assertEqual(cfg.tor, True)
|
||||||
|
|
||||||
def test_verify(self):
|
def test_verify(self):
|
||||||
cfg = run(["--verify", "send", "fn"])
|
cfg = config("--verify", "send", "fn")
|
||||||
self.assertEqual(cfg.verify, True)
|
self.assertEqual(cfg.verify, True)
|
||||||
|
|
||||||
def test_zeromode(self):
|
def test_zeromode(self):
|
||||||
cfg = run(["send", "-0", "fn"])
|
cfg = config("send", "-0", "fn")
|
||||||
self.assertEqual(cfg.zeromode, True)
|
self.assertEqual(cfg.zeromode, True)
|
||||||
|
|
||||||
class Receive(unittest.TestCase):
|
class Receive(unittest.TestCase):
|
||||||
def test_baseline(self):
|
def test_baseline(self):
|
||||||
cfg = run(["receive"])
|
cfg = config("receive")
|
||||||
#pprint(cfg.__dict__)
|
#pprint(cfg.__dict__)
|
||||||
self.assertEqual(cfg.accept_file, False)
|
self.assertEqual(cfg.accept_file, False)
|
||||||
self.assertEqual(cfg.what, None)
|
self.assertEqual(cfg.what, None)
|
||||||
|
@ -94,45 +86,53 @@ class Receive(unittest.TestCase):
|
||||||
self.assertEqual(cfg.zeromode, False)
|
self.assertEqual(cfg.zeromode, False)
|
||||||
|
|
||||||
def test_nolisten(self):
|
def test_nolisten(self):
|
||||||
cfg = run(["--no-listen", "receive"])
|
cfg = config("--no-listen", "receive")
|
||||||
self.assertEqual(cfg.listen, False)
|
self.assertEqual(cfg.listen, False)
|
||||||
|
|
||||||
def test_code(self):
|
def test_code(self):
|
||||||
cfg = run(["receive", "1-abc"])
|
cfg = config("receive", "1-abc")
|
||||||
self.assertEqual(cfg.code, u"1-abc")
|
self.assertEqual(cfg.code, u"1-abc")
|
||||||
|
|
||||||
def test_code_length(self):
|
def test_code_length(self):
|
||||||
cfg = run(["-c", "3", "receive"])
|
cfg = config("-c", "3", "receive")
|
||||||
self.assertEqual(cfg.code_length, 3)
|
self.assertEqual(cfg.code_length, 3)
|
||||||
|
|
||||||
def test_dump_timing(self):
|
def test_dump_timing(self):
|
||||||
cfg = run(["--dump-timing", "tx.json", "receive"])
|
cfg = config("--dump-timing", "tx.json", "receive")
|
||||||
self.assertEqual(cfg.dump_timing, "tx.json")
|
self.assertEqual(cfg.dump_timing, "tx.json")
|
||||||
|
|
||||||
def test_hide_progress(self):
|
def test_hide_progress(self):
|
||||||
cfg = run(["--hide-progress", "receive"])
|
cfg = config("--hide-progress", "receive")
|
||||||
self.assertEqual(cfg.hide_progress, True)
|
self.assertEqual(cfg.hide_progress, True)
|
||||||
|
|
||||||
def test_tor(self):
|
def test_tor(self):
|
||||||
cfg = run(["--tor", "receive"])
|
cfg = config("--tor", "receive")
|
||||||
self.assertEqual(cfg.tor, True)
|
self.assertEqual(cfg.tor, True)
|
||||||
|
|
||||||
def test_verify(self):
|
def test_verify(self):
|
||||||
cfg = run(["--verify", "receive"])
|
cfg = config("--verify", "receive")
|
||||||
self.assertEqual(cfg.verify, True)
|
self.assertEqual(cfg.verify, True)
|
||||||
|
|
||||||
def test_zeromode(self):
|
def test_zeromode(self):
|
||||||
cfg = run(["receive", "-0"])
|
cfg = config("receive", "-0")
|
||||||
self.assertEqual(cfg.zeromode, True)
|
self.assertEqual(cfg.zeromode, True)
|
||||||
|
|
||||||
def test_only_text(self):
|
def test_only_text(self):
|
||||||
cfg = run(["receive", "-t"])
|
cfg = config("receive", "-t")
|
||||||
self.assertEqual(cfg.only_text, True)
|
self.assertEqual(cfg.only_text, True)
|
||||||
|
|
||||||
def test_accept_file(self):
|
def test_accept_file(self):
|
||||||
cfg = run(["receive", "--accept-file"])
|
cfg = config("receive", "--accept-file")
|
||||||
self.assertEqual(cfg.accept_file, True)
|
self.assertEqual(cfg.accept_file, True)
|
||||||
|
|
||||||
def test_output_file(self):
|
def test_output_file(self):
|
||||||
cfg = run(["receive", "--output-file", "fn"])
|
cfg = config("receive", "--output-file", "fn")
|
||||||
self.assertEqual(cfg.output_file, u"fn")
|
self.assertEqual(cfg.output_file, u"fn")
|
||||||
|
|
||||||
|
class Config(unittest.TestCase):
|
||||||
|
def test_send(self):
|
||||||
|
cfg = config("send")
|
||||||
|
self.assertEqual(cfg.stdout, sys.stdout)
|
||||||
|
def test_receive(self):
|
||||||
|
cfg = config("receive")
|
||||||
|
self.assertEqual(cfg.stdout, sys.stdout)
|
||||||
|
|
|
@ -6,9 +6,8 @@ from twisted.python import procutils, log
|
||||||
from twisted.internet.utils import getProcessOutputAndValue
|
from twisted.internet.utils import getProcessOutputAndValue
|
||||||
from twisted.internet.defer import gatherResults, inlineCallbacks
|
from twisted.internet.defer import gatherResults, inlineCallbacks
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .common import ServerBase
|
from .common import ServerBase, config
|
||||||
from ..cli import cmd_send, cmd_receive
|
from ..cli import cmd_send, cmd_receive
|
||||||
from ..cli.cli import Config
|
|
||||||
from ..errors import TransferError, WrongPasswordError, WelcomeError
|
from ..errors import TransferError, WrongPasswordError, WelcomeError
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +19,7 @@ def build_offer(args):
|
||||||
class OfferData(unittest.TestCase):
|
class OfferData(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self._things_to_delete = []
|
self._things_to_delete = []
|
||||||
self.cfg = cfg = Config()
|
self.cfg = cfg = config("send")
|
||||||
cfg.stdout = io.StringIO()
|
cfg.stdout = io.StringIO()
|
||||||
cfg.stderr = io.StringIO()
|
cfg.stderr = io.StringIO()
|
||||||
|
|
||||||
|
@ -226,8 +225,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
def _do_test(self, as_subprocess=False,
|
def _do_test(self, as_subprocess=False,
|
||||||
mode="text", addslash=False, override_filename=False):
|
mode="text", addslash=False, override_filename=False):
|
||||||
assert mode in ("text", "file", "directory")
|
assert mode in ("text", "file", "directory")
|
||||||
send_cfg = Config()
|
send_cfg = config("send")
|
||||||
recv_cfg = Config()
|
recv_cfg = config("receive")
|
||||||
message = "blah blah blah ponies"
|
message = "blah blah blah ponies"
|
||||||
|
|
||||||
for cfg in [send_cfg, recv_cfg]:
|
for cfg in [send_cfg, recv_cfg]:
|
||||||
|
@ -449,8 +448,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def test_file_noclobber(self):
|
def test_file_noclobber(self):
|
||||||
send_cfg = Config()
|
send_cfg = config("send")
|
||||||
recv_cfg = Config()
|
recv_cfg = config("receive")
|
||||||
|
|
||||||
for cfg in [send_cfg, recv_cfg]:
|
for cfg in [send_cfg, recv_cfg]:
|
||||||
cfg.hide_progress = True
|
cfg.hide_progress = True
|
||||||
|
@ -537,7 +536,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
class NotWelcome(ServerBase, unittest.TestCase):
|
class NotWelcome(ServerBase, unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self._setup_relay(error="please upgrade XYZ")
|
self._setup_relay(error="please upgrade XYZ")
|
||||||
self.cfg = cfg = Config()
|
self.cfg = cfg = config("send")
|
||||||
cfg.hide_progress = True
|
cfg.hide_progress = True
|
||||||
cfg.listen = False
|
cfg.listen = False
|
||||||
cfg.relay_url = self.relayurl
|
cfg.relay_url = self.relayurl
|
||||||
|
@ -567,7 +566,7 @@ class Cleanup(ServerBase, unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
d = super(Cleanup, self).setUp()
|
d = super(Cleanup, self).setUp()
|
||||||
self.cfg = cfg = Config()
|
self.cfg = cfg = config("send")
|
||||||
# common options for all tests in this suite
|
# common options for all tests in this suite
|
||||||
cfg.hide_progress = True
|
cfg.hide_progress = True
|
||||||
cfg.relay_url = self.relayurl
|
cfg.relay_url = self.relayurl
|
||||||
|
|
Loading…
Reference in New Issue
Block a user