wormhole: switch to observers for all APIs
Tests can pass an EventualQueue into wormhole.create(), to override the default. This lets the tests flush the queue without using a haphazard real-time delay. closes #23 (in fact, we added multiple-Deferreds-per-API a while ago, but this does it in a much cleaner fashion, and with the safety of an eventual-send)
This commit is contained in:
parent
caabb3510c
commit
be47f53e7c
|
@ -128,10 +128,3 @@ def poll_until(predicate):
|
|||
d = defer.Deferred()
|
||||
reactor.callLater(0.001, d.callback, None)
|
||||
yield d
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def pause_one_tick():
|
||||
# return a Deferred that won't fire until at least the next reactor tick
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(0.001, d.callback, None)
|
||||
yield d
|
||||
|
|
|
@ -5,12 +5,13 @@ from twisted.trial import unittest
|
|||
from twisted.internet import reactor
|
||||
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
|
||||
from twisted.internet.error import ConnectionRefusedError
|
||||
from .common import ServerBase, poll_until, pause_one_tick
|
||||
from .common import ServerBase, poll_until
|
||||
from .. import wormhole, _rendezvous
|
||||
from ..errors import (WrongPasswordError, ServerConnectionError,
|
||||
KeyFormatError, WormholeClosed, LonelyError,
|
||||
NoKeyError, OnlyOneCodeError)
|
||||
from ..transit import allocate_tcp_port
|
||||
from ..eventual import EventualQueue
|
||||
|
||||
APPID = "appid"
|
||||
|
||||
|
@ -159,9 +160,6 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
verifier2 = yield w2.get_verifier()
|
||||
self.assertEqual(verifier1, verifier2)
|
||||
|
||||
self.successResultOf(w1.get_unverified_key())
|
||||
self.successResultOf(w2.get_unverified_key())
|
||||
|
||||
versions1 = yield w1.get_versions()
|
||||
versions2 = yield w2.get_versions()
|
||||
# app-versions are exercised properly in test_versions, this just
|
||||
|
@ -186,18 +184,22 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
|
||||
@inlineCallbacks
|
||||
def test_get_code_early(self):
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
eq = EventualQueue(reactor)
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
d = w1.get_code()
|
||||
w1.set_code("1-abc")
|
||||
yield eq.flush()
|
||||
code = self.successResultOf(d)
|
||||
self.assertEqual(code, "1-abc")
|
||||
yield self.assertFailure(w1.close(), LonelyError)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_get_code_late(self):
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
eq = EventualQueue(reactor)
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
w1.set_code("1-abc")
|
||||
d = w1.get_code()
|
||||
yield eq.flush()
|
||||
code = self.successResultOf(d)
|
||||
self.assertEqual(code, "1-abc")
|
||||
yield self.assertFailure(w1.close(), LonelyError)
|
||||
|
@ -323,8 +325,9 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
|
||||
@inlineCallbacks
|
||||
def test_closed(self):
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
w2 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
eq = EventualQueue(reactor)
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
w1.set_code("123-foo")
|
||||
w2.set_code("123-foo")
|
||||
|
||||
|
@ -335,14 +338,14 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
yield w1.close()
|
||||
yield w2.close()
|
||||
|
||||
# once closed, all Deferred-yielding API calls get an immediate error
|
||||
self.failureResultOf(w1.get_welcome(), WormholeClosed)
|
||||
f = self.failureResultOf(w1.get_code(), WormholeClosed)
|
||||
self.assertEqual(f.value.args[0], "happy")
|
||||
self.failureResultOf(w1.get_unverified_key(), WormholeClosed)
|
||||
self.failureResultOf(w1.get_verifier(), WormholeClosed)
|
||||
self.failureResultOf(w1.get_versions(), WormholeClosed)
|
||||
self.failureResultOf(w1.get_message(), WormholeClosed)
|
||||
# once closed, all Deferred-yielding API calls get an prompt error
|
||||
yield self.assertFailure(w1.get_welcome(), WormholeClosed)
|
||||
e = yield self.assertFailure(w1.get_code(), WormholeClosed)
|
||||
self.assertEqual(e.args[0], "happy")
|
||||
yield self.assertFailure(w1.get_unverified_key(), WormholeClosed)
|
||||
yield self.assertFailure(w1.get_verifier(), WormholeClosed)
|
||||
yield self.assertFailure(w1.get_versions(), WormholeClosed)
|
||||
yield self.assertFailure(w1.get_message(), WormholeClosed)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_closed_idle(self):
|
||||
|
@ -360,17 +363,18 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
|
||||
yield self.assertFailure(w1.close(), LonelyError)
|
||||
|
||||
self.failureResultOf(d_welcome, LonelyError)
|
||||
self.failureResultOf(d_code, LonelyError)
|
||||
self.failureResultOf(d_key, LonelyError)
|
||||
self.failureResultOf(d_verifier, LonelyError)
|
||||
self.failureResultOf(d_versions, LonelyError)
|
||||
self.failureResultOf(d_message, LonelyError)
|
||||
yield self.assertFailure(d_welcome, LonelyError)
|
||||
yield self.assertFailure(d_code, LonelyError)
|
||||
yield self.assertFailure(d_key, LonelyError)
|
||||
yield self.assertFailure(d_verifier, LonelyError)
|
||||
yield self.assertFailure(d_versions, LonelyError)
|
||||
yield self.assertFailure(d_message, LonelyError)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_wrong_password(self):
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
w2 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
eq = EventualQueue(reactor)
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
w1.allocate_code()
|
||||
code = yield w1.get_code()
|
||||
w2.set_code(code+"not")
|
||||
|
@ -403,9 +407,8 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
# wait for each side to notice the failure
|
||||
yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
|
||||
# and then wait for the rest of the loops to fire. if we had+used
|
||||
# eventual-send, this wouldn't be a problem
|
||||
yield pause_one_tick()
|
||||
# the rest of the loops should fire within the next tick
|
||||
yield eq.flush()
|
||||
|
||||
# now all the rest should have fired already
|
||||
self.failureResultOf(d1_verified, WrongPasswordError)
|
||||
|
@ -420,27 +423,27 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
# before we close
|
||||
|
||||
# any new calls in the error state should immediately fail
|
||||
self.failureResultOf(w1.get_unverified_key(), WrongPasswordError)
|
||||
self.failureResultOf(w1.get_verifier(), WrongPasswordError)
|
||||
self.failureResultOf(w1.get_versions(), WrongPasswordError)
|
||||
self.failureResultOf(w1.get_message(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_unverified_key(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_verifier(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_versions(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_message(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_unverified_key(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_versions(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_message(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_unverified_key(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_versions(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_message(), WrongPasswordError)
|
||||
|
||||
yield self.assertFailure(w1.close(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.close(), WrongPasswordError)
|
||||
|
||||
# API calls should still get the error, not WormholeClosed
|
||||
self.failureResultOf(w1.get_unverified_key(), WrongPasswordError)
|
||||
self.failureResultOf(w1.get_verifier(), WrongPasswordError)
|
||||
self.failureResultOf(w1.get_versions(), WrongPasswordError)
|
||||
self.failureResultOf(w1.get_message(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_unverified_key(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_verifier(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_versions(), WrongPasswordError)
|
||||
self.failureResultOf(w2.get_message(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_unverified_key(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_versions(), WrongPasswordError)
|
||||
yield self.assertFailure(w1.get_message(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_unverified_key(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_versions(), WrongPasswordError)
|
||||
yield self.assertFailure(w2.get_message(), WrongPasswordError)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_wrong_password_with_spaces(self):
|
||||
|
@ -493,8 +496,9 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
|
||||
@inlineCallbacks
|
||||
def test_verifier(self):
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
w2 = wormhole.create(APPID, self.relayurl, reactor)
|
||||
eq = EventualQueue(reactor)
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
|
||||
w1.allocate_code()
|
||||
code = yield w1.get_code()
|
||||
w2.set_code(code)
|
||||
|
@ -510,7 +514,9 @@ class Wormholes(ServerBase, unittest.TestCase):
|
|||
self.assertEqual(dataY, b"data1")
|
||||
|
||||
# calling get_verifier() this late should fire right away
|
||||
v1_late = self.successResultOf(w2.get_verifier())
|
||||
d = w2.get_verifier()
|
||||
yield eq.flush()
|
||||
v1_late = self.successResultOf(d)
|
||||
self.assertEqual(v1_late, v1)
|
||||
|
||||
yield w1.close()
|
||||
|
@ -644,26 +650,30 @@ class Reconnection(ServerBase, unittest.TestCase):
|
|||
self.assertEqual(c2, "happy")
|
||||
|
||||
class InitialFailure(unittest.TestCase):
|
||||
def assertSCEResultOf(self, d, innerType):
|
||||
@inlineCallbacks
|
||||
def assertSCEFailure(self, eq, d, innerType):
|
||||
yield eq.flush()
|
||||
f = self.failureResultOf(d, ServerConnectionError)
|
||||
inner = f.value.reason
|
||||
self.assertIsInstance(inner, innerType)
|
||||
return inner
|
||||
returnValue(inner)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_bad_dns(self):
|
||||
eq = EventualQueue(reactor)
|
||||
# point at a URL that will never connect
|
||||
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", reactor)
|
||||
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1",
|
||||
reactor, _eventual_queue=eq)
|
||||
# that should have already received an error, when it tried to
|
||||
# resolve the bogus DNS name. All API calls will return an error.
|
||||
e = yield self.assertFailure(w.get_unverified_key(),
|
||||
ServerConnectionError)
|
||||
self.assertIsInstance(e.reason, ValueError)
|
||||
|
||||
e = yield self.assertSCEFailure(eq, w.get_unverified_key(), ValueError)
|
||||
self.assertIsInstance(e, ValueError)
|
||||
self.assertEqual(str(e), "invalid hostname: %%%.example.org")
|
||||
self.assertSCEResultOf(w.get_code(), ValueError)
|
||||
self.assertSCEResultOf(w.get_verifier(), ValueError)
|
||||
self.assertSCEResultOf(w.get_versions(), ValueError)
|
||||
self.assertSCEResultOf(w.get_message(), ValueError)
|
||||
yield self.assertSCEFailure(eq, w.get_code(), ValueError)
|
||||
yield self.assertSCEFailure(eq, w.get_verifier(), ValueError)
|
||||
yield self.assertSCEFailure(eq, w.get_versions(), ValueError)
|
||||
yield self.assertSCEFailure(eq, w.get_message(), ValueError)
|
||||
|
||||
@inlineCallbacks
|
||||
def assertSCE(self, d, innerType):
|
||||
|
|
|
@ -3,9 +3,10 @@ import os, sys
|
|||
from attr import attrs, attrib
|
||||
from zope.interface import implementer
|
||||
from twisted.python import failure
|
||||
from twisted.internet import defer
|
||||
from ._interfaces import IWormhole, IDeferredWormhole
|
||||
from .util import bytes_to_hexstr
|
||||
from .eventual import EventualQueue
|
||||
from .observer import OneShotObserver, SequenceObserver
|
||||
from .timing import DebugTiming
|
||||
from .journal import ImmediateJournal
|
||||
from ._boss import Boss
|
||||
|
@ -100,22 +101,16 @@ class _DelegatedWormhole(object):
|
|||
|
||||
@implementer(IWormhole, IDeferredWormhole)
|
||||
class _DeferredWormhole(object):
|
||||
def __init__(self):
|
||||
self._welcome = None
|
||||
self._welcome_observers = []
|
||||
self._code = None
|
||||
self._code_observers = []
|
||||
def __init__(self, eq):
|
||||
self._welcome_observer = OneShotObserver(eq)
|
||||
self._code_observer = OneShotObserver(eq)
|
||||
self._key = None
|
||||
self._key_observers = []
|
||||
self._verifier = None
|
||||
self._verifier_observers = []
|
||||
self._versions = None
|
||||
self._version_observers = []
|
||||
self._received_data = []
|
||||
self._received_observers = []
|
||||
self._observer_result = None
|
||||
self._closed_result = None
|
||||
self._closed_observers = []
|
||||
self._key_observer = OneShotObserver(eq)
|
||||
self._verifier_observer = OneShotObserver(eq)
|
||||
self._version_observer = OneShotObserver(eq)
|
||||
self._received_observer = SequenceObserver(eq)
|
||||
self._closed = False
|
||||
self._closed_observer = OneShotObserver(eq)
|
||||
|
||||
def _set_boss(self, boss):
|
||||
self._boss = boss
|
||||
|
@ -127,58 +122,22 @@ class _DeferredWormhole(object):
|
|||
# the process that will cause it to fire, but forbidding that
|
||||
# ordering would make it easier to cause programming errors that
|
||||
# forget to trigger it entirely.
|
||||
if self._observer_result is not None:
|
||||
return defer.fail(self._observer_result)
|
||||
if self._code is not None:
|
||||
return defer.succeed(self._code)
|
||||
d = defer.Deferred()
|
||||
self._code_observers.append(d)
|
||||
return d
|
||||
return self._code_observer.when_fired()
|
||||
|
||||
def get_welcome(self):
|
||||
if self._observer_result is not None:
|
||||
return defer.fail(self._observer_result)
|
||||
if self._welcome is not None:
|
||||
return defer.succeed(self._welcome)
|
||||
d = defer.Deferred()
|
||||
self._welcome_observers.append(d)
|
||||
return d
|
||||
return self._welcome_observer.when_fired()
|
||||
|
||||
def get_unverified_key(self):
|
||||
if self._observer_result is not None:
|
||||
return defer.fail(self._observer_result)
|
||||
if self._key is not None:
|
||||
return defer.succeed(self._key)
|
||||
d = defer.Deferred()
|
||||
self._key_observers.append(d)
|
||||
return d
|
||||
return self._key_observer.when_fired()
|
||||
|
||||
def get_verifier(self):
|
||||
if self._observer_result is not None:
|
||||
return defer.fail(self._observer_result)
|
||||
if self._verifier is not None:
|
||||
return defer.succeed(self._verifier)
|
||||
d = defer.Deferred()
|
||||
self._verifier_observers.append(d)
|
||||
return d
|
||||
return self._verifier_observer.when_fired()
|
||||
|
||||
def get_versions(self):
|
||||
if self._observer_result is not None:
|
||||
return defer.fail(self._observer_result)
|
||||
if self._versions is not None:
|
||||
return defer.succeed(self._versions)
|
||||
d = defer.Deferred()
|
||||
self._version_observers.append(d)
|
||||
return d
|
||||
return self._version_observer.when_fired()
|
||||
|
||||
def get_message(self):
|
||||
if self._observer_result is not None:
|
||||
return defer.fail(self._observer_result)
|
||||
if self._received_data:
|
||||
return defer.succeed(self._received_data.pop(0))
|
||||
d = defer.Deferred()
|
||||
self._received_observers.append(d)
|
||||
return d
|
||||
return self._received_observer.when_next_event()
|
||||
|
||||
def allocate_code(self, code_length=2):
|
||||
self._boss.allocate_code(code_length)
|
||||
|
@ -207,11 +166,9 @@ class _DeferredWormhole(object):
|
|||
# fails with WormholeError unless we established a connection
|
||||
# (state=="happy"). Fails with WrongPasswordError (a subclass of
|
||||
# WormholeError) if state=="scary".
|
||||
if self._closed_result:
|
||||
return defer.succeed(self._closed_result) # maybe Failure
|
||||
d = defer.Deferred()
|
||||
self._closed_observers.append(d)
|
||||
self._boss.close() # only need to close if it wasn't already
|
||||
d = self._closed_observer.when_fired() # maybe Failure
|
||||
if not self._closed:
|
||||
self._boss.close() # only need to close if it wasn't already
|
||||
return d
|
||||
|
||||
def debug_set_trace(self, client_name,
|
||||
|
@ -221,75 +178,56 @@ class _DeferredWormhole(object):
|
|||
|
||||
# from below
|
||||
def got_welcome(self, welcome):
|
||||
self._welcome = welcome
|
||||
for d in self._welcome_observers:
|
||||
d.callback(welcome)
|
||||
self._welcome_observers[:] = []
|
||||
self._welcome_observer.fire_if_not_fired(welcome)
|
||||
def got_code(self, code):
|
||||
self._code = code
|
||||
for d in self._code_observers:
|
||||
d.callback(code)
|
||||
self._code_observers[:] = []
|
||||
self._code_observer.fire_if_not_fired(code)
|
||||
def got_key(self, key):
|
||||
self._key = key # for derive_key()
|
||||
for d in self._key_observers:
|
||||
d.callback(key)
|
||||
self._key_observers[:] = []
|
||||
self._key_observer.fire_if_not_fired(key)
|
||||
|
||||
def got_verifier(self, verifier):
|
||||
self._verifier = verifier
|
||||
for d in self._verifier_observers:
|
||||
d.callback(verifier)
|
||||
self._verifier_observers[:] = []
|
||||
self._verifier_observer.fire_if_not_fired(verifier)
|
||||
def got_versions(self, versions):
|
||||
self._versions = versions
|
||||
for d in self._version_observers:
|
||||
d.callback(versions)
|
||||
self._version_observers[:] = []
|
||||
self._version_observer.fire_if_not_fired(versions)
|
||||
|
||||
def received(self, plaintext):
|
||||
if self._received_observers:
|
||||
self._received_observers.pop(0).callback(plaintext)
|
||||
return
|
||||
self._received_data.append(plaintext)
|
||||
self._received_observer.fire(plaintext)
|
||||
|
||||
def closed(self, result):
|
||||
self._closed = True
|
||||
#print("closed", result, type(result), file=sys.stderr)
|
||||
if isinstance(result, Exception):
|
||||
self._observer_result = self._closed_result = failure.Failure(result)
|
||||
# everything pending gets an error, including close()
|
||||
f = failure.Failure(result)
|
||||
self._closed_observer.error(f)
|
||||
else:
|
||||
# pending w.key()/w.verify()/w.version()/w.read() get an error
|
||||
self._observer_result = WormholeClosed(result)
|
||||
# everything pending except close() gets an error:
|
||||
# w.get_code()/welcome/unverified_key/verifier/versions/message
|
||||
f = failure.Failure(WormholeClosed(result))
|
||||
# but w.close() only gets error if we're unhappy
|
||||
self._closed_result = result
|
||||
for d in self._welcome_observers:
|
||||
d.errback(self._observer_result)
|
||||
for d in self._code_observers:
|
||||
d.errback(self._observer_result)
|
||||
for d in self._key_observers:
|
||||
d.errback(self._observer_result)
|
||||
for d in self._verifier_observers:
|
||||
d.errback(self._observer_result)
|
||||
for d in self._version_observers:
|
||||
d.errback(self._observer_result)
|
||||
for d in self._received_observers:
|
||||
d.errback(self._observer_result)
|
||||
for d in self._closed_observers:
|
||||
d.callback(self._closed_result)
|
||||
self._closed_observer.fire_if_not_fired(result)
|
||||
self._welcome_observer.error(f)
|
||||
self._code_observer.error(f)
|
||||
self._key_observer.error(f)
|
||||
self._verifier_observer.error(f)
|
||||
self._version_observer.error(f)
|
||||
self._received_observer.fire(f)
|
||||
|
||||
|
||||
def create(appid, relay_url, reactor, # use keyword args for everything else
|
||||
versions={},
|
||||
delegate=None, journal=None, tor=None,
|
||||
timing=None,
|
||||
stderr=sys.stderr):
|
||||
stderr=sys.stderr,
|
||||
_eventual_queue=None):
|
||||
timing = timing or DebugTiming()
|
||||
side = bytes_to_hexstr(os.urandom(5))
|
||||
journal = journal or ImmediateJournal()
|
||||
eq = _eventual_queue or EventualQueue(reactor)
|
||||
if delegate:
|
||||
w = _DelegatedWormhole(delegate)
|
||||
else:
|
||||
w = _DeferredWormhole()
|
||||
w = _DeferredWormhole(eq)
|
||||
wormhole_versions = {} # will be used to indicate Wormhole capabilities
|
||||
wormhole_versions["app_versions"] = versions # app-specific capabilities
|
||||
b = Boss(w, side, relay_url, appid, wormhole_versions,
|
||||
|
|
Loading…
Reference in New Issue
Block a user