improve error signalling

This commit is contained in:
Brian Warner 2016-05-23 00:14:39 -07:00
parent c88d6937c2
commit 528092dd97
4 changed files with 191 additions and 18 deletions

View File

@ -20,6 +20,10 @@ def handle_server_error(func):
class Timeout(Exception): class Timeout(Exception):
pass pass
class WelcomeError(Exception):
"""The server told us to signal an error, probably because our version is
too old to possibly work."""
class WrongPasswordError(Exception): class WrongPasswordError(Exception):
""" """
Key confirmation failed. Either you or your correspondent typed the code Key confirmation failed. Either you or your correspondent typed the code

View File

@ -5,7 +5,6 @@ from twisted.trial import unittest
from twisted.internet import protocol, reactor, defer from twisted.internet import protocol, reactor, defer
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.endpoints import clientFromString, connectProtocol from twisted.internet.endpoints import clientFromString, connectProtocol
from twisted.web.client import getPage, Agent, readBody
from autobahn.twisted import websocket from autobahn.twisted import websocket
from .. import __version__ from .. import __version__
from .common import ServerBase from .common import ServerBase

View File

@ -1,5 +1,5 @@
from __future__ import print_function from __future__ import print_function
import os, json, re import os, json, re, gc
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
import mock import mock
from twisted.trial import unittest from twisted.trial import unittest
@ -84,7 +84,14 @@ class Welcome(unittest.TestCase):
self.assertEqual(se.mock_calls, []) self.assertEqual(se.mock_calls, [])
w.handle_welcome({u"error": u"oops"}) w.handle_welcome({u"error": u"oops"})
self.assertEqual(se.mock_calls, [mock.call(u"oops")]) self.assertEqual(len(se.mock_calls), 1)
self.assertEqual(len(se.mock_calls[0][1]), 1) # posargs
we = se.mock_calls[0][1][0]
self.assertIsInstance(we, wormhole.WelcomeError)
self.assertEqual(we.args, (u"oops",))
# alas WelcomeError instances don't compare against each other
#self.assertEqual(se.mock_calls,
# [mock.call(wormhole.WelcomeError(u"oops"))])
class InputCode(unittest.TestCase): class InputCode(unittest.TestCase):
def test_list(self): def test_list(self):
@ -116,10 +123,10 @@ class GetCode(unittest.TestCase):
self.assertEqual(len(pieces), 3) # nameplate plus two words self.assertEqual(len(pieces), 3) # nameplate plus two words
self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) self.assert_(re.search(r'^\d+-\w+-\w+$', code), code)
class Basic(unittest.TestCase): class Basic(unittest.TestCase):
def test_create(self): def tearDown(self):
wormhole._Wormhole(APPID, u"relay_url", reactor, None, None) # flush out any errorful Deferreds left dangling in cycles
gc.collect()
def check_out(self, out, **kwargs): def check_out(self, out, **kwargs):
# Assert that each kwarg is present in the 'out' dict. Ignore other # Assert that each kwarg is present in the 'out' dict. Ignore other
@ -135,6 +142,17 @@ class Basic(unittest.TestCase):
self.assertEqual(out[i][u"type"], t, (i,t,out)) self.assertEqual(out[i][u"type"], t, (i,t,out))
return out return out
def make_pake(self, code, side, msg1):
sp2 = SPAKE2_Symmetric(wormhole.to_bytes(code),
idSymmetric=wormhole.to_bytes(APPID))
msg2 = sp2.start()
msg2_hex = hexlify(msg2).decode("ascii")
key = sp2.finish(msg1)
return key, msg2_hex
def test_create(self):
wormhole._Wormhole(APPID, u"relay_url", reactor, None, None)
def test_basic(self): def test_basic(self):
# We don't call w._start(), so this doesn't create a WebSocket # We don't call w._start(), so this doesn't create a WebSocket
# connection. We provide a mock connection instead. If we wanted to # connection. We provide a mock connection instead. If we wanted to
@ -201,11 +219,7 @@ class Basic(unittest.TestCase):
# next we build the simulated peer's PAKE operation # next we build the simulated peer's PAKE operation
side2 = w._side + u"other" side2 = w._side + u"other"
msg1 = unhexlify(out[1][u"body"].encode("ascii")) msg1 = unhexlify(out[1][u"body"].encode("ascii"))
sp2 = SPAKE2_Symmetric(wormhole.to_bytes(CODE), key, msg2_hex = self.make_pake(CODE, side2, msg1)
idSymmetric=wormhole.to_bytes(APPID))
msg2 = sp2.start()
msg2_hex = hexlify(msg2).decode("ascii")
key = sp2.finish(msg1)
response(w, type=u"message", phase=u"pake", body=msg2_hex, side=side2) response(w, type=u"message", phase=u"pake", body=msg2_hex, side=side2)
# hearing the peer's PAKE (msg2) makes us release the nameplate, send # hearing the peer's PAKE (msg2) makes us release the nameplate, send
@ -272,6 +286,24 @@ class Basic(unittest.TestCase):
self.check_out(out[0], type=u"close", mood=u"happy") self.check_out(out[0], type=u"close", mood=u"happy")
self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
def test_close_wait_0(self):
# close before even claiming the nameplate
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock()
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
d = w.close(wait=True)
self.check_outbound(ws, [u"bind"])
self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
self.assertNoResult(d)
w._ws_closed(True, None, None)
self.successResultOf(d)
def test_close_wait_1(self): def test_close_wait_1(self):
# close after claiming the nameplate, but before opening the mailbox # close after claiming the nameplate, but before opening the mailbox
timing = DebugTiming() timing = DebugTiming()
@ -406,6 +438,98 @@ class Basic(unittest.TestCase):
self.assertEqual(len(pieces), 3) # nameplate plus two words self.assertEqual(len(pieces), 3) # nameplate plus two words
self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) self.assert_(re.search(r'^\d+-\w+-\w+$', code), code)
def test_api_errors(self):
# doing things you're not supposed to do
pass
def test_welcome_error(self):
# A welcome message could arrive at any time, with an [error] key
# that should make us halt. In practice, though, this gets sent as
# soon as the connection is established, which limits the possible
# states in which we might see it.
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock()
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
self.check_outbound(ws, [u"bind"])
WE = wormhole.WelcomeError
d1 = w.get()
d2 = w.get_verifier()
d3 = w.get_code()
# TODO (tricky): test w.input_code
self.assertNoResult(d1)
self.assertNoResult(d2)
self.assertNoResult(d3)
w._signal_error(WE(u"you are not actually welcome"))
self.failureResultOf(d1, WE)
self.failureResultOf(d2, WE)
self.failureResultOf(d3, WE)
# once the error is signalled, all API calls should fail
self.assertRaises(WE, w.send, u"foo")
self.assertRaises(WE, w.derive_key, u"foo")
self.failureResultOf(w.get(), WE)
self.failureResultOf(w.get_verifier(), WE)
def test_confirm_error(self):
# we should only receive the "confirm" message after we receive the
# PAKE message, by which point we should know the key. If the
# confirmation message doesn't decrypt, we signal an error.
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock()
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
w.set_code(u"123-foo-bar")
response(w, type=u"claimed", mailbox=u"mb456")
WP = wormhole.WrongPasswordError
d1 = w.get()
d2 = w.get_verifier()
self.assertNoResult(d1)
self.assertNoResult(d2)
out = ws.outbound()
# [u"bind", u"claim", u"open", u"add"]
self.assertEqual(len(out), 4)
self.assertEqual(out[3][u"type"], u"add")
sp2 = SPAKE2_Symmetric(b"", idSymmetric=wormhole.to_bytes(APPID))
msg2 = sp2.start()
msg2_hex = hexlify(msg2).decode("ascii")
response(w, type=u"message", phase=u"pake", body=msg2_hex, side=u"s2")
self.assertNoResult(d1)
self.successResultOf(d2) # early get_verifier is unaffected
# TODO: get_verifier would be a lovely place to signal a confirmation
# error, but that's at odds with delivering the verifier as early as
# possible. The confirmation messages should be hot on the heels of
# the PAKE message that produced the verifier. Maybe get_verifier()
# should explicitly wait for confirm()?
# sending a random confirm message will cause a confirmation error
confkey = w.derive_key(u"WRONG")
nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH)
badconfirm = wormhole.make_confmsg(confkey, nonce)
badconfirm_hex = hexlify(badconfirm).decode("ascii")
response(w, type=u"message", phase=u"confirm", body=badconfirm_hex,
side=u"s2")
self.failureResultOf(d1, WP)
# once the error is signalled, all API calls should fail
self.assertRaises(WP, w.send, u"foo")
self.assertRaises(WP, w.derive_key, u"foo")
self.failureResultOf(w.get(), WP)
self.failureResultOf(w.get_verifier(), WP)
# event orderings to exercise: # event orderings to exercise:
# #
# * normal sender: set_code, send_phase1, connected, claimed, learn_msg2, # * normal sender: set_code, send_phase1, connected, claimed, learn_msg2,

