diff --git a/setup.py b/setup.py index 5d57848..34e06d7 100644 --- a/setup.py +++ b/setup.py @@ -19,18 +19,29 @@ setup(name="magic-wormhole", "wormhole.test", ], package_data={"wormhole.server": ["db-schemas/*.sql"]}, - entry_points={"console_scripts": - ["wormhole = wormhole.cli.runner:entry", - "wormhole-server = wormhole.server.runner:entry", - ]}, - install_requires=["spake2==0.7", "pynacl", "argparse", - "six", - "twisted", - "autobahn[twisted] >= 0.14.1", - "hkdf", "tqdm", - ], - extras_require={':sys_platform=="win32"': ["pypiwin32"], - "tor": ["txtorcon", "ipaddress"]}, + entry_points={ + "console_scripts": + [ + "wormhole = wormhole.cli.cli:wormhole", + "wormhole-server = wormhole.server.cli:server", + ] + }, + install_requires=[ + "spake2==0.7", "pynacl", + "six", + "twisted", + "autobahn[twisted] >= 0.14.1", + "hkdf", "tqdm", + "click", + ], + extras_require={ + ':sys_platform=="win32"': ["pypiwin32"], + "tor": ["txtorcon", "ipaddress"], + "dev": [ + "mock", + "tox", + ], + }, test_suite="wormhole.test", cmdclass=commands, ) diff --git a/src/wormhole/cli/cli.py b/src/wormhole/cli/cli.py new file mode 100644 index 0000000..631b347 --- /dev/null +++ b/src/wormhole/cli/cli.py @@ -0,0 +1,233 @@ +from __future__ import print_function + +import os +import time +start = time.time() +import traceback +from textwrap import fill, dedent +from sys import stdout, stderr +from . import public_relay +from .. import __version__ +from ..timing import DebugTiming +from ..errors import WrongPasswordError, WelcomeError, KeyFormatError +from twisted.internet.defer import inlineCallbacks, maybeDeferred +from twisted.internet.task import react + +import click +top_import_finish = time.time() + + +class Config(object): + """ + Union of config options that we pass down to (sub) commands. + """ + def __init__(self): + # common options + self.timing = DebugTiming() + self.tor = None + self.listen = None + self.relay_url = u"" + self.transit_helper = u"" + self.cwd = os.getcwd() + # send/receive commands + self.code = None + self.code_length = 2 + self.verify = False + self.hide_progress = False + self.dump_timing = False + self.stdout = stdout + self.stderr = stderr + self.zeromode = False + self.accept_file = None + self.output_file = None + # send only + self.text = None + self.what = None + + +ALIASES = { + "tx": "send", + "rx": "receive", +} +class AliasedGroup(click.Group): + def get_command(self, ctx, cmd_name): + cmd_name = ALIASES.get(cmd_name, cmd_name) + return click.Group.get_command(self, ctx, cmd_name) + + +# top-level command ("wormhole ...") +@click.group(cls=AliasedGroup) +@click.option( + "--relay-url", default=public_relay.RENDEZVOUS_RELAY, + metavar="URL", + help="rendezvous relay to use", +) +@click.option( + "--transit-helper", default=public_relay.TRANSIT_RELAY, + metavar="tcp:HOST:PORT", + help="transit relay to use", +) +@click.option( + "-c", "--code-length", default=2, + metavar="NUMWORDS", + help="length of code (in bytes/words)", +) +@click.option( + "-v", "--verify", is_flag=True, default=False, + help="display (and wait for acceptance of) verification string", +) +@click.option( + "--hide-progress", is_flag=True, default=False, + help="supress progress-bar display", +) +@click.option( + "--dump-timing", type=type(u""), # TODO: hide from --help output + default=None, + metavar="FILE.json", + help="(debug) write timing data to file", +) +@click.option( + "--no-listen", is_flag=True, default=False, + help="(debug) don't open a listening socket for Transit", +) +@click.option( + "--tor", is_flag=True, default=True, + help="use Tor when connecting", +) +@click.version_option( + message="magic-wormhole %(version)s", + version=__version__, +) +@click.pass_context +def wormhole(ctx, tor, no_listen, dump_timing, hide_progress, + verify, code_length, transit_helper, relay_url): + """ + Create a Magic Wormhole and communicate through it. + + Wormholes are created by speaking the same magic CODE in two + different places at the same time. Wormholes are secure against + anyone who doesn't use the same code. + """ + ctx.obj = cfg = Config() + ctx.tor = tor + if no_listen: + cfg.listen = False + cfg.relay_url = relay_url + cfg.transit_helper = transit_helper + cfg.code_length = code_length + cfg.verify = verify + cfg.hide_progress = hide_progress + cfg.dump_timing = dump_timing + + +@inlineCallbacks +def _dispatch_command(reactor, cfg, command): + """ + Internal helper. This calls the given command (a no-argument + callable) with the Config instance in cfg and interprets any + errors for the user. + """ + cfg.timing.add("command dispatch") + cfg.timing.add("import", when=start, which="top").finish(when=top_import_finish) + + try: + yield maybeDeferred(command) + except WrongPasswordError as e: + msg = fill("ERROR: " + dedent(e.__doc__)) + print(msg, file=stderr) + except WelcomeError as e: + msg = fill("ERROR: " + dedent(e.__doc__)) + print(msg, file=stderr) + print(file=stderr) + print(str(e), file=stderr) + except KeyFormatError as e: + msg = fill("ERROR: " + dedent(e.__doc__)) + print(msg, file=stderr) + except Exception as e: + traceback.print_exc() + print("ERROR:", e, file=stderr) + raise SystemExit(1) + + cfg.timing.add("exit") + if cfg.dump_timing: + cfg.timing.write(cfg.dump_timing, stderr) + + +# wormhole send (or "wormhole tx") +@wormhole.command() +@click.option( + "-0", "zeromode", default=False, is_flag=True, + help="enable no-code anything-goes mode", +) +@click.option( + "--code", metavar="CODE", + help="human-generated code phrase", +) +@click.option( + "--text", default=None, metavar="MESSAGE", + help="text message to send, instead of a file. Use '-' to read from stdin.", +) +@click.argument("what", default=u'') +@click.pass_obj +def send(cfg, what, text, code, zeromode): + """Send a text message, file, or directory""" + with cfg.timing.add("import", which="cmd_send"): + from . import cmd_send + cfg.what = what + cfg.text = text + cfg.zeromode = zeromode + cfg.code = code + + # note: react() does not return + return react(_dispatch_command, (cfg, lambda: cmd_send.send(cfg))) + + +# wormhole receive (or "wormhole rx") +@wormhole.command() +@click.option( + "-0", "zeromode", default=False, is_flag=True, + help="enable no-code anything-goes mode", +) +@click.option( + "--only-text", "-t", is_flag=True, + help="refuse file transfers, only accept text transfers", +) +@click.option( + "--accept-file", is_flag=True, + help="accept file transfer without asking for confirmation", +) +@click.option( + "--output-file", "-o", + metavar="FILENAME|DIRNAME", + help=("The file or directory to create, overriding the name suggested" + " by the sender."), +) +@click.argument( + "code", nargs=-1, default=None, +# help=("The magic-wormhole code, from the sender. If omitted, the" +# " program will ask for it, using tab-completion."), +) +@click.pass_obj +def receive(cfg, code, zeromode, output_file, accept_file, only_text): + """ + Receive a text message, file, or directory (from 'wormhole send') + """ + with cfg.timing.add("import", which="cmd_receive"): + from . import cmd_receive + cfg.zeromode = zeromode + cfg.output_file = output_file + cfg.accept_file = accept_file + cfg.only_text = only_text + if len(code) == 1: + cfg.code = code[0] + elif len(code) > 1: + print( + "Pass either no code or just one code; you passed" + " {}: {}".format(len(code), ', '.join(code)) + ) + raise SystemExit(1) + else: + cfg.code = None + + # note: react() does not return + return react(_dispatch_command, (cfg, lambda: cmd_receive.receive(cfg))) diff --git a/src/wormhole/cli/cli_args.py b/src/wormhole/cli/cli_args.py deleted file mode 100644 index 793a171..0000000 --- a/src/wormhole/cli/cli_args.py +++ /dev/null @@ -1,124 +0,0 @@ -import argparse -from textwrap import dedent -from . import public_relay -from .. import __version__ - -parser = argparse.ArgumentParser( - usage="wormhole SUBCOMMAND (subcommand-options)", - description=dedent(""" - Create a Magic Wormhole and communicate through it. Wormholes are created - by speaking the same magic CODE in two different places at the same time. - Wormholes are secure against anyone who doesn't use the same code."""), - ) - -parser.add_argument("--version", action="version", - version="magic-wormhole "+ __version__) -g = parser.add_argument_group("wormhole configuration options") -g.add_argument("--relay-url", default=public_relay.RENDEZVOUS_RELAY, - metavar="URL", help="rendezvous relay to use", type=type(u"")) -g.add_argument("--transit-helper", default=public_relay.TRANSIT_RELAY, - metavar="tcp:HOST:PORT", help="transit relay to use", - type=type(u"")) -g.add_argument("-c", "--code-length", type=int, default=2, - metavar="WORDS", help="length of code (in bytes/words)") -g.add_argument("-v", "--verify", action="store_true", - help="display (and wait for acceptance of) verification string") -g.add_argument("--hide-progress", action="store_true", - help="supress progress-bar display") -g.add_argument("--dump-timing", type=type(u""), # TODO: hide from --help output - metavar="FILE", help="(debug) write timing data to file") -g.add_argument("--no-listen", action="store_true", - help="(debug) don't open a listening socket for Transit") -g.add_argument("--tor", action="store_true", - help="use Tor when connecting") -parser.set_defaults(timing=None) -subparsers = parser.add_subparsers(title="subcommands", - dest="subcommand") - - -# CLI: run-server -s = subparsers.add_parser("server", description="Start/stop a relay server") -sp = s.add_subparsers(title="subcommands", dest="subcommand") -sp_start = sp.add_parser("start", description="Start a relay server", - usage="wormhole server start [opts] [TWISTD-ARGS..]") -sp_start.add_argument("--rendezvous", default="tcp:3000", metavar="tcp:PORT", - help="endpoint specification for the rendezvous port") -sp_start.add_argument("--transit", default="tcp:3001", metavar="tcp:PORT", - help="endpoint specification for the transit-relay port") -sp_start.add_argument("--advertise-version", metavar="VERSION", - help="version to recommend to clients") -sp_start.add_argument("--blur-usage", default=None, type=int, - metavar="SECONDS", - help="round logged access times to improve privacy") -sp_start.add_argument("-n", "--no-daemon", action="store_true") -#sp_start.add_argument("twistd_args", nargs="*", default=None, -# metavar="[TWISTD-ARGS..]", -# help=dedent("""\ -# Additional arguments to pass to twistd"""), -# ) -sp_start.set_defaults(func="server/start") - -sp_stop = sp.add_parser("stop", description="Stop the relay server", - usage="wormhole server stop") -sp_stop.set_defaults(func="server/stop") - -sp_restart = sp.add_parser("restart", description="Restart the relay server", - usage="wormhole server restart") -sp_restart.add_argument("--rendezvous", default="tcp:3000", metavar="tcp:PORT", - help="endpoint specification for the rendezvous port") -sp_restart.add_argument("--transit", default="tcp:3001", metavar="tcp:PORT", - help="endpoint specification for the transit-relay port") -sp_restart.add_argument("--advertise-version", metavar="VERSION", - help="version to recommend to clients") -sp_restart.add_argument("--blur-usage", default=None, type=int, - metavar="SECONDS", - help="round logged access times to improve privacy") -sp_restart.add_argument("-n", "--no-daemon", action="store_true") -sp_restart.set_defaults(func="server/restart") - -sp_show_usage = sp.add_parser("show-usage", description="Display usage data", - usage="wormhole server show-usage") -sp_show_usage.add_argument("-n", default=100, type=int, - help="show last N entries") -sp_show_usage.set_defaults(func="usage/usage") - -sp_tail_usage = sp.add_parser("tail-usage", description="Follow latest usage", - usage="wormhole server tail-usage") -sp_tail_usage.set_defaults(func="usage/tail") - -# CLI: send -p = subparsers.add_parser("send", - description="Send text message, file, or directory", - usage="wormhole send [FILENAME|DIRNAME]") -p.add_argument("--text", metavar="MESSAGE", - help="text message to send, instead of a file. Use '-' to read from stdin.") -p.add_argument("--code", metavar="CODE", help="human-generated code phrase", - type=type(u"")) -p.add_argument("-0", dest="zeromode", action="store_true", - help="enable no-code anything-goes mode") -p.add_argument("what", nargs="?", default=None, metavar="[FILENAME|DIRNAME]", - help="the file/directory to send") -p.set_defaults(func="send/send") - -# CLI: receive -p = subparsers.add_parser("receive", - description="Receive a text message, file, or directory", - usage="wormhole receive [CODE]") -p.add_argument("-0", dest="zeromode", action="store_true", - help="enable no-code anything-goes mode") -p.add_argument("-t", "--only-text", dest="only_text", action="store_true", - help="refuse file transfers, only accept text transfers") -p.add_argument("--accept-file", dest="accept_file", action="store_true", - help="accept file transfer with asking for confirmation") -p.add_argument("-o", "--output-file", default=None, metavar="FILENAME|DIRNAME", - help=dedent("""\ - The file or directory to create, overriding the name suggested - by the sender."""), - ) -p.add_argument("code", nargs="?", default=None, metavar="[CODE]", - help=dedent("""\ - The magic-wormhole code, from the sender. If omitted, the - program will ask for it, using tab-completion."""), - type=type(u""), - ) -p.set_defaults(func="receive/receive") diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index 79a6221..89caba2 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -141,7 +141,7 @@ class TwistedReceiver: @inlineCallbacks def _build_transit(self, w, sender_transit): tr = TransitReceiver(self.args.transit_helper, - no_listen=self.args.no_listen, + no_listen=(not self.args.listen), tor_manager=self._tor_manager, reactor=self._reactor, timing=self.args.timing) diff --git a/src/wormhole/cli/cmd_send.py b/src/wormhole/cli/cmd_send.py index 1b7f3c2..cf082a9 100644 --- a/src/wormhole/cli/cmd_send.py +++ b/src/wormhole/cli/cmd_send.py @@ -48,9 +48,10 @@ class Sender: w = wormhole(APPID, self._args.relay_url, self._reactor, self._tor_manager, timing=self._timing) - d = self._go(w) - d.addBoth(w.close) - yield d + try: + yield self._go(w) + finally: + w.close() def _send_data(self, data, w): data_bytes = dict_to_bytes(data) @@ -100,7 +101,7 @@ class Sender: if self._fd_to_send: ts = TransitSender(args.transit_helper, - no_listen=args.no_listen, + no_listen=(not args.listen), tor_manager=self._tor_manager, reactor=self._reactor, timing=self._timing) diff --git a/src/wormhole/cli/runner.py b/src/wormhole/cli/runner.py deleted file mode 100644 index 8c373b1..0000000 --- a/src/wormhole/cli/runner.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import print_function -import time -start = time.time() -import os, sys, textwrap -from twisted.internet.defer import maybeDeferred -from twisted.internet.task import react -from ..errors import (TransferError, WrongPasswordError, WelcomeError, Timeout, - KeyFormatError) -from ..timing import DebugTiming -from .cli_args import parser -top_import_finish = time.time() - -def dispatch(args): # returns Deferred - if args.func == "send/send": - with args.timing.add("import", which="cmd_send"): - from . import cmd_send - return cmd_send.send(args) - if args.func == "receive/receive": - with args.timing.add("import", which="cmd_receive"): - from . import cmd_receive - return cmd_receive.receive(args) - - raise ValueError("unknown args.func %s" % args.func) - -def run(reactor, argv, cwd, stdout, stderr, executable=None): - """This is invoked by entry() below, and can also be invoked directly by - tests. - """ - - args = parser.parse_args(argv) - if not getattr(args, "func", None): - # So far this only works on py3. py2 exits with a really terse - # "error: too few arguments" during parse_args(). - parser.print_help() - sys.exit(0) - args.cwd = cwd - args.stdout = stdout - args.stderr = stderr - args.timing = timing = DebugTiming() - - timing.add("command dispatch") - timing.add("import", when=start, which="top").finish(when=top_import_finish) - # fires with None, or raises an error - d = maybeDeferred(dispatch, args) - def _maybe_dump_timing(res): - timing.add("exit") - if args.dump_timing: - timing.write(args.dump_timing, stderr) - return res - d.addBoth(_maybe_dump_timing) - def _explain_error(f): - # these errors don't print a traceback, just an explanation - f.trap(TransferError, WrongPasswordError, WelcomeError, Timeout, - KeyFormatError) - if f.check(WrongPasswordError): - msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__)) - print(msg, file=stderr) - elif f.check(WelcomeError): - msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__)) - print(msg, file=stderr) - print(file=stderr) - print(str(f.value), file=stderr) - elif f.check(KeyFormatError): - msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__)) - print(msg, file=stderr) - else: - print("ERROR:", f.value, file=stderr) - raise SystemExit(1) - d.addErrback(_explain_error) - d.addCallback(lambda _: 0) - return d - -def entry(): - """This is used by a setuptools entry_point. When invoked this way, - setuptools has already put the installed package on sys.path .""" - react(run, (sys.argv[1:], os.getcwd(), sys.stdout, sys.stderr, - sys.argv[0])) - -if __name__ == "__main__": - args = parser.parse_args() - print(args) diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index 4782830..cf039fd 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -47,10 +47,10 @@ class KeyFormatError(Exception): class ReflectionAttack(Exception): """An attacker (or bug) reflected our outgoing message back to us.""" -class UsageError(Exception): +class InternalError(Exception): """The programmer did something wrong.""" -class WormholeClosedError(UsageError): +class WormholeClosedError(InternalError): """API calls may not be made after close() is called.""" class TransferError(Exception): diff --git a/src/wormhole/server/cli.py b/src/wormhole/server/cli.py new file mode 100644 index 0000000..2408114 --- /dev/null +++ b/src/wormhole/server/cli.py @@ -0,0 +1,156 @@ +from __future__ import print_function + +import click + + +# can put this back in to get this command as "wormhole server" +# instead +#from ..cli.cli import wormhole +#@wormhole.group() +@click.group() +@click.pass_context +def server(ctx): + """ + Control a relay server (most users shouldn't need to worry + about this and can use the default server). + """ + # just leaving this pointing to wormhole.cli.cli.Config for now, + # but if we want to keep wormhole-server as a separate command + # should probably have our own Config without all the options the + # server commands don't use + from ..cli.cli import Config + ctx.obj = Config() + + +@server.command() +@click.option( + "--rendezvous", default="tcp:4000", metavar="tcp:PORT", + help="endpoint specification for the rendezvous port", +) +@click.option( + "--transit", default="tcp:4001", metavar="tcp:PORT", + help="endpoint specification for the transit-relay port", +) +@click.option( + "--advertise-version", metavar="VERSION", + help="version to recommend to clients", +) +@click.option( + "--blur-usage", default=None, type=int, + metavar="SECONDS", + help="round logged access times to improve privacy", +) +@click.option( + "--no-daemon", "-n", is_flag=True, + help="Run in the foreground", +) +@click.option( + "--signal-error", is_flag=True, + help="force all clients to fail with a message", +) +@click.pass_obj +def start(cfg, signal_error, no_daemon, blur_usage, advertise_version, transit, rendezvous): + """ + Start a relay server + """ + from wormhole.server.cmd_server import start_server + cfg.no_daemon = no_daemon + cfg.blur_usage = blur_usage + cfg.advertise_version = advertise_version + cfg.transit = str(transit) + cfg.rendezvous = str(rendezvous) + cfg.signal_error = signal_error + + start_server(cfg) + + +# XXX it would be nice to reduce the duplication between 'restart' and +# 'start' options... +@server.command() +@click.option( + "--rendezvous", default="tcp:4000", metavar="tcp:PORT", + help="endpoint specification for the rendezvous port", +) +@click.option( + "--transit", default="tcp:4001", metavar="tcp:PORT", + help="endpoint specification for the transit-relay port", +) +@click.option( + "--advertise-version", metavar="VERSION", + help="version to recommend to clients", +) +@click.option( + "--blur-usage", default=None, type=int, + metavar="SECONDS", + help="round logged access times to improve privacy", +) +@click.option( + "--no-daemon", "-n", is_flag=True, + help="Run in the foreground", +) +@click.option( + "--signal-error", is_flag=True, + help="force all clients to fail with a message", +) +@click.pass_obj +def restart(cfg, signal_error, no_daemon, blur_usage, advertise_version, transit, rendezvous): + """ + Re-start a relay server + """ + from wormhole.server.cmd_server import restart_server + cfg.no_daemon = no_daemon + cfg.blur_usage = blur_usage + cfg.advertise_version = advertise_version + cfg.transit = str(transit) + cfg.rendezvous = str(rendezvous) + cfg.signal_error = signal_error + + restart_server(cfg) + + +@server.command() +@click.pass_obj +def stop(cfg): + """ + Stop a relay server + """ + from wormhole.server.cmd_server import stop_server + stop_server(cfg) + + +@server.command(name="tail-usage") +@click.pass_obj +def tail_usage(cfg): + """ + Follow the latest usage + """ + from wormhole.server.cmd_usage import tail_usage + tail_usage(cfg) + + +@server.command(name='count-channels') +@click.option( + "--json", is_flag=True, +) +@click.pass_obj +def count_channels(cfg, json): + """ + Count active channels + """ + from wormhole.server.cmd_usage import count_channels + cfg.json = json + count_channels(cfg) + + +@server.command(name='count-events') +@click.option( + "--json", is_flag=True, +) +@click.pass_obj +def count_events(cfg, json): + """ + Count events + """ + from wormhole.server.cmd_usage import count_events + cfg.json = json + count_events(cfg) diff --git a/src/wormhole/server/cli_args.py b/src/wormhole/server/cli_args.py deleted file mode 100644 index bf95565..0000000 --- a/src/wormhole/server/cli_args.py +++ /dev/null @@ -1,77 +0,0 @@ -import argparse -from textwrap import dedent -from .. import __version__ - -parser = argparse.ArgumentParser( - usage="wormhole-server SUBCOMMAND (subcommand-options)", - description=dedent(""" - Create a Magic Wormhole and communicate through it. Wormholes are created - by speaking the same magic CODE in two different places at the same time. - Wormholes are secure against anyone who doesn't use the same code."""), - ) - -parser.add_argument("--version", action="version", - version="magic-wormhole "+ __version__) -s = parser.add_subparsers(title="subcommands", dest="subcommand") - - -# CLI: run-server -sp_start = s.add_parser("start", description="Start a relay server", - usage="wormhole server start [opts] [TWISTD-ARGS..]") -sp_start.add_argument("--rendezvous", default="tcp:4000", metavar="tcp:PORT", - help="endpoint specification for the rendezvous port") -sp_start.add_argument("--transit", default="tcp:4001", metavar="tcp:PORT", - help="endpoint specification for the transit-relay port") -sp_start.add_argument("--advertise-version", metavar="VERSION", - help="version to recommend to clients") -sp_start.add_argument("--signal-error", metavar="ERROR", - help="force all clients to fail with a message") -sp_start.add_argument("--blur-usage", default=None, type=int, - metavar="SECONDS", - help="round logged access times to improve privacy") -sp_start.add_argument("-n", "--no-daemon", action="store_true") -#sp_start.add_argument("twistd_args", nargs="*", default=None, -# metavar="[TWISTD-ARGS..]", -# help=dedent("""\ -# Additional arguments to pass to twistd"""), -# ) -sp_start.set_defaults(func="server/start") - -sp_stop = s.add_parser("stop", description="Stop the relay server", - usage="wormhole server stop") -sp_stop.set_defaults(func="server/stop") - -sp_restart = s.add_parser("restart", description="Restart the relay server", - usage="wormhole server restart") -sp_restart.add_argument("--rendezvous", default="tcp:4000", metavar="tcp:PORT", - help="endpoint specification for the rendezvous port") -sp_restart.add_argument("--transit", default="tcp:4001", metavar="tcp:PORT", - help="endpoint specification for the transit-relay port") -sp_restart.add_argument("--advertise-version", metavar="VERSION", - help="version to recommend to clients") -sp_restart.add_argument("--signal-error", metavar="ERROR", - help="force all clients to fail with a message") -sp_restart.add_argument("--blur-usage", default=None, type=int, - metavar="SECONDS", - help="round logged access times to improve privacy") -sp_restart.add_argument("-n", "--no-daemon", action="store_true") -sp_restart.set_defaults(func="server/restart") - -sp_show_usage = s.add_parser("show-usage", description="Display usage data", - usage="wormhole server show-usage") -sp_show_usage.add_argument("-n", default=100, type=int, - help="show last N entries") -sp_show_usage.set_defaults(func="usage/usage") - -sp_tail_usage = s.add_parser("tail-usage", description="Follow latest usage", - usage="wormhole server tail-usage") -sp_tail_usage.set_defaults(func="usage/tail") - -sp_count_channels = s.add_parser("count-channels", - description="Count active channels") -sp_count_channels.add_argument("--json", action="store_true") -sp_count_channels.set_defaults(func="usage/count-channels") - -sp_count_events = s.add_parser("count-events", description="Count events") -sp_count_events.add_argument("--json", action="store_true") -sp_count_events.set_defaults(func="usage/count-events") diff --git a/src/wormhole/server/cmd_server.py b/src/wormhole/server/cmd_server.py index ed2afe5..6c3c6cc 100644 --- a/src/wormhole/server/cmd_server.py +++ b/src/wormhole/server/cmd_server.py @@ -1,5 +1,6 @@ from __future__ import print_function, unicode_literals import os, time +import click from twisted.python import usage from twisted.scripts import twistd @@ -38,16 +39,17 @@ def kill_server(): try: f = open("twistd.pid", "r") except EnvironmentError: - print("Unable to find twistd.pid . Is this really a server directory?") - return 1 + raise click.UsageError( + "Unable to find 'twistd.pid' -- is this really a server directory?" + ) pid = int(f.read().strip()) f.close() os.kill(pid, 15) print("server process %d sent SIGTERM" % pid) - return 0 + return def stop_server(args): - return kill_server() + kill_server() def restart_server(args): kill_server() diff --git a/src/wormhole/server/cmd_usage.py b/src/wormhole/server/cmd_usage.py index fd294e7..c314313 100644 --- a/src/wormhole/server/cmd_usage.py +++ b/src/wormhole/server/cmd_usage.py @@ -1,8 +1,8 @@ from __future__ import print_function, unicode_literals import os, time, json from collections import defaultdict +import click from .database import get_db -from ..errors import UsageError def abbrev(t): if t is None: @@ -57,7 +57,9 @@ def show_usage(args): print("closed for renovation") return 0 if not os.path.exists("relay.sqlite"): - raise UsageError("cannot find relay.sqlite, please run from the server directory") + raise click.UsageError( + "cannot find relay.sqlite, please run from the server directory" + ) oldest = None newest = None rendezvous_counters = defaultdict(int) @@ -116,7 +118,9 @@ def show_usage(args): def tail_usage(args): if not os.path.exists("relay.sqlite"): - raise UsageError("cannot find relay.sqlite, please run from the server directory") + raise click.UsageError( + "cannot find relay.sqlite, please run from the server directory" + ) db = get_db("relay.sqlite") # we don't seem to have unique row IDs, so this is an inaccurate and # inefficient hack @@ -141,7 +145,9 @@ def tail_usage(args): def count_channels(args): if not os.path.exists("relay.sqlite"): - raise UsageError("cannot find relay.sqlite, please run from the server directory") + raise click.UsageError( + "cannot find relay.sqlite, please run from the server directory" + ) db = get_db("relay.sqlite") c_list = [] c_dict = {} @@ -188,7 +194,9 @@ def count_channels(args): def count_events(args): if not os.path.exists("relay.sqlite"): - raise UsageError("cannot find relay.sqlite, please run from the server directory") + raise click.UsageError( + "cannot find relay.sqlite, please run from the server directory" + ) db = get_db("relay.sqlite") c_list = [] c_dict = {} diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index 07015db..bea06da 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -1,6 +1,6 @@ # no unicode_literals untill twisted update from twisted.application import service -from twisted.internet import reactor, defer +from twisted.internet import defer, task from twisted.python import log from ..transit import allocate_tcp_port from ..server.server import RelayServer @@ -36,8 +36,18 @@ class ServerBase: # relay's .stopService() drops all connections, which ought to # encourage those threads to terminate soon. If they don't, print a # warning to ease debugging. + + # XXX FIXME there's something in _noclobber test that's not + # waiting for a close, I think -- was pretty relieably getting + # unclean-reactor, but adding a slight pause here stops it... + from twisted.internet import reactor + tp = reactor.getThreadPool() if not tp.working: + d = defer.succeed(None) + d.addCallback(lambda _: self.sp.stopService()) + d.addCallback(lambda _: task.deferLater(reactor, 0.1, lambda: None)) + return d return self.sp.stopService() # disconnect all callers d = defer.maybeDeferred(self.sp.stopService) diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index 40dbf8d..f35f916 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -7,33 +7,32 @@ from twisted.internet.utils import getProcessOutputAndValue from twisted.internet.defer import gatherResults, inlineCallbacks from .. import __version__ from .common import ServerBase -from ..cli import runner, cmd_send, cmd_receive +from ..cli import cmd_send, cmd_receive +from ..cli.cli import Config from ..errors import TransferError, WrongPasswordError, WelcomeError -from ..timing import DebugTiming + def build_offer(args): s = cmd_send.Sender(args, None) return s._build_offer() + class OfferData(unittest.TestCase): def setUp(self): self._things_to_delete = [] + self.cfg = cfg = Config() + cfg.stdout = io.StringIO() + cfg.stderr = io.StringIO() def tearDown(self): for fn in self._things_to_delete: if os.path.exists(fn): os.unlink(fn) + del self.cfg def test_text(self): - message = "blah blah blah ponies" - - send_args = [ "send", "--text", message ] - args = runner.parser.parse_args(send_args) - args.cwd = os.getcwd() - args.stdout = io.StringIO() - args.stderr = io.StringIO() - - d, fd_to_send = build_offer(args) + self.cfg.text = message = "blah blah blah ponies" + d, fd_to_send = build_offer(self.cfg) self.assertIn("message", d) self.assertNotIn("file", d) @@ -42,7 +41,7 @@ class OfferData(unittest.TestCase): self.assertEqual(fd_to_send, None) def test_file(self): - filename = "my file" + self.cfg.what = filename = "my file" message = b"yay ponies\n" send_dir = self.mktemp() os.mkdir(send_dir) @@ -50,13 +49,8 @@ class OfferData(unittest.TestCase): with open(abs_filename, "wb") as f: f.write(message) - send_args = [ "send", filename ] - args = runner.parser.parse_args(send_args) - args.cwd = send_dir - args.stdout = io.StringIO() - args.stderr = io.StringIO() - - d, fd_to_send = build_offer(args) + self.cfg.cwd = send_dir + d, fd_to_send = build_offer(self.cfg) self.assertNotIn("message", d) self.assertIn("file", d) @@ -67,17 +61,12 @@ class OfferData(unittest.TestCase): self.assertEqual(fd_to_send.read(), message) def test_missing_file(self): - filename = "missing" + self.cfg.what = filename = "missing" send_dir = self.mktemp() os.mkdir(send_dir) + self.cfg.cwd = send_dir - send_args = [ "send", filename ] - args = runner.parser.parse_args(send_args) - args.cwd = send_dir - args.stdout = io.StringIO() - args.stderr = io.StringIO() - - e = self.assertRaises(TransferError, build_offer, args) + e = self.assertRaises(TransferError, build_offer, self.cfg) self.assertEqual(str(e), "Cannot send: no file/directory named '%s'" % filename) @@ -94,13 +83,10 @@ class OfferData(unittest.TestCase): send_dir_arg = send_dir if addslash: send_dir_arg += os.sep - send_args = [ "send", send_dir_arg ] - args = runner.parser.parse_args(send_args) - args.cwd = parent_dir - args.stdout = io.StringIO() - args.stderr = io.StringIO() + self.cfg.what = send_dir_arg + self.cfg.cwd = parent_dir - d, fd_to_send = build_offer(args) + d, fd_to_send = build_offer(self.cfg) self.assertNotIn("message", d) self.assertNotIn("file", d) @@ -130,10 +116,11 @@ class OfferData(unittest.TestCase): return self._do_test_directory(addslash=True) def test_unknown(self): - filename = "unknown" + self.cfg.what = filename = "unknown" send_dir = self.mktemp() os.mkdir(send_dir) abs_filename = os.path.abspath(os.path.join(send_dir, filename)) + self.cfg.cwd = send_dir try: os.mkfifo(abs_filename) @@ -149,13 +136,7 @@ class OfferData(unittest.TestCase): self.assertFalse(os.path.isfile(abs_filename)) self.assertFalse(os.path.isdir(abs_filename)) - send_args = [ "send", filename ] - args = runner.parser.parse_args(send_args) - args.cwd = send_dir - args.stdout = io.StringIO() - args.stderr = io.StringIO() - - e = self.assertRaises(TypeError, build_offer, args) + e = self.assertRaises(TypeError, build_offer, self.cfg) self.assertEqual(str(e), "'%s' is neither file nor directory" % filename) @@ -203,27 +184,24 @@ class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase): # we need Twisted to run the server, but we run the sender and receiver # with deferToThread() + @inlineCallbacks def test_version(self): # "wormhole" must be on the path, so e.g. "pip install -e ." in a # virtualenv. This guards against an environment where the tests # below might run the wrong executable. + self.maxDiff = None wormhole = self.find_executable() - d = getProcessOutputAndValue(wormhole, ["--version"]) - def _check(res): - out, err, rc = res - # argparse on py2 and py3.3 sends --version to stderr - # argparse on py3.4/py3.5 sends --version to stdout - # aargh - err = err.decode("utf-8") - if "DistributionNotFound" in err: - log.msg("stderr was %s" % err) - last = err.strip().split("\n")[-1] - self.fail("wormhole not runnable: %s" % last) - ver = out.decode("utf-8") or err - self.failUnlessEqual(ver, "magic-wormhole "+__version__+os.linesep) - self.failUnlessEqual(rc, 0) - d.addCallback(_check) - return d + # we must pass on the environment so that "something" doesn't + # get sad about UTF8 vs. ascii encodings + out, err, rc = yield getProcessOutputAndValue(wormhole, ["--version"], env=os.environ) + err = err.decode("utf-8") + if "DistributionNotFound" in err: + log.msg("stderr was %s" % err) + last = err.strip().split("\n")[-1] + self.fail("wormhole not runnable: %s" % last) + ver = out.decode("utf-8") or err + self.failUnlessEqual(ver.strip(), "magic-wormhole {}".format(__version__)) + self.failUnlessEqual(rc, 0) class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): # we need Twisted to run the server, but we run the sender and receiver @@ -238,20 +216,18 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): def _do_test(self, as_subprocess=False, mode="text", addslash=False, override_filename=False): assert mode in ("text", "file", "directory") - common_args = ["--hide-progress", - "--relay-url", self.relayurl, - "--transit-helper", ""] - code = "1-abc" - message = "test message" + send_cfg = Config() + recv_cfg = Config() + message = "blah blah blah ponies" - send_args = common_args + [ - "send", - "--code", code, - ] - - receive_args = common_args + [ - "receive", - ] + for cfg in [send_cfg, recv_cfg]: + cfg.hide_progress = True + cfg.relay_url = self.relayurl + cfg.transit_helper = "" + cfg.listen = True + cfg.code = "1-abc" + cfg.stdout = io.StringIO() + cfg.stderr = io.StringIO() send_dir = self.mktemp() os.mkdir(send_dir) @@ -259,19 +235,18 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): os.mkdir(receive_dir) if mode == "text": - send_args.extend(["--text", message]) + send_cfg.text = message elif mode == "file": send_filename = "testfile" with open(os.path.join(send_dir, send_filename), "w") as f: f.write(message) - send_args.append(send_filename) + send_cfg.what = send_filename receive_filename = send_filename - receive_args.append("--accept-file") + recv_cfg.accept_file = True if override_filename: - receive_args.extend(["-o", "outfile"]) - receive_filename = "outfile" + recv_cfg.output_file = receive_filename = "outfile" elif mode == "directory": # $send_dir/ @@ -299,22 +274,48 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): send_dirname_arg = os.path.join("middle", send_dirname) if addslash: send_dirname_arg += os.sep - send_args.append(send_dirname_arg) + send_cfg.what = send_dirname_arg receive_dirname = send_dirname - receive_args.append("--accept-file") + recv_cfg.accept_file = True if override_filename: - receive_args.extend(["-o", "outdir"]) - receive_dirname = "outdir" - - receive_args.append(code) + recv_cfg.output_file = receive_dirname = "outdir" if as_subprocess: wormhole_bin = self.find_executable() - send_d = getProcessOutputAndValue(wormhole_bin, send_args, - path=send_dir) - receive_d = getProcessOutputAndValue(wormhole_bin, receive_args, - path=receive_dir) + if send_cfg.text: + content_args = ['--text', send_cfg.text] + elif send_cfg.what: + content_args = [send_cfg.what] + + send_args = [ + '--hide-progress', + '--relay-url', self.relayurl, + '--transit-helper', '', + 'send', + '--code', send_cfg.code, + ] + content_args + + send_d = getProcessOutputAndValue( + wormhole_bin, send_args, + path=send_dir, + ) + recv_args = [ + '--hide-progress', + '--relay-url', self.relayurl, + '--transit-helper', '', + 'receive', + '--accept-file', + recv_cfg.code, + ] + if override_filename: + recv_args.extend(['-o', receive_filename]) + + receive_d = getProcessOutputAndValue( + wormhole_bin, recv_args, + path=receive_dir, + ) + (send_res, receive_res) = yield gatherResults([send_d, receive_d], True) send_stdout = send_res[0].decode("utf-8") @@ -327,27 +328,20 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.assertEqual((send_rc, receive_rc), (0, 0), (send_res, receive_res)) else: - sargs = runner.parser.parse_args(send_args) - sargs.cwd = send_dir - sargs.stdout = io.StringIO() - sargs.stderr = io.StringIO() - sargs.timing = DebugTiming() - rargs = runner.parser.parse_args(receive_args) - rargs.cwd = receive_dir - rargs.stdout = io.StringIO() - rargs.stderr = io.StringIO() - rargs.timing = DebugTiming() - send_d = cmd_send.send(sargs) - receive_d = cmd_receive.receive(rargs) + send_cfg.cwd = send_dir + send_d = cmd_send.send(send_cfg) + + recv_cfg.cwd = receive_dir + receive_d = cmd_receive.receive(recv_cfg) # The sender might fail, leaving the receiver hanging, or vice # versa. Make sure we don't wait on one side exclusively yield gatherResults([send_d, receive_d], True) - send_stdout = sargs.stdout.getvalue() - send_stderr = sargs.stderr.getvalue() - receive_stdout = rargs.stdout.getvalue() - receive_stderr = rargs.stderr.getvalue() + send_stdout = send_cfg.stdout.getvalue() + send_stderr = send_cfg.stderr.getvalue() + receive_stdout = recv_cfg.stdout.getvalue() + receive_stderr = recv_cfg.stderr.getvalue() # all output here comes from a StringIO, which uses \n for # newlines, even if we're on windows @@ -367,7 +361,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): "wormhole receive{NL}" "Wormhole code is: {code}{NL}{NL}" "text message sent{NL}").format(bytes=len(message), - code=code, + code=send_cfg.code, NL=NL) self.failUnlessEqual(send_stdout, expected) elif mode == "file": @@ -377,7 +371,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.failUnlessIn("On the other computer, please run: " "wormhole receive{NL}" "Wormhole code is: {code}{NL}{NL}" - .format(code=code, NL=NL), + .format(code=send_cfg.code, NL=NL), send_stdout) self.failUnlessIn("File sent.. waiting for confirmation{NL}" "Confirmation received. Transfer complete.{NL}" @@ -388,7 +382,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.failUnlessIn("On the other computer, please run: " "wormhole receive{NL}" "Wormhole code is: {code}{NL}{NL}" - .format(code=code, NL=NL), send_stdout) + .format(code=send_cfg.code, NL=NL), send_stdout) self.failUnlessIn("File sent.. waiting for confirmation{NL}" "Confirmation received. Transfer complete.{NL}" .format(NL=NL), send_stdout) @@ -440,14 +434,21 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): @inlineCallbacks def test_file_noclobber(self): - common_args = ["--hide-progress", "--no-listen", - "--relay-url", self.relayurl, - "--transit-helper", ""] - code = "1-abc" + send_cfg = Config() + recv_cfg = Config() + + for cfg in [send_cfg, recv_cfg]: + cfg.hide_progress = True + cfg.relay_url = self.relayurl + cfg.transit_helper = "" + cfg.listen = False + cfg.code = code = "1-abc" + cfg.stdout = io.StringIO() + cfg.stderr = io.StringIO() + message = "test message" - send_args = common_args + [ "send", "--code", code ] - receive_args = common_args + [ "receive", "--accept-file", code ] + recv_cfg.accept_file = True send_dir = self.mktemp() os.mkdir(send_dir) @@ -457,26 +458,19 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): send_filename = "testfile" with open(os.path.join(send_dir, send_filename), "w") as f: f.write(message) - send_args.append(send_filename) - receive_filename = send_filename + send_cfg.what = receive_filename = send_filename + recv_cfg.what = receive_filename PRESERVE = "don't clobber me\n" clobberable = os.path.join(receive_dir, receive_filename) with open(clobberable, "w") as f: f.write(PRESERVE) - sargs = runner.parser.parse_args(send_args) - sargs.cwd = send_dir - sargs.stdout = io.StringIO() - sargs.stderr = io.StringIO() - sargs.timing = DebugTiming() - rargs = runner.parser.parse_args(receive_args) - rargs.cwd = receive_dir - rargs.stdout = io.StringIO() - rargs.stderr = io.StringIO() - rargs.timing = DebugTiming() - send_d = cmd_send.send(sargs) - receive_d = cmd_receive.receive(rargs) + send_cfg.cwd = send_dir + send_d = cmd_send.send(send_cfg) + + recv_cfg.cwd = receive_dir + receive_d = cmd_receive.receive(recv_cfg) # both sides will fail because of the pre-existing file @@ -486,10 +480,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): f = yield self.assertFailure(receive_d, TransferError) self.assertEqual(str(f), "file already exists") - send_stdout = sargs.stdout.getvalue() - send_stderr = sargs.stderr.getvalue() - receive_stdout = rargs.stdout.getvalue() - receive_stderr = rargs.stderr.getvalue() + send_stdout = send_cfg.stdout.getvalue() + send_stderr = send_cfg.stderr.getvalue() + receive_stdout = recv_cfg.stdout.getvalue() + receive_stderr = recv_cfg.stderr.getvalue() # all output here comes from a StringIO, which uses \n for # newlines, even if we're on windows @@ -528,62 +522,54 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): class NotWelcome(ServerBase, unittest.TestCase): def setUp(self): self._setup_relay(error="please upgrade XYZ") + self.cfg = cfg = Config() + cfg.hide_progress = True + cfg.listen = False + cfg.relay_url = self.relayurl + cfg.transit_helper = "" + cfg.stdout = io.StringIO() + cfg.stderr = io.StringIO() @inlineCallbacks def test_sender(self): - common_args = ["--hide-progress", "--no-listen", - "--relay-url", self.relayurl, - "--transit-helper", ""] - send_args = common_args + [ "send", "--text", "hi", - "--code", "1-abc" ] - sargs = runner.parser.parse_args(send_args) - sargs.cwd = self.mktemp() - sargs.stdout = io.StringIO() - sargs.stderr = io.StringIO() - sargs.timing = DebugTiming() + self.cfg.text = "hi" + self.cfg.code = "1-abc" - send_d = cmd_send.send(sargs) + send_d = cmd_send.send(self.cfg) f = yield self.assertFailure(send_d, WelcomeError) self.assertEqual(str(f), "please upgrade XYZ") @inlineCallbacks def test_receiver(self): - common_args = ["--hide-progress", "--no-listen", - "--relay-url", self.relayurl, - "--transit-helper", ""] - receive_args = common_args + [ "receive", "1-abc" ] - rargs = runner.parser.parse_args(receive_args) - rargs.cwd = self.mktemp() - rargs.stdout = io.StringIO() - rargs.stderr = io.StringIO() - rargs.timing = DebugTiming() + self.cfg.code = "1-abc" - receive_d = cmd_receive.receive(rargs) + receive_d = cmd_receive.receive(self.cfg) f = yield self.assertFailure(receive_d, WelcomeError) self.assertEqual(str(f), "please upgrade XYZ") + class Cleanup(ServerBase, unittest.TestCase): + + def setUp(self): + d = super(Cleanup, self).setUp() + self.cfg = cfg = Config() + # common options for all tests in this suite + cfg.hide_progress = True + cfg.relay_url = self.relayurl + cfg.transit_helper = "" + cfg.stdout = io.StringIO() + cfg.stderr = io.StringIO() + return d + @inlineCallbacks - def test_text(self): + @mock.patch('sys.stdout') + def test_text(self, stdout): # the rendezvous channel should be deleted after success - code = "1-abc" - common_args = ["--hide-progress", - "--relay-url", self.relayurl, - "--transit-helper", ""] - sargs = runner.parser.parse_args(common_args + - ["send", - "--text", "secret message", - "--code", code]) - sargs.stdout = io.StringIO() - sargs.stderr = io.StringIO() - sargs.timing = DebugTiming() - rargs = runner.parser.parse_args(common_args + - ["receive", code]) - rargs.stdout = io.StringIO() - rargs.stderr = io.StringIO() - rargs.timing = DebugTiming() - send_d = cmd_send.send(sargs) - receive_d = cmd_receive.receive(rargs) + self.cfg.text = "hello" + self.cfg.code = "1-abc" + + send_d = cmd_send.send(self.cfg) + receive_d = cmd_receive.receive(self.cfg) yield send_d yield receive_d @@ -595,23 +581,12 @@ class Cleanup(ServerBase, unittest.TestCase): def test_text_wrong_password(self): # if the password was wrong, the rendezvous channel should still be # deleted - common_args = ["--hide-progress", - "--relay-url", self.relayurl, - "--transit-helper", ""] - sargs = runner.parser.parse_args(common_args + - ["send", - "--text", "secret message", - "--code", "1-abc"]) - sargs.stdout = io.StringIO() - sargs.stderr = io.StringIO() - sargs.timing = DebugTiming() - rargs = runner.parser.parse_args(common_args + - ["receive", "1-WRONG"]) - rargs.stdout = io.StringIO() - rargs.stderr = io.StringIO() - rargs.timing = DebugTiming() - send_d = cmd_send.send(sargs) - receive_d = cmd_receive.receive(rargs) + self.cfg.text = "secret message" + self.cfg.code = "1-abc" + send_d = cmd_send.send(self.cfg) + + self.cfg.code = "1-WRONG" + receive_d = cmd_receive.receive(self.cfg) # both sides should be capable of detecting the mismatch yield self.assertFailure(send_d, WrongPasswordError) diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index 6f3f0cd..bdc5806 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -6,8 +6,8 @@ from twisted.internet import defer, task, endpoints, protocol, address, error from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.python import log, failure from twisted.test import proto_helpers +from ..errors import InternalError from .. import transit -from ..errors import UsageError from nacl.secret import SecretBox from nacl.exceptions import CryptoError @@ -147,7 +147,7 @@ class Basic(unittest.TestCase): "hostname": "host", "port": 1234}], }]) - self.assertRaises(UsageError, transit.Common, 123) + self.assertRaises(InternalError, transit.Common, 123) @inlineCallbacks def test_no_relay_hints(self): diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index 2b795f2..67ad14f 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -7,7 +7,7 @@ from twisted.internet import reactor from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks from .common import ServerBase from .. import wormhole -from ..errors import (WrongPasswordError, WelcomeError, UsageError, +from ..errors import (WrongPasswordError, WelcomeError, InternalError, KeyFormatError) from spake2 import SPAKE2_Symmetric from ..timing import DebugTiming @@ -876,20 +876,20 @@ class Errors(ServerBase, unittest.TestCase): def test_codes_1(self): w = wormhole.wormhole(APPID, self.relayurl, reactor) # definitely too early - self.assertRaises(UsageError, w.derive_key, "purpose", 12) + self.assertRaises(InternalError, w.derive_key, "purpose", 12) w.set_code("123-purple-elephant") # code can only be set once - self.assertRaises(UsageError, w.set_code, "123-nope") - yield self.assertFailure(w.get_code(), UsageError) - yield self.assertFailure(w.input_code(), UsageError) + self.assertRaises(InternalError, w.set_code, "123-nope") + yield self.assertFailure(w.get_code(), InternalError) + yield self.assertFailure(w.input_code(), InternalError) yield w.close() @inlineCallbacks def test_codes_2(self): w = wormhole.wormhole(APPID, self.relayurl, reactor) yield w.get_code() - self.assertRaises(UsageError, w.set_code, "123-nope") - yield self.assertFailure(w.get_code(), UsageError) - yield self.assertFailure(w.input_code(), UsageError) + self.assertRaises(InternalError, w.set_code, "123-nope") + yield self.assertFailure(w.get_code(), InternalError) + yield self.assertFailure(w.input_code(), InternalError) yield w.close() diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index 050cd00..308e92f 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -13,7 +13,7 @@ from twisted.internet.defer import inlineCallbacks, returnValue from twisted.protocols import policies from nacl.secret import SecretBox from hkdf import Hkdf -from .errors import UsageError +from .errors import InternalError from .timing import DebugTiming from . import ipaddrs @@ -276,7 +276,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): return self._description def send_record(self, record): - if not isinstance(record, type(b"")): raise UsageError + if not isinstance(record, type(b"")): raise InternalError assert SecretBox.NONCE_SIZE == 24 assert self.send_nonce < 2**(8*24) assert len(record) < 2**(8*4) @@ -577,7 +577,7 @@ class Common: reactor=reactor, timing=None): if transit_relay: if not isinstance(transit_relay, type(u"")): - raise UsageError + raise InternalError relay = RelayV1Hint(hints=[parse_hint_argv(transit_relay)]) self._transit_relays = [relay] else: @@ -813,6 +813,9 @@ class Common: is_relay=True) contenders.append(d) + if not contenders: + raise TransitError("No contenders for connection") + winner = there_can_be_only_one(contenders) return self._not_forever(2*TIMEOUT, winner) diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 3a716b1..11fbd12 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -14,7 +14,7 @@ from hashlib import sha256 from . import __version__ from . import codes #from .errors import ServerError, Timeout -from .errors import (WrongPasswordError, UsageError, WelcomeError, +from .errors import (WrongPasswordError, InternalError, WelcomeError, WormholeClosedError, KeyFormatError) from .timing import DebugTiming from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes, @@ -425,8 +425,8 @@ class _Wormhole: # entry point 1: generate a new code @inlineCallbacks def _API_get_code(self, code_length): - if self._code is not None: raise UsageError - if self._started_get_code: raise UsageError + if self._code is not None: raise InternalError + if self._started_get_code: raise InternalError self._started_get_code = True with self._timing.add("API get_code"): yield self._when_connected() @@ -443,8 +443,8 @@ class _Wormhole: # entry point 2: interactively type in a code, with completion @inlineCallbacks def _API_input_code(self, prompt, code_length): - if self._code is not None: raise UsageError - if self._started_input_code: raise UsageError + if self._code is not None: raise InternalError + if self._started_input_code: raise InternalError self._started_input_code = True with self._timing.add("API input_code"): yield self._when_connected() @@ -462,8 +462,10 @@ class _Wormhole: # entry point 3: paste in a fully-formed code def _API_set_code(self, code): self._timing.add("API set_code") - if not isinstance(code, type("")): raise TypeError(type(code)) - if self._code is not None: raise UsageError + if not isinstance(code, type(u"")): + raise TypeError("Unexpected code type '{}'".format(type(code))) + if self._code is not None: + raise InternalError self._event_learned_code(code) # TODO: entry point 4: restore pre-contact saved state (we haven't heard @@ -526,7 +528,7 @@ class _Wormhole: self._event_learned_mailbox() def _event_learned_mailbox(self): - if not self._mailbox_id: raise UsageError + if not self._mailbox_id: raise InternalError assert self._mailbox_state == CLOSED, self._mailbox_state if self._closing: return @@ -578,7 +580,7 @@ class _Wormhole: def _API_verify(self): if self._error: return defer.fail(self._error) - if self._get_verifier_called: raise UsageError + if self._get_verifier_called: raise InternalError self._get_verifier_called = True if self._verify_result: return defer.succeed(self._verify_result) # bytes or Failure @@ -683,7 +685,7 @@ class _Wormhole: return box.encrypt(data, nonce) def _msg_send(self, phase, body): - if phase in self._sent_phases: raise UsageError + if phase in self._sent_phases: raise InternalError assert self._mailbox_state == OPEN, self._mailbox_state self._sent_phases.add(phase) # TODO: retry on failure, with exponential backoff. We're guarding @@ -700,14 +702,14 @@ class _Wormhole: def _API_derive_key(self, purpose, length): if self._error: raise self._error if self._key is None: - raise UsageError # call derive_key after get_verifier() or get() + raise InternalError # call derive_key after get_verifier() or get() if not isinstance(purpose, type("")): raise TypeError(type(purpose)) return self._derive_key(to_bytes(purpose), length) def _derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(b"")): raise TypeError(type(purpose)) if self._key is None: - raise UsageError # call derive_key after get_verifier() or get() + raise InternalError # call derive_key after get_verifier() or get() return HKDF(self._key, length, CTXinfo=purpose) def _response_handle_message(self, msg): @@ -782,7 +784,7 @@ class _Wormhole: @inlineCallbacks def _API_close(self, res, mood="happy"): if self.DEBUG: print("close") - if self._close_called: raise UsageError + if self._close_called: raise InternalError self._close_called = True self._maybe_close(WormholeClosedError(), mood) if self.DEBUG: print("waiting for disconnect") diff --git a/tox.ini b/tox.ini index 36a1650..cea30da 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ # and then run "tox" from this directory. [tox] -envlist = {py27,py33,py34,py35} +envlist = {py27,py33,py34,py35,pypy} skip_missing_interpreters = True # On windows we need "pypiwin32" installed. It's supposedly possible to make