cli._dispatch_command: improve test coverage

This commit is contained in:
Brian Warner 2017-04-23 15:56:19 -04:00
parent 82b4327f23
commit d6d6669b23
2 changed files with 89 additions and 9 deletions

View File

@ -3,6 +3,7 @@ from __future__ import print_function
import os import os
import time import time
start = time.time() start = time.time()
import six
from textwrap import fill, dedent from textwrap import fill, dedent
from sys import stdout, stderr from sys import stdout, stderr
from . import public_relay from . import public_relay
@ -106,28 +107,28 @@ def _dispatch_command(reactor, cfg, command):
yield maybeDeferred(command) yield maybeDeferred(command)
except (WrongPasswordError, KeyFormatError, NoTorError) as e: except (WrongPasswordError, KeyFormatError, NoTorError) as e:
msg = fill("ERROR: " + dedent(e.__doc__)) msg = fill("ERROR: " + dedent(e.__doc__))
print(msg, file=stderr) print(msg, file=cfg.stderr)
raise SystemExit(1) raise SystemExit(1)
except WelcomeError as e: except WelcomeError as e:
msg = fill("ERROR: " + dedent(e.__doc__)) msg = fill("ERROR: " + dedent(e.__doc__))
print(msg, file=stderr) print(msg, file=cfg.stderr)
print(file=stderr) print(six.u(""), file=cfg.stderr)
print(str(e), file=stderr) print(six.text_type(e), file=cfg.stderr)
raise SystemExit(1) raise SystemExit(1)
except TransferError as e: except TransferError as e:
print("TransferError: %s" % str(e), file=stderr) print(u"TransferError: %s" % six.text_type(e), file=cfg.stderr)
raise SystemExit(1) raise SystemExit(1)
except Exception as e: except Exception as e:
# this prints a proper traceback, whereas # this prints a proper traceback, whereas
# traceback.print_exc() just prints a TB to the "yield" # traceback.print_exc() just prints a TB to the "yield"
# line above ... # line above ...
Failure().printTraceback(file=stderr) Failure().printTraceback(file=cfg.stderr)
print("ERROR:", e, file=stderr) print(u"ERROR:", six.text_type(e), file=cfg.stderr)
raise SystemExit(1) raise SystemExit(1)
cfg.timing.add("exit") cfg.timing.add("exit")
if cfg.dump_timing: if cfg.dump_timing:
cfg.timing.write(cfg.dump_timing, stderr) cfg.timing.write(cfg.dump_timing, cfg.stderr)
CommonArgs = _compose( CommonArgs = _compose(

View File

@ -1,5 +1,6 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import os, sys, re, io, zipfile, six, stat import os, sys, re, io, zipfile, six, stat
from textwrap import fill, dedent
from humanize import naturalsize from humanize import naturalsize
import mock import mock
from twisted.trial import unittest from twisted.trial import unittest
@ -9,7 +10,7 @@ from twisted.internet.utils import getProcessOutputAndValue
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
from .. import __version__ from .. import __version__
from .common import ServerBase, config from .common import ServerBase, config
from ..cli import cmd_send, cmd_receive, welcome from ..cli import cmd_send, cmd_receive, welcome, cli
from ..errors import TransferError, WrongPasswordError, WelcomeError from ..errors import TransferError, WrongPasswordError, WelcomeError
@ -989,3 +990,81 @@ class Welcome(unittest.TestCase):
def test_motd(self): def test_motd(self):
stderr = self.do({"motd": "hello"}) stderr = self.do({"motd": "hello"})
self.assertEqual(stderr, "Server (at url) says:\n hello\n") self.assertEqual(stderr, "Server (at url) says:\n hello\n")
class Dispatch(unittest.TestCase):
@inlineCallbacks
def test_success(self):
cfg = config("send")
cfg.stderr = io.StringIO()
called = []
def fake():
called.append(1)
yield cli._dispatch_command(reactor, cfg, fake)
self.assertEqual(called, [1])
self.assertEqual(cfg.stderr.getvalue(), "")
@inlineCallbacks
def test_timing(self):
cfg = config("send")
cfg.stderr = io.StringIO()
cfg.timing = mock.Mock()
cfg.dump_timing = "filename"
def fake():
pass
yield cli._dispatch_command(reactor, cfg, fake)
self.assertEqual(cfg.stderr.getvalue(), "")
self.assertEqual(cfg.timing.mock_calls[-1],
mock.call.write("filename", cfg.stderr))
@inlineCallbacks
def test_wrong_password_error(self):
cfg = config("send")
cfg.stderr = io.StringIO()
def fake():
raise WrongPasswordError("abcd")
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
SystemExit)
expected = fill("ERROR: " + dedent(WrongPasswordError.__doc__))+"\n"
self.assertEqual(cfg.stderr.getvalue(), expected)
@inlineCallbacks
def test_welcome_error(self):
cfg = config("send")
cfg.stderr = io.StringIO()
def fake():
raise WelcomeError("abcd")
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
SystemExit)
expected = fill("ERROR: " + dedent(WelcomeError.__doc__))+"\n\nabcd\n"
self.assertEqual(cfg.stderr.getvalue(), expected)
@inlineCallbacks
def test_transfer_error(self):
cfg = config("send")
cfg.stderr = io.StringIO()
def fake():
raise TransferError("abcd")
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
SystemExit)
expected = "TransferError: abcd\n"
self.assertEqual(cfg.stderr.getvalue(), expected)
@inlineCallbacks
def test_other_error(self):
cfg = config("send")
cfg.stderr = io.StringIO()
def fake():
raise ValueError("abcd")
# I'm seeing unicode problems with the Failure().printTraceback, and
# the output would be kind of unpredictable anyways, so we'll mock it
# out here.
f = mock.Mock()
def mock_print(file):
file.write("<TRACEBACK>\n")
f.printTraceback = mock_print
with mock.patch("wormhole.cli.cli.Failure", return_value=f):
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
SystemExit)
expected = "<TRACEBACK>\nERROR: abcd\n"
self.assertEqual(cfg.stderr.getvalue(), expected)