wormhole: switch to observers for all APIs

Tests can pass an EventualQueue into wormhole.create(), to override the
default. This lets the tests flush the queue without using a haphazard
real-time delay.

closes #23

(in fact, we added multiple-Deferreds-per-API a while ago, but this does it
in a much cleaner fashion, and with the safety of an eventual-send)
This commit is contained in:
Brian Warner 2018-02-24 17:48:23 -08:00
parent caabb3510c
commit be47f53e7c
3 changed files with 110 additions and 169 deletions

View File

@ -128,10 +128,3 @@ def poll_until(predicate):
d = defer.Deferred() d = defer.Deferred()
reactor.callLater(0.001, d.callback, None) reactor.callLater(0.001, d.callback, None)
yield d yield d
@defer.inlineCallbacks
def pause_one_tick():
# return a Deferred that won't fire until at least the next reactor tick
d = defer.Deferred()
reactor.callLater(0.001, d.callback, None)
yield d

View File

@ -5,12 +5,13 @@ from twisted.trial import unittest
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
from twisted.internet.error import ConnectionRefusedError from twisted.internet.error import ConnectionRefusedError
from .common import ServerBase, poll_until, pause_one_tick from .common import ServerBase, poll_until
from .. import wormhole, _rendezvous from .. import wormhole, _rendezvous
from ..errors import (WrongPasswordError, ServerConnectionError, from ..errors import (WrongPasswordError, ServerConnectionError,
KeyFormatError, WormholeClosed, LonelyError, KeyFormatError, WormholeClosed, LonelyError,
NoKeyError, OnlyOneCodeError) NoKeyError, OnlyOneCodeError)
from ..transit import allocate_tcp_port from ..transit import allocate_tcp_port
from ..eventual import EventualQueue
APPID = "appid" APPID = "appid"
@ -159,9 +160,6 @@ class Wormholes(ServerBase, unittest.TestCase):
verifier2 = yield w2.get_verifier() verifier2 = yield w2.get_verifier()
self.assertEqual(verifier1, verifier2) self.assertEqual(verifier1, verifier2)
self.successResultOf(w1.get_unverified_key())
self.successResultOf(w2.get_unverified_key())
versions1 = yield w1.get_versions() versions1 = yield w1.get_versions()
versions2 = yield w2.get_versions() versions2 = yield w2.get_versions()
# app-versions are exercised properly in test_versions, this just # app-versions are exercised properly in test_versions, this just
@ -186,18 +184,22 @@ class Wormholes(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_get_code_early(self): def test_get_code_early(self):
w1 = wormhole.create(APPID, self.relayurl, reactor) eq = EventualQueue(reactor)
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
d = w1.get_code() d = w1.get_code()
w1.set_code("1-abc") w1.set_code("1-abc")
yield eq.flush()
code = self.successResultOf(d) code = self.successResultOf(d)
self.assertEqual(code, "1-abc") self.assertEqual(code, "1-abc")
yield self.assertFailure(w1.close(), LonelyError) yield self.assertFailure(w1.close(), LonelyError)
@inlineCallbacks @inlineCallbacks
def test_get_code_late(self): def test_get_code_late(self):
w1 = wormhole.create(APPID, self.relayurl, reactor) eq = EventualQueue(reactor)
w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w1.set_code("1-abc") w1.set_code("1-abc")
d = w1.get_code() d = w1.get_code()
yield eq.flush()
code = self.successResultOf(d) code = self.successResultOf(d)
self.assertEqual(code, "1-abc") self.assertEqual(code, "1-abc")
yield self.assertFailure(w1.close(), LonelyError) yield self.assertFailure(w1.close(), LonelyError)
@ -323,8 +325,9 @@ class Wormholes(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_closed(self): def test_closed(self):
w1 = wormhole.create(APPID, self.relayurl, reactor) eq = EventualQueue(reactor)
w2 = wormhole.create(APPID, self.relayurl, reactor) w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w1.set_code("123-foo") w1.set_code("123-foo")
w2.set_code("123-foo") w2.set_code("123-foo")
@ -335,14 +338,14 @@ class Wormholes(ServerBase, unittest.TestCase):
yield w1.close() yield w1.close()
yield w2.close() yield w2.close()
# once closed, all Deferred-yielding API calls get an immediate error # once closed, all Deferred-yielding API calls get an prompt error
self.failureResultOf(w1.get_welcome(), WormholeClosed) yield self.assertFailure(w1.get_welcome(), WormholeClosed)
f = self.failureResultOf(w1.get_code(), WormholeClosed) e = yield self.assertFailure(w1.get_code(), WormholeClosed)
self.assertEqual(f.value.args[0], "happy") self.assertEqual(e.args[0], "happy")
self.failureResultOf(w1.get_unverified_key(), WormholeClosed) yield self.assertFailure(w1.get_unverified_key(), WormholeClosed)
self.failureResultOf(w1.get_verifier(), WormholeClosed) yield self.assertFailure(w1.get_verifier(), WormholeClosed)
self.failureResultOf(w1.get_versions(), WormholeClosed) yield self.assertFailure(w1.get_versions(), WormholeClosed)
self.failureResultOf(w1.get_message(), WormholeClosed) yield self.assertFailure(w1.get_message(), WormholeClosed)
@inlineCallbacks @inlineCallbacks
def test_closed_idle(self): def test_closed_idle(self):
@ -360,17 +363,18 @@ class Wormholes(ServerBase, unittest.TestCase):
yield self.assertFailure(w1.close(), LonelyError) yield self.assertFailure(w1.close(), LonelyError)
self.failureResultOf(d_welcome, LonelyError) yield self.assertFailure(d_welcome, LonelyError)
self.failureResultOf(d_code, LonelyError) yield self.assertFailure(d_code, LonelyError)
self.failureResultOf(d_key, LonelyError) yield self.assertFailure(d_key, LonelyError)
self.failureResultOf(d_verifier, LonelyError) yield self.assertFailure(d_verifier, LonelyError)
self.failureResultOf(d_versions, LonelyError) yield self.assertFailure(d_versions, LonelyError)
self.failureResultOf(d_message, LonelyError) yield self.assertFailure(d_message, LonelyError)
@inlineCallbacks @inlineCallbacks
def test_wrong_password(self): def test_wrong_password(self):
w1 = wormhole.create(APPID, self.relayurl, reactor) eq = EventualQueue(reactor)
w2 = wormhole.create(APPID, self.relayurl, reactor) w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w1.allocate_code() w1.allocate_code()
code = yield w1.get_code() code = yield w1.get_code()
w2.set_code(code+"not") w2.set_code(code+"not")
@ -403,9 +407,8 @@ class Wormholes(ServerBase, unittest.TestCase):
# wait for each side to notice the failure # wait for each side to notice the failure
yield self.assertFailure(w1.get_verifier(), WrongPasswordError) yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
yield self.assertFailure(w2.get_verifier(), WrongPasswordError) yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
# and then wait for the rest of the loops to fire. if we had+used # the rest of the loops should fire within the next tick
# eventual-send, this wouldn't be a problem yield eq.flush()
yield pause_one_tick()
# now all the rest should have fired already # now all the rest should have fired already
self.failureResultOf(d1_verified, WrongPasswordError) self.failureResultOf(d1_verified, WrongPasswordError)
@ -420,27 +423,27 @@ class Wormholes(ServerBase, unittest.TestCase):
# before we close # before we close
# any new calls in the error state should immediately fail # any new calls in the error state should immediately fail
self.failureResultOf(w1.get_unverified_key(), WrongPasswordError) yield self.assertFailure(w1.get_unverified_key(), WrongPasswordError)
self.failureResultOf(w1.get_verifier(), WrongPasswordError) yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
self.failureResultOf(w1.get_versions(), WrongPasswordError) yield self.assertFailure(w1.get_versions(), WrongPasswordError)
self.failureResultOf(w1.get_message(), WrongPasswordError) yield self.assertFailure(w1.get_message(), WrongPasswordError)
self.failureResultOf(w2.get_unverified_key(), WrongPasswordError) yield self.assertFailure(w2.get_unverified_key(), WrongPasswordError)
self.failureResultOf(w2.get_verifier(), WrongPasswordError) yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
self.failureResultOf(w2.get_versions(), WrongPasswordError) yield self.assertFailure(w2.get_versions(), WrongPasswordError)
self.failureResultOf(w2.get_message(), WrongPasswordError) yield self.assertFailure(w2.get_message(), WrongPasswordError)
yield self.assertFailure(w1.close(), WrongPasswordError) yield self.assertFailure(w1.close(), WrongPasswordError)
yield self.assertFailure(w2.close(), WrongPasswordError) yield self.assertFailure(w2.close(), WrongPasswordError)
# API calls should still get the error, not WormholeClosed # API calls should still get the error, not WormholeClosed
self.failureResultOf(w1.get_unverified_key(), WrongPasswordError) yield self.assertFailure(w1.get_unverified_key(), WrongPasswordError)
self.failureResultOf(w1.get_verifier(), WrongPasswordError) yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
self.failureResultOf(w1.get_versions(), WrongPasswordError) yield self.assertFailure(w1.get_versions(), WrongPasswordError)
self.failureResultOf(w1.get_message(), WrongPasswordError) yield self.assertFailure(w1.get_message(), WrongPasswordError)
self.failureResultOf(w2.get_unverified_key(), WrongPasswordError) yield self.assertFailure(w2.get_unverified_key(), WrongPasswordError)
self.failureResultOf(w2.get_verifier(), WrongPasswordError) yield self.assertFailure(w2.get_verifier(), WrongPasswordError)
self.failureResultOf(w2.get_versions(), WrongPasswordError) yield self.assertFailure(w2.get_versions(), WrongPasswordError)
self.failureResultOf(w2.get_message(), WrongPasswordError) yield self.assertFailure(w2.get_message(), WrongPasswordError)
@inlineCallbacks @inlineCallbacks
def test_wrong_password_with_spaces(self): def test_wrong_password_with_spaces(self):
@ -493,8 +496,9 @@ class Wormholes(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_verifier(self): def test_verifier(self):
w1 = wormhole.create(APPID, self.relayurl, reactor) eq = EventualQueue(reactor)
w2 = wormhole.create(APPID, self.relayurl, reactor) w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w1.allocate_code() w1.allocate_code()
code = yield w1.get_code() code = yield w1.get_code()
w2.set_code(code) w2.set_code(code)
@ -510,7 +514,9 @@ class Wormholes(ServerBase, unittest.TestCase):
self.assertEqual(dataY, b"data1") self.assertEqual(dataY, b"data1")
# calling get_verifier() this late should fire right away # calling get_verifier() this late should fire right away
v1_late = self.successResultOf(w2.get_verifier()) d = w2.get_verifier()
yield eq.flush()
v1_late = self.successResultOf(d)
self.assertEqual(v1_late, v1) self.assertEqual(v1_late, v1)
yield w1.close() yield w1.close()
@ -644,26 +650,30 @@ class Reconnection(ServerBase, unittest.TestCase):
self.assertEqual(c2, "happy") self.assertEqual(c2, "happy")
class InitialFailure(unittest.TestCase): class InitialFailure(unittest.TestCase):
def assertSCEResultOf(self, d, innerType): @inlineCallbacks
def assertSCEFailure(self, eq, d, innerType):
yield eq.flush()
f = self.failureResultOf(d, ServerConnectionError) f = self.failureResultOf(d, ServerConnectionError)
inner = f.value.reason inner = f.value.reason
self.assertIsInstance(inner, innerType) self.assertIsInstance(inner, innerType)
return inner returnValue(inner)
@inlineCallbacks @inlineCallbacks
def test_bad_dns(self): def test_bad_dns(self):
eq = EventualQueue(reactor)
# point at a URL that will never connect # point at a URL that will never connect
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", reactor) w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1",
reactor, _eventual_queue=eq)
# that should have already received an error, when it tried to # that should have already received an error, when it tried to
# resolve the bogus DNS name. All API calls will return an error. # resolve the bogus DNS name. All API calls will return an error.
e = yield self.assertFailure(w.get_unverified_key(),
ServerConnectionError) e = yield self.assertSCEFailure(eq, w.get_unverified_key(), ValueError)
self.assertIsInstance(e.reason, ValueError) self.assertIsInstance(e, ValueError)
self.assertEqual(str(e), "invalid hostname: %%%.example.org") self.assertEqual(str(e), "invalid hostname: %%%.example.org")
self.assertSCEResultOf(w.get_code(), ValueError) yield self.assertSCEFailure(eq, w.get_code(), ValueError)
self.assertSCEResultOf(w.get_verifier(), ValueError) yield self.assertSCEFailure(eq, w.get_verifier(), ValueError)
self.assertSCEResultOf(w.get_versions(), ValueError) yield self.assertSCEFailure(eq, w.get_versions(), ValueError)
self.assertSCEResultOf(w.get_message(), ValueError) yield self.assertSCEFailure(eq, w.get_message(), ValueError)
@inlineCallbacks @inlineCallbacks
def assertSCE(self, d, innerType): def assertSCE(self, d, innerType):

View File

@ -3,9 +3,10 @@ import os, sys
from attr import attrs, attrib from attr import attrs, attrib
from zope.interface import implementer from zope.interface import implementer
from twisted.python import failure from twisted.python import failure
from twisted.internet import defer
from ._interfaces import IWormhole, IDeferredWormhole from ._interfaces import IWormhole, IDeferredWormhole
from .util import bytes_to_hexstr from .util import bytes_to_hexstr
from .eventual import EventualQueue
from .observer import OneShotObserver, SequenceObserver
from .timing import DebugTiming from .timing import DebugTiming
from .journal import ImmediateJournal from .journal import ImmediateJournal
from ._boss import Boss from ._boss import Boss
@ -100,22 +101,16 @@ class _DelegatedWormhole(object):
@implementer(IWormhole, IDeferredWormhole) @implementer(IWormhole, IDeferredWormhole)
class _DeferredWormhole(object): class _DeferredWormhole(object):
def __init__(self): def __init__(self, eq):
self._welcome = None self._welcome_observer = OneShotObserver(eq)
self._welcome_observers = [] self._code_observer = OneShotObserver(eq)
self._code = None
self._code_observers = []
self._key = None self._key = None
self._key_observers = [] self._key_observer = OneShotObserver(eq)
self._verifier = None self._verifier_observer = OneShotObserver(eq)
self._verifier_observers = [] self._version_observer = OneShotObserver(eq)
self._versions = None self._received_observer = SequenceObserver(eq)
self._version_observers = [] self._closed = False
self._received_data = [] self._closed_observer = OneShotObserver(eq)
self._received_observers = []
self._observer_result = None
self._closed_result = None
self._closed_observers = []
def _set_boss(self, boss): def _set_boss(self, boss):
self._boss = boss self._boss = boss
@ -127,58 +122,22 @@ class _DeferredWormhole(object):
# the process that will cause it to fire, but forbidding that # the process that will cause it to fire, but forbidding that
# ordering would make it easier to cause programming errors that # ordering would make it easier to cause programming errors that
# forget to trigger it entirely. # forget to trigger it entirely.
if self._observer_result is not None: return self._code_observer.when_fired()
return defer.fail(self._observer_result)
if self._code is not None:
return defer.succeed(self._code)
d = defer.Deferred()
self._code_observers.append(d)
return d
def get_welcome(self): def get_welcome(self):
if self._observer_result is not None: return self._welcome_observer.when_fired()
return defer.fail(self._observer_result)
if self._welcome is not None:
return defer.succeed(self._welcome)
d = defer.Deferred()
self._welcome_observers.append(d)
return d
def get_unverified_key(self): def get_unverified_key(self):
if self._observer_result is not None: return self._key_observer.when_fired()
return defer.fail(self._observer_result)
if self._key is not None:
return defer.succeed(self._key)
d = defer.Deferred()
self._key_observers.append(d)
return d
def get_verifier(self): def get_verifier(self):
if self._observer_result is not None: return self._verifier_observer.when_fired()
return defer.fail(self._observer_result)
if self._verifier is not None:
return defer.succeed(self._verifier)
d = defer.Deferred()
self._verifier_observers.append(d)
return d
def get_versions(self): def get_versions(self):
if self._observer_result is not None: return self._version_observer.when_fired()
return defer.fail(self._observer_result)
if self._versions is not None:
return defer.succeed(self._versions)
d = defer.Deferred()
self._version_observers.append(d)
return d
def get_message(self): def get_message(self):
if self._observer_result is not None: return self._received_observer.when_next_event()
return defer.fail(self._observer_result)
if self._received_data:
return defer.succeed(self._received_data.pop(0))
d = defer.Deferred()
self._received_observers.append(d)
return d
def allocate_code(self, code_length=2): def allocate_code(self, code_length=2):
self._boss.allocate_code(code_length) self._boss.allocate_code(code_length)
@ -207,11 +166,9 @@ class _DeferredWormhole(object):
# fails with WormholeError unless we established a connection # fails with WormholeError unless we established a connection
# (state=="happy"). Fails with WrongPasswordError (a subclass of # (state=="happy"). Fails with WrongPasswordError (a subclass of
# WormholeError) if state=="scary". # WormholeError) if state=="scary".
if self._closed_result: d = self._closed_observer.when_fired() # maybe Failure
return defer.succeed(self._closed_result) # maybe Failure if not self._closed:
d = defer.Deferred() self._boss.close() # only need to close if it wasn't already
self._closed_observers.append(d)
self._boss.close() # only need to close if it wasn't already
return d return d
def debug_set_trace(self, client_name, def debug_set_trace(self, client_name,
@ -221,75 +178,56 @@ class _DeferredWormhole(object):
# from below # from below
def got_welcome(self, welcome): def got_welcome(self, welcome):
self._welcome = welcome self._welcome_observer.fire_if_not_fired(welcome)
for d in self._welcome_observers:
d.callback(welcome)
self._welcome_observers[:] = []
def got_code(self, code): def got_code(self, code):
self._code = code self._code_observer.fire_if_not_fired(code)
for d in self._code_observers:
d.callback(code)
self._code_observers[:] = []
def got_key(self, key): def got_key(self, key):
self._key = key # for derive_key() self._key = key # for derive_key()
for d in self._key_observers: self._key_observer.fire_if_not_fired(key)
d.callback(key)
self._key_observers[:] = []
def got_verifier(self, verifier): def got_verifier(self, verifier):
self._verifier = verifier self._verifier_observer.fire_if_not_fired(verifier)
for d in self._verifier_observers:
d.callback(verifier)
self._verifier_observers[:] = []
def got_versions(self, versions): def got_versions(self, versions):
self._versions = versions self._version_observer.fire_if_not_fired(versions)
for d in self._version_observers:
d.callback(versions)
self._version_observers[:] = []
def received(self, plaintext): def received(self, plaintext):
if self._received_observers: self._received_observer.fire(plaintext)
self._received_observers.pop(0).callback(plaintext)
return
self._received_data.append(plaintext)
def closed(self, result): def closed(self, result):
self._closed = True
#print("closed", result, type(result), file=sys.stderr) #print("closed", result, type(result), file=sys.stderr)
if isinstance(result, Exception): if isinstance(result, Exception):
self._observer_result = self._closed_result = failure.Failure(result) # everything pending gets an error, including close()
f = failure.Failure(result)
self._closed_observer.error(f)
else: else:
# pending w.key()/w.verify()/w.version()/w.read() get an error # everything pending except close() gets an error:
self._observer_result = WormholeClosed(result) # w.get_code()/welcome/unverified_key/verifier/versions/message
f = failure.Failure(WormholeClosed(result))
# but w.close() only gets error if we're unhappy # but w.close() only gets error if we're unhappy
self._closed_result = result self._closed_observer.fire_if_not_fired(result)
for d in self._welcome_observers: self._welcome_observer.error(f)
d.errback(self._observer_result) self._code_observer.error(f)
for d in self._code_observers: self._key_observer.error(f)
d.errback(self._observer_result) self._verifier_observer.error(f)
for d in self._key_observers: self._version_observer.error(f)
d.errback(self._observer_result) self._received_observer.fire(f)
for d in self._verifier_observers:
d.errback(self._observer_result)
for d in self._version_observers:
d.errback(self._observer_result)
for d in self._received_observers:
d.errback(self._observer_result)
for d in self._closed_observers:
d.callback(self._closed_result)
def create(appid, relay_url, reactor, # use keyword args for everything else def create(appid, relay_url, reactor, # use keyword args for everything else
versions={}, versions={},
delegate=None, journal=None, tor=None, delegate=None, journal=None, tor=None,
timing=None, timing=None,
stderr=sys.stderr): stderr=sys.stderr,
_eventual_queue=None):
timing = timing or DebugTiming() timing = timing or DebugTiming()
side = bytes_to_hexstr(os.urandom(5)) side = bytes_to_hexstr(os.urandom(5))
journal = journal or ImmediateJournal() journal = journal or ImmediateJournal()
eq = _eventual_queue or EventualQueue(reactor)
if delegate: if delegate:
w = _DelegatedWormhole(delegate) w = _DelegatedWormhole(delegate)
else: else:
w = _DeferredWormhole() w = _DeferredWormhole(eq)
wormhole_versions = {} # will be used to indicate Wormhole capabilities wormhole_versions = {} # will be used to indicate Wormhole capabilities
wormhole_versions["app_versions"] = versions # app-specific capabilities wormhole_versions["app_versions"] = versions # app-specific capabilities
b = Boss(w, side, relay_url, appid, wormhole_versions, b = Boss(w, side, relay_url, appid, wormhole_versions,