if the first connection fails, abandon the wormhole
This provides a clear error in case the user doesn't have an internet connection at all, or something is so broken with their DNS or routing that they can't reach the server. I think this is better than waiting and retrying (silently) forever. If the first connection succeeds, but is then lost, subsequent retries occur without fanfare. closes #68
This commit is contained in:
parent
caa25a05a1
commit
8f97e4e7e2
|
@ -5,7 +5,7 @@ from attr import attrs, attrib
|
|||
from attr.validators import provides, instance_of, optional
|
||||
from zope.interface import implementer
|
||||
from twisted.python import log
|
||||
from twisted.internet import defer, endpoints
|
||||
from twisted.internet import defer, endpoints, task
|
||||
from twisted.application import internet
|
||||
from autobahn.twisted import websocket
|
||||
from . import _interfaces, errors
|
||||
|
@ -69,14 +69,22 @@ class RendezvousConnector(object):
|
|||
_timing = attrib(validator=provides(_interfaces.ITiming))
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._have_made_a_successful_connection = False
|
||||
self._stopping = False
|
||||
|
||||
self._trace = None
|
||||
self._ws = None
|
||||
f = WSFactory(self, self._url)
|
||||
f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
|
||||
p = urlparse(self._url)
|
||||
ep = self._make_endpoint(p.hostname, p.port or 80)
|
||||
# TODO: change/wrap ClientService to fail if the first attempt fails
|
||||
self._connector = internet.ClientService(ep, f)
|
||||
faf = None if self._have_made_a_successful_connection else 1
|
||||
d = self._connector.whenConnected(failAfterFailures=faf)
|
||||
# if the initial connection fails, signal an error and shut down. do
|
||||
# this in a different reactor turn to avoid some hazards
|
||||
d.addBoth(lambda res: task.deferLater(self._reactor, 0.0, lambda: res))
|
||||
d.addErrback(self._initial_connection_failed)
|
||||
self._debug_record_inbound_f = None
|
||||
|
||||
def set_trace(self, f):
|
||||
|
@ -124,8 +132,11 @@ class RendezvousConnector(object):
|
|||
def stop(self):
|
||||
# ClientService.stopService is defined to "Stop attempting to
|
||||
# reconnect and close any existing connections"
|
||||
self._stopping = True # to catch _initial_connection_failed error
|
||||
d = defer.maybeDeferred(self._connector.stopService)
|
||||
d.addErrback(log.err) # TODO: deliver error upstairs?
|
||||
# ClientService.stopService always fires with None, even if the
|
||||
# initial connection failed, so log.err just in case
|
||||
d.addErrback(log.err)
|
||||
d.addBoth(self._stopped)
|
||||
|
||||
|
||||
|
@ -137,9 +148,21 @@ class RendezvousConnector(object):
|
|||
def tx_allocate(self):
|
||||
self._tx("allocate")
|
||||
|
||||
# from our ClientService
|
||||
def _initial_connection_failed(self, f):
|
||||
if not self._stopping:
|
||||
sce = errors.ServerConnectionError(f.value)
|
||||
d = defer.maybeDeferred(self._connector.stopService)
|
||||
# this should happen right away: the ClientService ought to be in
|
||||
# the "_waiting" state, and everything in the _waiting.stop
|
||||
# transition is immediate
|
||||
d.addErrback(log.err) # just in case something goes wrong
|
||||
d.addCallback(lambda _: self._B.error(sce))
|
||||
|
||||
# from our WSClient (the WebSocket protocol)
|
||||
def ws_open(self, proto):
|
||||
self._debug("R.connected")
|
||||
self._have_made_a_successful_connection = True
|
||||
self._ws = proto
|
||||
try:
|
||||
self._tx("bind", appid=self._appid, side=self._side)
|
||||
|
|
|
@ -10,7 +10,8 @@ from . import public_relay
|
|||
from .. import __version__
|
||||
from ..timing import DebugTiming
|
||||
from ..errors import (WrongPasswordError, WelcomeError, KeyFormatError,
|
||||
TransferError, NoTorError, UnsendableFileError)
|
||||
TransferError, NoTorError, UnsendableFileError,
|
||||
ServerConnectionError)
|
||||
from twisted.internet.defer import inlineCallbacks, maybeDeferred
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.internet.task import react
|
||||
|
@ -118,6 +119,11 @@ def _dispatch_command(reactor, cfg, command):
|
|||
except TransferError as e:
|
||||
print(u"TransferError: %s" % six.text_type(e), file=cfg.stderr)
|
||||
raise SystemExit(1)
|
||||
except ServerConnectionError as e:
|
||||
msg = fill("ERROR: " + dedent(e.__doc__))
|
||||
msg += "\n" + six.text_type(e)
|
||||
print(msg, file=cfg.stderr)
|
||||
raise SystemExit(1)
|
||||
except Exception as e:
|
||||
# this prints a proper traceback, whereas
|
||||
# traceback.print_exc() just prints a TB to the "yield"
|
||||
|
|
|
@ -16,6 +16,13 @@ class UnsendableFileError(Exception):
|
|||
class ServerError(WormholeError):
|
||||
"""The relay server complained about something we did."""
|
||||
|
||||
class ServerConnectionError(WormholeError):
|
||||
"""We had a problem connecting to the relay server:"""
|
||||
def __init__(self, reason):
|
||||
self.reason = reason
|
||||
def __str__(self):
|
||||
return str(self.reason)
|
||||
|
||||
class Timeout(WormholeError):
|
||||
pass
|
||||
|
||||
|
|
|
@ -10,11 +10,12 @@ from twisted.python import procutils, log
|
|||
from twisted.internet import endpoints, reactor
|
||||
from twisted.internet.utils import getProcessOutputAndValue
|
||||
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
|
||||
from twisted.internet.error import ConnectionRefusedError
|
||||
from .. import __version__
|
||||
from .common import ServerBase, config
|
||||
from ..cli import cmd_send, cmd_receive, welcome, cli
|
||||
from ..errors import (TransferError, WrongPasswordError, WelcomeError,
|
||||
UnsendableFileError)
|
||||
UnsendableFileError, ServerConnectionError)
|
||||
from .._interfaces import ITorManager
|
||||
from wormhole.server.cmd_server import MyPlugin
|
||||
from wormhole.server.cli import server
|
||||
|
@ -874,6 +875,61 @@ class NotWelcome(ServerBase, unittest.TestCase):
|
|||
f = yield self.assertFailure(receive_d, WelcomeError)
|
||||
self.assertEqual(str(f), "please upgrade XYZ")
|
||||
|
||||
class NoServer(ServerBase, unittest.TestCase):
|
||||
@inlineCallbacks
|
||||
def setUp(self):
|
||||
self._setup_relay(None)
|
||||
yield self._relay_server.disownServiceParent()
|
||||
|
||||
@inlineCallbacks
|
||||
def test_sender(self):
|
||||
cfg = config("send")
|
||||
cfg.hide_progress = True
|
||||
cfg.listen = False
|
||||
cfg.relay_url = self.relayurl
|
||||
cfg.transit_helper = ""
|
||||
cfg.stdout = io.StringIO()
|
||||
cfg.stderr = io.StringIO()
|
||||
|
||||
cfg.text = "hi"
|
||||
cfg.code = "1-abc"
|
||||
|
||||
send_d = cmd_send.send(cfg)
|
||||
e = yield self.assertFailure(send_d, ServerConnectionError)
|
||||
self.assertIsInstance(e.reason, ConnectionRefusedError)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_sender_allocation(self):
|
||||
cfg = config("send")
|
||||
cfg.hide_progress = True
|
||||
cfg.listen = False
|
||||
cfg.relay_url = self.relayurl
|
||||
cfg.transit_helper = ""
|
||||
cfg.stdout = io.StringIO()
|
||||
cfg.stderr = io.StringIO()
|
||||
|
||||
cfg.text = "hi"
|
||||
|
||||
send_d = cmd_send.send(cfg)
|
||||
e = yield self.assertFailure(send_d, ServerConnectionError)
|
||||
self.assertIsInstance(e.reason, ConnectionRefusedError)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_receiver(self):
|
||||
cfg = config("receive")
|
||||
cfg.hide_progress = True
|
||||
cfg.listen = False
|
||||
cfg.relay_url = self.relayurl
|
||||
cfg.transit_helper = ""
|
||||
cfg.stdout = io.StringIO()
|
||||
cfg.stderr = io.StringIO()
|
||||
|
||||
cfg.code = "1-abc"
|
||||
|
||||
receive_d = cmd_receive.receive(cfg)
|
||||
e = yield self.assertFailure(receive_d, ServerConnectionError)
|
||||
self.assertIsInstance(e.reason, ConnectionRefusedError)
|
||||
|
||||
class Cleanup(ServerBase, unittest.TestCase):
|
||||
|
||||
def make_config(self):
|
||||
|
@ -1083,6 +1139,18 @@ class Dispatch(unittest.TestCase):
|
|||
expected = "TransferError: abcd\n"
|
||||
self.assertEqual(cfg.stderr.getvalue(), expected)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_server_connection_error(self):
|
||||
cfg = config("send")
|
||||
cfg.stderr = io.StringIO()
|
||||
def fake():
|
||||
raise ServerConnectionError(ValueError("abcd"))
|
||||
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
|
||||
SystemExit)
|
||||
expected = fill("ERROR: " + dedent(ServerConnectionError.__doc__))+"\n"
|
||||
expected += "abcd\n"
|
||||
self.assertEqual(cfg.stderr.getvalue(), expected)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_other_error(self):
|
||||
cfg = config("send")
|
||||
|
|
|
@ -3,12 +3,14 @@ import io, re
|
|||
import mock
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet.defer import gatherResults, inlineCallbacks
|
||||
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
|
||||
from twisted.internet.error import ConnectionRefusedError
|
||||
from .common import ServerBase, poll_until, pause_one_tick
|
||||
from .. import wormhole, _rendezvous
|
||||
from ..errors import (WrongPasswordError,
|
||||
from ..errors import (WrongPasswordError, ServerConnectionError,
|
||||
KeyFormatError, WormholeClosed, LonelyError,
|
||||
NoKeyError, OnlyOneCodeError)
|
||||
from ..transit import allocate_tcp_port
|
||||
|
||||
APPID = "appid"
|
||||
|
||||
|
@ -617,6 +619,70 @@ class Reconnection(ServerBase, unittest.TestCase):
|
|||
c2 = yield w2.close()
|
||||
self.assertEqual(c2, "happy")
|
||||
|
||||
class InitialFailure(unittest.TestCase):
|
||||
def assertSCEResultOf(self, d, innerType):
|
||||
f = self.failureResultOf(d, ServerConnectionError)
|
||||
inner = f.value.reason
|
||||
self.assertIsInstance(inner, innerType)
|
||||
return inner
|
||||
|
||||
@inlineCallbacks
|
||||
def test_bad_dns(self):
|
||||
# point at a URL that will never connect
|
||||
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", reactor)
|
||||
# that should have already received an error, when it tried to
|
||||
# resolve the bogus DNS name. All API calls will return an error.
|
||||
e = yield self.assertFailure(w.get_unverified_key(),
|
||||
ServerConnectionError)
|
||||
self.assertIsInstance(e.reason, ValueError)
|
||||
self.assertEqual(str(e), "invalid hostname: %%%.example.org")
|
||||
self.assertSCEResultOf(w.get_code(), ValueError)
|
||||
self.assertSCEResultOf(w.get_verifier(), ValueError)
|
||||
self.assertSCEResultOf(w.get_versions(), ValueError)
|
||||
self.assertSCEResultOf(w.get_message(), ValueError)
|
||||
|
||||
@inlineCallbacks
|
||||
def assertSCE(self, d, innerType):
|
||||
e = yield self.assertFailure(d, ServerConnectionError)
|
||||
inner = e.reason
|
||||
self.assertIsInstance(inner, innerType)
|
||||
returnValue(inner)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_no_connection(self):
|
||||
# point at a URL that will never connect
|
||||
port = allocate_tcp_port()
|
||||
w = wormhole.create(APPID, "ws://127.0.0.1:%d/v1" % port, reactor)
|
||||
# nothing is listening, but it will take a turn to discover that
|
||||
d1 = w.get_code()
|
||||
d2 = w.get_unverified_key()
|
||||
d3 = w.get_verifier()
|
||||
d4 = w.get_versions()
|
||||
d5 = w.get_message()
|
||||
yield self.assertSCE(d1, ConnectionRefusedError)
|
||||
yield self.assertSCE(d2, ConnectionRefusedError)
|
||||
yield self.assertSCE(d3, ConnectionRefusedError)
|
||||
yield self.assertSCE(d4, ConnectionRefusedError)
|
||||
yield self.assertSCE(d5, ConnectionRefusedError)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_all_deferreds(self):
|
||||
# point at a URL that will never connect
|
||||
port = allocate_tcp_port()
|
||||
w = wormhole.create(APPID, "ws://127.0.0.1:%d/v1" % port, reactor)
|
||||
# nothing is listening, but it will take a turn to discover that
|
||||
w.allocate_code()
|
||||
d1 = w.get_code()
|
||||
d2 = w.get_unverified_key()
|
||||
d3 = w.get_verifier()
|
||||
d4 = w.get_versions()
|
||||
d5 = w.get_message()
|
||||
yield self.assertSCE(d1, ConnectionRefusedError)
|
||||
yield self.assertSCE(d2, ConnectionRefusedError)
|
||||
yield self.assertSCE(d3, ConnectionRefusedError)
|
||||
yield self.assertSCE(d4, ConnectionRefusedError)
|
||||
yield self.assertSCE(d5, ConnectionRefusedError)
|
||||
|
||||
class Trace(unittest.TestCase):
|
||||
def test_basic(self):
|
||||
w1 = wormhole.create(APPID, "ws://localhost:1", reactor)
|
||||
|
|
Loading…
Reference in New Issue
Block a user