Refactor to use Click

This commit is contained in:
meejah 2016-06-03 15:17:47 -07:00
parent d89fbd69dd
commit e16b53817e
13 changed files with 594 additions and 540 deletions

View File

@ -19,15 +19,19 @@ setup(name="magic-wormhole",
"wormhole.test", "wormhole.test",
], ],
package_data={"wormhole.server": ["db-schemas/*.sql"]}, package_data={"wormhole.server": ["db-schemas/*.sql"]},
entry_points={"console_scripts": entry_points={
["wormhole = wormhole.cli.runner:entry", "console_scripts":
"wormhole-server = wormhole.server.runner:entry", [
]}, "wormhole = wormhole.cli.cli:wormhole",
"wormhole-server = wormhole.server.cli:server",
]
},
install_requires=["spake2==0.7", "pynacl", "argparse", install_requires=["spake2==0.7", "pynacl", "argparse",
"six", "six",
"twisted", "twisted",
"autobahn[twisted] >= 0.14.1", "autobahn[twisted] >= 0.14.1",
"hkdf", "tqdm", "hkdf", "tqdm",
"click",
], ],
extras_require={':sys_platform=="win32"': ["pypiwin32"], extras_require={':sys_platform=="win32"': ["pypiwin32"],
"tor": ["txtorcon", "ipaddress"]}, "tor": ["txtorcon", "ipaddress"]},

231
src/wormhole/cli/cli.py Normal file
View File

@ -0,0 +1,231 @@
from __future__ import print_function
import os
import time
start = time.time()
import traceback
from textwrap import fill, dedent
from sys import stdout, stderr
from . import public_relay
from .. import __version__
from ..timing import DebugTiming
from ..errors import WrongPasswordError, WelcomeError, KeyFormatError
from twisted.internet.defer import inlineCallbacks, maybeDeferred
from twisted.internet.task import react
import click
top_import_finish = time.time()
class Config(object):
"""
Union of config options that we pass down to (sub) commands.
"""
def __init__(self):
# common options
self.timing = DebugTiming()
self.tor = None
self.listen = None
self.relay_url = u""
self.transit_helper = u""
self.cwd = os.getcwd()
# send/receive commands
self.code = None
self.code_length = 2
self.verify = False
self.hide_progress = False
self.dump_timing = False
self.stdout = stdout
self.stderr = stderr
self.zeromode = False
self.accept_file = None
self.output_file = None
# send only
self.text = None
self.what = None
ALIASES = {
"tx": "send",
"rx": "receive",
}
class AliasedGroup(click.Group):
def get_command(self, ctx, cmd_name):
cmd_name = ALIASES.get(cmd_name, cmd_name)
return click.Group.get_command(self, ctx, cmd_name)
# top-level command ("wormhole ...")
@click.group(cls=AliasedGroup)
@click.option(
"--relay-url", default=public_relay.RENDEZVOUS_RELAY,
metavar="URL",
help="rendezvous relay to use",
)
@click.option(
"--transit-helper", default=public_relay.TRANSIT_RELAY,
metavar="tcp:HOST:PORT",
help="transit relay to use",
)
@click.option(
"-c", "--code-length", default=2,
metavar="NUMWORDS",
help="length of code (in bytes/words)",
)
@click.option(
"-v", "--verify", is_flag=True, default=False,
help="display (and wait for acceptance of) verification string",
)
@click.option(
"--hide-progress", is_flag=True, default=False,
help="supress progress-bar display",
)
@click.option(
"--dump-timing", type=type(u""), # TODO: hide from --help output
default=None,
metavar="FILE.json",
help="(debug) write timing data to file",
)
@click.option(
"--no-listen", is_flag=True, default=False,
help="(debug) don't open a listening socket for Transit",
)
@click.option(
"--tor", is_flag=True, default=True,
help="use Tor when connecting",
)
@click.version_option(
message="magic-wormhole %(version)s",
version=__version__,
)
@click.pass_context
def wormhole(ctx, tor, no_listen, dump_timing, hide_progress,
verify, code_length, transit_helper, relay_url):
"""
Create a Magic Wormhole and communicate through it.
Wormholes are created by speaking the same magic CODE in two
different places at the same time. Wormholes are secure against
anyone who doesn't use the same code.
"""
ctx.obj = cfg = Config()
ctx.tor = tor
if no_listen:
cfg.listen = False
cfg.relay_url = relay_url
cfg.transit_helper = transit_helper
cfg.code_length = code_length
cfg.verify = verify
cfg.hide_progress = hide_progress
cfg.dump_timing = dump_timing
@inlineCallbacks
def _dispatch_command(reactor, cfg, command):
"""
Internal helper. This calls the give command (a no-argument
callable) with the Config instance in cfg and interprets any
errors for the user.
"""
cfg.timing.add("command dispatch")
cfg.timing.add("import", when=start, which="top").finish(when=top_import_finish)
try:
yield maybeDeferred(command)
except WrongPasswordError as e:
msg = fill("ERROR: " + dedent(e.__doc__))
print(msg, file=stderr)
except WelcomeError as e:
msg = fill("ERROR: " + dedent(e.__doc__))
print(msg, file=stderr)
print(file=stderr)
print(str(e), file=stderr)
except KeyFormatError as e:
msg = fill("ERROR: " + dedent(e.__doc__))
print(msg, file=stderr)
except Exception as e:
traceback.print_exc()
print("ERROR:", e, file=stderr)
raise SystemExit(1)
cfg.timing.add("exit")
if cfg.dump_timing:
cfg.timing.write(cfg.dump_timing, stderr)
# wormhole send (or "wormhole tx")
@wormhole.command()
@click.option(
"zeromode", "-0", default=False, is_flag=True,
)
@click.option(
"--code", metavar="CODE",
help="human-generated code phrase",
)
@click.option(
"--text", default=None, metavar="MESSAGE",
help="text message to send, instead of a file. Use '-' to read from stdin.",
)
@click.argument("what", default=u'')
@click.pass_obj
def send(cfg, what, text, code, zeromode):
"""Send a text message, file, or directory"""
with cfg.timing.add("import", which="cmd_send"):
from . import cmd_send
cfg.what = what
cfg.text = text
cfg.zeromode = zeromode
cfg.code = code
react(_dispatch_command, (cfg, lambda: cmd_send.send(cfg)))
# wormhole receive (or "wormhole rx")
@wormhole.command()
@click.option(
"--only-text", "-t", is_flag=True,
help="refuse file transfers, only accept text transfers",
)
@click.option(
"--accept-file", is_flag=True,
help="accept file transfer without asking for confirmation",
)
@click.option(
"--output-file", "-o",
metavar="FILENAME|DIRNAME",
help=("The file or directory to create, overriding the name suggested"
" by the sender."),
)
@click.option(
"-0", "zeromode", is_flag=True,
help="enable no-code anything-goes mode",
)
@click.argument(
"code", nargs=-1, default=None,
# help=("The magic-wormhole code, from the sender. If omitted, the"
# " program will ask for it, using tab-completion."),
)
@click.pass_obj
def receive(cfg, code, zeromode, output_file, accept_file, only_text):
"""
Receive a text message, file, or directory (from 'wormhole send')
"""
with cfg.timing.add("import", which="cmd_receive"):
from . import cmd_receive
cfg.zeromode = zeromode
cfg.output_file = output_file
cfg.accept_file = accept_file
cfg.only_text = only_text
if len(code) == 1:
cfg.code = code[0]
elif len(code) > 1:
print(
"Pass either no code or just one code; you passed"
" {}: {}".format(len(code), ', '.join(code))
)
raise SystemExit(1)
else:
cfg.code = None
react(_dispatch_command, (cfg, lambda: cmd_receive.receive(cfg)))
return

