From 84def8a54ba1f64e198756f2f5a4bef4fa2abd39 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 1 Mar 2016 17:55:25 -0800 Subject: [PATCH] add some inlineCallbacks for simplicity This control flow was getting too hairy. --- src/wormhole/test/test_twisted.py | 6 +- src/wormhole/twisted/transcribe.py | 166 +++++++++++++---------------- src/wormhole/twisted/transit.py | 10 +- 3 files changed, 83 insertions(+), 99 deletions(-) diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py index 688b7a9..fde7d1d 100644 --- a/src/wormhole/test/test_twisted.py +++ b/src/wormhole/test/test_twisted.py @@ -321,13 +321,13 @@ class Basic(ServerBase, unittest.TestCase): d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data"), UsageError)) d.addCallback(lambda _: self.assertFailure(w1.get_data(), UsageError)) d.addCallback(lambda _: w1.set_code(u"123-purple-elephant")) - # these two UsageErrors are synchronous, although most of the rest are async + # this UsageError is synchronous, although most of the rest are async d.addCallback(lambda _: self.assertRaises(UsageError, w1.set_code, u"123-nope")) - d.addCallback(lambda _: self.assertRaises(UsageError, w1.get_code)) + d.addCallback(lambda _: self.assertFailure(w1.get_code(), UsageError)) def _then(_): w2 = Wormhole(APPID, self.relayurl) d2 = w2.get_code() - d2.addCallback(lambda _: self.assertRaises(UsageError, w2.get_code)) + d2.addCallback(lambda _: self.assertFailure(w2.get_code(), UsageError)) d2.addCallback(lambda _: self.doBoth(w1.close(), w2.close())) return d2 d.addCallback(_then) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index f5c98ff..932ab00 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -4,6 +4,7 @@ from six.moves.urllib_parse import urlencode from binascii import hexlify, unhexlify from zope.interface import implementer from twisted.internet import reactor, defer +from twisted.internet.defer import inlineCallbacks, returnValue from twisted.web import client as web_client from twisted.web import error as web_error from twisted.web.iweb import IBodyProducer @@ -150,14 +151,12 @@ class Channel: d.addCallback(lambda _: msgs[0]) return d + @inlineCallbacks def get(self, phase): - d = self.get_first_of([phase]) - def _got(res): - (got_phase, body) = res - assert got_phase == phase - return body - d.addCallback(_got) - return d + res = yield self.get_first_of([phase]) + (got_phase, body) = res + assert got_phase == phase + returnValue(body) def deallocate(self, mood=None): # only try once, no retries @@ -178,23 +177,21 @@ class ChannelManager: self._handle_welcome = handle_welcome self._agent = web_client.Agent(reactor) + @inlineCallbacks def allocate(self): url = self._relay + "allocate" - d = post_json(self._agent, url, {"appid": self._appid, - "side": self._side}) - def _got_channel(data): - if "welcome" in data: - self._handle_welcome(data["welcome"]) - return data["channelid"] - d.addCallback(_got_channel) - return d + data = yield post_json(self._agent, url, {"appid": self._appid, + "side": self._side}) + if "welcome" in data: + self._handle_welcome(data["welcome"]) + returnValue(data["channelid"]) + @inlineCallbacks def list_channels(self): queryargs = urlencode([("appid", self._appid)]) url = self._relay + u"list?%s" % queryargs - d = get_json(self._agent, url) - d.addCallback(lambda r: r["channelids"]) - return d + r = yield get_json(self._agent, url) + returnValue(r["channelids"]) def connect(self, channelid): return Channel(self._relay, self._appid, channelid, self._side, @@ -274,19 +271,17 @@ class Wormhole: if "error" in welcome: raise ServerError(welcome["error"], self._relay_url) + @inlineCallbacks def get_code(self, code_length=2): if self.code is not None: raise UsageError if self._started_get_code: raise UsageError self._started_get_code = True - d = self._channel_manager.allocate() - def _got_channelid(channelid): - code = codes.make_code(channelid, code_length) - assert isinstance(code, type(u"")), type(code) - self._set_code_and_channelid(code) - self._start() - return code - d.addCallback(_got_channelid) - return d + channelid = yield self._channel_manager.allocate() + code = codes.make_code(channelid, code_length) + assert isinstance(code, type(u"")), type(code) + self._set_code_and_channelid(code) + self._start() + returnValue(code) def set_code(self, code): if not isinstance(code, type(u"")): raise TypeError(type(code)) @@ -361,37 +356,34 @@ class Wormhole: data = box.decrypt(encrypted) return data - + @inlineCallbacks def _get_key(self): # TODO: prevent multiple invocation if self.key: - return defer.succeed(self.key) - d = self._channel.send(u"pake", self.msg1) - d.addCallback(lambda _: self._channel.get(u"pake")) - def _got_pake(pake_msg): - key = self.sp.finish(pake_msg) - self.key = key - self.verifier = self.derive_key(u"wormhole:verifier") - if not self._send_confirm: - return key - confkey = self.derive_key(u"wormhole:confirmation") - nonce = os.urandom(CONFMSG_NONCE_LENGTH) - confmsg = make_confmsg(confkey, nonce) - d1 = self._channel.send(u"_confirm", confmsg) - d1.addCallback(lambda _: key) - return d1 - d.addCallback(_got_pake) - return d + returnValue(self.key) + yield self._channel.send(u"pake", self.msg1) + pake_msg = yield self._channel.get(u"pake") + key = self.sp.finish(pake_msg) + self.key = key + self.verifier = self.derive_key(u"wormhole:verifier") + if not self._send_confirm: + returnValue(key) + confkey = self.derive_key(u"wormhole:confirmation") + nonce = os.urandom(CONFMSG_NONCE_LENGTH) + confmsg = make_confmsg(confkey, nonce) + yield self._channel.send(u"_confirm", confmsg) + returnValue(key) @close_on_error + @inlineCallbacks def get_verifier(self): if self._closed: raise UsageError if self.code is None: raise UsageError - d = self._get_key() - d.addCallback(lambda _: self.verifier) - return d + yield self._get_key() + returnValue(self.verifier) @close_on_error + @inlineCallbacks def send_data(self, outbound_data, phase=u"data"): if not isinstance(outbound_data, type(b"")): raise TypeError(type(outbound_data)) @@ -406,15 +398,13 @@ class Wormhole: # nonces to keep the messages distinct, and the Channel automatically # ignores reflections. self._sent_data.add(phase) - d = self._get_key() - def _send(key): - data_key = self.derive_key(u"wormhole:phase:%s" % phase) - outbound_encrypted = self._encrypt_data(data_key, outbound_data) - return self._channel.send(phase, outbound_encrypted) - d.addCallback(_send) - return d + yield self._get_key() + data_key = self.derive_key(u"wormhole:phase:%s" % phase) + outbound_encrypted = self._encrypt_data(data_key, outbound_data) + yield self._channel.send(phase, outbound_encrypted) @close_on_error + @inlineCallbacks def get_data(self, phase=u"data"): if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if phase in self._got_data: raise UsageError # only call this once @@ -423,45 +413,37 @@ class Wormhole: if self.code is None: raise UsageError if self._channel is None: raise UsageError self._got_data.add(phase) - d = self._get_key() - def _get(key): - phases = [] - if not self._got_confirmation: - phases.append(u"_confirm") - phases.append(phase) - d1 = self._channel.get_first_of(phases) - def _maybe_got_confirm(phase_and_body): - (got_phase, body) = phase_and_body - if got_phase == u"_confirm": - confkey = self.derive_key(u"wormhole:confirmation") - nonce = body[:CONFMSG_NONCE_LENGTH] - if body != make_confmsg(confkey, nonce): - raise WrongPasswordError - self._got_confirmation = True - return self._channel.get_first_of([phase]) - return phase_and_body - d1.addCallback(_maybe_got_confirm) - def _got(phase_and_body): - (got_phase, body) = phase_and_body - assert got_phase == phase - try: - data_key = self.derive_key(u"wormhole:phase:%s" % phase) - inbound_data = self._decrypt_data(data_key, body) - return inbound_data - except CryptoError: - raise WrongPasswordError - d1.addCallback(_got) - return d1 - d.addCallback(_get) - return d + yield self._get_key() + phases = [] + if not self._got_confirmation: + phases.append(u"_confirm") + phases.append(phase) + phase_and_body = yield self._channel.get_first_of(phases) + (got_phase, body) = phase_and_body + if got_phase == u"_confirm": + confkey = self.derive_key(u"wormhole:confirmation") + nonce = body[:CONFMSG_NONCE_LENGTH] + if body != make_confmsg(confkey, nonce): + raise WrongPasswordError + self._got_confirmation = True + phase_and_body = yield self._channel.get_first_of([phase]) + (got_phase, body) = phase_and_body + assert got_phase == phase + try: + data_key = self.derive_key(u"wormhole:phase:%s" % phase) + inbound_data = self._decrypt_data(data_key, body) + returnValue(inbound_data) + except CryptoError: + raise WrongPasswordError + @inlineCallbacks def close(self, res=None, mood=u"happy"): if not isinstance(mood, (type(None), type(u""))): raise TypeError(type(mood)) self._closed = True - d = defer.succeed(None) - if self._channel: - c, self._channel = self._channel, None - monitor.close(c) - d.addCallback(lambda _: c.deallocate(mood)) - return d + if not self._channel: + returnValue(None) + c, self._channel = self._channel, None + monitor.close(c) + yield c.deallocate(mood) + diff --git a/src/wormhole/twisted/transit.py b/src/wormhole/twisted/transit.py index 942fcd1..54cda1e 100644 --- a/src/wormhole/twisted/transit.py +++ b/src/wormhole/twisted/transit.py @@ -5,6 +5,7 @@ from zope.interface import implementer from twisted.python.runtime import platformType from twisted.internet import (reactor, interfaces, defer, protocol, endpoints, task, address, error) +from twisted.internet.defer import inlineCallbacks, returnValue from twisted.protocols import policies from nacl.secret import SecretBox from ..util import ipaddrs @@ -536,15 +537,16 @@ class Common: self._waiting_for_transit_key.append(d) return d + @inlineCallbacks def connect(self): - d = self._get_transit_key() - d.addCallback(self._connect) + yield self._get_transit_key() # we want to have the transit key before starting any outbound # connections, so those connections will know what to say when they # connect - return d + winner = yield self._connect() + returnValue(winner) - def _connect(self, _): + def _connect(self): # It might be nice to wire this so that a failure in the direct hints # causes the relay hints to be used right away (fast failover). But # none of our current use cases would take advantage of that: if we