Clean up error handling, close channel upon error
This commit is contained in:
commit
1511f96c66
|
@ -1,9 +1,9 @@
|
|||
from __future__ import print_function
|
||||
import os, sys, json, binascii, six, tempfile, zipfile
|
||||
from tqdm import tqdm
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
||||
from ..twisted.transcribe import Wormhole
|
||||
from ..twisted.transit import TransitReceiver
|
||||
from ..errors import TransferError
|
||||
|
||||
|
@ -14,6 +14,13 @@ class RespondError(Exception):
|
|||
self.response = response
|
||||
|
||||
def receive(args, reactor=reactor):
|
||||
"""I implement 'wormhole receive'. I return a Deferred that fires with
|
||||
None (for success), or signals one of the following errors:
|
||||
* WrongPasswordError: the two sides didn't use matching passwords
|
||||
* Timeout: something didn't happen fast enough for our tastes
|
||||
* TransferError: the sender rejected the transfer: verifier mismatch
|
||||
* any other error: something unexpected happened
|
||||
"""
|
||||
return TwistedReceiver(args, reactor).go()
|
||||
|
||||
|
||||
|
@ -26,9 +33,8 @@ class TwistedReceiver:
|
|||
def msg(self, *args, **kwargs):
|
||||
print(*args, file=self.args.stdout, **kwargs)
|
||||
|
||||
# TODO: @handle_server_error
|
||||
@inlineCallbacks
|
||||
def go(self):
|
||||
d = defer.succeed(None)
|
||||
tor_manager = None
|
||||
if self.args.tor:
|
||||
_start = self.args.timing.add_event("import TorManager")
|
||||
|
@ -39,18 +45,10 @@ class TwistedReceiver:
|
|||
# tor in parallel with everything else, make sure the TorManager
|
||||
# can lazy-provide an endpoint, and overlap the startup process
|
||||
# with the user handing off the wormhole code
|
||||
d.addCallback(lambda _: tor_manager.start())
|
||||
def _make_wormhole(_):
|
||||
self._w = Wormhole(APPID, self.args.relay_url, tor_manager,
|
||||
yield tor_manager.start()
|
||||
w = Wormhole(APPID, self.args.relay_url, tor_manager,
|
||||
timing=self.args.timing,
|
||||
reactor=self._reactor)
|
||||
d.addCallback(_make_wormhole)
|
||||
d.addCallback(lambda _: self._go(self._w, tor_manager))
|
||||
def _always_close(res):
|
||||
d2 = self._w.close()
|
||||
d2.addBoth(lambda _: res)
|
||||
return d2
|
||||
d.addBoth(_always_close)
|
||||
# I wanted to do this instead:
|
||||
#
|
||||
# try:
|
||||
|
@ -61,7 +59,9 @@ class TwistedReceiver:
|
|||
# but when _go had a UsageError, the stacktrace was always displayed
|
||||
# as coming from the "yield self._go" line, which wasn't very useful
|
||||
# for tracking it down.
|
||||
return d
|
||||
d = self._go(w, tor_manager)
|
||||
d.addBoth(w.close)
|
||||
yield d
|
||||
|
||||
@inlineCallbacks
|
||||
def _go(self, w, tor_manager):
|
||||
|
@ -72,7 +72,7 @@ class TwistedReceiver:
|
|||
try:
|
||||
if "message" in them_d:
|
||||
yield self.handle_text(them_d, w)
|
||||
returnValue(0)
|
||||
returnValue(None)
|
||||
if "file" in them_d:
|
||||
f = self.handle_file(them_d)
|
||||
rp = yield self.establish_transit(w, them_d, tor_manager)
|
||||
|
@ -92,8 +92,8 @@ class TwistedReceiver:
|
|||
except RespondError as r:
|
||||
data = json.dumps(r.response).encode("utf-8")
|
||||
yield w.send_data(data)
|
||||
raise SystemExit(1)
|
||||
returnValue(0)
|
||||
raise TransferError(r["error"])
|
||||
returnValue(None)
|
||||
|
||||
@inlineCallbacks
|
||||
def handle_code(self, w):
|
||||
|
@ -113,10 +113,8 @@ class TwistedReceiver:
|
|||
|
||||
@inlineCallbacks
|
||||
def get_data(self, w):
|
||||
try:
|
||||
# this may raise WrongPasswordError
|
||||
them_bytes = yield w.get_data()
|
||||
except WrongPasswordError as e:
|
||||
raise TransferError(u"ERROR: " + e.explain())
|
||||
them_d = json.loads(them_bytes.decode("utf-8"))
|
||||
if "error" in them_d:
|
||||
raise TransferError(u"ERROR: " + them_d["error"])
|
||||
|
@ -229,7 +227,7 @@ class TwistedReceiver:
|
|||
self.msg()
|
||||
self.msg(u"Connection dropped before full file received")
|
||||
self.msg(u"got %d bytes, wanted %d" % (received, self.xfersize))
|
||||
returnValue(1) # TODO: exit properly
|
||||
raise TransferError("Connection dropped before full file received")
|
||||
assert received == self.xfersize
|
||||
|
||||
def write_file(self, f):
|
||||
|
|
|
@ -5,13 +5,21 @@ from twisted.protocols import basic
|
|||
from twisted.internet import reactor
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from ..errors import TransferError
|
||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
||||
from ..twisted.transcribe import Wormhole
|
||||
from ..twisted.transit import TransitSender
|
||||
|
||||
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
||||
|
||||
@inlineCallbacks
|
||||
def send(args, reactor=reactor):
|
||||
"""I implement 'wormhole send'. I return a Deferred that fires with None
|
||||
(for success), or signals one of the following errors:
|
||||
* WrongPasswordError: the two sides didn't use matching passwords
|
||||
* Timeout: something didn't happen fast enough for our tastes
|
||||
* TransferError: the receiver rejected the transfer: verifier mismatch,
|
||||
permission not granted, ack not successful.
|
||||
* any other error: something unexpected happened
|
||||
"""
|
||||
assert isinstance(args.relay_url, type(u""))
|
||||
if args.zeromode:
|
||||
assert not args.code
|
||||
|
@ -43,6 +51,13 @@ def send(args, reactor=reactor):
|
|||
w = Wormhole(APPID, args.relay_url, tor_manager, timing=args.timing,
|
||||
reactor=reactor)
|
||||
|
||||
d = _send(reactor, w, args, phase1, fd_to_send, tor_manager)
|
||||
d.addBoth(w.close)
|
||||
yield d
|
||||
|
||||
@inlineCallbacks
|
||||
def _send(reactor, w, args, phase1, fd_to_send, tor_manager):
|
||||
transit_sender = None
|
||||
if fd_to_send:
|
||||
transit_sender = TransitSender(args.transit_helper,
|
||||
no_listen=args.no_listen,
|
||||
|
@ -87,18 +102,15 @@ def send(args, reactor=reactor):
|
|||
my_phase1_bytes = json.dumps(phase1).encode("utf-8")
|
||||
yield w.send_data(my_phase1_bytes)
|
||||
|
||||
try:
|
||||
# this may raise WrongPasswordError
|
||||
them_phase1_bytes = yield w.get_data()
|
||||
except WrongPasswordError as e:
|
||||
raise TransferError(e.explain())
|
||||
|
||||
them_phase1 = json.loads(them_phase1_bytes.decode("utf-8"))
|
||||
|
||||
if fd_to_send is None:
|
||||
if them_phase1["message_ack"] == "ok":
|
||||
print(u"text message sent", file=args.stdout)
|
||||
yield w.close()
|
||||
returnValue(0) # terminates this function
|
||||
returnValue(None) # terminates this function
|
||||
raise TransferError("error sending text: %r" % (them_phase1,))
|
||||
|
||||
if "error" in them_phase1:
|
||||
|
@ -108,10 +120,12 @@ def send(args, reactor=reactor):
|
|||
raise TransferError("ambiguous response from remote, "
|
||||
"transfer abandoned: %s" % (them_phase1,))
|
||||
tdata = them_phase1["transit"]
|
||||
yield w.close()
|
||||
# XXX the downside of closing above, rather than here, is that it leaves
|
||||
# the channel claimed for a longer time
|
||||
#yield w.close()
|
||||
yield _send_file_twisted(tdata, transit_sender, fd_to_send,
|
||||
args.stdout, args.hide_progress, args.timing)
|
||||
returnValue(0)
|
||||
returnValue(None)
|
||||
|
||||
def build_phase1_data(args):
|
||||
phase1 = {}
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import print_function
|
|||
import os, sys
|
||||
from twisted.internet.defer import maybeDeferred
|
||||
from twisted.internet.task import react
|
||||
from ..errors import TransferError
|
||||
from ..errors import TransferError, WrongPasswordError, Timeout
|
||||
from ..timing import DebugTiming
|
||||
from .cli_args import parser
|
||||
|
||||
|
@ -35,6 +35,7 @@ def run(reactor, argv, cwd, stdout, stderr, executable=None):
|
|||
args.timing = timing = DebugTiming()
|
||||
|
||||
timing.add_event("command dispatch")
|
||||
# fires with None, or raises an error
|
||||
d = maybeDeferred(dispatch, args)
|
||||
def _maybe_dump_timing(res):
|
||||
timing.add_event("exit")
|
||||
|
@ -43,13 +44,12 @@ def run(reactor, argv, cwd, stdout, stderr, executable=None):
|
|||
return res
|
||||
d.addBoth(_maybe_dump_timing)
|
||||
def _explain_error(f):
|
||||
f.trap(TransferError)
|
||||
# these three errors don't print a traceback, just an explanation
|
||||
f.trap(TransferError, WrongPasswordError, Timeout)
|
||||
print("ERROR:", f.value, file=stderr)
|
||||
raise SystemExit(1)
|
||||
d.addErrback(_explain_error)
|
||||
def _rc(rc):
|
||||
raise SystemExit(rc)
|
||||
d.addCallback(_rc)
|
||||
d.addCallback(lambda _: 0)
|
||||
return d
|
||||
|
||||
def entry():
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import functools, textwrap
|
||||
import functools
|
||||
|
||||
class ServerError(Exception):
|
||||
def __init__(self, message, relay):
|
||||
|
@ -28,8 +28,6 @@ class WrongPasswordError(Exception):
|
|||
chance.
|
||||
"""
|
||||
# or the data blob was corrupted, and that's why decrypt failed
|
||||
def explain(self):
|
||||
return textwrap.dedent(self.__doc__)
|
||||
|
||||
class ReflectionAttack(Exception):
|
||||
"""An attacker (or bug) reflected our outgoing message back to us."""
|
||||
|
|
|
@ -3,12 +3,12 @@ import os, sys, re, io, zipfile, six
|
|||
from twisted.trial import unittest
|
||||
from twisted.python import procutils, log
|
||||
from twisted.internet.utils import getProcessOutputAndValue
|
||||
from twisted.internet.defer import inlineCallbacks
|
||||
from twisted.internet.defer import gatherResults, inlineCallbacks
|
||||
from .. import __version__
|
||||
from .common import ServerBase
|
||||
from ..cli import runner, cmd_send, cmd_receive
|
||||
from ..cli.cmd_send import build_phase1_data
|
||||
from ..errors import TransferError
|
||||
from ..errors import TransferError, WrongPasswordError
|
||||
from ..timing import DebugTiming
|
||||
|
||||
class Phase1Data(unittest.TestCase):
|
||||
|
@ -306,15 +306,17 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
path=send_dir)
|
||||
receive_d = getProcessOutputAndValue(wormhole_bin, receive_args,
|
||||
path=receive_dir)
|
||||
send_res = yield send_d
|
||||
(send_res, receive_res) = yield gatherResults([send_d, receive_d],
|
||||
True)
|
||||
send_stdout = send_res[0].decode("utf-8")
|
||||
send_stderr = send_res[1].decode("utf-8")
|
||||
send_rc = send_res[2]
|
||||
receive_res = yield receive_d
|
||||
receive_stdout = receive_res[0].decode("utf-8")
|
||||
receive_stderr = receive_res[1].decode("utf-8")
|
||||
receive_rc = receive_res[2]
|
||||
NL = os.linesep
|
||||
self.assertEqual((send_rc, receive_rc), (0, 0),
|
||||
(send_res, receive_res))
|
||||
else:
|
||||
sargs = runner.parser.parse_args(send_args)
|
||||
sargs.cwd = send_dir
|
||||
|
@ -330,22 +332,11 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
receive_d = cmd_receive.receive(rargs)
|
||||
|
||||
# The sender might fail, leaving the receiver hanging, or vice
|
||||
# versa. If either side fails, cancel the other, so it won't
|
||||
# matter which Deferred we wait upon first.
|
||||
# versa. Make sure we don't wait on one side exclusively
|
||||
|
||||
def _oops(f, which):
|
||||
log.msg("test_scripts: %s failed, cancelling both" % which)
|
||||
send_d.cancel()
|
||||
receive_d.cancel()
|
||||
return f
|
||||
send_d.addErrback(_oops, "send_d")
|
||||
receive_d.addErrback(_oops, "receive_d")
|
||||
|
||||
send_rc = yield send_d
|
||||
yield gatherResults([send_d, receive_d], True)
|
||||
send_stdout = sargs.stdout.getvalue()
|
||||
send_stderr = sargs.stderr.getvalue()
|
||||
|
||||
receive_rc = yield receive_d
|
||||
receive_stdout = rargs.stdout.getvalue()
|
||||
receive_stderr = rargs.stderr.getvalue()
|
||||
|
||||
|
@ -355,8 +346,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
|
||||
self.maxDiff = None # show full output for assertion failures
|
||||
|
||||
self.failUnlessEqual(send_stderr, "")
|
||||
self.failUnlessEqual(receive_stderr, "")
|
||||
self.failUnlessEqual(send_stderr, "",
|
||||
(send_stdout, send_stderr))
|
||||
self.failUnlessEqual(receive_stderr, "",
|
||||
(receive_stdout, receive_stderr))
|
||||
|
||||
# check sender
|
||||
if mode == "text":
|
||||
|
@ -417,9 +410,6 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
with open(fn, "r") as f:
|
||||
self.failUnlessEqual(f.read(), message(i))
|
||||
|
||||
self.failUnlessEqual(send_rc, 0)
|
||||
self.failUnlessEqual(receive_rc, 0)
|
||||
|
||||
def test_text(self):
|
||||
return self._do_test()
|
||||
def test_text_subprocess(self):
|
||||
|
@ -436,3 +426,62 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
return self._do_test(mode="directory", addslash=True)
|
||||
def test_directory_override(self):
|
||||
return self._do_test(mode="directory", override_filename=True)
|
||||
|
||||
class Cleanup(ServerBase, unittest.TestCase):
|
||||
@inlineCallbacks
|
||||
def test_text(self):
|
||||
# the rendezvous channel should be deleted after success
|
||||
code = u"1-abc"
|
||||
common_args = ["--hide-progress",
|
||||
"--relay-url", self.relayurl,
|
||||
"--transit-helper", ""]
|
||||
sargs = runner.parser.parse_args(common_args +
|
||||
["send",
|
||||
"--text", "secret message",
|
||||
"--code", code])
|
||||
sargs.stdout = io.StringIO()
|
||||
sargs.stderr = io.StringIO()
|
||||
sargs.timing = DebugTiming()
|
||||
rargs = runner.parser.parse_args(common_args +
|
||||
["receive", code])
|
||||
rargs.stdout = io.StringIO()
|
||||
rargs.stderr = io.StringIO()
|
||||
rargs.timing = DebugTiming()
|
||||
send_d = cmd_send.send(sargs)
|
||||
receive_d = cmd_receive.receive(rargs)
|
||||
|
||||
yield send_d
|
||||
yield receive_d
|
||||
|
||||
cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated()
|
||||
self.assertEqual(len(cids), 0)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_text_wrong_password(self):
|
||||
# if the password was wrong, the rendezvous channel should still be
|
||||
# deleted
|
||||
common_args = ["--hide-progress",
|
||||
"--relay-url", self.relayurl,
|
||||
"--transit-helper", ""]
|
||||
sargs = runner.parser.parse_args(common_args +
|
||||
["send",
|
||||
"--text", "secret message",
|
||||
"--code", u"1-abc"])
|
||||
sargs.stdout = io.StringIO()
|
||||
sargs.stderr = io.StringIO()
|
||||
sargs.timing = DebugTiming()
|
||||
rargs = runner.parser.parse_args(common_args +
|
||||
["receive", u"1-WRONG"])
|
||||
rargs.stdout = io.StringIO()
|
||||
rargs.stderr = io.StringIO()
|
||||
rargs.timing = DebugTiming()
|
||||
send_d = cmd_send.send(sargs)
|
||||
receive_d = cmd_receive.receive(rargs)
|
||||
|
||||
# both sides should be capable of detecting the mismatch
|
||||
yield self.assertFailure(send_d, WrongPasswordError)
|
||||
yield self.assertFailure(receive_d, WrongPasswordError)
|
||||
|
||||
cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated()
|
||||
self.assertEqual(len(cids), 0)
|
||||
|
||||
|
|
|
@ -27,30 +27,6 @@ def make_confmsg(confkey, nonce):
|
|||
def to_bytes(u):
|
||||
return unicodedata.normalize("NFC", u).encode("utf-8")
|
||||
|
||||
def close_on_error(meth): # method decorator
|
||||
# Clients report certain errors as "moods", so the server can make a
|
||||
# rough count failed connections (due to mismatched passwords, attacks,
|
||||
# or timeouts). We don't report precondition failures, as those are the
|
||||
# responsibility/fault of the local application code. We count
|
||||
# non-precondition errors in case they represent server-side problems.
|
||||
def _wrapper(self, *args, **kwargs):
|
||||
d = defer.maybeDeferred(meth, self, *args, **kwargs)
|
||||
def _onerror(f):
|
||||
if f.check(Timeout):
|
||||
d2 = self.close(u"lonely")
|
||||
elif f.check(WrongPasswordError):
|
||||
d2 = self.close(u"scary")
|
||||
elif f.check(TypeError, UsageError):
|
||||
# preconditions don't warrant _close_with_error()
|
||||
d2 = defer.succeed(None)
|
||||
else:
|
||||
d2 = self.close(u"errory")
|
||||
d2.addBoth(lambda _: f)
|
||||
return d2
|
||||
d.addErrback(_onerror)
|
||||
return d
|
||||
return _wrapper
|
||||
|
||||
class WSClient(websocket.WebSocketClientProtocol):
|
||||
def onOpen(self):
|
||||
self.wormhole_open = True
|
||||
|
@ -181,13 +157,16 @@ class Wormhole:
|
|||
return self._signal_error(welcome["error"])
|
||||
|
||||
@inlineCallbacks
|
||||
def _sleep(self):
|
||||
if self._error: # don't sleep if the bed's already on fire
|
||||
def _sleep(self, wake_on_error=True):
|
||||
if wake_on_error and self._error:
|
||||
# don't sleep if the bed's already on fire, unless we're waiting
|
||||
# for the fire department to respond, in which case sure, keep on
|
||||
# sleeping
|
||||
raise self._error
|
||||
d = defer.Deferred()
|
||||
self._sleepers.append(d)
|
||||
yield d
|
||||
if self._error:
|
||||
if wake_on_error and self._error:
|
||||
raise self._error
|
||||
|
||||
def _wakeup(self):
|
||||
|
@ -366,7 +345,6 @@ class Wormhole:
|
|||
self._msg1 = unhexlify(d["msg1"].encode("ascii"))
|
||||
return self
|
||||
|
||||
@close_on_error
|
||||
@inlineCallbacks
|
||||
def get_verifier(self):
|
||||
if self._closed: raise UsageError
|
||||
|
@ -458,7 +436,6 @@ class Wormhole:
|
|||
data = box.decrypt(encrypted)
|
||||
return data
|
||||
|
||||
@close_on_error
|
||||
@inlineCallbacks
|
||||
def send_data(self, outbound_data, phase=u"data", wait=False):
|
||||
if not isinstance(outbound_data, type(b"")):
|
||||
|
@ -481,7 +458,6 @@ class Wormhole:
|
|||
yield self._msg_send(phase, outbound_encrypted, wait)
|
||||
self._timing.finish_event(_sent)
|
||||
|
||||
@close_on_error
|
||||
@inlineCallbacks
|
||||
def get_data(self, phase=u"data"):
|
||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||
|
@ -506,26 +482,42 @@ class Wormhole:
|
|||
# TODO: schedule reconnect, unless we're done
|
||||
|
||||
@inlineCallbacks
|
||||
def close(self, res=None, mood=u"happy"):
|
||||
if not isinstance(mood, (type(None), type(u""))):
|
||||
raise TypeError(type(mood))
|
||||
def close(self, f=None, mood=None):
|
||||
"""Do d.addBoth(w.close) at the end of your chain."""
|
||||
if self._closed:
|
||||
returnValue(None)
|
||||
self._closed = True
|
||||
if not self._ws:
|
||||
returnValue(None)
|
||||
|
||||
if mood is None:
|
||||
mood = u"happy"
|
||||
if f:
|
||||
if f.check(Timeout):
|
||||
mood = u"lonely"
|
||||
elif f.check(WrongPasswordError):
|
||||
mood = u"scary"
|
||||
elif f.check(TypeError, UsageError):
|
||||
# preconditions don't warrant reporting mood
|
||||
pass
|
||||
else:
|
||||
mood = u"errory" # other errors do
|
||||
if not isinstance(mood, (type(None), type(u""))):
|
||||
raise TypeError(type(mood))
|
||||
|
||||
self._timing.finish_event(self._timing_started, mood=mood)
|
||||
yield self._deallocate(mood)
|
||||
# TODO: mark WebSocket as don't-reconnect
|
||||
self._ws.transport.loseConnection() # probably flushes
|
||||
del self._ws
|
||||
returnValue(f)
|
||||
|
||||
@inlineCallbacks
|
||||
def _deallocate(self, mood=None):
|
||||
def _deallocate(self, mood):
|
||||
_sent = self._timing.add_event("close")
|
||||
yield self._ws_send(u"deallocate", mood=mood)
|
||||
while self._deallocated_status is None:
|
||||
yield self._sleep()
|
||||
yield self._sleep(wake_on_error=False)
|
||||
self._timing.finish_event(_sent)
|
||||
# TODO: set a timeout, don't wait forever for an ack
|
||||
# TODO: if the connection is lost, let it go
|
||||
|
|
Loading…
Reference in New Issue
Block a user