View File

@ -1,179 +0,0 @@
import click
from textwrap import dedent
from . import public_relay
from .. import __version__
class Common:
def __init__(self, stuff):
self.stuff = stuff
ALIASES = {
"tx": "send",
"rx": "receive",
}
class AliasedGroup(click.Group):
def get_command(self, ctx, cmd_name):
cmd_name = ALIASES.get(cmd_name, cmd_name)
return click.Group.get_command(self, ctx, cmd_name)
@click.group()
#@click.command(cls=AliasedGroup)
@click.option("--relay-url", default=public_relay.RENDEZVOUS_RELAY,
metavar="URL",
help="rendezvous relay to use",
)
@click.option("--transit-helper", default=public_relay.TRANSIT_RELAY,
metavar="tcp:HOST:PORT",
help="transit relay to use",
)
@click.option("-c", "--code-length", default=2,
metavar="NUMWORDS",
help="length of code (in bytes/words)",
)
@click.option("-v", "--verify", is_flag=True, default=False,
help="display (and wait for acceptance of) verification string",
)
@click.option("--hide-progress", is_flag=True, default=False,
help="supress progress-bar display",
)
@click.option("--dump-timing", type=type(u""), # TODO: hide from --help output
default=None,
metavar="FILE.json",
help="(debug) write timing data to file")
@click.option("--no-listen", is_flag=True, default=False,
help="(debug) don't open a listening socket for Transit")
@click.option("--tor", is_flag=True, default=True,
help="use Tor when connecting")
@click.version_option(message="magic-wormhole %(version)s", version=__version__)
@click.pass_context
def cli(ctx, relay_url, transit_helper):
"""
Create a Magic Wormhole and communicate through it. Wormholes are created
by speaking the same magic CODE in two different places at the same time.
Wormholes are secure against anyone who doesn't use the same code."""
ctx.obj = Common(relay_url)
@cli.command()
@click.argument("what")
@click.pass_obj
def send(obj, what):
"""Send a text message, file, or directory"""
print obj
print what
@cli.command()
@click.argument("what")
@click.pass_obj
def receive(obj, what):
"""Receive anything sent by 'wormhole send'."""
print obj
print what
# for now, use "python -m wormhole.cli.cli_args --version", etc
if __name__ == "__main__":
cli()
import argparse
parser = argparse.ArgumentParser(
usage="wormhole SUBCOMMAND (subcommand-options)",
description=dedent("""
Create a Magic Wormhole and communicate through it. Wormholes are created
by speaking the same magic CODE in two different places at the same time.
Wormholes are secure against anyone who doesn't use the same code."""),
)
parser.add_argument("--version", action="version",
version="magic-wormhole "+ __version__)
g = parser.add_argument_group("wormhole configuration options")
parser.set_defaults(timing=None)
subparsers = parser.add_subparsers(title="subcommands",
dest="subcommand")
# CLI: run-server
s = subparsers.add_parser("server", description="Start/stop a relay server")
sp = s.add_subparsers(title="subcommands", dest="subcommand")
sp_start = sp.add_parser("start", description="Start a relay server",
usage="wormhole server start [opts] [TWISTD-ARGS..]")
sp_start.add_argument("--rendezvous", default="tcp:3000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port")
sp_start.add_argument("--transit", default="tcp:3001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port")
sp_start.add_argument("--advertise-version", metavar="VERSION",
help="version to recommend to clients")
sp_start.add_argument("--blur-usage", default=None, type=int,
metavar="SECONDS",
help="round logged access times to improve privacy")
sp_start.add_argument("-n", "--no-daemon", action="store_true")
#sp_start.add_argument("twistd_args", nargs="*", default=None,
# metavar="[TWISTD-ARGS..]",
# help=dedent("""\
# Additional arguments to pass to twistd"""),
# )
sp_start.set_defaults(func="server/start")
sp_stop = sp.add_parser("stop", description="Stop the relay server",
usage="wormhole server stop")
sp_stop.set_defaults(func="server/stop")
sp_restart = sp.add_parser("restart", description="Restart the relay server",
usage="wormhole server restart")
sp_restart.add_argument("--rendezvous", default="tcp:3000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port")
sp_restart.add_argument("--transit", default="tcp:3001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port")
sp_restart.add_argument("--advertise-version", metavar="VERSION",
help="version to recommend to clients")
sp_restart.add_argument("--blur-usage", default=None, type=int,
metavar="SECONDS",
help="round logged access times to improve privacy")
sp_restart.add_argument("-n", "--no-daemon", action="store_true")
sp_restart.set_defaults(func="server/restart")
sp_show_usage = sp.add_parser("show-usage", description="Display usage data",
usage="wormhole server show-usage")
sp_show_usage.add_argument("-n", default=100, type=int,
help="show last N entries")
sp_show_usage.set_defaults(func="usage/usage")
sp_tail_usage = sp.add_parser("tail-usage", description="Follow latest usage",
usage="wormhole server tail-usage")
sp_tail_usage.set_defaults(func="usage/tail")
# CLI: send
p = subparsers.add_parser("send",
description="Send text message, file, or directory",
usage="wormhole send [FILENAME|DIRNAME]")
p.add_argument("--text", metavar="MESSAGE",
help="text message to send, instead of a file. Use '-' to read from stdin.")
p.add_argument("--code", metavar="CODE", help="human-generated code phrase",
type=type(u""))
p.add_argument("-0", dest="zeromode", action="store_true",
help="enable no-code anything-goes mode")
p.add_argument("what", nargs="?", default=None, metavar="[FILENAME|DIRNAME]",
help="the file/directory to send")
p.set_defaults(func="send/send")
# CLI: receive
p = subparsers.add_parser("receive",
description="Receive a text message, file, or directory",
usage="wormhole receive [CODE]")
p.add_argument("-0", dest="zeromode", action="store_true",
help="enable no-code anything-goes mode")
p.add_argument("-t", "--only-text", dest="only_text", action="store_true",
help="refuse file transfers, only accept text transfers")
p.add_argument("--accept-file", dest="accept_file", action="store_true",
help="accept file transfer with asking for confirmation")
p.add_argument("-o", "--output-file", default=None, metavar="FILENAME|DIRNAME",
help=dedent("""\
The file or directory to create, overriding the name suggested
by the sender."""),
)
p.add_argument("code", nargs="?", default=None, metavar="[CODE]",
help=dedent("""\
The magic-wormhole code, from the sender. If omitted, the
program will ask for it, using tab-completion."""),
type=type(u""),
)
p.set_defaults(func="receive/receive")

