Clean up error handling, close channel upon error

This commit is contained in:
Brian Warner 2016-04-25 17:57:02 -07:00
commit 1511f96c66
6 changed files with 150 additions and 99 deletions

View File

@ -1,9 +1,9 @@
from __future__ import print_function
import os, sys, json, binascii, six, tempfile, zipfile
from tqdm import tqdm
from twisted.internet import reactor, defer
from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue
from ..twisted.transcribe import Wormhole, WrongPasswordError
from ..twisted.transcribe import Wormhole
from ..twisted.transit import TransitReceiver
from ..errors import TransferError
@ -14,6 +14,13 @@ class RespondError(Exception):
self.response = response
def receive(args, reactor=reactor):
"""I implement 'wormhole receive'. I return a Deferred that fires with
None (for success), or signals one of the following errors:
* WrongPasswordError: the two sides didn't use matching passwords
* Timeout: something didn't happen fast enough for our tastes
* TransferError: the sender rejected the transfer: verifier mismatch
* any other error: something unexpected happened
"""
return TwistedReceiver(args, reactor).go()
@ -26,9 +33,8 @@ class TwistedReceiver:
def msg(self, *args, **kwargs):
print(*args, file=self.args.stdout, **kwargs)
# TODO: @handle_server_error
@inlineCallbacks
def go(self):
d = defer.succeed(None)
tor_manager = None
if self.args.tor:
_start = self.args.timing.add_event("import TorManager")
@ -39,18 +45,10 @@ class TwistedReceiver:
# tor in parallel with everything else, make sure the TorManager
# can lazy-provide an endpoint, and overlap the startup process
# with the user handing off the wormhole code
d.addCallback(lambda _: tor_manager.start())
def _make_wormhole(_):
self._w = Wormhole(APPID, self.args.relay_url, tor_manager,
timing=self.args.timing,
reactor=self._reactor)
d.addCallback(_make_wormhole)
d.addCallback(lambda _: self._go(self._w, tor_manager))
def _always_close(res):
d2 = self._w.close()
d2.addBoth(lambda _: res)
return d2
d.addBoth(_always_close)
yield tor_manager.start()
w = Wormhole(APPID, self.args.relay_url, tor_manager,
timing=self.args.timing,
reactor=self._reactor)
# I wanted to do this instead:
#
# try:
@ -61,7 +59,9 @@ class TwistedReceiver:
# but when _go had a UsageError, the stacktrace was always displayed
# as coming from the "yield self._go" line, which wasn't very useful
# for tracking it down.
return d
d = self._go(w, tor_manager)
d.addBoth(w.close)
yield d
@inlineCallbacks
def _go(self, w, tor_manager):
@ -72,7 +72,7 @@ class TwistedReceiver:
try:
if "message" in them_d:
yield self.handle_text(them_d, w)
returnValue(0)
returnValue(None)
if "file" in them_d:
f = self.handle_file(them_d)
rp = yield self.establish_transit(w, them_d, tor_manager)
@ -92,8 +92,8 @@ class TwistedReceiver:
except RespondError as r:
data = json.dumps(r.response).encode("utf-8")
yield w.send_data(data)
raise SystemExit(1)
returnValue(0)
raise TransferError(r["error"])
returnValue(None)
@inlineCallbacks
def handle_code(self, w):
@ -113,10 +113,8 @@ class TwistedReceiver:
@inlineCallbacks
def get_data(self, w):
try:
them_bytes = yield w.get_data()
except WrongPasswordError as e:
raise TransferError(u"ERROR: " + e.explain())
# this may raise WrongPasswordError
them_bytes = yield w.get_data()
them_d = json.loads(them_bytes.decode("utf-8"))
if "error" in them_d:
raise TransferError(u"ERROR: " + them_d["error"])
@ -229,7 +227,7 @@ class TwistedReceiver:
self.msg()
self.msg(u"Connection dropped before full file received")
self.msg(u"got %d bytes, wanted %d" % (received, self.xfersize))
returnValue(1) # TODO: exit properly
raise TransferError("Connection dropped before full file received")
assert received == self.xfersize
def write_file(self, f):

View File

