Merge branch 'observers-4'
This factors out the various "get me a Deferred which fires when/if we compute a value" code from the _DeferredWormhole API calls: get_code, get_unverified_key, get_versions, get_message, etc. It uses an eventual-send for each one, which will protect against surprises when an application invokes an wormhole API from within a previous API's callback: without this, the internal wormhole state isn't guaranteed to be coherent, and crashes could result.
This commit is contained in:
commit
c5ae678417
50
src/wormhole/eventual.py
Normal file
50
src/wormhole/eventual.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
# inspired-by/adapted-from Foolscap's eventual.py, which Glyph wrote for me
|
||||
# years ago.
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.interfaces import IReactorTime
|
||||
from twisted.python import log
|
||||
|
||||
class EventualQueue(object):
|
||||
def __init__(self, clock):
|
||||
# pass clock=reactor unless you're testing
|
||||
self._clock = IReactorTime(clock)
|
||||
self._calls = []
|
||||
self._flush_d = None
|
||||
self._timer = None
|
||||
|
||||
def eventually(self, f, *args, **kwargs):
|
||||
self._calls.append( (f, args, kwargs) )
|
||||
if not self._timer:
|
||||
self._timer = self._clock.callLater(0, self._turn)
|
||||
|
||||
def fire_eventually(self, value=None):
|
||||
d = Deferred()
|
||||
self.eventually(d.callback, value)
|
||||
return d
|
||||
|
||||
def _turn(self):
|
||||
while self._calls:
|
||||
(f, args, kwargs) = self._calls.pop(0)
|
||||
try:
|
||||
f(*args, **kwargs)
|
||||
except:
|
||||
log.err()
|
||||
self._timer = None
|
||||
d, self._flush_d = self._flush_d, None
|
||||
if d:
|
||||
d.callback(None)
|
||||
|
||||
def flush_sync(self):
|
||||
# if you have control over the Clock, this will synchronously flush the
|
||||
# queue
|
||||
assert self._clock.advance, "needs clock=twisted.internet.task.Clock()"
|
||||
while self._calls:
|
||||
self._clock.advance(0)
|
||||
|
||||
def flush(self):
|
||||
# this is for unit tests, not application code
|
||||
assert not self._flush_d, "only one flush at a time"
|
||||
self._flush_d = Deferred()
|
||||
self.eventually(lambda: None)
|
||||
return self._flush_d
|
69
src/wormhole/observer.py
Normal file
69
src/wormhole/observer.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
from __future__ import unicode_literals, print_function
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
NoResult = object()
|
||||
|
||||
class OneShotObserver(object):
|
||||
def __init__(self, eventual_queue):
|
||||
self._eq = eventual_queue
|
||||
self._result = NoResult
|
||||
self._observers = [] # list of Deferreds
|
||||
|
||||
def when_fired(self):
|
||||
d = Deferred()
|
||||
self._observers.append(d)
|
||||
self._maybe_call_observers()
|
||||
return d
|
||||
|
||||
def fire(self, result):
|
||||
assert self._result is NoResult
|
||||
self._result = result
|
||||
self._maybe_call_observers()
|
||||
|
||||
def _maybe_call_observers(self):
|
||||
if self._result is NoResult:
|
||||
return
|
||||
observers, self._observers = self._observers, []
|
||||
for d in observers:
|
||||
self._eq.eventually(d.callback, self._result)
|
||||
|
||||
def error(self, f):
|
||||
# errors will override an existing result
|
||||
assert isinstance(f, Failure)
|
||||
self._result = f
|
||||
self._maybe_call_observers()
|
||||
|
||||
def fire_if_not_fired(self, result):
|
||||
if self._result is NoResult:
|
||||
self.fire(result)
|
||||
|
||||
class SequenceObserver(object):
|
||||
def __init__(self, eventual_queue):
|
||||
self._eq = eventual_queue
|
||||
self._error = None
|
||||
self._results = []
|
||||
self._observers = []
|
||||
|
||||
def when_next_event(self):
|
||||
d = Deferred()
|
||||
if self._error:
|
||||
self._eq.eventually(d.errback, self._error)
|
||||
elif self._results:
|
||||
result = self._results.pop(0)
|
||||
self._eq.eventually(d.callback, result)
|
||||
else:
|
||||
self._observers.append(d)
|
||||
return d
|
||||
|
||||
def fire(self, result):
|
||||
if isinstance(result, Failure):
|
||||
self._error = result
|
||||
for d in self._observers:
|
||||
self._eq.eventually(d.errback, self._error)
|
||||
self._observers = []
|
||||
else:
|
||||
self._results.append(result)
|
||||
if self._observers:
|
||||
d = self._observers.pop(0)
|
||||
self._eq.eventually(d.callback, self._results.pop(0))
|
|
@ -128,10 +128,3 @@ def poll_until(predicate):
|
|||
d = defer.Deferred()
|
||||
reactor.callLater(0.001, d.callback, None)
|
||||
yield d
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def pause_one_tick():
|
||||
# return a Deferred that won't fire until at least the next reactor tick
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(0.001, d.callback, None)
|
||||
yield d
|
||||
|
|
57
src/wormhole/test/test_eventual.py
Normal file
57
src/wormhole/test/test_eventual.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet.task import Clock
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks
|
||||
from ..eventual import EventualQueue
|
||||
|
||||
class IntentionalError(Exception):
|
||||
pass
|
||||
|
||||
class Eventual(unittest.TestCase, object):
|
||||
def test_eventually(self):
|
||||
c = Clock()
|
||||
eq = EventualQueue(c)
|
||||
c1 = mock.Mock()
|
||||
eq.eventually(c1, "arg1", "arg2", kwarg1="kw1")
|
||||
eq.eventually(c1, "arg3", "arg4", kwarg5="kw5")
|
||||
d2 = eq.fire_eventually()
|
||||
d3 = eq.fire_eventually("value")
|
||||
self.assertEqual(c1.mock_calls, [])
|
||||
self.assertNoResult(d2)
|
||||
self.assertNoResult(d3)
|
||||
|
||||
eq.flush_sync()
|
||||
self.assertEqual(c1.mock_calls,
|
||||
[mock.call("arg1", "arg2", kwarg1="kw1"),
|
||||
mock.call("arg3", "arg4", kwarg5="kw5")])
|
||||
self.assertEqual(self.successResultOf(d2), None)
|
||||
self.assertEqual(self.successResultOf(d3), "value")
|
||||
|
||||
def test_error(self):
|
||||
c = Clock()
|
||||
eq = EventualQueue(c)
|
||||
c1 = mock.Mock(side_effect=IntentionalError)
|
||||
eq.eventually(c1, "arg1", "arg2", kwarg1="kw1")
|
||||
self.assertEqual(c1.mock_calls, [])
|
||||
|
||||
eq.flush_sync()
|
||||
self.assertEqual(c1.mock_calls,
|
||||
[mock.call("arg1", "arg2", kwarg1="kw1")])
|
||||
|
||||
self.flushLoggedErrors(IntentionalError)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_flush(self):
|
||||
eq = EventualQueue(reactor)
|
||||
d1 = eq.fire_eventually()
|
||||
d2 = Deferred()
|
||||
def _more(res):
|
||||
eq.eventually(d2.callback, None)
|
||||
d1.addCallback(_more)
|
||||
yield eq.flush()
|
||||
# d1 will fire, which will queue d2 to fire, and the flush() ought to
|
||||
# wait for d2 too
|
||||
self.successResultOf(d2)
|
||||
|
122
src/wormhole/test/test_observer.py
Normal file
122
src/wormhole/test/test_observer.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
from twisted.trial import unittest
|
||||
from twisted.internet.task import Clock
|
||||
from twisted.python.failure import Failure
|
||||
from ..eventual import EventualQueue
|
||||
from ..observer import OneShotObserver, SequenceObserver
|
||||
|
||||
class OneShot(unittest.TestCase):
|
||||
def test_fire(self):
|
||||
c = Clock()
|
||||
eq = EventualQueue(c)
|
||||
o = OneShotObserver(eq)
|
||||
res = object()
|
||||
d1 = o.when_fired()
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d1)
|
||||
o.fire(res)
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.successResultOf(d1), res)
|
||||
d2 = o.when_fired()
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.successResultOf(d2), res)
|
||||
o.fire_if_not_fired(object())
|
||||
eq.flush_sync()
|
||||
|
||||
def test_fire_if_not_fired(self):
|
||||
c = Clock()
|
||||
eq = EventualQueue(c)
|
||||
o = OneShotObserver(eq)
|
||||
res1 = object()
|
||||
res2 = object()
|
||||
d1 = o.when_fired()
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d1)
|
||||
o.fire_if_not_fired(res1)
|
||||
o.fire_if_not_fired(res2)
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.successResultOf(d1), res1)
|
||||
|
||||
def test_error_before_firing(self):
|
||||
c = Clock()
|
||||
eq = EventualQueue(c)
|
||||
o = OneShotObserver(eq)
|
||||
f = Failure(ValueError("oops"))
|
||||
d1 = o.when_fired()
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d1)
|
||||
o.error(f)
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.failureResultOf(d1), f)
|
||||
d2 = o.when_fired()
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.failureResultOf(d2), f)
|
||||
|
||||
def test_error_after_firing(self):
|
||||
c = Clock()
|
||||
eq = EventualQueue(c)
|
||||
o = OneShotObserver(eq)
|
||||
res = object()
|
||||
f = Failure(ValueError("oops"))
|
||||
|
||||
o.fire(res)
|
||||
eq.flush_sync()
|
||||
d1 = o.when_fired()
|
||||
o.error(f)
|
||||
d2 = o.when_fired()
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.successResultOf(d1), res)
|
||||
self.assertIdentical(self.failureResultOf(d2), f)
|
||||
|
||||
|
||||
class Sequence(unittest.TestCase):
|
||||
def test_fire(self):
|
||||
c = Clock()
|
||||
eq = EventualQueue(c)
|
||||
o = SequenceObserver(eq)
|
||||
d1 = o.when_next_event()
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d1)
|
||||
d2 = o.when_next_event()
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d1)
|
||||
self.assertNoResult(d2)
|
||||
|
||||
ev1 = object()
|
||||
ev2 = object()
|
||||
o.fire(ev1)
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.successResultOf(d1), ev1)
|
||||
self.assertNoResult(d2)
|
||||
|
||||
o.fire(ev2)
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.successResultOf(d2), ev2)
|
||||
|
||||
ev3 = object()
|
||||
ev4 = object()
|
||||
o.fire(ev3)
|
||||
o.fire(ev4)
|
||||
|
||||
d3 = o.when_next_event()
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.successResultOf(d3), ev3)
|
||||
|
||||
d4 = o.when_next_event()
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.successResultOf(d4), ev4)
|
||||
|
||||
def test_error(self):
|
||||
c = Clock()
|
||||
eq = EventualQueue(c)
|
||||
o = SequenceObserver(eq)
|
||||
d1 = o.when_next_event()
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d1)
|
||||
f = Failure(ValueError("oops"))
|
||||
o.fire(f)
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.failureResultOf(d1), f)
|
||||
d2 = o.when_next_event()
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.failureResultOf(d2), f)
|
||||
|
|
@ -5,12 +5,13 @@ from twisted.trial import unittest
|
|||
from twisted.internet import reactor
|
||||
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
|
||||
from twisted.internet.error import ConnectionRefusedError
|
||||
from .common import ServerBase, poll_until, pause_one_tick
|
||||
from .common import ServerBase, poll_until
|
||||
from .. import wormhole, _rendezvous
|
||||
from ..errors import (WrongPasswordError, ServerConnectionError,
|
||||
KeyFormatError, WormholeClosed, LonelyError,
|
||||
NoKeyError, OnlyOneCodeError)
|
||||
from ..transit import allocate_tcp_port
|
||||
from ..eventual import EventualQueue
|
||||
|
||||
APPID = "appid"
|
||||
|
||||
|
@ -159,9 +160,6 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
verifier2 = yield w2.get_verifier()
|
||||
self.assertEqual(verifier1, verifier2)
|
||||
|
||||
self.successResultOf(w1.get_unverified_key())
|
||||
self.successResultOf(w2.get_unverified_key())
|
||||
|
||||
versions1 = yield w1.get_versions()
|
||||
versions2 = yield w2.get_versions()
|
||||
# app-versions are exercised properly in test_versions, this just
|
||||
|
@ -186,18 +184,22 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
|
||||
@inlineCallbacks
|
||||
def test_get_code_early(self):
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
eq = EventualQueue(reactor)
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
d = w1.get_code()
|
||||
w1.set_code("1-abc")
|
||||
yield eq.flush()
|
||||
code = self.successResultOf(d)
|
||||
self.assertEqual(code, "1-abc")
|
||||
yield self.assertFailure(w1.close(), LonelyError)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_get_code_late(self):
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
eq = EventualQueue(reactor)
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
w1.set_code("1-abc")
|
||||
d = w1.get_code()
|
||||
yield eq.flush()
|
||||
code = self.successResultOf(d)
|
||||
self.assertEqual(code, "1-abc")
|
||||
yield self.assertFailure(w1.close(), LonelyError)
|
||||
|
@ -323,8 +325,9 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
|
||||
@inlineCallbacks
|
||||
def test_closed(self):
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
w2 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
eq = EventualQueue(reactor)
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
w1.set_code("123-foo")
|
||||
w2.set_code("123-foo")
|
||||
|
||||
|
@ -335,14 +338,14 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
yield w1.close()
|
||||
yield w2.close()
|
||||
|
||||
# once closed, all Deferred-yielding API calls get an immediate error
|
||||
self.failureResultOf(w1.get_welcome(), WormholeClosed)
|
||||
f = self.failureResultOf(w1.get_code(), WormholeClosed)
|
||||
self.assertEqual(f.value.args[0], "happy")
|
||||
self.failureResultOf(w1.get_unverified_key(), WormholeClosed)
|
||||
self.failureResultOf(w1.get_verifier(), WormholeClosed)
|
||||
self.failureResultOf(w1.get_versions(), WormholeClosed)
|
||||
self.failureResultOf(w1.get_message(), WormholeClosed)
|
||||
# once closed, all Deferred-yielding API calls get an prompt error
|
||||
yield self.assertFailure(w1.get_welcome(), WormholeClosed)
|
||||
e = yield self.assertFailure(w1.get_code(), WormholeClosed)
|
||||
self.assertEqual(e.args[0], "happy")
|
||||
yield self.assertFailure(w1.get_unverified_key(), WormholeClosed)
|
||||
yield self.assertFailure(w1.get_verifier(), WormholeClosed)
|
||||
yield self.assertFailure(w1.get_versions(), WormholeClosed)
|
||||
yield self.assertFailure(w1.get_message(), WormholeClosed)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_closed_idle(self):
|
||||
|
@ -360,17 +363,18 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
|
||||
yield self.assertFailure(w1.close(), LonelyError)
|
||||
|
||||
self.failureResultOf(d_welcome, LonelyError)
|
||||
self.failureResultOf(d_code, LonelyError)
|
||||
self.failureResultOf(d_key, LonelyError)
|
||||
self.failureResultOf(d_verifier, LonelyError)
|
||||
self.failureResultOf(d_versions, LonelyError)
|
||||
self.failureResultOf(d_message, LonelyError)
|
||||
yield self.assertFailure(d_welcome, LonelyError)
|
||||
yield self.assertFailure(d_code, LonelyError)
|
||||
yield self.assertFailure(d_key, LonelyError)
|
||||
yield self.assertFailure(d_verifier, LonelyError)
|
||||
yield self.assertFailure(d_versions, LonelyError)
|
||||
yield self.assertFailure(d_message, LonelyError)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_wrong_password(self):
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
w2 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
eq = EventualQueue(reactor)
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
w1.allocate_code()
|
||||
code = yield w1.get_code()
|
||||
w2.set_code(code+"not")
|
||||
|
@ -403,9 +407,8 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
# wait for each side to notice the failure
|
||||
yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
|
||||
# and then wait for the rest of the loops to fire. if we had+used
|
||||
# eventual-send, this wouldn't be a problem
|
||||
yield pause_one_tick()
|
||||
# the rest of the loops should fire within the next tick
|
||||
yield eq.flush()
|
||||
|
||||
# now all the rest should have fired already
|
||||
self.failureResultOf(d1_verified, WrongPasswordError)
|
||||
|
@ -420,27 +423,27 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
# before we close
|
||||
|
||||
# any new calls in the error state should immediately fail
|
||||
self.failureResultOf(w1.get_unverified_key(), WrongPasswordError)
|
||||
self.failureResultOf(w1.get_verifier(), WrongPasswordError)
|
||||
self.failureResultOf(w1.get_versions(), WrongPasswordError)
|
||||
self.failureResultOf(w1.get_message(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_unverified_key(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_verifier(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_versions(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_message(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_unverified_key(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_versions(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_message(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_unverified_key(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_versions(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_message(), WrongPasswordError)
|
||||
|
||||
yield self.assertFailure(w1.close(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.close(), WrongPasswordError)
|
||||
|
||||
# API calls should still get the error, not WormholeClosed
|
||||
self.failureResultOf(w1.get_unverified_key(), WrongPasswordError)
|
||||
self.failureResultOf(w1.get_verifier(), WrongPasswordError)
|
||||
self.failureResultOf(w1.get_versions(), WrongPasswordError)
|
||||
self.failureResultOf(w1.get_message(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_unverified_key(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_verifier(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_versions(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_message(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_unverified_key(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_versions(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_message(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_unverified_key(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_versions(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_message(), WrongPasswordError)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_wrong_password_with_spaces(self):
|
||||
|
@ -493,8 +496,9 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
|
||||
@inlineCallbacks
|
||||
def test_verifier(self):
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
w2 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
eq = EventualQueue(reactor)
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
w1.allocate_code()
|
||||
code = yield w1.get_code()
|
||||
w2.set_code(code)
|
||||
|
@ -510,7 +514,9 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
self.assertEqual(dataY, b"data1")
|
||||
|
||||
# calling get_verifier() this late should fire right away
|
||||
v1_late = self.successResultOf(w2.get_verifier())
|
||||
d = w2.get_verifier()
|
||||
yield eq.flush()
|
||||
v1_late = self.successResultOf(d)
|
||||
self.assertEqual(v1_late, v1)
|
||||
|
||||
yield w1.close()
|
||||
|
@ -644,26 +650,30 @@ class Reconnection(ServerBase, unittest.TestCase):
|
|||
self.assertEqual(c2, "happy")
|
||||
|
||||
class InitialFailure(unittest.TestCase):
|
||||
def assertSCEResultOf(self, d, innerType):
|
||||
@inlineCallbacks
|
||||
def assertSCEFailure(self, eq, d, innerType):
|
||||
yield eq.flush()
|
||||
f = self.failureResultOf(d, ServerConnectionError)
|
||||
inner = f.value.reason
|
||||
self.assertIsInstance(inner, innerType)
|
||||
return inner
|
||||
returnValue(inner)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_bad_dns(self):
|
||||
eq = EventualQueue(reactor)
|
||||
# point at a URL that will never connect
|
||||
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", reactor)
|
||||
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1",
|
||||
reactor, _eventual_queue=eq)
|
||||
# that should have already received an error, when it tried to
|
||||
# resolve the bogus DNS name. All API calls will return an error.
|
||||
e = yield self.assertFailure(w.get_unverified_key(),
|
||||
ServerConnectionError)
|
||||
self.assertIsInstance(e.reason, ValueError)
|
||||
|
||||
e = yield self.assertSCEFailure(eq, w.get_unverified_key(), ValueError)
|
||||
self.assertIsInstance(e, ValueError)
|
||||
self.assertEqual(str(e), "invalid hostname: %%%.example.org")
|
||||
self.assertSCEResultOf(w.get_code(), ValueError)
|
||||
self.assertSCEResultOf(w.get_verifier(), ValueError)
|
||||
self.assertSCEResultOf(w.get_versions(), ValueError)
|
||||
self.assertSCEResultOf(w.get_message(), ValueError)
|
||||
yield self.assertSCEFailure(eq, w.get_code(), ValueError)
|
||||
yield self.assertSCEFailure(eq, w.get_verifier(), ValueError)
|
||||
yield self.assertSCEFailure(eq, w.get_versions(), ValueError)
|
||||
yield self.assertSCEFailure(eq, w.get_message(), ValueError)
|
||||
|
||||
@inlineCallbacks
|
||||
def assertSCE(self, d, innerType):
|
||||
|
|
|
@ -3,9 +3,10 @@ import os, sys
|
|||
from attr import attrs, attrib
|
||||
from zope.interface import implementer
|
||||
from twisted.python import failure
|
||||
from twisted.internet import defer
|
||||
from ._interfaces import IWormhole, IDeferredWormhole
|
||||
from .util import bytes_to_hexstr
|
||||
from .eventual import EventualQueue
|
||||
from .observer import OneShotObserver, SequenceObserver
|
||||
from .timing import DebugTiming
|
||||
from .journal import ImmediateJournal
|
||||
from ._boss import Boss
|
||||
|
@ -100,22 +101,16 @@ class _DelegatedWormhole(object):
|
|||
|
||||
@implementer(IWormhole, IDeferredWormhole)
|
||||
class _DeferredWormhole(object):
|
||||
def __init__(self):
|
||||
self._welcome = None
|
||||
self._welcome_observers = []
|
||||
self._code = None
|
||||
self._code_observers = []
|
||||
def __init__(self, eq):
|
||||
self._welcome_observer = OneShotObserver(eq)
|
||||
self._code_observer = OneShotObserver(eq)
|
||||
self._key = None
|
||||
self._key_observers = []
|
||||
self._verifier = None
|
||||
self._verifier_observers = []
|
||||
self._versions = None
|
||||
self._version_observers = []
|
||||
self._received_data = []
|
||||
self._received_observers = []
|
||||
self._observer_result = None
|
||||
self._closed_result = None
|
||||
self._closed_observers = []
|
||||
self._key_observer = OneShotObserver(eq)
|
||||
self._verifier_observer = OneShotObserver(eq)
|
||||
self._version_observer = OneShotObserver(eq)
|
||||
self._received_observer = SequenceObserver(eq)
|
||||
self._closed = False
|
||||
self._closed_observer = OneShotObserver(eq)
|
||||
|
||||
def _set_boss(self, boss):
|
||||
self._boss = boss
|
||||
|
@ -127,58 +122,22 @@ class _DeferredWormhole(object):
|
|||
# the process that will cause it to fire, but forbidding that
|
||||
# ordering would make it easier to cause programming errors that
|
||||
# forget to trigger it entirely.
|
||||
if self._observer_result is not None:
|
||||
return defer.fail(self._observer_result)
|
||||
if self._code is not None:
|
||||
return defer.succeed(self._code)
|
||||
d = defer.Deferred()
|
||||
self._code_observers.append(d)
|
||||
return d
|
||||
return self._code_observer.when_fired()
|
||||
|
||||
def get_welcome(self):
|
||||
if self._observer_result is not None:
|
||||
return defer.fail(self._observer_result)
|
||||
if self._welcome is not None:
|
||||
return defer.succeed(self._welcome)
|
||||
d = defer.Deferred()
|
||||
self._welcome_observers.append(d)
|
||||
return d
|
||||
return self._welcome_observer.when_fired()
|
||||
|
||||
def get_unverified_key(self):
|
||||
if self._observer_result is not None:
|
||||
return defer.fail(self._observer_result)
|
||||
if self._key is not None:
|
||||
return defer.succeed(self._key)
|
||||
d = defer.Deferred()
|
||||
self._key_observers.append(d)
|
||||
return d
|
||||
return self._key_observer.when_fired()
|
||||
|
||||
def get_verifier(self):
|
||||
if self._observer_result is not None:
|
||||
return defer.fail(self._observer_result)
|
||||
if self._verifier is not None:
|
||||
return defer.succeed(self._verifier)
|
||||
d = defer.Deferred()
|
||||
self._verifier_observers.append(d)
|
||||
return d
|
||||
return self._verifier_observer.when_fired()
|
||||
|
||||
def get_versions(self):
|
||||
if self._observer_result is not None:
|
||||
return defer.fail(self._observer_result)
|
||||
if self._versions is not None:
|
||||
return defer.succeed(self._versions)
|
||||
d = defer.Deferred()
|
||||
self._version_observers.append(d)
|
||||
return d
|
||||
return self._version_observer.when_fired()
|
||||
|
||||
def get_message(self):
|
||||
if self._observer_result is not None:
|
||||
return defer.fail(self._observer_result)
|
||||
if self._received_data:
|
||||
return defer.succeed(self._received_data.pop(0))
|
||||
d = defer.Deferred()
|
||||
self._received_observers.append(d)
|
||||
return d
|
||||
return self._received_observer.when_next_event()
|
||||
|
||||
def allocate_code(self, code_length=2):
|
||||
self._boss.allocate_code(code_length)
|
||||
|
@ -207,10 +166,8 @@ class _DeferredWormhole(object):
|
|||
# fails with WormholeError unless we established a connection
|
||||
# (state=="happy"). Fails with WrongPasswordError (a subclass of
|
||||
# WormholeError) if state=="scary".
|
||||
if self._closed_result:
|
||||
return defer.succeed(self._closed_result) # maybe Failure
|
||||
d = defer.Deferred()
|
||||
self._closed_observers.append(d)
|
||||
d = self._closed_observer.when_fired() # maybe Failure
|
||||
if not self._closed:
|
||||
self._boss.close() # only need to close if it wasn't already
|
||||
return d
|
||||
|
||||
|
@ -221,75 +178,56 @@ class _DeferredWormhole(object):
|
|||
|
||||
# from below
|
||||
def got_welcome(self, welcome):
|
||||
self._welcome = welcome
|
||||
for d in self._welcome_observers:
|
||||
d.callback(welcome)
|
||||
self._welcome_observers[:] = []
|
||||
self._welcome_observer.fire_if_not_fired(welcome)
|
||||
def got_code(self, code):
|
||||
self._code = code
|
||||
for d in self._code_observers:
|
||||
d.callback(code)
|
||||
self._code_observers[:] = []
|
||||
self._code_observer.fire_if_not_fired(code)
|
||||
def got_key(self, key):
|
||||
self._key = key # for derive_key()
|
||||
for d in self._key_observers:
|
||||
d.callback(key)
|
||||
self._key_observers[:] = []
|
||||
self._key_observer.fire_if_not_fired(key)
|
||||
|
||||
def got_verifier(self, verifier):
|
||||
self._verifier = verifier
|
||||
for d in self._verifier_observers:
|
||||
d.callback(verifier)
|
||||
self._verifier_observers[:] = []
|
||||
self._verifier_observer.fire_if_not_fired(verifier)
|
||||
def got_versions(self, versions):
|
||||
self._versions = versions
|
||||
for d in self._version_observers:
|
||||
d.callback(versions)
|
||||
self._version_observers[:] = []
|
||||
self._version_observer.fire_if_not_fired(versions)
|
||||
|
||||
def received(self, plaintext):
|
||||
if self._received_observers:
|
||||
self._received_observers.pop(0).callback(plaintext)
|
||||
return
|
||||
self._received_data.append(plaintext)
|
||||
self._received_observer.fire(plaintext)
|
||||
|
||||
def closed(self, result):
|
||||
self._closed = True
|
||||
#print("closed", result, type(result), file=sys.stderr)
|
||||
if isinstance(result, Exception):
|
||||
self._observer_result = self._closed_result = failure.Failure(result)
|
||||
# everything pending gets an error, including close()
|
||||
f = failure.Failure(result)
|
||||
self._closed_observer.error(f)
|
||||
else:
|
||||
# pending w.key()/w.verify()/w.version()/w.read() get an error
|
||||
self._observer_result = WormholeClosed(result)
|
||||
# everything pending except close() gets an error:
|
||||
# w.get_code()/welcome/unverified_key/verifier/versions/message
|
||||
f = failure.Failure(WormholeClosed(result))
|
||||
# but w.close() only gets error if we're unhappy
|
||||
self._closed_result = result
|
||||
for d in self._welcome_observers:
|
||||
d.errback(self._observer_result)
|
||||
for d in self._code_observers:
|
||||
d.errback(self._observer_result)
|
||||
for d in self._key_observers:
|
||||
d.errback(self._observer_result)
|
||||
for d in self._verifier_observers:
|
||||
d.errback(self._observer_result)
|
||||
for d in self._version_observers:
|
||||
d.errback(self._observer_result)
|
||||
for d in self._received_observers:
|
||||
d.errback(self._observer_result)
|
||||
for d in self._closed_observers:
|
||||
d.callback(self._closed_result)
|
||||
self._closed_observer.fire_if_not_fired(result)
|
||||
self._welcome_observer.error(f)
|
||||
self._code_observer.error(f)
|
||||
self._key_observer.error(f)
|
||||
self._verifier_observer.error(f)
|
||||
self._version_observer.error(f)
|
||||
self._received_observer.fire(f)
|
||||
|
||||
|
||||
def create(appid, relay_url, reactor, # use keyword args for everything else
|
||||
versions={},
|
||||
delegate=None, journal=None, tor=None,
|
||||
timing=None,
|
||||
stderr=sys.stderr):
|
||||
stderr=sys.stderr,
|
||||
_eventual_queue=None):
|
||||
timing = timing or DebugTiming()
|
||||
side = bytes_to_hexstr(os.urandom(5))
|
||||
journal = journal or ImmediateJournal()
|
||||
eq = _eventual_queue or EventualQueue(reactor)
|
||||
if delegate:
|
||||
w = _DelegatedWormhole(delegate)
|
||||
else:
|
||||
w = _DeferredWormhole()
|
||||
w = _DeferredWormhole(eq)
|
||||
wormhole_versions = {} # will be used to indicate Wormhole capabilities
|
||||
wormhole_versions["app_versions"] = versions # app-specific capabilities
|
||||
b = Boss(w, side, relay_url, appid, wormhole_versions,
|
||||
|
|
Loading…
Reference in New Issue
Block a user