Merge branch 'observers-4'

This factors out the various "get me a Deferred which fires when/if we
compute a value" code from the _DeferredWormhole API calls: get_code,
get_unverified_key, get_versions, get_message, etc. It uses an eventual-send
for each one, which will protect against surprises when an application
invokes an wormhole API from within a previous API's callback: without this,
the internal wormhole state isn't guaranteed to be coherent, and crashes
could result.
This commit is contained in:
Brian Warner 2018-02-26 18:32:49 -08:00
commit c5ae678417
7 changed files with 408 additions and 169 deletions

50
src/wormhole/eventual.py Normal file
View File

@ -0,0 +1,50 @@
# inspired-by/adapted-from Foolscap's eventual.py, which Glyph wrote for me
# years ago.
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IReactorTime
from twisted.python import log
class EventualQueue(object):
def __init__(self, clock):
# pass clock=reactor unless you're testing
self._clock = IReactorTime(clock)
self._calls = []
self._flush_d = None
self._timer = None
def eventually(self, f, *args, **kwargs):
self._calls.append( (f, args, kwargs) )
if not self._timer:
self._timer = self._clock.callLater(0, self._turn)
def fire_eventually(self, value=None):
d = Deferred()
self.eventually(d.callback, value)
return d
def _turn(self):
while self._calls:
(f, args, kwargs) = self._calls.pop(0)
try:
f(*args, **kwargs)
except:
log.err()
self._timer = None
d, self._flush_d = self._flush_d, None
if d:
d.callback(None)
def flush_sync(self):
# if you have control over the Clock, this will synchronously flush the
# queue
assert self._clock.advance, "needs clock=twisted.internet.task.Clock()"
while self._calls:
self._clock.advance(0)
def flush(self):
# this is for unit tests, not application code
assert not self._flush_d, "only one flush at a time"
self._flush_d = Deferred()
self.eventually(lambda: None)
return self._flush_d

69
src/wormhole/observer.py Normal file
View File

@ -0,0 +1,69 @@
from __future__ import unicode_literals, print_function
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
NoResult = object()
class OneShotObserver(object):
def __init__(self, eventual_queue):
self._eq = eventual_queue
self._result = NoResult
self._observers = [] # list of Deferreds
def when_fired(self):
d = Deferred()
self._observers.append(d)
self._maybe_call_observers()
return d
def fire(self, result):
assert self._result is NoResult
self._result = result
self._maybe_call_observers()
def _maybe_call_observers(self):
if self._result is NoResult:
return
observers, self._observers = self._observers, []
for d in observers:
self._eq.eventually(d.callback, self._result)
def error(self, f):
# errors will override an existing result
assert isinstance(f, Failure)
self._result = f
self._maybe_call_observers()
def fire_if_not_fired(self, result):
if self._result is NoResult:
self.fire(result)
class SequenceObserver(object):
def __init__(self, eventual_queue):
self._eq = eventual_queue
self._error = None
self._results = []
self._observers = []
def when_next_event(self):
d = Deferred()
if self._error:
self._eq.eventually(d.errback, self._error)
elif self._results:
result = self._results.pop(0)
self._eq.eventually(d.callback, result)
else:
self._observers.append(d)
return d
def fire(self, result):
if isinstance(result, Failure):
self._error = result
for d in self._observers:
self._eq.eventually(d.errback, self._error)
self._observers = []
else:
self._results.append(result)
if self._observers:
d = self._observers.pop(0)
self._eq.eventually(d.callback, self._results.pop(0))

View File

@ -128,10 +128,3 @@ def poll_until(predicate):
d = defer.Deferred() d = defer.Deferred()
reactor.callLater(0.001, d.callback, None) reactor.callLater(0.001, d.callback, None)
yield d yield d
@defer.inlineCallbacks
def pause_one_tick():
# return a Deferred that won't fire until at least the next reactor tick
d = defer.Deferred()
reactor.callLater(0.001, d.callback, None)
yield d

View File

