CLI runner: use task.react, remove sync wrapper

This commit is contained in:
Brian Warner 2016-04-20 00:02:05 -07:00
parent 44ca031047
commit 8068aeeca4
3 changed files with 61 additions and 89 deletions

View File

@ -14,62 +14,55 @@ class RespondError(Exception):
def __init__(self, response): def __init__(self, response):
self.response = response self.response = response
def receive_twisted_sync(args): def receive_twisted(args, reactor=reactor):
# try to use twisted.internet.task.react(f) here (but it calls sys.exit return TwistedReceiver(args, reactor).go()
# directly)
d = defer.Deferred()
# don't call receive_twisted() until after the reactor is running, so
# that if it raises an exception synchronously, we won't stop the reactor
# before it starts
reactor.callLater(0, d.callback, None)
d.addCallback(lambda _: receive_twisted(args))
rc = []
def _done(res):
rc.extend([True, res])
reactor.stop()
def _err(f):
rc.extend([False, f])
reactor.stop()
d.addCallbacks(_done, _err)
reactor.run()
if rc[0]:
return rc[1]
print(str(rc[1]))
rc[1].raiseException()
def receive_twisted(args):
return TwistedReceiver(args).go()
class TwistedReceiver: class TwistedReceiver:
def __init__(self, args): def __init__(self, args, reactor=reactor):
assert isinstance(args.relay_url, type(u"")) assert isinstance(args.relay_url, type(u""))
self.args = args self.args = args
self._reactor = reactor
def msg(self, *args, **kwargs): def msg(self, *args, **kwargs):
print(*args, file=self.args.stdout, **kwargs) print(*args, file=self.args.stdout, **kwargs)
# TODO: @handle_server_error # TODO: @handle_server_error
@inlineCallbacks
def go(self): def go(self):
d = defer.succeed(None)
tor_manager = None tor_manager = None
if self.args.tor: if self.args.tor:
_start = self.args.timing.add_event("import TorManager") _start = self.args.timing.add_event("import TorManager")
from txwormhole.tor_manager import TorManager from txwormhole.tor_manager import TorManager
self.args.timing.finish_event(_start) self.args.timing.finish_event(_start)
tor_manager = TorManager(reactor, timing=self.args.timing) tor_manager = TorManager(self._reactor, timing=self.args.timing)
# For now, block everything until Tor has started. Soon: launch # For now, block everything until Tor has started. Soon: launch
# tor in parallel with everything else, make sure the TorManager # tor in parallel with everything else, make sure the TorManager
# can lazy-provide an endpoint, and overlap the startup process # can lazy-provide an endpoint, and overlap the startup process
# with the user handing off the wormhole code # with the user handing off the wormhole code
yield tor_manager.start() d.addCallback(lambda _: tor_manager.start())
def _make_wormhole(_):
w = Wormhole(APPID, self.args.relay_url, tor_manager, self._w = Wormhole(APPID, self.args.relay_url, tor_manager,
timing=self.args.timing) timing=self.args.timing,
reactor=self._reactor)
rc = yield self._go(w, tor_manager) d.addCallback(_make_wormhole)
yield w.close() d.addCallback(lambda _: self._go(self._w, tor_manager))
returnValue(rc) def _always_close(res):
d2 = self._w.close()
d2.addBoth(lambda _: res)
return d2
d.addBoth(_always_close)
# I wanted to do this instead:
#
# try:
# yield self._go(w, tor_manager)
# finally:
# yield w.close()
#
# but when _go had a UsageError, the stacktrace was always displayed
# as coming from the "yield self._go" line, which wasn't very useful
# for tracking it down.
return d
@inlineCallbacks @inlineCallbacks
def _go(self, w, tor_manager): def _go(self, w, tor_manager):
@ -100,7 +93,7 @@ class TwistedReceiver:
except RespondError as r: except RespondError as r:
data = json.dumps(r.response).encode("utf-8") data = json.dumps(r.response).encode("utf-8")
yield w.send_data(data) yield w.send_data(data)
returnValue(1) raise SystemExit(1)
returnValue(0) returnValue(0)
@inlineCallbacks @inlineCallbacks
@ -198,6 +191,7 @@ class TwistedReceiver:
transit_receiver = TransitReceiver(self.args.transit_helper, transit_receiver = TransitReceiver(self.args.transit_helper,
no_listen=self.args.no_listen, no_listen=self.args.no_listen,
tor_manager=tor_manager, tor_manager=tor_manager,
reactor=self._reactor,
timing=self.args.timing) timing=self.args.timing)
transit_receiver.set_transit_key(transit_key) transit_receiver.set_transit_key(transit_key)
direct_hints = yield transit_receiver.get_direct_hints() direct_hints = yield transit_receiver.get_direct_hints()

View File