View File

@ -141,7 +141,7 @@ class TwistedReceiver:
@inlineCallbacks @inlineCallbacks
def _build_transit(self, w, sender_transit): def _build_transit(self, w, sender_transit):
tr = TransitReceiver(self.args.transit_helper, tr = TransitReceiver(self.args.transit_helper,
no_listen=self.args.no_listen, no_listen=(not self.args.listen),
tor_manager=self._tor_manager, tor_manager=self._tor_manager,
reactor=self._reactor, reactor=self._reactor,
timing=self.args.timing) timing=self.args.timing)

View File

@ -48,9 +48,10 @@ class Sender:
w = wormhole(APPID, self._args.relay_url, w = wormhole(APPID, self._args.relay_url,
self._reactor, self._tor_manager, self._reactor, self._tor_manager,
timing=self._timing) timing=self._timing)
d = self._go(w) try:
d.addBoth(w.close) yield self._go(w)
yield d finally:
w.close()
def _send_data(self, data, w): def _send_data(self, data, w):
data_bytes = dict_to_bytes(data) data_bytes = dict_to_bytes(data)
@ -100,7 +101,7 @@ class Sender:
if self._fd_to_send: if self._fd_to_send:
ts = TransitSender(args.transit_helper, ts = TransitSender(args.transit_helper,
no_listen=args.no_listen, no_listen=(not args.listen),
tor_manager=self._tor_manager, tor_manager=self._tor_manager,
reactor=self._reactor, reactor=self._reactor,
timing=self._timing) timing=self._timing)

View File

@ -1,81 +0,0 @@
from __future__ import print_function
import time
start = time.time()
import os, sys, textwrap
from twisted.internet.defer import maybeDeferred
from twisted.internet.task import react
from ..errors import (TransferError, WrongPasswordError, WelcomeError, Timeout,
KeyFormatError)
from ..timing import DebugTiming
from .cli_args import parser
top_import_finish = time.time()
def dispatch(args): # returns Deferred
if args.func == "send/send":
with args.timing.add("import", which="cmd_send"):
from . import cmd_send
return cmd_send.send(args)
if args.func == "receive/receive":
with args.timing.add("import", which="cmd_receive"):
from . import cmd_receive
return cmd_receive.receive(args)
raise ValueError("unknown args.func %s" % args.func)
def run(reactor, argv, cwd, stdout, stderr, executable=None):
"""This is invoked by entry() below, and can also be invoked directly by
tests.
"""
args = parser.parse_args(argv)
if not getattr(args, "func", None):
# So far this only works on py3. py2 exits with a really terse
# "error: too few arguments" during parse_args().
parser.print_help()
sys.exit(0)
args.cwd = cwd
args.stdout = stdout
args.stderr = stderr
args.timing = timing = DebugTiming()
timing.add("command dispatch")
timing.add("import", when=start, which="top").finish(when=top_import_finish)
# fires with None, or raises an error
d = maybeDeferred(dispatch, args)
def _maybe_dump_timing(res):
timing.add("exit")
if args.dump_timing:
timing.write(args.dump_timing, stderr)
return res
d.addBoth(_maybe_dump_timing)
def _explain_error(f):
# these errors don't print a traceback, just an explanation
f.trap(TransferError, WrongPasswordError, WelcomeError, Timeout,
KeyFormatError)
if f.check(WrongPasswordError):
msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__))
print(msg, file=stderr)
elif f.check(WelcomeError):
msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__))
print(msg, file=stderr)
print(file=stderr)
print(str(f.value), file=stderr)
elif f.check(KeyFormatError):
msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__))
print(msg, file=stderr)
else:
print("ERROR:", f.value, file=stderr)
raise SystemExit(1)
d.addErrback(_explain_error)
d.addCallback(lambda _: 0)
return d
def entry():
"""This is used by a setuptools entry_point. When invoked this way,
setuptools has already put the installed package on sys.path ."""
react(run, (sys.argv[1:], os.getcwd(), sys.stdout, sys.stderr,
sys.argv[0]))
if __name__ == "__main__":
args = parser.parse_args()
print(args)

162
src/wormhole/server/cli.py Normal file
View File

