fix close behavior: Deferreds should errback once closed

This commit is contained in:
Brian Warner 2017-03-07 12:09:06 +01:00
parent 9571fcd388
commit aebee61816
2 changed files with 26 additions and 12 deletions

View File

@ -53,3 +53,9 @@ class NoTorError(WormholeError):
class NoKeyError(WormholeError): class NoKeyError(WormholeError):
"""w.derive_key() was called before got_verifier() fired""" """w.derive_key() was called before got_verifier() fired"""
class WormholeClosed(Exception):
"""Deferred-returning API calls errback with WormholeClosed if the
wormhole was already closed, or if it closes before a real result can be
obtained."""

View File

@ -10,7 +10,7 @@ from .timing import DebugTiming
from .journal import ImmediateJournal from .journal import ImmediateJournal
from ._boss import Boss from ._boss import Boss
from ._key import derive_key from ._key import derive_key
from .errors import WelcomeError, NoKeyError from .errors import WelcomeError, NoKeyError, WormholeClosed
from .util import to_bytes from .util import to_bytes
# We can provide different APIs to different apps: # We can provide different APIs to different apps:
@ -131,9 +131,6 @@ class _DelegatedWormhole(object):
def closed(self, result): def closed(self, result):
self._delegate.wormhole_closed(result) self._delegate.wormhole_closed(result)
class WormholeClosed(Exception):
pass
@implementer(IWormhole) @implementer(IWormhole)
class _DeferredWormhole(object): class _DeferredWormhole(object):
def __init__(self): def __init__(self):
@ -146,6 +143,7 @@ class _DeferredWormhole(object):
self._version_observers = [] self._version_observers = []
self._received_data = [] self._received_data = []
self._received_observers = [] self._received_observers = []
self._observer_result = None
self._closed_result = None self._closed_result = None
self._closed_observers = [] self._closed_observers = []
@ -154,20 +152,28 @@ class _DeferredWormhole(object):
# from above # from above
def when_code(self): def when_code(self):
if self._code: # TODO: consider throwing error unless one of allocate/set/input_code
# was called first
if self._observer_result is not None:
return defer.fail(self._observer_result)
if self._code is not None:
return defer.succeed(self._code) return defer.succeed(self._code)
d = defer.Deferred() d = defer.Deferred()
self._code_observers.append(d) self._code_observers.append(d)
return d return d
def when_verifier(self): def when_verifier(self):
if self._verifier: if self._observer_result is not None:
return defer.fail(self._observer_result)
if self._verifier is not None:
return defer.succeed(self._verifier) return defer.succeed(self._verifier)
d = defer.Deferred() d = defer.Deferred()
self._verifier_observers.append(d) self._verifier_observers.append(d)
return d return d
def when_version(self): def when_version(self):
if self._observer_result is not None:
return defer.fail(self._observer_result)
if self._version is not None: if self._version is not None:
return defer.succeed(self._version) return defer.succeed(self._version)
d = defer.Deferred() d = defer.Deferred()
@ -175,6 +181,8 @@ class _DeferredWormhole(object):
return d return d
def when_received(self): def when_received(self):
if self._observer_result is not None:
return defer.fail(self._observer_result)
if self._received_data: if self._received_data:
return defer.succeed(self._received_data.pop(0)) return defer.succeed(self._received_data.pop(0))
d = defer.Deferred() d = defer.Deferred()
@ -209,9 +217,9 @@ class _DeferredWormhole(object):
# WormholeError) if state=="scary". # WormholeError) if state=="scary".
if self._closed_result: if self._closed_result:
return defer.succeed(self._closed_result) # maybe Failure return defer.succeed(self._closed_result) # maybe Failure
self._boss.close() # only need to close if it wasn't already
d = defer.Deferred() d = defer.Deferred()
self._closed_observers.append(d) self._closed_observers.append(d)
self._boss.close() # only need to close if it wasn't already
return d return d
def debug_set_trace(self, client_name, which="B N M S O K R RC L C T", def debug_set_trace(self, client_name, which="B N M S O K R RC L C T",
@ -248,18 +256,18 @@ class _DeferredWormhole(object):
def closed(self, result): def closed(self, result):
#print("closed", result, type(result)) #print("closed", result, type(result))
if isinstance(result, Exception): if isinstance(result, Exception):
observer_result = self._closed_result = failure.Failure(result) self._observer_result = self._closed_result = failure.Failure(result)
else: else:
# pending w.verify()/w.version()/w.read() get an error # pending w.verify()/w.version()/w.read() get an error
observer_result = WormholeClosed(result) self._observer_result = WormholeClosed(result)
# but w.close() only gets error if we're unhappy # but w.close() only gets error if we're unhappy
self._closed_result = result self._closed_result = result
for d in self._verifier_observers: for d in self._verifier_observers:
d.errback(observer_result) d.errback(self._observer_result)
for d in self._version_observers: for d in self._version_observers:
d.errback(observer_result) d.errback(self._observer_result)
for d in self._received_observers: for d in self._received_observers:
d.errback(observer_result) d.errback(self._observer_result)
for d in self._closed_observers: for d in self._closed_observers:
d.callback(self._closed_result) d.callback(self._closed_result)