@ -1,7 +1,7 @@
from __future__ import print_function from __future__ import print_function
import os, sys, io, json, binascii, six, tempfile, zipfile import os, sys, io, json, binascii, six, tempfile, zipfile
from twisted.protocols import basic from twisted.protocols import basic
from twisted.internet import reactor, defer from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from wormhole.errors import TransferError from wormhole.errors import TransferError
from .progress import ProgressPrinter from .progress import ProgressPrinter
@ -95,32 +95,8 @@ def build_phase1_data(args):
raise TypeError("'%s' is neither file nor directory" % args.what) raise TypeError("'%s' is neither file nor directory" % args.what)
def send_twisted_sync(args):
# try to use twisted.internet.task.react(f) here (but it calls sys.exit
# directly)
d = defer.Deferred()
# don't call send_twisted() until after the reactor is running, so
# that if it raises an exception synchronously, we won't stop the reactor
# before it starts
reactor.callLater(0, d.callback, None)
d.addCallback(lambda _: send_twisted(args))
rc = []
def _done(res):
rc.extend([True, res])
reactor.stop()
def _err(f):
rc.extend([False, f])
reactor.stop()
d.addCallbacks(_done, _err)
reactor.run()
if rc[0]:
return rc[1]
print(str(rc[1]))
rc[1].raiseException()
@inlineCallbacks @inlineCallbacks
def send_twisted(args): def send_twisted(args, reactor=reactor):
assert isinstance(args.relay_url, type(u"")) assert isinstance(args.relay_url, type(u""))
handle_zero(args) handle_zero(args)
phase1, fd_to_send = build_phase1_data(args) phase1, fd_to_send = build_phase1_data(args)
@ -138,12 +114,14 @@ def send_twisted(args):
# user handing off the wormhole code # user handing off the wormhole code
yield tor_manager.start() yield tor_manager.start()
w = Wormhole(APPID, args.relay_url, tor_manager, timing=args.timing) w = Wormhole(APPID, args.relay_url, tor_manager, timing=args.timing,
reactor=reactor)
if fd_to_send: if fd_to_send:
transit_sender = TransitSender(args.transit_helper, transit_sender = TransitSender(args.transit_helper,
no_listen=args.no_listen, no_listen=args.no_listen,
tor_manager=tor_manager, tor_manager=tor_manager,
reactor=reactor,
timing=args.timing) timing=args.timing)
phase1["transit"] = transit_data = {} phase1["transit"] = transit_data = {}
transit_data["relay_connection_hints"] = transit_sender.get_relay_hints() transit_data["relay_connection_hints"] = transit_sender.get_relay_hints()

View File

@ -1,26 +1,28 @@
from __future__ import print_function from __future__ import print_function
import os, sys import os, sys
from wormhole.errors import TransferError from twisted.internet.defer import maybeDeferred
from twisted.internet.task import react
from wormhole.timing import DebugTiming from wormhole.timing import DebugTiming
from .cli_args import parser from .cli_args import parser
def dispatch(args): def dispatch(args): # returns Deferred
if args.func == "send/send": if args.func == "send/send":
from . import cmd_send from . import cmd_send
return cmd_send.send_twisted_sync(args) return cmd_send.send_twisted(args)
if args.func == "receive/receive": if args.func == "receive/receive":
_start = args.timing.add_event("import c_r_t") _start = args.timing.add_event("import c_r_t")
from . import cmd_receive from . import cmd_receive
args.timing.finish_event(_start) args.timing.finish_event(_start)
return cmd_receive.receive_twisted_sync(args) return cmd_receive.receive_twisted(args)
raise ValueError("unknown args.func %s" % args.func) raise ValueError("unknown args.func %s" % args.func)
def run(args, cwd, stdout, stderr, executable=None): def run(reactor, argv, cwd, stdout, stderr, executable=None):
"""This is invoked directly by the 'wormhole' entry-point script. It can """This is invoked by entry() below, and can also be invoked directly by
also invoked by entry() below.""" tests.
"""
args = parser.parse_args() args = parser.parse_args(argv)
if not getattr(args, "func", None): if not getattr(args, "func", None):
# So far this only works on py3. py2 exits with a really terse # So far this only works on py3. py2 exits with a really terse
# "error: too few arguments" during parse_args(). # "error: too few arguments" during parse_args().
@ -31,30 +33,28 @@ def run(args, cwd, stdout, stderr, executable=None):
args.stderr = stderr args.stderr = stderr
args.timing = timing = DebugTiming() args.timing = timing = DebugTiming()
try: timing.add_event("command dispatch")
timing.add_event("command dispatch") d = maybeDeferred(dispatch, args)
rc = dispatch(args) def _maybe_dump_timing(res):
timing.add_event("exit") timing.add_event("exit")
if args.dump_timing: if args.dump_timing:
timing.write(args.dump_timing, stderr) timing.write(args.dump_timing, stderr)
return rc return res
except TransferError as e: d.addBoth(_maybe_dump_timing)
print(e, file=stderr) def _explain_error(f):
if args.dump_timing: print(f.value, file=stderr)
timing.write(args.dump_timing, stderr) raise SystemExit(1)
return 1 d.addErrback(_explain_error)
except ImportError as e: def _rc(rc):
print("--- ImportError ---", file=stderr) raise SystemExit(rc)
print(e, file=stderr) d.addCallback(_rc)
print("Please run 'python setup.py build'", file=stderr) return d
raise
return 1
def entry(): def entry():
"""This is used by a setuptools entry_point. When invoked this way, """This is used by a setuptools entry_point. When invoked this way,
setuptools has already put the installed package on sys.path .""" setuptools has already put the installed package on sys.path ."""
return run(sys.argv[1:], os.getcwd(), sys.stdout, sys.stderr, react(run, (sys.argv[1:], os.getcwd(), sys.stdout, sys.stderr,
executable=sys.argv[0]) sys.argv[0]))
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()