@ -0,0 +1,57 @@
from __future__ import print_function, unicode_literals
import mock
from twisted.trial import unittest
from twisted.internet import reactor
from twisted.internet.task import Clock
from twisted.internet.defer import Deferred, inlineCallbacks
from ..eventual import EventualQueue
class IntentionalError(Exception):
pass
class Eventual(unittest.TestCase, object):
def test_eventually(self):
c = Clock()
eq = EventualQueue(c)
c1 = mock.Mock()
eq.eventually(c1, "arg1", "arg2", kwarg1="kw1")
eq.eventually(c1, "arg3", "arg4", kwarg5="kw5")
d2 = eq.fire_eventually()
d3 = eq.fire_eventually("value")
self.assertEqual(c1.mock_calls, [])
self.assertNoResult(d2)
self.assertNoResult(d3)
eq.flush_sync()
self.assertEqual(c1.mock_calls,
[mock.call("arg1", "arg2", kwarg1="kw1"),
mock.call("arg3", "arg4", kwarg5="kw5")])
self.assertEqual(self.successResultOf(d2), None)
self.assertEqual(self.successResultOf(d3), "value")
def test_error(self):
c = Clock()
eq = EventualQueue(c)
c1 = mock.Mock(side_effect=IntentionalError)
eq.eventually(c1, "arg1", "arg2", kwarg1="kw1")
self.assertEqual(c1.mock_calls, [])
eq.flush_sync()
self.assertEqual(c1.mock_calls,
[mock.call("arg1", "arg2", kwarg1="kw1")])
self.flushLoggedErrors(IntentionalError)
@inlineCallbacks
def test_flush(self):
eq = EventualQueue(reactor)
d1 = eq.fire_eventually()
d2 = Deferred()
def _more(res):
eq.eventually(d2.callback, None)
d1.addCallback(_more)
yield eq.flush()
# d1 will fire, which will queue d2 to fire, and the flush() ought to
# wait for d2 too
self.successResultOf(d2)

View File

@ -0,0 +1,122 @@
from twisted.trial import unittest
from twisted.internet.task import Clock
from twisted.python.failure import Failure
from ..eventual import EventualQueue
from ..observer import OneShotObserver, SequenceObserver
class OneShot(unittest.TestCase):
def test_fire(self):
c = Clock()
eq = EventualQueue(c)
o = OneShotObserver(eq)
res = object()
d1 = o.when_fired()
eq.flush_sync()
self.assertNoResult(d1)
o.fire(res)
eq.flush_sync()
self.assertIdentical(self.successResultOf(d1), res)
d2 = o.when_fired()
eq.flush_sync()
self.assertIdentical(self.successResultOf(d2), res)
o.fire_if_not_fired(object())
eq.flush_sync()
def test_fire_if_not_fired(self):
c = Clock()
eq = EventualQueue(c)
o = OneShotObserver(eq)
res1 = object()
res2 = object()
d1 = o.when_fired()
eq.flush_sync()
self.assertNoResult(d1)
o.fire_if_not_fired(res1)
o.fire_if_not_fired(res2)
eq.flush_sync()
self.assertIdentical(self.successResultOf(d1), res1)
def test_error_before_firing(self):
c = Clock()
eq = EventualQueue(c)
o = OneShotObserver(eq)
f = Failure(ValueError("oops"))
d1 = o.when_fired()
eq.flush_sync()
self.assertNoResult(d1)
o.error(f)
eq.flush_sync()
self.assertIdentical(self.failureResultOf(d1), f)
d2 = o.when_fired()
eq.flush_sync()
self.assertIdentical(self.failureResultOf(d2), f)
def test_error_after_firing(self):
c = Clock()
eq = EventualQueue(c)
o = OneShotObserver(eq)
res = object()
f = Failure(ValueError("oops"))
o.fire(res)
eq.flush_sync()
d1 = o.when_fired()
o.error(f)
d2 = o.when_fired()
eq.flush_sync()
self.assertIdentical(self.successResultOf(d1), res)
self.assertIdentical(self.failureResultOf(d2), f)
class Sequence(unittest.TestCase):
def test_fire(self):
c = Clock()
eq = EventualQueue(c)
o = SequenceObserver(eq)
d1 = o.when_next_event()
eq.flush_sync()
self.assertNoResult(d1)
d2 = o.when_next_event()
eq.flush_sync()
self.assertNoResult(d1)
self.assertNoResult(d2)
ev1 = object()
ev2 = object()
o.fire(ev1)
eq.flush_sync()
self.assertIdentical(self.successResultOf(d1), ev1)
self.assertNoResult(d2)
o.fire(ev2)
eq.flush_sync()
self.assertIdentical(self.successResultOf(d2), ev2)
ev3 = object()
ev4 = object()
o.fire(ev3)
o.fire(ev4)
d3 = o.when_next_event()
eq.flush_sync()
self.assertIdentical(self.successResultOf(d3), ev3)
d4 = o.when_next_event()
eq.flush_sync()
self.assertIdentical(self.successResultOf(d4), ev4)
def test_error(self):
c = Clock()
eq = EventualQueue(c)
o = SequenceObserver(eq)
d1 = o.when_next_event()
eq.flush_sync()
self.assertNoResult(d1)
f = Failure(ValueError("oops"))
o.fire(f)
eq.flush_sync()
self.assertIdentical(self.failureResultOf(d1), f)
d2 = o.when_next_event()
eq.flush_sync()
self.assertIdentical(self.failureResultOf(d2), f)

