add close-with-mood-on-error to twisted style too
This commit is contained in:
parent
d1cf1c6da0
commit
22a1ce2eda
|
@ -317,18 +317,20 @@ class Basic(ServerBase, unittest.TestCase):
|
||||||
|
|
||||||
def test_errors(self):
|
def test_errors(self):
|
||||||
w1 = Wormhole(APPID, self.relayurl)
|
w1 = Wormhole(APPID, self.relayurl)
|
||||||
self.assertRaises(UsageError, w1.get_verifier)
|
d = self.assertFailure(w1.get_verifier(), UsageError)
|
||||||
self.assertRaises(UsageError, w1.send_data, b"data")
|
d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data"), UsageError))
|
||||||
self.assertRaises(UsageError, w1.get_data)
|
d.addCallback(lambda _: self.assertFailure(w1.get_data(), UsageError))
|
||||||
w1.set_code(u"123-purple-elephant")
|
d.addCallback(lambda _: w1.set_code(u"123-purple-elephant"))
|
||||||
self.assertRaises(UsageError, w1.set_code, u"123-nope")
|
# these two UsageErrors are synchronous, although most of the rest are async
|
||||||
self.assertRaises(UsageError, w1.get_code)
|
d.addCallback(lambda _: self.assertRaises(UsageError, w1.set_code, u"123-nope"))
|
||||||
|
d.addCallback(lambda _: self.assertRaises(UsageError, w1.get_code))
|
||||||
|
def _then(_):
|
||||||
w2 = Wormhole(APPID, self.relayurl)
|
w2 = Wormhole(APPID, self.relayurl)
|
||||||
d = w2.get_code()
|
d2 = w2.get_code()
|
||||||
self.assertRaises(UsageError, w2.get_code)
|
d2.addCallback(lambda _: self.assertRaises(UsageError, w2.get_code))
|
||||||
def _got_code(code):
|
d2.addCallback(lambda _: self.doBoth(w1.close(), w2.close()))
|
||||||
return self.doBoth(w1.close(), w2.close())
|
return d2
|
||||||
d.addCallback(_got_code)
|
d.addCallback(_then)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def test_repeat_phases(self):
|
def test_repeat_phases(self):
|
||||||
|
@ -339,27 +341,23 @@ class Basic(ServerBase, unittest.TestCase):
|
||||||
# we must let them establish a key before we can send data
|
# we must let them establish a key before we can send data
|
||||||
d = self.doBoth(w1.get_verifier(), w2.get_verifier())
|
d = self.doBoth(w1.get_verifier(), w2.get_verifier())
|
||||||
d.addCallback(lambda _: w1.send_data(b"data1", phase=u"1"))
|
d.addCallback(lambda _: w1.send_data(b"data1", phase=u"1"))
|
||||||
def _sent(res):
|
|
||||||
# underscore-prefixed phases are reserved
|
# underscore-prefixed phases are reserved
|
||||||
self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"_1")
|
d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data1", phase=u"_1"),
|
||||||
self.assertRaises(UsageError, w1.get_data, phase=u"_1")
|
UsageError))
|
||||||
|
d.addCallback(lambda _: self.assertFailure(w1.get_data(phase=u"_1"), UsageError))
|
||||||
# you can't send twice to the same phase
|
# you can't send twice to the same phase
|
||||||
self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"1")
|
d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data1", phase=u"1"),
|
||||||
|
UsageError))
|
||||||
# but you can send to a different one
|
# but you can send to a different one
|
||||||
return w1.send_data(b"data2", phase=u"2")
|
d.addCallback(lambda _: w1.send_data(b"data2", phase=u"2"))
|
||||||
d.addCallback(_sent)
|
|
||||||
d.addCallback(lambda _: w2.get_data(phase=u"1"))
|
d.addCallback(lambda _: w2.get_data(phase=u"1"))
|
||||||
def _got1(res):
|
d.addCallback(lambda res: self.failUnlessEqual(res, b"data1"))
|
||||||
self.failUnlessEqual(res, b"data1")
|
|
||||||
# and you can't read twice from the same phase
|
# and you can't read twice from the same phase
|
||||||
self.assertRaises(UsageError, w2.get_data, phase=u"1")
|
d.addCallback(lambda _: self.assertFailure(w2.get_data(phase=u"1"), UsageError))
|
||||||
# but you can read from a different one
|
# but you can read from a different one
|
||||||
return w2.get_data(phase=u"2")
|
d.addCallback(lambda _: w2.get_data(phase=u"2"))
|
||||||
d.addCallback(_got1)
|
d.addCallback(lambda res: self.failUnlessEqual(res, b"data2"))
|
||||||
def _got2(res):
|
d.addCallback(lambda _: self.doBoth(w1.close(), w2.close()))
|
||||||
self.failUnlessEqual(res, b"data2")
|
|
||||||
return self.doBoth(w1.close(), w2.close())
|
|
||||||
d.addCallback(_got2)
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def test_serialize(self):
|
def test_serialize(self):
|
||||||
|
|
|
@ -14,7 +14,7 @@ from spake2 import SPAKE2_Symmetric
|
||||||
from .eventsource_twisted import ReconnectingEventSource
|
from .eventsource_twisted import ReconnectingEventSource
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .. import codes
|
from .. import codes
|
||||||
from ..errors import ServerError, WrongPasswordError, UsageError
|
from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
|
||||||
from ..util.hkdf import HKDF
|
from ..util.hkdf import HKDF
|
||||||
from ..channel_monitor import monitor
|
from ..channel_monitor import monitor
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ class Channel:
|
||||||
d.addCallback(_got)
|
d.addCallback(_got)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def deallocate(self, mood=u"unknown"):
|
def deallocate(self, mood=None):
|
||||||
# only try once, no retries
|
# only try once, no retries
|
||||||
d = post_json(self._agent, self._relay_url+"deallocate",
|
d = post_json(self._agent, self._relay_url+"deallocate",
|
||||||
{"appid": self._appid,
|
{"appid": self._appid,
|
||||||
|
@ -194,6 +194,31 @@ class ChannelManager:
|
||||||
return Channel(self._relay, self._appid, channelid, self._side,
|
return Channel(self._relay, self._appid, channelid, self._side,
|
||||||
self._handle_welcome, self._agent)
|
self._handle_welcome, self._agent)
|
||||||
|
|
||||||
|
|
||||||
|
def close_on_error(meth): # method decorator
|
||||||
|
# Clients report certain errors as "moods", so the server can make a
|
||||||
|
# rough count failed connections (due to mismatched passwords, attacks,
|
||||||
|
# or timeouts). We don't report precondition failures, as those are the
|
||||||
|
# responsibility/fault of the local application code. We count
|
||||||
|
# non-precondition errors in case they represent server-side problems.
|
||||||
|
def _wrapper(self, *args, **kwargs):
|
||||||
|
d = defer.maybeDeferred(meth, self, *args, **kwargs)
|
||||||
|
def _onerror(f):
|
||||||
|
if f.check(Timeout):
|
||||||
|
d2 = self.close(u"lonely")
|
||||||
|
elif f.check(WrongPasswordError):
|
||||||
|
d2 = self.close(u"scary")
|
||||||
|
elif f.check(TypeError, UsageError):
|
||||||
|
# preconditions don't warrant _close_with_error()
|
||||||
|
d2 = defer.succeed(None)
|
||||||
|
else:
|
||||||
|
d2 = self.close(u"errory")
|
||||||
|
d2.addBoth(lambda _: f)
|
||||||
|
return d2
|
||||||
|
d.addErrback(_onerror)
|
||||||
|
return d
|
||||||
|
return _wrapper
|
||||||
|
|
||||||
class Wormhole:
|
class Wormhole:
|
||||||
motd_displayed = False
|
motd_displayed = False
|
||||||
version_warning_displayed = False
|
version_warning_displayed = False
|
||||||
|
@ -213,6 +238,7 @@ class Wormhole:
|
||||||
self._sent_data = set() # phases
|
self._sent_data = set() # phases
|
||||||
self._got_data = set()
|
self._got_data = set()
|
||||||
self._got_confirmation = False
|
self._got_confirmation = False
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
def _set_side(self, side):
|
def _set_side(self, side):
|
||||||
self._side = side
|
self._side = side
|
||||||
|
@ -305,6 +331,7 @@ class Wormhole:
|
||||||
self.msg1 = d["msg1"].decode("hex")
|
self.msg1 = d["msg1"].decode("hex")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
#@close_on_error # XXX
|
||||||
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||||||
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
|
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
|
||||||
if self.key is None:
|
if self.key is None:
|
||||||
|
@ -350,16 +377,20 @@ class Wormhole:
|
||||||
d.addCallback(_got_pake)
|
d.addCallback(_got_pake)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
@close_on_error
|
||||||
def get_verifier(self):
|
def get_verifier(self):
|
||||||
|
if self._closed: raise UsageError
|
||||||
if self.code is None: raise UsageError
|
if self.code is None: raise UsageError
|
||||||
d = self._get_key()
|
d = self._get_key()
|
||||||
d.addCallback(lambda _: self.verifier)
|
d.addCallback(lambda _: self.verifier)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
@close_on_error
|
||||||
def send_data(self, outbound_data, phase=u"data"):
|
def send_data(self, outbound_data, phase=u"data"):
|
||||||
if not isinstance(outbound_data, type(b"")):
|
if not isinstance(outbound_data, type(b"")):
|
||||||
raise TypeError(type(outbound_data))
|
raise TypeError(type(outbound_data))
|
||||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||||
|
if self._closed: raise UsageError
|
||||||
if phase in self._sent_data: raise UsageError # only call this once
|
if phase in self._sent_data: raise UsageError # only call this once
|
||||||
if phase.startswith(u"_"): raise UsageError # reserved for internals
|
if phase.startswith(u"_"): raise UsageError # reserved for internals
|
||||||
if self.code is None: raise UsageError
|
if self.code is None: raise UsageError
|
||||||
|
@ -377,10 +408,12 @@ class Wormhole:
|
||||||
d.addCallback(_send)
|
d.addCallback(_send)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
@close_on_error
|
||||||
def get_data(self, phase=u"data"):
|
def get_data(self, phase=u"data"):
|
||||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||||
if phase in self._got_data: raise UsageError # only call this once
|
if phase in self._got_data: raise UsageError # only call this once
|
||||||
if phase.startswith(u"_"): raise UsageError # reserved for internals
|
if phase.startswith(u"_"): raise UsageError # reserved for internals
|
||||||
|
if self._closed: raise UsageError
|
||||||
if self.code is None: raise UsageError
|
if self.code is None: raise UsageError
|
||||||
if self._channel is None: raise UsageError
|
if self._channel is None: raise UsageError
|
||||||
self._got_data.add(phase)
|
self._got_data.add(phase)
|
||||||
|
@ -416,7 +449,13 @@ class Wormhole:
|
||||||
d.addCallback(_get)
|
d.addCallback(_get)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def close(self, res=None):
|
def close(self, res=None, mood=u"happy"):
|
||||||
monitor.close(self._channel)
|
if not isinstance(mood, (type(None), type(u""))):
|
||||||
d = self._channel.deallocate()
|
raise TypeError(type(mood))
|
||||||
|
self._closed = True
|
||||||
|
d = defer.succeed(None)
|
||||||
|
if self._channel:
|
||||||
|
c, self._channel = self._channel, None
|
||||||
|
monitor.close(c)
|
||||||
|
d.addCallback(lambda _: c.deallocate(mood))
|
||||||
return d
|
return d
|
||||||
|
|
Loading…
Reference in New Issue
Block a user