cli._dispatch_command: improve test coverage
This commit is contained in:
parent
82b4327f23
commit
d6d6669b23
|
@ -3,6 +3,7 @@ from __future__ import print_function
|
|||
import os
|
||||
import time
|
||||
start = time.time()
|
||||
import six
|
||||
from textwrap import fill, dedent
|
||||
from sys import stdout, stderr
|
||||
from . import public_relay
|
||||
|
@ -106,28 +107,28 @@ def _dispatch_command(reactor, cfg, command):
|
|||
yield maybeDeferred(command)
|
||||
except (WrongPasswordError, KeyFormatError, NoTorError) as e:
|
||||
msg = fill("ERROR: " + dedent(e.__doc__))
|
||||
print(msg, file=stderr)
|
||||
print(msg, file=cfg.stderr)
|
||||
raise SystemExit(1)
|
||||
except WelcomeError as e:
|
||||
msg = fill("ERROR: " + dedent(e.__doc__))
|
||||
print(msg, file=stderr)
|
||||
print(file=stderr)
|
||||
print(str(e), file=stderr)
|
||||
print(msg, file=cfg.stderr)
|
||||
print(six.u(""), file=cfg.stderr)
|
||||
print(six.text_type(e), file=cfg.stderr)
|
||||
raise SystemExit(1)
|
||||
except TransferError as e:
|
||||
print("TransferError: %s" % str(e), file=stderr)
|
||||
print(u"TransferError: %s" % six.text_type(e), file=cfg.stderr)
|
||||
raise SystemExit(1)
|
||||
except Exception as e:
|
||||
# this prints a proper traceback, whereas
|
||||
# traceback.print_exc() just prints a TB to the "yield"
|
||||
# line above ...
|
||||
Failure().printTraceback(file=stderr)
|
||||
print("ERROR:", e, file=stderr)
|
||||
Failure().printTraceback(file=cfg.stderr)
|
||||
print(u"ERROR:", six.text_type(e), file=cfg.stderr)
|
||||
raise SystemExit(1)
|
||||
|
||||
cfg.timing.add("exit")
|
||||
if cfg.dump_timing:
|
||||
cfg.timing.write(cfg.dump_timing, stderr)
|
||||
cfg.timing.write(cfg.dump_timing, cfg.stderr)
|
||||
|
||||
|
||||
CommonArgs = _compose(
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import os, sys, re, io, zipfile, six, stat
|
||||
from textwrap import fill, dedent
|
||||
from humanize import naturalsize
|
||||
import mock
|
||||
from twisted.trial import unittest
|
||||
|
@ -9,7 +10,7 @@ from twisted.internet.utils import getProcessOutputAndValue
|
|||
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
|
||||
from .. import __version__
|
||||
from .common import ServerBase, config
|
||||
from ..cli import cmd_send, cmd_receive, welcome
|
||||
from ..cli import cmd_send, cmd_receive, welcome, cli
|
||||
from ..errors import TransferError, WrongPasswordError, WelcomeError
|
||||
|
||||
|
||||
|
@ -989,3 +990,81 @@ class Welcome(unittest.TestCase):
|
|||
def test_motd(self):
|
||||
stderr = self.do({"motd": "hello"})
|
||||
self.assertEqual(stderr, "Server (at url) says:\n hello\n")
|
||||
|
||||
class Dispatch(unittest.TestCase):
|
||||
@inlineCallbacks
|
||||
def test_success(self):
|
||||
cfg = config("send")
|
||||
cfg.stderr = io.StringIO()
|
||||
called = []
|
||||
def fake():
|
||||
called.append(1)
|
||||
yield cli._dispatch_command(reactor, cfg, fake)
|
||||
self.assertEqual(called, [1])
|
||||
self.assertEqual(cfg.stderr.getvalue(), "")
|
||||
|
||||
@inlineCallbacks
|
||||
def test_timing(self):
|
||||
cfg = config("send")
|
||||
cfg.stderr = io.StringIO()
|
||||
cfg.timing = mock.Mock()
|
||||
cfg.dump_timing = "filename"
|
||||
def fake():
|
||||
pass
|
||||
yield cli._dispatch_command(reactor, cfg, fake)
|
||||
self.assertEqual(cfg.stderr.getvalue(), "")
|
||||
self.assertEqual(cfg.timing.mock_calls[-1],
|
||||
mock.call.write("filename", cfg.stderr))
|
||||
|
||||
@inlineCallbacks
|
||||
def test_wrong_password_error(self):
|
||||
cfg = config("send")
|
||||
cfg.stderr = io.StringIO()
|
||||
def fake():
|
||||
raise WrongPasswordError("abcd")
|
||||
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
|
||||
SystemExit)
|
||||
expected = fill("ERROR: " + dedent(WrongPasswordError.__doc__))+"\n"
|
||||
self.assertEqual(cfg.stderr.getvalue(), expected)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_welcome_error(self):
|
||||
cfg = config("send")
|
||||
cfg.stderr = io.StringIO()
|
||||
def fake():
|
||||
raise WelcomeError("abcd")
|
||||
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
|
||||
SystemExit)
|
||||
expected = fill("ERROR: " + dedent(WelcomeError.__doc__))+"\n\nabcd\n"
|
||||
self.assertEqual(cfg.stderr.getvalue(), expected)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_transfer_error(self):
|
||||
cfg = config("send")
|
||||
cfg.stderr = io.StringIO()
|
||||
def fake():
|
||||
raise TransferError("abcd")
|
||||
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
|
||||
SystemExit)
|
||||
expected = "TransferError: abcd\n"
|
||||
self.assertEqual(cfg.stderr.getvalue(), expected)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_other_error(self):
|
||||
cfg = config("send")
|
||||
cfg.stderr = io.StringIO()
|
||||
def fake():
|
||||
raise ValueError("abcd")
|
||||
# I'm seeing unicode problems with the Failure().printTraceback, and
|
||||
# the output would be kind of unpredictable anyways, so we'll mock it
|
||||
# out here.
|
||||
f = mock.Mock()
|
||||
def mock_print(file):
|
||||
file.write("<TRACEBACK>\n")
|
||||
f.printTraceback = mock_print
|
||||
with mock.patch("wormhole.cli.cli.Failure", return_value=f):
|
||||
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
|
||||
SystemExit)
|
||||
expected = "<TRACEBACK>\nERROR: abcd\n"
|
||||
self.assertEqual(cfg.stderr.getvalue(), expected)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user