add some inlineCallbacks for simplicity

This control flow was getting too hairy.
This commit is contained in:
Brian Warner 2016-03-01 17:55:25 -08:00
parent fd143caded
commit 84def8a54b
3 changed files with 83 additions and 99 deletions

View File

@ -321,13 +321,13 @@ class Basic(ServerBase, unittest.TestCase):
d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data"), UsageError)) d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data"), UsageError))
d.addCallback(lambda _: self.assertFailure(w1.get_data(), UsageError)) d.addCallback(lambda _: self.assertFailure(w1.get_data(), UsageError))
d.addCallback(lambda _: w1.set_code(u"123-purple-elephant")) d.addCallback(lambda _: w1.set_code(u"123-purple-elephant"))
# these two UsageErrors are synchronous, although most of the rest are async # this UsageError is synchronous, although most of the rest are async
d.addCallback(lambda _: self.assertRaises(UsageError, w1.set_code, u"123-nope")) d.addCallback(lambda _: self.assertRaises(UsageError, w1.set_code, u"123-nope"))
d.addCallback(lambda _: self.assertRaises(UsageError, w1.get_code)) d.addCallback(lambda _: self.assertFailure(w1.get_code(), UsageError))
def _then(_): def _then(_):
w2 = Wormhole(APPID, self.relayurl) w2 = Wormhole(APPID, self.relayurl)
d2 = w2.get_code() d2 = w2.get_code()
d2.addCallback(lambda _: self.assertRaises(UsageError, w2.get_code)) d2.addCallback(lambda _: self.assertFailure(w2.get_code(), UsageError))
d2.addCallback(lambda _: self.doBoth(w1.close(), w2.close())) d2.addCallback(lambda _: self.doBoth(w1.close(), w2.close()))
return d2 return d2
d.addCallback(_then) d.addCallback(_then)

View File