@ -5,13 +5,21 @@ from twisted.protocols import basic
from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue
from ..errors import TransferError
from ..twisted.transcribe import Wormhole, WrongPasswordError
from ..twisted.transcribe import Wormhole
from ..twisted.transit import TransitSender
APPID = u"lothar.com/wormhole/text-or-file-xfer"
@inlineCallbacks
def send(args, reactor=reactor):
"""I implement 'wormhole send'. I return a Deferred that fires with None
(for success), or signals one of the following errors:
* WrongPasswordError: the two sides didn't use matching passwords
* Timeout: something didn't happen fast enough for our tastes
* TransferError: the receiver rejected the transfer: verifier mismatch,
permission not granted, ack not successful.
* any other error: something unexpected happened
"""
assert isinstance(args.relay_url, type(u""))
if args.zeromode:
assert not args.code
@ -43,6 +51,13 @@ def send(args, reactor=reactor):
w = Wormhole(APPID, args.relay_url, tor_manager, timing=args.timing,
reactor=reactor)
d = _send(reactor, w, args, phase1, fd_to_send, tor_manager)
d.addBoth(w.close)
yield d
@inlineCallbacks
def _send(reactor, w, args, phase1, fd_to_send, tor_manager):
transit_sender = None
if fd_to_send:
transit_sender = TransitSender(args.transit_helper,
no_listen=args.no_listen,
@ -87,18 +102,15 @@ def send(args, reactor=reactor):
my_phase1_bytes = json.dumps(phase1).encode("utf-8")
yield w.send_data(my_phase1_bytes)
try:
them_phase1_bytes = yield w.get_data()
except WrongPasswordError as e:
raise TransferError(e.explain())
# this may raise WrongPasswordError
them_phase1_bytes = yield w.get_data()
them_phase1 = json.loads(them_phase1_bytes.decode("utf-8"))
if fd_to_send is None:
if them_phase1["message_ack"] == "ok":
print(u"text message sent", file=args.stdout)
yield w.close()
returnValue(0) # terminates this function
returnValue(None) # terminates this function
raise TransferError("error sending text: %r" % (them_phase1,))
if "error" in them_phase1:
@ -108,10 +120,12 @@ def send(args, reactor=reactor):
raise TransferError("ambiguous response from remote, "
"transfer abandoned: %s" % (them_phase1,))
tdata = them_phase1["transit"]
yield w.close()
# XXX the downside of closing above, rather than here, is that it leaves
# the channel claimed for a longer time
#yield w.close()
yield _send_file_twisted(tdata, transit_sender, fd_to_send,
args.stdout, args.hide_progress, args.timing)
returnValue(0)
returnValue(None)
def build_phase1_data(args):
phase1 = {}

View File

@ -2,7 +2,7 @@ from __future__ import print_function
import os, sys
from twisted.internet.defer import maybeDeferred
from twisted.internet.task import react
from ..errors import TransferError
from ..errors import TransferError, WrongPasswordError, Timeout
from ..timing import DebugTiming
from .cli_args import parser
@ -35,6 +35,7 @@ def run(reactor, argv, cwd, stdout, stderr, executable=None):
args.timing = timing = DebugTiming()
timing.add_event("command dispatch")
# fires with None, or raises an error
d = maybeDeferred(dispatch, args)
def _maybe_dump_timing(res):
timing.add_event("exit")
@ -43,13 +44,12 @@ def run(reactor, argv, cwd, stdout, stderr, executable=None):
return res
d.addBoth(_maybe_dump_timing)
def _explain_error(f):
f.trap(TransferError)
# these three errors don't print a traceback, just an explanation
f.trap(TransferError, WrongPasswordError, Timeout)
print("ERROR:", f.value, file=stderr)
raise SystemExit(1)
d.addErrback(_explain_error)
def _rc(rc):
raise SystemExit(rc)
d.addCallback(_rc)
d.addCallback(lambda _: 0)
return d
def entry():

View File

@ -1,4 +1,4 @@
import functools, textwrap
import functools
class ServerError(Exception):
def __init__(self, message, relay):
@ -28,8 +28,6 @@ class WrongPasswordError(Exception):
chance.
"""
# or the data blob was corrupted, and that's why decrypt failed
def explain(self):
return textwrap.dedent(self.__doc__)
class ReflectionAttack(Exception):
"""An attacker (or bug) reflected our outgoing message back to us."""

View File

@ -3,12 +3,12 @@ import os, sys, re, io, zipfile, six
from twisted.trial import unittest
from twisted.python import procutils, log
from twisted.internet.utils import getProcessOutputAndValue
from twisted.internet.defer import inlineCallbacks
from twisted.internet.defer import gatherResults, inlineCallbacks
from .. import __version__
from .common import ServerBase
from ..cli import runner, cmd_send, cmd_receive
from ..cli.cmd_send import build_phase1_data
from ..errors import TransferError
from ..errors import TransferError, WrongPasswordError
from ..timing import DebugTiming
class Phase1Data(unittest.TestCase):
@ -306,15 +306,17 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
path=send_dir)
receive_d = getProcessOutputAndValue(wormhole_bin, receive_args,
path=receive_dir)
send_res = yield send_d
(send_res, receive_res) = yield gatherResults([send_d, receive_d],
True)
send_stdout = send_res[0].decode("utf-8")
send_stderr = send_res[1].decode("utf-8")
send_rc = send_res[2]
receive_res = yield receive_d
receive_stdout = receive_res[0].decode("utf-8")
receive_stderr = receive_res[1].decode("utf-8")
receive_rc = receive_res[2]
NL = os.linesep
self.assertEqual((send_rc, receive_rc), (0, 0),
(send_res, receive_res))
else:
sargs = runner.parser.parse_args(send_args)
sargs.cwd = send_dir
@ -330,22 +332,11 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
receive_d = cmd_receive.receive(rargs)
# The sender might fail, leaving the receiver hanging, or vice
# versa. If either side fails, cancel the other, so it won't
# matter which Deferred we wait upon first.
# versa. Make sure we don't wait on one side exclusively
def _oops(f, which):
log.msg("test_scripts: %s failed, cancelling both" % which)
send_d.cancel()
receive_d.cancel()
return f
send_d.addErrback(_oops, "send_d")
receive_d.addErrback(_oops, "receive_d")
send_rc = yield send_d
yield gatherResults([send_d, receive_d], True)
send_stdout = sargs.stdout.getvalue()
send_stderr = sargs.stderr.getvalue()
receive_rc = yield receive_d
receive_stdout = rargs.stdout.getvalue()
receive_stderr = rargs.stderr.getvalue()
@ -355,8 +346,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.maxDiff = None # show full output for assertion failures
self.failUnlessEqual(send_stderr, "")
self.failUnlessEqual(receive_stderr, "")
self.failUnlessEqual(send_stderr, "",
(send_stdout, send_stderr))
self.failUnlessEqual(receive_stderr, "",
(receive_stdout, receive_stderr))
# check sender
if mode == "text":
@ -417,9 +410,6 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
with open(fn, "r") as f:
self.failUnlessEqual(f.read(), message(i))
self.failUnlessEqual(send_rc, 0)
self.failUnlessEqual(receive_rc, 0)
def test_text(self):
return self._do_test()
def test_text_subprocess(self):
@ -436,3 +426,62 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
return self._do_test(mode="directory", addslash=True)
def test_directory_override(self):
return self._do_test(mode="directory", override_filename=True)
class Cleanup(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_text(self):
# the rendezvous channel should be deleted after success
code = u"1-abc"
common_args = ["--hide-progress",
"--relay-url", self.relayurl,
"--transit-helper", ""]
sargs = runner.parser.parse_args(common_args +
["send",
"--text", "secret message",
"--code", code])
sargs.stdout = io.StringIO()
sargs.stderr = io.StringIO()
sargs.timing = DebugTiming()
rargs = runner.parser.parse_args(common_args +
["receive", code])
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
send_d = cmd_send.send(sargs)
receive_d = cmd_receive.receive(rargs)
yield send_d
yield receive_d
cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated()
self.assertEqual(len(cids), 0)
@inlineCallbacks
def test_text_wrong_password(self):
# if the password was wrong, the rendezvous channel should still be
# deleted
common_args = ["--hide-progress",
"--relay-url", self.relayurl,
"--transit-helper", ""]
sargs = runner.parser.parse_args(common_args +
["send",
"--text", "secret message",
"--code", u"1-abc"])
sargs.stdout = io.StringIO()
sargs.stderr = io.StringIO()
sargs.timing = DebugTiming()
rargs = runner.parser.parse_args(common_args +
["receive", u"1-WRONG"])
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
send_d = cmd_send.send(sargs)
receive_d = cmd_receive.receive(rargs)
# both sides should be capable of detecting the mismatch
yield self.assertFailure(send_d, WrongPasswordError)
yield self.assertFailure(receive_d, WrongPasswordError)
cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated()
self.assertEqual(len(cids), 0)

View File

@ -27,30 +27,6 @@ def make_confmsg(confkey, nonce):
def to_bytes(u):
return unicodedata.normalize("NFC", u).encode("utf-8")
def close_on_error(meth): # method decorator
# Clients report certain errors as "moods", so the server can make a
# rough count failed connections (due to mismatched passwords, attacks,
# or timeouts). We don't report precondition failures, as those are the
# responsibility/fault of the local application code. We count
# non-precondition errors in case they represent server-side problems.
def _wrapper(self, *args, **kwargs):
d = defer.maybeDeferred(meth, self, *args, **kwargs)
def _onerror(f):
if f.check(Timeout):
d2 = self.close(u"lonely")
elif f.check(WrongPasswordError):
d2 = self.close(u"scary")
elif f.check(TypeError, UsageError):
# preconditions don't warrant _close_with_error()
d2 = defer.succeed(None)
else:
d2 = self.close(u"errory")
d2.addBoth(lambda _: f)
return d2
d.addErrback(_onerror)
return d
return _wrapper
class WSClient(websocket.WebSocketClientProtocol):
def onOpen(self):
self.wormhole_open = True
@ -181,13 +157,16 @@ class Wormhole:
return self._signal_error(welcome["error"])
@inlineCallbacks
def _sleep(self):
if self._error: # don't sleep if the bed's already on fire
def _sleep(self, wake_on_error=True):
if wake_on_error and self._error:
# don't sleep if the bed's already on fire, unless we're waiting
# for the fire department to respond, in which case sure, keep on
# sleeping
raise self._error
d = defer.Deferred()
self._sleepers.append(d)
yield d
if self._error:
if wake_on_error and self._error:
raise self._error
def _wakeup(self):
@ -366,7 +345,6 @@ class Wormhole:
self._msg1 = unhexlify(d["msg1"].encode("ascii"))
return self
@close_on_error
@inlineCallbacks
def get_verifier(self):
if self._closed: raise UsageError
@ -458,7 +436,6 @@ class Wormhole:
data = box.decrypt(encrypted)
return data
@close_on_error
@inlineCallbacks
def send_data(self, outbound_data, phase=u"data", wait=False):
if not isinstance(outbound_data, type(b"")):
@ -481,7 +458,6 @@ class Wormhole:
yield self._msg_send(phase, outbound_encrypted, wait)
self._timing.finish_event(_sent)
@close_on_error
@inlineCallbacks
def get_data(self, phase=u"data"):
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
@ -506,26 +482,42 @@ class Wormhole:
# TODO: schedule reconnect, unless we're done
@inlineCallbacks
def close(self, res=None, mood=u"happy"):
if not isinstance(mood, (type(None), type(u""))):
raise TypeError(type(mood))
def close(self, f=None, mood=None):
"""Do d.addBoth(w.close) at the end of your chain."""
if self._closed:
returnValue(None)
self._closed = True
if not self._ws:
returnValue(None)
if mood is None:
mood = u"happy"
if f:
if f.check(Timeout):
mood = u"lonely"
elif f.check(WrongPasswordError):
mood = u"scary"
elif f.check(TypeError, UsageError):
# preconditions don't warrant reporting mood
pass
else:
mood = u"errory" # other errors do
if not isinstance(mood, (type(None), type(u""))):
raise TypeError(type(mood))
self._timing.finish_event(self._timing_started, mood=mood)
yield self._deallocate(mood)
# TODO: mark WebSocket as don't-reconnect
self._ws.transport.loseConnection() # probably flushes
del self._ws
returnValue(f)
@inlineCallbacks
def _deallocate(self, mood=None):
def _deallocate(self, mood):
_sent = self._timing.add_event("close")
yield self._ws_send(u"deallocate", mood=mood)
while self._deallocated_status is None:
yield self._sleep()
yield self._sleep(wake_on_error=False)
self._timing.finish_event(_sent)
# TODO: set a timeout, don't wait forever for an ack
# TODO: if the connection is lost, let it go