diff --git a/src/wormhole/cli/cmd_ssh.py b/src/wormhole/cli/cmd_ssh.py index 6b5c4ee..715f0bc 100644 --- a/src/wormhole/cli/cmd_ssh.py +++ b/src/wormhole/cli/cmd_ssh.py @@ -8,6 +8,8 @@ import click from .. import xfer_util +class PubkeyError(Exception): + pass def find_public_key(hint=None): """ @@ -23,11 +25,11 @@ def find_public_key(hint=None): hint = expanduser('~/.ssh/') else: if not exists(hint): - raise RuntimeError("Can't find '{}'".format(hint)) + raise PubkeyError("Can't find '{}'".format(hint)) pubkeys = [f for f in os.listdir(hint) if f.endswith('.pub')] if len(pubkeys) == 0: - raise RuntimeError("No public keys in '{}'".format(hint)) + raise PubkeyError("No public keys in '{}'".format(hint)) elif len(pubkeys) > 1: got_key = False while not got_key: diff --git a/src/wormhole/test/test_ssh.py b/src/wormhole/test/test_ssh.py new file mode 100644 index 0000000..b81be35 --- /dev/null +++ b/src/wormhole/test/test_ssh.py @@ -0,0 +1,59 @@ +import os, io +import mock +from twisted.trial import unittest +from ..cli import cmd_ssh + +OTHERS = ["config", "config~", "known_hosts", "known_hosts~"] + +class FindPubkey(unittest.TestCase): + def test_find_one(self): + files = OTHERS + ["id_rsa.pub", "id_rsa"] + pubkey_data = b"ssh-rsa AAAAkeystuff email@host\n" + pubkey_file = io.BytesIO(pubkey_data) + with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True): + with mock.patch("os.listdir", return_value=files) as ld: + with mock.patch("wormhole.cli.cmd_ssh.open", + return_value=pubkey_file): + res = cmd_ssh.find_public_key() + self.assertEqual(ld.mock_calls, + [mock.call(os.path.expanduser("~/.ssh/"))]) + self.assertEqual(len(res), 3, res) + kind, keyid, pubkey = res + self.assertEqual(kind, "ssh-rsa") + self.assertEqual(keyid, "email@host") + self.assertEqual(pubkey, pubkey_data) + + def test_find_none(self): + files = OTHERS # no pubkey + with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True): + with mock.patch("os.listdir", return_value=files): + e = self.assertRaises(cmd_ssh.PubkeyError, + cmd_ssh.find_public_key) + dot_ssh = os.path.expanduser("~/.ssh/") + self.assertEqual(str(e), "No public keys in '{}'".format(dot_ssh)) + + def test_bad_hint(self): + with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=False): + e = self.assertRaises(cmd_ssh.PubkeyError, + cmd_ssh.find_public_key, + hint="bogus/path") + self.assertEqual(str(e), "Can't find 'bogus/path'") + + + def test_find_multiple(self): + files = OTHERS + ["id_rsa.pub", "id_rsa", "id_dsa.pub", "id_dsa"] + pubkey_data = b"ssh-rsa AAAAkeystuff email@host\n" + pubkey_file = io.BytesIO(pubkey_data) + with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True): + with mock.patch("os.listdir", return_value=files): + responses = iter(["frog", "NaN", "-1", "0"]) + with mock.patch("click.prompt", + side_effect=lambda p: next(responses)): + with mock.patch("wormhole.cli.cmd_ssh.open", + return_value=pubkey_file): + res = cmd_ssh.find_public_key() + self.assertEqual(len(res), 3, res) + kind, keyid, pubkey = res + self.assertEqual(kind, "ssh-rsa") + self.assertEqual(keyid, "email@host") + self.assertEqual(pubkey, pubkey_data) diff --git a/src/wormhole/test/test_xfer_util.py b/src/wormhole/test/test_xfer_util.py new file mode 100644 index 0000000..1c08f3a --- /dev/null +++ b/src/wormhole/test/test_xfer_util.py @@ -0,0 +1,49 @@ +from twisted.trial import unittest +from twisted.internet import reactor, defer +from twisted.internet.defer import inlineCallbacks +from .. import xfer_util +from .common import ServerBase + +APPID = u"appid" + +class Xfer(ServerBase, unittest.TestCase): + @inlineCallbacks + def test_xfer(self): + code = u"1-code" + data = u"data" + d1 = xfer_util.send(reactor, APPID, self.relayurl, data, code) + d2 = xfer_util.receive(reactor, APPID, self.relayurl, code) + send_result = yield d1 + receive_result = yield d2 + self.assertEqual(send_result, None) + self.assertEqual(receive_result, data) + + @inlineCallbacks + def test_on_code(self): + code = u"1-code" + data = u"data" + send_code = [] + receive_code = [] + d1 = xfer_util.send(reactor, APPID, self.relayurl, data, code, + on_code=send_code.append) + d2 = xfer_util.receive(reactor, APPID, self.relayurl, code, + on_code=receive_code.append) + send_result = yield d1 + receive_result = yield d2 + self.assertEqual(send_code, [code]) + self.assertEqual(receive_code, [code]) + self.assertEqual(send_result, None) + self.assertEqual(receive_result, data) + + @inlineCallbacks + def test_make_code(self): + data = u"data" + got_code = defer.Deferred() + d1 = xfer_util.send(reactor, APPID, self.relayurl, data, code=None, + on_code=got_code.callback) + code = yield got_code + d2 = xfer_util.receive(reactor, APPID, self.relayurl, code) + send_result = yield d1 + receive_result = yield d2 + self.assertEqual(send_result, None) + self.assertEqual(receive_result, data)