improve error signalling
This commit is contained in:
parent
c88d6937c2
commit
528092dd97
|
@ -20,6 +20,10 @@ def handle_server_error(func):
|
|||
class Timeout(Exception):
|
||||
pass
|
||||
|
||||
class WelcomeError(Exception):
|
||||
"""The server told us to signal an error, probably because our version is
|
||||
too old to possibly work."""
|
||||
|
||||
class WrongPasswordError(Exception):
|
||||
"""
|
||||
Key confirmation failed. Either you or your correspondent typed the code
|
||||
|
|
|
@ -5,7 +5,6 @@ from twisted.trial import unittest
|
|||
from twisted.internet import protocol, reactor, defer
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from twisted.internet.endpoints import clientFromString, connectProtocol
|
||||
from twisted.web.client import getPage, Agent, readBody
|
||||
from autobahn.twisted import websocket
|
||||
from .. import __version__
|
||||
from .common import ServerBase
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import print_function
|
||||
import os, json, re
|
||||
import os, json, re, gc
|
||||
from binascii import hexlify, unhexlify
|
||||
import mock
|
||||
from twisted.trial import unittest
|
||||
|
@ -84,7 +84,14 @@ class Welcome(unittest.TestCase):
|
|||
self.assertEqual(se.mock_calls, [])
|
||||
|
||||
w.handle_welcome({u"error": u"oops"})
|
||||
self.assertEqual(se.mock_calls, [mock.call(u"oops")])
|
||||
self.assertEqual(len(se.mock_calls), 1)
|
||||
self.assertEqual(len(se.mock_calls[0][1]), 1) # posargs
|
||||
we = se.mock_calls[0][1][0]
|
||||
self.assertIsInstance(we, wormhole.WelcomeError)
|
||||
self.assertEqual(we.args, (u"oops",))
|
||||
# alas WelcomeError instances don't compare against each other
|
||||
#self.assertEqual(se.mock_calls,
|
||||
# [mock.call(wormhole.WelcomeError(u"oops"))])
|
||||
|
||||
class InputCode(unittest.TestCase):
|
||||
def test_list(self):
|
||||
|
@ -116,10 +123,10 @@ class GetCode(unittest.TestCase):
|
|||
self.assertEqual(len(pieces), 3) # nameplate plus two words
|
||||
self.assert_(re.search(r'^\d+-\w+-\w+$', code), code)
|
||||
|
||||
|
||||
class Basic(unittest.TestCase):
|
||||
def test_create(self):
|
||||
wormhole._Wormhole(APPID, u"relay_url", reactor, None, None)
|
||||
def tearDown(self):
|
||||
# flush out any errorful Deferreds left dangling in cycles
|
||||
gc.collect()
|
||||
|
||||
def check_out(self, out, **kwargs):
|
||||
# Assert that each kwarg is present in the 'out' dict. Ignore other
|
||||
|
@ -135,6 +142,17 @@ class Basic(unittest.TestCase):
|
|||
self.assertEqual(out[i][u"type"], t, (i,t,out))
|
||||
return out
|
||||
|
||||
def make_pake(self, code, side, msg1):
|
||||
sp2 = SPAKE2_Symmetric(wormhole.to_bytes(code),
|
||||
idSymmetric=wormhole.to_bytes(APPID))
|
||||
msg2 = sp2.start()
|
||||
msg2_hex = hexlify(msg2).decode("ascii")
|
||||
key = sp2.finish(msg1)
|
||||
return key, msg2_hex
|
||||
|
||||
def test_create(self):
|
||||
wormhole._Wormhole(APPID, u"relay_url", reactor, None, None)
|
||||
|
||||
def test_basic(self):
|
||||
# We don't call w._start(), so this doesn't create a WebSocket
|
||||
# connection. We provide a mock connection instead. If we wanted to
|
||||
|
@ -201,11 +219,7 @@ class Basic(unittest.TestCase):
|
|||
# next we build the simulated peer's PAKE operation
|
||||
side2 = w._side + u"other"
|
||||
msg1 = unhexlify(out[1][u"body"].encode("ascii"))
|
||||
sp2 = SPAKE2_Symmetric(wormhole.to_bytes(CODE),
|
||||
idSymmetric=wormhole.to_bytes(APPID))
|
||||
msg2 = sp2.start()
|
||||
msg2_hex = hexlify(msg2).decode("ascii")
|
||||
key = sp2.finish(msg1)
|
||||
key, msg2_hex = self.make_pake(CODE, side2, msg1)
|
||||
response(w, type=u"message", phase=u"pake", body=msg2_hex, side=side2)
|
||||
|
||||
# hearing the peer's PAKE (msg2) makes us release the nameplate, send
|
||||
|
@ -272,6 +286,24 @@ class Basic(unittest.TestCase):
|
|||
self.check_out(out[0], type=u"close", mood=u"happy")
|
||||
self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
|
||||
|
||||
def test_close_wait_0(self):
|
||||
# close before even claiming the nameplate
|
||||
timing = DebugTiming()
|
||||
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
|
||||
w._drop_connection = mock.Mock()
|
||||
ws = MockWebSocket()
|
||||
w._event_connected(ws)
|
||||
w._event_ws_opened(None)
|
||||
|
||||
d = w.close(wait=True)
|
||||
self.check_outbound(ws, [u"bind"])
|
||||
self.assertNoResult(d)
|
||||
self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
|
||||
self.assertNoResult(d)
|
||||
|
||||
w._ws_closed(True, None, None)
|
||||
self.successResultOf(d)
|
||||
|
||||
def test_close_wait_1(self):
|
||||
# close after claiming the nameplate, but before opening the mailbox
|
||||
timing = DebugTiming()
|
||||
|
@ -406,6 +438,98 @@ class Basic(unittest.TestCase):
|
|||
self.assertEqual(len(pieces), 3) # nameplate plus two words
|
||||
self.assert_(re.search(r'^\d+-\w+-\w+$', code), code)
|
||||
|
||||
def test_api_errors(self):
|
||||
# doing things you're not supposed to do
|
||||
pass
|
||||
|
||||
def test_welcome_error(self):
|
||||
# A welcome message could arrive at any time, with an [error] key
|
||||
# that should make us halt. In practice, though, this gets sent as
|
||||
# soon as the connection is established, which limits the possible
|
||||
# states in which we might see it.
|
||||
|
||||
timing = DebugTiming()
|
||||
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
|
||||
w._drop_connection = mock.Mock()
|
||||
ws = MockWebSocket()
|
||||
w._event_connected(ws)
|
||||
w._event_ws_opened(None)
|
||||
self.check_outbound(ws, [u"bind"])
|
||||
|
||||
WE = wormhole.WelcomeError
|
||||
d1 = w.get()
|
||||
d2 = w.get_verifier()
|
||||
d3 = w.get_code()
|
||||
# TODO (tricky): test w.input_code
|
||||
|
||||
self.assertNoResult(d1)
|
||||
self.assertNoResult(d2)
|
||||
self.assertNoResult(d3)
|
||||
|
||||
w._signal_error(WE(u"you are not actually welcome"))
|
||||
self.failureResultOf(d1, WE)
|
||||
self.failureResultOf(d2, WE)
|
||||
self.failureResultOf(d3, WE)
|
||||
|
||||
# once the error is signalled, all API calls should fail
|
||||
self.assertRaises(WE, w.send, u"foo")
|
||||
self.assertRaises(WE, w.derive_key, u"foo")
|
||||
self.failureResultOf(w.get(), WE)
|
||||
self.failureResultOf(w.get_verifier(), WE)
|
||||
|
||||
def test_confirm_error(self):
|
||||
# we should only receive the "confirm" message after we receive the
|
||||
# PAKE message, by which point we should know the key. If the
|
||||
# confirmation message doesn't decrypt, we signal an error.
|
||||
timing = DebugTiming()
|
||||
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
|
||||
w._drop_connection = mock.Mock()
|
||||
ws = MockWebSocket()
|
||||
w._event_connected(ws)
|
||||
w._event_ws_opened(None)
|
||||
w.set_code(u"123-foo-bar")
|
||||
response(w, type=u"claimed", mailbox=u"mb456")
|
||||
|
||||
WP = wormhole.WrongPasswordError
|
||||
d1 = w.get()
|
||||
d2 = w.get_verifier()
|
||||
self.assertNoResult(d1)
|
||||
self.assertNoResult(d2)
|
||||
|
||||
out = ws.outbound()
|
||||
# [u"bind", u"claim", u"open", u"add"]
|
||||
self.assertEqual(len(out), 4)
|
||||
self.assertEqual(out[3][u"type"], u"add")
|
||||
|
||||
sp2 = SPAKE2_Symmetric(b"", idSymmetric=wormhole.to_bytes(APPID))
|
||||
msg2 = sp2.start()
|
||||
msg2_hex = hexlify(msg2).decode("ascii")
|
||||
response(w, type=u"message", phase=u"pake", body=msg2_hex, side=u"s2")
|
||||
self.assertNoResult(d1)
|
||||
self.successResultOf(d2) # early get_verifier is unaffected
|
||||
# TODO: get_verifier would be a lovely place to signal a confirmation
|
||||
# error, but that's at odds with delivering the verifier as early as
|
||||
# possible. The confirmation messages should be hot on the heels of
|
||||
# the PAKE message that produced the verifier. Maybe get_verifier()
|
||||
# should explicitly wait for confirm()?
|
||||
|
||||
# sending a random confirm message will cause a confirmation error
|
||||
confkey = w.derive_key(u"WRONG")
|
||||
nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH)
|
||||
badconfirm = wormhole.make_confmsg(confkey, nonce)
|
||||
badconfirm_hex = hexlify(badconfirm).decode("ascii")
|
||||
response(w, type=u"message", phase=u"confirm", body=badconfirm_hex,
|
||||
side=u"s2")
|
||||
|
||||
self.failureResultOf(d1, WP)
|
||||
|
||||
# once the error is signalled, all API calls should fail
|
||||
self.assertRaises(WP, w.send, u"foo")
|
||||
self.assertRaises(WP, w.derive_key, u"foo")
|
||||
self.failureResultOf(w.get(), WP)
|
||||
self.failureResultOf(w.get_verifier(), WP)
|
||||
|
||||
|
||||
# event orderings to exercise:
|
||||
#
|
||||
# * normal sender: set_code, send_phase1, connected, claimed, learn_msg2,
|
||||
|
|
|
@ -14,7 +14,7 @@ from spake2 import SPAKE2_Symmetric
|
|||
from . import __version__
|
||||
from . import codes
|
||||
#from .errors import ServerError, Timeout
|
||||
from .errors import WrongPasswordError, UsageError
|
||||
from .errors import WrongPasswordError, UsageError, WelcomeError
|
||||
from .timing import DebugTiming
|
||||
from hkdf import Hkdf
|
||||
|
||||
|
@ -203,7 +203,7 @@ class _WelcomeHandler:
|
|||
self._version_warning_displayed = True
|
||||
|
||||
if "error" in welcome:
|
||||
return self._signal_error(welcome["error"])
|
||||
return self._signal_error(WelcomeError(welcome["error"]))
|
||||
|
||||
|
||||
class _Wormhole:
|
||||
|
@ -220,15 +220,16 @@ class _Wormhole:
|
|||
self._connected = None
|
||||
self._connection_waiters = []
|
||||
self._started_get_code = False
|
||||
self._get_code = None
|
||||
self._code = None
|
||||
self._nameplate_id = None
|
||||
self._nameplate_claimed = False
|
||||
self._nameplate_released = False
|
||||
self._release_waiter = defer.Deferred()
|
||||
self._release_waiter = None
|
||||
self._mailbox_id = None
|
||||
self._mailbox_opened = False
|
||||
self._mailbox_closed = False
|
||||
self._close_waiter = defer.Deferred()
|
||||
self._close_waiter = None
|
||||
self._flag_need_nameplate = True
|
||||
self._flag_need_to_see_mailbox_used = True
|
||||
self._flag_need_to_build_msg1 = True
|
||||
|
@ -237,6 +238,7 @@ class _Wormhole:
|
|||
self._closed = False
|
||||
self._disconnect_waiter = defer.Deferred()
|
||||
self._mood = u"happy"
|
||||
self._error = None
|
||||
|
||||
self._get_verifier_called = False
|
||||
self._verifier_waiter = defer.Deferred()
|
||||
|
@ -253,7 +255,24 @@ class _Wormhole:
|
|||
def _signal_error(self, error):
|
||||
# close the mailbox with an "errory" mood, errback all Deferreds,
|
||||
# record the error, fail all subsequent API calls
|
||||
pass # XXX
|
||||
if self.DEBUG: print("_signal_error", error)
|
||||
self._error = error # causes new API calls to fail
|
||||
for d in self._connection_waiters:
|
||||
d.errback(error)
|
||||
if self._get_code:
|
||||
self._get_code._allocated_d.errback(error)
|
||||
if not self._verifier_waiter.called:
|
||||
self._verifier_waiter.errback(error)
|
||||
for d in self._receive_waiters.values():
|
||||
d.errback(error)
|
||||
|
||||
self._maybe_close(mood=u"errory")
|
||||
if self._release_waiter and not self._release_waiter.called:
|
||||
self._release_waiter.errback(error)
|
||||
if self._close_waiter and not self._close_waiter.called:
|
||||
self._close_waiter.errback(error)
|
||||
# leave self._disconnect_waiter alone
|
||||
if self.DEBUG: print("_signal_error done")
|
||||
|
||||
def _start(self):
|
||||
d = self._connect() # causes stuff to happen
|
||||
|
@ -346,8 +365,11 @@ class _Wormhole:
|
|||
with self._timing.add("API get_code"):
|
||||
yield self._when_connected()
|
||||
gc = _GetCode(code_length, self._ws_send_command, self._timing)
|
||||
self._get_code = gc
|
||||
self._response_handle_allocated = gc._response_handle_allocated
|
||||
# TODO: signal_error
|
||||
code = yield gc.go()
|
||||
self._get_code = None
|
||||
self._nameplate_claimed = True # side-effect of allocation
|
||||
self._event_learned_code(code)
|
||||
returnValue(code)
|
||||
|
@ -362,6 +384,7 @@ class _Wormhole:
|
|||
yield self._when_connected()
|
||||
ic = _InputCode(prompt, code_length, self._ws_send_command)
|
||||
self._response_handle_nameplates = ic._response_handle_nameplates
|
||||
# TODO: signal_error
|
||||
code = yield ic.go()
|
||||
self._event_learned_code(code)
|
||||
returnValue(None)
|
||||
|
@ -465,9 +488,11 @@ class _Wormhole:
|
|||
self._maybe_send_phase_messages()
|
||||
|
||||
def get_verifier(self):
|
||||
if self._error: return defer.fail(self._error)
|
||||
if self._closed: raise UsageError
|
||||
if self._get_verifier_called: raise UsageError
|
||||
self._get_verifier_called = True
|
||||
# TODO: maybe have this wait on _event_received_confirm too
|
||||
return self._verifier_waiter
|
||||
|
||||
def _event_computed_verifier(self, verifier):
|
||||
|
@ -481,10 +506,12 @@ class _Wormhole:
|
|||
nonce = body[:CONFMSG_NONCE_LENGTH]
|
||||
if body != make_confmsg(confkey, nonce):
|
||||
# this makes all API calls fail
|
||||
if self.DEBUG: print("CONFIRM FAILED")
|
||||
return self._signal_error(WrongPasswordError())
|
||||
|
||||
|
||||
def send(self, outbound_data):
|
||||
if self._error: raise self._error
|
||||
if not isinstance(outbound_data, type(b"")):
|
||||
raise TypeError(type(outbound_data))
|
||||
if self._closed: raise UsageError
|
||||
|
@ -540,6 +567,7 @@ class _Wormhole:
|
|||
self._flag_need_to_see_mailbox_used = False
|
||||
|
||||
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||||
if self._error: raise self._error
|
||||
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
|
||||
if self._key is None:
|
||||
raise UsageError # call derive_key after get_verifier() or get()
|
||||
|
@ -589,6 +617,7 @@ class _Wormhole:
|
|||
return data
|
||||
|
||||
def get(self):
|
||||
if self._error: return defer.fail(self._error)
|
||||
if self._closed: raise UsageError
|
||||
phase = u"%d" % self._next_receive_phase
|
||||
self._next_receive_phase += 1
|
||||
|
@ -598,21 +627,34 @@ class _Wormhole:
|
|||
d = self._receive_waiters[phase] = defer.Deferred()
|
||||
return d
|
||||
|
||||
def _maybe_close(self, mood):
|
||||
if self._closed:
|
||||
return
|
||||
self.close(mood)
|
||||
|
||||
@inlineCallbacks
|
||||
def close(self, mood=None, wait=False):
|
||||
# TODO: auto-close on error, mostly for load-from-state
|
||||
if self.DEBUG: print("close", wait)
|
||||
if self._closed: raise UsageError
|
||||
self._closed = True
|
||||
if mood:
|
||||
self._mood = mood
|
||||
self._maybe_release_nameplate()
|
||||
self._maybe_close_mailbox()
|
||||
if wait:
|
||||
if self._nameplate_claimed:
|
||||
if self.DEBUG: print("waiting for released")
|
||||
self._release_waiter = defer.Deferred()
|
||||
yield self._release_waiter
|
||||
if self._mailbox_opened:
|
||||
if self.DEBUG: print("waiting for closed")
|
||||
self._close_waiter = defer.Deferred()
|
||||
yield self._close_waiter
|
||||
if self.DEBUG: print("dropping connection")
|
||||
self._drop_connection()
|
||||
if wait:
|
||||
if self.DEBUG: print("waiting for disconnect")
|
||||
yield self._disconnect_waiter
|
||||
|
||||
def _maybe_release_nameplate(self):
|
||||
|
@ -623,15 +665,19 @@ class _Wormhole:
|
|||
self._nameplate_released = True
|
||||
|
||||
def _response_handle_released(self, msg):
|
||||
self._release_waiter.callback(None)
|
||||
if self._release_waiter and not self._release_waiter.called:
|
||||
self._release_waiter.callback(None)
|
||||
|
||||
def _maybe_close_mailbox(self):
|
||||
if self.DEBUG: print("_maybe_close_mailbox", self._mailbox_opened, self._mailbox_closed)
|
||||
if self._mailbox_opened and not self._mailbox_closed:
|
||||
if self.DEBUG: print(" sending close")
|
||||
self._ws_send_command(u"close", mood=self._mood)
|
||||
self._mailbox_closed = True
|
||||
|
||||
def _response_handle_closed(self, msg):
|
||||
self._close_waiter.callback(None)
|
||||
if self._close_waiter and not self._close_waiter.called:
|
||||
self._close_waiter.callback(None)
|
||||
|
||||
def _drop_connection(self):
|
||||
self._ws.transport.loseConnection() # probably flushes
|
||||
|
|
Loading…
Reference in New Issue
Block a user