Change UsageError -> InternalError, use click.UsageError for human-visible errors

This commit is contained in:
meejah 2016-06-22 02:04:05 -06:00
parent f32cd46e2c
commit 7fab6b3dff
7 changed files with 42 additions and 39 deletions

View File

@ -1,6 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import functools import functools
import click
class ServerError(Exception): class ServerError(Exception):
def __init__(self, message, relay): def __init__(self, message, relay):
@ -48,14 +47,10 @@ class KeyFormatError(Exception):
class ReflectionAttack(Exception): class ReflectionAttack(Exception):
"""An attacker (or bug) reflected our outgoing message back to us.""" """An attacker (or bug) reflected our outgoing message back to us."""
# Click needs to receive click.UsageError instances to "do the right class InternalError(Exception):
# thing", which is print the error and exit -- perhaps it would be
# better just to re-export click.UsageError here? Or use
# click.UsageError throughout the codebase?
class UsageError(click.UsageError, Exception):
"""The programmer did something wrong.""" """The programmer did something wrong."""
class WormholeClosedError(UsageError): class WormholeClosedError(InternalError):
"""API calls may not be made after close() is called.""" """API calls may not be made after close() is called."""
class TransferError(Exception): class TransferError(Exception):

View File

@ -1,6 +1,6 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import os, time import os, time
from ..errors import UsageError import click
from twisted.python import usage from twisted.python import usage
from twisted.scripts import twistd from twisted.scripts import twistd
@ -39,7 +39,7 @@ def kill_server():
try: try:
f = open("twistd.pid", "r") f = open("twistd.pid", "r")
except EnvironmentError: except EnvironmentError:
raise UsageError( raise click.UsageError(
"Unable to find 'twistd.pid' -- is this really a server directory?" "Unable to find 'twistd.pid' -- is this really a server directory?"
) )
pid = int(f.read().strip()) pid = int(f.read().strip())

View File

@ -1,8 +1,8 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import os, time, json import os, time, json
from collections import defaultdict from collections import defaultdict
import click
from .database import get_db from .database import get_db
from ..errors import UsageError
def abbrev(t): def abbrev(t):
if t is None: if t is None:
@ -57,7 +57,9 @@ def show_usage(args):
print("closed for renovation") print("closed for renovation")
return 0 return 0
if not os.path.exists("relay.sqlite"): if not os.path.exists("relay.sqlite"):
raise UsageError("cannot find relay.sqlite, please run from the server directory") raise click.UsageError(
"cannot find relay.sqlite, please run from the server directory"
)
oldest = None oldest = None
newest = None newest = None
rendezvous_counters = defaultdict(int) rendezvous_counters = defaultdict(int)
@ -116,7 +118,9 @@ def show_usage(args):
def tail_usage(args): def tail_usage(args):
if not os.path.exists("relay.sqlite"): if not os.path.exists("relay.sqlite"):
raise UsageError("cannot find relay.sqlite, please run from the server directory") raise click.UsageError(
"cannot find relay.sqlite, please run from the server directory"
)
db = get_db("relay.sqlite") db = get_db("relay.sqlite")
# we don't seem to have unique row IDs, so this is an inaccurate and # we don't seem to have unique row IDs, so this is an inaccurate and
# inefficient hack # inefficient hack
@ -141,7 +145,9 @@ def tail_usage(args):
def count_channels(args): def count_channels(args):
if not os.path.exists("relay.sqlite"): if not os.path.exists("relay.sqlite"):
raise UsageError("cannot find relay.sqlite, please run from the server directory") raise click.UsageError(
"cannot find relay.sqlite, please run from the server directory"
)
db = get_db("relay.sqlite") db = get_db("relay.sqlite")
c_list = [] c_list = []
c_dict = {} c_dict = {}
@ -188,7 +194,9 @@ def count_channels(args):
def count_events(args): def count_events(args):
if not os.path.exists("relay.sqlite"): if not os.path.exists("relay.sqlite"):
raise UsageError("cannot find relay.sqlite, please run from the server directory") raise click.UsageError(
"cannot find relay.sqlite, please run from the server directory"
)
db = get_db("relay.sqlite") db = get_db("relay.sqlite")
c_list = [] c_list = []
c_dict = {} c_dict = {}

View File

@ -6,8 +6,8 @@ from twisted.internet import defer, task, endpoints, protocol, address, error
from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.internet.defer import gatherResults, inlineCallbacks
from twisted.python import log, failure from twisted.python import log, failure
from twisted.test import proto_helpers from twisted.test import proto_helpers
from ..errors import InternalError
from .. import transit from .. import transit
from ..errors import UsageError
from nacl.secret import SecretBox from nacl.secret import SecretBox
from nacl.exceptions import CryptoError from nacl.exceptions import CryptoError
@ -147,7 +147,7 @@ class Basic(unittest.TestCase):
"hostname": "host", "hostname": "host",
"port": 1234}], "port": 1234}],
}]) }])
self.assertRaises(UsageError, transit.Common, 123) self.assertRaises(InternalError, transit.Common, 123)
@inlineCallbacks @inlineCallbacks
def test_no_relay_hints(self): def test_no_relay_hints(self):

