diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index 2521b38..0b3897e 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -128,10 +128,3 @@ def poll_until(predicate): d = defer.Deferred() reactor.callLater(0.001, d.callback, None) yield d - -@defer.inlineCallbacks -def pause_one_tick(): - # return a Deferred that won't fire until at least the next reactor tick - d = defer.Deferred() - reactor.callLater(0.001, d.callback, None) - yield d diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index 012bc15..df5830f 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -5,12 +5,13 @@ from twisted.trial import unittest from twisted.internet import reactor from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue from twisted.internet.error import ConnectionRefusedError -from .common import ServerBase, poll_until, pause_one_tick +from .common import ServerBase, poll_until from .. import wormhole, _rendezvous from ..errors import (WrongPasswordError, ServerConnectionError, KeyFormatError, WormholeClosed, LonelyError, NoKeyError, OnlyOneCodeError) from ..transit import allocate_tcp_port +from ..eventual import EventualQueue APPID = "appid" @@ -159,9 +160,6 @@ class Wormholes(ServerBase, unittest.TestCase): verifier2 = yield w2.get_verifier() self.assertEqual(verifier1, verifier2) - self.successResultOf(w1.get_unverified_key()) - self.successResultOf(w2.get_unverified_key()) - versions1 = yield w1.get_versions() versions2 = yield w2.get_versions() # app-versions are exercised properly in test_versions, this just @@ -186,18 +184,22 @@ class Wormholes(ServerBase, unittest.TestCase): @inlineCallbacks def test_get_code_early(self): - w1 = wormhole.create(APPID, self.relayurl, reactor) + eq = EventualQueue(reactor) + w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq) d = w1.get_code() w1.set_code("1-abc") + yield eq.flush() code = self.successResultOf(d) self.assertEqual(code, "1-abc") yield self.assertFailure(w1.close(), LonelyError) @inlineCallbacks def test_get_code_late(self): - w1 = wormhole.create(APPID, self.relayurl, reactor) + eq = EventualQueue(reactor) + w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq) w1.set_code("1-abc") d = w1.get_code() + yield eq.flush() code = self.successResultOf(d) self.assertEqual(code, "1-abc") yield self.assertFailure(w1.close(), LonelyError) @@ -323,8 +325,9 @@ class Wormholes(ServerBase, unittest.TestCase): @inlineCallbacks def test_closed(self): - w1 = wormhole.create(APPID, self.relayurl, reactor) - w2 = wormhole.create(APPID, self.relayurl, reactor) + eq = EventualQueue(reactor) + w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq) + w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq) w1.set_code("123-foo") w2.set_code("123-foo") @@ -335,14 +338,14 @@ class Wormholes(ServerBase, unittest.TestCase): yield w1.close() yield w2.close() - # once closed, all Deferred-yielding API calls get an immediate error - self.failureResultOf(w1.get_welcome(), WormholeClosed) - f = self.failureResultOf(w1.get_code(), WormholeClosed) - self.assertEqual(f.value.args[0], "happy") - self.failureResultOf(w1.get_unverified_key(), WormholeClosed) - self.failureResultOf(w1.get_verifier(), WormholeClosed) - self.failureResultOf(w1.get_versions(), WormholeClosed) - self.failureResultOf(w1.get_message(), WormholeClosed) + # once closed, all Deferred-yielding API calls get an prompt error + yield self.assertFailure(w1.get_welcome(), WormholeClosed) + e = yield self.assertFailure(w1.get_code(), WormholeClosed) + self.assertEqual(e.args[0], "happy") + yield self.assertFailure(w1.get_unverified_key(), WormholeClosed) + yield self.assertFailure(w1.get_verifier(), WormholeClosed) + yield self.assertFailure(w1.get_versions(), WormholeClosed) + yield self.assertFailure(w1.get_message(), WormholeClosed) @inlineCallbacks def test_closed_idle(self): @@ -360,17 +363,18 @@ class Wormholes(ServerBase, unittest.TestCase): yield self.assertFailure(w1.close(), LonelyError) - self.failureResultOf(d_welcome, LonelyError) - self.failureResultOf(d_code, LonelyError) - self.failureResultOf(d_key, LonelyError) - self.failureResultOf(d_verifier, LonelyError) - self.failureResultOf(d_versions, LonelyError) - self.failureResultOf(d_message, LonelyError) + yield self.assertFailure(d_welcome, LonelyError) + yield self.assertFailure(d_code, LonelyError) + yield self.assertFailure(d_key, LonelyError) + yield self.assertFailure(d_verifier, LonelyError) + yield self.assertFailure(d_versions, LonelyError) + yield self.assertFailure(d_message, LonelyError) @inlineCallbacks def test_wrong_password(self): - w1 = wormhole.create(APPID, self.relayurl, reactor) - w2 = wormhole.create(APPID, self.relayurl, reactor) + eq = EventualQueue(reactor) + w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq) + w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq) w1.allocate_code() code = yield w1.get_code() w2.set_code(code+"not") @@ -403,9 +407,8 @@ class Wormholes(ServerBase, unittest.TestCase): # wait for each side to notice the failure yield self.assertFailure(w1.get_verifier(), WrongPasswordError) yield self.assertFailure(w2.get_verifier(), WrongPasswordError) - # and then wait for the rest of the loops to fire. if we had+used - # eventual-send, this wouldn't be a problem - yield pause_one_tick() + # the rest of the loops should fire within the next tick + yield eq.flush() # now all the rest should have fired already self.failureResultOf(d1_verified, WrongPasswordError) @@ -420,27 +423,27 @@ class Wormholes(ServerBase, unittest.TestCase): # before we close # any new calls in the error state should immediately fail - self.failureResultOf(w1.get_unverified_key(), WrongPasswordError) - self.failureResultOf(w1.get_verifier(), WrongPasswordError) - self.failureResultOf(w1.get_versions(), WrongPasswordError) - self.failureResultOf(w1.get_message(), WrongPasswordError) - self.failureResultOf(w2.get_unverified_key(), WrongPasswordError) - self.failureResultOf(w2.get_verifier(), WrongPasswordError) - self.failureResultOf(w2.get_versions(), WrongPasswordError) - self.failureResultOf(w2.get_message(), WrongPasswordError) + yield self.assertFailure(w1.get_unverified_key(), WrongPasswordError) + yield self.assertFailure(w1.get_verifier(), WrongPasswordError) + yield self.assertFailure(w1.get_versions(), WrongPasswordError) + yield self.assertFailure(w1.get_message(), WrongPasswordError) + yield self.assertFailure(w2.get_unverified_key(), WrongPasswordError) + yield self.assertFailure(w2.get_verifier(), WrongPasswordError) + yield self.assertFailure(w2.get_versions(), WrongPasswordError) + yield self.assertFailure(w2.get_message(), WrongPasswordError) yield self.assertFailure(w1.close(), WrongPasswordError) yield self.assertFailure(w2.close(), WrongPasswordError) # API calls should still get the error, not WormholeClosed - self.failureResultOf(w1.get_unverified_key(), WrongPasswordError) - self.failureResultOf(w1.get_verifier(), WrongPasswordError) - self.failureResultOf(w1.get_versions(), WrongPasswordError) - self.failureResultOf(w1.get_message(), WrongPasswordError) - self.failureResultOf(w2.get_unverified_key(), WrongPasswordError) - self.failureResultOf(w2.get_verifier(), WrongPasswordError) - self.failureResultOf(w2.get_versions(), WrongPasswordError) - self.failureResultOf(w2.get_message(), WrongPasswordError) + yield self.assertFailure(w1.get_unverified_key(), WrongPasswordError) + yield self.assertFailure(w1.get_verifier(), WrongPasswordError) + yield self.assertFailure(w1.get_versions(), WrongPasswordError) + yield self.assertFailure(w1.get_message(), WrongPasswordError) + yield self.assertFailure(w2.get_unverified_key(), WrongPasswordError) + yield self.assertFailure(w2.get_verifier(), WrongPasswordError) + yield self.assertFailure(w2.get_versions(), WrongPasswordError) + yield self.assertFailure(w2.get_message(), WrongPasswordError) @inlineCallbacks def test_wrong_password_with_spaces(self): @@ -493,8 +496,9 @@ class Wormholes(ServerBase, unittest.TestCase): @inlineCallbacks def test_verifier(self): - w1 = wormhole.create(APPID, self.relayurl, reactor) - w2 = wormhole.create(APPID, self.relayurl, reactor) + eq = EventualQueue(reactor) + w1 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq) + w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq) w1.allocate_code() code = yield w1.get_code() w2.set_code(code) @@ -510,7 +514,9 @@ class Wormholes(ServerBase, unittest.TestCase): self.assertEqual(dataY, b"data1") # calling get_verifier() this late should fire right away - v1_late = self.successResultOf(w2.get_verifier()) + d = w2.get_verifier() + yield eq.flush() + v1_late = self.successResultOf(d) self.assertEqual(v1_late, v1) yield w1.close() @@ -644,26 +650,30 @@ class Reconnection(ServerBase, unittest.TestCase): self.assertEqual(c2, "happy") class InitialFailure(unittest.TestCase): - def assertSCEResultOf(self, d, innerType): + @inlineCallbacks + def assertSCEFailure(self, eq, d, innerType): + yield eq.flush() f = self.failureResultOf(d, ServerConnectionError) inner = f.value.reason self.assertIsInstance(inner, innerType) - return inner + returnValue(inner) @inlineCallbacks def test_bad_dns(self): + eq = EventualQueue(reactor) # point at a URL that will never connect - w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", reactor) + w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", + reactor, _eventual_queue=eq) # that should have already received an error, when it tried to # resolve the bogus DNS name. All API calls will return an error. - e = yield self.assertFailure(w.get_unverified_key(), - ServerConnectionError) - self.assertIsInstance(e.reason, ValueError) + + e = yield self.assertSCEFailure(eq, w.get_unverified_key(), ValueError) + self.assertIsInstance(e, ValueError) self.assertEqual(str(e), "invalid hostname: %%%.example.org") - self.assertSCEResultOf(w.get_code(), ValueError) - self.assertSCEResultOf(w.get_verifier(), ValueError) - self.assertSCEResultOf(w.get_versions(), ValueError) - self.assertSCEResultOf(w.get_message(), ValueError) + yield self.assertSCEFailure(eq, w.get_code(), ValueError) + yield self.assertSCEFailure(eq, w.get_verifier(), ValueError) + yield self.assertSCEFailure(eq, w.get_versions(), ValueError) + yield self.assertSCEFailure(eq, w.get_message(), ValueError) @inlineCallbacks def assertSCE(self, d, innerType): diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index e58107f..72feb07 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -3,9 +3,10 @@ import os, sys from attr import attrs, attrib from zope.interface import implementer from twisted.python import failure -from twisted.internet import defer from ._interfaces import IWormhole, IDeferredWormhole from .util import bytes_to_hexstr +from .eventual import EventualQueue +from .observer import OneShotObserver, SequenceObserver from .timing import DebugTiming from .journal import ImmediateJournal from ._boss import Boss @@ -100,22 +101,16 @@ class _DelegatedWormhole(object): @implementer(IWormhole, IDeferredWormhole) class _DeferredWormhole(object): - def __init__(self): - self._welcome = None - self._welcome_observers = [] - self._code = None - self._code_observers = [] + def __init__(self, eq): + self._welcome_observer = OneShotObserver(eq) + self._code_observer = OneShotObserver(eq) self._key = None - self._key_observers = [] - self._verifier = None - self._verifier_observers = [] - self._versions = None - self._version_observers = [] - self._received_data = [] - self._received_observers = [] - self._observer_result = None - self._closed_result = None - self._closed_observers = [] + self._key_observer = OneShotObserver(eq) + self._verifier_observer = OneShotObserver(eq) + self._version_observer = OneShotObserver(eq) + self._received_observer = SequenceObserver(eq) + self._closed = False + self._closed_observer = OneShotObserver(eq) def _set_boss(self, boss): self._boss = boss @@ -127,58 +122,22 @@ class _DeferredWormhole(object): # the process that will cause it to fire, but forbidding that # ordering would make it easier to cause programming errors that # forget to trigger it entirely. - if self._observer_result is not None: - return defer.fail(self._observer_result) - if self._code is not None: - return defer.succeed(self._code) - d = defer.Deferred() - self._code_observers.append(d) - return d + return self._code_observer.when_fired() def get_welcome(self): - if self._observer_result is not None: - return defer.fail(self._observer_result) - if self._welcome is not None: - return defer.succeed(self._welcome) - d = defer.Deferred() - self._welcome_observers.append(d) - return d + return self._welcome_observer.when_fired() def get_unverified_key(self): - if self._observer_result is not None: - return defer.fail(self._observer_result) - if self._key is not None: - return defer.succeed(self._key) - d = defer.Deferred() - self._key_observers.append(d) - return d + return self._key_observer.when_fired() def get_verifier(self): - if self._observer_result is not None: - return defer.fail(self._observer_result) - if self._verifier is not None: - return defer.succeed(self._verifier) - d = defer.Deferred() - self._verifier_observers.append(d) - return d + return self._verifier_observer.when_fired() def get_versions(self): - if self._observer_result is not None: - return defer.fail(self._observer_result) - if self._versions is not None: - return defer.succeed(self._versions) - d = defer.Deferred() - self._version_observers.append(d) - return d + return self._version_observer.when_fired() def get_message(self): - if self._observer_result is not None: - return defer.fail(self._observer_result) - if self._received_data: - return defer.succeed(self._received_data.pop(0)) - d = defer.Deferred() - self._received_observers.append(d) - return d + return self._received_observer.when_next_event() def allocate_code(self, code_length=2): self._boss.allocate_code(code_length) @@ -207,11 +166,9 @@ class _DeferredWormhole(object): # fails with WormholeError unless we established a connection # (state=="happy"). Fails with WrongPasswordError (a subclass of # WormholeError) if state=="scary". - if self._closed_result: - return defer.succeed(self._closed_result) # maybe Failure - d = defer.Deferred() - self._closed_observers.append(d) - self._boss.close() # only need to close if it wasn't already + d = self._closed_observer.when_fired() # maybe Failure + if not self._closed: + self._boss.close() # only need to close if it wasn't already return d def debug_set_trace(self, client_name, @@ -221,75 +178,56 @@ class _DeferredWormhole(object): # from below def got_welcome(self, welcome): - self._welcome = welcome - for d in self._welcome_observers: - d.callback(welcome) - self._welcome_observers[:] = [] + self._welcome_observer.fire_if_not_fired(welcome) def got_code(self, code): - self._code = code - for d in self._code_observers: - d.callback(code) - self._code_observers[:] = [] + self._code_observer.fire_if_not_fired(code) def got_key(self, key): self._key = key # for derive_key() - for d in self._key_observers: - d.callback(key) - self._key_observers[:] = [] + self._key_observer.fire_if_not_fired(key) def got_verifier(self, verifier): - self._verifier = verifier - for d in self._verifier_observers: - d.callback(verifier) - self._verifier_observers[:] = [] + self._verifier_observer.fire_if_not_fired(verifier) def got_versions(self, versions): - self._versions = versions - for d in self._version_observers: - d.callback(versions) - self._version_observers[:] = [] + self._version_observer.fire_if_not_fired(versions) def received(self, plaintext): - if self._received_observers: - self._received_observers.pop(0).callback(plaintext) - return - self._received_data.append(plaintext) + self._received_observer.fire(plaintext) def closed(self, result): + self._closed = True #print("closed", result, type(result), file=sys.stderr) if isinstance(result, Exception): - self._observer_result = self._closed_result = failure.Failure(result) + # everything pending gets an error, including close() + f = failure.Failure(result) + self._closed_observer.error(f) else: - # pending w.key()/w.verify()/w.version()/w.read() get an error - self._observer_result = WormholeClosed(result) + # everything pending except close() gets an error: + # w.get_code()/welcome/unverified_key/verifier/versions/message + f = failure.Failure(WormholeClosed(result)) # but w.close() only gets error if we're unhappy - self._closed_result = result - for d in self._welcome_observers: - d.errback(self._observer_result) - for d in self._code_observers: - d.errback(self._observer_result) - for d in self._key_observers: - d.errback(self._observer_result) - for d in self._verifier_observers: - d.errback(self._observer_result) - for d in self._version_observers: - d.errback(self._observer_result) - for d in self._received_observers: - d.errback(self._observer_result) - for d in self._closed_observers: - d.callback(self._closed_result) + self._closed_observer.fire_if_not_fired(result) + self._welcome_observer.error(f) + self._code_observer.error(f) + self._key_observer.error(f) + self._verifier_observer.error(f) + self._version_observer.error(f) + self._received_observer.fire(f) def create(appid, relay_url, reactor, # use keyword args for everything else versions={}, delegate=None, journal=None, tor=None, timing=None, - stderr=sys.stderr): + stderr=sys.stderr, + _eventual_queue=None): timing = timing or DebugTiming() side = bytes_to_hexstr(os.urandom(5)) journal = journal or ImmediateJournal() + eq = _eventual_queue or EventualQueue(reactor) if delegate: w = _DelegatedWormhole(delegate) else: - w = _DeferredWormhole() + w = _DeferredWormhole(eq) wormhole_versions = {} # will be used to indicate Wormhole capabilities wormhole_versions["app_versions"] = versions # app-specific capabilities b = Boss(w, side, relay_url, appid, wormhole_versions,