@ -4,6 +4,7 @@ from six.moves.urllib_parse import urlencode
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.web import client as web_client from twisted.web import client as web_client
from twisted.web import error as web_error from twisted.web import error as web_error
from twisted.web.iweb import IBodyProducer from twisted.web.iweb import IBodyProducer
@ -150,14 +151,12 @@ class Channel:
d.addCallback(lambda _: msgs[0]) d.addCallback(lambda _: msgs[0])
return d return d
@inlineCallbacks
def get(self, phase): def get(self, phase):
d = self.get_first_of([phase]) res = yield self.get_first_of([phase])
def _got(res): (got_phase, body) = res
(got_phase, body) = res assert got_phase == phase
assert got_phase == phase returnValue(body)
return body
d.addCallback(_got)
return d
def deallocate(self, mood=None): def deallocate(self, mood=None):
# only try once, no retries # only try once, no retries
@ -178,23 +177,21 @@ class ChannelManager:
self._handle_welcome = handle_welcome self._handle_welcome = handle_welcome
self._agent = web_client.Agent(reactor) self._agent = web_client.Agent(reactor)
@inlineCallbacks
def allocate(self): def allocate(self):
url = self._relay + "allocate" url = self._relay + "allocate"
d = post_json(self._agent, url, {"appid": self._appid, data = yield post_json(self._agent, url, {"appid": self._appid,
"side": self._side}) "side": self._side})
def _got_channel(data): if "welcome" in data:
if "welcome" in data: self._handle_welcome(data["welcome"])
self._handle_welcome(data["welcome"]) returnValue(data["channelid"])
return data["channelid"]
d.addCallback(_got_channel)
return d
@inlineCallbacks
def list_channels(self): def list_channels(self):
queryargs = urlencode([("appid", self._appid)]) queryargs = urlencode([("appid", self._appid)])
url = self._relay + u"list?%s" % queryargs url = self._relay + u"list?%s" % queryargs
d = get_json(self._agent, url) r = yield get_json(self._agent, url)
d.addCallback(lambda r: r["channelids"]) returnValue(r["channelids"])
return d
def connect(self, channelid): def connect(self, channelid):
return Channel(self._relay, self._appid, channelid, self._side, return Channel(self._relay, self._appid, channelid, self._side,
@ -274,19 +271,17 @@ class Wormhole:
if "error" in welcome: if "error" in welcome:
raise ServerError(welcome["error"], self._relay_url) raise ServerError(welcome["error"], self._relay_url)
@inlineCallbacks
def get_code(self, code_length=2): def get_code(self, code_length=2):
if self.code is not None: raise UsageError if self.code is not None: raise UsageError
if self._started_get_code: raise UsageError if self._started_get_code: raise UsageError
self._started_get_code = True self._started_get_code = True
d = self._channel_manager.allocate() channelid = yield self._channel_manager.allocate()
def _got_channelid(channelid): code = codes.make_code(channelid, code_length)
code = codes.make_code(channelid, code_length) assert isinstance(code, type(u"")), type(code)
assert isinstance(code, type(u"")), type(code) self._set_code_and_channelid(code)
self._set_code_and_channelid(code) self._start()
self._start() returnValue(code)
return code
d.addCallback(_got_channelid)
return d
def set_code(self, code): def set_code(self, code):
if not isinstance(code, type(u"")): raise TypeError(type(code)) if not isinstance(code, type(u"")): raise TypeError(type(code))
@ -361,37 +356,34 @@ class Wormhole:
data = box.decrypt(encrypted) data = box.decrypt(encrypted)
return data return data
@inlineCallbacks
def _get_key(self): def _get_key(self):
# TODO: prevent multiple invocation # TODO: prevent multiple invocation
if self.key: if self.key:
return defer.succeed(self.key) returnValue(self.key)
d = self._channel.send(u"pake", self.msg1) yield self._channel.send(u"pake", self.msg1)
d.addCallback(lambda _: self._channel.get(u"pake")) pake_msg = yield self._channel.get(u"pake")
def _got_pake(pake_msg): key = self.sp.finish(pake_msg)
key = self.sp.finish(pake_msg) self.key = key
self.key = key self.verifier = self.derive_key(u"wormhole:verifier")
self.verifier = self.derive_key(u"wormhole:verifier") if not self._send_confirm:
if not self._send_confirm: returnValue(key)
return key confkey = self.derive_key(u"wormhole:confirmation")
confkey = self.derive_key(u"wormhole:confirmation") nonce = os.urandom(CONFMSG_NONCE_LENGTH)
nonce = os.urandom(CONFMSG_NONCE_LENGTH) confmsg = make_confmsg(confkey, nonce)
confmsg = make_confmsg(confkey, nonce) yield self._channel.send(u"_confirm", confmsg)
d1 = self._channel.send(u"_confirm", confmsg) returnValue(key)
d1.addCallback(lambda _: key)
return d1
d.addCallback(_got_pake)
return d
@close_on_error @close_on_error
@inlineCallbacks
def get_verifier(self): def get_verifier(self):
if self._closed: raise UsageError if self._closed: raise UsageError
if self.code is None: raise UsageError if self.code is None: raise UsageError
d = self._get_key() yield self._get_key()
d.addCallback(lambda _: self.verifier) returnValue(self.verifier)
return d
@close_on_error @close_on_error
@inlineCallbacks
def send_data(self, outbound_data, phase=u"data"): def send_data(self, outbound_data, phase=u"data"):
if not isinstance(outbound_data, type(b"")): if not isinstance(outbound_data, type(b"")):
raise TypeError(type(outbound_data)) raise TypeError(type(outbound_data))
@ -406,15 +398,13 @@ class Wormhole:
# nonces to keep the messages distinct, and the Channel automatically # nonces to keep the messages distinct, and the Channel automatically
# ignores reflections. # ignores reflections.
self._sent_data.add(phase) self._sent_data.add(phase)
d = self._get_key() yield self._get_key()
def _send(key): data_key = self.derive_key(u"wormhole:phase:%s" % phase)
data_key = self.derive_key(u"wormhole:phase:%s" % phase) outbound_encrypted = self._encrypt_data(data_key, outbound_data)
outbound_encrypted = self._encrypt_data(data_key, outbound_data) yield self._channel.send(phase, outbound_encrypted)
return self._channel.send(phase, outbound_encrypted)
d.addCallback(_send)
return d
@close_on_error @close_on_error
@inlineCallbacks
def get_data(self, phase=u"data"): def get_data(self, phase=u"data"):
if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if phase in self._got_data: raise UsageError # only call this once if phase in self._got_data: raise UsageError # only call this once
@ -423,45 +413,37 @@ class Wormhole:
if self.code is None: raise UsageError if self.code is None: raise UsageError
if self._channel is None: raise UsageError if self._channel is None: raise UsageError
self._got_data.add(phase) self._got_data.add(phase)
d = self._get_key() yield self._get_key()
def _get(key): phases = []
phases = [] if not self._got_confirmation:
if not self._got_confirmation: phases.append(u"_confirm")
phases.append(u"_confirm") phases.append(phase)
phases.append(phase) phase_and_body = yield self._channel.get_first_of(phases)
d1 = self._channel.get_first_of(phases) (got_phase, body) = phase_and_body
def _maybe_got_confirm(phase_and_body): if got_phase == u"_confirm":
(got_phase, body) = phase_and_body confkey = self.derive_key(u"wormhole:confirmation")
if got_phase == u"_confirm": nonce = body[:CONFMSG_NONCE_LENGTH]
confkey = self.derive_key(u"wormhole:confirmation") if body != make_confmsg(confkey, nonce):
nonce = body[:CONFMSG_NONCE_LENGTH] raise WrongPasswordError
if body != make_confmsg(confkey, nonce): self._got_confirmation = True
raise WrongPasswordError phase_and_body = yield self._channel.get_first_of([phase])
self._got_confirmation = True (got_phase, body) = phase_and_body
return self._channel.get_first_of([phase]) assert got_phase == phase
return phase_and_body try:
d1.addCallback(_maybe_got_confirm) data_key = self.derive_key(u"wormhole:phase:%s" % phase)
def _got(phase_and_body): inbound_data = self._decrypt_data(data_key, body)
(got_phase, body) = phase_and_body returnValue(inbound_data)
assert got_phase == phase except CryptoError:
try: raise WrongPasswordError
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
inbound_data = self._decrypt_data(data_key, body)
return inbound_data
except CryptoError:
raise WrongPasswordError
d1.addCallback(_got)
return d1
d.addCallback(_get)
return d
@inlineCallbacks
def close(self, res=None, mood=u"happy"): def close(self, res=None, mood=u"happy"):
if not isinstance(mood, (type(None), type(u""))): if not isinstance(mood, (type(None), type(u""))):
raise TypeError(type(mood)) raise TypeError(type(mood))
self._closed = True self._closed = True
d = defer.succeed(None) if not self._channel:
if self._channel: returnValue(None)
c, self._channel = self._channel, None c, self._channel = self._channel, None
monitor.close(c) monitor.close(c)
d.addCallback(lambda _: c.deallocate(mood)) yield c.deallocate(mood)
return d

View File

@ -5,6 +5,7 @@ from zope.interface import implementer
from twisted.python.runtime import platformType from twisted.python.runtime import platformType
from twisted.internet import (reactor, interfaces, defer, protocol, from twisted.internet import (reactor, interfaces, defer, protocol,
endpoints, task, address, error) endpoints, task, address, error)
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.protocols import policies from twisted.protocols import policies
from nacl.secret import SecretBox from nacl.secret import SecretBox
from ..util import ipaddrs from ..util import ipaddrs
@ -536,15 +537,16 @@ class Common:
self._waiting_for_transit_key.append(d) self._waiting_for_transit_key.append(d)
return d return d
@inlineCallbacks
def connect(self): def connect(self):
d = self._get_transit_key() yield self._get_transit_key()
d.addCallback(self._connect)
# we want to have the transit key before starting any outbound # we want to have the transit key before starting any outbound
# connections, so those connections will know what to say when they # connections, so those connections will know what to say when they
# connect # connect
return d winner = yield self._connect()
returnValue(winner)
def _connect(self, _): def _connect(self):
# It might be nice to wire this so that a failure in the direct hints # It might be nice to wire this so that a failure in the direct hints
# causes the relay hints to be used right away (fast failover). But # causes the relay hints to be used right away (fast failover). But
# none of our current use cases would take advantage of that: if we # none of our current use cases would take advantage of that: if we