diff --git a/src/wormhole_cli/cmd_receive.py b/src/wormhole_cli/cmd_receive.py index 1e0e1bc..b95290e 100644 --- a/src/wormhole_cli/cmd_receive.py +++ b/src/wormhole_cli/cmd_receive.py @@ -14,62 +14,55 @@ class RespondError(Exception): def __init__(self, response): self.response = response -def receive_twisted_sync(args): - # try to use twisted.internet.task.react(f) here (but it calls sys.exit - # directly) - d = defer.Deferred() - # don't call receive_twisted() until after the reactor is running, so - # that if it raises an exception synchronously, we won't stop the reactor - # before it starts - reactor.callLater(0, d.callback, None) - d.addCallback(lambda _: receive_twisted(args)) - rc = [] - def _done(res): - rc.extend([True, res]) - reactor.stop() - def _err(f): - rc.extend([False, f]) - reactor.stop() - d.addCallbacks(_done, _err) - reactor.run() - if rc[0]: - return rc[1] - print(str(rc[1])) - rc[1].raiseException() - -def receive_twisted(args): - return TwistedReceiver(args).go() +def receive_twisted(args, reactor=reactor): + return TwistedReceiver(args, reactor).go() class TwistedReceiver: - def __init__(self, args): + def __init__(self, args, reactor=reactor): assert isinstance(args.relay_url, type(u"")) self.args = args + self._reactor = reactor def msg(self, *args, **kwargs): print(*args, file=self.args.stdout, **kwargs) # TODO: @handle_server_error - @inlineCallbacks def go(self): + d = defer.succeed(None) tor_manager = None if self.args.tor: _start = self.args.timing.add_event("import TorManager") from txwormhole.tor_manager import TorManager self.args.timing.finish_event(_start) - tor_manager = TorManager(reactor, timing=self.args.timing) + tor_manager = TorManager(self._reactor, timing=self.args.timing) # For now, block everything until Tor has started. Soon: launch # tor in parallel with everything else, make sure the TorManager # can lazy-provide an endpoint, and overlap the startup process # with the user handing off the wormhole code - yield tor_manager.start() - - w = Wormhole(APPID, self.args.relay_url, tor_manager, - timing=self.args.timing) - - rc = yield self._go(w, tor_manager) - yield w.close() - returnValue(rc) + d.addCallback(lambda _: tor_manager.start()) + def _make_wormhole(_): + self._w = Wormhole(APPID, self.args.relay_url, tor_manager, + timing=self.args.timing, + reactor=self._reactor) + d.addCallback(_make_wormhole) + d.addCallback(lambda _: self._go(self._w, tor_manager)) + def _always_close(res): + d2 = self._w.close() + d2.addBoth(lambda _: res) + return d2 + d.addBoth(_always_close) + # I wanted to do this instead: + # + # try: + # yield self._go(w, tor_manager) + # finally: + # yield w.close() + # + # but when _go had a UsageError, the stacktrace was always displayed + # as coming from the "yield self._go" line, which wasn't very useful + # for tracking it down. + return d @inlineCallbacks def _go(self, w, tor_manager): @@ -100,7 +93,7 @@ class TwistedReceiver: except RespondError as r: data = json.dumps(r.response).encode("utf-8") yield w.send_data(data) - returnValue(1) + raise SystemExit(1) returnValue(0) @inlineCallbacks @@ -198,6 +191,7 @@ class TwistedReceiver: transit_receiver = TransitReceiver(self.args.transit_helper, no_listen=self.args.no_listen, tor_manager=tor_manager, + reactor=self._reactor, timing=self.args.timing) transit_receiver.set_transit_key(transit_key) direct_hints = yield transit_receiver.get_direct_hints() diff --git a/src/wormhole_cli/cmd_send.py b/src/wormhole_cli/cmd_send.py index dc93ecd..99791cb 100644 --- a/src/wormhole_cli/cmd_send.py +++ b/src/wormhole_cli/cmd_send.py @@ -1,7 +1,7 @@ from __future__ import print_function import os, sys, io, json, binascii, six, tempfile, zipfile from twisted.protocols import basic -from twisted.internet import reactor, defer +from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue from wormhole.errors import TransferError from .progress import ProgressPrinter @@ -95,32 +95,8 @@ def build_phase1_data(args): raise TypeError("'%s' is neither file nor directory" % args.what) - -def send_twisted_sync(args): - # try to use twisted.internet.task.react(f) here (but it calls sys.exit - # directly) - d = defer.Deferred() - # don't call send_twisted() until after the reactor is running, so - # that if it raises an exception synchronously, we won't stop the reactor - # before it starts - reactor.callLater(0, d.callback, None) - d.addCallback(lambda _: send_twisted(args)) - rc = [] - def _done(res): - rc.extend([True, res]) - reactor.stop() - def _err(f): - rc.extend([False, f]) - reactor.stop() - d.addCallbacks(_done, _err) - reactor.run() - if rc[0]: - return rc[1] - print(str(rc[1])) - rc[1].raiseException() - @inlineCallbacks -def send_twisted(args): +def send_twisted(args, reactor=reactor): assert isinstance(args.relay_url, type(u"")) handle_zero(args) phase1, fd_to_send = build_phase1_data(args) @@ -138,12 +114,14 @@ def send_twisted(args): # user handing off the wormhole code yield tor_manager.start() - w = Wormhole(APPID, args.relay_url, tor_manager, timing=args.timing) + w = Wormhole(APPID, args.relay_url, tor_manager, timing=args.timing, + reactor=reactor) if fd_to_send: transit_sender = TransitSender(args.transit_helper, no_listen=args.no_listen, tor_manager=tor_manager, + reactor=reactor, timing=args.timing) phase1["transit"] = transit_data = {} transit_data["relay_connection_hints"] = transit_sender.get_relay_hints() diff --git a/src/wormhole_cli/runner.py b/src/wormhole_cli/runner.py index e15b37e..5bebbde 100644 --- a/src/wormhole_cli/runner.py +++ b/src/wormhole_cli/runner.py @@ -1,26 +1,28 @@ from __future__ import print_function import os, sys -from wormhole.errors import TransferError +from twisted.internet.defer import maybeDeferred +from twisted.internet.task import react from wormhole.timing import DebugTiming from .cli_args import parser -def dispatch(args): +def dispatch(args): # returns Deferred if args.func == "send/send": from . import cmd_send - return cmd_send.send_twisted_sync(args) + return cmd_send.send_twisted(args) if args.func == "receive/receive": _start = args.timing.add_event("import c_r_t") from . import cmd_receive args.timing.finish_event(_start) - return cmd_receive.receive_twisted_sync(args) + return cmd_receive.receive_twisted(args) raise ValueError("unknown args.func %s" % args.func) -def run(args, cwd, stdout, stderr, executable=None): - """This is invoked directly by the 'wormhole' entry-point script. It can - also invoked by entry() below.""" +def run(reactor, argv, cwd, stdout, stderr, executable=None): + """This is invoked by entry() below, and can also be invoked directly by + tests. + """ - args = parser.parse_args() + args = parser.parse_args(argv) if not getattr(args, "func", None): # So far this only works on py3. py2 exits with a really terse # "error: too few arguments" during parse_args(). @@ -31,30 +33,28 @@ def run(args, cwd, stdout, stderr, executable=None): args.stderr = stderr args.timing = timing = DebugTiming() - try: - timing.add_event("command dispatch") - rc = dispatch(args) + timing.add_event("command dispatch") + d = maybeDeferred(dispatch, args) + def _maybe_dump_timing(res): timing.add_event("exit") if args.dump_timing: timing.write(args.dump_timing, stderr) - return rc - except TransferError as e: - print(e, file=stderr) - if args.dump_timing: - timing.write(args.dump_timing, stderr) - return 1 - except ImportError as e: - print("--- ImportError ---", file=stderr) - print(e, file=stderr) - print("Please run 'python setup.py build'", file=stderr) - raise - return 1 + return res + d.addBoth(_maybe_dump_timing) + def _explain_error(f): + print(f.value, file=stderr) + raise SystemExit(1) + d.addErrback(_explain_error) + def _rc(rc): + raise SystemExit(rc) + d.addCallback(_rc) + return d def entry(): """This is used by a setuptools entry_point. When invoked this way, setuptools has already put the installed package on sys.path .""" - return run(sys.argv[1:], os.getcwd(), sys.stdout, sys.stderr, - executable=sys.argv[0]) + react(run, (sys.argv[1:], os.getcwd(), sys.stdout, sys.stderr, + sys.argv[0])) if __name__ == "__main__": args = parser.parse_args()