Clean up error handling, close channel upon error

This commit is contained in:
Brian Warner 2016-04-25 17:57:02 -07:00
commit 1511f96c66
6 changed files with 150 additions and 99 deletions

View File

@ -1,9 +1,9 @@
from __future__ import print_function from __future__ import print_function
import os, sys, json, binascii, six, tempfile, zipfile import os, sys, json, binascii, six, tempfile, zipfile
from tqdm import tqdm from tqdm import tqdm
from twisted.internet import reactor, defer from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from ..twisted.transcribe import Wormhole, WrongPasswordError from ..twisted.transcribe import Wormhole
from ..twisted.transit import TransitReceiver from ..twisted.transit import TransitReceiver
from ..errors import TransferError from ..errors import TransferError
@ -14,6 +14,13 @@ class RespondError(Exception):
self.response = response self.response = response
def receive(args, reactor=reactor): def receive(args, reactor=reactor):
"""I implement 'wormhole receive'. I return a Deferred that fires with
None (for success), or signals one of the following errors:
* WrongPasswordError: the two sides didn't use matching passwords
* Timeout: something didn't happen fast enough for our tastes
* TransferError: the sender rejected the transfer: verifier mismatch
* any other error: something unexpected happened
"""
return TwistedReceiver(args, reactor).go() return TwistedReceiver(args, reactor).go()
@ -26,9 +33,8 @@ class TwistedReceiver:
def msg(self, *args, **kwargs): def msg(self, *args, **kwargs):
print(*args, file=self.args.stdout, **kwargs) print(*args, file=self.args.stdout, **kwargs)
# TODO: @handle_server_error @inlineCallbacks
def go(self): def go(self):
d = defer.succeed(None)
tor_manager = None tor_manager = None
if self.args.tor: if self.args.tor:
_start = self.args.timing.add_event("import TorManager") _start = self.args.timing.add_event("import TorManager")
@ -39,18 +45,10 @@ class TwistedReceiver:
# tor in parallel with everything else, make sure the TorManager # tor in parallel with everything else, make sure the TorManager
# can lazy-provide an endpoint, and overlap the startup process # can lazy-provide an endpoint, and overlap the startup process
# with the user handing off the wormhole code # with the user handing off the wormhole code
d.addCallback(lambda _: tor_manager.start()) yield tor_manager.start()
def _make_wormhole(_): w = Wormhole(APPID, self.args.relay_url, tor_manager,
self._w = Wormhole(APPID, self.args.relay_url, tor_manager,
timing=self.args.timing, timing=self.args.timing,
reactor=self._reactor) reactor=self._reactor)
d.addCallback(_make_wormhole)
d.addCallback(lambda _: self._go(self._w, tor_manager))
def _always_close(res):
d2 = self._w.close()
d2.addBoth(lambda _: res)
return d2
d.addBoth(_always_close)
# I wanted to do this instead: # I wanted to do this instead:
# #
# try: # try:
@ -61,7 +59,9 @@ class TwistedReceiver:
# but when _go had a UsageError, the stacktrace was always displayed # but when _go had a UsageError, the stacktrace was always displayed
# as coming from the "yield self._go" line, which wasn't very useful # as coming from the "yield self._go" line, which wasn't very useful
# for tracking it down. # for tracking it down.
return d d = self._go(w, tor_manager)
d.addBoth(w.close)
yield d
@inlineCallbacks @inlineCallbacks
def _go(self, w, tor_manager): def _go(self, w, tor_manager):
@ -72,7 +72,7 @@ class TwistedReceiver:
try: try:
if "message" in them_d: if "message" in them_d:
yield self.handle_text(them_d, w) yield self.handle_text(them_d, w)
returnValue(0) returnValue(None)
if "file" in them_d: if "file" in them_d:
f = self.handle_file(them_d) f = self.handle_file(them_d)
rp = yield self.establish_transit(w, them_d, tor_manager) rp = yield self.establish_transit(w, them_d, tor_manager)
@ -92,8 +92,8 @@ class TwistedReceiver:
except RespondError as r: except RespondError as r:
data = json.dumps(r.response).encode("utf-8") data = json.dumps(r.response).encode("utf-8")
yield w.send_data(data) yield w.send_data(data)
raise SystemExit(1) raise TransferError(r["error"])
returnValue(0) returnValue(None)
@inlineCallbacks @inlineCallbacks
def handle_code(self, w): def handle_code(self, w):
@ -113,10 +113,8 @@ class TwistedReceiver:
@inlineCallbacks @inlineCallbacks
def get_data(self, w): def get_data(self, w):
try: # this may raise WrongPasswordError
them_bytes = yield w.get_data() them_bytes = yield w.get_data()
except WrongPasswordError as e:
raise TransferError(u"ERROR: " + e.explain())
them_d = json.loads(them_bytes.decode("utf-8")) them_d = json.loads(them_bytes.decode("utf-8"))
if "error" in them_d: if "error" in them_d:
raise TransferError(u"ERROR: " + them_d["error"]) raise TransferError(u"ERROR: " + them_d["error"])
@ -229,7 +227,7 @@ class TwistedReceiver:
self.msg() self.msg()
self.msg(u"Connection dropped before full file received") self.msg(u"Connection dropped before full file received")
self.msg(u"got %d bytes, wanted %d" % (received, self.xfersize)) self.msg(u"got %d bytes, wanted %d" % (received, self.xfersize))
returnValue(1) # TODO: exit properly raise TransferError("Connection dropped before full file received")
assert received == self.xfersize assert received == self.xfersize
def write_file(self, f): def write_file(self, f):

