add some inlineCallbacks for simplicity

This control flow was getting too hairy.
This commit is contained in:
Brian Warner 2016-03-01 17:55:25 -08:00
parent fd143caded
commit 84def8a54b
3 changed files with 83 additions and 99 deletions

View File

@ -321,13 +321,13 @@ class Basic(ServerBase, unittest.TestCase):
d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data"), UsageError))
d.addCallback(lambda _: self.assertFailure(w1.get_data(), UsageError))
d.addCallback(lambda _: w1.set_code(u"123-purple-elephant"))
# these two UsageErrors are synchronous, although most of the rest are async
# this UsageError is synchronous, although most of the rest are async
d.addCallback(lambda _: self.assertRaises(UsageError, w1.set_code, u"123-nope"))
d.addCallback(lambda _: self.assertRaises(UsageError, w1.get_code))
d.addCallback(lambda _: self.assertFailure(w1.get_code(), UsageError))
def _then(_):
w2 = Wormhole(APPID, self.relayurl)
d2 = w2.get_code()
d2.addCallback(lambda _: self.assertRaises(UsageError, w2.get_code))
d2.addCallback(lambda _: self.assertFailure(w2.get_code(), UsageError))
d2.addCallback(lambda _: self.doBoth(w1.close(), w2.close()))
return d2
d.addCallback(_then)

View File

