From aebee618167bbd5ce926fa6997197063bc6e22c7 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 7 Mar 2017 12:09:06 +0100 Subject: [PATCH] fix close behavior: Deferreds should errback once closed --- src/wormhole/errors.py | 6 ++++++ src/wormhole/wormhole.py | 32 ++++++++++++++++++++------------ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index 865b5e2..b6ba419 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -53,3 +53,9 @@ class NoTorError(WormholeError): class NoKeyError(WormholeError): """w.derive_key() was called before got_verifier() fired""" + +class WormholeClosed(Exception): + """Deferred-returning API calls errback with WormholeClosed if the + wormhole was already closed, or if it closes before a real result can be + obtained.""" + diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index d316daf..e73eb1b 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -10,7 +10,7 @@ from .timing import DebugTiming from .journal import ImmediateJournal from ._boss import Boss from ._key import derive_key -from .errors import WelcomeError, NoKeyError +from .errors import WelcomeError, NoKeyError, WormholeClosed from .util import to_bytes # We can provide different APIs to different apps: @@ -131,9 +131,6 @@ class _DelegatedWormhole(object): def closed(self, result): self._delegate.wormhole_closed(result) -class WormholeClosed(Exception): - pass - @implementer(IWormhole) class _DeferredWormhole(object): def __init__(self): @@ -146,6 +143,7 @@ class _DeferredWormhole(object): self._version_observers = [] self._received_data = [] self._received_observers = [] + self._observer_result = None self._closed_result = None self._closed_observers = [] @@ -154,20 +152,28 @@ class _DeferredWormhole(object): # from above def when_code(self): - if self._code: + # TODO: consider throwing error unless one of allocate/set/input_code + # was called first + if self._observer_result is not None: + return defer.fail(self._observer_result) + if self._code is not None: return defer.succeed(self._code) d = defer.Deferred() self._code_observers.append(d) return d def when_verifier(self): - if self._verifier: + if self._observer_result is not None: + return defer.fail(self._observer_result) + if self._verifier is not None: return defer.succeed(self._verifier) d = defer.Deferred() self._verifier_observers.append(d) return d def when_version(self): + if self._observer_result is not None: + return defer.fail(self._observer_result) if self._version is not None: return defer.succeed(self._version) d = defer.Deferred() @@ -175,6 +181,8 @@ class _DeferredWormhole(object): return d def when_received(self): + if self._observer_result is not None: + return defer.fail(self._observer_result) if self._received_data: return defer.succeed(self._received_data.pop(0)) d = defer.Deferred() @@ -209,9 +217,9 @@ class _DeferredWormhole(object): # WormholeError) if state=="scary". if self._closed_result: return defer.succeed(self._closed_result) # maybe Failure - self._boss.close() # only need to close if it wasn't already d = defer.Deferred() self._closed_observers.append(d) + self._boss.close() # only need to close if it wasn't already return d def debug_set_trace(self, client_name, which="B N M S O K R RC L C T", @@ -248,18 +256,18 @@ class _DeferredWormhole(object): def closed(self, result): #print("closed", result, type(result)) if isinstance(result, Exception): - observer_result = self._closed_result = failure.Failure(result) + self._observer_result = self._closed_result = failure.Failure(result) else: # pending w.verify()/w.version()/w.read() get an error - observer_result = WormholeClosed(result) + self._observer_result = WormholeClosed(result) # but w.close() only gets error if we're unhappy self._closed_result = result for d in self._verifier_observers: - d.errback(observer_result) + d.errback(self._observer_result) for d in self._version_observers: - d.errback(observer_result) + d.errback(self._observer_result) for d in self._received_observers: - d.errback(observer_result) + d.errback(self._observer_result) for d in self._closed_observers: d.callback(self._closed_result)