Merge branch '68-first-failure'
This commit is contained in:
commit
e7bb25907a
2
setup.py
2
setup.py
|
@ -28,7 +28,7 @@ setup(name="magic-wormhole",
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"spake2==0.7", "pynacl",
|
"spake2==0.7", "pynacl",
|
||||||
"six",
|
"six",
|
||||||
"twisted[tls]",
|
"twisted[tls] >= 17.5.0", # 17.5.0 adds failAfterFailures=
|
||||||
"autobahn[twisted] >= 0.14.1",
|
"autobahn[twisted] >= 0.14.1",
|
||||||
"automat",
|
"automat",
|
||||||
"hkdf",
|
"hkdf",
|
||||||
|
|
|
@ -5,7 +5,7 @@ from attr import attrs, attrib
|
||||||
from attr.validators import provides, instance_of, optional
|
from attr.validators import provides, instance_of, optional
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
from twisted.python import log
|
from twisted.python import log
|
||||||
from twisted.internet import defer, endpoints
|
from twisted.internet import defer, endpoints, task
|
||||||
from twisted.application import internet
|
from twisted.application import internet
|
||||||
from autobahn.twisted import websocket
|
from autobahn.twisted import websocket
|
||||||
from . import _interfaces, errors
|
from . import _interfaces, errors
|
||||||
|
@ -69,14 +69,22 @@ class RendezvousConnector(object):
|
||||||
_timing = attrib(validator=provides(_interfaces.ITiming))
|
_timing = attrib(validator=provides(_interfaces.ITiming))
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
def __attrs_post_init__(self):
|
||||||
|
self._have_made_a_successful_connection = False
|
||||||
|
self._stopping = False
|
||||||
|
|
||||||
self._trace = None
|
self._trace = None
|
||||||
self._ws = None
|
self._ws = None
|
||||||
f = WSFactory(self, self._url)
|
f = WSFactory(self, self._url)
|
||||||
f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
|
f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
|
||||||
p = urlparse(self._url)
|
p = urlparse(self._url)
|
||||||
ep = self._make_endpoint(p.hostname, p.port or 80)
|
ep = self._make_endpoint(p.hostname, p.port or 80)
|
||||||
# TODO: change/wrap ClientService to fail if the first attempt fails
|
|
||||||
self._connector = internet.ClientService(ep, f)
|
self._connector = internet.ClientService(ep, f)
|
||||||
|
faf = None if self._have_made_a_successful_connection else 1
|
||||||
|
d = self._connector.whenConnected(failAfterFailures=faf)
|
||||||
|
# if the initial connection fails, signal an error and shut down. do
|
||||||
|
# this in a different reactor turn to avoid some hazards
|
||||||
|
d.addBoth(lambda res: task.deferLater(self._reactor, 0.0, lambda: res))
|
||||||
|
d.addErrback(self._initial_connection_failed)
|
||||||
self._debug_record_inbound_f = None
|
self._debug_record_inbound_f = None
|
||||||
|
|
||||||
def set_trace(self, f):
|
def set_trace(self, f):
|
||||||
|
@ -124,8 +132,11 @@ class RendezvousConnector(object):
|
||||||
def stop(self):
|
def stop(self):
|
||||||
# ClientService.stopService is defined to "Stop attempting to
|
# ClientService.stopService is defined to "Stop attempting to
|
||||||
# reconnect and close any existing connections"
|
# reconnect and close any existing connections"
|
||||||
|
self._stopping = True # to catch _initial_connection_failed error
|
||||||
d = defer.maybeDeferred(self._connector.stopService)
|
d = defer.maybeDeferred(self._connector.stopService)
|
||||||
d.addErrback(log.err) # TODO: deliver error upstairs?
|
# ClientService.stopService always fires with None, even if the
|
||||||
|
# initial connection failed, so log.err just in case
|
||||||
|
d.addErrback(log.err)
|
||||||
d.addBoth(self._stopped)
|
d.addBoth(self._stopped)
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,9 +148,21 @@ class RendezvousConnector(object):
|
||||||
def tx_allocate(self):
|
def tx_allocate(self):
|
||||||
self._tx("allocate")
|
self._tx("allocate")
|
||||||
|
|
||||||
|
# from our ClientService
|
||||||
|
def _initial_connection_failed(self, f):
|
||||||
|
if not self._stopping:
|
||||||
|
sce = errors.ServerConnectionError(f.value)
|
||||||
|
d = defer.maybeDeferred(self._connector.stopService)
|
||||||
|
# this should happen right away: the ClientService ought to be in
|
||||||
|
# the "_waiting" state, and everything in the _waiting.stop
|
||||||
|
# transition is immediate
|
||||||
|
d.addErrback(log.err) # just in case something goes wrong
|
||||||
|
d.addCallback(lambda _: self._B.error(sce))
|
||||||
|
|
||||||
# from our WSClient (the WebSocket protocol)
|
# from our WSClient (the WebSocket protocol)
|
||||||
def ws_open(self, proto):
|
def ws_open(self, proto):
|
||||||
self._debug("R.connected")
|
self._debug("R.connected")
|
||||||
|
self._have_made_a_successful_connection = True
|
||||||
self._ws = proto
|
self._ws = proto
|
||||||
try:
|
try:
|
||||||
self._tx("bind", appid=self._appid, side=self._side)
|
self._tx("bind", appid=self._appid, side=self._side)
|
||||||
|
|
|
@ -10,7 +10,8 @@ from . import public_relay
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from ..timing import DebugTiming
|
from ..timing import DebugTiming
|
||||||
from ..errors import (WrongPasswordError, WelcomeError, KeyFormatError,
|
from ..errors import (WrongPasswordError, WelcomeError, KeyFormatError,
|
||||||
TransferError, NoTorError, UnsendableFileError)
|
TransferError, NoTorError, UnsendableFileError,
|
||||||
|
ServerConnectionError)
|
||||||
from twisted.internet.defer import inlineCallbacks, maybeDeferred
|
from twisted.internet.defer import inlineCallbacks, maybeDeferred
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.internet.task import react
|
from twisted.internet.task import react
|
||||||
|
@ -118,6 +119,11 @@ def _dispatch_command(reactor, cfg, command):
|
||||||
except TransferError as e:
|
except TransferError as e:
|
||||||
print(u"TransferError: %s" % six.text_type(e), file=cfg.stderr)
|
print(u"TransferError: %s" % six.text_type(e), file=cfg.stderr)
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
|
except ServerConnectionError as e:
|
||||||
|
msg = fill("ERROR: " + dedent(e.__doc__))
|
||||||
|
msg += "\n" + six.text_type(e)
|
||||||
|
print(msg, file=cfg.stderr)
|
||||||
|
raise SystemExit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# this prints a proper traceback, whereas
|
# this prints a proper traceback, whereas
|
||||||
# traceback.print_exc() just prints a TB to the "yield"
|
# traceback.print_exc() just prints a TB to the "yield"
|
||||||
|
|
|
@ -16,6 +16,13 @@ class UnsendableFileError(Exception):
|
||||||
class ServerError(WormholeError):
|
class ServerError(WormholeError):
|
||||||
"""The relay server complained about something we did."""
|
"""The relay server complained about something we did."""
|
||||||
|
|
||||||
|
class ServerConnectionError(WormholeError):
|
||||||
|
"""We had a problem connecting to the relay server:"""
|
||||||
|
def __init__(self, reason):
|
||||||
|
self.reason = reason
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.reason)
|
||||||
|
|
||||||
class Timeout(WormholeError):
|
class Timeout(WormholeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -10,11 +10,12 @@ from twisted.python import procutils, log
|
||||||
from twisted.internet import endpoints, reactor
|
from twisted.internet import endpoints, reactor
|
||||||
from twisted.internet.utils import getProcessOutputAndValue
|
from twisted.internet.utils import getProcessOutputAndValue
|
||||||
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
|
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
|
||||||
|
from twisted.internet.error import ConnectionRefusedError
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .common import ServerBase, config
|
from .common import ServerBase, config
|
||||||
from ..cli import cmd_send, cmd_receive, welcome, cli
|
from ..cli import cmd_send, cmd_receive, welcome, cli
|
||||||
from ..errors import (TransferError, WrongPasswordError, WelcomeError,
|
from ..errors import (TransferError, WrongPasswordError, WelcomeError,
|
||||||
UnsendableFileError)
|
UnsendableFileError, ServerConnectionError)
|
||||||
from .._interfaces import ITorManager
|
from .._interfaces import ITorManager
|
||||||
from wormhole.server.cmd_server import MyPlugin
|
from wormhole.server.cmd_server import MyPlugin
|
||||||
from wormhole.server.cli import server
|
from wormhole.server.cli import server
|
||||||
|
@ -874,6 +875,61 @@ class NotWelcome(ServerBase, unittest.TestCase):
|
||||||
f = yield self.assertFailure(receive_d, WelcomeError)
|
f = yield self.assertFailure(receive_d, WelcomeError)
|
||||||
self.assertEqual(str(f), "please upgrade XYZ")
|
self.assertEqual(str(f), "please upgrade XYZ")
|
||||||
|
|
||||||
|
class NoServer(ServerBase, unittest.TestCase):
|
||||||
|
@inlineCallbacks
|
||||||
|
def setUp(self):
|
||||||
|
self._setup_relay(None)
|
||||||
|
yield self._relay_server.disownServiceParent()
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_sender(self):
|
||||||
|
cfg = config("send")
|
||||||
|
cfg.hide_progress = True
|
||||||
|
cfg.listen = False
|
||||||
|
cfg.relay_url = self.relayurl
|
||||||
|
cfg.transit_helper = ""
|
||||||
|
cfg.stdout = io.StringIO()
|
||||||
|
cfg.stderr = io.StringIO()
|
||||||
|
|
||||||
|
cfg.text = "hi"
|
||||||
|
cfg.code = "1-abc"
|
||||||
|
|
||||||
|
send_d = cmd_send.send(cfg)
|
||||||
|
e = yield self.assertFailure(send_d, ServerConnectionError)
|
||||||
|
self.assertIsInstance(e.reason, ConnectionRefusedError)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_sender_allocation(self):
|
||||||
|
cfg = config("send")
|
||||||
|
cfg.hide_progress = True
|
||||||
|
cfg.listen = False
|
||||||
|
cfg.relay_url = self.relayurl
|
||||||
|
cfg.transit_helper = ""
|
||||||
|
cfg.stdout = io.StringIO()
|
||||||
|
cfg.stderr = io.StringIO()
|
||||||
|
|
||||||
|
cfg.text = "hi"
|
||||||
|
|
||||||
|
send_d = cmd_send.send(cfg)
|
||||||
|
e = yield self.assertFailure(send_d, ServerConnectionError)
|
||||||
|
self.assertIsInstance(e.reason, ConnectionRefusedError)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_receiver(self):
|
||||||
|
cfg = config("receive")
|
||||||
|
cfg.hide_progress = True
|
||||||
|
cfg.listen = False
|
||||||
|
cfg.relay_url = self.relayurl
|
||||||
|
cfg.transit_helper = ""
|
||||||
|
cfg.stdout = io.StringIO()
|
||||||
|
cfg.stderr = io.StringIO()
|
||||||
|
|
||||||
|
cfg.code = "1-abc"
|
||||||
|
|
||||||
|
receive_d = cmd_receive.receive(cfg)
|
||||||
|
e = yield self.assertFailure(receive_d, ServerConnectionError)
|
||||||
|
self.assertIsInstance(e.reason, ConnectionRefusedError)
|
||||||
|
|
||||||
class Cleanup(ServerBase, unittest.TestCase):
|
class Cleanup(ServerBase, unittest.TestCase):
|
||||||
|
|
||||||
def make_config(self):
|
def make_config(self):
|
||||||
|
@ -1083,6 +1139,18 @@ class Dispatch(unittest.TestCase):
|
||||||
expected = "TransferError: abcd\n"
|
expected = "TransferError: abcd\n"
|
||||||
self.assertEqual(cfg.stderr.getvalue(), expected)
|
self.assertEqual(cfg.stderr.getvalue(), expected)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_server_connection_error(self):
|
||||||
|
cfg = config("send")
|
||||||
|
cfg.stderr = io.StringIO()
|
||||||
|
def fake():
|
||||||
|
raise ServerConnectionError(ValueError("abcd"))
|
||||||
|
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
|
||||||
|
SystemExit)
|
||||||
|
expected = fill("ERROR: " + dedent(ServerConnectionError.__doc__))+"\n"
|
||||||
|
expected += "abcd\n"
|
||||||
|
self.assertEqual(cfg.stderr.getvalue(), expected)
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def test_other_error(self):
|
def test_other_error(self):
|
||||||
cfg = config("send")
|
cfg = config("send")
|
||||||
|
|
|
@ -3,12 +3,14 @@ import io, re
|
||||||
import mock
|
import mock
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
from twisted.internet.defer import gatherResults, inlineCallbacks
|
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
|
||||||
|
from twisted.internet.error import ConnectionRefusedError
|
||||||
from .common import ServerBase, poll_until, pause_one_tick
|
from .common import ServerBase, poll_until, pause_one_tick
|
||||||
from .. import wormhole, _rendezvous
|
from .. import wormhole, _rendezvous
|
||||||
from ..errors import (WrongPasswordError,
|
from ..errors import (WrongPasswordError, ServerConnectionError,
|
||||||
KeyFormatError, WormholeClosed, LonelyError,
|
KeyFormatError, WormholeClosed, LonelyError,
|
||||||
NoKeyError, OnlyOneCodeError)
|
NoKeyError, OnlyOneCodeError)
|
||||||
|
from ..transit import allocate_tcp_port
|
||||||
|
|
||||||
APPID = "appid"
|
APPID = "appid"
|
||||||
|
|
||||||
|
@ -617,6 +619,70 @@ class Reconnection(ServerBase, unittest.TestCase):
|
||||||
c2 = yield w2.close()
|
c2 = yield w2.close()
|
||||||
self.assertEqual(c2, "happy")
|
self.assertEqual(c2, "happy")
|
||||||
|
|
||||||
|
class InitialFailure(unittest.TestCase):
|
||||||
|
def assertSCEResultOf(self, d, innerType):
|
||||||
|
f = self.failureResultOf(d, ServerConnectionError)
|
||||||
|
inner = f.value.reason
|
||||||
|
self.assertIsInstance(inner, innerType)
|
||||||
|
return inner
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_bad_dns(self):
|
||||||
|
# point at a URL that will never connect
|
||||||
|
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", reactor)
|
||||||
|
# that should have already received an error, when it tried to
|
||||||
|
# resolve the bogus DNS name. All API calls will return an error.
|
||||||
|
e = yield self.assertFailure(w.get_unverified_key(),
|
||||||
|
ServerConnectionError)
|
||||||
|
self.assertIsInstance(e.reason, ValueError)
|
||||||
|
self.assertEqual(str(e), "invalid hostname: %%%.example.org")
|
||||||
|
self.assertSCEResultOf(w.get_code(), ValueError)
|
||||||
|
self.assertSCEResultOf(w.get_verifier(), ValueError)
|
||||||
|
self.assertSCEResultOf(w.get_versions(), ValueError)
|
||||||
|
self.assertSCEResultOf(w.get_message(), ValueError)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def assertSCE(self, d, innerType):
|
||||||
|
e = yield self.assertFailure(d, ServerConnectionError)
|
||||||
|
inner = e.reason
|
||||||
|
self.assertIsInstance(inner, innerType)
|
||||||
|
returnValue(inner)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_no_connection(self):
|
||||||
|
# point at a URL that will never connect
|
||||||
|
port = allocate_tcp_port()
|
||||||
|
w = wormhole.create(APPID, "ws://127.0.0.1:%d/v1" % port, reactor)
|
||||||
|
# nothing is listening, but it will take a turn to discover that
|
||||||
|
d1 = w.get_code()
|
||||||
|
d2 = w.get_unverified_key()
|
||||||
|
d3 = w.get_verifier()
|
||||||
|
d4 = w.get_versions()
|
||||||
|
d5 = w.get_message()
|
||||||
|
yield self.assertSCE(d1, ConnectionRefusedError)
|
||||||
|
yield self.assertSCE(d2, ConnectionRefusedError)
|
||||||
|
yield self.assertSCE(d3, ConnectionRefusedError)
|
||||||
|
yield self.assertSCE(d4, ConnectionRefusedError)
|
||||||
|
yield self.assertSCE(d5, ConnectionRefusedError)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_all_deferreds(self):
|
||||||
|
# point at a URL that will never connect
|
||||||
|
port = allocate_tcp_port()
|
||||||
|
w = wormhole.create(APPID, "ws://127.0.0.1:%d/v1" % port, reactor)
|
||||||
|
# nothing is listening, but it will take a turn to discover that
|
||||||
|
w.allocate_code()
|
||||||
|
d1 = w.get_code()
|
||||||
|
d2 = w.get_unverified_key()
|
||||||
|
d3 = w.get_verifier()
|
||||||
|
d4 = w.get_versions()
|
||||||
|
d5 = w.get_message()
|
||||||
|
yield self.assertSCE(d1, ConnectionRefusedError)
|
||||||
|
yield self.assertSCE(d2, ConnectionRefusedError)
|
||||||
|
yield self.assertSCE(d3, ConnectionRefusedError)
|
||||||
|
yield self.assertSCE(d4, ConnectionRefusedError)
|
||||||
|
yield self.assertSCE(d5, ConnectionRefusedError)
|
||||||
|
|
||||||
class Trace(unittest.TestCase):
|
class Trace(unittest.TestCase):
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
w1 = wormhole.create(APPID, "ws://localhost:1", reactor)
|
w1 = wormhole.create(APPID, "ws://localhost:1", reactor)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user