@ -4,6 +4,7 @@ from six.moves.urllib_parse import urlencode
from binascii import hexlify, unhexlify
from zope.interface import implementer
from twisted.internet import reactor, defer
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.web import client as web_client
from twisted.web import error as web_error
from twisted.web.iweb import IBodyProducer
@ -150,14 +151,12 @@ class Channel:
d.addCallback(lambda _: msgs[0])
return d
@inlineCallbacks
def get(self, phase):
d = self.get_first_of([phase])
def _got(res):
(got_phase, body) = res
assert got_phase == phase
return body
d.addCallback(_got)
return d
res = yield self.get_first_of([phase])
(got_phase, body) = res
assert got_phase == phase
returnValue(body)
def deallocate(self, mood=None):
# only try once, no retries
@ -178,23 +177,21 @@ class ChannelManager:
self._handle_welcome = handle_welcome
self._agent = web_client.Agent(reactor)
@inlineCallbacks
def allocate(self):
url = self._relay + "allocate"
d = post_json(self._agent, url, {"appid": self._appid,
"side": self._side})
def _got_channel(data):
if "welcome" in data:
self._handle_welcome(data["welcome"])
return data["channelid"]
d.addCallback(_got_channel)
return d
data = yield post_json(self._agent, url, {"appid": self._appid,
"side": self._side})
if "welcome" in data:
self._handle_welcome(data["welcome"])
returnValue(data["channelid"])
@inlineCallbacks
def list_channels(self):
queryargs = urlencode([("appid", self._appid)])
url = self._relay + u"list?%s" % queryargs
d = get_json(self._agent, url)
d.addCallback(lambda r: r["channelids"])
return d
r = yield get_json(self._agent, url)
returnValue(r["channelids"])
def connect(self, channelid):
return Channel(self._relay, self._appid, channelid, self._side,
@ -274,19 +271,17 @@ class Wormhole:
if "error" in welcome:
raise ServerError(welcome["error"], self._relay_url)
@inlineCallbacks
def get_code(self, code_length=2):
if self.code is not None: raise UsageError
if self._started_get_code: raise UsageError
self._started_get_code = True
d = self._channel_manager.allocate()
def _got_channelid(channelid):
code = codes.make_code(channelid, code_length)
assert isinstance(code, type(u"")), type(code)
self._set_code_and_channelid(code)
self._start()
return code
d.addCallback(_got_channelid)
return d
channelid = yield self._channel_manager.allocate()
code = codes.make_code(channelid, code_length)
assert isinstance(code, type(u"")), type(code)
self._set_code_and_channelid(code)
self._start()
returnValue(code)
def set_code(self, code):
if not isinstance(code, type(u"")): raise TypeError(type(code))
@ -361,37 +356,34 @@ class Wormhole:
data = box.decrypt(encrypted)
return data
@inlineCallbacks
def _get_key(self):
# TODO: prevent multiple invocation
if self.key:
return defer.succeed(self.key)
d = self._channel.send(u"pake", self.msg1)
d.addCallback(lambda _: self._channel.get(u"pake"))
def _got_pake(pake_msg):
key = self.sp.finish(pake_msg)
self.key = key
self.verifier = self.derive_key(u"wormhole:verifier")
if not self._send_confirm:
return key
confkey = self.derive_key(u"wormhole:confirmation")
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
confmsg = make_confmsg(confkey, nonce)
d1 = self._channel.send(u"_confirm", confmsg)
d1.addCallback(lambda _: key)
return d1
d.addCallback(_got_pake)
return d
returnValue(self.key)
yield self._channel.send(u"pake", self.msg1)
pake_msg = yield self._channel.get(u"pake")
key = self.sp.finish(pake_msg)
self.key = key
self.verifier = self.derive_key(u"wormhole:verifier")
if not self._send_confirm:
returnValue(key)
confkey = self.derive_key(u"wormhole:confirmation")
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
confmsg = make_confmsg(confkey, nonce)
yield self._channel.send(u"_confirm", confmsg)
returnValue(key)
@close_on_error
@inlineCallbacks
def get_verifier(self):
if self._closed: raise UsageError
if self.code is None: raise UsageError
d = self._get_key()
d.addCallback(lambda _: self.verifier)
return d
yield self._get_key()
returnValue(self.verifier)
@close_on_error
@inlineCallbacks
def send_data(self, outbound_data, phase=u"data"):
if not isinstance(outbound_data, type(b"")):
raise TypeError(type(outbound_data))
@ -406,15 +398,13 @@ class Wormhole:
# nonces to keep the messages distinct, and the Channel automatically
# ignores reflections.
self._sent_data.add(phase)
d = self._get_key()
def _send(key):
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
return self._channel.send(phase, outbound_encrypted)
d.addCallback(_send)
return d
yield self._get_key()
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
yield self._channel.send(phase, outbound_encrypted)
@close_on_error
@inlineCallbacks
def get_data(self, phase=u"data"):
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if phase in self._got_data: raise UsageError # only call this once
@ -423,45 +413,37 @@ class Wormhole:
if self.code is None: raise UsageError
if self._channel is None: raise UsageError
self._got_data.add(phase)
d = self._get_key()
def _get(key):
phases = []
if not self._got_confirmation:
phases.append(u"_confirm")
phases.append(phase)
d1 = self._channel.get_first_of(phases)
def _maybe_got_confirm(phase_and_body):
(got_phase, body) = phase_and_body
if got_phase == u"_confirm":
confkey = self.derive_key(u"wormhole:confirmation")
nonce = body[:CONFMSG_NONCE_LENGTH]
if body != make_confmsg(confkey, nonce):
raise WrongPasswordError
self._got_confirmation = True
return self._channel.get_first_of([phase])
return phase_and_body
d1.addCallback(_maybe_got_confirm)
def _got(phase_and_body):
(got_phase, body) = phase_and_body
assert got_phase == phase
try:
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
inbound_data = self._decrypt_data(data_key, body)
return inbound_data
except CryptoError:
raise WrongPasswordError
d1.addCallback(_got)
return d1
d.addCallback(_get)
return d
yield self._get_key()
phases = []
if not self._got_confirmation:
phases.append(u"_confirm")
phases.append(phase)
phase_and_body = yield self._channel.get_first_of(phases)
(got_phase, body) = phase_and_body
if got_phase == u"_confirm":
confkey = self.derive_key(u"wormhole:confirmation")
nonce = body[:CONFMSG_NONCE_LENGTH]
if body != make_confmsg(confkey, nonce):
raise WrongPasswordError
self._got_confirmation = True
phase_and_body = yield self._channel.get_first_of([phase])
(got_phase, body) = phase_and_body
assert got_phase == phase
try:
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
inbound_data = self._decrypt_data(data_key, body)
returnValue(inbound_data)
except CryptoError:
raise WrongPasswordError
@inlineCallbacks
def close(self, res=None, mood=u"happy"):
if not isinstance(mood, (type(None), type(u""))):
raise TypeError(type(mood))
self._closed = True
d = defer.succeed(None)
if self._channel:
c, self._channel = self._channel, None
monitor.close(c)
d.addCallback(lambda _: c.deallocate(mood))
return d
if not self._channel:
returnValue(None)
c, self._channel = self._channel, None
monitor.close(c)
yield c.deallocate(mood)

View File

@ -5,6 +5,7 @@ from zope.interface import implementer
from twisted.python.runtime import platformType
from twisted.internet import (reactor, interfaces, defer, protocol,
endpoints, task, address, error)
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.protocols import policies
from nacl.secret import SecretBox
from ..util import ipaddrs
@ -536,15 +537,16 @@ class Common:
self._waiting_for_transit_key.append(d)
return d
@inlineCallbacks
def connect(self):
d = self._get_transit_key()
d.addCallback(self._connect)
yield self._get_transit_key()
# we want to have the transit key before starting any outbound
# connections, so those connections will know what to say when they
# connect
return d
winner = yield self._connect()
returnValue(winner)
def _connect(self, _):
def _connect(self):
# It might be nice to wire this so that a failure in the direct hints
# causes the relay hints to be used right away (fast failover). But
# none of our current use cases would take advantage of that: if we