add some inlineCallbacks for simplicity
This control flow was getting too hairy.
This commit is contained in:
parent
fd143caded
commit
84def8a54b
|
@ -321,13 +321,13 @@ class Basic(ServerBase, unittest.TestCase):
|
||||||
d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data"), UsageError))
|
d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data"), UsageError))
|
||||||
d.addCallback(lambda _: self.assertFailure(w1.get_data(), UsageError))
|
d.addCallback(lambda _: self.assertFailure(w1.get_data(), UsageError))
|
||||||
d.addCallback(lambda _: w1.set_code(u"123-purple-elephant"))
|
d.addCallback(lambda _: w1.set_code(u"123-purple-elephant"))
|
||||||
# these two UsageErrors are synchronous, although most of the rest are async
|
# this UsageError is synchronous, although most of the rest are async
|
||||||
d.addCallback(lambda _: self.assertRaises(UsageError, w1.set_code, u"123-nope"))
|
d.addCallback(lambda _: self.assertRaises(UsageError, w1.set_code, u"123-nope"))
|
||||||
d.addCallback(lambda _: self.assertRaises(UsageError, w1.get_code))
|
d.addCallback(lambda _: self.assertFailure(w1.get_code(), UsageError))
|
||||||
def _then(_):
|
def _then(_):
|
||||||
w2 = Wormhole(APPID, self.relayurl)
|
w2 = Wormhole(APPID, self.relayurl)
|
||||||
d2 = w2.get_code()
|
d2 = w2.get_code()
|
||||||
d2.addCallback(lambda _: self.assertRaises(UsageError, w2.get_code))
|
d2.addCallback(lambda _: self.assertFailure(w2.get_code(), UsageError))
|
||||||
d2.addCallback(lambda _: self.doBoth(w1.close(), w2.close()))
|
d2.addCallback(lambda _: self.doBoth(w1.close(), w2.close()))
|
||||||
return d2
|
return d2
|
||||||
d.addCallback(_then)
|
d.addCallback(_then)
|
||||||
|
|
|
@ -4,6 +4,7 @@ from six.moves.urllib_parse import urlencode
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
from twisted.internet import reactor, defer
|
from twisted.internet import reactor, defer
|
||||||
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||||
from twisted.web import client as web_client
|
from twisted.web import client as web_client
|
||||||
from twisted.web import error as web_error
|
from twisted.web import error as web_error
|
||||||
from twisted.web.iweb import IBodyProducer
|
from twisted.web.iweb import IBodyProducer
|
||||||
|
@ -150,14 +151,12 @@ class Channel:
|
||||||
d.addCallback(lambda _: msgs[0])
|
d.addCallback(lambda _: msgs[0])
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
def get(self, phase):
|
def get(self, phase):
|
||||||
d = self.get_first_of([phase])
|
res = yield self.get_first_of([phase])
|
||||||
def _got(res):
|
|
||||||
(got_phase, body) = res
|
(got_phase, body) = res
|
||||||
assert got_phase == phase
|
assert got_phase == phase
|
||||||
return body
|
returnValue(body)
|
||||||
d.addCallback(_got)
|
|
||||||
return d
|
|
||||||
|
|
||||||
def deallocate(self, mood=None):
|
def deallocate(self, mood=None):
|
||||||
# only try once, no retries
|
# only try once, no retries
|
||||||
|
@ -178,23 +177,21 @@ class ChannelManager:
|
||||||
self._handle_welcome = handle_welcome
|
self._handle_welcome = handle_welcome
|
||||||
self._agent = web_client.Agent(reactor)
|
self._agent = web_client.Agent(reactor)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
def allocate(self):
|
def allocate(self):
|
||||||
url = self._relay + "allocate"
|
url = self._relay + "allocate"
|
||||||
d = post_json(self._agent, url, {"appid": self._appid,
|
data = yield post_json(self._agent, url, {"appid": self._appid,
|
||||||
"side": self._side})
|
"side": self._side})
|
||||||
def _got_channel(data):
|
|
||||||
if "welcome" in data:
|
if "welcome" in data:
|
||||||
self._handle_welcome(data["welcome"])
|
self._handle_welcome(data["welcome"])
|
||||||
return data["channelid"]
|
returnValue(data["channelid"])
|
||||||
d.addCallback(_got_channel)
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
def list_channels(self):
|
def list_channels(self):
|
||||||
queryargs = urlencode([("appid", self._appid)])
|
queryargs = urlencode([("appid", self._appid)])
|
||||||
url = self._relay + u"list?%s" % queryargs
|
url = self._relay + u"list?%s" % queryargs
|
||||||
d = get_json(self._agent, url)
|
r = yield get_json(self._agent, url)
|
||||||
d.addCallback(lambda r: r["channelids"])
|
returnValue(r["channelids"])
|
||||||
return d
|
|
||||||
|
|
||||||
def connect(self, channelid):
|
def connect(self, channelid):
|
||||||
return Channel(self._relay, self._appid, channelid, self._side,
|
return Channel(self._relay, self._appid, channelid, self._side,
|
||||||
|
@ -274,19 +271,17 @@ class Wormhole:
|
||||||
if "error" in welcome:
|
if "error" in welcome:
|
||||||
raise ServerError(welcome["error"], self._relay_url)
|
raise ServerError(welcome["error"], self._relay_url)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
def get_code(self, code_length=2):
|
def get_code(self, code_length=2):
|
||||||
if self.code is not None: raise UsageError
|
if self.code is not None: raise UsageError
|
||||||
if self._started_get_code: raise UsageError
|
if self._started_get_code: raise UsageError
|
||||||
self._started_get_code = True
|
self._started_get_code = True
|
||||||
d = self._channel_manager.allocate()
|
channelid = yield self._channel_manager.allocate()
|
||||||
def _got_channelid(channelid):
|
|
||||||
code = codes.make_code(channelid, code_length)
|
code = codes.make_code(channelid, code_length)
|
||||||
assert isinstance(code, type(u"")), type(code)
|
assert isinstance(code, type(u"")), type(code)
|
||||||
self._set_code_and_channelid(code)
|
self._set_code_and_channelid(code)
|
||||||
self._start()
|
self._start()
|
||||||
return code
|
returnValue(code)
|
||||||
d.addCallback(_got_channelid)
|
|
||||||
return d
|
|
||||||
|
|
||||||
def set_code(self, code):
|
def set_code(self, code):
|
||||||
if not isinstance(code, type(u"")): raise TypeError(type(code))
|
if not isinstance(code, type(u"")): raise TypeError(type(code))
|
||||||
|
@ -361,37 +356,34 @@ class Wormhole:
|
||||||
data = box.decrypt(encrypted)
|
data = box.decrypt(encrypted)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
def _get_key(self):
|
def _get_key(self):
|
||||||
# TODO: prevent multiple invocation
|
# TODO: prevent multiple invocation
|
||||||
if self.key:
|
if self.key:
|
||||||
return defer.succeed(self.key)
|
returnValue(self.key)
|
||||||
d = self._channel.send(u"pake", self.msg1)
|
yield self._channel.send(u"pake", self.msg1)
|
||||||
d.addCallback(lambda _: self._channel.get(u"pake"))
|
pake_msg = yield self._channel.get(u"pake")
|
||||||
def _got_pake(pake_msg):
|
|
||||||
key = self.sp.finish(pake_msg)
|
key = self.sp.finish(pake_msg)
|
||||||
self.key = key
|
self.key = key
|
||||||
self.verifier = self.derive_key(u"wormhole:verifier")
|
self.verifier = self.derive_key(u"wormhole:verifier")
|
||||||
if not self._send_confirm:
|
if not self._send_confirm:
|
||||||
return key
|
returnValue(key)
|
||||||
confkey = self.derive_key(u"wormhole:confirmation")
|
confkey = self.derive_key(u"wormhole:confirmation")
|
||||||
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
|
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
|
||||||
confmsg = make_confmsg(confkey, nonce)
|
confmsg = make_confmsg(confkey, nonce)
|
||||||
d1 = self._channel.send(u"_confirm", confmsg)
|
yield self._channel.send(u"_confirm", confmsg)
|
||||||
d1.addCallback(lambda _: key)
|
returnValue(key)
|
||||||
return d1
|
|
||||||
d.addCallback(_got_pake)
|
|
||||||
return d
|
|
||||||
|
|
||||||
@close_on_error
|
@close_on_error
|
||||||
|
@inlineCallbacks
|
||||||
def get_verifier(self):
|
def get_verifier(self):
|
||||||
if self._closed: raise UsageError
|
if self._closed: raise UsageError
|
||||||
if self.code is None: raise UsageError
|
if self.code is None: raise UsageError
|
||||||
d = self._get_key()
|
yield self._get_key()
|
||||||
d.addCallback(lambda _: self.verifier)
|
returnValue(self.verifier)
|
||||||
return d
|
|
||||||
|
|
||||||
@close_on_error
|
@close_on_error
|
||||||
|
@inlineCallbacks
|
||||||
def send_data(self, outbound_data, phase=u"data"):
|
def send_data(self, outbound_data, phase=u"data"):
|
||||||
if not isinstance(outbound_data, type(b"")):
|
if not isinstance(outbound_data, type(b"")):
|
||||||
raise TypeError(type(outbound_data))
|
raise TypeError(type(outbound_data))
|
||||||
|
@ -406,15 +398,13 @@ class Wormhole:
|
||||||
# nonces to keep the messages distinct, and the Channel automatically
|
# nonces to keep the messages distinct, and the Channel automatically
|
||||||
# ignores reflections.
|
# ignores reflections.
|
||||||
self._sent_data.add(phase)
|
self._sent_data.add(phase)
|
||||||
d = self._get_key()
|
yield self._get_key()
|
||||||
def _send(key):
|
|
||||||
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||||
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
||||||
return self._channel.send(phase, outbound_encrypted)
|
yield self._channel.send(phase, outbound_encrypted)
|
||||||
d.addCallback(_send)
|
|
||||||
return d
|
|
||||||
|
|
||||||
@close_on_error
|
@close_on_error
|
||||||
|
@inlineCallbacks
|
||||||
def get_data(self, phase=u"data"):
|
def get_data(self, phase=u"data"):
|
||||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||||
if phase in self._got_data: raise UsageError # only call this once
|
if phase in self._got_data: raise UsageError # only call this once
|
||||||
|
@ -423,14 +413,12 @@ class Wormhole:
|
||||||
if self.code is None: raise UsageError
|
if self.code is None: raise UsageError
|
||||||
if self._channel is None: raise UsageError
|
if self._channel is None: raise UsageError
|
||||||
self._got_data.add(phase)
|
self._got_data.add(phase)
|
||||||
d = self._get_key()
|
yield self._get_key()
|
||||||
def _get(key):
|
|
||||||
phases = []
|
phases = []
|
||||||
if not self._got_confirmation:
|
if not self._got_confirmation:
|
||||||
phases.append(u"_confirm")
|
phases.append(u"_confirm")
|
||||||
phases.append(phase)
|
phases.append(phase)
|
||||||
d1 = self._channel.get_first_of(phases)
|
phase_and_body = yield self._channel.get_first_of(phases)
|
||||||
def _maybe_got_confirm(phase_and_body):
|
|
||||||
(got_phase, body) = phase_and_body
|
(got_phase, body) = phase_and_body
|
||||||
if got_phase == u"_confirm":
|
if got_phase == u"_confirm":
|
||||||
confkey = self.derive_key(u"wormhole:confirmation")
|
confkey = self.derive_key(u"wormhole:confirmation")
|
||||||
|
@ -438,30 +426,24 @@ class Wormhole:
|
||||||
if body != make_confmsg(confkey, nonce):
|
if body != make_confmsg(confkey, nonce):
|
||||||
raise WrongPasswordError
|
raise WrongPasswordError
|
||||||
self._got_confirmation = True
|
self._got_confirmation = True
|
||||||
return self._channel.get_first_of([phase])
|
phase_and_body = yield self._channel.get_first_of([phase])
|
||||||
return phase_and_body
|
|
||||||
d1.addCallback(_maybe_got_confirm)
|
|
||||||
def _got(phase_and_body):
|
|
||||||
(got_phase, body) = phase_and_body
|
(got_phase, body) = phase_and_body
|
||||||
assert got_phase == phase
|
assert got_phase == phase
|
||||||
try:
|
try:
|
||||||
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||||
inbound_data = self._decrypt_data(data_key, body)
|
inbound_data = self._decrypt_data(data_key, body)
|
||||||
return inbound_data
|
returnValue(inbound_data)
|
||||||
except CryptoError:
|
except CryptoError:
|
||||||
raise WrongPasswordError
|
raise WrongPasswordError
|
||||||
d1.addCallback(_got)
|
|
||||||
return d1
|
|
||||||
d.addCallback(_get)
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
def close(self, res=None, mood=u"happy"):
|
def close(self, res=None, mood=u"happy"):
|
||||||
if not isinstance(mood, (type(None), type(u""))):
|
if not isinstance(mood, (type(None), type(u""))):
|
||||||
raise TypeError(type(mood))
|
raise TypeError(type(mood))
|
||||||
self._closed = True
|
self._closed = True
|
||||||
d = defer.succeed(None)
|
if not self._channel:
|
||||||
if self._channel:
|
returnValue(None)
|
||||||
c, self._channel = self._channel, None
|
c, self._channel = self._channel, None
|
||||||
monitor.close(c)
|
monitor.close(c)
|
||||||
d.addCallback(lambda _: c.deallocate(mood))
|
yield c.deallocate(mood)
|
||||||
return d
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from zope.interface import implementer
|
||||||
from twisted.python.runtime import platformType
|
from twisted.python.runtime import platformType
|
||||||
from twisted.internet import (reactor, interfaces, defer, protocol,
|
from twisted.internet import (reactor, interfaces, defer, protocol,
|
||||||
endpoints, task, address, error)
|
endpoints, task, address, error)
|
||||||
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||||
from twisted.protocols import policies
|
from twisted.protocols import policies
|
||||||
from nacl.secret import SecretBox
|
from nacl.secret import SecretBox
|
||||||
from ..util import ipaddrs
|
from ..util import ipaddrs
|
||||||
|
@ -536,15 +537,16 @@ class Common:
|
||||||
self._waiting_for_transit_key.append(d)
|
self._waiting_for_transit_key.append(d)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
def connect(self):
|
def connect(self):
|
||||||
d = self._get_transit_key()
|
yield self._get_transit_key()
|
||||||
d.addCallback(self._connect)
|
|
||||||
# we want to have the transit key before starting any outbound
|
# we want to have the transit key before starting any outbound
|
||||||
# connections, so those connections will know what to say when they
|
# connections, so those connections will know what to say when they
|
||||||
# connect
|
# connect
|
||||||
return d
|
winner = yield self._connect()
|
||||||
|
returnValue(winner)
|
||||||
|
|
||||||
def _connect(self, _):
|
def _connect(self):
|
||||||
# It might be nice to wire this so that a failure in the direct hints
|
# It might be nice to wire this so that a failure in the direct hints
|
||||||
# causes the relay hints to be used right away (fast failover). But
|
# causes the relay hints to be used right away (fast failover). But
|
||||||
# none of our current use cases would take advantage of that: if we
|
# none of our current use cases would take advantage of that: if we
|
||||||
|
|
Loading…
Reference in New Issue
Block a user