@ -0,0 +1,162 @@
from __future__ import print_function
import click
# can put this back in to get this command as "wormhole server"
# instead
#from ..cli.cli import wormhole
#@wormhole.group()
@click.group()
@click.pass_context
def server(ctx):
"""
Control a relay server (most users shouldn't need to worry
about this and can use the default server).
"""
# just leaving this pointing to wormhole.cli.cli.Config for now,
# but if we want to keep wormhole-server as a separate command
# should probably have our own Config without all the options the
# server commands don't use
from ..cli.cli import Config
ctx.obj = Config()
@server.command()
@click.option(
"--rendezvous", default="tcp:3000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port",
)
@click.option(
"--transit", default="tcp:3001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port",
)
@click.option(
"--advertise-version", metavar="VERSION",
help="version to recommend to clients",
)
@click.option(
"--blur-usage", default=None, type=int,
metavar="SECONDS",
help="round logged access times to improve privacy",
)
@click.option(
"--no-daemon", "-n", is_flag=True,
help="Run in the foreground",
)
@click.option(
"--signal-error", is_flag=True,
help="force all clients to fail with a message",
)
@click.pass_obj
def start(cfg, signal_error, no_daemon, blur_usage, advertise_version, transit, rendezvous):
"""
Start a relay server
"""
with cfg.timing.add("import", which="cmd_server_start"):
from wormhole.server.cmd_server import start_server
cfg.no_daemon = no_daemon
cfg.blur_usage = blur_usage
cfg.advertise_version = advertise_version
cfg.transit = str(transit)
cfg.rendezvous = str(rendezvous)
cfg.signal_error = signal_error
start_server(cfg)
# XXX it would be nice to reduce the duplication between 'restart' and
# 'start' options...
@server.command()
@click.option(
"--rendezvous", default="tcp:3000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port",
)
@click.option(
"--transit", default="tcp:3001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port",
)
@click.option(
"--advertise-version", metavar="VERSION",
help="version to recommend to clients",
)
@click.option(
"--blur-usage", default=None, type=int,
metavar="SECONDS",
help="round logged access times to improve privacy",
)
@click.option(
"--no-daemon", "-n", is_flag=True,
help="Run in the foreground",
)
@click.option(
"--signal-error", is_flag=True,
help="force all clients to fail with a message",
)
@click.pass_obj
def restart(cfg, signal_error, no_daemon, blur_usage, advertise_version, transit, rendezvous):
"""
Re-start a relay server
"""
with cfg.timing.add("import", which="cmd_server_restart"):
from wormhole.server.cmd_server import restart_server
cfg.no_daemon = no_daemon
cfg.blur_usage = blur_usage
cfg.advertise_version = advertise_version
cfg.transit = str(transit)
cfg.rendezvous = str(rendezvous)
cfg.signal_error = signal_error
restart_server(cfg)
@server.command()
@click.pass_obj
def stop(cfg):
"""
Stop a relay server
"""
with cfg.timing.add("import", which="cmd_server_stop"):
from wormhole.server.cmd_server import stop_server
stop_server(cfg)
@server.command(name="tail-usage")
@click.pass_obj
def tail_usage(cfg):
"""
Follow the latest usage
"""
with cfg.timing.add("import", which="cmd_tail_usage"):
from wormhole.server.cmd_usage import tail_usage
tail_usage(cfg)
@server.command(name='count-channels')
@click.option(
"--json", is_flag=True,
)
@click.pass_obj
def count_channels(cfg, json):
"""
Count active channels
"""
with cfg.timing.add("import", which="cmd_count_channels"):
from wormhole.server.cmd_usage import count_channels
cfg.json = json
count_channels(cfg)
@server.command(name='count-events')
@click.option(
"--json", is_flag=True,
)
@click.pass_obj
def count_events(cfg, json):
"""
Count events
"""
with cfg.timing.add("import", which="cmd_count_events"):
from wormhole.server.cmd_usage import count_events
cfg.json = json
count_events(cfg)

View File

@ -1,77 +0,0 @@
import argparse
from textwrap import dedent
from .. import __version__
parser = argparse.ArgumentParser(
usage="wormhole-server SUBCOMMAND (subcommand-options)",
description=dedent("""
Create a Magic Wormhole and communicate through it. Wormholes are created
by speaking the same magic CODE in two different places at the same time.
Wormholes are secure against anyone who doesn't use the same code."""),
)
parser.add_argument("--version", action="version",
version="magic-wormhole "+ __version__)
s = parser.add_subparsers(title="subcommands", dest="subcommand")
# CLI: run-server
sp_start = s.add_parser("start", description="Start a relay server",
usage="wormhole server start [opts] [TWISTD-ARGS..]")
sp_start.add_argument("--rendezvous", default="tcp:4000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port")
sp_start.add_argument("--transit", default="tcp:4001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port")
sp_start.add_argument("--advertise-version", metavar="VERSION",
help="version to recommend to clients")
sp_start.add_argument("--signal-error", metavar="ERROR",
help="force all clients to fail with a message")
sp_start.add_argument("--blur-usage", default=None, type=int,
metavar="SECONDS",
help="round logged access times to improve privacy")
sp_start.add_argument("-n", "--no-daemon", action="store_true")
#sp_start.add_argument("twistd_args", nargs="*", default=None,
# metavar="[TWISTD-ARGS..]",
# help=dedent("""\
# Additional arguments to pass to twistd"""),
# )
sp_start.set_defaults(func="server/start")
sp_stop = s.add_parser("stop", description="Stop the relay server",
usage="wormhole server stop")
sp_stop.set_defaults(func="server/stop")
sp_restart = s.add_parser("restart", description="Restart the relay server",
usage="wormhole server restart")
sp_restart.add_argument("--rendezvous", default="tcp:4000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port")
sp_restart.add_argument("--transit", default="tcp:4001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port")
sp_restart.add_argument("--advertise-version", metavar="VERSION",
help="version to recommend to clients")
sp_restart.add_argument("--signal-error", metavar="ERROR",
help="force all clients to fail with a message")
sp_restart.add_argument("--blur-usage", default=None, type=int,
metavar="SECONDS",
help="round logged access times to improve privacy")
sp_restart.add_argument("-n", "--no-daemon", action="store_true")
sp_restart.set_defaults(func="server/restart")
sp_show_usage = s.add_parser("show-usage", description="Display usage data",
usage="wormhole server show-usage")
sp_show_usage.add_argument("-n", default=100, type=int,
help="show last N entries")
sp_show_usage.set_defaults(func="usage/usage")
sp_tail_usage = s.add_parser("tail-usage", description="Follow latest usage",
usage="wormhole server tail-usage")
sp_tail_usage.set_defaults(func="usage/tail")
sp_count_channels = s.add_parser("count-channels",
description="Count active channels")
sp_count_channels.add_argument("--json", action="store_true")
sp_count_channels.set_defaults(func="usage/count-channels")
sp_count_events = s.add_parser("count-events", description="Count events")
sp_count_events.add_argument("--json", action="store_true")
sp_count_events.set_defaults(func="usage/count-events")

View File

@ -39,15 +39,15 @@ def kill_server():
f = open("twistd.pid", "r") f = open("twistd.pid", "r")
except EnvironmentError: except EnvironmentError:
print("Unable to find twistd.pid . Is this really a server directory?") print("Unable to find twistd.pid . Is this really a server directory?")
return 1 return
pid = int(f.read().strip()) pid = int(f.read().strip())
f.close() f.close()
os.kill(pid, 15) os.kill(pid, 15)
print("server process %d sent SIGTERM" % pid) print("server process %d sent SIGTERM" % pid)
return 0 return
def stop_server(args): def stop_server(args):
return kill_server() kill_server()
def restart_server(args): def restart_server(args):
kill_server() kill_server()

