diff --git a/setup.py b/setup.py index 73b4e03..435bda3 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ setup(name="magic-wormhole", install_requires=[ "spake2==0.7", "pynacl", "six", - "twisted[tls]", + "twisted[tls] >= 17.5.0", # 17.5.0 adds failAfterFailures= "autobahn[twisted] >= 0.14.1", "automat", "hkdf", diff --git a/src/wormhole/_rendezvous.py b/src/wormhole/_rendezvous.py index 269ce17..47198f0 100644 --- a/src/wormhole/_rendezvous.py +++ b/src/wormhole/_rendezvous.py @@ -5,7 +5,7 @@ from attr import attrs, attrib from attr.validators import provides, instance_of, optional from zope.interface import implementer from twisted.python import log -from twisted.internet import defer, endpoints +from twisted.internet import defer, endpoints, task from twisted.application import internet from autobahn.twisted import websocket from . import _interfaces, errors @@ -69,14 +69,22 @@ class RendezvousConnector(object): _timing = attrib(validator=provides(_interfaces.ITiming)) def __attrs_post_init__(self): + self._have_made_a_successful_connection = False + self._stopping = False + self._trace = None self._ws = None f = WSFactory(self, self._url) f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600) p = urlparse(self._url) ep = self._make_endpoint(p.hostname, p.port or 80) - # TODO: change/wrap ClientService to fail if the first attempt fails self._connector = internet.ClientService(ep, f) + faf = None if self._have_made_a_successful_connection else 1 + d = self._connector.whenConnected(failAfterFailures=faf) + # if the initial connection fails, signal an error and shut down. do + # this in a different reactor turn to avoid some hazards + d.addBoth(lambda res: task.deferLater(self._reactor, 0.0, lambda: res)) + d.addErrback(self._initial_connection_failed) self._debug_record_inbound_f = None def set_trace(self, f): @@ -124,8 +132,11 @@ class RendezvousConnector(object): def stop(self): # ClientService.stopService is defined to "Stop attempting to # reconnect and close any existing connections" + self._stopping = True # to catch _initial_connection_failed error d = defer.maybeDeferred(self._connector.stopService) - d.addErrback(log.err) # TODO: deliver error upstairs? + # ClientService.stopService always fires with None, even if the + # initial connection failed, so log.err just in case + d.addErrback(log.err) d.addBoth(self._stopped) @@ -137,9 +148,21 @@ class RendezvousConnector(object): def tx_allocate(self): self._tx("allocate") + # from our ClientService + def _initial_connection_failed(self, f): + if not self._stopping: + sce = errors.ServerConnectionError(f.value) + d = defer.maybeDeferred(self._connector.stopService) + # this should happen right away: the ClientService ought to be in + # the "_waiting" state, and everything in the _waiting.stop + # transition is immediate + d.addErrback(log.err) # just in case something goes wrong + d.addCallback(lambda _: self._B.error(sce)) + # from our WSClient (the WebSocket protocol) def ws_open(self, proto): self._debug("R.connected") + self._have_made_a_successful_connection = True self._ws = proto try: self._tx("bind", appid=self._appid, side=self._side) diff --git a/src/wormhole/cli/cli.py b/src/wormhole/cli/cli.py index 46b6d68..1b85f15 100644 --- a/src/wormhole/cli/cli.py +++ b/src/wormhole/cli/cli.py @@ -10,7 +10,8 @@ from . import public_relay from .. import __version__ from ..timing import DebugTiming from ..errors import (WrongPasswordError, WelcomeError, KeyFormatError, - TransferError, NoTorError, UnsendableFileError) + TransferError, NoTorError, UnsendableFileError, + ServerConnectionError) from twisted.internet.defer import inlineCallbacks, maybeDeferred from twisted.python.failure import Failure from twisted.internet.task import react @@ -118,6 +119,11 @@ def _dispatch_command(reactor, cfg, command): except TransferError as e: print(u"TransferError: %s" % six.text_type(e), file=cfg.stderr) raise SystemExit(1) + except ServerConnectionError as e: + msg = fill("ERROR: " + dedent(e.__doc__)) + msg += "\n" + six.text_type(e) + print(msg, file=cfg.stderr) + raise SystemExit(1) except Exception as e: # this prints a proper traceback, whereas # traceback.print_exc() just prints a TB to the "yield" diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index 7763f73..87381b4 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -16,6 +16,13 @@ class UnsendableFileError(Exception): class ServerError(WormholeError): """The relay server complained about something we did.""" +class ServerConnectionError(WormholeError): + """We had a problem connecting to the relay server:""" + def __init__(self, reason): + self.reason = reason + def __str__(self): + return str(self.reason) + class Timeout(WormholeError): pass diff --git a/src/wormhole/test/test_cli.py b/src/wormhole/test/test_cli.py index 44d8cdf..110dada 100644 --- a/src/wormhole/test/test_cli.py +++ b/src/wormhole/test/test_cli.py @@ -10,11 +10,12 @@ from twisted.python import procutils, log from twisted.internet import endpoints, reactor from twisted.internet.utils import getProcessOutputAndValue from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue +from twisted.internet.error import ConnectionRefusedError from .. import __version__ from .common import ServerBase, config from ..cli import cmd_send, cmd_receive, welcome, cli from ..errors import (TransferError, WrongPasswordError, WelcomeError, - UnsendableFileError) + UnsendableFileError, ServerConnectionError) from .._interfaces import ITorManager from wormhole.server.cmd_server import MyPlugin from wormhole.server.cli import server @@ -874,6 +875,61 @@ class NotWelcome(ServerBase, unittest.TestCase): f = yield self.assertFailure(receive_d, WelcomeError) self.assertEqual(str(f), "please upgrade XYZ") +class NoServer(ServerBase, unittest.TestCase): + @inlineCallbacks + def setUp(self): + self._setup_relay(None) + yield self._relay_server.disownServiceParent() + + @inlineCallbacks + def test_sender(self): + cfg = config("send") + cfg.hide_progress = True + cfg.listen = False + cfg.relay_url = self.relayurl + cfg.transit_helper = "" + cfg.stdout = io.StringIO() + cfg.stderr = io.StringIO() + + cfg.text = "hi" + cfg.code = "1-abc" + + send_d = cmd_send.send(cfg) + e = yield self.assertFailure(send_d, ServerConnectionError) + self.assertIsInstance(e.reason, ConnectionRefusedError) + + @inlineCallbacks + def test_sender_allocation(self): + cfg = config("send") + cfg.hide_progress = True + cfg.listen = False + cfg.relay_url = self.relayurl + cfg.transit_helper = "" + cfg.stdout = io.StringIO() + cfg.stderr = io.StringIO() + + cfg.text = "hi" + + send_d = cmd_send.send(cfg) + e = yield self.assertFailure(send_d, ServerConnectionError) + self.assertIsInstance(e.reason, ConnectionRefusedError) + + @inlineCallbacks + def test_receiver(self): + cfg = config("receive") + cfg.hide_progress = True + cfg.listen = False + cfg.relay_url = self.relayurl + cfg.transit_helper = "" + cfg.stdout = io.StringIO() + cfg.stderr = io.StringIO() + + cfg.code = "1-abc" + + receive_d = cmd_receive.receive(cfg) + e = yield self.assertFailure(receive_d, ServerConnectionError) + self.assertIsInstance(e.reason, ConnectionRefusedError) + class Cleanup(ServerBase, unittest.TestCase): def make_config(self): @@ -1083,6 +1139,18 @@ class Dispatch(unittest.TestCase): expected = "TransferError: abcd\n" self.assertEqual(cfg.stderr.getvalue(), expected) + @inlineCallbacks + def test_server_connection_error(self): + cfg = config("send") + cfg.stderr = io.StringIO() + def fake(): + raise ServerConnectionError(ValueError("abcd")) + yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake), + SystemExit) + expected = fill("ERROR: " + dedent(ServerConnectionError.__doc__))+"\n" + expected += "abcd\n" + self.assertEqual(cfg.stderr.getvalue(), expected) + @inlineCallbacks def test_other_error(self): cfg = config("send") diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index 646bd74..222f0af 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -3,12 +3,14 @@ import io, re import mock from twisted.trial import unittest from twisted.internet import reactor -from twisted.internet.defer import gatherResults, inlineCallbacks +from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue +from twisted.internet.error import ConnectionRefusedError from .common import ServerBase, poll_until, pause_one_tick from .. import wormhole, _rendezvous -from ..errors import (WrongPasswordError, +from ..errors import (WrongPasswordError, ServerConnectionError, KeyFormatError, WormholeClosed, LonelyError, NoKeyError, OnlyOneCodeError) +from ..transit import allocate_tcp_port APPID = "appid" @@ -617,6 +619,70 @@ class Reconnection(ServerBase, unittest.TestCase): c2 = yield w2.close() self.assertEqual(c2, "happy") +class InitialFailure(unittest.TestCase): + def assertSCEResultOf(self, d, innerType): + f = self.failureResultOf(d, ServerConnectionError) + inner = f.value.reason + self.assertIsInstance(inner, innerType) + return inner + + @inlineCallbacks + def test_bad_dns(self): + # point at a URL that will never connect + w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", reactor) + # that should have already received an error, when it tried to + # resolve the bogus DNS name. All API calls will return an error. + e = yield self.assertFailure(w.get_unverified_key(), + ServerConnectionError) + self.assertIsInstance(e.reason, ValueError) + self.assertEqual(str(e), "invalid hostname: %%%.example.org") + self.assertSCEResultOf(w.get_code(), ValueError) + self.assertSCEResultOf(w.get_verifier(), ValueError) + self.assertSCEResultOf(w.get_versions(), ValueError) + self.assertSCEResultOf(w.get_message(), ValueError) + + @inlineCallbacks + def assertSCE(self, d, innerType): + e = yield self.assertFailure(d, ServerConnectionError) + inner = e.reason + self.assertIsInstance(inner, innerType) + returnValue(inner) + + @inlineCallbacks + def test_no_connection(self): + # point at a URL that will never connect + port = allocate_tcp_port() + w = wormhole.create(APPID, "ws://127.0.0.1:%d/v1" % port, reactor) + # nothing is listening, but it will take a turn to discover that + d1 = w.get_code() + d2 = w.get_unverified_key() + d3 = w.get_verifier() + d4 = w.get_versions() + d5 = w.get_message() + yield self.assertSCE(d1, ConnectionRefusedError) + yield self.assertSCE(d2, ConnectionRefusedError) + yield self.assertSCE(d3, ConnectionRefusedError) + yield self.assertSCE(d4, ConnectionRefusedError) + yield self.assertSCE(d5, ConnectionRefusedError) + + @inlineCallbacks + def test_all_deferreds(self): + # point at a URL that will never connect + port = allocate_tcp_port() + w = wormhole.create(APPID, "ws://127.0.0.1:%d/v1" % port, reactor) + # nothing is listening, but it will take a turn to discover that + w.allocate_code() + d1 = w.get_code() + d2 = w.get_unverified_key() + d3 = w.get_verifier() + d4 = w.get_versions() + d5 = w.get_message() + yield self.assertSCE(d1, ConnectionRefusedError) + yield self.assertSCE(d2, ConnectionRefusedError) + yield self.assertSCE(d3, ConnectionRefusedError) + yield self.assertSCE(d4, ConnectionRefusedError) + yield self.assertSCE(d5, ConnectionRefusedError) + class Trace(unittest.TestCase): def test_basic(self): w1 = wormhole.create(APPID, "ws://localhost:1", reactor)