add close-with-mood-on-error to twisted style too

This commit is contained in:
Brian Warner 2015-11-19 17:08:21 -08:00
parent d1cf1c6da0
commit 22a1ce2eda
2 changed files with 74 additions and 37 deletions

View File

@ -317,18 +317,20 @@ class Basic(ServerBase, unittest.TestCase):
def test_errors(self): def test_errors(self):
w1 = Wormhole(APPID, self.relayurl) w1 = Wormhole(APPID, self.relayurl)
self.assertRaises(UsageError, w1.get_verifier) d = self.assertFailure(w1.get_verifier(), UsageError)
self.assertRaises(UsageError, w1.send_data, b"data") d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data"), UsageError))
self.assertRaises(UsageError, w1.get_data) d.addCallback(lambda _: self.assertFailure(w1.get_data(), UsageError))
w1.set_code(u"123-purple-elephant") d.addCallback(lambda _: w1.set_code(u"123-purple-elephant"))
self.assertRaises(UsageError, w1.set_code, u"123-nope") # these two UsageErrors are synchronous, although most of the rest are async
self.assertRaises(UsageError, w1.get_code) d.addCallback(lambda _: self.assertRaises(UsageError, w1.set_code, u"123-nope"))
d.addCallback(lambda _: self.assertRaises(UsageError, w1.get_code))
def _then(_):
w2 = Wormhole(APPID, self.relayurl) w2 = Wormhole(APPID, self.relayurl)
d = w2.get_code() d2 = w2.get_code()
self.assertRaises(UsageError, w2.get_code) d2.addCallback(lambda _: self.assertRaises(UsageError, w2.get_code))
def _got_code(code): d2.addCallback(lambda _: self.doBoth(w1.close(), w2.close()))
return self.doBoth(w1.close(), w2.close()) return d2
d.addCallback(_got_code) d.addCallback(_then)
return d return d
def test_repeat_phases(self): def test_repeat_phases(self):
@ -339,27 +341,23 @@ class Basic(ServerBase, unittest.TestCase):
# we must let them establish a key before we can send data # we must let them establish a key before we can send data
d = self.doBoth(w1.get_verifier(), w2.get_verifier()) d = self.doBoth(w1.get_verifier(), w2.get_verifier())
d.addCallback(lambda _: w1.send_data(b"data1", phase=u"1")) d.addCallback(lambda _: w1.send_data(b"data1", phase=u"1"))
def _sent(res):
# underscore-prefixed phases are reserved # underscore-prefixed phases are reserved
self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"_1") d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data1", phase=u"_1"),
self.assertRaises(UsageError, w1.get_data, phase=u"_1") UsageError))
d.addCallback(lambda _: self.assertFailure(w1.get_data(phase=u"_1"), UsageError))
# you can't send twice to the same phase # you can't send twice to the same phase
self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"1") d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data1", phase=u"1"),
UsageError))
# but you can send to a different one # but you can send to a different one
return w1.send_data(b"data2", phase=u"2") d.addCallback(lambda _: w1.send_data(b"data2", phase=u"2"))
d.addCallback(_sent)
d.addCallback(lambda _: w2.get_data(phase=u"1")) d.addCallback(lambda _: w2.get_data(phase=u"1"))
def _got1(res): d.addCallback(lambda res: self.failUnlessEqual(res, b"data1"))
self.failUnlessEqual(res, b"data1")
# and you can't read twice from the same phase # and you can't read twice from the same phase
self.assertRaises(UsageError, w2.get_data, phase=u"1") d.addCallback(lambda _: self.assertFailure(w2.get_data(phase=u"1"), UsageError))
# but you can read from a different one # but you can read from a different one
return w2.get_data(phase=u"2") d.addCallback(lambda _: w2.get_data(phase=u"2"))
d.addCallback(_got1) d.addCallback(lambda res: self.failUnlessEqual(res, b"data2"))
def _got2(res): d.addCallback(lambda _: self.doBoth(w1.close(), w2.close()))
self.failUnlessEqual(res, b"data2")
return self.doBoth(w1.close(), w2.close())
d.addCallback(_got2)
return d return d
def test_serialize(self): def test_serialize(self):

View File

