Merge branch 'observers-4'
This factors out the various "get me a Deferred which fires when/if we compute a value" code from the _DeferredWormhole API calls: get_code, get_unverified_key, get_versions, get_message, etc. It uses an eventual-send for each one, which will protect against surprises when an application invokes an wormhole API from within a previous API's callback: without this, the internal wormhole state isn't guaranteed to be coherent, and crashes could result.
This commit is contained in:
commit
c5ae678417
50
src/wormhole/eventual.py
Normal file
50
src/wormhole/eventual.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
# inspired-by/adapted-from Foolscap's eventual.py, which Glyph wrote for me
|
||||||
|
# years ago.
|
||||||
|
|
||||||
|
from twisted.internet.defer import Deferred
|
||||||
|
from twisted.internet.interfaces import IReactorTime
|
||||||
|
from twisted.python import log
|
||||||
|
|
||||||
|
class EventualQueue(object):
|
||||||
|
def __init__(self, clock):
|
||||||
|
# pass clock=reactor unless you're testing
|
||||||
|
self._clock = IReactorTime(clock)
|
||||||
|
self._calls = []
|
||||||
|
self._flush_d = None
|
||||||
|
self._timer = None
|
||||||
|
|
||||||
|
def eventually(self, f, *args, **kwargs):
|
||||||
|
self._calls.append( (f, args, kwargs) )
|
||||||
|
if not self._timer:
|
||||||
|
self._timer = self._clock.callLater(0, self._turn)
|
||||||
|
|
||||||
|
def fire_eventually(self, value=None):
|
||||||
|
d = Deferred()
|
||||||
|
self.eventually(d.callback, value)
|
||||||
|
return d
|
||||||
|
|
||||||
|
def _turn(self):
|
||||||
|
while self._calls:
|
||||||
|
(f, args, kwargs) = self._calls.pop(0)
|
||||||
|
try:
|
||||||
|
f(*args, **kwargs)
|
||||||
|
except:
|
||||||
|
log.err()
|
||||||
|
self._timer = None
|
||||||
|
d, self._flush_d = self._flush_d, None
|
||||||
|
if d:
|
||||||
|
d.callback(None)
|
||||||
|
|
||||||
|
def flush_sync(self):
|
||||||
|
# if you have control over the Clock, this will synchronously flush the
|
||||||
|
# queue
|
||||||
|
assert self._clock.advance, "needs clock=twisted.internet.task.Clock()"
|
||||||
|
while self._calls:
|
||||||
|
self._clock.advance(0)
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
# this is for unit tests, not application code
|
||||||
|
assert not self._flush_d, "only one flush at a time"
|
||||||
|
self._flush_d = Deferred()
|
||||||
|
self.eventually(lambda: None)
|
||||||
|
return self._flush_d
|
69
src/wormhole/observer.py
Normal file
69
src/wormhole/observer.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
from __future__ import unicode_literals, print_function
|
||||||
|
from twisted.internet.defer import Deferred
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
|
NoResult = object()
|
||||||
|
|
||||||
|
class OneShotObserver(object):
|
||||||
|
def __init__(self, eventual_queue):
|
||||||
|
self._eq = eventual_queue
|
||||||
|
self._result = NoResult
|
||||||
|
self._observers = [] # list of Deferreds
|
||||||
|
|
||||||
|
def when_fired(self):
|
||||||
|
d = Deferred()
|
||||||
|
self._observers.append(d)
|
||||||
|
self._maybe_call_observers()
|
||||||
|
return d
|
||||||
|
|
||||||
|
def fire(self, result):
|
||||||
|
assert self._result is NoResult
|
||||||
|
self._result = result
|
||||||
|
self._maybe_call_observers()
|
||||||
|
|
||||||
|
def _maybe_call_observers(self):
|
||||||
|
if self._result is NoResult:
|
||||||
|
return
|
||||||
|
observers, self._observers = self._observers, []
|
||||||
|
for d in observers:
|
||||||
|
self._eq.eventually(d.callback, self._result)
|
||||||
|
|
||||||
|
def error(self, f):
|
||||||
|
# errors will override an existing result
|
||||||
|
assert isinstance(f, Failure)
|
||||||
|
self._result = f
|
||||||
|
self._maybe_call_observers()
|
||||||
|
|
||||||
|
def fire_if_not_fired(self, result):
|
||||||
|
if self._result is NoResult:
|
||||||
|
self.fire(result)
|
||||||
|
|
||||||
|
class SequenceObserver(object):
|
||||||
|
def __init__(self, eventual_queue):
|
||||||
|
self._eq = eventual_queue
|
||||||
|
self._error = None
|
||||||
|
self._results = []
|
||||||
|
self._observers = []
|
||||||
|
|
||||||
|
def when_next_event(self):
|
||||||
|
d = Deferred()
|
||||||
|
if self._error:
|
||||||
|
self._eq.eventually(d.errback, self._error)
|
||||||
|
elif self._results:
|
||||||
|
result = self._results.pop(0)
|
||||||
|
self._eq.eventually(d.callback, result)
|
||||||
|
else:
|
||||||
|
self._observers.append(d)
|
||||||
|
return d
|
||||||
|
|
||||||
|
def fire(self, result):
|
||||||
|
if isinstance(result, Failure):
|
||||||
|
self._error = result
|
||||||
|
for d in self._observers:
|
||||||
|
self._eq.eventually(d.errback, self._error)
|
||||||
|
self._observers = []
|
||||||
|
else:
|
||||||
|
self._results.append(result)
|
||||||
|
if self._observers:
|
||||||
|
d = self._observers.pop(0)
|
||||||
|
self._eq.eventually(d.callback, self._results.pop(0))
|
|
@ -128,10 +128,3 @@ def poll_until(predicate):
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
reactor.callLater(0.001, d.callback, None)
|
reactor.callLater(0.001, d.callback, None)
|
||||||
yield d
|
yield d
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def pause_one_tick():
|
|
||||||
# return a Deferred that won't fire until at least the next reactor tick
|
|
||||||
d = defer.Deferred()
|
|
||||||
reactor.callLater(0.001, d.callback, None)
|
|
||||||
yield d
|
|
||||||
|
|
57
src/wormhole/test/test_eventual.py
Normal file
57
src/wormhole/test/test_eventual.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
from __future__ import print_function, unicode_literals
|
||||||
|
import mock
|
||||||
|
from twisted.trial import unittest
|
||||||
|
from twisted.internet import reactor
|
||||||
|
from twisted.internet.task import Clock
|
||||||
|
from twisted.internet.defer import Deferred, inlineCallbacks
|
||||||
|
from ..eventual import EventualQueue
|
||||||
|
|
||||||
|
class IntentionalError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Eventual(unittest.TestCase, object):
|
||||||
|
def test_eventually(self):
|
||||||
|
c = Clock()
|
||||||
|
eq = EventualQueue(c)
|
||||||
|
c1 = mock.Mock()
|
||||||
|
eq.eventually(c1, "arg1", "arg2", kwarg1="kw1")
|
||||||
|
eq.eventually(c1, "arg3", "arg4", kwarg5="kw5")
|
||||||
|
d2 = eq.fire_eventually()
|
||||||
|
d3 = eq.fire_eventually("value")
|
||||||
|
self.assertEqual(c1.mock_calls, [])
|
||||||
|
self.assertNoResult(d2)
|
||||||
|
self.assertNoResult(d3)
|
||||||
|
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertEqual(c1.mock_calls,
|
||||||
|
[mock.call("arg1", "arg2", kwarg1="kw1"),
|
||||||
|
mock.call("arg3", "arg4", kwarg5="kw5")])
|
||||||
|
self.assertEqual(self.successResultOf(d2), None)
|
||||||
|
self.assertEqual(self.successResultOf(d3), "value")
|
||||||
|
|
||||||
|
def test_error(self):
|
||||||
|
c = Clock()
|
||||||
|
eq = EventualQueue(c)
|
||||||
|
c1 = mock.Mock(side_effect=IntentionalError)
|
||||||
|
eq.eventually(c1, "arg1", "arg2", kwarg1="kw1")
|
||||||
|
self.assertEqual(c1.mock_calls, [])
|
||||||
|
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertEqual(c1.mock_calls,
|
||||||
|
[mock.call("arg1", "arg2", kwarg1="kw1")])
|
||||||
|
|
||||||
|
self.flushLoggedErrors(IntentionalError)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_flush(self):
|
||||||
|
eq = EventualQueue(reactor)
|
||||||
|
d1 = eq.fire_eventually()
|
||||||
|
d2 = Deferred()
|
||||||
|
def _more(res):
|
||||||
|
eq.eventually(d2.callback, None)
|
||||||
|
d1.addCallback(_more)
|
||||||
|
yield eq.flush()
|
||||||
|
# d1 will fire, which will queue d2 to fire, and the flush() ought to
|
||||||
|
# wait for d2 too
|
||||||
|
self.successResultOf(d2)
|
||||||
|
|
122
src/wormhole/test/test_observer.py
Normal file
122
src/wormhole/test/test_observer.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
from twisted.trial import unittest
|
||||||
|
from twisted.internet.task import Clock
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
from ..eventual import EventualQueue
|
||||||
|
from ..observer import OneShotObserver, SequenceObserver
|
||||||
|
|
||||||
|
class OneShot(unittest.TestCase):
|
||||||
|
def test_fire(self):
|
||||||
|
c = Clock()
|
||||||
|
eq = EventualQueue(c)
|
||||||
|
o = OneShotObserver(eq)
|
||||||
|
res = object()
|
||||||
|
d1 = o.when_fired()
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertNoResult(d1)
|
||||||
|
o.fire(res)
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertIdentical(self.successResultOf(d1), res)
|
||||||
|
d2 = o.when_fired()
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertIdentical(self.successResultOf(d2), res)
|
||||||
|
o.fire_if_not_fired(object())
|
||||||
|
eq.flush_sync()
|
||||||
|
|
||||||
|
def test_fire_if_not_fired(self):
|
||||||
|
c = Clock()
|
||||||
|
eq = EventualQueue(c)
|
||||||
|
o = OneShotObserver(eq)
|
||||||
|
res1 = object()
|
||||||
|
res2 = object()
|
||||||
|
d1 = o.when_fired()
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertNoResult(d1)
|
||||||
|
o.fire_if_not_fired(res1)
|
||||||
|
o.fire_if_not_fired(res2)
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertIdentical(self.successResultOf(d1), res1)
|
||||||
|
|
||||||
|
def test_error_before_firing(self):
|
||||||
|
c = Clock()
|
||||||
|
eq = EventualQueue(c)
|
||||||
|
o = OneShotObserver(eq)
|
||||||
|
f = Failure(ValueError("oops"))
|
||||||
|
d1 = o.when_fired()
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertNoResult(d1)
|
||||||
|
o.error(f)
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertIdentical(self.failureResultOf(d1), f)
|
||||||
|
d2 = o.when_fired()
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertIdentical(self.failureResultOf(d2), f)
|
||||||
|
|
||||||
|
def test_error_after_firing(self):
|
||||||
|
c = Clock()
|
||||||
|
eq = EventualQueue(c)
|
||||||
|
o = OneShotObserver(eq)
|
||||||
|
res = object()
|
||||||
|
f = Failure(ValueError("oops"))
|
||||||
|
|
||||||
|
o.fire(res)
|
||||||
|
eq.flush_sync()
|
||||||
|
d1 = o.when_fired()
|
||||||
|
o.error(f)
|
||||||
|
d2 = o.when_fired()
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertIdentical(self.successResultOf(d1), res)
|
||||||
|
self.assertIdentical(self.failureResultOf(d2), f)
|
||||||
|
|
||||||
|
|
||||||
|
class Sequence(unittest.TestCase):
|
||||||
|
def test_fire(self):
|
||||||
|
c = Clock()
|
||||||
|
eq = EventualQueue(c)
|
||||||
|
o = SequenceObserver(eq)
|
||||||
|
d1 = o.when_next_event()
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertNoResult(d1)
|
||||||
|
d2 = o.when_next_event()
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertNoResult(d1)
|
||||||
|
self.assertNoResult(d2)
|
||||||
|
|
||||||
|
ev1 = object()
|
||||||
|
ev2 = object()
|
||||||
|
o.fire(ev1)
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertIdentical(self.successResultOf(d1), ev1)
|
||||||
|
self.assertNoResult(d2)
|
||||||
|
|
||||||
|
o.fire(ev2)
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertIdentical(self.successResultOf(d2), ev2)
|
||||||
|
|
||||||
|
ev3 = object()
|
||||||
|
ev4 = object()
|
||||||
|
o.fire(ev3)
|
||||||
|
o.fire(ev4)
|
||||||
|
|
||||||
|
d3 = o.when_next_event()
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertIdentical(self.successResultOf(d3), ev3)
|
||||||
|
|
||||||
|
d4 = o.when_next_event()
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertIdentical(self.successResultOf(d4), ev4)
|
||||||
|
|
||||||
|
def test_error(self):
|
||||||
|
c = Clock()
|
||||||
|
eq = EventualQueue(c)
|
||||||
|
o = SequenceObserver(eq)
|
||||||
|
d1 = o.when_next_event()
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertNoResult(d1)
|
||||||
|
f = Failure(ValueError("oops"))
|
||||||
|
o.fire(f)
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertIdentical(self.failureResultOf(d1), f)
|
||||||
|
d2 = o.when_next_event()
|
||||||
|
eq.flush_sync()
|
||||||
|
self.assertIdentical(self.failureResultOf(d2), f)
|
||||||
|
|
|
@ -5,12 +5,13 @@ from twisted.trial import unittest
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
|
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
|
||||||
from twisted.internet.error import ConnectionRefusedError
|
from twisted.internet.error import ConnectionRefusedError
|
||||||
from .common import ServerBase, poll_until, pause_one_tick
|
from .common import ServerBase, poll_until
|
||||||
from .. import wormhole, _rendezvous
|
from .. import wormhole, _rendezvous
|
||||||
from ..errors import (WrongPasswordError, ServerConnectionError,
|
from ..errors import (WrongPasswordError, ServerConnectionError,
|
||||||
KeyFormatError, WormholeClosed, LonelyError,
|
KeyFormatError, WormholeClosed, LonelyError,
|
||||||
NoKeyError, OnlyOneCodeError)
|
NoKeyError, OnlyOneCodeError)
|
||||||
from ..transit import allocate_tcp_port
|
from ..transit import allocate_tcp_port
|
||||||
|
from ..eventual import EventualQueue
|
||||||
|
|
||||||
APPID = "appid"
|
APPID = "appid"
|
||||||
|
|
||||||
|
@ -159,9 +160,6 @@ class Wormholes(ServerBase, unittest.TestCase):
|
||||||
verifier2 = yield w2.get_verifier()
|
verifier2 = yield w2.get_verifier()
|
||||||
self.assertEqual(verifier1, verifier2)
|
self.assertEqual(verifier1, verifier2)
|
||||||
|
|
||||||
self.successResultOf(w1.get_unverified_key())
|
|
||||||
self.successResultOf(w2.get_unverified_key())
|
|
||||||
|
|
||||||
versions1 = yield w1.get_versions()
|
versions1 = yield w1.get_versions()
|
||||||
versions2 = yield w2.get_versions()
|
versions2 = yield w2.get_versions()
|
||||||
# app-versions are exercised properly in test_versions, this just
|
# app-versions are exercised properly in test_versions, this just
|
||||||
|
@ -186,18 +184,22 @@ class Wormholes(ServerBase, unittest.TestCase):
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def test_get_code_early(self):
|
def test_get_code_early(self):
|
||||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
eq = EventualQueue(reactor)
|
||||||
|
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||||
d = w1.get_code()
|
d = w1.get_code()
|
||||||
w1.set_code("1-abc")
|
w1.set_code("1-abc")
|
||||||
|
yield eq.flush()
|
||||||
code = self.successResultOf(d)
|
code = self.successResultOf(d)
|
||||||
self.assertEqual(code, "1-abc")
|
self.assertEqual(code, "1-abc")
|
||||||
yield self.assertFailure(w1.close(), LonelyError)
|
yield self.assertFailure(w1.close(), LonelyError)
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def test_get_code_late(self):
|
def test_get_code_late(self):
|
||||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
eq = EventualQueue(reactor)
|
||||||
|
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||||
w1.set_code("1-abc")
|
w1.set_code("1-abc")
|
||||||
d = w1.get_code()
|
d = w1.get_code()
|
||||||
|
yield eq.flush()
|
||||||
code = self.successResultOf(d)
|
code = self.successResultOf(d)
|
||||||
self.assertEqual(code, "1-abc")
|
self.assertEqual(code, "1-abc")
|
||||||
yield self.assertFailure(w1.close(), LonelyError)
|
yield self.assertFailure(w1.close(), LonelyError)
|
||||||
|
@ -323,8 +325,9 @@ class Wormholes(ServerBase, unittest.TestCase):
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def test_closed(self):
|
def test_closed(self):
|
||||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
eq = EventualQueue(reactor)
|
||||||
w2 = wormhole.create(APPID, self.relayurl, reactor)
|
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||||
|
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||||
w1.set_code("123-foo")
|
w1.set_code("123-foo")
|
||||||
w2.set_code("123-foo")
|
w2.set_code("123-foo")
|
||||||
|
|
||||||
|
@ -335,14 +338,14 @@ class Wormholes(ServerBase, unittest.TestCase):
|
||||||
yield w1.close()
|
yield w1.close()
|
||||||
yield w2.close()
|
yield w2.close()
|
||||||
|
|
||||||
# once closed, all Deferred-yielding API calls get an immediate error
|
# once closed, all Deferred-yielding API calls get an prompt error
|
||||||
self.failureResultOf(w1.get_welcome(), WormholeClosed)
|
yield self.assertFailure(w1.get_welcome(), WormholeClosed)
|
||||||
f = self.failureResultOf(w1.get_code(), WormholeClosed)
|
e = yield self.assertFailure(w1.get_code(), WormholeClosed)
|
||||||
self.assertEqual(f.value.args[0], "happy")
|
self.assertEqual(e.args[0], "happy")
|
||||||
self.failureResultOf(w1.get_unverified_key(), WormholeClosed)
|
yield self.assertFailure(w1.get_unverified_key(), WormholeClosed)
|
||||||
self.failureResultOf(w1.get_verifier(), WormholeClosed)
|
yield self.assertFailure(w1.get_verifier(), WormholeClosed)
|
||||||
self.failureResultOf(w1.get_versions(), WormholeClosed)
|
yield self.assertFailure(w1.get_versions(), WormholeClosed)
|
||||||
self.failureResultOf(w1.get_message(), WormholeClosed)
|
yield self.assertFailure(w1.get_message(), WormholeClosed)
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def test_closed_idle(self):
|
def test_closed_idle(self):
|
||||||
|
@ -360,17 +363,18 @@ class Wormholes(ServerBase, unittest.TestCase):
|
||||||
|
|
||||||
yield self.assertFailure(w1.close(), LonelyError)
|
yield self.assertFailure(w1.close(), LonelyError)
|
||||||
|
|
||||||
self.failureResultOf(d_welcome, LonelyError)
|
yield self.assertFailure(d_welcome, LonelyError)
|
||||||
self.failureResultOf(d_code, LonelyError)
|
yield self.assertFailure(d_code, LonelyError)
|
||||||
self.failureResultOf(d_key, LonelyError)
|
yield self.assertFailure(d_key, LonelyError)
|
||||||
self.failureResultOf(d_verifier, LonelyError)
|
yield self.assertFailure(d_verifier, LonelyError)
|
||||||
self.failureResultOf(d_versions, LonelyError)
|
yield self.assertFailure(d_versions, LonelyError)
|
||||||
self.failureResultOf(d_message, LonelyError)
|
yield self.assertFailure(d_message, LonelyError)
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def test_wrong_password(self):
|
def test_wrong_password(self):
|
||||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
eq = EventualQueue(reactor)
|
||||||
w2 = wormhole.create(APPID, self.relayurl, reactor)
|
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||||
|
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||||
w1.allocate_code()
|
w1.allocate_code()
|
||||||
code = yield w1.get_code()
|
code = yield w1.get_code()
|
||||||
w2.set_code(code+"not")
|
w2.set_code(code+"not")
|
||||||
|
@ -403,9 +407,8 @@ class Wormholes(ServerBase, unittest.TestCase):
|
||||||
# wait for each side to notice the failure
|
# wait for each side to notice the failure
|
||||||
yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
|
yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
|
||||||
yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
|
yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
|
||||||
# and then wait for the rest of the loops to fire. if we had+used
|
# the rest of the loops should fire within the next tick
|
||||||
# eventual-send, this wouldn't be a problem
|
yield eq.flush()
|
||||||
yield pause_one_tick()
|
|
||||||
|
|
||||||
# now all the rest should have fired already
|
# now all the rest should have fired already
|
||||||
self.failureResultOf(d1_verified, WrongPasswordError)
|
self.failureResultOf(d1_verified, WrongPasswordError)
|
||||||
|
@ -420,27 +423,27 @@ class Wormholes(ServerBase, unittest.TestCase):
|
||||||
# before we close
|
# before we close
|
||||||
|
|
||||||
# any new calls in the error state should immediately fail
|
# any new calls in the error state should immediately fail
|
||||||
self.failureResultOf(w1.get_unverified_key(), WrongPasswordError)
|
yield self.assertFailure(w1.get_unverified_key(), WrongPasswordError)
|
||||||
self.failureResultOf(w1.get_verifier(), WrongPasswordError)
|
yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
|
||||||
self.failureResultOf(w1.get_versions(), WrongPasswordError)
|
yield self.assertFailure(w1.get_versions(), WrongPasswordError)
|
||||||
self.failureResultOf(w1.get_message(), WrongPasswordError)
|
yield self.assertFailure(w1.get_message(), WrongPasswordError)
|
||||||
self.failureResultOf(w2.get_unverified_key(), WrongPasswordError)
|
yield self.assertFailure(w2.get_unverified_key(), WrongPasswordError)
|
||||||
self.failureResultOf(w2.get_verifier(), WrongPasswordError)
|
yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
|
||||||
self.failureResultOf(w2.get_versions(), WrongPasswordError)
|
yield self.assertFailure(w2.get_versions(), WrongPasswordError)
|
||||||
self.failureResultOf(w2.get_message(), WrongPasswordError)
|
yield self.assertFailure(w2.get_message(), WrongPasswordError)
|
||||||
|
|
||||||
yield self.assertFailure(w1.close(), WrongPasswordError)
|
yield self.assertFailure(w1.close(), WrongPasswordError)
|
||||||
yield self.assertFailure(w2.close(), WrongPasswordError)
|
yield self.assertFailure(w2.close(), WrongPasswordError)
|
||||||
|
|
||||||
# API calls should still get the error, not WormholeClosed
|
# API calls should still get the error, not WormholeClosed
|
||||||
self.failureResultOf(w1.get_unverified_key(), WrongPasswordError)
|
yield self.assertFailure(w1.get_unverified_key(), WrongPasswordError)
|
||||||
self.failureResultOf(w1.get_verifier(), WrongPasswordError)
|
yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
|
||||||
self.failureResultOf(w1.get_versions(), WrongPasswordError)
|
yield self.assertFailure(w1.get_versions(), WrongPasswordError)
|
||||||
self.failureResultOf(w1.get_message(), WrongPasswordError)
|
yield self.assertFailure(w1.get_message(), WrongPasswordError)
|
||||||
self.failureResultOf(w2.get_unverified_key(), WrongPasswordError)
|
yield self.assertFailure(w2.get_unverified_key(), WrongPasswordError)
|
||||||
self.failureResultOf(w2.get_verifier(), WrongPasswordError)
|
yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
|
||||||
self.failureResultOf(w2.get_versions(), WrongPasswordError)
|
yield self.assertFailure(w2.get_versions(), WrongPasswordError)
|
||||||
self.failureResultOf(w2.get_message(), WrongPasswordError)
|
yield self.assertFailure(w2.get_message(), WrongPasswordError)
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def test_wrong_password_with_spaces(self):
|
def test_wrong_password_with_spaces(self):
|
||||||
|
@ -493,8 +496,9 @@ class Wormholes(ServerBase, unittest.TestCase):
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def test_verifier(self):
|
def test_verifier(self):
|
||||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
eq = EventualQueue(reactor)
|
||||||
w2 = wormhole.create(APPID, self.relayurl, reactor)
|
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||||
|
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||||
w1.allocate_code()
|
w1.allocate_code()
|
||||||
code = yield w1.get_code()
|
code = yield w1.get_code()
|
||||||
w2.set_code(code)
|
w2.set_code(code)
|
||||||
|
@ -510,7 +514,9 @@ class Wormholes(ServerBase, unittest.TestCase):
|
||||||
self.assertEqual(dataY, b"data1")
|
self.assertEqual(dataY, b"data1")
|
||||||
|
|
||||||
# calling get_verifier() this late should fire right away
|
# calling get_verifier() this late should fire right away
|
||||||
v1_late = self.successResultOf(w2.get_verifier())
|
d = w2.get_verifier()
|
||||||
|
yield eq.flush()
|
||||||
|
v1_late = self.successResultOf(d)
|
||||||
self.assertEqual(v1_late, v1)
|
self.assertEqual(v1_late, v1)
|
||||||
|
|
||||||
yield w1.close()
|
yield w1.close()
|
||||||
|
@ -644,26 +650,30 @@ class Reconnection(ServerBase, unittest.TestCase):
|
||||||
self.assertEqual(c2, "happy")
|
self.assertEqual(c2, "happy")
|
||||||
|
|
||||||
class InitialFailure(unittest.TestCase):
|
class InitialFailure(unittest.TestCase):
|
||||||
def assertSCEResultOf(self, d, innerType):
|
@inlineCallbacks
|
||||||
|
def assertSCEFailure(self, eq, d, innerType):
|
||||||
|
yield eq.flush()
|
||||||
f = self.failureResultOf(d, ServerConnectionError)
|
f = self.failureResultOf(d, ServerConnectionError)
|
||||||
inner = f.value.reason
|
inner = f.value.reason
|
||||||
self.assertIsInstance(inner, innerType)
|
self.assertIsInstance(inner, innerType)
|
||||||
return inner
|
returnValue(inner)
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def test_bad_dns(self):
|
def test_bad_dns(self):
|
||||||
|
eq = EventualQueue(reactor)
|
||||||
# point at a URL that will never connect
|
# point at a URL that will never connect
|
||||||
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", reactor)
|
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1",
|
||||||
|
reactor, _eventual_queue=eq)
|
||||||
# that should have already received an error, when it tried to
|
# that should have already received an error, when it tried to
|
||||||
# resolve the bogus DNS name. All API calls will return an error.
|
# resolve the bogus DNS name. All API calls will return an error.
|
||||||
e = yield self.assertFailure(w.get_unverified_key(),
|
|
||||||
ServerConnectionError)
|
e = yield self.assertSCEFailure(eq, w.get_unverified_key(), ValueError)
|
||||||
self.assertIsInstance(e.reason, ValueError)
|
self.assertIsInstance(e, ValueError)
|
||||||
self.assertEqual(str(e), "invalid hostname: %%%.example.org")
|
self.assertEqual(str(e), "invalid hostname: %%%.example.org")
|
||||||
self.assertSCEResultOf(w.get_code(), ValueError)
|
yield self.assertSCEFailure(eq, w.get_code(), ValueError)
|
||||||
self.assertSCEResultOf(w.get_verifier(), ValueError)
|
yield self.assertSCEFailure(eq, w.get_verifier(), ValueError)
|
||||||
self.assertSCEResultOf(w.get_versions(), ValueError)
|
yield self.assertSCEFailure(eq, w.get_versions(), ValueError)
|
||||||
self.assertSCEResultOf(w.get_message(), ValueError)
|
yield self.assertSCEFailure(eq, w.get_message(), ValueError)
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def assertSCE(self, d, innerType):
|
def assertSCE(self, d, innerType):
|
||||||
|
|
|
@ -3,9 +3,10 @@ import os, sys
|
||||||
from attr import attrs, attrib
|
from attr import attrs, attrib
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
from twisted.python import failure
|
from twisted.python import failure
|
||||||
from twisted.internet import defer
|
|
||||||
from ._interfaces import IWormhole, IDeferredWormhole
|
from ._interfaces import IWormhole, IDeferredWormhole
|
||||||
from .util import bytes_to_hexstr
|
from .util import bytes_to_hexstr
|
||||||
|
from .eventual import EventualQueue
|
||||||
|
from .observer import OneShotObserver, SequenceObserver
|
||||||
from .timing import DebugTiming
|
from .timing import DebugTiming
|
||||||
from .journal import ImmediateJournal
|
from .journal import ImmediateJournal
|
||||||
from ._boss import Boss
|
from ._boss import Boss
|
||||||
|
@ -100,22 +101,16 @@ class _DelegatedWormhole(object):
|
||||||
|
|
||||||
@implementer(IWormhole, IDeferredWormhole)
|
@implementer(IWormhole, IDeferredWormhole)
|
||||||
class _DeferredWormhole(object):
|
class _DeferredWormhole(object):
|
||||||
def __init__(self):
|
def __init__(self, eq):
|
||||||
self._welcome = None
|
self._welcome_observer = OneShotObserver(eq)
|
||||||
self._welcome_observers = []
|
self._code_observer = OneShotObserver(eq)
|
||||||
self._code = None
|
|
||||||
self._code_observers = []
|
|
||||||
self._key = None
|
self._key = None
|
||||||
self._key_observers = []
|
self._key_observer = OneShotObserver(eq)
|
||||||
self._verifier = None
|
self._verifier_observer = OneShotObserver(eq)
|
||||||
self._verifier_observers = []
|
self._version_observer = OneShotObserver(eq)
|
||||||
self._versions = None
|
self._received_observer = SequenceObserver(eq)
|
||||||
self._version_observers = []
|
self._closed = False
|
||||||
self._received_data = []
|
self._closed_observer = OneShotObserver(eq)
|
||||||
self._received_observers = []
|
|
||||||
self._observer_result = None
|
|
||||||
self._closed_result = None
|
|
||||||
self._closed_observers = []
|
|
||||||
|
|
||||||
def _set_boss(self, boss):
|
def _set_boss(self, boss):
|
||||||
self._boss = boss
|
self._boss = boss
|
||||||
|
@ -127,58 +122,22 @@ class _DeferredWormhole(object):
|
||||||
# the process that will cause it to fire, but forbidding that
|
# the process that will cause it to fire, but forbidding that
|
||||||
# ordering would make it easier to cause programming errors that
|
# ordering would make it easier to cause programming errors that
|
||||||
# forget to trigger it entirely.
|
# forget to trigger it entirely.
|
||||||
if self._observer_result is not None:
|
return self._code_observer.when_fired()
|
||||||
return defer.fail(self._observer_result)
|
|
||||||
if self._code is not None:
|
|
||||||
return defer.succeed(self._code)
|
|
||||||
d = defer.Deferred()
|
|
||||||
self._code_observers.append(d)
|
|
||||||
return d
|
|
||||||
|
|
||||||
def get_welcome(self):
|
def get_welcome(self):
|
||||||
if self._observer_result is not None:
|
return self._welcome_observer.when_fired()
|
||||||
return defer.fail(self._observer_result)
|
|
||||||
if self._welcome is not None:
|
|
||||||
return defer.succeed(self._welcome)
|
|
||||||
d = defer.Deferred()
|
|
||||||
self._welcome_observers.append(d)
|
|
||||||
return d
|
|
||||||
|
|
||||||
def get_unverified_key(self):
|
def get_unverified_key(self):
|
||||||
if self._observer_result is not None:
|
return self._key_observer.when_fired()
|
||||||
return defer.fail(self._observer_result)
|
|
||||||
if self._key is not None:
|
|
||||||
return defer.succeed(self._key)
|
|
||||||
d = defer.Deferred()
|
|
||||||
self._key_observers.append(d)
|
|
||||||
return d
|
|
||||||
|
|
||||||
def get_verifier(self):
|
def get_verifier(self):
|
||||||
if self._observer_result is not None:
|
return self._verifier_observer.when_fired()
|
||||||
return defer.fail(self._observer_result)
|
|
||||||
if self._verifier is not None:
|
|
||||||
return defer.succeed(self._verifier)
|
|
||||||
d = defer.Deferred()
|
|
||||||
self._verifier_observers.append(d)
|
|
||||||
return d
|
|
||||||
|
|
||||||
def get_versions(self):
|
def get_versions(self):
|
||||||
if self._observer_result is not None:
|
return self._version_observer.when_fired()
|
||||||
return defer.fail(self._observer_result)
|
|
||||||
if self._versions is not None:
|
|
||||||
return defer.succeed(self._versions)
|
|
||||||
d = defer.Deferred()
|
|
||||||
self._version_observers.append(d)
|
|
||||||
return d
|
|
||||||
|
|
||||||
def get_message(self):
|
def get_message(self):
|
||||||
if self._observer_result is not None:
|
return self._received_observer.when_next_event()
|
||||||
return defer.fail(self._observer_result)
|
|
||||||
if self._received_data:
|
|
||||||
return defer.succeed(self._received_data.pop(0))
|
|
||||||
d = defer.Deferred()
|
|
||||||
self._received_observers.append(d)
|
|
||||||
return d
|
|
||||||
|
|
||||||
def allocate_code(self, code_length=2):
|
def allocate_code(self, code_length=2):
|
||||||
self._boss.allocate_code(code_length)
|
self._boss.allocate_code(code_length)
|
||||||
|
@ -207,10 +166,8 @@ class _DeferredWormhole(object):
|
||||||
# fails with WormholeError unless we established a connection
|
# fails with WormholeError unless we established a connection
|
||||||
# (state=="happy"). Fails with WrongPasswordError (a subclass of
|
# (state=="happy"). Fails with WrongPasswordError (a subclass of
|
||||||
# WormholeError) if state=="scary".
|
# WormholeError) if state=="scary".
|
||||||
if self._closed_result:
|
d = self._closed_observer.when_fired() # maybe Failure
|
||||||
return defer.succeed(self._closed_result) # maybe Failure
|
if not self._closed:
|
||||||
d = defer.Deferred()
|
|
||||||
self._closed_observers.append(d)
|
|
||||||
self._boss.close() # only need to close if it wasn't already
|
self._boss.close() # only need to close if it wasn't already
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
@ -221,75 +178,56 @@ class _DeferredWormhole(object):
|
||||||
|
|
||||||
# from below
|
# from below
|
||||||
def got_welcome(self, welcome):
|
def got_welcome(self, welcome):
|
||||||
self._welcome = welcome
|
self._welcome_observer.fire_if_not_fired(welcome)
|
||||||
for d in self._welcome_observers:
|
|
||||||
d.callback(welcome)
|
|
||||||
self._welcome_observers[:] = []
|
|
||||||
def got_code(self, code):
|
def got_code(self, code):
|
||||||
self._code = code
|
self._code_observer.fire_if_not_fired(code)
|
||||||
for d in self._code_observers:
|
|
||||||
d.callback(code)
|
|
||||||
self._code_observers[:] = []
|
|
||||||
def got_key(self, key):
|
def got_key(self, key):
|
||||||
self._key = key # for derive_key()
|
self._key = key # for derive_key()
|
||||||
for d in self._key_observers:
|
self._key_observer.fire_if_not_fired(key)
|
||||||
d.callback(key)
|
|
||||||
self._key_observers[:] = []
|
|
||||||
|
|
||||||
def got_verifier(self, verifier):
|
def got_verifier(self, verifier):
|
||||||
self._verifier = verifier
|
self._verifier_observer.fire_if_not_fired(verifier)
|
||||||
for d in self._verifier_observers:
|
|
||||||
d.callback(verifier)
|
|
||||||
self._verifier_observers[:] = []
|
|
||||||
def got_versions(self, versions):
|
def got_versions(self, versions):
|
||||||
self._versions = versions
|
self._version_observer.fire_if_not_fired(versions)
|
||||||
for d in self._version_observers:
|
|
||||||
d.callback(versions)
|
|
||||||
self._version_observers[:] = []
|
|
||||||
|
|
||||||
def received(self, plaintext):
|
def received(self, plaintext):
|
||||||
if self._received_observers:
|
self._received_observer.fire(plaintext)
|
||||||
self._received_observers.pop(0).callback(plaintext)
|
|
||||||
return
|
|
||||||
self._received_data.append(plaintext)
|
|
||||||
|
|
||||||
def closed(self, result):
|
def closed(self, result):
|
||||||
|
self._closed = True
|
||||||
#print("closed", result, type(result), file=sys.stderr)
|
#print("closed", result, type(result), file=sys.stderr)
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
self._observer_result = self._closed_result = failure.Failure(result)
|
# everything pending gets an error, including close()
|
||||||
|
f = failure.Failure(result)
|
||||||
|
self._closed_observer.error(f)
|
||||||
else:
|
else:
|
||||||
# pending w.key()/w.verify()/w.version()/w.read() get an error
|
# everything pending except close() gets an error:
|
||||||
self._observer_result = WormholeClosed(result)
|
# w.get_code()/welcome/unverified_key/verifier/versions/message
|
||||||
|
f = failure.Failure(WormholeClosed(result))
|
||||||
# but w.close() only gets error if we're unhappy
|
# but w.close() only gets error if we're unhappy
|
||||||
self._closed_result = result
|
self._closed_observer.fire_if_not_fired(result)
|
||||||
for d in self._welcome_observers:
|
self._welcome_observer.error(f)
|
||||||
d.errback(self._observer_result)
|
self._code_observer.error(f)
|
||||||
for d in self._code_observers:
|
self._key_observer.error(f)
|
||||||
d.errback(self._observer_result)
|
self._verifier_observer.error(f)
|
||||||
for d in self._key_observers:
|
self._version_observer.error(f)
|
||||||
d.errback(self._observer_result)
|
self._received_observer.fire(f)
|
||||||
for d in self._verifier_observers:
|
|
||||||
d.errback(self._observer_result)
|
|
||||||
for d in self._version_observers:
|
|
||||||
d.errback(self._observer_result)
|
|
||||||
for d in self._received_observers:
|
|
||||||
d.errback(self._observer_result)
|
|
||||||
for d in self._closed_observers:
|
|
||||||
d.callback(self._closed_result)
|
|
||||||
|
|
||||||
|
|
||||||
def create(appid, relay_url, reactor, # use keyword args for everything else
|
def create(appid, relay_url, reactor, # use keyword args for everything else
|
||||||
versions={},
|
versions={},
|
||||||
delegate=None, journal=None, tor=None,
|
delegate=None, journal=None, tor=None,
|
||||||
timing=None,
|
timing=None,
|
||||||
stderr=sys.stderr):
|
stderr=sys.stderr,
|
||||||
|
_eventual_queue=None):
|
||||||
timing = timing or DebugTiming()
|
timing = timing or DebugTiming()
|
||||||
side = bytes_to_hexstr(os.urandom(5))
|
side = bytes_to_hexstr(os.urandom(5))
|
||||||
journal = journal or ImmediateJournal()
|
journal = journal or ImmediateJournal()
|
||||||
|
eq = _eventual_queue or EventualQueue(reactor)
|
||||||
if delegate:
|
if delegate:
|
||||||
w = _DelegatedWormhole(delegate)
|
w = _DelegatedWormhole(delegate)
|
||||||
else:
|
else:
|
||||||
w = _DeferredWormhole()
|
w = _DeferredWormhole(eq)
|
||||||
wormhole_versions = {} # will be used to indicate Wormhole capabilities
|
wormhole_versions = {} # will be used to indicate Wormhole capabilities
|
||||||
wormhole_versions["app_versions"] = versions # app-specific capabilities
|
wormhole_versions["app_versions"] = versions # app-specific capabilities
|
||||||
b = Boss(w, side, relay_url, appid, wormhole_versions,
|
b = Boss(w, side, relay_url, appid, wormhole_versions,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user