diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index b2fb894..cbe3037 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -1,9 +1,9 @@ from __future__ import print_function import os, sys, json, binascii, six, tempfile, zipfile from tqdm import tqdm -from twisted.internet import reactor, defer +from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue -from ..twisted.transcribe import Wormhole, WrongPasswordError +from ..twisted.transcribe import Wormhole from ..twisted.transit import TransitReceiver from ..errors import TransferError @@ -14,6 +14,13 @@ class RespondError(Exception): self.response = response def receive(args, reactor=reactor): + """I implement 'wormhole receive'. I return a Deferred that fires with + None (for success), or signals one of the following errors: + * WrongPasswordError: the two sides didn't use matching passwords + * Timeout: something didn't happen fast enough for our tastes + * TransferError: the sender rejected the transfer: verifier mismatch + * any other error: something unexpected happened + """ return TwistedReceiver(args, reactor).go() @@ -26,9 +33,8 @@ class TwistedReceiver: def msg(self, *args, **kwargs): print(*args, file=self.args.stdout, **kwargs) - # TODO: @handle_server_error + @inlineCallbacks def go(self): - d = defer.succeed(None) tor_manager = None if self.args.tor: _start = self.args.timing.add_event("import TorManager") @@ -39,18 +45,10 @@ class TwistedReceiver: # tor in parallel with everything else, make sure the TorManager # can lazy-provide an endpoint, and overlap the startup process # with the user handing off the wormhole code - d.addCallback(lambda _: tor_manager.start()) - def _make_wormhole(_): - self._w = Wormhole(APPID, self.args.relay_url, tor_manager, - timing=self.args.timing, - reactor=self._reactor) - d.addCallback(_make_wormhole) - d.addCallback(lambda _: self._go(self._w, tor_manager)) - def _always_close(res): - d2 = self._w.close() - d2.addBoth(lambda _: res) - return d2 - d.addBoth(_always_close) + yield tor_manager.start() + w = Wormhole(APPID, self.args.relay_url, tor_manager, + timing=self.args.timing, + reactor=self._reactor) # I wanted to do this instead: # # try: @@ -61,7 +59,9 @@ class TwistedReceiver: # but when _go had a UsageError, the stacktrace was always displayed # as coming from the "yield self._go" line, which wasn't very useful # for tracking it down. - return d + d = self._go(w, tor_manager) + d.addBoth(w.close) + yield d @inlineCallbacks def _go(self, w, tor_manager): @@ -72,7 +72,7 @@ class TwistedReceiver: try: if "message" in them_d: yield self.handle_text(them_d, w) - returnValue(0) + returnValue(None) if "file" in them_d: f = self.handle_file(them_d) rp = yield self.establish_transit(w, them_d, tor_manager) @@ -92,8 +92,8 @@ class TwistedReceiver: except RespondError as r: data = json.dumps(r.response).encode("utf-8") yield w.send_data(data) - raise SystemExit(1) - returnValue(0) + raise TransferError(r["error"]) + returnValue(None) @inlineCallbacks def handle_code(self, w): @@ -113,10 +113,8 @@ class TwistedReceiver: @inlineCallbacks def get_data(self, w): - try: - them_bytes = yield w.get_data() - except WrongPasswordError as e: - raise TransferError(u"ERROR: " + e.explain()) + # this may raise WrongPasswordError + them_bytes = yield w.get_data() them_d = json.loads(them_bytes.decode("utf-8")) if "error" in them_d: raise TransferError(u"ERROR: " + them_d["error"]) @@ -229,7 +227,7 @@ class TwistedReceiver: self.msg() self.msg(u"Connection dropped before full file received") self.msg(u"got %d bytes, wanted %d" % (received, self.xfersize)) - returnValue(1) # TODO: exit properly + raise TransferError("Connection dropped before full file received") assert received == self.xfersize def write_file(self, f): diff --git a/src/wormhole/cli/cmd_send.py b/src/wormhole/cli/cmd_send.py index 652a370..ce58372 100644 --- a/src/wormhole/cli/cmd_send.py +++ b/src/wormhole/cli/cmd_send.py @@ -5,13 +5,21 @@ from twisted.protocols import basic from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue from ..errors import TransferError -from ..twisted.transcribe import Wormhole, WrongPasswordError +from ..twisted.transcribe import Wormhole from ..twisted.transit import TransitSender APPID = u"lothar.com/wormhole/text-or-file-xfer" @inlineCallbacks def send(args, reactor=reactor): + """I implement 'wormhole send'. I return a Deferred that fires with None + (for success), or signals one of the following errors: + * WrongPasswordError: the two sides didn't use matching passwords + * Timeout: something didn't happen fast enough for our tastes + * TransferError: the receiver rejected the transfer: verifier mismatch, + permission not granted, ack not successful. + * any other error: something unexpected happened + """ assert isinstance(args.relay_url, type(u"")) if args.zeromode: assert not args.code @@ -43,6 +51,13 @@ def send(args, reactor=reactor): w = Wormhole(APPID, args.relay_url, tor_manager, timing=args.timing, reactor=reactor) + d = _send(reactor, w, args, phase1, fd_to_send, tor_manager) + d.addBoth(w.close) + yield d + +@inlineCallbacks +def _send(reactor, w, args, phase1, fd_to_send, tor_manager): + transit_sender = None if fd_to_send: transit_sender = TransitSender(args.transit_helper, no_listen=args.no_listen, @@ -87,18 +102,15 @@ def send(args, reactor=reactor): my_phase1_bytes = json.dumps(phase1).encode("utf-8") yield w.send_data(my_phase1_bytes) - try: - them_phase1_bytes = yield w.get_data() - except WrongPasswordError as e: - raise TransferError(e.explain()) + # this may raise WrongPasswordError + them_phase1_bytes = yield w.get_data() them_phase1 = json.loads(them_phase1_bytes.decode("utf-8")) if fd_to_send is None: if them_phase1["message_ack"] == "ok": print(u"text message sent", file=args.stdout) - yield w.close() - returnValue(0) # terminates this function + returnValue(None) # terminates this function raise TransferError("error sending text: %r" % (them_phase1,)) if "error" in them_phase1: @@ -108,10 +120,12 @@ def send(args, reactor=reactor): raise TransferError("ambiguous response from remote, " "transfer abandoned: %s" % (them_phase1,)) tdata = them_phase1["transit"] - yield w.close() + # XXX the downside of closing above, rather than here, is that it leaves + # the channel claimed for a longer time + #yield w.close() yield _send_file_twisted(tdata, transit_sender, fd_to_send, args.stdout, args.hide_progress, args.timing) - returnValue(0) + returnValue(None) def build_phase1_data(args): phase1 = {} diff --git a/src/wormhole/cli/runner.py b/src/wormhole/cli/runner.py index f2d1bea..1ed73d5 100644 --- a/src/wormhole/cli/runner.py +++ b/src/wormhole/cli/runner.py @@ -2,7 +2,7 @@ from __future__ import print_function import os, sys from twisted.internet.defer import maybeDeferred from twisted.internet.task import react -from ..errors import TransferError +from ..errors import TransferError, WrongPasswordError, Timeout from ..timing import DebugTiming from .cli_args import parser @@ -35,6 +35,7 @@ def run(reactor, argv, cwd, stdout, stderr, executable=None): args.timing = timing = DebugTiming() timing.add_event("command dispatch") + # fires with None, or raises an error d = maybeDeferred(dispatch, args) def _maybe_dump_timing(res): timing.add_event("exit") @@ -43,13 +44,12 @@ def run(reactor, argv, cwd, stdout, stderr, executable=None): return res d.addBoth(_maybe_dump_timing) def _explain_error(f): - f.trap(TransferError) + # these three errors don't print a traceback, just an explanation + f.trap(TransferError, WrongPasswordError, Timeout) print("ERROR:", f.value, file=stderr) raise SystemExit(1) d.addErrback(_explain_error) - def _rc(rc): - raise SystemExit(rc) - d.addCallback(_rc) + d.addCallback(lambda _: 0) return d def entry(): diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index 2846626..898bafd 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -1,4 +1,4 @@ -import functools, textwrap +import functools class ServerError(Exception): def __init__(self, message, relay): @@ -28,8 +28,6 @@ class WrongPasswordError(Exception): chance. """ # or the data blob was corrupted, and that's why decrypt failed - def explain(self): - return textwrap.dedent(self.__doc__) class ReflectionAttack(Exception): """An attacker (or bug) reflected our outgoing message back to us.""" diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index 7f65de6..b39f2c5 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -3,12 +3,12 @@ import os, sys, re, io, zipfile, six from twisted.trial import unittest from twisted.python import procutils, log from twisted.internet.utils import getProcessOutputAndValue -from twisted.internet.defer import inlineCallbacks +from twisted.internet.defer import gatherResults, inlineCallbacks from .. import __version__ from .common import ServerBase from ..cli import runner, cmd_send, cmd_receive from ..cli.cmd_send import build_phase1_data -from ..errors import TransferError +from ..errors import TransferError, WrongPasswordError from ..timing import DebugTiming class Phase1Data(unittest.TestCase): @@ -306,15 +306,17 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): path=send_dir) receive_d = getProcessOutputAndValue(wormhole_bin, receive_args, path=receive_dir) - send_res = yield send_d + (send_res, receive_res) = yield gatherResults([send_d, receive_d], + True) send_stdout = send_res[0].decode("utf-8") send_stderr = send_res[1].decode("utf-8") send_rc = send_res[2] - receive_res = yield receive_d receive_stdout = receive_res[0].decode("utf-8") receive_stderr = receive_res[1].decode("utf-8") receive_rc = receive_res[2] NL = os.linesep + self.assertEqual((send_rc, receive_rc), (0, 0), + (send_res, receive_res)) else: sargs = runner.parser.parse_args(send_args) sargs.cwd = send_dir @@ -330,22 +332,11 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): receive_d = cmd_receive.receive(rargs) # The sender might fail, leaving the receiver hanging, or vice - # versa. If either side fails, cancel the other, so it won't - # matter which Deferred we wait upon first. + # versa. Make sure we don't wait on one side exclusively - def _oops(f, which): - log.msg("test_scripts: %s failed, cancelling both" % which) - send_d.cancel() - receive_d.cancel() - return f - send_d.addErrback(_oops, "send_d") - receive_d.addErrback(_oops, "receive_d") - - send_rc = yield send_d + yield gatherResults([send_d, receive_d], True) send_stdout = sargs.stdout.getvalue() send_stderr = sargs.stderr.getvalue() - - receive_rc = yield receive_d receive_stdout = rargs.stdout.getvalue() receive_stderr = rargs.stderr.getvalue() @@ -355,8 +346,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.maxDiff = None # show full output for assertion failures - self.failUnlessEqual(send_stderr, "") - self.failUnlessEqual(receive_stderr, "") + self.failUnlessEqual(send_stderr, "", + (send_stdout, send_stderr)) + self.failUnlessEqual(receive_stderr, "", + (receive_stdout, receive_stderr)) # check sender if mode == "text": @@ -417,9 +410,6 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): with open(fn, "r") as f: self.failUnlessEqual(f.read(), message(i)) - self.failUnlessEqual(send_rc, 0) - self.failUnlessEqual(receive_rc, 0) - def test_text(self): return self._do_test() def test_text_subprocess(self): @@ -436,3 +426,62 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): return self._do_test(mode="directory", addslash=True) def test_directory_override(self): return self._do_test(mode="directory", override_filename=True) + +class Cleanup(ServerBase, unittest.TestCase): + @inlineCallbacks + def test_text(self): + # the rendezvous channel should be deleted after success + code = u"1-abc" + common_args = ["--hide-progress", + "--relay-url", self.relayurl, + "--transit-helper", ""] + sargs = runner.parser.parse_args(common_args + + ["send", + "--text", "secret message", + "--code", code]) + sargs.stdout = io.StringIO() + sargs.stderr = io.StringIO() + sargs.timing = DebugTiming() + rargs = runner.parser.parse_args(common_args + + ["receive", code]) + rargs.stdout = io.StringIO() + rargs.stderr = io.StringIO() + rargs.timing = DebugTiming() + send_d = cmd_send.send(sargs) + receive_d = cmd_receive.receive(rargs) + + yield send_d + yield receive_d + + cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated() + self.assertEqual(len(cids), 0) + + @inlineCallbacks + def test_text_wrong_password(self): + # if the password was wrong, the rendezvous channel should still be + # deleted + common_args = ["--hide-progress", + "--relay-url", self.relayurl, + "--transit-helper", ""] + sargs = runner.parser.parse_args(common_args + + ["send", + "--text", "secret message", + "--code", u"1-abc"]) + sargs.stdout = io.StringIO() + sargs.stderr = io.StringIO() + sargs.timing = DebugTiming() + rargs = runner.parser.parse_args(common_args + + ["receive", u"1-WRONG"]) + rargs.stdout = io.StringIO() + rargs.stderr = io.StringIO() + rargs.timing = DebugTiming() + send_d = cmd_send.send(sargs) + receive_d = cmd_receive.receive(rargs) + + # both sides should be capable of detecting the mismatch + yield self.assertFailure(send_d, WrongPasswordError) + yield self.assertFailure(receive_d, WrongPasswordError) + + cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated() + self.assertEqual(len(cids), 0) + diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index c9f599c..8219834 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -27,30 +27,6 @@ def make_confmsg(confkey, nonce): def to_bytes(u): return unicodedata.normalize("NFC", u).encode("utf-8") -def close_on_error(meth): # method decorator - # Clients report certain errors as "moods", so the server can make a - # rough count failed connections (due to mismatched passwords, attacks, - # or timeouts). We don't report precondition failures, as those are the - # responsibility/fault of the local application code. We count - # non-precondition errors in case they represent server-side problems. - def _wrapper(self, *args, **kwargs): - d = defer.maybeDeferred(meth, self, *args, **kwargs) - def _onerror(f): - if f.check(Timeout): - d2 = self.close(u"lonely") - elif f.check(WrongPasswordError): - d2 = self.close(u"scary") - elif f.check(TypeError, UsageError): - # preconditions don't warrant _close_with_error() - d2 = defer.succeed(None) - else: - d2 = self.close(u"errory") - d2.addBoth(lambda _: f) - return d2 - d.addErrback(_onerror) - return d - return _wrapper - class WSClient(websocket.WebSocketClientProtocol): def onOpen(self): self.wormhole_open = True @@ -181,13 +157,16 @@ class Wormhole: return self._signal_error(welcome["error"]) @inlineCallbacks - def _sleep(self): - if self._error: # don't sleep if the bed's already on fire + def _sleep(self, wake_on_error=True): + if wake_on_error and self._error: + # don't sleep if the bed's already on fire, unless we're waiting + # for the fire department to respond, in which case sure, keep on + # sleeping raise self._error d = defer.Deferred() self._sleepers.append(d) yield d - if self._error: + if wake_on_error and self._error: raise self._error def _wakeup(self): @@ -366,7 +345,6 @@ class Wormhole: self._msg1 = unhexlify(d["msg1"].encode("ascii")) return self - @close_on_error @inlineCallbacks def get_verifier(self): if self._closed: raise UsageError @@ -458,7 +436,6 @@ class Wormhole: data = box.decrypt(encrypted) return data - @close_on_error @inlineCallbacks def send_data(self, outbound_data, phase=u"data", wait=False): if not isinstance(outbound_data, type(b"")): @@ -481,7 +458,6 @@ class Wormhole: yield self._msg_send(phase, outbound_encrypted, wait) self._timing.finish_event(_sent) - @close_on_error @inlineCallbacks def get_data(self, phase=u"data"): if not isinstance(phase, type(u"")): raise TypeError(type(phase)) @@ -506,26 +482,42 @@ class Wormhole: # TODO: schedule reconnect, unless we're done @inlineCallbacks - def close(self, res=None, mood=u"happy"): - if not isinstance(mood, (type(None), type(u""))): - raise TypeError(type(mood)) + def close(self, f=None, mood=None): + """Do d.addBoth(w.close) at the end of your chain.""" if self._closed: returnValue(None) self._closed = True if not self._ws: returnValue(None) + + if mood is None: + mood = u"happy" + if f: + if f.check(Timeout): + mood = u"lonely" + elif f.check(WrongPasswordError): + mood = u"scary" + elif f.check(TypeError, UsageError): + # preconditions don't warrant reporting mood + pass + else: + mood = u"errory" # other errors do + if not isinstance(mood, (type(None), type(u""))): + raise TypeError(type(mood)) + self._timing.finish_event(self._timing_started, mood=mood) yield self._deallocate(mood) # TODO: mark WebSocket as don't-reconnect self._ws.transport.loseConnection() # probably flushes del self._ws + returnValue(f) @inlineCallbacks - def _deallocate(self, mood=None): + def _deallocate(self, mood): _sent = self._timing.add_event("close") yield self._ws_send(u"deallocate", mood=mood) while self._deallocated_status is None: - yield self._sleep() + yield self._sleep(wake_on_error=False) self._timing.finish_event(_sent) # TODO: set a timeout, don't wait forever for an ack # TODO: if the connection is lost, let it go