View File

@ -7,7 +7,7 @@ from twisted.internet import reactor
from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks
from .common import ServerBase from .common import ServerBase
from .. import wormhole from .. import wormhole
from ..errors import (WrongPasswordError, WelcomeError, UsageError, from ..errors import (WrongPasswordError, WelcomeError, InternalError,
KeyFormatError) KeyFormatError)
from spake2 import SPAKE2_Symmetric from spake2 import SPAKE2_Symmetric
from ..timing import DebugTiming from ..timing import DebugTiming
@ -876,20 +876,20 @@ class Errors(ServerBase, unittest.TestCase):
def test_codes_1(self): def test_codes_1(self):
w = wormhole.wormhole(APPID, self.relayurl, reactor) w = wormhole.wormhole(APPID, self.relayurl, reactor)
# definitely too early # definitely too early
self.assertRaises(UsageError, w.derive_key, "purpose", 12) self.assertRaises(InternalError, w.derive_key, "purpose", 12)
w.set_code("123-purple-elephant") w.set_code("123-purple-elephant")
# code can only be set once # code can only be set once
self.assertRaises(UsageError, w.set_code, "123-nope") self.assertRaises(InternalError, w.set_code, "123-nope")
yield self.assertFailure(w.get_code(), UsageError) yield self.assertFailure(w.get_code(), InternalError)
yield self.assertFailure(w.input_code(), UsageError) yield self.assertFailure(w.input_code(), InternalError)
yield w.close() yield w.close()
@inlineCallbacks @inlineCallbacks
def test_codes_2(self): def test_codes_2(self):
w = wormhole.wormhole(APPID, self.relayurl, reactor) w = wormhole.wormhole(APPID, self.relayurl, reactor)
yield w.get_code() yield w.get_code()
self.assertRaises(UsageError, w.set_code, "123-nope") self.assertRaises(InternalError, w.set_code, "123-nope")
yield self.assertFailure(w.get_code(), UsageError) yield self.assertFailure(w.get_code(), InternalError)
yield self.assertFailure(w.input_code(), UsageError) yield self.assertFailure(w.input_code(), InternalError)
yield w.close() yield w.close()

View File

@ -13,7 +13,7 @@ from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.protocols import policies from twisted.protocols import policies
from nacl.secret import SecretBox from nacl.secret import SecretBox
from hkdf import Hkdf from hkdf import Hkdf
from .errors import UsageError from .errors import InternalError
from .timing import DebugTiming from .timing import DebugTiming
from . import ipaddrs from . import ipaddrs
@ -276,7 +276,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
return self._description return self._description
def send_record(self, record): def send_record(self, record):
if not isinstance(record, type(b"")): raise UsageError if not isinstance(record, type(b"")): raise InternalError
assert SecretBox.NONCE_SIZE == 24 assert SecretBox.NONCE_SIZE == 24
assert self.send_nonce < 2**(8*24) assert self.send_nonce < 2**(8*24)
assert len(record) < 2**(8*4) assert len(record) < 2**(8*4)
@ -577,7 +577,7 @@ class Common:
reactor=reactor, timing=None): reactor=reactor, timing=None):
if transit_relay: if transit_relay:
if not isinstance(transit_relay, type(u"")): if not isinstance(transit_relay, type(u"")):
raise UsageError raise InternalError
relay = RelayV1Hint(hints=[parse_hint_argv(transit_relay)]) relay = RelayV1Hint(hints=[parse_hint_argv(transit_relay)])
self._transit_relays = [relay] self._transit_relays = [relay]
else: else:

View File

