Merge branch '68-first-failure'

This commit is contained in:
Brian Warner 2017-06-14 11:34:31 +01:00
commit e7bb25907a
6 changed files with 178 additions and 8 deletions

View File

@ -28,7 +28,7 @@ setup(name="magic-wormhole",
install_requires=[ install_requires=[
"spake2==0.7", "pynacl", "spake2==0.7", "pynacl",
"six", "six",
"twisted[tls]", "twisted[tls] >= 17.5.0", # 17.5.0 adds failAfterFailures=
"autobahn[twisted] >= 0.14.1", "autobahn[twisted] >= 0.14.1",
"automat", "automat",
"hkdf", "hkdf",

View File

@ -5,7 +5,7 @@ from attr import attrs, attrib
from attr.validators import provides, instance_of, optional from attr.validators import provides, instance_of, optional
from zope.interface import implementer from zope.interface import implementer
from twisted.python import log from twisted.python import log
from twisted.internet import defer, endpoints from twisted.internet import defer, endpoints, task
from twisted.application import internet from twisted.application import internet
from autobahn.twisted import websocket from autobahn.twisted import websocket
from . import _interfaces, errors from . import _interfaces, errors
@ -69,14 +69,22 @@ class RendezvousConnector(object):
_timing = attrib(validator=provides(_interfaces.ITiming)) _timing = attrib(validator=provides(_interfaces.ITiming))
def __attrs_post_init__(self): def __attrs_post_init__(self):
self._have_made_a_successful_connection = False
self._stopping = False
self._trace = None self._trace = None
self._ws = None self._ws = None
f = WSFactory(self, self._url) f = WSFactory(self, self._url)
f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600) f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
p = urlparse(self._url) p = urlparse(self._url)
ep = self._make_endpoint(p.hostname, p.port or 80) ep = self._make_endpoint(p.hostname, p.port or 80)
# TODO: change/wrap ClientService to fail if the first attempt fails
self._connector = internet.ClientService(ep, f) self._connector = internet.ClientService(ep, f)
faf = None if self._have_made_a_successful_connection else 1
d = self._connector.whenConnected(failAfterFailures=faf)
# if the initial connection fails, signal an error and shut down. do
# this in a different reactor turn to avoid some hazards
d.addBoth(lambda res: task.deferLater(self._reactor, 0.0, lambda: res))
d.addErrback(self._initial_connection_failed)
self._debug_record_inbound_f = None self._debug_record_inbound_f = None
def set_trace(self, f): def set_trace(self, f):
@ -124,8 +132,11 @@ class RendezvousConnector(object):
def stop(self): def stop(self):
# ClientService.stopService is defined to "Stop attempting to # ClientService.stopService is defined to "Stop attempting to
# reconnect and close any existing connections" # reconnect and close any existing connections"
self._stopping = True # to catch _initial_connection_failed error
d = defer.maybeDeferred(self._connector.stopService) d = defer.maybeDeferred(self._connector.stopService)
d.addErrback(log.err) # TODO: deliver error upstairs? # ClientService.stopService always fires with None, even if the
# initial connection failed, so log.err just in case
d.addErrback(log.err)
d.addBoth(self._stopped) d.addBoth(self._stopped)
@ -137,9 +148,21 @@ class RendezvousConnector(object):
def tx_allocate(self): def tx_allocate(self):
self._tx("allocate") self._tx("allocate")
# from our ClientService
def _initial_connection_failed(self, f):
if not self._stopping:
sce = errors.ServerConnectionError(f.value)
d = defer.maybeDeferred(self._connector.stopService)
# this should happen right away: the ClientService ought to be in
# the "_waiting" state, and everything in the _waiting.stop
# transition is immediate
d.addErrback(log.err) # just in case something goes wrong
d.addCallback(lambda _: self._B.error(sce))
# from our WSClient (the WebSocket protocol) # from our WSClient (the WebSocket protocol)
def ws_open(self, proto): def ws_open(self, proto):
self._debug("R.connected") self._debug("R.connected")
self._have_made_a_successful_connection = True
self._ws = proto self._ws = proto
try: try:
self._tx("bind", appid=self._appid, side=self._side) self._tx("bind", appid=self._appid, side=self._side)

View File

@ -10,7 +10,8 @@ from . import public_relay
from .. import __version__ from .. import __version__
from ..timing import DebugTiming from ..timing import DebugTiming
from ..errors import (WrongPasswordError, WelcomeError, KeyFormatError, from ..errors import (WrongPasswordError, WelcomeError, KeyFormatError,
TransferError, NoTorError, UnsendableFileError) TransferError, NoTorError, UnsendableFileError,
ServerConnectionError)
from twisted.internet.defer import inlineCallbacks, maybeDeferred from twisted.internet.defer import inlineCallbacks, maybeDeferred
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.internet.task import react from twisted.internet.task import react
@ -118,6 +119,11 @@ def _dispatch_command(reactor, cfg, command):
except TransferError as e: except TransferError as e:
print(u"TransferError: %s" % six.text_type(e), file=cfg.stderr) print(u"TransferError: %s" % six.text_type(e), file=cfg.stderr)
raise SystemExit(1) raise SystemExit(1)
except ServerConnectionError as e:
msg = fill("ERROR: " + dedent(e.__doc__))
msg += "\n" + six.text_type(e)
print(msg, file=cfg.stderr)
raise SystemExit(1)
except Exception as e: except Exception as e:
# this prints a proper traceback, whereas # this prints a proper traceback, whereas
# traceback.print_exc() just prints a TB to the "yield" # traceback.print_exc() just prints a TB to the "yield"

View File

@ -16,6 +16,13 @@ class UnsendableFileError(Exception):
class ServerError(WormholeError): class ServerError(WormholeError):
"""The relay server complained about something we did.""" """The relay server complained about something we did."""
class ServerConnectionError(WormholeError):
"""We had a problem connecting to the relay server:"""
def __init__(self, reason):
self.reason = reason
def __str__(self):
return str(self.reason)
class Timeout(WormholeError): class Timeout(WormholeError):
pass pass

View File

@ -10,11 +10,12 @@ from twisted.python import procutils, log
from twisted.internet import endpoints, reactor from twisted.internet import endpoints, reactor
from twisted.internet.utils import getProcessOutputAndValue from twisted.internet.utils import getProcessOutputAndValue
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
from twisted.internet.error import ConnectionRefusedError
from .. import __version__ from .. import __version__
from .common import ServerBase, config from .common import ServerBase, config
from ..cli import cmd_send, cmd_receive, welcome, cli from ..cli import cmd_send, cmd_receive, welcome, cli
from ..errors import (TransferError, WrongPasswordError, WelcomeError, from ..errors import (TransferError, WrongPasswordError, WelcomeError,
UnsendableFileError) UnsendableFileError, ServerConnectionError)
from .._interfaces import ITorManager from .._interfaces import ITorManager
from wormhole.server.cmd_server import MyPlugin from wormhole.server.cmd_server import MyPlugin
from wormhole.server.cli import server from wormhole.server.cli import server
@ -874,6 +875,61 @@ class NotWelcome(ServerBase, unittest.TestCase):
f = yield self.assertFailure(receive_d, WelcomeError) f = yield self.assertFailure(receive_d, WelcomeError)
self.assertEqual(str(f), "please upgrade XYZ") self.assertEqual(str(f), "please upgrade XYZ")
class NoServer(ServerBase, unittest.TestCase):
@inlineCallbacks
def setUp(self):
self._setup_relay(None)
yield self._relay_server.disownServiceParent()
@inlineCallbacks
def test_sender(self):
cfg = config("send")
cfg.hide_progress = True
cfg.listen = False
cfg.relay_url = self.relayurl
cfg.transit_helper = ""
cfg.stdout = io.StringIO()
cfg.stderr = io.StringIO()
cfg.text = "hi"
cfg.code = "1-abc"
send_d = cmd_send.send(cfg)
e = yield self.assertFailure(send_d, ServerConnectionError)
self.assertIsInstance(e.reason, ConnectionRefusedError)
@inlineCallbacks
def test_sender_allocation(self):
cfg = config("send")
cfg.hide_progress = True
cfg.listen = False
cfg.relay_url = self.relayurl
cfg.transit_helper = ""
cfg.stdout = io.StringIO()
cfg.stderr = io.StringIO()
cfg.text = "hi"
send_d = cmd_send.send(cfg)
e = yield self.assertFailure(send_d, ServerConnectionError)
self.assertIsInstance(e.reason, ConnectionRefusedError)
@inlineCallbacks
def test_receiver(self):
cfg = config("receive")
cfg.hide_progress = True
cfg.listen = False
cfg.relay_url = self.relayurl
cfg.transit_helper = ""
cfg.stdout = io.StringIO()
cfg.stderr = io.StringIO()
cfg.code = "1-abc"
receive_d = cmd_receive.receive(cfg)
e = yield self.assertFailure(receive_d, ServerConnectionError)
self.assertIsInstance(e.reason, ConnectionRefusedError)
class Cleanup(ServerBase, unittest.TestCase): class Cleanup(ServerBase, unittest.TestCase):
def make_config(self): def make_config(self):
@ -1083,6 +1139,18 @@ class Dispatch(unittest.TestCase):
expected = "TransferError: abcd\n" expected = "TransferError: abcd\n"
self.assertEqual(cfg.stderr.getvalue(), expected) self.assertEqual(cfg.stderr.getvalue(), expected)
@inlineCallbacks
def test_server_connection_error(self):
cfg = config("send")
cfg.stderr = io.StringIO()
def fake():
raise ServerConnectionError(ValueError("abcd"))
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
SystemExit)
expected = fill("ERROR: " + dedent(ServerConnectionError.__doc__))+"\n"
expected += "abcd\n"
self.assertEqual(cfg.stderr.getvalue(), expected)
@inlineCallbacks @inlineCallbacks
def test_other_error(self): def test_other_error(self):
cfg = config("send") cfg = config("send")

View File

@ -3,12 +3,14 @@ import io, re
import mock import mock
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
from twisted.internet.error import ConnectionRefusedError
from .common import ServerBase, poll_until, pause_one_tick from .common import ServerBase, poll_until, pause_one_tick
from .. import wormhole, _rendezvous from .. import wormhole, _rendezvous
from ..errors import (WrongPasswordError, from ..errors import (WrongPasswordError, ServerConnectionError,
KeyFormatError, WormholeClosed, LonelyError, KeyFormatError, WormholeClosed, LonelyError,
NoKeyError, OnlyOneCodeError) NoKeyError, OnlyOneCodeError)
from ..transit import allocate_tcp_port
APPID = "appid" APPID = "appid"
@ -617,6 +619,70 @@ class Reconnection(ServerBase, unittest.TestCase):
c2 = yield w2.close() c2 = yield w2.close()
self.assertEqual(c2, "happy") self.assertEqual(c2, "happy")
class InitialFailure(unittest.TestCase):
def assertSCEResultOf(self, d, innerType):
f = self.failureResultOf(d, ServerConnectionError)
inner = f.value.reason
self.assertIsInstance(inner, innerType)
return inner
@inlineCallbacks
def test_bad_dns(self):
# point at a URL that will never connect
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", reactor)
# that should have already received an error, when it tried to
# resolve the bogus DNS name. All API calls will return an error.
e = yield self.assertFailure(w.get_unverified_key(),
ServerConnectionError)
self.assertIsInstance(e.reason, ValueError)
self.assertEqual(str(e), "invalid hostname: %%%.example.org")
self.assertSCEResultOf(w.get_code(), ValueError)
self.assertSCEResultOf(w.get_verifier(), ValueError)
self.assertSCEResultOf(w.get_versions(), ValueError)
self.assertSCEResultOf(w.get_message(), ValueError)
@inlineCallbacks
def assertSCE(self, d, innerType):
e = yield self.assertFailure(d, ServerConnectionError)
inner = e.reason
self.assertIsInstance(inner, innerType)
returnValue(inner)
@inlineCallbacks
def test_no_connection(self):
# point at a URL that will never connect
port = allocate_tcp_port()
w = wormhole.create(APPID, "ws://127.0.0.1:%d/v1" % port, reactor)
# nothing is listening, but it will take a turn to discover that
d1 = w.get_code()
d2 = w.get_unverified_key()
d3 = w.get_verifier()
d4 = w.get_versions()
d5 = w.get_message()
yield self.assertSCE(d1, ConnectionRefusedError)
yield self.assertSCE(d2, ConnectionRefusedError)
yield self.assertSCE(d3, ConnectionRefusedError)
yield self.assertSCE(d4, ConnectionRefusedError)
yield self.assertSCE(d5, ConnectionRefusedError)
@inlineCallbacks
def test_all_deferreds(self):
# point at a URL that will never connect
port = allocate_tcp_port()
w = wormhole.create(APPID, "ws://127.0.0.1:%d/v1" % port, reactor)
# nothing is listening, but it will take a turn to discover that
w.allocate_code()
d1 = w.get_code()
d2 = w.get_unverified_key()
d3 = w.get_verifier()
d4 = w.get_versions()
d5 = w.get_message()
yield self.assertSCE(d1, ConnectionRefusedError)
yield self.assertSCE(d2, ConnectionRefusedError)
yield self.assertSCE(d3, ConnectionRefusedError)
yield self.assertSCE(d4, ConnectionRefusedError)
yield self.assertSCE(d5, ConnectionRefusedError)
class Trace(unittest.TestCase): class Trace(unittest.TestCase):
def test_basic(self): def test_basic(self):
w1 = wormhole.create(APPID, "ws://localhost:1", reactor) w1 = wormhole.create(APPID, "ws://localhost:1", reactor)