View File

@ -5,13 +5,21 @@ from twisted.protocols import basic
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from ..errors import TransferError from ..errors import TransferError
from ..twisted.transcribe import Wormhole, WrongPasswordError from ..twisted.transcribe import Wormhole
from ..twisted.transit import TransitSender from ..twisted.transit import TransitSender
APPID = u"lothar.com/wormhole/text-or-file-xfer" APPID = u"lothar.com/wormhole/text-or-file-xfer"
@inlineCallbacks @inlineCallbacks
def send(args, reactor=reactor): def send(args, reactor=reactor):
"""I implement 'wormhole send'. I return a Deferred that fires with None
(for success), or signals one of the following errors:
* WrongPasswordError: the two sides didn't use matching passwords
* Timeout: something didn't happen fast enough for our tastes
* TransferError: the receiver rejected the transfer: verifier mismatch,
permission not granted, ack not successful.
* any other error: something unexpected happened
"""
assert isinstance(args.relay_url, type(u"")) assert isinstance(args.relay_url, type(u""))
if args.zeromode: if args.zeromode:
assert not args.code assert not args.code
@ -43,6 +51,13 @@ def send(args, reactor=reactor):
w = Wormhole(APPID, args.relay_url, tor_manager, timing=args.timing, w = Wormhole(APPID, args.relay_url, tor_manager, timing=args.timing,
reactor=reactor) reactor=reactor)
d = _send(reactor, w, args, phase1, fd_to_send, tor_manager)
d.addBoth(w.close)
yield d
@inlineCallbacks
def _send(reactor, w, args, phase1, fd_to_send, tor_manager):
transit_sender = None
if fd_to_send: if fd_to_send:
transit_sender = TransitSender(args.transit_helper, transit_sender = TransitSender(args.transit_helper,
no_listen=args.no_listen, no_listen=args.no_listen,
@ -87,18 +102,15 @@ def send(args, reactor=reactor):
my_phase1_bytes = json.dumps(phase1).encode("utf-8") my_phase1_bytes = json.dumps(phase1).encode("utf-8")
yield w.send_data(my_phase1_bytes) yield w.send_data(my_phase1_bytes)
try: # this may raise WrongPasswordError
them_phase1_bytes = yield w.get_data() them_phase1_bytes = yield w.get_data()
except WrongPasswordError as e:
raise TransferError(e.explain())
them_phase1 = json.loads(them_phase1_bytes.decode("utf-8")) them_phase1 = json.loads(them_phase1_bytes.decode("utf-8"))
if fd_to_send is None: if fd_to_send is None:
if them_phase1["message_ack"] == "ok": if them_phase1["message_ack"] == "ok":
print(u"text message sent", file=args.stdout) print(u"text message sent", file=args.stdout)
yield w.close() returnValue(None) # terminates this function
returnValue(0) # terminates this function
raise TransferError("error sending text: %r" % (them_phase1,)) raise TransferError("error sending text: %r" % (them_phase1,))
if "error" in them_phase1: if "error" in them_phase1:
@ -108,10 +120,12 @@ def send(args, reactor=reactor):
raise TransferError("ambiguous response from remote, " raise TransferError("ambiguous response from remote, "
"transfer abandoned: %s" % (them_phase1,)) "transfer abandoned: %s" % (them_phase1,))
tdata = them_phase1["transit"] tdata = them_phase1["transit"]
yield w.close() # XXX the downside of closing above, rather than here, is that it leaves
# the channel claimed for a longer time
#yield w.close()
yield _send_file_twisted(tdata, transit_sender, fd_to_send, yield _send_file_twisted(tdata, transit_sender, fd_to_send,
args.stdout, args.hide_progress, args.timing) args.stdout, args.hide_progress, args.timing)
returnValue(0) returnValue(None)
def build_phase1_data(args): def build_phase1_data(args):
phase1 = {} phase1 = {}

View File

@ -2,7 +2,7 @@ from __future__ import print_function
import os, sys import os, sys
from twisted.internet.defer import maybeDeferred from twisted.internet.defer import maybeDeferred
from twisted.internet.task import react from twisted.internet.task import react
from ..errors import TransferError from ..errors import TransferError, WrongPasswordError, Timeout
from ..timing import DebugTiming from ..timing import DebugTiming
from .cli_args import parser from .cli_args import parser
@ -35,6 +35,7 @@ def run(reactor, argv, cwd, stdout, stderr, executable=None):
args.timing = timing = DebugTiming() args.timing = timing = DebugTiming()
timing.add_event("command dispatch") timing.add_event("command dispatch")
# fires with None, or raises an error
d = maybeDeferred(dispatch, args) d = maybeDeferred(dispatch, args)
def _maybe_dump_timing(res): def _maybe_dump_timing(res):
timing.add_event("exit") timing.add_event("exit")
@ -43,13 +44,12 @@ def run(reactor, argv, cwd, stdout, stderr, executable=None):
return res return res
d.addBoth(_maybe_dump_timing) d.addBoth(_maybe_dump_timing)
def _explain_error(f): def _explain_error(f):
f.trap(TransferError) # these three errors don't print a traceback, just an explanation
f.trap(TransferError, WrongPasswordError, Timeout)
print("ERROR:", f.value, file=stderr) print("ERROR:", f.value, file=stderr)
raise SystemExit(1) raise SystemExit(1)
d.addErrback(_explain_error) d.addErrback(_explain_error)
def _rc(rc): d.addCallback(lambda _: 0)
raise SystemExit(rc)
d.addCallback(_rc)
return d return d
def entry(): def entry():

View File

@ -1,4 +1,4 @@
import functools, textwrap import functools
class ServerError(Exception): class ServerError(Exception):
def __init__(self, message, relay): def __init__(self, message, relay):
@ -28,8 +28,6 @@ class WrongPasswordError(Exception):
chance. chance.
""" """
# or the data blob was corrupted, and that's why decrypt failed # or the data blob was corrupted, and that's why decrypt failed
def explain(self):
return textwrap.dedent(self.__doc__)
class ReflectionAttack(Exception): class ReflectionAttack(Exception):
"""An attacker (or bug) reflected our outgoing message back to us.""" """An attacker (or bug) reflected our outgoing message back to us."""

View File

@ -3,12 +3,12 @@ import os, sys, re, io, zipfile, six
from twisted.trial import unittest from twisted.trial import unittest
from twisted.python import procutils, log from twisted.python import procutils, log
from twisted.internet.utils import getProcessOutputAndValue from twisted.internet.utils import getProcessOutputAndValue
from twisted.internet.defer import inlineCallbacks from twisted.internet.defer import gatherResults, inlineCallbacks
from .. import __version__ from .. import __version__
from .common import ServerBase from .common import ServerBase
from ..cli import runner, cmd_send, cmd_receive from ..cli import runner, cmd_send, cmd_receive
from ..cli.cmd_send import build_phase1_data from ..cli.cmd_send import build_phase1_data
from ..errors import TransferError from ..errors import TransferError, WrongPasswordError
from ..timing import DebugTiming from ..timing import DebugTiming
class Phase1Data(unittest.TestCase): class Phase1Data(unittest.TestCase):
@ -306,15 +306,17 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
path=send_dir) path=send_dir)
receive_d = getProcessOutputAndValue(wormhole_bin, receive_args, receive_d = getProcessOutputAndValue(wormhole_bin, receive_args,
path=receive_dir) path=receive_dir)
send_res = yield send_d (send_res, receive_res) = yield gatherResults([send_d, receive_d],
True)
send_stdout = send_res[0].decode("utf-8") send_stdout = send_res[0].decode("utf-8")
send_stderr = send_res[1].decode("utf-8") send_stderr = send_res[1].decode("utf-8")
send_rc = send_res[2] send_rc = send_res[2]
receive_res = yield receive_d
receive_stdout = receive_res[0].decode("utf-8") receive_stdout = receive_res[0].decode("utf-8")
receive_stderr = receive_res[1].decode("utf-8") receive_stderr = receive_res[1].decode("utf-8")
receive_rc = receive_res[2] receive_rc = receive_res[2]
NL = os.linesep NL = os.linesep
self.assertEqual((send_rc, receive_rc), (0, 0),
(send_res, receive_res))
else: else:
sargs = runner.parser.parse_args(send_args) sargs = runner.parser.parse_args(send_args)
sargs.cwd = send_dir sargs.cwd = send_dir
@ -330,22 +332,11 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
receive_d = cmd_receive.receive(rargs) receive_d = cmd_receive.receive(rargs)
# The sender might fail, leaving the receiver hanging, or vice # The sender might fail, leaving the receiver hanging, or vice
# versa. If either side fails, cancel the other, so it won't # versa. Make sure we don't wait on one side exclusively
# matter which Deferred we wait upon first.
def _oops(f, which): yield gatherResults([send_d, receive_d], True)
log.msg("test_scripts: %s failed, cancelling both" % which)
send_d.cancel()
receive_d.cancel()
return f
send_d.addErrback(_oops, "send_d")
receive_d.addErrback(_oops, "receive_d")
send_rc = yield send_d
send_stdout = sargs.stdout.getvalue() send_stdout = sargs.stdout.getvalue()
send_stderr = sargs.stderr.getvalue() send_stderr = sargs.stderr.getvalue()
receive_rc = yield receive_d
receive_stdout = rargs.stdout.getvalue() receive_stdout = rargs.stdout.getvalue()
receive_stderr = rargs.stderr.getvalue() receive_stderr = rargs.stderr.getvalue()
@ -355,8 +346,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.maxDiff = None # show full output for assertion failures self.maxDiff = None # show full output for assertion failures
self.failUnlessEqual(send_stderr, "") self.failUnlessEqual(send_stderr, "",
self.failUnlessEqual(receive_stderr, "") (send_stdout, send_stderr))
self.failUnlessEqual(receive_stderr, "",
(receive_stdout, receive_stderr))
# check sender # check sender
if mode == "text": if mode == "text":
@ -417,9 +410,6 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
with open(fn, "r") as f: with open(fn, "r") as f:
self.failUnlessEqual(f.read(), message(i)) self.failUnlessEqual(f.read(), message(i))
self.failUnlessEqual(send_rc, 0)
self.failUnlessEqual(receive_rc, 0)
def test_text(self): def test_text(self):
return self._do_test() return self._do_test()
def test_text_subprocess(self): def test_text_subprocess(self):
@ -436,3 +426,62 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
return self._do_test(mode="directory", addslash=True) return self._do_test(mode="directory", addslash=True)
def test_directory_override(self): def test_directory_override(self):
return self._do_test(mode="directory", override_filename=True) return self._do_test(mode="directory", override_filename=True)
class Cleanup(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_text(self):
# the rendezvous channel should be deleted after success
code = u"1-abc"
common_args = ["--hide-progress",
"--relay-url", self.relayurl,
"--transit-helper", ""]
sargs = runner.parser.parse_args(common_args +
["send",
"--text", "secret message",
"--code", code])
sargs.stdout = io.StringIO()
sargs.stderr = io.StringIO()
sargs.timing = DebugTiming()
rargs = runner.parser.parse_args(common_args +
["receive", code])
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
send_d = cmd_send.send(sargs)
receive_d = cmd_receive.receive(rargs)
yield send_d
yield receive_d
cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated()
self.assertEqual(len(cids), 0)
@inlineCallbacks
def test_text_wrong_password(self):
# if the password was wrong, the rendezvous channel should still be
# deleted
common_args = ["--hide-progress",
"--relay-url", self.relayurl,
"--transit-helper", ""]
sargs = runner.parser.parse_args(common_args +
["send",
"--text", "secret message",
"--code", u"1-abc"])
sargs.stdout = io.StringIO()
sargs.stderr = io.StringIO()
sargs.timing = DebugTiming()
rargs = runner.parser.parse_args(common_args +
["receive", u"1-WRONG"])
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
send_d = cmd_send.send(sargs)
receive_d = cmd_receive.receive(rargs)
# both sides should be capable of detecting the mismatch
yield self.assertFailure(send_d, WrongPasswordError)
yield self.assertFailure(receive_d, WrongPasswordError)
cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated()
self.assertEqual(len(cids), 0)

View File

@ -27,30 +27,6 @@ def make_confmsg(confkey, nonce):
def to_bytes(u): def to_bytes(u):
return unicodedata.normalize("NFC", u).encode("utf-8") return unicodedata.normalize("NFC", u).encode("utf-8")
def close_on_error(meth): # method decorator
# Clients report certain errors as "moods", so the server can make a
# rough count failed connections (due to mismatched passwords, attacks,
# or timeouts). We don't report precondition failures, as those are the
# responsibility/fault of the local application code. We count
# non-precondition errors in case they represent server-side problems.
def _wrapper(self, *args, **kwargs):
d = defer.maybeDeferred(meth, self, *args, **kwargs)
def _onerror(f):
if f.check(Timeout):
d2 = self.close(u"lonely")
elif f.check(WrongPasswordError):
d2 = self.close(u"scary")
elif f.check(TypeError, UsageError):
# preconditions don't warrant _close_with_error()
d2 = defer.succeed(None)
else:
d2 = self.close(u"errory")
d2.addBoth(lambda _: f)
return d2
d.addErrback(_onerror)
return d
return _wrapper
class WSClient(websocket.WebSocketClientProtocol): class WSClient(websocket.WebSocketClientProtocol):
def onOpen(self): def onOpen(self):
self.wormhole_open = True self.wormhole_open = True
@ -181,13 +157,16 @@ class Wormhole:
return self._signal_error(welcome["error"]) return self._signal_error(welcome["error"])
@inlineCallbacks @inlineCallbacks
def _sleep(self): def _sleep(self, wake_on_error=True):
if self._error: # don't sleep if the bed's already on fire if wake_on_error and self._error:
# don't sleep if the bed's already on fire, unless we're waiting
# for the fire department to respond, in which case sure, keep on
# sleeping
raise self._error raise self._error
d = defer.Deferred() d = defer.Deferred()
self._sleepers.append(d) self._sleepers.append(d)
yield d yield d
if self._error: if wake_on_error and self._error:
raise self._error raise self._error
def _wakeup(self): def _wakeup(self):
@ -366,7 +345,6 @@ class Wormhole:
self._msg1 = unhexlify(d["msg1"].encode("ascii")) self._msg1 = unhexlify(d["msg1"].encode("ascii"))
return self return self
@close_on_error
@inlineCallbacks @inlineCallbacks
def get_verifier(self): def get_verifier(self):
if self._closed: raise UsageError if self._closed: raise UsageError
@ -458,7 +436,6 @@ class Wormhole:
data = box.decrypt(encrypted) data = box.decrypt(encrypted)
return data return data
@close_on_error
@inlineCallbacks @inlineCallbacks
def send_data(self, outbound_data, phase=u"data", wait=False): def send_data(self, outbound_data, phase=u"data", wait=False):
if not isinstance(outbound_data, type(b"")): if not isinstance(outbound_data, type(b"")):
@ -481,7 +458,6 @@ class Wormhole:
yield self._msg_send(phase, outbound_encrypted, wait) yield self._msg_send(phase, outbound_encrypted, wait)
self._timing.finish_event(_sent) self._timing.finish_event(_sent)
@close_on_error
@inlineCallbacks @inlineCallbacks
def get_data(self, phase=u"data"): def get_data(self, phase=u"data"):
if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if not isinstance(phase, type(u"")): raise TypeError(type(phase))
@ -506,26 +482,42 @@ class Wormhole:
# TODO: schedule reconnect, unless we're done # TODO: schedule reconnect, unless we're done
@inlineCallbacks @inlineCallbacks
def close(self, res=None, mood=u"happy"): def close(self, f=None, mood=None):
if not isinstance(mood, (type(None), type(u""))): """Do d.addBoth(w.close) at the end of your chain."""
raise TypeError(type(mood))
if self._closed: if self._closed:
returnValue(None) returnValue(None)
self._closed = True self._closed = True
if not self._ws: if not self._ws:
returnValue(None) returnValue(None)
if mood is None:
mood = u"happy"
if f:
if f.check(Timeout):
mood = u"lonely"
elif f.check(WrongPasswordError):
mood = u"scary"
elif f.check(TypeError, UsageError):
# preconditions don't warrant reporting mood
pass
else:
mood = u"errory" # other errors do
if not isinstance(mood, (type(None), type(u""))):
raise TypeError(type(mood))
self._timing.finish_event(self._timing_started, mood=mood) self._timing.finish_event(self._timing_started, mood=mood)
yield self._deallocate(mood) yield self._deallocate(mood)
# TODO: mark WebSocket as don't-reconnect # TODO: mark WebSocket as don't-reconnect
self._ws.transport.loseConnection() # probably flushes self._ws.transport.loseConnection() # probably flushes
del self._ws del self._ws
returnValue(f)
@inlineCallbacks @inlineCallbacks
def _deallocate(self, mood=None): def _deallocate(self, mood):
_sent = self._timing.add_event("close") _sent = self._timing.add_event("close")
yield self._ws_send(u"deallocate", mood=mood) yield self._ws_send(u"deallocate", mood=mood)
while self._deallocated_status is None: while self._deallocated_status is None:
yield self._sleep() yield self._sleep(wake_on_error=False)
self._timing.finish_event(_sent) self._timing.finish_event(_sent)
# TODO: set a timeout, don't wait forever for an ack # TODO: set a timeout, don't wait forever for an ack
# TODO: if the connection is lost, let it go # TODO: if the connection is lost, let it go