View File

@ -1,6 +1,6 @@
# no unicode_literals untill twisted update # no unicode_literals untill twisted update
from twisted.application import service from twisted.application import service
from twisted.internet import reactor, defer from twisted.internet import defer, task
from twisted.python import log from twisted.python import log
from ..transit import allocate_tcp_port from ..transit import allocate_tcp_port
from ..server.server import RelayServer from ..server.server import RelayServer
@ -36,8 +36,18 @@ class ServerBase:
# relay's .stopService() drops all connections, which ought to # relay's .stopService() drops all connections, which ought to
# encourage those threads to terminate soon. If they don't, print a # encourage those threads to terminate soon. If they don't, print a
# warning to ease debugging. # warning to ease debugging.
# XXX FIXME there's something in _noclobber test that's not
# waiting for a close, I think -- was pretty relieably getting
# unclean-reactor, but adding a slight pause here stops it...
from twisted.internet import reactor
tp = reactor.getThreadPool() tp = reactor.getThreadPool()
if not tp.working: if not tp.working:
d = defer.succeed(None)
d.addCallback(lambda _: self.sp.stopService())
d.addCallback(lambda _: task.deferLater(reactor, 0.1, lambda: None))
return d
return self.sp.stopService() return self.sp.stopService()
# disconnect all callers # disconnect all callers
d = defer.maybeDeferred(self.sp.stopService) d = defer.maybeDeferred(self.sp.stopService)

View File