@ -14,7 +14,7 @@ from hashlib import sha256
from . import __version__ from . import __version__
from . import codes from . import codes
#from .errors import ServerError, Timeout #from .errors import ServerError, Timeout
from .errors import (WrongPasswordError, UsageError, WelcomeError, from .errors import (WrongPasswordError, InternalError, WelcomeError,
WormholeClosedError, KeyFormatError) WormholeClosedError, KeyFormatError)
from .timing import DebugTiming from .timing import DebugTiming
from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes, from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes,
@ -425,8 +425,8 @@ class _Wormhole:
# entry point 1: generate a new code # entry point 1: generate a new code
@inlineCallbacks @inlineCallbacks
def _API_get_code(self, code_length): def _API_get_code(self, code_length):
if self._code is not None: raise UsageError if self._code is not None: raise InternalError
if self._started_get_code: raise UsageError if self._started_get_code: raise InternalError
self._started_get_code = True self._started_get_code = True
with self._timing.add("API get_code"): with self._timing.add("API get_code"):
yield self._when_connected() yield self._when_connected()
@ -443,8 +443,8 @@ class _Wormhole:
# entry point 2: interactively type in a code, with completion # entry point 2: interactively type in a code, with completion
@inlineCallbacks @inlineCallbacks
def _API_input_code(self, prompt, code_length): def _API_input_code(self, prompt, code_length):
if self._code is not None: raise UsageError if self._code is not None: raise InternalError
if self._started_input_code: raise UsageError if self._started_input_code: raise InternalError
self._started_input_code = True self._started_input_code = True
with self._timing.add("API input_code"): with self._timing.add("API input_code"):
yield self._when_connected() yield self._when_connected()
@ -465,7 +465,7 @@ class _Wormhole:
if not isinstance(code, type(u"")): if not isinstance(code, type(u"")):
raise TypeError("Unexpected code type '{}'".format(type(code))) raise TypeError("Unexpected code type '{}'".format(type(code)))
if self._code is not None: if self._code is not None:
raise UsageError raise InternalError
self._event_learned_code(code) self._event_learned_code(code)
# TODO: entry point 4: restore pre-contact saved state (we haven't heard # TODO: entry point 4: restore pre-contact saved state (we haven't heard
@ -528,7 +528,7 @@ class _Wormhole:
self._event_learned_mailbox() self._event_learned_mailbox()
def _event_learned_mailbox(self): def _event_learned_mailbox(self):
if not self._mailbox_id: raise UsageError if not self._mailbox_id: raise InternalError
assert self._mailbox_state == CLOSED, self._mailbox_state assert self._mailbox_state == CLOSED, self._mailbox_state
if self._closing: if self._closing:
return return
@ -580,7 +580,7 @@ class _Wormhole:
def _API_verify(self): def _API_verify(self):
if self._error: return defer.fail(self._error) if self._error: return defer.fail(self._error)
if self._get_verifier_called: raise UsageError if self._get_verifier_called: raise InternalError
self._get_verifier_called = True self._get_verifier_called = True
if self._verify_result: if self._verify_result:
return defer.succeed(self._verify_result) # bytes or Failure return defer.succeed(self._verify_result) # bytes or Failure
@ -685,7 +685,7 @@ class _Wormhole:
return box.encrypt(data, nonce) return box.encrypt(data, nonce)
def _msg_send(self, phase, body): def _msg_send(self, phase, body):
if phase in self._sent_phases: raise UsageError if phase in self._sent_phases: raise InternalError
assert self._mailbox_state == OPEN, self._mailbox_state assert self._mailbox_state == OPEN, self._mailbox_state
self._sent_phases.add(phase) self._sent_phases.add(phase)
# TODO: retry on failure, with exponential backoff. We're guarding # TODO: retry on failure, with exponential backoff. We're guarding
@ -702,14 +702,14 @@ class _Wormhole:
def _API_derive_key(self, purpose, length): def _API_derive_key(self, purpose, length):
if self._error: raise self._error if self._error: raise self._error
if self._key is None: if self._key is None:
raise UsageError # call derive_key after get_verifier() or get() raise InternalError # call derive_key after get_verifier() or get()
if not isinstance(purpose, type("")): raise TypeError(type(purpose)) if not isinstance(purpose, type("")): raise TypeError(type(purpose))
return self._derive_key(to_bytes(purpose), length) return self._derive_key(to_bytes(purpose), length)
def _derive_key(self, purpose, length=SecretBox.KEY_SIZE): def _derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(purpose, type(b"")): raise TypeError(type(purpose)) if not isinstance(purpose, type(b"")): raise TypeError(type(purpose))
if self._key is None: if self._key is None:
raise UsageError # call derive_key after get_verifier() or get() raise InternalError # call derive_key after get_verifier() or get()
return HKDF(self._key, length, CTXinfo=purpose) return HKDF(self._key, length, CTXinfo=purpose)
def _response_handle_message(self, msg): def _response_handle_message(self, msg):
@ -784,7 +784,7 @@ class _Wormhole:
@inlineCallbacks @inlineCallbacks
def _API_close(self, res, mood="happy"): def _API_close(self, res, mood="happy"):
if self.DEBUG: print("close") if self.DEBUG: print("close")
if self._close_called: raise UsageError if self._close_called: raise InternalError
self._close_called = True self._close_called = True
self._maybe_close(WormholeClosedError(), mood) self._maybe_close(WormholeClosedError(), mood)
if self.DEBUG: print("waiting for disconnect") if self.DEBUG: print("waiting for disconnect")