diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py index 341a44a..920522c 100644 --- a/src/wormhole/test/test_twisted.py +++ b/src/wormhole/test/test_twisted.py @@ -317,18 +317,20 @@ class Basic(ServerBase, unittest.TestCase): def test_errors(self): w1 = Wormhole(APPID, self.relayurl) - self.assertRaises(UsageError, w1.get_verifier) - self.assertRaises(UsageError, w1.send_data, b"data") - self.assertRaises(UsageError, w1.get_data) - w1.set_code(u"123-purple-elephant") - self.assertRaises(UsageError, w1.set_code, u"123-nope") - self.assertRaises(UsageError, w1.get_code) - w2 = Wormhole(APPID, self.relayurl) - d = w2.get_code() - self.assertRaises(UsageError, w2.get_code) - def _got_code(code): - return self.doBoth(w1.close(), w2.close()) - d.addCallback(_got_code) + d = self.assertFailure(w1.get_verifier(), UsageError) + d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data"), UsageError)) + d.addCallback(lambda _: self.assertFailure(w1.get_data(), UsageError)) + d.addCallback(lambda _: w1.set_code(u"123-purple-elephant")) + # these two UsageErrors are synchronous, although most of the rest are async + d.addCallback(lambda _: self.assertRaises(UsageError, w1.set_code, u"123-nope")) + d.addCallback(lambda _: self.assertRaises(UsageError, w1.get_code)) + def _then(_): + w2 = Wormhole(APPID, self.relayurl) + d2 = w2.get_code() + d2.addCallback(lambda _: self.assertRaises(UsageError, w2.get_code)) + d2.addCallback(lambda _: self.doBoth(w1.close(), w2.close())) + return d2 + d.addCallback(_then) return d def test_repeat_phases(self): @@ -339,27 +341,23 @@ class Basic(ServerBase, unittest.TestCase): # we must let them establish a key before we can send data d = self.doBoth(w1.get_verifier(), w2.get_verifier()) d.addCallback(lambda _: w1.send_data(b"data1", phase=u"1")) - def _sent(res): - # underscore-prefixed phases are reserved - self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"_1") - self.assertRaises(UsageError, w1.get_data, phase=u"_1") - # you can't send twice to the same phase - self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"1") - # but you can send to a different one - return w1.send_data(b"data2", phase=u"2") - d.addCallback(_sent) + # underscore-prefixed phases are reserved + d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data1", phase=u"_1"), + UsageError)) + d.addCallback(lambda _: self.assertFailure(w1.get_data(phase=u"_1"), UsageError)) + # you can't send twice to the same phase + d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data1", phase=u"1"), + UsageError)) + # but you can send to a different one + d.addCallback(lambda _: w1.send_data(b"data2", phase=u"2")) d.addCallback(lambda _: w2.get_data(phase=u"1")) - def _got1(res): - self.failUnlessEqual(res, b"data1") - # and you can't read twice from the same phase - self.assertRaises(UsageError, w2.get_data, phase=u"1") - # but you can read from a different one - return w2.get_data(phase=u"2") - d.addCallback(_got1) - def _got2(res): - self.failUnlessEqual(res, b"data2") - return self.doBoth(w1.close(), w2.close()) - d.addCallback(_got2) + d.addCallback(lambda res: self.failUnlessEqual(res, b"data1")) + # and you can't read twice from the same phase + d.addCallback(lambda _: self.assertFailure(w2.get_data(phase=u"1"), UsageError)) + # but you can read from a different one + d.addCallback(lambda _: w2.get_data(phase=u"2")) + d.addCallback(lambda res: self.failUnlessEqual(res, b"data2")) + d.addCallback(lambda _: self.doBoth(w1.close(), w2.close())) return d def test_serialize(self): diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index c38f702..e71cd42 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -14,7 +14,7 @@ from spake2 import SPAKE2_Symmetric from .eventsource_twisted import ReconnectingEventSource from .. import __version__ from .. import codes -from ..errors import ServerError, WrongPasswordError, UsageError +from ..errors import ServerError, Timeout, WrongPasswordError, UsageError from ..util.hkdf import HKDF from ..channel_monitor import monitor @@ -153,7 +153,7 @@ class Channel: d.addCallback(_got) return d - def deallocate(self, mood=u"unknown"): + def deallocate(self, mood=None): # only try once, no retries d = post_json(self._agent, self._relay_url+"deallocate", {"appid": self._appid, @@ -194,6 +194,31 @@ class ChannelManager: return Channel(self._relay, self._appid, channelid, self._side, self._handle_welcome, self._agent) + +def close_on_error(meth): # method decorator + # Clients report certain errors as "moods", so the server can make a + # rough count failed connections (due to mismatched passwords, attacks, + # or timeouts). We don't report precondition failures, as those are the + # responsibility/fault of the local application code. We count + # non-precondition errors in case they represent server-side problems. + def _wrapper(self, *args, **kwargs): + d = defer.maybeDeferred(meth, self, *args, **kwargs) + def _onerror(f): + if f.check(Timeout): + d2 = self.close(u"lonely") + elif f.check(WrongPasswordError): + d2 = self.close(u"scary") + elif f.check(TypeError, UsageError): + # preconditions don't warrant _close_with_error() + d2 = defer.succeed(None) + else: + d2 = self.close(u"errory") + d2.addBoth(lambda _: f) + return d2 + d.addErrback(_onerror) + return d + return _wrapper + class Wormhole: motd_displayed = False version_warning_displayed = False @@ -213,6 +238,7 @@ class Wormhole: self._sent_data = set() # phases self._got_data = set() self._got_confirmation = False + self._closed = False def _set_side(self, side): self._side = side @@ -305,6 +331,7 @@ class Wormhole: self.msg1 = d["msg1"].decode("hex") return self + #@close_on_error # XXX def derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if self.key is None: @@ -350,16 +377,20 @@ class Wormhole: d.addCallback(_got_pake) return d + @close_on_error def get_verifier(self): + if self._closed: raise UsageError if self.code is None: raise UsageError d = self._get_key() d.addCallback(lambda _: self.verifier) return d + @close_on_error def send_data(self, outbound_data, phase=u"data"): if not isinstance(outbound_data, type(b"")): raise TypeError(type(outbound_data)) if not isinstance(phase, type(u"")): raise TypeError(type(phase)) + if self._closed: raise UsageError if phase in self._sent_data: raise UsageError # only call this once if phase.startswith(u"_"): raise UsageError # reserved for internals if self.code is None: raise UsageError @@ -377,10 +408,12 @@ class Wormhole: d.addCallback(_send) return d + @close_on_error def get_data(self, phase=u"data"): if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if phase in self._got_data: raise UsageError # only call this once if phase.startswith(u"_"): raise UsageError # reserved for internals + if self._closed: raise UsageError if self.code is None: raise UsageError if self._channel is None: raise UsageError self._got_data.add(phase) @@ -416,7 +449,13 @@ class Wormhole: d.addCallback(_get) return d - def close(self, res=None): - monitor.close(self._channel) - d = self._channel.deallocate() + def close(self, res=None, mood=u"happy"): + if not isinstance(mood, (type(None), type(u""))): + raise TypeError(type(mood)) + self._closed = True + d = defer.succeed(None) + if self._channel: + c, self._channel = self._channel, None + monitor.close(c) + d.addCallback(lambda _: c.deallocate(mood)) return d