From e16b53817eee12ed524fdcf9263deedce72419a6 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 3 Jun 2016 15:17:47 -0700 Subject: [PATCH] Refactor to use Click --- setup.py | 12 +- src/wormhole/cli/cli.py | 231 +++++++++++++++++++ src/wormhole/cli/cli_args.py | 179 --------------- src/wormhole/cli/cmd_receive.py | 2 +- src/wormhole/cli/cmd_send.py | 9 +- src/wormhole/cli/runner.py | 81 ------- src/wormhole/server/cli.py | 162 ++++++++++++++ src/wormhole/server/cli_args.py | 77 ------- src/wormhole/server/cmd_server.py | 6 +- src/wormhole/test/common.py | 12 +- src/wormhole/test/test_scripts.py | 353 ++++++++++++++---------------- src/wormhole/transit.py | 4 + src/wormhole/wormhole.py | 6 +- 13 files changed, 594 insertions(+), 540 deletions(-) create mode 100644 src/wormhole/cli/cli.py delete mode 100644 src/wormhole/cli/cli_args.py delete mode 100644 src/wormhole/cli/runner.py create mode 100644 src/wormhole/server/cli.py delete mode 100644 src/wormhole/server/cli_args.py diff --git a/setup.py b/setup.py index 5d57848..7554165 100644 --- a/setup.py +++ b/setup.py @@ -19,15 +19,19 @@ setup(name="magic-wormhole", "wormhole.test", ], package_data={"wormhole.server": ["db-schemas/*.sql"]}, - entry_points={"console_scripts": - ["wormhole = wormhole.cli.runner:entry", - "wormhole-server = wormhole.server.runner:entry", - ]}, + entry_points={ + "console_scripts": + [ + "wormhole = wormhole.cli.cli:wormhole", + "wormhole-server = wormhole.server.cli:server", + ] + }, install_requires=["spake2==0.7", "pynacl", "argparse", "six", "twisted", "autobahn[twisted] >= 0.14.1", "hkdf", "tqdm", + "click", ], extras_require={':sys_platform=="win32"': ["pypiwin32"], "tor": ["txtorcon", "ipaddress"]}, diff --git a/src/wormhole/cli/cli.py b/src/wormhole/cli/cli.py new file mode 100644 index 0000000..4085889 --- /dev/null +++ b/src/wormhole/cli/cli.py @@ -0,0 +1,231 @@ +from __future__ import print_function + +import os +import time +start = time.time() +import traceback +from textwrap import fill, dedent +from sys import stdout, stderr +from . import public_relay +from .. import __version__ +from ..timing import DebugTiming +from ..errors import WrongPasswordError, WelcomeError, KeyFormatError +from twisted.internet.defer import inlineCallbacks, maybeDeferred +from twisted.internet.task import react + +import click +top_import_finish = time.time() + + +class Config(object): + """ + Union of config options that we pass down to (sub) commands. + """ + def __init__(self): + # common options + self.timing = DebugTiming() + self.tor = None + self.listen = None + self.relay_url = u"" + self.transit_helper = u"" + self.cwd = os.getcwd() + # send/receive commands + self.code = None + self.code_length = 2 + self.verify = False + self.hide_progress = False + self.dump_timing = False + self.stdout = stdout + self.stderr = stderr + self.zeromode = False + self.accept_file = None + self.output_file = None + # send only + self.text = None + self.what = None + + +ALIASES = { + "tx": "send", + "rx": "receive", +} +class AliasedGroup(click.Group): + def get_command(self, ctx, cmd_name): + cmd_name = ALIASES.get(cmd_name, cmd_name) + return click.Group.get_command(self, ctx, cmd_name) + + +# top-level command ("wormhole ...") +@click.group(cls=AliasedGroup) +@click.option( + "--relay-url", default=public_relay.RENDEZVOUS_RELAY, + metavar="URL", + help="rendezvous relay to use", +) +@click.option( + "--transit-helper", default=public_relay.TRANSIT_RELAY, + metavar="tcp:HOST:PORT", + help="transit relay to use", +) +@click.option( + "-c", "--code-length", default=2, + metavar="NUMWORDS", + help="length of code (in bytes/words)", +) +@click.option( + "-v", "--verify", is_flag=True, default=False, + help="display (and wait for acceptance of) verification string", +) +@click.option( + "--hide-progress", is_flag=True, default=False, + help="supress progress-bar display", +) +@click.option( + "--dump-timing", type=type(u""), # TODO: hide from --help output + default=None, + metavar="FILE.json", + help="(debug) write timing data to file", +) +@click.option( + "--no-listen", is_flag=True, default=False, + help="(debug) don't open a listening socket for Transit", +) +@click.option( + "--tor", is_flag=True, default=True, + help="use Tor when connecting", +) +@click.version_option( + message="magic-wormhole %(version)s", + version=__version__, +) +@click.pass_context +def wormhole(ctx, tor, no_listen, dump_timing, hide_progress, + verify, code_length, transit_helper, relay_url): + """ + Create a Magic Wormhole and communicate through it. + + Wormholes are created by speaking the same magic CODE in two + different places at the same time. Wormholes are secure against + anyone who doesn't use the same code. + """ + ctx.obj = cfg = Config() + ctx.tor = tor + if no_listen: + cfg.listen = False + cfg.relay_url = relay_url + cfg.transit_helper = transit_helper + cfg.code_length = code_length + cfg.verify = verify + cfg.hide_progress = hide_progress + cfg.dump_timing = dump_timing + + +@inlineCallbacks +def _dispatch_command(reactor, cfg, command): + """ + Internal helper. This calls the give command (a no-argument + callable) with the Config instance in cfg and interprets any + errors for the user. + """ + cfg.timing.add("command dispatch") + cfg.timing.add("import", when=start, which="top").finish(when=top_import_finish) + + try: + yield maybeDeferred(command) + except WrongPasswordError as e: + msg = fill("ERROR: " + dedent(e.__doc__)) + print(msg, file=stderr) + except WelcomeError as e: + msg = fill("ERROR: " + dedent(e.__doc__)) + print(msg, file=stderr) + print(file=stderr) + print(str(e), file=stderr) + except KeyFormatError as e: + msg = fill("ERROR: " + dedent(e.__doc__)) + print(msg, file=stderr) + except Exception as e: + traceback.print_exc() + print("ERROR:", e, file=stderr) + raise SystemExit(1) + + cfg.timing.add("exit") + if cfg.dump_timing: + cfg.timing.write(cfg.dump_timing, stderr) + + +# wormhole send (or "wormhole tx") +@wormhole.command() +@click.option( + "zeromode", "-0", default=False, is_flag=True, +) +@click.option( + "--code", metavar="CODE", + help="human-generated code phrase", +) +@click.option( + "--text", default=None, metavar="MESSAGE", + help="text message to send, instead of a file. Use '-' to read from stdin.", +) +@click.argument("what", default=u'') +@click.pass_obj +def send(cfg, what, text, code, zeromode): + """Send a text message, file, or directory""" + with cfg.timing.add("import", which="cmd_send"): + from . import cmd_send + cfg.what = what + cfg.text = text + cfg.zeromode = zeromode + cfg.code = code + + react(_dispatch_command, (cfg, lambda: cmd_send.send(cfg))) + + +# wormhole receive (or "wormhole rx") +@wormhole.command() +@click.option( + "--only-text", "-t", is_flag=True, + help="refuse file transfers, only accept text transfers", +) +@click.option( + "--accept-file", is_flag=True, + help="accept file transfer without asking for confirmation", +) +@click.option( + "--output-file", "-o", + metavar="FILENAME|DIRNAME", + help=("The file or directory to create, overriding the name suggested" + " by the sender."), +) +@click.option( + "-0", "zeromode", is_flag=True, + help="enable no-code anything-goes mode", +) +@click.argument( + "code", nargs=-1, default=None, +# help=("The magic-wormhole code, from the sender. If omitted, the" +# " program will ask for it, using tab-completion."), +) +@click.pass_obj +def receive(cfg, code, zeromode, output_file, accept_file, only_text): + """ + Receive a text message, file, or directory (from 'wormhole send') + """ + with cfg.timing.add("import", which="cmd_receive"): + from . import cmd_receive + cfg.zeromode = zeromode + cfg.output_file = output_file + cfg.accept_file = accept_file + cfg.only_text = only_text + if len(code) == 1: + cfg.code = code[0] + elif len(code) > 1: + print( + "Pass either no code or just one code; you passed" + " {}: {}".format(len(code), ', '.join(code)) + ) + raise SystemExit(1) + else: + cfg.code = None + + react(_dispatch_command, (cfg, lambda: cmd_receive.receive(cfg))) + return diff --git a/src/wormhole/cli/cli_args.py b/src/wormhole/cli/cli_args.py deleted file mode 100644 index bf51ea8..0000000 --- a/src/wormhole/cli/cli_args.py +++ /dev/null @@ -1,179 +0,0 @@ -import click -from textwrap import dedent -from . import public_relay -from .. import __version__ - -class Common: - def __init__(self, stuff): - self.stuff = stuff - -ALIASES = { - "tx": "send", - "rx": "receive", - } -class AliasedGroup(click.Group): - def get_command(self, ctx, cmd_name): - cmd_name = ALIASES.get(cmd_name, cmd_name) - return click.Group.get_command(self, ctx, cmd_name) - -@click.group() -#@click.command(cls=AliasedGroup) -@click.option("--relay-url", default=public_relay.RENDEZVOUS_RELAY, - metavar="URL", - help="rendezvous relay to use", - ) -@click.option("--transit-helper", default=public_relay.TRANSIT_RELAY, - metavar="tcp:HOST:PORT", - help="transit relay to use", - ) -@click.option("-c", "--code-length", default=2, - metavar="NUMWORDS", - help="length of code (in bytes/words)", - ) -@click.option("-v", "--verify", is_flag=True, default=False, - help="display (and wait for acceptance of) verification string", - ) -@click.option("--hide-progress", is_flag=True, default=False, - help="supress progress-bar display", - ) -@click.option("--dump-timing", type=type(u""), # TODO: hide from --help output - default=None, - metavar="FILE.json", - help="(debug) write timing data to file") -@click.option("--no-listen", is_flag=True, default=False, - help="(debug) don't open a listening socket for Transit") -@click.option("--tor", is_flag=True, default=True, - help="use Tor when connecting") -@click.version_option(message="magic-wormhole %(version)s", version=__version__) -@click.pass_context -def cli(ctx, relay_url, transit_helper): - """ - Create a Magic Wormhole and communicate through it. Wormholes are created - by speaking the same magic CODE in two different places at the same time. - Wormholes are secure against anyone who doesn't use the same code.""" - ctx.obj = Common(relay_url) - -@cli.command() -@click.argument("what") -@click.pass_obj -def send(obj, what): - """Send a text message, file, or directory""" - print obj - print what - -@cli.command() -@click.argument("what") -@click.pass_obj -def receive(obj, what): - """Receive anything sent by 'wormhole send'.""" - print obj - print what - -# for now, use "python -m wormhole.cli.cli_args --version", etc -if __name__ == "__main__": - cli() - - -import argparse -parser = argparse.ArgumentParser( - usage="wormhole SUBCOMMAND (subcommand-options)", - description=dedent(""" - Create a Magic Wormhole and communicate through it. Wormholes are created - by speaking the same magic CODE in two different places at the same time. - Wormholes are secure against anyone who doesn't use the same code."""), - ) - -parser.add_argument("--version", action="version", - version="magic-wormhole "+ __version__) -g = parser.add_argument_group("wormhole configuration options") -parser.set_defaults(timing=None) -subparsers = parser.add_subparsers(title="subcommands", - dest="subcommand") - - -# CLI: run-server -s = subparsers.add_parser("server", description="Start/stop a relay server") -sp = s.add_subparsers(title="subcommands", dest="subcommand") -sp_start = sp.add_parser("start", description="Start a relay server", - usage="wormhole server start [opts] [TWISTD-ARGS..]") -sp_start.add_argument("--rendezvous", default="tcp:3000", metavar="tcp:PORT", - help="endpoint specification for the rendezvous port") -sp_start.add_argument("--transit", default="tcp:3001", metavar="tcp:PORT", - help="endpoint specification for the transit-relay port") -sp_start.add_argument("--advertise-version", metavar="VERSION", - help="version to recommend to clients") -sp_start.add_argument("--blur-usage", default=None, type=int, - metavar="SECONDS", - help="round logged access times to improve privacy") -sp_start.add_argument("-n", "--no-daemon", action="store_true") -#sp_start.add_argument("twistd_args", nargs="*", default=None, -# metavar="[TWISTD-ARGS..]", -# help=dedent("""\ -# Additional arguments to pass to twistd"""), -# ) -sp_start.set_defaults(func="server/start") - -sp_stop = sp.add_parser("stop", description="Stop the relay server", - usage="wormhole server stop") -sp_stop.set_defaults(func="server/stop") - -sp_restart = sp.add_parser("restart", description="Restart the relay server", - usage="wormhole server restart") -sp_restart.add_argument("--rendezvous", default="tcp:3000", metavar="tcp:PORT", - help="endpoint specification for the rendezvous port") -sp_restart.add_argument("--transit", default="tcp:3001", metavar="tcp:PORT", - help="endpoint specification for the transit-relay port") -sp_restart.add_argument("--advertise-version", metavar="VERSION", - help="version to recommend to clients") -sp_restart.add_argument("--blur-usage", default=None, type=int, - metavar="SECONDS", - help="round logged access times to improve privacy") -sp_restart.add_argument("-n", "--no-daemon", action="store_true") -sp_restart.set_defaults(func="server/restart") - -sp_show_usage = sp.add_parser("show-usage", description="Display usage data", - usage="wormhole server show-usage") -sp_show_usage.add_argument("-n", default=100, type=int, - help="show last N entries") -sp_show_usage.set_defaults(func="usage/usage") - -sp_tail_usage = sp.add_parser("tail-usage", description="Follow latest usage", - usage="wormhole server tail-usage") -sp_tail_usage.set_defaults(func="usage/tail") - -# CLI: send -p = subparsers.add_parser("send", - description="Send text message, file, or directory", - usage="wormhole send [FILENAME|DIRNAME]") -p.add_argument("--text", metavar="MESSAGE", - help="text message to send, instead of a file. Use '-' to read from stdin.") -p.add_argument("--code", metavar="CODE", help="human-generated code phrase", - type=type(u"")) -p.add_argument("-0", dest="zeromode", action="store_true", - help="enable no-code anything-goes mode") -p.add_argument("what", nargs="?", default=None, metavar="[FILENAME|DIRNAME]", - help="the file/directory to send") -p.set_defaults(func="send/send") - -# CLI: receive -p = subparsers.add_parser("receive", - description="Receive a text message, file, or directory", - usage="wormhole receive [CODE]") -p.add_argument("-0", dest="zeromode", action="store_true", - help="enable no-code anything-goes mode") -p.add_argument("-t", "--only-text", dest="only_text", action="store_true", - help="refuse file transfers, only accept text transfers") -p.add_argument("--accept-file", dest="accept_file", action="store_true", - help="accept file transfer with asking for confirmation") -p.add_argument("-o", "--output-file", default=None, metavar="FILENAME|DIRNAME", - help=dedent("""\ - The file or directory to create, overriding the name suggested - by the sender."""), - ) -p.add_argument("code", nargs="?", default=None, metavar="[CODE]", - help=dedent("""\ - The magic-wormhole code, from the sender. If omitted, the - program will ask for it, using tab-completion."""), - type=type(u""), - ) -p.set_defaults(func="receive/receive") diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index 79a6221..89caba2 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -141,7 +141,7 @@ class TwistedReceiver: @inlineCallbacks def _build_transit(self, w, sender_transit): tr = TransitReceiver(self.args.transit_helper, - no_listen=self.args.no_listen, + no_listen=(not self.args.listen), tor_manager=self._tor_manager, reactor=self._reactor, timing=self.args.timing) diff --git a/src/wormhole/cli/cmd_send.py b/src/wormhole/cli/cmd_send.py index 1b7f3c2..cf082a9 100644 --- a/src/wormhole/cli/cmd_send.py +++ b/src/wormhole/cli/cmd_send.py @@ -48,9 +48,10 @@ class Sender: w = wormhole(APPID, self._args.relay_url, self._reactor, self._tor_manager, timing=self._timing) - d = self._go(w) - d.addBoth(w.close) - yield d + try: + yield self._go(w) + finally: + w.close() def _send_data(self, data, w): data_bytes = dict_to_bytes(data) @@ -100,7 +101,7 @@ class Sender: if self._fd_to_send: ts = TransitSender(args.transit_helper, - no_listen=args.no_listen, + no_listen=(not args.listen), tor_manager=self._tor_manager, reactor=self._reactor, timing=self._timing) diff --git a/src/wormhole/cli/runner.py b/src/wormhole/cli/runner.py deleted file mode 100644 index 8c373b1..0000000 --- a/src/wormhole/cli/runner.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import print_function -import time -start = time.time() -import os, sys, textwrap -from twisted.internet.defer import maybeDeferred -from twisted.internet.task import react -from ..errors import (TransferError, WrongPasswordError, WelcomeError, Timeout, - KeyFormatError) -from ..timing import DebugTiming -from .cli_args import parser -top_import_finish = time.time() - -def dispatch(args): # returns Deferred - if args.func == "send/send": - with args.timing.add("import", which="cmd_send"): - from . import cmd_send - return cmd_send.send(args) - if args.func == "receive/receive": - with args.timing.add("import", which="cmd_receive"): - from . import cmd_receive - return cmd_receive.receive(args) - - raise ValueError("unknown args.func %s" % args.func) - -def run(reactor, argv, cwd, stdout, stderr, executable=None): - """This is invoked by entry() below, and can also be invoked directly by - tests. - """ - - args = parser.parse_args(argv) - if not getattr(args, "func", None): - # So far this only works on py3. py2 exits with a really terse - # "error: too few arguments" during parse_args(). - parser.print_help() - sys.exit(0) - args.cwd = cwd - args.stdout = stdout - args.stderr = stderr - args.timing = timing = DebugTiming() - - timing.add("command dispatch") - timing.add("import", when=start, which="top").finish(when=top_import_finish) - # fires with None, or raises an error - d = maybeDeferred(dispatch, args) - def _maybe_dump_timing(res): - timing.add("exit") - if args.dump_timing: - timing.write(args.dump_timing, stderr) - return res - d.addBoth(_maybe_dump_timing) - def _explain_error(f): - # these errors don't print a traceback, just an explanation - f.trap(TransferError, WrongPasswordError, WelcomeError, Timeout, - KeyFormatError) - if f.check(WrongPasswordError): - msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__)) - print(msg, file=stderr) - elif f.check(WelcomeError): - msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__)) - print(msg, file=stderr) - print(file=stderr) - print(str(f.value), file=stderr) - elif f.check(KeyFormatError): - msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__)) - print(msg, file=stderr) - else: - print("ERROR:", f.value, file=stderr) - raise SystemExit(1) - d.addErrback(_explain_error) - d.addCallback(lambda _: 0) - return d - -def entry(): - """This is used by a setuptools entry_point. When invoked this way, - setuptools has already put the installed package on sys.path .""" - react(run, (sys.argv[1:], os.getcwd(), sys.stdout, sys.stderr, - sys.argv[0])) - -if __name__ == "__main__": - args = parser.parse_args() - print(args) diff --git a/src/wormhole/server/cli.py b/src/wormhole/server/cli.py new file mode 100644 index 0000000..cb8538a --- /dev/null +++ b/src/wormhole/server/cli.py @@ -0,0 +1,162 @@ +from __future__ import print_function + +import click + + +# can put this back in to get this command as "wormhole server" +# instead +#from ..cli.cli import wormhole +#@wormhole.group() +@click.group() +@click.pass_context +def server(ctx): + """ + Control a relay server (most users shouldn't need to worry + about this and can use the default server). + """ + # just leaving this pointing to wormhole.cli.cli.Config for now, + # but if we want to keep wormhole-server as a separate command + # should probably have our own Config without all the options the + # server commands don't use + from ..cli.cli import Config + ctx.obj = Config() + + +@server.command() +@click.option( + "--rendezvous", default="tcp:3000", metavar="tcp:PORT", + help="endpoint specification for the rendezvous port", +) +@click.option( + "--transit", default="tcp:3001", metavar="tcp:PORT", + help="endpoint specification for the transit-relay port", +) +@click.option( + "--advertise-version", metavar="VERSION", + help="version to recommend to clients", +) +@click.option( + "--blur-usage", default=None, type=int, + metavar="SECONDS", + help="round logged access times to improve privacy", +) +@click.option( + "--no-daemon", "-n", is_flag=True, + help="Run in the foreground", +) +@click.option( + "--signal-error", is_flag=True, + help="force all clients to fail with a message", +) +@click.pass_obj +def start(cfg, signal_error, no_daemon, blur_usage, advertise_version, transit, rendezvous): + """ + Start a relay server + """ + with cfg.timing.add("import", which="cmd_server_start"): + from wormhole.server.cmd_server import start_server + cfg.no_daemon = no_daemon + cfg.blur_usage = blur_usage + cfg.advertise_version = advertise_version + cfg.transit = str(transit) + cfg.rendezvous = str(rendezvous) + cfg.signal_error = signal_error + + start_server(cfg) + + +# XXX it would be nice to reduce the duplication between 'restart' and +# 'start' options... +@server.command() +@click.option( + "--rendezvous", default="tcp:3000", metavar="tcp:PORT", + help="endpoint specification for the rendezvous port", +) +@click.option( + "--transit", default="tcp:3001", metavar="tcp:PORT", + help="endpoint specification for the transit-relay port", +) +@click.option( + "--advertise-version", metavar="VERSION", + help="version to recommend to clients", +) +@click.option( + "--blur-usage", default=None, type=int, + metavar="SECONDS", + help="round logged access times to improve privacy", +) +@click.option( + "--no-daemon", "-n", is_flag=True, + help="Run in the foreground", +) +@click.option( + "--signal-error", is_flag=True, + help="force all clients to fail with a message", +) +@click.pass_obj +def restart(cfg, signal_error, no_daemon, blur_usage, advertise_version, transit, rendezvous): + """ + Re-start a relay server + """ + with cfg.timing.add("import", which="cmd_server_restart"): + from wormhole.server.cmd_server import restart_server + cfg.no_daemon = no_daemon + cfg.blur_usage = blur_usage + cfg.advertise_version = advertise_version + cfg.transit = str(transit) + cfg.rendezvous = str(rendezvous) + cfg.signal_error = signal_error + + restart_server(cfg) + + +@server.command() +@click.pass_obj +def stop(cfg): + """ + Stop a relay server + """ + with cfg.timing.add("import", which="cmd_server_stop"): + from wormhole.server.cmd_server import stop_server + stop_server(cfg) + + +@server.command(name="tail-usage") +@click.pass_obj +def tail_usage(cfg): + """ + Follow the latest usage + """ + with cfg.timing.add("import", which="cmd_tail_usage"): + from wormhole.server.cmd_usage import tail_usage + tail_usage(cfg) + + +@server.command(name='count-channels') +@click.option( + "--json", is_flag=True, +) +@click.pass_obj +def count_channels(cfg, json): + """ + Count active channels + """ + with cfg.timing.add("import", which="cmd_count_channels"): + from wormhole.server.cmd_usage import count_channels + cfg.json = json + count_channels(cfg) + + +@server.command(name='count-events') +@click.option( + "--json", is_flag=True, +) +@click.pass_obj +def count_events(cfg, json): + """ + Count events + """ + with cfg.timing.add("import", which="cmd_count_events"): + from wormhole.server.cmd_usage import count_events + cfg.json = json + count_events(cfg) diff --git a/src/wormhole/server/cli_args.py b/src/wormhole/server/cli_args.py deleted file mode 100644 index bf95565..0000000 --- a/src/wormhole/server/cli_args.py +++ /dev/null @@ -1,77 +0,0 @@ -import argparse -from textwrap import dedent -from .. import __version__ - -parser = argparse.ArgumentParser( - usage="wormhole-server SUBCOMMAND (subcommand-options)", - description=dedent(""" - Create a Magic Wormhole and communicate through it. Wormholes are created - by speaking the same magic CODE in two different places at the same time. - Wormholes are secure against anyone who doesn't use the same code."""), - ) - -parser.add_argument("--version", action="version", - version="magic-wormhole "+ __version__) -s = parser.add_subparsers(title="subcommands", dest="subcommand") - - -# CLI: run-server -sp_start = s.add_parser("start", description="Start a relay server", - usage="wormhole server start [opts] [TWISTD-ARGS..]") -sp_start.add_argument("--rendezvous", default="tcp:4000", metavar="tcp:PORT", - help="endpoint specification for the rendezvous port") -sp_start.add_argument("--transit", default="tcp:4001", metavar="tcp:PORT", - help="endpoint specification for the transit-relay port") -sp_start.add_argument("--advertise-version", metavar="VERSION", - help="version to recommend to clients") -sp_start.add_argument("--signal-error", metavar="ERROR", - help="force all clients to fail with a message") -sp_start.add_argument("--blur-usage", default=None, type=int, - metavar="SECONDS", - help="round logged access times to improve privacy") -sp_start.add_argument("-n", "--no-daemon", action="store_true") -#sp_start.add_argument("twistd_args", nargs="*", default=None, -# metavar="[TWISTD-ARGS..]", -# help=dedent("""\ -# Additional arguments to pass to twistd"""), -# ) -sp_start.set_defaults(func="server/start") - -sp_stop = s.add_parser("stop", description="Stop the relay server", - usage="wormhole server stop") -sp_stop.set_defaults(func="server/stop") - -sp_restart = s.add_parser("restart", description="Restart the relay server", - usage="wormhole server restart") -sp_restart.add_argument("--rendezvous", default="tcp:4000", metavar="tcp:PORT", - help="endpoint specification for the rendezvous port") -sp_restart.add_argument("--transit", default="tcp:4001", metavar="tcp:PORT", - help="endpoint specification for the transit-relay port") -sp_restart.add_argument("--advertise-version", metavar="VERSION", - help="version to recommend to clients") -sp_restart.add_argument("--signal-error", metavar="ERROR", - help="force all clients to fail with a message") -sp_restart.add_argument("--blur-usage", default=None, type=int, - metavar="SECONDS", - help="round logged access times to improve privacy") -sp_restart.add_argument("-n", "--no-daemon", action="store_true") -sp_restart.set_defaults(func="server/restart") - -sp_show_usage = s.add_parser("show-usage", description="Display usage data", - usage="wormhole server show-usage") -sp_show_usage.add_argument("-n", default=100, type=int, - help="show last N entries") -sp_show_usage.set_defaults(func="usage/usage") - -sp_tail_usage = s.add_parser("tail-usage", description="Follow latest usage", - usage="wormhole server tail-usage") -sp_tail_usage.set_defaults(func="usage/tail") - -sp_count_channels = s.add_parser("count-channels", - description="Count active channels") -sp_count_channels.add_argument("--json", action="store_true") -sp_count_channels.set_defaults(func="usage/count-channels") - -sp_count_events = s.add_parser("count-events", description="Count events") -sp_count_events.add_argument("--json", action="store_true") -sp_count_events.set_defaults(func="usage/count-events") diff --git a/src/wormhole/server/cmd_server.py b/src/wormhole/server/cmd_server.py index ed2afe5..d149b82 100644 --- a/src/wormhole/server/cmd_server.py +++ b/src/wormhole/server/cmd_server.py @@ -39,15 +39,15 @@ def kill_server(): f = open("twistd.pid", "r") except EnvironmentError: print("Unable to find twistd.pid . Is this really a server directory?") - return 1 + return pid = int(f.read().strip()) f.close() os.kill(pid, 15) print("server process %d sent SIGTERM" % pid) - return 0 + return def stop_server(args): - return kill_server() + kill_server() def restart_server(args): kill_server() diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index 07015db..bea06da 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -1,6 +1,6 @@ # no unicode_literals untill twisted update from twisted.application import service -from twisted.internet import reactor, defer +from twisted.internet import defer, task from twisted.python import log from ..transit import allocate_tcp_port from ..server.server import RelayServer @@ -36,8 +36,18 @@ class ServerBase: # relay's .stopService() drops all connections, which ought to # encourage those threads to terminate soon. If they don't, print a # warning to ease debugging. + + # XXX FIXME there's something in _noclobber test that's not + # waiting for a close, I think -- was pretty relieably getting + # unclean-reactor, but adding a slight pause here stops it... + from twisted.internet import reactor + tp = reactor.getThreadPool() if not tp.working: + d = defer.succeed(None) + d.addCallback(lambda _: self.sp.stopService()) + d.addCallback(lambda _: task.deferLater(reactor, 0.1, lambda: None)) + return d return self.sp.stopService() # disconnect all callers d = defer.maybeDeferred(self.sp.stopService) diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index 40dbf8d..549f590 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -7,33 +7,32 @@ from twisted.internet.utils import getProcessOutputAndValue from twisted.internet.defer import gatherResults, inlineCallbacks from .. import __version__ from .common import ServerBase -from ..cli import runner, cmd_send, cmd_receive +from ..cli import cmd_send, cmd_receive +from ..cli.cli import Config from ..errors import TransferError, WrongPasswordError, WelcomeError -from ..timing import DebugTiming + def build_offer(args): s = cmd_send.Sender(args, None) return s._build_offer() + class OfferData(unittest.TestCase): def setUp(self): self._things_to_delete = [] + self.cfg = cfg = Config() + cfg.stdout = io.StringIO() + cfg.stderr = io.StringIO() def tearDown(self): for fn in self._things_to_delete: if os.path.exists(fn): os.unlink(fn) + del self.cfg def test_text(self): - message = "blah blah blah ponies" - - send_args = [ "send", "--text", message ] - args = runner.parser.parse_args(send_args) - args.cwd = os.getcwd() - args.stdout = io.StringIO() - args.stderr = io.StringIO() - - d, fd_to_send = build_offer(args) + self.cfg.text = message = "blah blah blah ponies" + d, fd_to_send = build_offer(self.cfg) self.assertIn("message", d) self.assertNotIn("file", d) @@ -42,7 +41,7 @@ class OfferData(unittest.TestCase): self.assertEqual(fd_to_send, None) def test_file(self): - filename = "my file" + self.cfg.what = filename = "my file" message = b"yay ponies\n" send_dir = self.mktemp() os.mkdir(send_dir) @@ -50,13 +49,8 @@ class OfferData(unittest.TestCase): with open(abs_filename, "wb") as f: f.write(message) - send_args = [ "send", filename ] - args = runner.parser.parse_args(send_args) - args.cwd = send_dir - args.stdout = io.StringIO() - args.stderr = io.StringIO() - - d, fd_to_send = build_offer(args) + self.cfg.cwd = send_dir + d, fd_to_send = build_offer(self.cfg) self.assertNotIn("message", d) self.assertIn("file", d) @@ -67,17 +61,12 @@ class OfferData(unittest.TestCase): self.assertEqual(fd_to_send.read(), message) def test_missing_file(self): - filename = "missing" + self.cfg.what = filename = "missing" send_dir = self.mktemp() os.mkdir(send_dir) + self.cfg.cwd = send_dir - send_args = [ "send", filename ] - args = runner.parser.parse_args(send_args) - args.cwd = send_dir - args.stdout = io.StringIO() - args.stderr = io.StringIO() - - e = self.assertRaises(TransferError, build_offer, args) + e = self.assertRaises(TransferError, build_offer, self.cfg) self.assertEqual(str(e), "Cannot send: no file/directory named '%s'" % filename) @@ -94,13 +83,10 @@ class OfferData(unittest.TestCase): send_dir_arg = send_dir if addslash: send_dir_arg += os.sep - send_args = [ "send", send_dir_arg ] - args = runner.parser.parse_args(send_args) - args.cwd = parent_dir - args.stdout = io.StringIO() - args.stderr = io.StringIO() + self.cfg.what = send_dir_arg + self.cfg.cwd = parent_dir - d, fd_to_send = build_offer(args) + d, fd_to_send = build_offer(self.cfg) self.assertNotIn("message", d) self.assertNotIn("file", d) @@ -130,10 +116,11 @@ class OfferData(unittest.TestCase): return self._do_test_directory(addslash=True) def test_unknown(self): - filename = "unknown" + self.cfg.what = filename = "unknown" send_dir = self.mktemp() os.mkdir(send_dir) abs_filename = os.path.abspath(os.path.join(send_dir, filename)) + self.cfg.cwd = send_dir try: os.mkfifo(abs_filename) @@ -149,13 +136,7 @@ class OfferData(unittest.TestCase): self.assertFalse(os.path.isfile(abs_filename)) self.assertFalse(os.path.isdir(abs_filename)) - send_args = [ "send", filename ] - args = runner.parser.parse_args(send_args) - args.cwd = send_dir - args.stdout = io.StringIO() - args.stderr = io.StringIO() - - e = self.assertRaises(TypeError, build_offer, args) + e = self.assertRaises(TypeError, build_offer, self.cfg) self.assertEqual(str(e), "'%s' is neither file nor directory" % filename) @@ -203,27 +184,24 @@ class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase): # we need Twisted to run the server, but we run the sender and receiver # with deferToThread() + @inlineCallbacks def test_version(self): # "wormhole" must be on the path, so e.g. "pip install -e ." in a # virtualenv. This guards against an environment where the tests # below might run the wrong executable. + self.maxDiff = None wormhole = self.find_executable() - d = getProcessOutputAndValue(wormhole, ["--version"]) - def _check(res): - out, err, rc = res - # argparse on py2 and py3.3 sends --version to stderr - # argparse on py3.4/py3.5 sends --version to stdout - # aargh - err = err.decode("utf-8") - if "DistributionNotFound" in err: - log.msg("stderr was %s" % err) - last = err.strip().split("\n")[-1] - self.fail("wormhole not runnable: %s" % last) - ver = out.decode("utf-8") or err - self.failUnlessEqual(ver, "magic-wormhole "+__version__+os.linesep) - self.failUnlessEqual(rc, 0) - d.addCallback(_check) - return d + # we must pass on the environment so that "something" doesn't + # get sad about UTF8 vs. ascii encodings + out, err, rc = yield getProcessOutputAndValue(wormhole, ["--version"], env=os.environ) + err = err.decode("utf-8") + if "DistributionNotFound" in err: + log.msg("stderr was %s" % err) + last = err.strip().split("\n")[-1] + self.fail("wormhole not runnable: %s" % last) + ver = out.decode("utf-8") or err + self.failUnlessEqual(ver.strip(), "magic-wormhole {}".format(__version__)) + self.failUnlessEqual(rc, 0) class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): # we need Twisted to run the server, but we run the sender and receiver @@ -238,20 +216,18 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): def _do_test(self, as_subprocess=False, mode="text", addslash=False, override_filename=False): assert mode in ("text", "file", "directory") - common_args = ["--hide-progress", - "--relay-url", self.relayurl, - "--transit-helper", ""] - code = "1-abc" - message = "test message" + send_cfg = Config() + recv_cfg = Config() + message = "blah blah blah ponies" - send_args = common_args + [ - "send", - "--code", code, - ] - - receive_args = common_args + [ - "receive", - ] + for cfg in [send_cfg, recv_cfg]: + cfg.hide_progress = True + cfg.relay_url = self.relayurl + cfg.transit_helper = "" + cfg.listen = True + cfg.code = "1-abc" + cfg.stdout = io.StringIO() + cfg.stderr = io.StringIO() send_dir = self.mktemp() os.mkdir(send_dir) @@ -259,19 +235,18 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): os.mkdir(receive_dir) if mode == "text": - send_args.extend(["--text", message]) + send_cfg.text = message elif mode == "file": send_filename = "testfile" with open(os.path.join(send_dir, send_filename), "w") as f: f.write(message) - send_args.append(send_filename) + send_cfg.what = send_filename receive_filename = send_filename - receive_args.append("--accept-file") + recv_cfg.accept_file = True if override_filename: - receive_args.extend(["-o", "outfile"]) - receive_filename = "outfile" + recv_cfg.output_file = receive_filename = "outfile" elif mode == "directory": # $send_dir/ @@ -299,22 +274,48 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): send_dirname_arg = os.path.join("middle", send_dirname) if addslash: send_dirname_arg += os.sep - send_args.append(send_dirname_arg) + send_cfg.what = send_dirname_arg receive_dirname = send_dirname - receive_args.append("--accept-file") + recv_cfg.accept_file = True if override_filename: - receive_args.extend(["-o", "outdir"]) - receive_dirname = "outdir" - - receive_args.append(code) + recv_cfg.output_file = receive_dirname = "outdir" if as_subprocess: wormhole_bin = self.find_executable() - send_d = getProcessOutputAndValue(wormhole_bin, send_args, - path=send_dir) - receive_d = getProcessOutputAndValue(wormhole_bin, receive_args, - path=receive_dir) + if send_cfg.text: + content_args = ['--text', send_cfg.text] + elif send_cfg.what: + content_args = [send_cfg.what] + + send_args = [ + '--hide-progress', + '--relay-url', self.relayurl, + '--transit-helper', '', + 'send', + '--code', send_cfg.code, + ] + content_args + + send_d = getProcessOutputAndValue( + wormhole_bin, send_args, + path=send_dir, + ) + recv_args = [ + '--hide-progress', + '--relay-url', self.relayurl, + '--transit-helper', '', + 'receive', + '--accept-file', + recv_cfg.code, + ] + if override_filename: + recv_args.extend(['-o', receive_filename]) + + receive_d = getProcessOutputAndValue( + wormhole_bin, recv_args, + path=receive_dir, + ) + (send_res, receive_res) = yield gatherResults([send_d, receive_d], True) send_stdout = send_res[0].decode("utf-8") @@ -327,27 +328,21 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.assertEqual((send_rc, receive_rc), (0, 0), (send_res, receive_res)) else: - sargs = runner.parser.parse_args(send_args) - sargs.cwd = send_dir - sargs.stdout = io.StringIO() - sargs.stderr = io.StringIO() - sargs.timing = DebugTiming() - rargs = runner.parser.parse_args(receive_args) - rargs.cwd = receive_dir - rargs.stdout = io.StringIO() - rargs.stderr = io.StringIO() - rargs.timing = DebugTiming() - send_d = cmd_send.send(sargs) - receive_d = cmd_receive.receive(rargs) + send_cfg.cwd = send_dir + send_d = cmd_send.send(send_cfg) + + recv_cfg.cwd = receive_dir + receive_d = cmd_receive.receive(recv_cfg) # The sender might fail, leaving the receiver hanging, or vice # versa. Make sure we don't wait on one side exclusively yield gatherResults([send_d, receive_d], True) - send_stdout = sargs.stdout.getvalue() - send_stderr = sargs.stderr.getvalue() - receive_stdout = rargs.stdout.getvalue() - receive_stderr = rargs.stderr.getvalue() + # XXX need captured stdin/stdout from sender/receiver + send_stdout = send_cfg.stdout.getvalue() + send_stderr = send_cfg.stderr.getvalue() + receive_stdout = recv_cfg.stdout.getvalue() + receive_stderr = recv_cfg.stderr.getvalue() # all output here comes from a StringIO, which uses \n for # newlines, even if we're on windows @@ -367,7 +362,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): "wormhole receive{NL}" "Wormhole code is: {code}{NL}{NL}" "text message sent{NL}").format(bytes=len(message), - code=code, + code=send_cfg.code, NL=NL) self.failUnlessEqual(send_stdout, expected) elif mode == "file": @@ -377,7 +372,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.failUnlessIn("On the other computer, please run: " "wormhole receive{NL}" "Wormhole code is: {code}{NL}{NL}" - .format(code=code, NL=NL), + .format(code=send_cfg.code, NL=NL), send_stdout) self.failUnlessIn("File sent.. waiting for confirmation{NL}" "Confirmation received. Transfer complete.{NL}" @@ -388,7 +383,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.failUnlessIn("On the other computer, please run: " "wormhole receive{NL}" "Wormhole code is: {code}{NL}{NL}" - .format(code=code, NL=NL), send_stdout) + .format(code=send_cfg.code, NL=NL), send_stdout) self.failUnlessIn("File sent.. waiting for confirmation{NL}" "Confirmation received. Transfer complete.{NL}" .format(NL=NL), send_stdout) @@ -440,14 +435,21 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): @inlineCallbacks def test_file_noclobber(self): - common_args = ["--hide-progress", "--no-listen", - "--relay-url", self.relayurl, - "--transit-helper", ""] - code = "1-abc" + send_cfg = Config() + recv_cfg = Config() + + for cfg in [send_cfg, recv_cfg]: + cfg.hide_progress = True + cfg.relay_url = self.relayurl + cfg.transit_helper = "" + cfg.listen = False + cfg.code = code = "1-abc" + cfg.stdout = io.StringIO() + cfg.stderr = io.StringIO() + message = "test message" - send_args = common_args + [ "send", "--code", code ] - receive_args = common_args + [ "receive", "--accept-file", code ] + recv_cfg.accept_file = True send_dir = self.mktemp() os.mkdir(send_dir) @@ -457,26 +459,19 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): send_filename = "testfile" with open(os.path.join(send_dir, send_filename), "w") as f: f.write(message) - send_args.append(send_filename) - receive_filename = send_filename + send_cfg.what = receive_filename = send_filename + recv_cfg.what = receive_filename PRESERVE = "don't clobber me\n" clobberable = os.path.join(receive_dir, receive_filename) with open(clobberable, "w") as f: f.write(PRESERVE) - sargs = runner.parser.parse_args(send_args) - sargs.cwd = send_dir - sargs.stdout = io.StringIO() - sargs.stderr = io.StringIO() - sargs.timing = DebugTiming() - rargs = runner.parser.parse_args(receive_args) - rargs.cwd = receive_dir - rargs.stdout = io.StringIO() - rargs.stderr = io.StringIO() - rargs.timing = DebugTiming() - send_d = cmd_send.send(sargs) - receive_d = cmd_receive.receive(rargs) + send_cfg.cwd = send_dir + send_d = cmd_send.send(send_cfg) + + recv_cfg.cwd = receive_dir + receive_d = cmd_receive.receive(recv_cfg) # both sides will fail because of the pre-existing file @@ -486,10 +481,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): f = yield self.assertFailure(receive_d, TransferError) self.assertEqual(str(f), "file already exists") - send_stdout = sargs.stdout.getvalue() - send_stderr = sargs.stderr.getvalue() - receive_stdout = rargs.stdout.getvalue() - receive_stderr = rargs.stderr.getvalue() + send_stdout = send_cfg.stdout.getvalue() + send_stderr = send_cfg.stderr.getvalue() + receive_stdout = recv_cfg.stdout.getvalue() + receive_stderr = recv_cfg.stderr.getvalue() # all output here comes from a StringIO, which uses \n for # newlines, even if we're on windows @@ -528,63 +523,56 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): class NotWelcome(ServerBase, unittest.TestCase): def setUp(self): self._setup_relay(error="please upgrade XYZ") + self.cfg = cfg = Config() + cfg.hide_progress = True + cfg.listen = False + cfg.relay_url = self.relayurl + cfg.transit_helper = "" + cfg.stdout = io.StringIO() + cfg.stderr = io.StringIO() @inlineCallbacks def test_sender(self): - common_args = ["--hide-progress", "--no-listen", - "--relay-url", self.relayurl, - "--transit-helper", ""] - send_args = common_args + [ "send", "--text", "hi", - "--code", "1-abc" ] - sargs = runner.parser.parse_args(send_args) - sargs.cwd = self.mktemp() - sargs.stdout = io.StringIO() - sargs.stderr = io.StringIO() - sargs.timing = DebugTiming() + self.cfg.text = "hi" + self.cfg.code = "1-abc" - send_d = cmd_send.send(sargs) + send_d = cmd_send.send(self.cfg) f = yield self.assertFailure(send_d, WelcomeError) self.assertEqual(str(f), "please upgrade XYZ") @inlineCallbacks def test_receiver(self): - common_args = ["--hide-progress", "--no-listen", - "--relay-url", self.relayurl, - "--transit-helper", ""] - receive_args = common_args + [ "receive", "1-abc" ] - rargs = runner.parser.parse_args(receive_args) - rargs.cwd = self.mktemp() - rargs.stdout = io.StringIO() - rargs.stderr = io.StringIO() - rargs.timing = DebugTiming() + self.cfg.code = "1-abc" - receive_d = cmd_receive.receive(rargs) + receive_d = cmd_receive.receive(self.cfg) f = yield self.assertFailure(receive_d, WelcomeError) self.assertEqual(str(f), "please upgrade XYZ") -class Cleanup(ServerBase, unittest.TestCase): - @inlineCallbacks - def test_text(self): - # the rendezvous channel should be deleted after success - code = "1-abc" - common_args = ["--hide-progress", - "--relay-url", self.relayurl, - "--transit-helper", ""] - sargs = runner.parser.parse_args(common_args + - ["send", - "--text", "secret message", - "--code", code]) - sargs.stdout = io.StringIO() - sargs.stderr = io.StringIO() - sargs.timing = DebugTiming() - rargs = runner.parser.parse_args(common_args + - ["receive", code]) - rargs.stdout = io.StringIO() - rargs.stderr = io.StringIO() - rargs.timing = DebugTiming() - send_d = cmd_send.send(sargs) - receive_d = cmd_receive.receive(rargs) +class Cleanup(ServerBase, unittest.TestCase): + + def setUp(self): + d = super(Cleanup, self).setUp() + self.cfg = cfg = Config() + # common options for all tests in this suite + cfg.hide_progress = True + cfg.relay_url = self.relayurl + cfg.transit_helper = "" + cfg.stdout = io.StringIO() + cfg.stderr = io.StringIO() + return d + + @inlineCallbacks + @mock.patch('sys.stdout') + def test_text(self, stdout): + # the rendezvous channel should be deleted after success + self.cfg.text = "hello" + self.cfg.code = "1-abc" + + send_d = cmd_send.send(self.cfg) + receive_d = cmd_receive.receive(self.cfg) + + # XXX DeferredList? yield send_d yield receive_d @@ -595,23 +583,12 @@ class Cleanup(ServerBase, unittest.TestCase): def test_text_wrong_password(self): # if the password was wrong, the rendezvous channel should still be # deleted - common_args = ["--hide-progress", - "--relay-url", self.relayurl, - "--transit-helper", ""] - sargs = runner.parser.parse_args(common_args + - ["send", - "--text", "secret message", - "--code", "1-abc"]) - sargs.stdout = io.StringIO() - sargs.stderr = io.StringIO() - sargs.timing = DebugTiming() - rargs = runner.parser.parse_args(common_args + - ["receive", "1-WRONG"]) - rargs.stdout = io.StringIO() - rargs.stderr = io.StringIO() - rargs.timing = DebugTiming() - send_d = cmd_send.send(sargs) - receive_d = cmd_receive.receive(rargs) + self.cfg.text = "secret message" + self.cfg.code = "1-abc" + send_d = cmd_send.send(self.cfg) + + self.cfg.code = "1-WRONG" + receive_d = cmd_receive.receive(self.cfg) # both sides should be capable of detecting the mismatch yield self.assertFailure(send_d, WrongPasswordError) diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index 050cd00..92ac4df 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -813,6 +813,10 @@ class Common: is_relay=True) contenders.append(d) + if not contenders: + raise RuntimeError("No contenders for connection") +# else: +# print("contend", contenders) winner = there_can_be_only_one(contenders) return self._not_forever(2*TIMEOUT, winner) diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 3a716b1..0c77e69 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -462,8 +462,10 @@ class _Wormhole: # entry point 3: paste in a fully-formed code def _API_set_code(self, code): self._timing.add("API set_code") - if not isinstance(code, type("")): raise TypeError(type(code)) - if self._code is not None: raise UsageError + if not isinstance(code, type(u"")): + raise TypeError("Unexpected code type '{}'".format(type(code))) + if self._code is not None: + raise UsageError self._event_learned_code(code) # TODO: entry point 4: restore pre-contact saved state (we haven't heard