@ -14,7 +14,7 @@ from spake2 import SPAKE2_Symmetric
from .eventsource_twisted import ReconnectingEventSource from .eventsource_twisted import ReconnectingEventSource
from .. import __version__ from .. import __version__
from .. import codes from .. import codes
from ..errors import ServerError, WrongPasswordError, UsageError from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
from ..util.hkdf import HKDF from ..util.hkdf import HKDF
from ..channel_monitor import monitor from ..channel_monitor import monitor
@ -153,7 +153,7 @@ class Channel:
d.addCallback(_got) d.addCallback(_got)
return d return d
def deallocate(self, mood=u"unknown"): def deallocate(self, mood=None):
# only try once, no retries # only try once, no retries
d = post_json(self._agent, self._relay_url+"deallocate", d = post_json(self._agent, self._relay_url+"deallocate",
{"appid": self._appid, {"appid": self._appid,
@ -194,6 +194,31 @@ class ChannelManager:
return Channel(self._relay, self._appid, channelid, self._side, return Channel(self._relay, self._appid, channelid, self._side,
self._handle_welcome, self._agent) self._handle_welcome, self._agent)
def close_on_error(meth): # method decorator
# Clients report certain errors as "moods", so the server can make a
# rough count failed connections (due to mismatched passwords, attacks,
# or timeouts). We don't report precondition failures, as those are the
# responsibility/fault of the local application code. We count
# non-precondition errors in case they represent server-side problems.
def _wrapper(self, *args, **kwargs):
d = defer.maybeDeferred(meth, self, *args, **kwargs)
def _onerror(f):
if f.check(Timeout):
d2 = self.close(u"lonely")
elif f.check(WrongPasswordError):
d2 = self.close(u"scary")
elif f.check(TypeError, UsageError):
# preconditions don't warrant _close_with_error()
d2 = defer.succeed(None)
else:
d2 = self.close(u"errory")
d2.addBoth(lambda _: f)
return d2
d.addErrback(_onerror)
return d
return _wrapper
class Wormhole: class Wormhole:
motd_displayed = False motd_displayed = False
version_warning_displayed = False version_warning_displayed = False
@ -213,6 +238,7 @@ class Wormhole:
self._sent_data = set() # phases self._sent_data = set() # phases
self._got_data = set() self._got_data = set()
self._got_confirmation = False self._got_confirmation = False
self._closed = False
def _set_side(self, side): def _set_side(self, side):
self._side = side self._side = side
@ -305,6 +331,7 @@ class Wormhole:
self.msg1 = d["msg1"].decode("hex") self.msg1 = d["msg1"].decode("hex")
return self return self
#@close_on_error # XXX
def derive_key(self, purpose, length=SecretBox.KEY_SIZE): def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
if self.key is None: if self.key is None:
@ -350,16 +377,20 @@ class Wormhole:
d.addCallback(_got_pake) d.addCallback(_got_pake)
return d return d
@close_on_error
def get_verifier(self): def get_verifier(self):
if self._closed: raise UsageError
if self.code is None: raise UsageError if self.code is None: raise UsageError
d = self._get_key() d = self._get_key()
d.addCallback(lambda _: self.verifier) d.addCallback(lambda _: self.verifier)
return d return d
@close_on_error
def send_data(self, outbound_data, phase=u"data"): def send_data(self, outbound_data, phase=u"data"):
if not isinstance(outbound_data, type(b"")): if not isinstance(outbound_data, type(b"")):
raise TypeError(type(outbound_data)) raise TypeError(type(outbound_data))
if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if self._closed: raise UsageError
if phase in self._sent_data: raise UsageError # only call this once if phase in self._sent_data: raise UsageError # only call this once
if phase.startswith(u"_"): raise UsageError # reserved for internals if phase.startswith(u"_"): raise UsageError # reserved for internals
if self.code is None: raise UsageError if self.code is None: raise UsageError
@ -377,10 +408,12 @@ class Wormhole:
d.addCallback(_send) d.addCallback(_send)
return d return d
@close_on_error
def get_data(self, phase=u"data"): def get_data(self, phase=u"data"):
if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if phase in self._got_data: raise UsageError # only call this once if phase in self._got_data: raise UsageError # only call this once
if phase.startswith(u"_"): raise UsageError # reserved for internals if phase.startswith(u"_"): raise UsageError # reserved for internals
if self._closed: raise UsageError
if self.code is None: raise UsageError if self.code is None: raise UsageError
if self._channel is None: raise UsageError if self._channel is None: raise UsageError
self._got_data.add(phase) self._got_data.add(phase)
@ -416,7 +449,13 @@ class Wormhole:
d.addCallback(_get) d.addCallback(_get)
return d return d
def close(self, res=None): def close(self, res=None, mood=u"happy"):
monitor.close(self._channel) if not isinstance(mood, (type(None), type(u""))):
d = self._channel.deallocate() raise TypeError(type(mood))
self._closed = True
d = defer.succeed(None)
if self._channel:
c, self._channel = self._channel, None
monitor.close(c)
d.addCallback(lambda _: c.deallocate(mood))
return d return d