add some inlineCallbacks for simplicity
This control flow was getting too hairy.
This commit is contained in:
parent
fd143caded
commit
84def8a54b
|
@ -321,13 +321,13 @@ class Basic(ServerBase, unittest.TestCase):
|
|||
d.addCallback(lambda _: self.assertFailure(w1.send_data(b"data"), UsageError))
|
||||
d.addCallback(lambda _: self.assertFailure(w1.get_data(), UsageError))
|
||||
d.addCallback(lambda _: w1.set_code(u"123-purple-elephant"))
|
||||
# these two UsageErrors are synchronous, although most of the rest are async
|
||||
# this UsageError is synchronous, although most of the rest are async
|
||||
d.addCallback(lambda _: self.assertRaises(UsageError, w1.set_code, u"123-nope"))
|
||||
d.addCallback(lambda _: self.assertRaises(UsageError, w1.get_code))
|
||||
d.addCallback(lambda _: self.assertFailure(w1.get_code(), UsageError))
|
||||
def _then(_):
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
d2 = w2.get_code()
|
||||
d2.addCallback(lambda _: self.assertRaises(UsageError, w2.get_code))
|
||||
d2.addCallback(lambda _: self.assertFailure(w2.get_code(), UsageError))
|
||||
d2.addCallback(lambda _: self.doBoth(w1.close(), w2.close()))
|
||||
return d2
|
||||
d.addCallback(_then)
|
||||
|
|
|
@ -4,6 +4,7 @@ from six.moves.urllib_parse import urlencode
|
|||
from binascii import hexlify, unhexlify
|
||||
from zope.interface import implementer
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from twisted.web import client as web_client
|
||||
from twisted.web import error as web_error
|
||||
from twisted.web.iweb import IBodyProducer
|
||||
|
@ -150,14 +151,12 @@ class Channel:
|
|||
d.addCallback(lambda _: msgs[0])
|
||||
return d
|
||||
|
||||
@inlineCallbacks
|
||||
def get(self, phase):
|
||||
d = self.get_first_of([phase])
|
||||
def _got(res):
|
||||
(got_phase, body) = res
|
||||
assert got_phase == phase
|
||||
return body
|
||||
d.addCallback(_got)
|
||||
return d
|
||||
res = yield self.get_first_of([phase])
|
||||
(got_phase, body) = res
|
||||
assert got_phase == phase
|
||||
returnValue(body)
|
||||
|
||||
def deallocate(self, mood=None):
|
||||
# only try once, no retries
|
||||
|
@ -178,23 +177,21 @@ class ChannelManager:
|
|||
self._handle_welcome = handle_welcome
|
||||
self._agent = web_client.Agent(reactor)
|
||||
|
||||
@inlineCallbacks
|
||||
def allocate(self):
|
||||
url = self._relay + "allocate"
|
||||
d = post_json(self._agent, url, {"appid": self._appid,
|
||||
"side": self._side})
|
||||
def _got_channel(data):
|
||||
if "welcome" in data:
|
||||
self._handle_welcome(data["welcome"])
|
||||
return data["channelid"]
|
||||
d.addCallback(_got_channel)
|
||||
return d
|
||||
data = yield post_json(self._agent, url, {"appid": self._appid,
|
||||
"side": self._side})
|
||||
if "welcome" in data:
|
||||
self._handle_welcome(data["welcome"])
|
||||
returnValue(data["channelid"])
|
||||
|
||||
@inlineCallbacks
|
||||
def list_channels(self):
|
||||
queryargs = urlencode([("appid", self._appid)])
|
||||
url = self._relay + u"list?%s" % queryargs
|
||||
d = get_json(self._agent, url)
|
||||
d.addCallback(lambda r: r["channelids"])
|
||||
return d
|
||||
r = yield get_json(self._agent, url)
|
||||
returnValue(r["channelids"])
|
||||
|
||||
def connect(self, channelid):
|
||||
return Channel(self._relay, self._appid, channelid, self._side,
|
||||
|
@ -274,19 +271,17 @@ class Wormhole:
|
|||
if "error" in welcome:
|
||||
raise ServerError(welcome["error"], self._relay_url)
|
||||
|
||||
@inlineCallbacks
|
||||
def get_code(self, code_length=2):
|
||||
if self.code is not None: raise UsageError
|
||||
if self._started_get_code: raise UsageError
|
||||
self._started_get_code = True
|
||||
d = self._channel_manager.allocate()
|
||||
def _got_channelid(channelid):
|
||||
code = codes.make_code(channelid, code_length)
|
||||
assert isinstance(code, type(u"")), type(code)
|
||||
self._set_code_and_channelid(code)
|
||||
self._start()
|
||||
return code
|
||||
d.addCallback(_got_channelid)
|
||||
return d
|
||||
channelid = yield self._channel_manager.allocate()
|
||||
code = codes.make_code(channelid, code_length)
|
||||
assert isinstance(code, type(u"")), type(code)
|
||||
self._set_code_and_channelid(code)
|
||||
self._start()
|
||||
returnValue(code)
|
||||
|
||||
def set_code(self, code):
|
||||
if not isinstance(code, type(u"")): raise TypeError(type(code))
|
||||
|
@ -361,37 +356,34 @@ class Wormhole:
|
|||
data = box.decrypt(encrypted)
|
||||
return data
|
||||
|
||||
|
||||
@inlineCallbacks
|
||||
def _get_key(self):
|
||||
# TODO: prevent multiple invocation
|
||||
if self.key:
|
||||
return defer.succeed(self.key)
|
||||
d = self._channel.send(u"pake", self.msg1)
|
||||
d.addCallback(lambda _: self._channel.get(u"pake"))
|
||||
def _got_pake(pake_msg):
|
||||
key = self.sp.finish(pake_msg)
|
||||
self.key = key
|
||||
self.verifier = self.derive_key(u"wormhole:verifier")
|
||||
if not self._send_confirm:
|
||||
return key
|
||||
confkey = self.derive_key(u"wormhole:confirmation")
|
||||
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
|
||||
confmsg = make_confmsg(confkey, nonce)
|
||||
d1 = self._channel.send(u"_confirm", confmsg)
|
||||
d1.addCallback(lambda _: key)
|
||||
return d1
|
||||
d.addCallback(_got_pake)
|
||||
return d
|
||||
returnValue(self.key)
|
||||
yield self._channel.send(u"pake", self.msg1)
|
||||
pake_msg = yield self._channel.get(u"pake")
|
||||
key = self.sp.finish(pake_msg)
|
||||
self.key = key
|
||||
self.verifier = self.derive_key(u"wormhole:verifier")
|
||||
if not self._send_confirm:
|
||||
returnValue(key)
|
||||
confkey = self.derive_key(u"wormhole:confirmation")
|
||||
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
|
||||
confmsg = make_confmsg(confkey, nonce)
|
||||
yield self._channel.send(u"_confirm", confmsg)
|
||||
returnValue(key)
|
||||
|
||||
@close_on_error
|
||||
@inlineCallbacks
|
||||
def get_verifier(self):
|
||||
if self._closed: raise UsageError
|
||||
if self.code is None: raise UsageError
|
||||
d = self._get_key()
|
||||
d.addCallback(lambda _: self.verifier)
|
||||
return d
|
||||
yield self._get_key()
|
||||
returnValue(self.verifier)
|
||||
|
||||
@close_on_error
|
||||
@inlineCallbacks
|
||||
def send_data(self, outbound_data, phase=u"data"):
|
||||
if not isinstance(outbound_data, type(b"")):
|
||||
raise TypeError(type(outbound_data))
|
||||
|
@ -406,15 +398,13 @@ class Wormhole:
|
|||
# nonces to keep the messages distinct, and the Channel automatically
|
||||
# ignores reflections.
|
||||
self._sent_data.add(phase)
|
||||
d = self._get_key()
|
||||
def _send(key):
|
||||
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
||||
return self._channel.send(phase, outbound_encrypted)
|
||||
d.addCallback(_send)
|
||||
return d
|
||||
yield self._get_key()
|
||||
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
||||
yield self._channel.send(phase, outbound_encrypted)
|
||||
|
||||
@close_on_error
|
||||
@inlineCallbacks
|
||||
def get_data(self, phase=u"data"):
|
||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||
if phase in self._got_data: raise UsageError # only call this once
|
||||
|
@ -423,45 +413,37 @@ class Wormhole:
|
|||
if self.code is None: raise UsageError
|
||||
if self._channel is None: raise UsageError
|
||||
self._got_data.add(phase)
|
||||
d = self._get_key()
|
||||
def _get(key):
|
||||
phases = []
|
||||
if not self._got_confirmation:
|
||||
phases.append(u"_confirm")
|
||||
phases.append(phase)
|
||||
d1 = self._channel.get_first_of(phases)
|
||||
def _maybe_got_confirm(phase_and_body):
|
||||
(got_phase, body) = phase_and_body
|
||||
if got_phase == u"_confirm":
|
||||
confkey = self.derive_key(u"wormhole:confirmation")
|
||||
nonce = body[:CONFMSG_NONCE_LENGTH]
|
||||
if body != make_confmsg(confkey, nonce):
|
||||
raise WrongPasswordError
|
||||
self._got_confirmation = True
|
||||
return self._channel.get_first_of([phase])
|
||||
return phase_and_body
|
||||
d1.addCallback(_maybe_got_confirm)
|
||||
def _got(phase_and_body):
|
||||
(got_phase, body) = phase_and_body
|
||||
assert got_phase == phase
|
||||
try:
|
||||
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||
inbound_data = self._decrypt_data(data_key, body)
|
||||
return inbound_data
|
||||
except CryptoError:
|
||||
raise WrongPasswordError
|
||||
d1.addCallback(_got)
|
||||
return d1
|
||||
d.addCallback(_get)
|
||||
return d
|
||||
yield self._get_key()
|
||||
phases = []
|
||||
if not self._got_confirmation:
|
||||
phases.append(u"_confirm")
|
||||
phases.append(phase)
|
||||
phase_and_body = yield self._channel.get_first_of(phases)
|
||||
(got_phase, body) = phase_and_body
|
||||
if got_phase == u"_confirm":
|
||||
confkey = self.derive_key(u"wormhole:confirmation")
|
||||
nonce = body[:CONFMSG_NONCE_LENGTH]
|
||||
if body != make_confmsg(confkey, nonce):
|
||||
raise WrongPasswordError
|
||||
self._got_confirmation = True
|
||||
phase_and_body = yield self._channel.get_first_of([phase])
|
||||
(got_phase, body) = phase_and_body
|
||||
assert got_phase == phase
|
||||
try:
|
||||
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||
inbound_data = self._decrypt_data(data_key, body)
|
||||
returnValue(inbound_data)
|
||||
except CryptoError:
|
||||
raise WrongPasswordError
|
||||
|
||||
@inlineCallbacks
|
||||
def close(self, res=None, mood=u"happy"):
|
||||
if not isinstance(mood, (type(None), type(u""))):
|
||||
raise TypeError(type(mood))
|
||||
self._closed = True
|
||||
d = defer.succeed(None)
|
||||
if self._channel:
|
||||
c, self._channel = self._channel, None
|
||||
monitor.close(c)
|
||||
d.addCallback(lambda _: c.deallocate(mood))
|
||||
return d
|
||||
if not self._channel:
|
||||
returnValue(None)
|
||||
c, self._channel = self._channel, None
|
||||
monitor.close(c)
|
||||
yield c.deallocate(mood)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from zope.interface import implementer
|
|||
from twisted.python.runtime import platformType
|
||||
from twisted.internet import (reactor, interfaces, defer, protocol,
|
||||
endpoints, task, address, error)
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from twisted.protocols import policies
|
||||
from nacl.secret import SecretBox
|
||||
from ..util import ipaddrs
|
||||
|
@ -536,15 +537,16 @@ class Common:
|
|||
self._waiting_for_transit_key.append(d)
|
||||
return d
|
||||
|
||||
@inlineCallbacks
|
||||
def connect(self):
|
||||
d = self._get_transit_key()
|
||||
d.addCallback(self._connect)
|
||||
yield self._get_transit_key()
|
||||
# we want to have the transit key before starting any outbound
|
||||
# connections, so those connections will know what to say when they
|
||||
# connect
|
||||
return d
|
||||
winner = yield self._connect()
|
||||
returnValue(winner)
|
||||
|
||||
def _connect(self, _):
|
||||
def _connect(self):
|
||||
# It might be nice to wire this so that a failure in the direct hints
|
||||
# causes the relay hints to be used right away (fast failover). But
|
||||
# none of our current use cases would take advantage of that: if we
|
||||
|
|
Loading…
Reference in New Issue
Block a user