wormhole: switch to observers for all APIs

Tests can pass an EventualQueue into wormhole.create(), to override the
default. This lets the tests flush the queue without using a haphazard
real-time delay.

closes #23

(in fact, we added multiple-Deferreds-per-API a while ago, but this does it
in a much cleaner fashion, and with the safety of an eventual-send)
This commit is contained in:
Brian Warner 2018-02-24 17:48:23 -08:00
parent caabb3510c
commit be47f53e7c
3 changed files with 110 additions and 169 deletions

View File

@ -128,10 +128,3 @@ def poll_until(predicate):
d = defer.Deferred()
reactor.callLater(0.001, d.callback, None)
yield d
@defer.inlineCallbacks
def pause_one_tick():
# return a Deferred that won't fire until at least the next reactor tick
d = defer.Deferred()
reactor.callLater(0.001, d.callback, None)
yield d

View File

@ -5,12 +5,13 @@ from twisted.trial import unittest
from twisted.internet import reactor
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
from twisted.internet.error import ConnectionRefusedError
from .common import ServerBase, poll_until, pause_one_tick
from .common import ServerBase, poll_until
from .. import wormhole, _rendezvous
from ..errors import (WrongPasswordError, ServerConnectionError,
KeyFormatError, WormholeClosed, LonelyError,
NoKeyError, OnlyOneCodeError)
from ..transit import allocate_tcp_port
from ..eventual import EventualQueue
APPID = "appid"
@ -159,9 +160,6 @@ class Wormholes(ServerBase, unittest.TestCase):
verifier2 = yield w2.get_verifier()
self.assertEqual(verifier1, verifier2)
self.successResultOf(w1.get_unverified_key())
self.successResultOf(w2.get_unverified_key())
versions1 = yield w1.get_versions()
versions2 = yield w2.get_versions()
# app-versions are exercised properly in test_versions, this just
@ -186,18 +184,22 @@ class Wormholes(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_get_code_early(self):
w1 = wormhole.create(APPID, self.relayurl, reactor)
eq = EventualQueue(reactor)
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
d = w1.get_code()
w1.set_code("1-abc")
yield eq.flush()
code = self.successResultOf(d)
self.assertEqual(code, "1-abc")
yield self.assertFailure(w1.close(), LonelyError)
@inlineCallbacks
def test_get_code_late(self):
w1 = wormhole.create(APPID, self.relayurl, reactor)
eq = EventualQueue(reactor)
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w1.set_code("1-abc")
d = w1.get_code()
yield eq.flush()
code = self.successResultOf(d)
self.assertEqual(code, "1-abc")
yield self.assertFailure(w1.close(), LonelyError)
@ -323,8 +325,9 @@ class Wormholes(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_closed(self):
w1 = wormhole.create(APPID, self.relayurl, reactor)
w2 = wormhole.create(APPID, self.relayurl, reactor)
eq = EventualQueue(reactor)
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w1.set_code("123-foo")
w2.set_code("123-foo")
@ -335,14 +338,14 @@ class Wormholes(ServerBase, unittest.TestCase):
yield w1.close()
yield w2.close()
# once closed, all Deferred-yielding API calls get an immediate error
self.failureResultOf(w1.get_welcome(), WormholeClosed)
f = self.failureResultOf(w1.get_code(), WormholeClosed)
self.assertEqual(f.value.args[0], "happy")
self.failureResultOf(w1.get_unverified_key(), WormholeClosed)
self.failureResultOf(w1.get_verifier(), WormholeClosed)
self.failureResultOf(w1.get_versions(), WormholeClosed)
self.failureResultOf(w1.get_message(), WormholeClosed)
# once closed, all Deferred-yielding API calls get an prompt error
yield self.assertFailure(w1.get_welcome(), WormholeClosed)
e = yield self.assertFailure(w1.get_code(), WormholeClosed)
self.assertEqual(e.args[0], "happy")
yield self.assertFailure(w1.get_unverified_key(), WormholeClosed)
yield self.assertFailure(w1.get_verifier(), WormholeClosed)
yield self.assertFailure(w1.get_versions(), WormholeClosed)
yield self.assertFailure(w1.get_message(), WormholeClosed)
@inlineCallbacks
def test_closed_idle(self):
@ -360,17 +363,18 @@ class Wormholes(ServerBase, unittest.TestCase):
yield self.assertFailure(w1.close(), LonelyError)
self.failureResultOf(d_welcome, LonelyError)
self.failureResultOf(d_code, LonelyError)
self.failureResultOf(d_key, LonelyError)
self.failureResultOf(d_verifier, LonelyError)
self.failureResultOf(d_versions, LonelyError)
self.failureResultOf(d_message, LonelyError)
yield self.assertFailure(d_welcome, LonelyError)
yield self.assertFailure(d_code, LonelyError)
yield self.assertFailure(d_key, LonelyError)
yield self.assertFailure(d_verifier, LonelyError)
yield self.assertFailure(d_versions, LonelyError)
yield self.assertFailure(d_message, LonelyError)
@inlineCallbacks
def test_wrong_password(self):
w1 = wormhole.create(APPID, self.relayurl, reactor)
w2 = wormhole.create(APPID, self.relayurl, reactor)
eq = EventualQueue(reactor)
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w1.allocate_code()
code = yield w1.get_code()
w2.set_code(code+"not")
@ -403,9 +407,8 @@ class Wormholes(ServerBase, unittest.TestCase):
# wait for each side to notice the failure
yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
# and then wait for the rest of the loops to fire. if we had+used
# eventual-send, this wouldn't be a problem
yield pause_one_tick()
# the rest of the loops should fire within the next tick
yield eq.flush()
# now all the rest should have fired already
self.failureResultOf(d1_verified, WrongPasswordError)
@ -420,27 +423,27 @@ class Wormholes(ServerBase, unittest.TestCase):
# before we close
# any new calls in the error state should immediately fail
self.failureResultOf(w1.get_unverified_key(), WrongPasswordError)
self.failureResultOf(w1.get_verifier(), WrongPasswordError)
self.failureResultOf(w1.get_versions(), WrongPasswordError)
self.failureResultOf(w1.get_message(), WrongPasswordError)
self.failureResultOf(w2.get_unverified_key(), WrongPasswordError)
self.failureResultOf(w2.get_verifier(), WrongPasswordError)
self.failureResultOf(w2.get_versions(), WrongPasswordError)
self.failureResultOf(w2.get_message(), WrongPasswordError)
yield self.assertFailure(w1.get_unverified_key(), WrongPasswordError)
yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
yield self.assertFailure(w1.get_versions(), WrongPasswordError)
yield self.assertFailure(w1.get_message(), WrongPasswordError)
yield self.assertFailure(w2.get_unverified_key(), WrongPasswordError)
yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
yield self.assertFailure(w2.get_versions(), WrongPasswordError)
yield self.assertFailure(w2.get_message(), WrongPasswordError)
yield self.assertFailure(w1.close(), WrongPasswordError)
yield self.assertFailure(w2.close(), WrongPasswordError)
# API calls should still get the error, not WormholeClosed
self.failureResultOf(w1.get_unverified_key(), WrongPasswordError)
self.failureResultOf(w1.get_verifier(), WrongPasswordError)
self.failureResultOf(w1.get_versions(), WrongPasswordError)
self.failureResultOf(w1.get_message(), WrongPasswordError)
self.failureResultOf(w2.get_unverified_key(), WrongPasswordError)
self.failureResultOf(w2.get_verifier(), WrongPasswordError)
self.failureResultOf(w2.get_versions(), WrongPasswordError)
self.failureResultOf(w2.get_message(), WrongPasswordError)
yield self.assertFailure(w1.get_unverified_key(), WrongPasswordError)
yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
yield self.assertFailure(w1.get_versions(), WrongPasswordError)
yield self.assertFailure(w1.get_message(), WrongPasswordError)
yield self.assertFailure(w2.get_unverified_key(), WrongPasswordError)
yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
yield self.assertFailure(w2.get_versions(), WrongPasswordError)
yield self.assertFailure(w2.get_message(), WrongPasswordError)
@inlineCallbacks
def test_wrong_password_with_spaces(self):
@ -493,8 +496,9 @@ class Wormholes(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_verifier(self):
w1 = wormhole.create(APPID, self.relayurl, reactor)
w2 = wormhole.create(APPID, self.relayurl, reactor)
eq = EventualQueue(reactor)
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w1.allocate_code()
code = yield w1.get_code()
w2.set_code(code)
@ -510,7 +514,9 @@ class Wormholes(ServerBase, unittest.TestCase):
self.assertEqual(dataY, b"data1")
# calling get_verifier() this late should fire right away
v1_late = self.successResultOf(w2.get_verifier())
d = w2.get_verifier()
yield eq.flush()
v1_late = self.successResultOf(d)
self.assertEqual(v1_late, v1)
yield w1.close()
@ -644,26 +650,30 @@ class Reconnection(ServerBase, unittest.TestCase):
self.assertEqual(c2, "happy")
class InitialFailure(unittest.TestCase):
def assertSCEResultOf(self, d, innerType):
@inlineCallbacks
def assertSCEFailure(self, eq, d, innerType):
yield eq.flush()
f = self.failureResultOf(d, ServerConnectionError)
inner = f.value.reason
self.assertIsInstance(inner, innerType)
return inner
returnValue(inner)
@inlineCallbacks
def test_bad_dns(self):
eq = EventualQueue(reactor)
# point at a URL that will never connect
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", reactor)
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1",
reactor, _eventual_queue=eq)
# that should have already received an error, when it tried to
# resolve the bogus DNS name. All API calls will return an error.
e = yield self.assertFailure(w.get_unverified_key(),
ServerConnectionError)
self.assertIsInstance(e.reason, ValueError)
e = yield self.assertSCEFailure(eq, w.get_unverified_key(), ValueError)
self.assertIsInstance(e, ValueError)
self.assertEqual(str(e), "invalid hostname: %%%.example.org")
self.assertSCEResultOf(w.get_code(), ValueError)
self.assertSCEResultOf(w.get_verifier(), ValueError)
self.assertSCEResultOf(w.get_versions(), ValueError)
self.assertSCEResultOf(w.get_message(), ValueError)
yield self.assertSCEFailure(eq, w.get_code(), ValueError)
yield self.assertSCEFailure(eq, w.get_verifier(), ValueError)
yield self.assertSCEFailure(eq, w.get_versions(), ValueError)
yield self.assertSCEFailure(eq, w.get_message(), ValueError)
@inlineCallbacks
def assertSCE(self, d, innerType):

View File

@ -3,9 +3,10 @@ import os, sys
from attr import attrs, attrib
from zope.interface import implementer
from twisted.python import failure
from twisted.internet import defer
from ._interfaces import IWormhole, IDeferredWormhole
from .util import bytes_to_hexstr
from .eventual import EventualQueue
from .observer import OneShotObserver, SequenceObserver
from .timing import DebugTiming
from .journal import ImmediateJournal
from ._boss import Boss
@ -100,22 +101,16 @@ class _DelegatedWormhole(object):
@implementer(IWormhole, IDeferredWormhole)
class _DeferredWormhole(object):
def __init__(self):
self._welcome = None
self._welcome_observers = []
self._code = None
self._code_observers = []
def __init__(self, eq):
self._welcome_observer = OneShotObserver(eq)
self._code_observer = OneShotObserver(eq)
self._key = None
self._key_observers = []
self._verifier = None
self._verifier_observers = []
self._versions = None
self._version_observers = []
self._received_data = []
self._received_observers = []
self._observer_result = None
self._closed_result = None
self._closed_observers = []
self._key_observer = OneShotObserver(eq)
self._verifier_observer = OneShotObserver(eq)
self._version_observer = OneShotObserver(eq)
self._received_observer = SequenceObserver(eq)
self._closed = False
self._closed_observer = OneShotObserver(eq)
def _set_boss(self, boss):
self._boss = boss
@ -127,58 +122,22 @@ class _DeferredWormhole(object):
# the process that will cause it to fire, but forbidding that
# ordering would make it easier to cause programming errors that
# forget to trigger it entirely.
if self._observer_result is not None:
return defer.fail(self._observer_result)
if self._code is not None:
return defer.succeed(self._code)
d = defer.Deferred()
self._code_observers.append(d)
return d
return self._code_observer.when_fired()
def get_welcome(self):
if self._observer_result is not None:
return defer.fail(self._observer_result)
if self._welcome is not None:
return defer.succeed(self._welcome)
d = defer.Deferred()
self._welcome_observers.append(d)
return d
return self._welcome_observer.when_fired()
def get_unverified_key(self):
if self._observer_result is not None:
return defer.fail(self._observer_result)
if self._key is not None:
return defer.succeed(self._key)
d = defer.Deferred()
self._key_observers.append(d)
return d
return self._key_observer.when_fired()
def get_verifier(self):
if self._observer_result is not None:
return defer.fail(self._observer_result)
if self._verifier is not None:
return defer.succeed(self._verifier)
d = defer.Deferred()
self._verifier_observers.append(d)
return d
return self._verifier_observer.when_fired()
def get_versions(self):
if self._observer_result is not None:
return defer.fail(self._observer_result)
if self._versions is not None:
return defer.succeed(self._versions)
d = defer.Deferred()
self._version_observers.append(d)
return d
return self._version_observer.when_fired()
def get_message(self):
if self._observer_result is not None:
return defer.fail(self._observer_result)
if self._received_data:
return defer.succeed(self._received_data.pop(0))
d = defer.Deferred()
self._received_observers.append(d)
return d
return self._received_observer.when_next_event()
def allocate_code(self, code_length=2):
self._boss.allocate_code(code_length)
@ -207,11 +166,9 @@ class _DeferredWormhole(object):
# fails with WormholeError unless we established a connection
# (state=="happy"). Fails with WrongPasswordError (a subclass of
# WormholeError) if state=="scary".
if self._closed_result:
return defer.succeed(self._closed_result) # maybe Failure
d = defer.Deferred()
self._closed_observers.append(d)
self._boss.close() # only need to close if it wasn't already
d = self._closed_observer.when_fired() # maybe Failure
if not self._closed:
self._boss.close() # only need to close if it wasn't already
return d
def debug_set_trace(self, client_name,
@ -221,75 +178,56 @@ class _DeferredWormhole(object):
# from below
def got_welcome(self, welcome):
self._welcome = welcome
for d in self._welcome_observers:
d.callback(welcome)
self._welcome_observers[:] = []
self._welcome_observer.fire_if_not_fired(welcome)
def got_code(self, code):
self._code = code
for d in self._code_observers:
d.callback(code)
self._code_observers[:] = []
self._code_observer.fire_if_not_fired(code)
def got_key(self, key):
self._key = key # for derive_key()
for d in self._key_observers:
d.callback(key)
self._key_observers[:] = []
self._key_observer.fire_if_not_fired(key)
def got_verifier(self, verifier):
self._verifier = verifier
for d in self._verifier_observers:
d.callback(verifier)
self._verifier_observers[:] = []
self._verifier_observer.fire_if_not_fired(verifier)
def got_versions(self, versions):
self._versions = versions
for d in self._version_observers:
d.callback(versions)
self._version_observers[:] = []
self._version_observer.fire_if_not_fired(versions)
def received(self, plaintext):
if self._received_observers:
self._received_observers.pop(0).callback(plaintext)
return
self._received_data.append(plaintext)
self._received_observer.fire(plaintext)
def closed(self, result):
self._closed = True
#print("closed", result, type(result), file=sys.stderr)
if isinstance(result, Exception):
self._observer_result = self._closed_result = failure.Failure(result)
# everything pending gets an error, including close()
f = failure.Failure(result)
self._closed_observer.error(f)
else:
# pending w.key()/w.verify()/w.version()/w.read() get an error
self._observer_result = WormholeClosed(result)
# everything pending except close() gets an error:
# w.get_code()/welcome/unverified_key/verifier/versions/message
f = failure.Failure(WormholeClosed(result))
# but w.close() only gets error if we're unhappy
self._closed_result = result
for d in self._welcome_observers:
d.errback(self._observer_result)
for d in self._code_observers:
d.errback(self._observer_result)
for d in self._key_observers:
d.errback(self._observer_result)
for d in self._verifier_observers:
d.errback(self._observer_result)
for d in self._version_observers:
d.errback(self._observer_result)
for d in self._received_observers:
d.errback(self._observer_result)
for d in self._closed_observers:
d.callback(self._closed_result)
self._closed_observer.fire_if_not_fired(result)
self._welcome_observer.error(f)
self._code_observer.error(f)
self._key_observer.error(f)
self._verifier_observer.error(f)
self._version_observer.error(f)
self._received_observer.fire(f)
def create(appid, relay_url, reactor, # use keyword args for everything else
versions={},
delegate=None, journal=None, tor=None,
timing=None,
stderr=sys.stderr):
stderr=sys.stderr,
_eventual_queue=None):
timing = timing or DebugTiming()
side = bytes_to_hexstr(os.urandom(5))
journal = journal or ImmediateJournal()
eq = _eventual_queue or EventualQueue(reactor)
if delegate:
w = _DelegatedWormhole(delegate)
else:
w = _DeferredWormhole()
w = _DeferredWormhole(eq)
wormhole_versions = {} # will be used to indicate Wormhole capabilities
wormhole_versions["app_versions"] = versions # app-specific capabilities
b = Boss(w, side, relay_url, appid, wormhole_versions,