add close-with-mood-on-error to twisted style too

This commit is contained in:
Brian Warner 2015-11-19 17:08:21 -08:00
parent d1cf1c6da0
commit 22a1ce2eda
2 changed files with 74 additions and 37 deletions

View File

@ -317,18 +317,20 @@ class Basic(ServerBase, unittest.TestCase):
def test_errors(self):
w1 = Wormhole(APPID, self.relayurl)
self.assertRaises(UsageError, w1.get_verifier)
self.assertRaises(UsageError, w1.send_data, b"data")
self.assertRaises(UsageError, w1.get_data)
w1.set_code(u"123-purple-elephant")
self.assertRaises(UsageError, w1.set_code, u"123-nope")
self.assertRaises(UsageError, w1.get_code)
w2 = Wormhole(APPID, self.relayurl)
d = w2.get_code()
self.assertRaises(UsageError, w2.get_code)
def _got_code(code):
return self.doBoth(w1.close(), w2.close())
d.addCallback(_got_code)
d = self.assertFailure(w1.get_verifier(), UsageError)
d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data"), UsageError))
d.addCallback(lambda _: self.assertFailure(w1.get_data(), UsageError))
d.addCallback(lambda _: w1.set_code(u"123-purple-elephant"))
# these two UsageErrors are synchronous, although most of the rest are async
d.addCallback(lambda _: self.assertRaises(UsageError, w1.set_code, u"123-nope"))
d.addCallback(lambda _: self.assertRaises(UsageError, w1.get_code))
def _then(_):
w2 = Wormhole(APPID, self.relayurl)
d2 = w2.get_code()
d2.addCallback(lambda _: self.assertRaises(UsageError, w2.get_code))
d2.addCallback(lambda _: self.doBoth(w1.close(), w2.close()))
return d2
d.addCallback(_then)
return d
def test_repeat_phases(self):
@ -339,27 +341,23 @@ class Basic(ServerBase, unittest.TestCase):
# we must let them establish a key before we can send data
d = self.doBoth(w1.get_verifier(), w2.get_verifier())
d.addCallback(lambda _: w1.send_data(b"data1", phase=u"1"))
def _sent(res):
# underscore-prefixed phases are reserved
self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"_1")
self.assertRaises(UsageError, w1.get_data, phase=u"_1")
# you can't send twice to the same phase
self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"1")
# but you can send to a different one
return w1.send_data(b"data2", phase=u"2")
d.addCallback(_sent)
# underscore-prefixed phases are reserved
d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data1", phase=u"_1"),
UsageError))
d.addCallback(lambda _: self.assertFailure(w1.get_data(phase=u"_1"), UsageError))
# you can't send twice to the same phase
d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data1", phase=u"1"),
UsageError))
# but you can send to a different one
d.addCallback(lambda _: w1.send_data(b"data2", phase=u"2"))
d.addCallback(lambda _: w2.get_data(phase=u"1"))
def _got1(res):
self.failUnlessEqual(res, b"data1")
# and you can't read twice from the same phase
self.assertRaises(UsageError, w2.get_data, phase=u"1")
# but you can read from a different one
return w2.get_data(phase=u"2")
d.addCallback(_got1)
def _got2(res):
self.failUnlessEqual(res, b"data2")
return self.doBoth(w1.close(), w2.close())
d.addCallback(_got2)
d.addCallback(lambda res: self.failUnlessEqual(res, b"data1"))
# and you can't read twice from the same phase
d.addCallback(lambda _: self.assertFailure(w2.get_data(phase=u"1"), UsageError))
# but you can read from a different one
d.addCallback(lambda _: w2.get_data(phase=u"2"))
d.addCallback(lambda res: self.failUnlessEqual(res, b"data2"))
d.addCallback(lambda _: self.doBoth(w1.close(), w2.close()))
return d
def test_serialize(self):

View File

@ -14,7 +14,7 @@ from spake2 import SPAKE2_Symmetric
from .eventsource_twisted import ReconnectingEventSource
from .. import __version__
from .. import codes
from ..errors import ServerError, WrongPasswordError, UsageError
from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
from ..util.hkdf import HKDF
from ..channel_monitor import monitor
@ -153,7 +153,7 @@ class Channel:
d.addCallback(_got)
return d
def deallocate(self, mood=u"unknown"):
def deallocate(self, mood=None):
# only try once, no retries
d = post_json(self._agent, self._relay_url+"deallocate",
{"appid": self._appid,
@ -194,6 +194,31 @@ class ChannelManager:
return Channel(self._relay, self._appid, channelid, self._side,
self._handle_welcome, self._agent)
def close_on_error(meth): # method decorator
# Clients report certain errors as "moods", so the server can make a
# rough count failed connections (due to mismatched passwords, attacks,
# or timeouts). We don't report precondition failures, as those are the
# responsibility/fault of the local application code. We count
# non-precondition errors in case they represent server-side problems.
def _wrapper(self, *args, **kwargs):
d = defer.maybeDeferred(meth, self, *args, **kwargs)
def _onerror(f):
if f.check(Timeout):
d2 = self.close(u"lonely")
elif f.check(WrongPasswordError):
d2 = self.close(u"scary")
elif f.check(TypeError, UsageError):
# preconditions don't warrant _close_with_error()
d2 = defer.succeed(None)
else:
d2 = self.close(u"errory")
d2.addBoth(lambda _: f)
return d2
d.addErrback(_onerror)
return d
return _wrapper
class Wormhole:
motd_displayed = False
version_warning_displayed = False
@ -213,6 +238,7 @@ class Wormhole:
self._sent_data = set() # phases
self._got_data = set()
self._got_confirmation = False
self._closed = False
def _set_side(self, side):
self._side = side
@ -305,6 +331,7 @@ class Wormhole:
self.msg1 = d["msg1"].decode("hex")
return self
#@close_on_error # XXX
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
if self.key is None:
@ -350,16 +377,20 @@ class Wormhole:
d.addCallback(_got_pake)
return d
@close_on_error
def get_verifier(self):
if self._closed: raise UsageError
if self.code is None: raise UsageError
d = self._get_key()
d.addCallback(lambda _: self.verifier)
return d
@close_on_error
def send_data(self, outbound_data, phase=u"data"):
if not isinstance(outbound_data, type(b"")):
raise TypeError(type(outbound_data))
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if self._closed: raise UsageError
if phase in self._sent_data: raise UsageError # only call this once
if phase.startswith(u"_"): raise UsageError # reserved for internals
if self.code is None: raise UsageError
@ -377,10 +408,12 @@ class Wormhole:
d.addCallback(_send)
return d
@close_on_error
def get_data(self, phase=u"data"):
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if phase in self._got_data: raise UsageError # only call this once
if phase.startswith(u"_"): raise UsageError # reserved for internals
if self._closed: raise UsageError
if self.code is None: raise UsageError
if self._channel is None: raise UsageError
self._got_data.add(phase)
@ -416,7 +449,13 @@ class Wormhole:
d.addCallback(_get)
return d
def close(self, res=None):
monitor.close(self._channel)
d = self._channel.deallocate()
def close(self, res=None, mood=u"happy"):
if not isinstance(mood, (type(None), type(u""))):
raise TypeError(type(mood))
self._closed = True
d = defer.succeed(None)
if self._channel:
c, self._channel = self._channel, None
monitor.close(c)
d.addCallback(lambda _: c.deallocate(mood))
return d