View File

@ -14,7 +14,7 @@ from spake2 import SPAKE2_Symmetric
from . import __version__ from . import __version__
from . import codes from . import codes
#from .errors import ServerError, Timeout #from .errors import ServerError, Timeout
from .errors import WrongPasswordError, UsageError from .errors import WrongPasswordError, UsageError, WelcomeError
from .timing import DebugTiming from .timing import DebugTiming
from hkdf import Hkdf from hkdf import Hkdf
@ -203,7 +203,7 @@ class _WelcomeHandler:
self._version_warning_displayed = True self._version_warning_displayed = True
if "error" in welcome: if "error" in welcome:
return self._signal_error(welcome["error"]) return self._signal_error(WelcomeError(welcome["error"]))
class _Wormhole: class _Wormhole:
@ -220,15 +220,16 @@ class _Wormhole:
self._connected = None self._connected = None
self._connection_waiters = [] self._connection_waiters = []
self._started_get_code = False self._started_get_code = False
self._get_code = None
self._code = None self._code = None
self._nameplate_id = None self._nameplate_id = None
self._nameplate_claimed = False self._nameplate_claimed = False
self._nameplate_released = False self._nameplate_released = False
self._release_waiter = defer.Deferred() self._release_waiter = None
self._mailbox_id = None self._mailbox_id = None
self._mailbox_opened = False self._mailbox_opened = False
self._mailbox_closed = False self._mailbox_closed = False
self._close_waiter = defer.Deferred() self._close_waiter = None
self._flag_need_nameplate = True self._flag_need_nameplate = True
self._flag_need_to_see_mailbox_used = True self._flag_need_to_see_mailbox_used = True
self._flag_need_to_build_msg1 = True self._flag_need_to_build_msg1 = True
@ -237,6 +238,7 @@ class _Wormhole:
self._closed = False self._closed = False
self._disconnect_waiter = defer.Deferred() self._disconnect_waiter = defer.Deferred()
self._mood = u"happy" self._mood = u"happy"
self._error = None
self._get_verifier_called = False self._get_verifier_called = False
self._verifier_waiter = defer.Deferred() self._verifier_waiter = defer.Deferred()
@ -253,7 +255,24 @@ class _Wormhole:
def _signal_error(self, error): def _signal_error(self, error):
# close the mailbox with an "errory" mood, errback all Deferreds, # close the mailbox with an "errory" mood, errback all Deferreds,
# record the error, fail all subsequent API calls # record the error, fail all subsequent API calls
pass # XXX if self.DEBUG: print("_signal_error", error)
self._error = error # causes new API calls to fail
for d in self._connection_waiters:
d.errback(error)
if self._get_code:
self._get_code._allocated_d.errback(error)
if not self._verifier_waiter.called:
self._verifier_waiter.errback(error)
for d in self._receive_waiters.values():
d.errback(error)
self._maybe_close(mood=u"errory")
if self._release_waiter and not self._release_waiter.called:
self._release_waiter.errback(error)
if self._close_waiter and not self._close_waiter.called:
self._close_waiter.errback(error)
# leave self._disconnect_waiter alone
if self.DEBUG: print("_signal_error done")
def _start(self): def _start(self):
d = self._connect() # causes stuff to happen d = self._connect() # causes stuff to happen
@ -346,8 +365,11 @@ class _Wormhole:
with self._timing.add("API get_code"): with self._timing.add("API get_code"):
yield self._when_connected() yield self._when_connected()
gc = _GetCode(code_length, self._ws_send_command, self._timing) gc = _GetCode(code_length, self._ws_send_command, self._timing)
self._get_code = gc
self._response_handle_allocated = gc._response_handle_allocated self._response_handle_allocated = gc._response_handle_allocated
# TODO: signal_error
code = yield gc.go() code = yield gc.go()
self._get_code = None
self._nameplate_claimed = True # side-effect of allocation self._nameplate_claimed = True # side-effect of allocation
self._event_learned_code(code) self._event_learned_code(code)
returnValue(code) returnValue(code)
@ -362,6 +384,7 @@ class _Wormhole:
yield self._when_connected() yield self._when_connected()
ic = _InputCode(prompt, code_length, self._ws_send_command) ic = _InputCode(prompt, code_length, self._ws_send_command)
self._response_handle_nameplates = ic._response_handle_nameplates self._response_handle_nameplates = ic._response_handle_nameplates
# TODO: signal_error
code = yield ic.go() code = yield ic.go()
self._event_learned_code(code) self._event_learned_code(code)
returnValue(None) returnValue(None)
@ -465,9 +488,11 @@ class _Wormhole:
self._maybe_send_phase_messages() self._maybe_send_phase_messages()
def get_verifier(self): def get_verifier(self):
if self._error: return defer.fail(self._error)
if self._closed: raise UsageError if self._closed: raise UsageError
if self._get_verifier_called: raise UsageError if self._get_verifier_called: raise UsageError
self._get_verifier_called = True self._get_verifier_called = True
# TODO: maybe have this wait on _event_received_confirm too
return self._verifier_waiter return self._verifier_waiter
def _event_computed_verifier(self, verifier): def _event_computed_verifier(self, verifier):
@ -481,10 +506,12 @@ class _Wormhole:
nonce = body[:CONFMSG_NONCE_LENGTH] nonce = body[:CONFMSG_NONCE_LENGTH]
if body != make_confmsg(confkey, nonce): if body != make_confmsg(confkey, nonce):
# this makes all API calls fail # this makes all API calls fail
if self.DEBUG: print("CONFIRM FAILED")
return self._signal_error(WrongPasswordError()) return self._signal_error(WrongPasswordError())
def send(self, outbound_data): def send(self, outbound_data):
if self._error: raise self._error
if not isinstance(outbound_data, type(b"")): if not isinstance(outbound_data, type(b"")):
raise TypeError(type(outbound_data)) raise TypeError(type(outbound_data))
if self._closed: raise UsageError if self._closed: raise UsageError
@ -540,6 +567,7 @@ class _Wormhole:
self._flag_need_to_see_mailbox_used = False self._flag_need_to_see_mailbox_used = False
def derive_key(self, purpose, length=SecretBox.KEY_SIZE): def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if self._error: raise self._error
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
if self._key is None: if self._key is None:
raise UsageError # call derive_key after get_verifier() or get() raise UsageError # call derive_key after get_verifier() or get()
@ -589,6 +617,7 @@ class _Wormhole:
return data return data
def get(self): def get(self):
if self._error: return defer.fail(self._error)
if self._closed: raise UsageError if self._closed: raise UsageError
phase = u"%d" % self._next_receive_phase phase = u"%d" % self._next_receive_phase
self._next_receive_phase += 1 self._next_receive_phase += 1
@ -598,21 +627,34 @@ class _Wormhole:
d = self._receive_waiters[phase] = defer.Deferred() d = self._receive_waiters[phase] = defer.Deferred()
return d return d
def _maybe_close(self, mood):
if self._closed:
return
self.close(mood)
@inlineCallbacks @inlineCallbacks
def close(self, mood=None, wait=False): def close(self, mood=None, wait=False):
# TODO: auto-close on error, mostly for load-from-state # TODO: auto-close on error, mostly for load-from-state
if self.DEBUG: print("close", wait)
if self._closed: raise UsageError if self._closed: raise UsageError
self._closed = True
if mood: if mood:
self._mood = mood self._mood = mood
self._maybe_release_nameplate() self._maybe_release_nameplate()
self._maybe_close_mailbox() self._maybe_close_mailbox()
if wait: if wait:
if self._nameplate_claimed: if self._nameplate_claimed:
if self.DEBUG: print("waiting for released")
self._release_waiter = defer.Deferred()
yield self._release_waiter yield self._release_waiter
if self._mailbox_opened: if self._mailbox_opened:
if self.DEBUG: print("waiting for closed")
self._close_waiter = defer.Deferred()
yield self._close_waiter yield self._close_waiter
if self.DEBUG: print("dropping connection")
self._drop_connection() self._drop_connection()
if wait: if wait:
if self.DEBUG: print("waiting for disconnect")
yield self._disconnect_waiter yield self._disconnect_waiter
def _maybe_release_nameplate(self): def _maybe_release_nameplate(self):
@ -623,15 +665,19 @@ class _Wormhole:
self._nameplate_released = True self._nameplate_released = True
def _response_handle_released(self, msg): def _response_handle_released(self, msg):
self._release_waiter.callback(None) if self._release_waiter and not self._release_waiter.called:
self._release_waiter.callback(None)
def _maybe_close_mailbox(self): def _maybe_close_mailbox(self):
if self.DEBUG: print("_maybe_close_mailbox", self._mailbox_opened, self._mailbox_closed)
if self._mailbox_opened and not self._mailbox_closed: if self._mailbox_opened and not self._mailbox_closed:
if self.DEBUG: print(" sending close")
self._ws_send_command(u"close", mood=self._mood) self._ws_send_command(u"close", mood=self._mood)
self._mailbox_closed = True self._mailbox_closed = True
def _response_handle_closed(self, msg): def _response_handle_closed(self, msg):
self._close_waiter.callback(None) if self._close_waiter and not self._close_waiter.called:
self._close_waiter.callback(None)
def _drop_connection(self): def _drop_connection(self):
self._ws.transport.loseConnection() # probably flushes self._ws.transport.loseConnection() # probably flushes