diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index f994659..cf039fd 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -1,6 +1,5 @@ from __future__ import unicode_literals import functools -import click class ServerError(Exception): def __init__(self, message, relay): @@ -48,14 +47,10 @@ class KeyFormatError(Exception): class ReflectionAttack(Exception): """An attacker (or bug) reflected our outgoing message back to us.""" -# Click needs to receive click.UsageError instances to "do the right -# thing", which is print the error and exit -- perhaps it would be -# better just to re-export click.UsageError here? Or use -# click.UsageError throughout the codebase? -class UsageError(click.UsageError, Exception): +class InternalError(Exception): """The programmer did something wrong.""" -class WormholeClosedError(UsageError): +class WormholeClosedError(InternalError): """API calls may not be made after close() is called.""" class TransferError(Exception): diff --git a/src/wormhole/server/cmd_server.py b/src/wormhole/server/cmd_server.py index 9f9583d..6c3c6cc 100644 --- a/src/wormhole/server/cmd_server.py +++ b/src/wormhole/server/cmd_server.py @@ -1,6 +1,6 @@ from __future__ import print_function, unicode_literals import os, time -from ..errors import UsageError +import click from twisted.python import usage from twisted.scripts import twistd @@ -39,7 +39,7 @@ def kill_server(): try: f = open("twistd.pid", "r") except EnvironmentError: - raise UsageError( + raise click.UsageError( "Unable to find 'twistd.pid' -- is this really a server directory?" ) pid = int(f.read().strip()) diff --git a/src/wormhole/server/cmd_usage.py b/src/wormhole/server/cmd_usage.py index fd294e7..c314313 100644 --- a/src/wormhole/server/cmd_usage.py +++ b/src/wormhole/server/cmd_usage.py @@ -1,8 +1,8 @@ from __future__ import print_function, unicode_literals import os, time, json from collections import defaultdict +import click from .database import get_db -from ..errors import UsageError def abbrev(t): if t is None: @@ -57,7 +57,9 @@ def show_usage(args): print("closed for renovation") return 0 if not os.path.exists("relay.sqlite"): - raise UsageError("cannot find relay.sqlite, please run from the server directory") + raise click.UsageError( + "cannot find relay.sqlite, please run from the server directory" + ) oldest = None newest = None rendezvous_counters = defaultdict(int) @@ -116,7 +118,9 @@ def show_usage(args): def tail_usage(args): if not os.path.exists("relay.sqlite"): - raise UsageError("cannot find relay.sqlite, please run from the server directory") + raise click.UsageError( + "cannot find relay.sqlite, please run from the server directory" + ) db = get_db("relay.sqlite") # we don't seem to have unique row IDs, so this is an inaccurate and # inefficient hack @@ -141,7 +145,9 @@ def tail_usage(args): def count_channels(args): if not os.path.exists("relay.sqlite"): - raise UsageError("cannot find relay.sqlite, please run from the server directory") + raise click.UsageError( + "cannot find relay.sqlite, please run from the server directory" + ) db = get_db("relay.sqlite") c_list = [] c_dict = {} @@ -188,7 +194,9 @@ def count_channels(args): def count_events(args): if not os.path.exists("relay.sqlite"): - raise UsageError("cannot find relay.sqlite, please run from the server directory") + raise click.UsageError( + "cannot find relay.sqlite, please run from the server directory" + ) db = get_db("relay.sqlite") c_list = [] c_dict = {} diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index 6f3f0cd..bdc5806 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -6,8 +6,8 @@ from twisted.internet import defer, task, endpoints, protocol, address, error from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.python import log, failure from twisted.test import proto_helpers +from ..errors import InternalError from .. import transit -from ..errors import UsageError from nacl.secret import SecretBox from nacl.exceptions import CryptoError @@ -147,7 +147,7 @@ class Basic(unittest.TestCase): "hostname": "host", "port": 1234}], }]) - self.assertRaises(UsageError, transit.Common, 123) + self.assertRaises(InternalError, transit.Common, 123) @inlineCallbacks def test_no_relay_hints(self): diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index 2b795f2..67ad14f 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -7,7 +7,7 @@ from twisted.internet import reactor from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks from .common import ServerBase from .. import wormhole -from ..errors import (WrongPasswordError, WelcomeError, UsageError, +from ..errors import (WrongPasswordError, WelcomeError, InternalError, KeyFormatError) from spake2 import SPAKE2_Symmetric from ..timing import DebugTiming @@ -876,20 +876,20 @@ class Errors(ServerBase, unittest.TestCase): def test_codes_1(self): w = wormhole.wormhole(APPID, self.relayurl, reactor) # definitely too early - self.assertRaises(UsageError, w.derive_key, "purpose", 12) + self.assertRaises(InternalError, w.derive_key, "purpose", 12) w.set_code("123-purple-elephant") # code can only be set once - self.assertRaises(UsageError, w.set_code, "123-nope") - yield self.assertFailure(w.get_code(), UsageError) - yield self.assertFailure(w.input_code(), UsageError) + self.assertRaises(InternalError, w.set_code, "123-nope") + yield self.assertFailure(w.get_code(), InternalError) + yield self.assertFailure(w.input_code(), InternalError) yield w.close() @inlineCallbacks def test_codes_2(self): w = wormhole.wormhole(APPID, self.relayurl, reactor) yield w.get_code() - self.assertRaises(UsageError, w.set_code, "123-nope") - yield self.assertFailure(w.get_code(), UsageError) - yield self.assertFailure(w.input_code(), UsageError) + self.assertRaises(InternalError, w.set_code, "123-nope") + yield self.assertFailure(w.get_code(), InternalError) + yield self.assertFailure(w.input_code(), InternalError) yield w.close() diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index 92ac4df..23a116c 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -13,7 +13,7 @@ from twisted.internet.defer import inlineCallbacks, returnValue from twisted.protocols import policies from nacl.secret import SecretBox from hkdf import Hkdf -from .errors import UsageError +from .errors import InternalError from .timing import DebugTiming from . import ipaddrs @@ -276,7 +276,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): return self._description def send_record(self, record): - if not isinstance(record, type(b"")): raise UsageError + if not isinstance(record, type(b"")): raise InternalError assert SecretBox.NONCE_SIZE == 24 assert self.send_nonce < 2**(8*24) assert len(record) < 2**(8*4) @@ -577,7 +577,7 @@ class Common: reactor=reactor, timing=None): if transit_relay: if not isinstance(transit_relay, type(u"")): - raise UsageError + raise InternalError relay = RelayV1Hint(hints=[parse_hint_argv(transit_relay)]) self._transit_relays = [relay] else: diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 0c77e69..11fbd12 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -14,7 +14,7 @@ from hashlib import sha256 from . import __version__ from . import codes #from .errors import ServerError, Timeout -from .errors import (WrongPasswordError, UsageError, WelcomeError, +from .errors import (WrongPasswordError, InternalError, WelcomeError, WormholeClosedError, KeyFormatError) from .timing import DebugTiming from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes, @@ -425,8 +425,8 @@ class _Wormhole: # entry point 1: generate a new code @inlineCallbacks def _API_get_code(self, code_length): - if self._code is not None: raise UsageError - if self._started_get_code: raise UsageError + if self._code is not None: raise InternalError + if self._started_get_code: raise InternalError self._started_get_code = True with self._timing.add("API get_code"): yield self._when_connected() @@ -443,8 +443,8 @@ class _Wormhole: # entry point 2: interactively type in a code, with completion @inlineCallbacks def _API_input_code(self, prompt, code_length): - if self._code is not None: raise UsageError - if self._started_input_code: raise UsageError + if self._code is not None: raise InternalError + if self._started_input_code: raise InternalError self._started_input_code = True with self._timing.add("API input_code"): yield self._when_connected() @@ -465,7 +465,7 @@ class _Wormhole: if not isinstance(code, type(u"")): raise TypeError("Unexpected code type '{}'".format(type(code))) if self._code is not None: - raise UsageError + raise InternalError self._event_learned_code(code) # TODO: entry point 4: restore pre-contact saved state (we haven't heard @@ -528,7 +528,7 @@ class _Wormhole: self._event_learned_mailbox() def _event_learned_mailbox(self): - if not self._mailbox_id: raise UsageError + if not self._mailbox_id: raise InternalError assert self._mailbox_state == CLOSED, self._mailbox_state if self._closing: return @@ -580,7 +580,7 @@ class _Wormhole: def _API_verify(self): if self._error: return defer.fail(self._error) - if self._get_verifier_called: raise UsageError + if self._get_verifier_called: raise InternalError self._get_verifier_called = True if self._verify_result: return defer.succeed(self._verify_result) # bytes or Failure @@ -685,7 +685,7 @@ class _Wormhole: return box.encrypt(data, nonce) def _msg_send(self, phase, body): - if phase in self._sent_phases: raise UsageError + if phase in self._sent_phases: raise InternalError assert self._mailbox_state == OPEN, self._mailbox_state self._sent_phases.add(phase) # TODO: retry on failure, with exponential backoff. We're guarding @@ -702,14 +702,14 @@ class _Wormhole: def _API_derive_key(self, purpose, length): if self._error: raise self._error if self._key is None: - raise UsageError # call derive_key after get_verifier() or get() + raise InternalError # call derive_key after get_verifier() or get() if not isinstance(purpose, type("")): raise TypeError(type(purpose)) return self._derive_key(to_bytes(purpose), length) def _derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(b"")): raise TypeError(type(purpose)) if self._key is None: - raise UsageError # call derive_key after get_verifier() or get() + raise InternalError # call derive_key after get_verifier() or get() return HKDF(self._key, length, CTXinfo=purpose) def _response_handle_message(self, msg): @@ -784,7 +784,7 @@ class _Wormhole: @inlineCallbacks def _API_close(self, res, mood="happy"): if self.DEBUG: print("close") - if self._close_called: raise UsageError + if self._close_called: raise InternalError self._close_called = True self._maybe_close(WormholeClosedError(), mood) if self.DEBUG: print("waiting for disconnect")