@ -7,33 +7,32 @@ from twisted.internet.utils import getProcessOutputAndValue
from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.internet.defer import gatherResults, inlineCallbacks
from .. import __version__ from .. import __version__
from .common import ServerBase from .common import ServerBase
from ..cli import runner, cmd_send, cmd_receive from ..cli import cmd_send, cmd_receive
from ..cli.cli import Config
from ..errors import TransferError, WrongPasswordError, WelcomeError from ..errors import TransferError, WrongPasswordError, WelcomeError
from ..timing import DebugTiming
def build_offer(args): def build_offer(args):
s = cmd_send.Sender(args, None) s = cmd_send.Sender(args, None)
return s._build_offer() return s._build_offer()
class OfferData(unittest.TestCase): class OfferData(unittest.TestCase):
def setUp(self): def setUp(self):
self._things_to_delete = [] self._things_to_delete = []
self.cfg = cfg = Config()
cfg.stdout = io.StringIO()
cfg.stderr = io.StringIO()
def tearDown(self): def tearDown(self):
for fn in self._things_to_delete: for fn in self._things_to_delete:
if os.path.exists(fn): if os.path.exists(fn):
os.unlink(fn) os.unlink(fn)
del self.cfg
def test_text(self): def test_text(self):
message = "blah blah blah ponies" self.cfg.text = message = "blah blah blah ponies"
d, fd_to_send = build_offer(self.cfg)
send_args = [ "send", "--text", message ]
args = runner.parser.parse_args(send_args)
args.cwd = os.getcwd()
args.stdout = io.StringIO()
args.stderr = io.StringIO()
d, fd_to_send = build_offer(args)
self.assertIn("message", d) self.assertIn("message", d)
self.assertNotIn("file", d) self.assertNotIn("file", d)
@ -42,7 +41,7 @@ class OfferData(unittest.TestCase):
self.assertEqual(fd_to_send, None) self.assertEqual(fd_to_send, None)
def test_file(self): def test_file(self):
filename = "my file" self.cfg.what = filename = "my file"
message = b"yay ponies\n" message = b"yay ponies\n"
send_dir = self.mktemp() send_dir = self.mktemp()
os.mkdir(send_dir) os.mkdir(send_dir)
@ -50,13 +49,8 @@ class OfferData(unittest.TestCase):
with open(abs_filename, "wb") as f: with open(abs_filename, "wb") as f:
f.write(message) f.write(message)
send_args = [ "send", filename ] self.cfg.cwd = send_dir
args = runner.parser.parse_args(send_args) d, fd_to_send = build_offer(self.cfg)
args.cwd = send_dir
args.stdout = io.StringIO()
args.stderr = io.StringIO()
d, fd_to_send = build_offer(args)
self.assertNotIn("message", d) self.assertNotIn("message", d)
self.assertIn("file", d) self.assertIn("file", d)
@ -67,17 +61,12 @@ class OfferData(unittest.TestCase):
self.assertEqual(fd_to_send.read(), message) self.assertEqual(fd_to_send.read(), message)
def test_missing_file(self): def test_missing_file(self):
filename = "missing" self.cfg.what = filename = "missing"
send_dir = self.mktemp() send_dir = self.mktemp()
os.mkdir(send_dir) os.mkdir(send_dir)
self.cfg.cwd = send_dir
send_args = [ "send", filename ] e = self.assertRaises(TransferError, build_offer, self.cfg)
args = runner.parser.parse_args(send_args)
args.cwd = send_dir
args.stdout = io.StringIO()
args.stderr = io.StringIO()
e = self.assertRaises(TransferError, build_offer, args)
self.assertEqual(str(e), self.assertEqual(str(e),
"Cannot send: no file/directory named '%s'" % filename) "Cannot send: no file/directory named '%s'" % filename)
@ -94,13 +83,10 @@ class OfferData(unittest.TestCase):
send_dir_arg = send_dir send_dir_arg = send_dir
if addslash: if addslash:
send_dir_arg += os.sep send_dir_arg += os.sep
send_args = [ "send", send_dir_arg ] self.cfg.what = send_dir_arg
args = runner.parser.parse_args(send_args) self.cfg.cwd = parent_dir
args.cwd = parent_dir
args.stdout = io.StringIO()
args.stderr = io.StringIO()
d, fd_to_send = build_offer(args) d, fd_to_send = build_offer(self.cfg)
self.assertNotIn("message", d) self.assertNotIn("message", d)
self.assertNotIn("file", d) self.assertNotIn("file", d)
@ -130,10 +116,11 @@ class OfferData(unittest.TestCase):
return self._do_test_directory(addslash=True) return self._do_test_directory(addslash=True)
def test_unknown(self): def test_unknown(self):
filename = "unknown" self.cfg.what = filename = "unknown"
send_dir = self.mktemp() send_dir = self.mktemp()
os.mkdir(send_dir) os.mkdir(send_dir)
abs_filename = os.path.abspath(os.path.join(send_dir, filename)) abs_filename = os.path.abspath(os.path.join(send_dir, filename))
self.cfg.cwd = send_dir
try: try:
os.mkfifo(abs_filename) os.mkfifo(abs_filename)
@ -149,13 +136,7 @@ class OfferData(unittest.TestCase):
self.assertFalse(os.path.isfile(abs_filename)) self.assertFalse(os.path.isfile(abs_filename))
self.assertFalse(os.path.isdir(abs_filename)) self.assertFalse(os.path.isdir(abs_filename))
send_args = [ "send", filename ] e = self.assertRaises(TypeError, build_offer, self.cfg)
args = runner.parser.parse_args(send_args)
args.cwd = send_dir
args.stdout = io.StringIO()
args.stderr = io.StringIO()
e = self.assertRaises(TypeError, build_offer, args)
self.assertEqual(str(e), self.assertEqual(str(e),
"'%s' is neither file nor directory" % filename) "'%s' is neither file nor directory" % filename)
@ -203,27 +184,24 @@ class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase):
# we need Twisted to run the server, but we run the sender and receiver # we need Twisted to run the server, but we run the sender and receiver
# with deferToThread() # with deferToThread()
@inlineCallbacks
def test_version(self): def test_version(self):
# "wormhole" must be on the path, so e.g. "pip install -e ." in a # "wormhole" must be on the path, so e.g. "pip install -e ." in a
# virtualenv. This guards against an environment where the tests # virtualenv. This guards against an environment where the tests
# below might run the wrong executable. # below might run the wrong executable.
self.maxDiff = None
wormhole = self.find_executable() wormhole = self.find_executable()
d = getProcessOutputAndValue(wormhole, ["--version"]) # we must pass on the environment so that "something" doesn't
def _check(res): # get sad about UTF8 vs. ascii encodings
out, err, rc = res out, err, rc = yield getProcessOutputAndValue(wormhole, ["--version"], env=os.environ)
# argparse on py2 and py3.3 sends --version to stderr err = err.decode("utf-8")
# argparse on py3.4/py3.5 sends --version to stdout if "DistributionNotFound" in err:
# aargh log.msg("stderr was %s" % err)
err = err.decode("utf-8") last = err.strip().split("\n")[-1]
if "DistributionNotFound" in err: self.fail("wormhole not runnable: %s" % last)
log.msg("stderr was %s" % err) ver = out.decode("utf-8") or err
last = err.strip().split("\n")[-1] self.failUnlessEqual(ver.strip(), "magic-wormhole {}".format(__version__))
self.fail("wormhole not runnable: %s" % last) self.failUnlessEqual(rc, 0)
ver = out.decode("utf-8") or err
self.failUnlessEqual(ver, "magic-wormhole "+__version__+os.linesep)
self.failUnlessEqual(rc, 0)
d.addCallback(_check)
return d
class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
# we need Twisted to run the server, but we run the sender and receiver # we need Twisted to run the server, but we run the sender and receiver
@ -238,20 +216,18 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
def _do_test(self, as_subprocess=False, def _do_test(self, as_subprocess=False,
mode="text", addslash=False, override_filename=False): mode="text", addslash=False, override_filename=False):
assert mode in ("text", "file", "directory") assert mode in ("text", "file", "directory")
common_args = ["--hide-progress", send_cfg = Config()
"--relay-url", self.relayurl, recv_cfg = Config()
"--transit-helper", ""] message = "blah blah blah ponies"
code = "1-abc"
message = "test message"
send_args = common_args + [ for cfg in [send_cfg, recv_cfg]:
"send", cfg.hide_progress = True
"--code", code, cfg.relay_url = self.relayurl
] cfg.transit_helper = ""
cfg.listen = True
receive_args = common_args + [ cfg.code = "1-abc"
"receive", cfg.stdout = io.StringIO()
] cfg.stderr = io.StringIO()
send_dir = self.mktemp() send_dir = self.mktemp()
os.mkdir(send_dir) os.mkdir(send_dir)
@ -259,19 +235,18 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
os.mkdir(receive_dir) os.mkdir(receive_dir)
if mode == "text": if mode == "text":
send_args.extend(["--text", message]) send_cfg.text = message
elif mode == "file": elif mode == "file":
send_filename = "testfile" send_filename = "testfile"
with open(os.path.join(send_dir, send_filename), "w") as f: with open(os.path.join(send_dir, send_filename), "w") as f:
f.write(message) f.write(message)
send_args.append(send_filename) send_cfg.what = send_filename
receive_filename = send_filename receive_filename = send_filename
receive_args.append("--accept-file") recv_cfg.accept_file = True
if override_filename: if override_filename:
receive_args.extend(["-o", "outfile"]) recv_cfg.output_file = receive_filename = "outfile"
receive_filename = "outfile"
elif mode == "directory": elif mode == "directory":
# $send_dir/ # $send_dir/
@ -299,22 +274,48 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
send_dirname_arg = os.path.join("middle", send_dirname) send_dirname_arg = os.path.join("middle", send_dirname)
if addslash: if addslash:
send_dirname_arg += os.sep send_dirname_arg += os.sep
send_args.append(send_dirname_arg) send_cfg.what = send_dirname_arg
receive_dirname = send_dirname receive_dirname = send_dirname
receive_args.append("--accept-file") recv_cfg.accept_file = True
if override_filename: if override_filename:
receive_args.extend(["-o", "outdir"]) recv_cfg.output_file = receive_dirname = "outdir"
receive_dirname = "outdir"
receive_args.append(code)
if as_subprocess: if as_subprocess:
wormhole_bin = self.find_executable() wormhole_bin = self.find_executable()
send_d = getProcessOutputAndValue(wormhole_bin, send_args, if send_cfg.text:
path=send_dir) content_args = ['--text', send_cfg.text]
receive_d = getProcessOutputAndValue(wormhole_bin, receive_args, elif send_cfg.what:
path=receive_dir) content_args = [send_cfg.what]
send_args = [
'--hide-progress',
'--relay-url', self.relayurl,
'--transit-helper', '',
'send',
'--code', send_cfg.code,
] + content_args
send_d = getProcessOutputAndValue(
wormhole_bin, send_args,
path=send_dir,
)
recv_args = [
'--hide-progress',
'--relay-url', self.relayurl,
'--transit-helper', '',
'receive',
'--accept-file',
recv_cfg.code,
]
if override_filename:
recv_args.extend(['-o', receive_filename])
receive_d = getProcessOutputAndValue(
wormhole_bin, recv_args,
path=receive_dir,
)
(send_res, receive_res) = yield gatherResults([send_d, receive_d], (send_res, receive_res) = yield gatherResults([send_d, receive_d],
True) True)
send_stdout = send_res[0].decode("utf-8") send_stdout = send_res[0].decode("utf-8")
@ -327,27 +328,21 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.assertEqual((send_rc, receive_rc), (0, 0), self.assertEqual((send_rc, receive_rc), (0, 0),
(send_res, receive_res)) (send_res, receive_res))
else: else:
sargs = runner.parser.parse_args(send_args) send_cfg.cwd = send_dir
sargs.cwd = send_dir send_d = cmd_send.send(send_cfg)
sargs.stdout = io.StringIO()
sargs.stderr = io.StringIO() recv_cfg.cwd = receive_dir
sargs.timing = DebugTiming() receive_d = cmd_receive.receive(recv_cfg)
rargs = runner.parser.parse_args(receive_args)
rargs.cwd = receive_dir
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
send_d = cmd_send.send(sargs)
receive_d = cmd_receive.receive(rargs)
# The sender might fail, leaving the receiver hanging, or vice # The sender might fail, leaving the receiver hanging, or vice
# versa. Make sure we don't wait on one side exclusively # versa. Make sure we don't wait on one side exclusively
yield gatherResults([send_d, receive_d], True) yield gatherResults([send_d, receive_d], True)
send_stdout = sargs.stdout.getvalue() # XXX need captured stdin/stdout from sender/receiver
send_stderr = sargs.stderr.getvalue() send_stdout = send_cfg.stdout.getvalue()
receive_stdout = rargs.stdout.getvalue() send_stderr = send_cfg.stderr.getvalue()
receive_stderr = rargs.stderr.getvalue() receive_stdout = recv_cfg.stdout.getvalue()
receive_stderr = recv_cfg.stderr.getvalue()
# all output here comes from a StringIO, which uses \n for # all output here comes from a StringIO, which uses \n for
# newlines, even if we're on windows # newlines, even if we're on windows
@ -367,7 +362,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
"wormhole receive{NL}" "wormhole receive{NL}"
"Wormhole code is: {code}{NL}{NL}" "Wormhole code is: {code}{NL}{NL}"
"text message sent{NL}").format(bytes=len(message), "text message sent{NL}").format(bytes=len(message),
code=code, code=send_cfg.code,
NL=NL) NL=NL)
self.failUnlessEqual(send_stdout, expected) self.failUnlessEqual(send_stdout, expected)
elif mode == "file": elif mode == "file":
@ -377,7 +372,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.failUnlessIn("On the other computer, please run: " self.failUnlessIn("On the other computer, please run: "
"wormhole receive{NL}" "wormhole receive{NL}"
"Wormhole code is: {code}{NL}{NL}" "Wormhole code is: {code}{NL}{NL}"
.format(code=code, NL=NL), .format(code=send_cfg.code, NL=NL),
send_stdout) send_stdout)
self.failUnlessIn("File sent.. waiting for confirmation{NL}" self.failUnlessIn("File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}" "Confirmation received. Transfer complete.{NL}"
@ -388,7 +383,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.failUnlessIn("On the other computer, please run: " self.failUnlessIn("On the other computer, please run: "
"wormhole receive{NL}" "wormhole receive{NL}"
"Wormhole code is: {code}{NL}{NL}" "Wormhole code is: {code}{NL}{NL}"
.format(code=code, NL=NL), send_stdout) .format(code=send_cfg.code, NL=NL), send_stdout)
self.failUnlessIn("File sent.. waiting for confirmation{NL}" self.failUnlessIn("File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}" "Confirmation received. Transfer complete.{NL}"
.format(NL=NL), send_stdout) .format(NL=NL), send_stdout)
@ -440,14 +435,21 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_file_noclobber(self): def test_file_noclobber(self):
common_args = ["--hide-progress", "--no-listen", send_cfg = Config()
"--relay-url", self.relayurl, recv_cfg = Config()
"--transit-helper", ""]
code = "1-abc" for cfg in [send_cfg, recv_cfg]:
cfg.hide_progress = True
cfg.relay_url = self.relayurl
cfg.transit_helper = ""
cfg.listen = False
cfg.code = code = "1-abc"
cfg.stdout = io.StringIO()
cfg.stderr = io.StringIO()
message = "test message" message = "test message"
send_args = common_args + [ "send", "--code", code ] recv_cfg.accept_file = True
receive_args = common_args + [ "receive", "--accept-file", code ]
send_dir = self.mktemp() send_dir = self.mktemp()
os.mkdir(send_dir) os.mkdir(send_dir)
@ -457,26 +459,19 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
send_filename = "testfile" send_filename = "testfile"
with open(os.path.join(send_dir, send_filename), "w") as f: with open(os.path.join(send_dir, send_filename), "w") as f:
f.write(message) f.write(message)
send_args.append(send_filename) send_cfg.what = receive_filename = send_filename
receive_filename = send_filename recv_cfg.what = receive_filename
PRESERVE = "don't clobber me\n" PRESERVE = "don't clobber me\n"
clobberable = os.path.join(receive_dir, receive_filename) clobberable = os.path.join(receive_dir, receive_filename)
with open(clobberable, "w") as f: with open(clobberable, "w") as f:
f.write(PRESERVE) f.write(PRESERVE)
sargs = runner.parser.parse_args(send_args) send_cfg.cwd = send_dir
sargs.cwd = send_dir send_d = cmd_send.send(send_cfg)
sargs.stdout = io.StringIO()
sargs.stderr = io.StringIO() recv_cfg.cwd = receive_dir
sargs.timing = DebugTiming() receive_d = cmd_receive.receive(recv_cfg)
rargs = runner.parser.parse_args(receive_args)
rargs.cwd = receive_dir
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
send_d = cmd_send.send(sargs)
receive_d = cmd_receive.receive(rargs)
# both sides will fail because of the pre-existing file # both sides will fail because of the pre-existing file
@ -486,10 +481,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
f = yield self.assertFailure(receive_d, TransferError) f = yield self.assertFailure(receive_d, TransferError)
self.assertEqual(str(f), "file already exists") self.assertEqual(str(f), "file already exists")
send_stdout = sargs.stdout.getvalue() send_stdout = send_cfg.stdout.getvalue()
send_stderr = sargs.stderr.getvalue() send_stderr = send_cfg.stderr.getvalue()
receive_stdout = rargs.stdout.getvalue() receive_stdout = recv_cfg.stdout.getvalue()
receive_stderr = rargs.stderr.getvalue() receive_stderr = recv_cfg.stderr.getvalue()
# all output here comes from a StringIO, which uses \n for # all output here comes from a StringIO, which uses \n for
# newlines, even if we're on windows # newlines, even if we're on windows
@ -528,63 +523,56 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
class NotWelcome(ServerBase, unittest.TestCase): class NotWelcome(ServerBase, unittest.TestCase):
def setUp(self): def setUp(self):
self._setup_relay(error="please upgrade XYZ") self._setup_relay(error="please upgrade XYZ")
self.cfg = cfg = Config()
cfg.hide_progress = True
cfg.listen = False
cfg.relay_url = self.relayurl
cfg.transit_helper = ""
cfg.stdout = io.StringIO()
cfg.stderr = io.StringIO()
@inlineCallbacks @inlineCallbacks
def test_sender(self): def test_sender(self):
common_args = ["--hide-progress", "--no-listen", self.cfg.text = "hi"
"--relay-url", self.relayurl, self.cfg.code = "1-abc"
"--transit-helper", ""]
send_args = common_args + [ "send", "--text", "hi",
"--code", "1-abc" ]
sargs = runner.parser.parse_args(send_args)
sargs.cwd = self.mktemp()
sargs.stdout = io.StringIO()
sargs.stderr = io.StringIO()
sargs.timing = DebugTiming()
send_d = cmd_send.send(sargs) send_d = cmd_send.send(self.cfg)
f = yield self.assertFailure(send_d, WelcomeError) f = yield self.assertFailure(send_d, WelcomeError)
self.assertEqual(str(f), "please upgrade XYZ") self.assertEqual(str(f), "please upgrade XYZ")
@inlineCallbacks @inlineCallbacks
def test_receiver(self): def test_receiver(self):
common_args = ["--hide-progress", "--no-listen", self.cfg.code = "1-abc"
"--relay-url", self.relayurl,
"--transit-helper", ""]
receive_args = common_args + [ "receive", "1-abc" ]
rargs = runner.parser.parse_args(receive_args)
rargs.cwd = self.mktemp()
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
receive_d = cmd_receive.receive(rargs) receive_d = cmd_receive.receive(self.cfg)
f = yield self.assertFailure(receive_d, WelcomeError) f = yield self.assertFailure(receive_d, WelcomeError)
self.assertEqual(str(f), "please upgrade XYZ") self.assertEqual(str(f), "please upgrade XYZ")
class Cleanup(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_text(self):
# the rendezvous channel should be deleted after success
code = "1-abc"
common_args = ["--hide-progress",
"--relay-url", self.relayurl,
"--transit-helper", ""]
sargs = runner.parser.parse_args(common_args +
["send",
"--text", "secret message",
"--code", code])
sargs.stdout = io.StringIO()
sargs.stderr = io.StringIO()
sargs.timing = DebugTiming()
rargs = runner.parser.parse_args(common_args +
["receive", code])
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
send_d = cmd_send.send(sargs)
receive_d = cmd_receive.receive(rargs)
class Cleanup(ServerBase, unittest.TestCase):
def setUp(self):
d = super(Cleanup, self).setUp()
self.cfg = cfg = Config()
# common options for all tests in this suite
cfg.hide_progress = True
cfg.relay_url = self.relayurl
cfg.transit_helper = ""
cfg.stdout = io.StringIO()
cfg.stderr = io.StringIO()
return d
@inlineCallbacks
@mock.patch('sys.stdout')
def test_text(self, stdout):
# the rendezvous channel should be deleted after success
self.cfg.text = "hello"
self.cfg.code = "1-abc"
send_d = cmd_send.send(self.cfg)
receive_d = cmd_receive.receive(self.cfg)
# XXX DeferredList?
yield send_d yield send_d
yield receive_d yield receive_d
@ -595,23 +583,12 @@ class Cleanup(ServerBase, unittest.TestCase):
def test_text_wrong_password(self): def test_text_wrong_password(self):
# if the password was wrong, the rendezvous channel should still be # if the password was wrong, the rendezvous channel should still be
# deleted # deleted
common_args = ["--hide-progress", self.cfg.text = "secret message"
"--relay-url", self.relayurl, self.cfg.code = "1-abc"
"--transit-helper", ""] send_d = cmd_send.send(self.cfg)
sargs = runner.parser.parse_args(common_args +
["send", self.cfg.code = "1-WRONG"
"--text", "secret message", receive_d = cmd_receive.receive(self.cfg)
"--code", "1-abc"])
sargs.stdout = io.StringIO()
sargs.stderr = io.StringIO()
sargs.timing = DebugTiming()
rargs = runner.parser.parse_args(common_args +
["receive", "1-WRONG"])
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
send_d = cmd_send.send(sargs)
receive_d = cmd_receive.receive(rargs)
# both sides should be capable of detecting the mismatch # both sides should be capable of detecting the mismatch
yield self.assertFailure(send_d, WrongPasswordError) yield self.assertFailure(send_d, WrongPasswordError)

View File

@ -813,6 +813,10 @@ class Common:
is_relay=True) is_relay=True)
contenders.append(d) contenders.append(d)
if not contenders:
raise RuntimeError("No contenders for connection")
# else:
# print("contend", contenders)
winner = there_can_be_only_one(contenders) winner = there_can_be_only_one(contenders)
return self._not_forever(2*TIMEOUT, winner) return self._not_forever(2*TIMEOUT, winner)

View File

@ -462,8 +462,10 @@ class _Wormhole:
# entry point 3: paste in a fully-formed code # entry point 3: paste in a fully-formed code
def _API_set_code(self, code): def _API_set_code(self, code):
self._timing.add("API set_code") self._timing.add("API set_code")
if not isinstance(code, type("")): raise TypeError(type(code)) if not isinstance(code, type(u"")):
if self._code is not None: raise UsageError raise TypeError("Unexpected code type '{}'".format(type(code)))
if self._code is not None:
raise UsageError
self._event_learned_code(code) self._event_learned_code(code)
# TODO: entry point 4: restore pre-contact saved state (we haven't heard # TODO: entry point 4: restore pre-contact saved state (we haven't heard