View File

@ -5,12 +5,13 @@ from twisted.trial import unittest
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
from twisted.internet.error import ConnectionRefusedError from twisted.internet.error import ConnectionRefusedError
from .common import ServerBase, poll_until, pause_one_tick from .common import ServerBase, poll_until
from .. import wormhole, _rendezvous from .. import wormhole, _rendezvous
from ..errors import (WrongPasswordError, ServerConnectionError, from ..errors import (WrongPasswordError, ServerConnectionError,
KeyFormatError, WormholeClosed, LonelyError, KeyFormatError, WormholeClosed, LonelyError,
NoKeyError, OnlyOneCodeError) NoKeyError, OnlyOneCodeError)
from ..transit import allocate_tcp_port from ..transit import allocate_tcp_port
from ..eventual import EventualQueue
APPID = "appid" APPID = "appid"
@ -159,9 +160,6 @@ class Wormholes(ServerBase, unittest.TestCase):
verifier2 = yield w2.get_verifier() verifier2 = yield w2.get_verifier()
self.assertEqual(verifier1, verifier2) self.assertEqual(verifier1, verifier2)
self.successResultOf(w1.get_unverified_key())
self.successResultOf(w2.get_unverified_key())
versions1 = yield w1.get_versions() versions1 = yield w1.get_versions()
versions2 = yield w2.get_versions() versions2 = yield w2.get_versions()
# app-versions are exercised properly in test_versions, this just # app-versions are exercised properly in test_versions, this just
@ -186,18 +184,22 @@ class Wormholes(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_get_code_early(self): def test_get_code_early(self):
w1 = wormhole.create(APPID, self.relayurl, reactor) eq = EventualQueue(reactor)
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
d = w1.get_code() d = w1.get_code()
w1.set_code("1-abc") w1.set_code("1-abc")
yield eq.flush()
code = self.successResultOf(d) code = self.successResultOf(d)
self.assertEqual(code, "1-abc") self.assertEqual(code, "1-abc")
yield self.assertFailure(w1.close(), LonelyError) yield self.assertFailure(w1.close(), LonelyError)
@inlineCallbacks @inlineCallbacks
def test_get_code_late(self): def test_get_code_late(self):
w1 = wormhole.create(APPID, self.relayurl, reactor) eq = EventualQueue(reactor)
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w1.set_code("1-abc") w1.set_code("1-abc")
d = w1.get_code() d = w1.get_code()
yield eq.flush()
code = self.successResultOf(d) code = self.successResultOf(d)
self.assertEqual(code, "1-abc") self.assertEqual(code, "1-abc")
yield self.assertFailure(w1.close(), LonelyError) yield self.assertFailure(w1.close(), LonelyError)
@ -323,8 +325,9 @@ class Wormholes(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_closed(self): def test_closed(self):
w1 = wormhole.create(APPID, self.relayurl, reactor) eq = EventualQueue(reactor)
w2 = wormhole.create(APPID, self.relayurl, reactor) w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w1.set_code("123-foo") w1.set_code("123-foo")
w2.set_code("123-foo") w2.set_code("123-foo")
@ -335,14 +338,14 @@ class Wormholes(ServerBase, unittest.TestCase):
yield w1.close() yield w1.close()
yield w2.close() yield w2.close()
# once closed, all Deferred-yielding API calls get an immediate error # once closed, all Deferred-yielding API calls get an prompt error
self.failureResultOf(w1.get_welcome(), WormholeClosed) yield self.assertFailure(w1.get_welcome(), WormholeClosed)
f = self.failureResultOf(w1.get_code(), WormholeClosed) e = yield self.assertFailure(w1.get_code(), WormholeClosed)
self.assertEqual(f.value.args[0], "happy") self.assertEqual(e.args[0], "happy")
self.failureResultOf(w1.get_unverified_key(), WormholeClosed) yield self.assertFailure(w1.get_unverified_key(), WormholeClosed)
self.failureResultOf(w1.get_verifier(), WormholeClosed) yield self.assertFailure(w1.get_verifier(), WormholeClosed)
self.failureResultOf(w1.get_versions(), WormholeClosed) yield self.assertFailure(w1.get_versions(), WormholeClosed)
self.failureResultOf(w1.get_message(), WormholeClosed) yield self.assertFailure(w1.get_message(), WormholeClosed)
@inlineCallbacks @inlineCallbacks
def test_closed_idle(self): def test_closed_idle(self):
@ -360,17 +363,18 @@ class Wormholes(ServerBase, unittest.TestCase):
yield self.assertFailure(w1.close(), LonelyError) yield self.assertFailure(w1.close(), LonelyError)
self.failureResultOf(d_welcome, LonelyError) yield self.assertFailure(d_welcome, LonelyError)
self.failureResultOf(d_code, LonelyError) yield self.assertFailure(d_code, LonelyError)
self.failureResultOf(d_key, LonelyError) yield self.assertFailure(d_key, LonelyError)
self.failureResultOf(d_verifier, LonelyError) yield self.assertFailure(d_verifier, LonelyError)
self.failureResultOf(d_versions, LonelyError) yield self.assertFailure(d_versions, LonelyError)
self.failureResultOf(d_message, LonelyError) yield self.assertFailure(d_message, LonelyError)
@inlineCallbacks @inlineCallbacks
def test_wrong_password(self): def test_wrong_password(self):
w1 = wormhole.create(APPID, self.relayurl, reactor) eq = EventualQueue(reactor)
w2 = wormhole.create(APPID, self.relayurl, reactor) w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w1.allocate_code() w1.allocate_code()
code = yield w1.get_code() code = yield w1.get_code()
w2.set_code(code+"not") w2.set_code(code+"not")
@ -403,9 +407,8 @@ class Wormholes(ServerBase, unittest.TestCase):
# wait for each side to notice the failure # wait for each side to notice the failure
yield self.assertFailure(w1.get_verifier(), WrongPasswordError) yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
yield self.assertFailure(w2.get_verifier(), WrongPasswordError) yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
# and then wait for the rest of the loops to fire. if we had+used # the rest of the loops should fire within the next tick
# eventual-send, this wouldn't be a problem yield eq.flush()
yield pause_one_tick()
# now all the rest should have fired already # now all the rest should have fired already
self.failureResultOf(d1_verified, WrongPasswordError) self.failureResultOf(d1_verified, WrongPasswordError)
@ -420,27 +423,27 @@ class Wormholes(ServerBase, unittest.TestCase):
# before we close # before we close
# any new calls in the error state should immediately fail # any new calls in the error state should immediately fail
self.failureResultOf(w1.get_unverified_key(), WrongPasswordError) yield self.assertFailure(w1.get_unverified_key(), WrongPasswordError)
self.failureResultOf(w1.get_verifier(), WrongPasswordError) yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
self.failureResultOf(w1.get_versions(), WrongPasswordError) yield self.assertFailure(w1.get_versions(), WrongPasswordError)
self.failureResultOf(w1.get_message(), WrongPasswordError) yield self.assertFailure(w1.get_message(), WrongPasswordError)
self.failureResultOf(w2.get_unverified_key(), WrongPasswordError) yield self.assertFailure(w2.get_unverified_key(), WrongPasswordError)
self.failureResultOf(w2.get_verifier(), WrongPasswordError) yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
self.failureResultOf(w2.get_versions(), WrongPasswordError) yield self.assertFailure(w2.get_versions(), WrongPasswordError)
self.failureResultOf(w2.get_message(), WrongPasswordError) yield self.assertFailure(w2.get_message(), WrongPasswordError)
yield self.assertFailure(w1.close(), WrongPasswordError) yield self.assertFailure(w1.close(), WrongPasswordError)
yield self.assertFailure(w2.close(), WrongPasswordError) yield self.assertFailure(w2.close(), WrongPasswordError)
# API calls should still get the error, not WormholeClosed # API calls should still get the error, not WormholeClosed
self.failureResultOf(w1.get_unverified_key(), WrongPasswordError) yield self.assertFailure(w1.get_unverified_key(), WrongPasswordError)
self.failureResultOf(w1.get_verifier(), WrongPasswordError) yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
self.failureResultOf(w1.get_versions(), WrongPasswordError) yield self.assertFailure(w1.get_versions(), WrongPasswordError)
self.failureResultOf(w1.get_message(), WrongPasswordError) yield self.assertFailure(w1.get_message(), WrongPasswordError)
self.failureResultOf(w2.get_unverified_key(), WrongPasswordError) yield self.assertFailure(w2.get_unverified_key(), WrongPasswordError)
self.failureResultOf(w2.get_verifier(), WrongPasswordError) yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
self.failureResultOf(w2.get_versions(), WrongPasswordError) yield self.assertFailure(w2.get_versions(), WrongPasswordError)
self.failureResultOf(w2.get_message(), WrongPasswordError) yield self.assertFailure(w2.get_message(), WrongPasswordError)
@inlineCallbacks @inlineCallbacks
def test_wrong_password_with_spaces(self): def test_wrong_password_with_spaces(self):
@ -493,8 +496,9 @@ class Wormholes(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_verifier(self): def test_verifier(self):
w1 = wormhole.create(APPID, self.relayurl, reactor) eq = EventualQueue(reactor)
w2 = wormhole.create(APPID, self.relayurl, reactor) w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w1.allocate_code() w1.allocate_code()
code = yield w1.get_code() code = yield w1.get_code()
w2.set_code(code) w2.set_code(code)
@ -510,7 +514,9 @@ class Wormholes(ServerBase, unittest.TestCase):
self.assertEqual(dataY, b"data1") self.assertEqual(dataY, b"data1")
# calling get_verifier() this late should fire right away # calling get_verifier() this late should fire right away
v1_late = self.successResultOf(w2.get_verifier()) d = w2.get_verifier()
yield eq.flush()
v1_late = self.successResultOf(d)
self.assertEqual(v1_late, v1) self.assertEqual(v1_late, v1)
yield w1.close() yield w1.close()
@ -644,26 +650,30 @@ class Reconnection(ServerBase, unittest.TestCase):
self.assertEqual(c2, "happy") self.assertEqual(c2, "happy")
class InitialFailure(unittest.TestCase): class InitialFailure(unittest.TestCase):
def assertSCEResultOf(self, d, innerType): @inlineCallbacks
def assertSCEFailure(self, eq, d, innerType):
yield eq.flush()
f = self.failureResultOf(d, ServerConnectionError) f = self.failureResultOf(d, ServerConnectionError)
inner = f.value.reason inner = f.value.reason
self.assertIsInstance(inner, innerType) self.assertIsInstance(inner, innerType)
return inner returnValue(inner)
@inlineCallbacks @inlineCallbacks
def test_bad_dns(self): def test_bad_dns(self):
eq = EventualQueue(reactor)
# point at a URL that will never connect # point at a URL that will never connect
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", reactor) w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1",
reactor, _eventual_queue=eq)
# that should have already received an error, when it tried to # that should have already received an error, when it tried to
# resolve the bogus DNS name. All API calls will return an error. # resolve the bogus DNS name. All API calls will return an error.
e = yield self.assertFailure(w.get_unverified_key(),
ServerConnectionError) e = yield self.assertSCEFailure(eq, w.get_unverified_key(), ValueError)
self.assertIsInstance(e.reason, ValueError) self.assertIsInstance(e, ValueError)
self.assertEqual(str(e), "invalid hostname: %%%.example.org") self.assertEqual(str(e), "invalid hostname: %%%.example.org")
self.assertSCEResultOf(w.get_code(), ValueError) yield self.assertSCEFailure(eq, w.get_code(), ValueError)
self.assertSCEResultOf(w.get_verifier(), ValueError) yield self.assertSCEFailure(eq, w.get_verifier(), ValueError)
self.assertSCEResultOf(w.get_versions(), ValueError) yield self.assertSCEFailure(eq, w.get_versions(), ValueError)
self.assertSCEResultOf(w.get_message(), ValueError) yield self.assertSCEFailure(eq, w.get_message(), ValueError)
@inlineCallbacks @inlineCallbacks
def assertSCE(self, d, innerType): def assertSCE(self, d, innerType):

View File

@ -3,9 +3,10 @@ import os, sys
from attr import attrs, attrib from attr import attrs, attrib
from zope.interface import implementer from zope.interface import implementer
from twisted.python import failure from twisted.python import failure
from twisted.internet import defer
from ._interfaces import IWormhole, IDeferredWormhole from ._interfaces import IWormhole, IDeferredWormhole
from .util import bytes_to_hexstr from .util import bytes_to_hexstr
from .eventual import EventualQueue
from .observer import OneShotObserver, SequenceObserver
from .timing import DebugTiming from .timing import DebugTiming
from .journal import ImmediateJournal from .journal import ImmediateJournal
from ._boss import Boss from ._boss import Boss
@ -100,22 +101,16 @@ class _DelegatedWormhole(object):
@implementer(IWormhole, IDeferredWormhole) @implementer(IWormhole, IDeferredWormhole)
class _DeferredWormhole(object): class _DeferredWormhole(object):
def __init__(self): def __init__(self, eq):
self._welcome = None self._welcome_observer = OneShotObserver(eq)
self._welcome_observers = [] self._code_observer = OneShotObserver(eq)
self._code = None
self._code_observers = []
self._key = None self._key = None
self._key_observers = [] self._key_observer = OneShotObserver(eq)
self._verifier = None self._verifier_observer = OneShotObserver(eq)
self._verifier_observers = [] self._version_observer = OneShotObserver(eq)
self._versions = None self._received_observer = SequenceObserver(eq)
self._version_observers = [] self._closed = False
self._received_data = [] self._closed_observer = OneShotObserver(eq)
self._received_observers = []
self._observer_result = None
self._closed_result = None
self._closed_observers = []
def _set_boss(self, boss): def _set_boss(self, boss):
self._boss = boss self._boss = boss
@ -127,58 +122,22 @@ class _DeferredWormhole(object):
# the process that will cause it to fire, but forbidding that # the process that will cause it to fire, but forbidding that
# ordering would make it easier to cause programming errors that # ordering would make it easier to cause programming errors that
# forget to trigger it entirely. # forget to trigger it entirely.
if self._observer_result is not None: return self._code_observer.when_fired()
return defer.fail(self._observer_result)
if self._code is not None:
return defer.succeed(self._code)
d = defer.Deferred()
self._code_observers.append(d)
return d
def get_welcome(self): def get_welcome(self):
if self._observer_result is not None: return self._welcome_observer.when_fired()
return defer.fail(self._observer_result)
if self._welcome is not None:
return defer.succeed(self._welcome)
d = defer.Deferred()
self._welcome_observers.append(d)
return d
def get_unverified_key(self): def get_unverified_key(self):
if self._observer_result is not None: return self._key_observer.when_fired()
return defer.fail(self._observer_result)
if self._key is not None:
return defer.succeed(self._key)
d = defer.Deferred()
self._key_observers.append(d)
return d
def get_verifier(self): def get_verifier(self):
if self._observer_result is not None: return self._verifier_observer.when_fired()
return defer.fail(self._observer_result)
if self._verifier is not None:
return defer.succeed(self._verifier)
d = defer.Deferred()
self._verifier_observers.append(d)
return d
def get_versions(self): def get_versions(self):
if self._observer_result is not None: return self._version_observer.when_fired()
return defer.fail(self._observer_result)
if self._versions is not None:
return defer.succeed(self._versions)
d = defer.Deferred()
self._version_observers.append(d)
return d
def get_message(self): def get_message(self):
if self._observer_result is not None: return self._received_observer.when_next_event()
return defer.fail(self._observer_result)
if self._received_data:
return defer.succeed(self._received_data.pop(0))
d = defer.Deferred()
self._received_observers.append(d)
return d
def allocate_code(self, code_length=2): def allocate_code(self, code_length=2):
self._boss.allocate_code(code_length) self._boss.allocate_code(code_length)
@ -207,10 +166,8 @@ class _DeferredWormhole(object):
# fails with WormholeError unless we established a connection # fails with WormholeError unless we established a connection
# (state=="happy"). Fails with WrongPasswordError (a subclass of # (state=="happy"). Fails with WrongPasswordError (a subclass of
# WormholeError) if state=="scary". # WormholeError) if state=="scary".
if self._closed_result: d = self._closed_observer.when_fired() # maybe Failure
return defer.succeed(self._closed_result) # maybe Failure if not self._closed:
d = defer.Deferred()
self._closed_observers.append(d)
self._boss.close() # only need to close if it wasn't already self._boss.close() # only need to close if it wasn't already
return d return d
@ -221,75 +178,56 @@ class _DeferredWormhole(object):
# from below # from below
def got_welcome(self, welcome): def got_welcome(self, welcome):
self._welcome = welcome self._welcome_observer.fire_if_not_fired(welcome)
for d in self._welcome_observers:
d.callback(welcome)
self._welcome_observers[:] = []
def got_code(self, code): def got_code(self, code):
self._code = code self._code_observer.fire_if_not_fired(code)
for d in self._code_observers:
d.callback(code)
self._code_observers[:] = []
def got_key(self, key): def got_key(self, key):
self._key = key # for derive_key() self._key = key # for derive_key()
for d in self._key_observers: self._key_observer.fire_if_not_fired(key)
d.callback(key)
self._key_observers[:] = []
def got_verifier(self, verifier): def got_verifier(self, verifier):
self._verifier = verifier self._verifier_observer.fire_if_not_fired(verifier)
for d in self._verifier_observers:
d.callback(verifier)
self._verifier_observers[:] = []
def got_versions(self, versions): def got_versions(self, versions):
self._versions = versions self._version_observer.fire_if_not_fired(versions)
for d in self._version_observers:
d.callback(versions)
self._version_observers[:] = []
def received(self, plaintext): def received(self, plaintext):
if self._received_observers: self._received_observer.fire(plaintext)
self._received_observers.pop(0).callback(plaintext)
return
self._received_data.append(plaintext)
def closed(self, result): def closed(self, result):
self._closed = True
#print("closed", result, type(result), file=sys.stderr) #print("closed", result, type(result), file=sys.stderr)
if isinstance(result, Exception): if isinstance(result, Exception):
self._observer_result = self._closed_result = failure.Failure(result) # everything pending gets an error, including close()
f = failure.Failure(result)
self._closed_observer.error(f)
else: else:
# pending w.key()/w.verify()/w.version()/w.read() get an error # everything pending except close() gets an error:
self._observer_result = WormholeClosed(result) # w.get_code()/welcome/unverified_key/verifier/versions/message
f = failure.Failure(WormholeClosed(result))
# but w.close() only gets error if we're unhappy # but w.close() only gets error if we're unhappy
self._closed_result = result self._closed_observer.fire_if_not_fired(result)
for d in self._welcome_observers: self._welcome_observer.error(f)
d.errback(self._observer_result) self._code_observer.error(f)
for d in self._code_observers: self._key_observer.error(f)
d.errback(self._observer_result) self._verifier_observer.error(f)
for d in self._key_observers: self._version_observer.error(f)
d.errback(self._observer_result) self._received_observer.fire(f)
for d in self._verifier_observers:
d.errback(self._observer_result)
for d in self._version_observers:
d.errback(self._observer_result)
for d in self._received_observers:
d.errback(self._observer_result)
for d in self._closed_observers:
d.callback(self._closed_result)
def create(appid, relay_url, reactor, # use keyword args for everything else def create(appid, relay_url, reactor, # use keyword args for everything else
versions={}, versions={},
delegate=None, journal=None, tor=None, delegate=None, journal=None, tor=None,
timing=None, timing=None,
stderr=sys.stderr): stderr=sys.stderr,
_eventual_queue=None):
timing = timing or DebugTiming() timing = timing or DebugTiming()
side = bytes_to_hexstr(os.urandom(5)) side = bytes_to_hexstr(os.urandom(5))
journal = journal or ImmediateJournal() journal = journal or ImmediateJournal()
eq = _eventual_queue or EventualQueue(reactor)
if delegate: if delegate:
w = _DelegatedWormhole(delegate) w = _DelegatedWormhole(delegate)
else: else:
w = _DeferredWormhole() w = _DeferredWormhole(eq)
wormhole_versions = {} # will be used to indicate Wormhole capabilities wormhole_versions = {} # will be used to indicate Wormhole capabilities
wormhole_versions["app_versions"] = versions # app-specific capabilities wormhole_versions["app_versions"] = versions # app-specific capabilities
b = Boss(w, side, relay_url, appid, wormhole_versions, b = Boss(w, side, relay_url, appid, wormhole_versions,