cli._dispatch_command: improve test coverage
This commit is contained in:
parent
82b4327f23
commit
d6d6669b23
|
@ -3,6 +3,7 @@ from __future__ import print_function
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
import six
|
||||||
from textwrap import fill, dedent
|
from textwrap import fill, dedent
|
||||||
from sys import stdout, stderr
|
from sys import stdout, stderr
|
||||||
from . import public_relay
|
from . import public_relay
|
||||||
|
@ -106,28 +107,28 @@ def _dispatch_command(reactor, cfg, command):
|
||||||
yield maybeDeferred(command)
|
yield maybeDeferred(command)
|
||||||
except (WrongPasswordError, KeyFormatError, NoTorError) as e:
|
except (WrongPasswordError, KeyFormatError, NoTorError) as e:
|
||||||
msg = fill("ERROR: " + dedent(e.__doc__))
|
msg = fill("ERROR: " + dedent(e.__doc__))
|
||||||
print(msg, file=stderr)
|
print(msg, file=cfg.stderr)
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
except WelcomeError as e:
|
except WelcomeError as e:
|
||||||
msg = fill("ERROR: " + dedent(e.__doc__))
|
msg = fill("ERROR: " + dedent(e.__doc__))
|
||||||
print(msg, file=stderr)
|
print(msg, file=cfg.stderr)
|
||||||
print(file=stderr)
|
print(six.u(""), file=cfg.stderr)
|
||||||
print(str(e), file=stderr)
|
print(six.text_type(e), file=cfg.stderr)
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
except TransferError as e:
|
except TransferError as e:
|
||||||
print("TransferError: %s" % str(e), file=stderr)
|
print(u"TransferError: %s" % six.text_type(e), file=cfg.stderr)
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# this prints a proper traceback, whereas
|
# this prints a proper traceback, whereas
|
||||||
# traceback.print_exc() just prints a TB to the "yield"
|
# traceback.print_exc() just prints a TB to the "yield"
|
||||||
# line above ...
|
# line above ...
|
||||||
Failure().printTraceback(file=stderr)
|
Failure().printTraceback(file=cfg.stderr)
|
||||||
print("ERROR:", e, file=stderr)
|
print(u"ERROR:", six.text_type(e), file=cfg.stderr)
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
|
|
||||||
cfg.timing.add("exit")
|
cfg.timing.add("exit")
|
||||||
if cfg.dump_timing:
|
if cfg.dump_timing:
|
||||||
cfg.timing.write(cfg.dump_timing, stderr)
|
cfg.timing.write(cfg.dump_timing, cfg.stderr)
|
||||||
|
|
||||||
|
|
||||||
CommonArgs = _compose(
|
CommonArgs = _compose(
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import print_function, unicode_literals
|
from __future__ import print_function, unicode_literals
|
||||||
import os, sys, re, io, zipfile, six, stat
|
import os, sys, re, io, zipfile, six, stat
|
||||||
|
from textwrap import fill, dedent
|
||||||
from humanize import naturalsize
|
from humanize import naturalsize
|
||||||
import mock
|
import mock
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
|
@ -9,7 +10,7 @@ from twisted.internet.utils import getProcessOutputAndValue
|
||||||
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
|
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .common import ServerBase, config
|
from .common import ServerBase, config
|
||||||
from ..cli import cmd_send, cmd_receive, welcome
|
from ..cli import cmd_send, cmd_receive, welcome, cli
|
||||||
from ..errors import TransferError, WrongPasswordError, WelcomeError
|
from ..errors import TransferError, WrongPasswordError, WelcomeError
|
||||||
|
|
||||||
|
|
||||||
|
@ -989,3 +990,81 @@ class Welcome(unittest.TestCase):
|
||||||
def test_motd(self):
|
def test_motd(self):
|
||||||
stderr = self.do({"motd": "hello"})
|
stderr = self.do({"motd": "hello"})
|
||||||
self.assertEqual(stderr, "Server (at url) says:\n hello\n")
|
self.assertEqual(stderr, "Server (at url) says:\n hello\n")
|
||||||
|
|
||||||
|
class Dispatch(unittest.TestCase):
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_success(self):
|
||||||
|
cfg = config("send")
|
||||||
|
cfg.stderr = io.StringIO()
|
||||||
|
called = []
|
||||||
|
def fake():
|
||||||
|
called.append(1)
|
||||||
|
yield cli._dispatch_command(reactor, cfg, fake)
|
||||||
|
self.assertEqual(called, [1])
|
||||||
|
self.assertEqual(cfg.stderr.getvalue(), "")
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_timing(self):
|
||||||
|
cfg = config("send")
|
||||||
|
cfg.stderr = io.StringIO()
|
||||||
|
cfg.timing = mock.Mock()
|
||||||
|
cfg.dump_timing = "filename"
|
||||||
|
def fake():
|
||||||
|
pass
|
||||||
|
yield cli._dispatch_command(reactor, cfg, fake)
|
||||||
|
self.assertEqual(cfg.stderr.getvalue(), "")
|
||||||
|
self.assertEqual(cfg.timing.mock_calls[-1],
|
||||||
|
mock.call.write("filename", cfg.stderr))
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_wrong_password_error(self):
|
||||||
|
cfg = config("send")
|
||||||
|
cfg.stderr = io.StringIO()
|
||||||
|
def fake():
|
||||||
|
raise WrongPasswordError("abcd")
|
||||||
|
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
|
||||||
|
SystemExit)
|
||||||
|
expected = fill("ERROR: " + dedent(WrongPasswordError.__doc__))+"\n"
|
||||||
|
self.assertEqual(cfg.stderr.getvalue(), expected)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_welcome_error(self):
|
||||||
|
cfg = config("send")
|
||||||
|
cfg.stderr = io.StringIO()
|
||||||
|
def fake():
|
||||||
|
raise WelcomeError("abcd")
|
||||||
|
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
|
||||||
|
SystemExit)
|
||||||
|
expected = fill("ERROR: " + dedent(WelcomeError.__doc__))+"\n\nabcd\n"
|
||||||
|
self.assertEqual(cfg.stderr.getvalue(), expected)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_transfer_error(self):
|
||||||
|
cfg = config("send")
|
||||||
|
cfg.stderr = io.StringIO()
|
||||||
|
def fake():
|
||||||
|
raise TransferError("abcd")
|
||||||
|
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
|
||||||
|
SystemExit)
|
||||||
|
expected = "TransferError: abcd\n"
|
||||||
|
self.assertEqual(cfg.stderr.getvalue(), expected)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_other_error(self):
|
||||||
|
cfg = config("send")
|
||||||
|
cfg.stderr = io.StringIO()
|
||||||
|
def fake():
|
||||||
|
raise ValueError("abcd")
|
||||||
|
# I'm seeing unicode problems with the Failure().printTraceback, and
|
||||||
|
# the output would be kind of unpredictable anyways, so we'll mock it
|
||||||
|
# out here.
|
||||||
|
f = mock.Mock()
|
||||||
|
def mock_print(file):
|
||||||
|
file.write("<TRACEBACK>\n")
|
||||||
|
f.printTraceback = mock_print
|
||||||
|
with mock.patch("wormhole.cli.cli.Failure", return_value=f):
|
||||||
|
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
|
||||||
|
SystemExit)
|
||||||
|
expected = "<TRACEBACK>\nERROR: abcd\n"
|
||||||
|
self.assertEqual(cfg.stderr.getvalue(), expected)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user