add close-with-mood-on-error to twisted style too
This commit is contained in:
parent
d1cf1c6da0
commit
22a1ce2eda
|
@ -317,18 +317,20 @@ class Basic(ServerBase, unittest.TestCase):
|
|||
|
||||
def test_errors(self):
|
||||
w1 = Wormhole(APPID, self.relayurl)
|
||||
self.assertRaises(UsageError, w1.get_verifier)
|
||||
self.assertRaises(UsageError, w1.send_data, b"data")
|
||||
self.assertRaises(UsageError, w1.get_data)
|
||||
w1.set_code(u"123-purple-elephant")
|
||||
self.assertRaises(UsageError, w1.set_code, u"123-nope")
|
||||
self.assertRaises(UsageError, w1.get_code)
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
d = w2.get_code()
|
||||
self.assertRaises(UsageError, w2.get_code)
|
||||
def _got_code(code):
|
||||
return self.doBoth(w1.close(), w2.close())
|
||||
d.addCallback(_got_code)
|
||||
d = self.assertFailure(w1.get_verifier(), UsageError)
|
||||
d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data"), UsageError))
|
||||
d.addCallback(lambda _: self.assertFailure(w1.get_data(), UsageError))
|
||||
d.addCallback(lambda _: w1.set_code(u"123-purple-elephant"))
|
||||
# these two UsageErrors are synchronous, although most of the rest are async
|
||||
d.addCallback(lambda _: self.assertRaises(UsageError, w1.set_code, u"123-nope"))
|
||||
d.addCallback(lambda _: self.assertRaises(UsageError, w1.get_code))
|
||||
def _then(_):
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
d2 = w2.get_code()
|
||||
d2.addCallback(lambda _: self.assertRaises(UsageError, w2.get_code))
|
||||
d2.addCallback(lambda _: self.doBoth(w1.close(), w2.close()))
|
||||
return d2
|
||||
d.addCallback(_then)
|
||||
return d
|
||||
|
||||
def test_repeat_phases(self):
|
||||
|
@ -339,27 +341,23 @@ class Basic(ServerBase, unittest.TestCase):
|
|||
# we must let them establish a key before we can send data
|
||||
d = self.doBoth(w1.get_verifier(), w2.get_verifier())
|
||||
d.addCallback(lambda _: w1.send_data(b"data1", phase=u"1"))
|
||||
def _sent(res):
|
||||
# underscore-prefixed phases are reserved
|
||||
self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"_1")
|
||||
self.assertRaises(UsageError, w1.get_data, phase=u"_1")
|
||||
# you can't send twice to the same phase
|
||||
self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"1")
|
||||
# but you can send to a different one
|
||||
return w1.send_data(b"data2", phase=u"2")
|
||||
d.addCallback(_sent)
|
||||
# underscore-prefixed phases are reserved
|
||||
d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data1", phase=u"_1"),
|
||||
UsageError))
|
||||
d.addCallback(lambda _: self.assertFailure(w1.get_data(phase=u"_1"), UsageError))
|
||||
# you can't send twice to the same phase
|
||||
d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data1", phase=u"1"),
|
||||
UsageError))
|
||||
# but you can send to a different one
|
||||
d.addCallback(lambda _: w1.send_data(b"data2", phase=u"2"))
|
||||
d.addCallback(lambda _: w2.get_data(phase=u"1"))
|
||||
def _got1(res):
|
||||
self.failUnlessEqual(res, b"data1")
|
||||
# and you can't read twice from the same phase
|
||||
self.assertRaises(UsageError, w2.get_data, phase=u"1")
|
||||
# but you can read from a different one
|
||||
return w2.get_data(phase=u"2")
|
||||
d.addCallback(_got1)
|
||||
def _got2(res):
|
||||
self.failUnlessEqual(res, b"data2")
|
||||
return self.doBoth(w1.close(), w2.close())
|
||||
d.addCallback(_got2)
|
||||
d.addCallback(lambda res: self.failUnlessEqual(res, b"data1"))
|
||||
# and you can't read twice from the same phase
|
||||
d.addCallback(lambda _: self.assertFailure(w2.get_data(phase=u"1"), UsageError))
|
||||
# but you can read from a different one
|
||||
d.addCallback(lambda _: w2.get_data(phase=u"2"))
|
||||
d.addCallback(lambda res: self.failUnlessEqual(res, b"data2"))
|
||||
d.addCallback(lambda _: self.doBoth(w1.close(), w2.close()))
|
||||
return d
|
||||
|
||||
def test_serialize(self):
|
||||
|
|
|
@ -14,7 +14,7 @@ from spake2 import SPAKE2_Symmetric
|
|||
from .eventsource_twisted import ReconnectingEventSource
|
||||
from .. import __version__
|
||||
from .. import codes
|
||||
from ..errors import ServerError, WrongPasswordError, UsageError
|
||||
from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
|
||||
from ..util.hkdf import HKDF
|
||||
from ..channel_monitor import monitor
|
||||
|
||||
|
@ -153,7 +153,7 @@ class Channel:
|
|||
d.addCallback(_got)
|
||||
return d
|
||||
|
||||
def deallocate(self, mood=u"unknown"):
|
||||
def deallocate(self, mood=None):
|
||||
# only try once, no retries
|
||||
d = post_json(self._agent, self._relay_url+"deallocate",
|
||||
{"appid": self._appid,
|
||||
|
@ -194,6 +194,31 @@ class ChannelManager:
|
|||
return Channel(self._relay, self._appid, channelid, self._side,
|
||||
self._handle_welcome, self._agent)
|
||||
|
||||
|
||||
def close_on_error(meth): # method decorator
|
||||
# Clients report certain errors as "moods", so the server can make a
|
||||
# rough count failed connections (due to mismatched passwords, attacks,
|
||||
# or timeouts). We don't report precondition failures, as those are the
|
||||
# responsibility/fault of the local application code. We count
|
||||
# non-precondition errors in case they represent server-side problems.
|
||||
def _wrapper(self, *args, **kwargs):
|
||||
d = defer.maybeDeferred(meth, self, *args, **kwargs)
|
||||
def _onerror(f):
|
||||
if f.check(Timeout):
|
||||
d2 = self.close(u"lonely")
|
||||
elif f.check(WrongPasswordError):
|
||||
d2 = self.close(u"scary")
|
||||
elif f.check(TypeError, UsageError):
|
||||
# preconditions don't warrant _close_with_error()
|
||||
d2 = defer.succeed(None)
|
||||
else:
|
||||
d2 = self.close(u"errory")
|
||||
d2.addBoth(lambda _: f)
|
||||
return d2
|
||||
d.addErrback(_onerror)
|
||||
return d
|
||||
return _wrapper
|
||||
|
||||
class Wormhole:
|
||||
motd_displayed = False
|
||||
version_warning_displayed = False
|
||||
|
@ -213,6 +238,7 @@ class Wormhole:
|
|||
self._sent_data = set() # phases
|
||||
self._got_data = set()
|
||||
self._got_confirmation = False
|
||||
self._closed = False
|
||||
|
||||
def _set_side(self, side):
|
||||
self._side = side
|
||||
|
@ -305,6 +331,7 @@ class Wormhole:
|
|||
self.msg1 = d["msg1"].decode("hex")
|
||||
return self
|
||||
|
||||
#@close_on_error # XXX
|
||||
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||||
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
|
||||
if self.key is None:
|
||||
|
@ -350,16 +377,20 @@ class Wormhole:
|
|||
d.addCallback(_got_pake)
|
||||
return d
|
||||
|
||||
@close_on_error
|
||||
def get_verifier(self):
|
||||
if self._closed: raise UsageError
|
||||
if self.code is None: raise UsageError
|
||||
d = self._get_key()
|
||||
d.addCallback(lambda _: self.verifier)
|
||||
return d
|
||||
|
||||
@close_on_error
|
||||
def send_data(self, outbound_data, phase=u"data"):
|
||||
if not isinstance(outbound_data, type(b"")):
|
||||
raise TypeError(type(outbound_data))
|
||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||
if self._closed: raise UsageError
|
||||
if phase in self._sent_data: raise UsageError # only call this once
|
||||
if phase.startswith(u"_"): raise UsageError # reserved for internals
|
||||
if self.code is None: raise UsageError
|
||||
|
@ -377,10 +408,12 @@ class Wormhole:
|
|||
d.addCallback(_send)
|
||||
return d
|
||||
|
||||
@close_on_error
|
||||
def get_data(self, phase=u"data"):
|
||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||
if phase in self._got_data: raise UsageError # only call this once
|
||||
if phase.startswith(u"_"): raise UsageError # reserved for internals
|
||||
if self._closed: raise UsageError
|
||||
if self.code is None: raise UsageError
|
||||
if self._channel is None: raise UsageError
|
||||
self._got_data.add(phase)
|
||||
|
@ -416,7 +449,13 @@ class Wormhole:
|
|||
d.addCallback(_get)
|
||||
return d
|
||||
|
||||
def close(self, res=None):
|
||||
monitor.close(self._channel)
|
||||
d = self._channel.deallocate()
|
||||
def close(self, res=None, mood=u"happy"):
|
||||
if not isinstance(mood, (type(None), type(u""))):
|
||||
raise TypeError(type(mood))
|
||||
self._closed = True
|
||||
d = defer.succeed(None)
|
||||
if self._channel:
|
||||
c, self._channel = self._channel, None
|
||||
monitor.close(c)
|
||||
d.addCallback(lambda _: c.deallocate(mood))
|
||||
return d
|
||||
|
|
Loading…
Reference in New Issue
Block a user