improve error signalling

This commit is contained in:
Brian Warner 2016-05-23 00:14:39 -07:00
parent c88d6937c2
commit 528092dd97
4 changed files with 191 additions and 18 deletions

View File

@ -20,6 +20,10 @@ def handle_server_error(func):
class Timeout(Exception):
pass
class WelcomeError(Exception):
"""The server told us to signal an error, probably because our version is
too old to possibly work."""
class WrongPasswordError(Exception):
"""
Key confirmation failed. Either you or your correspondent typed the code

View File

@ -5,7 +5,6 @@ from twisted.trial import unittest
from twisted.internet import protocol, reactor, defer
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.endpoints import clientFromString, connectProtocol
from twisted.web.client import getPage, Agent, readBody
from autobahn.twisted import websocket
from .. import __version__
from .common import ServerBase

View File

@ -1,5 +1,5 @@
from __future__ import print_function
import os, json, re
import os, json, re, gc
from binascii import hexlify, unhexlify
import mock
from twisted.trial import unittest
@ -84,7 +84,14 @@ class Welcome(unittest.TestCase):
self.assertEqual(se.mock_calls, [])
w.handle_welcome({u"error": u"oops"})
self.assertEqual(se.mock_calls, [mock.call(u"oops")])
self.assertEqual(len(se.mock_calls), 1)
self.assertEqual(len(se.mock_calls[0][1]), 1) # posargs
we = se.mock_calls[0][1][0]
self.assertIsInstance(we, wormhole.WelcomeError)
self.assertEqual(we.args, (u"oops",))
# alas WelcomeError instances don't compare against each other
#self.assertEqual(se.mock_calls,
# [mock.call(wormhole.WelcomeError(u"oops"))])
class InputCode(unittest.TestCase):
def test_list(self):
@ -116,10 +123,10 @@ class GetCode(unittest.TestCase):
self.assertEqual(len(pieces), 3) # nameplate plus two words
self.assert_(re.search(r'^\d+-\w+-\w+$', code), code)
class Basic(unittest.TestCase):
def test_create(self):
wormhole._Wormhole(APPID, u"relay_url", reactor, None, None)
def tearDown(self):
# flush out any errorful Deferreds left dangling in cycles
gc.collect()
def check_out(self, out, **kwargs):
# Assert that each kwarg is present in the 'out' dict. Ignore other
@ -135,6 +142,17 @@ class Basic(unittest.TestCase):
self.assertEqual(out[i][u"type"], t, (i,t,out))
return out
def make_pake(self, code, side, msg1):
sp2 = SPAKE2_Symmetric(wormhole.to_bytes(code),
idSymmetric=wormhole.to_bytes(APPID))
msg2 = sp2.start()
msg2_hex = hexlify(msg2).decode("ascii")
key = sp2.finish(msg1)
return key, msg2_hex
def test_create(self):
wormhole._Wormhole(APPID, u"relay_url", reactor, None, None)
def test_basic(self):
# We don't call w._start(), so this doesn't create a WebSocket
# connection. We provide a mock connection instead. If we wanted to
@ -201,11 +219,7 @@ class Basic(unittest.TestCase):
# next we build the simulated peer's PAKE operation
side2 = w._side + u"other"
msg1 = unhexlify(out[1][u"body"].encode("ascii"))
sp2 = SPAKE2_Symmetric(wormhole.to_bytes(CODE),
idSymmetric=wormhole.to_bytes(APPID))
msg2 = sp2.start()
msg2_hex = hexlify(msg2).decode("ascii")
key = sp2.finish(msg1)
key, msg2_hex = self.make_pake(CODE, side2, msg1)
response(w, type=u"message", phase=u"pake", body=msg2_hex, side=side2)
# hearing the peer's PAKE (msg2) makes us release the nameplate, send
@ -272,6 +286,24 @@ class Basic(unittest.TestCase):
self.check_out(out[0], type=u"close", mood=u"happy")
self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
def test_close_wait_0(self):
# close before even claiming the nameplate
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock()
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
d = w.close(wait=True)
self.check_outbound(ws, [u"bind"])
self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
self.assertNoResult(d)
w._ws_closed(True, None, None)
self.successResultOf(d)
def test_close_wait_1(self):
# close after claiming the nameplate, but before opening the mailbox
timing = DebugTiming()
@ -406,6 +438,98 @@ class Basic(unittest.TestCase):
self.assertEqual(len(pieces), 3) # nameplate plus two words
self.assert_(re.search(r'^\d+-\w+-\w+$', code), code)
def test_api_errors(self):
# doing things you're not supposed to do
pass
def test_welcome_error(self):
# A welcome message could arrive at any time, with an [error] key
# that should make us halt. In practice, though, this gets sent as
# soon as the connection is established, which limits the possible
# states in which we might see it.
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock()
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
self.check_outbound(ws, [u"bind"])
WE = wormhole.WelcomeError
d1 = w.get()
d2 = w.get_verifier()
d3 = w.get_code()
# TODO (tricky): test w.input_code
self.assertNoResult(d1)
self.assertNoResult(d2)
self.assertNoResult(d3)
w._signal_error(WE(u"you are not actually welcome"))
self.failureResultOf(d1, WE)
self.failureResultOf(d2, WE)
self.failureResultOf(d3, WE)
# once the error is signalled, all API calls should fail
self.assertRaises(WE, w.send, u"foo")
self.assertRaises(WE, w.derive_key, u"foo")
self.failureResultOf(w.get(), WE)
self.failureResultOf(w.get_verifier(), WE)
def test_confirm_error(self):
# we should only receive the "confirm" message after we receive the
# PAKE message, by which point we should know the key. If the
# confirmation message doesn't decrypt, we signal an error.
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock()
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
w.set_code(u"123-foo-bar")
response(w, type=u"claimed", mailbox=u"mb456")
WP = wormhole.WrongPasswordError
d1 = w.get()
d2 = w.get_verifier()
self.assertNoResult(d1)
self.assertNoResult(d2)
out = ws.outbound()
# [u"bind", u"claim", u"open", u"add"]
self.assertEqual(len(out), 4)
self.assertEqual(out[3][u"type"], u"add")
sp2 = SPAKE2_Symmetric(b"", idSymmetric=wormhole.to_bytes(APPID))
msg2 = sp2.start()
msg2_hex = hexlify(msg2).decode("ascii")
response(w, type=u"message", phase=u"pake", body=msg2_hex, side=u"s2")
self.assertNoResult(d1)
self.successResultOf(d2) # early get_verifier is unaffected
# TODO: get_verifier would be a lovely place to signal a confirmation
# error, but that's at odds with delivering the verifier as early as
# possible. The confirmation messages should be hot on the heels of
# the PAKE message that produced the verifier. Maybe get_verifier()
# should explicitly wait for confirm()?
# sending a random confirm message will cause a confirmation error
confkey = w.derive_key(u"WRONG")
nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH)
badconfirm = wormhole.make_confmsg(confkey, nonce)
badconfirm_hex = hexlify(badconfirm).decode("ascii")
response(w, type=u"message", phase=u"confirm", body=badconfirm_hex,
side=u"s2")
self.failureResultOf(d1, WP)
# once the error is signalled, all API calls should fail
self.assertRaises(WP, w.send, u"foo")
self.assertRaises(WP, w.derive_key, u"foo")
self.failureResultOf(w.get(), WP)
self.failureResultOf(w.get_verifier(), WP)
# event orderings to exercise:
#
# * normal sender: set_code, send_phase1, connected, claimed, learn_msg2,

View File

@ -14,7 +14,7 @@ from spake2 import SPAKE2_Symmetric
from . import __version__
from . import codes
#from .errors import ServerError, Timeout
from .errors import WrongPasswordError, UsageError
from .errors import WrongPasswordError, UsageError, WelcomeError
from .timing import DebugTiming
from hkdf import Hkdf
@ -203,7 +203,7 @@ class _WelcomeHandler:
self._version_warning_displayed = True
if "error" in welcome:
return self._signal_error(welcome["error"])
return self._signal_error(WelcomeError(welcome["error"]))
class _Wormhole:
@ -220,15 +220,16 @@ class _Wormhole:
self._connected = None
self._connection_waiters = []
self._started_get_code = False
self._get_code = None
self._code = None
self._nameplate_id = None
self._nameplate_claimed = False
self._nameplate_released = False
self._release_waiter = defer.Deferred()
self._release_waiter = None
self._mailbox_id = None
self._mailbox_opened = False
self._mailbox_closed = False
self._close_waiter = defer.Deferred()
self._close_waiter = None
self._flag_need_nameplate = True
self._flag_need_to_see_mailbox_used = True
self._flag_need_to_build_msg1 = True
@ -237,6 +238,7 @@ class _Wormhole:
self._closed = False
self._disconnect_waiter = defer.Deferred()
self._mood = u"happy"
self._error = None
self._get_verifier_called = False
self._verifier_waiter = defer.Deferred()
@ -253,7 +255,24 @@ class _Wormhole:
def _signal_error(self, error):
# close the mailbox with an "errory" mood, errback all Deferreds,
# record the error, fail all subsequent API calls
pass # XXX
if self.DEBUG: print("_signal_error", error)
self._error = error # causes new API calls to fail
for d in self._connection_waiters:
d.errback(error)
if self._get_code:
self._get_code._allocated_d.errback(error)
if not self._verifier_waiter.called:
self._verifier_waiter.errback(error)
for d in self._receive_waiters.values():
d.errback(error)
self._maybe_close(mood=u"errory")
if self._release_waiter and not self._release_waiter.called:
self._release_waiter.errback(error)
if self._close_waiter and not self._close_waiter.called:
self._close_waiter.errback(error)
# leave self._disconnect_waiter alone
if self.DEBUG: print("_signal_error done")
def _start(self):
d = self._connect() # causes stuff to happen
@ -346,8 +365,11 @@ class _Wormhole:
with self._timing.add("API get_code"):
yield self._when_connected()
gc = _GetCode(code_length, self._ws_send_command, self._timing)
self._get_code = gc
self._response_handle_allocated = gc._response_handle_allocated
# TODO: signal_error
code = yield gc.go()
self._get_code = None
self._nameplate_claimed = True # side-effect of allocation
self._event_learned_code(code)
returnValue(code)
@ -362,6 +384,7 @@ class _Wormhole:
yield self._when_connected()
ic = _InputCode(prompt, code_length, self._ws_send_command)
self._response_handle_nameplates = ic._response_handle_nameplates
# TODO: signal_error
code = yield ic.go()
self._event_learned_code(code)
returnValue(None)
@ -465,9 +488,11 @@ class _Wormhole:
self._maybe_send_phase_messages()
def get_verifier(self):
if self._error: return defer.fail(self._error)
if self._closed: raise UsageError
if self._get_verifier_called: raise UsageError
self._get_verifier_called = True
# TODO: maybe have this wait on _event_received_confirm too
return self._verifier_waiter
def _event_computed_verifier(self, verifier):
@ -481,10 +506,12 @@ class _Wormhole:
nonce = body[:CONFMSG_NONCE_LENGTH]
if body != make_confmsg(confkey, nonce):
# this makes all API calls fail
if self.DEBUG: print("CONFIRM FAILED")
return self._signal_error(WrongPasswordError())
def send(self, outbound_data):
if self._error: raise self._error
if not isinstance(outbound_data, type(b"")):
raise TypeError(type(outbound_data))
if self._closed: raise UsageError
@ -540,6 +567,7 @@ class _Wormhole:
self._flag_need_to_see_mailbox_used = False
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if self._error: raise self._error
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
if self._key is None:
raise UsageError # call derive_key after get_verifier() or get()
@ -589,6 +617,7 @@ class _Wormhole:
return data
def get(self):
if self._error: return defer.fail(self._error)
if self._closed: raise UsageError
phase = u"%d" % self._next_receive_phase
self._next_receive_phase += 1
@ -598,21 +627,34 @@ class _Wormhole:
d = self._receive_waiters[phase] = defer.Deferred()
return d
def _maybe_close(self, mood):
if self._closed:
return
self.close(mood)
@inlineCallbacks
def close(self, mood=None, wait=False):
# TODO: auto-close on error, mostly for load-from-state
if self.DEBUG: print("close", wait)
if self._closed: raise UsageError
self._closed = True
if mood:
self._mood = mood
self._maybe_release_nameplate()
self._maybe_close_mailbox()
if wait:
if self._nameplate_claimed:
if self.DEBUG: print("waiting for released")
self._release_waiter = defer.Deferred()
yield self._release_waiter
if self._mailbox_opened:
if self.DEBUG: print("waiting for closed")
self._close_waiter = defer.Deferred()
yield self._close_waiter
if self.DEBUG: print("dropping connection")
self._drop_connection()
if wait:
if self.DEBUG: print("waiting for disconnect")
yield self._disconnect_waiter
def _maybe_release_nameplate(self):
@ -623,15 +665,19 @@ class _Wormhole:
self._nameplate_released = True
def _response_handle_released(self, msg):
self._release_waiter.callback(None)
if self._release_waiter and not self._release_waiter.called:
self._release_waiter.callback(None)
def _maybe_close_mailbox(self):
if self.DEBUG: print("_maybe_close_mailbox", self._mailbox_opened, self._mailbox_closed)
if self._mailbox_opened and not self._mailbox_closed:
if self.DEBUG: print(" sending close")
self._ws_send_command(u"close", mood=self._mood)
self._mailbox_closed = True
def _response_handle_closed(self, msg):
self._close_waiter.callback(None)
if self._close_waiter and not self._close_waiter.called:
self._close_waiter.callback(None)
def _drop_connection(self):
self._ws.transport.loseConnection() # probably flushes