Merge branch 'pr47'

This commit is contained in:
Brian Warner 2016-06-22 13:23:35 -07:00
commit 38ebc0d8a4
18 changed files with 642 additions and 523 deletions

View File

@ -19,18 +19,29 @@ setup(name="magic-wormhole",
"wormhole.test", "wormhole.test",
], ],
package_data={"wormhole.server": ["db-schemas/*.sql"]}, package_data={"wormhole.server": ["db-schemas/*.sql"]},
entry_points={"console_scripts": entry_points={
["wormhole = wormhole.cli.runner:entry", "console_scripts":
"wormhole-server = wormhole.server.runner:entry", [
]}, "wormhole = wormhole.cli.cli:wormhole",
install_requires=["spake2==0.7", "pynacl", "argparse", "wormhole-server = wormhole.server.cli:server",
"six", ]
"twisted", },
"autobahn[twisted] >= 0.14.1", install_requires=[
"hkdf", "tqdm", "spake2==0.7", "pynacl",
], "six",
extras_require={':sys_platform=="win32"': ["pypiwin32"], "twisted",
"tor": ["txtorcon", "ipaddress"]}, "autobahn[twisted] >= 0.14.1",
"hkdf", "tqdm",
"click",
],
extras_require={
':sys_platform=="win32"': ["pypiwin32"],
"tor": ["txtorcon", "ipaddress"],
"dev": [
"mock",
"tox",
],
},
test_suite="wormhole.test", test_suite="wormhole.test",
cmdclass=commands, cmdclass=commands,
) )

233
src/wormhole/cli/cli.py Normal file
View File

@ -0,0 +1,233 @@
from __future__ import print_function
import os
import time
start = time.time()
import traceback
from textwrap import fill, dedent
from sys import stdout, stderr
from . import public_relay
from .. import __version__
from ..timing import DebugTiming
from ..errors import WrongPasswordError, WelcomeError, KeyFormatError
from twisted.internet.defer import inlineCallbacks, maybeDeferred
from twisted.internet.task import react
import click
top_import_finish = time.time()
class Config(object):
"""
Union of config options that we pass down to (sub) commands.
"""
def __init__(self):
# common options
self.timing = DebugTiming()
self.tor = None
self.listen = None
self.relay_url = u""
self.transit_helper = u""
self.cwd = os.getcwd()
# send/receive commands
self.code = None
self.code_length = 2
self.verify = False
self.hide_progress = False
self.dump_timing = False
self.stdout = stdout
self.stderr = stderr
self.zeromode = False
self.accept_file = None
self.output_file = None
# send only
self.text = None
self.what = None
ALIASES = {
"tx": "send",
"rx": "receive",
}
class AliasedGroup(click.Group):
def get_command(self, ctx, cmd_name):
cmd_name = ALIASES.get(cmd_name, cmd_name)
return click.Group.get_command(self, ctx, cmd_name)
# top-level command ("wormhole ...")
@click.group(cls=AliasedGroup)
@click.option(
"--relay-url", default=public_relay.RENDEZVOUS_RELAY,
metavar="URL",
help="rendezvous relay to use",
)
@click.option(
"--transit-helper", default=public_relay.TRANSIT_RELAY,
metavar="tcp:HOST:PORT",
help="transit relay to use",
)
@click.option(
"-c", "--code-length", default=2,
metavar="NUMWORDS",
help="length of code (in bytes/words)",
)
@click.option(
"-v", "--verify", is_flag=True, default=False,
help="display (and wait for acceptance of) verification string",
)
@click.option(
"--hide-progress", is_flag=True, default=False,
help="supress progress-bar display",
)
@click.option(
"--dump-timing", type=type(u""), # TODO: hide from --help output
default=None,
metavar="FILE.json",
help="(debug) write timing data to file",
)
@click.option(
"--no-listen", is_flag=True, default=False,
help="(debug) don't open a listening socket for Transit",
)
@click.option(
"--tor", is_flag=True, default=True,
help="use Tor when connecting",
)
@click.version_option(
message="magic-wormhole %(version)s",
version=__version__,
)
@click.pass_context
def wormhole(ctx, tor, no_listen, dump_timing, hide_progress,
verify, code_length, transit_helper, relay_url):
"""
Create a Magic Wormhole and communicate through it.
Wormholes are created by speaking the same magic CODE in two
different places at the same time. Wormholes are secure against
anyone who doesn't use the same code.
"""
ctx.obj = cfg = Config()
ctx.tor = tor
if no_listen:
cfg.listen = False
cfg.relay_url = relay_url
cfg.transit_helper = transit_helper
cfg.code_length = code_length
cfg.verify = verify
cfg.hide_progress = hide_progress
cfg.dump_timing = dump_timing
@inlineCallbacks
def _dispatch_command(reactor, cfg, command):
"""
Internal helper. This calls the given command (a no-argument
callable) with the Config instance in cfg and interprets any
errors for the user.
"""
cfg.timing.add("command dispatch")
cfg.timing.add("import", when=start, which="top").finish(when=top_import_finish)
try:
yield maybeDeferred(command)
except WrongPasswordError as e:
msg = fill("ERROR: " + dedent(e.__doc__))
print(msg, file=stderr)
except WelcomeError as e:
msg = fill("ERROR: " + dedent(e.__doc__))
print(msg, file=stderr)
print(file=stderr)
print(str(e), file=stderr)
except KeyFormatError as e:
msg = fill("ERROR: " + dedent(e.__doc__))
print(msg, file=stderr)
except Exception as e:
traceback.print_exc()
print("ERROR:", e, file=stderr)
raise SystemExit(1)
cfg.timing.add("exit")
if cfg.dump_timing:
cfg.timing.write(cfg.dump_timing, stderr)
# wormhole send (or "wormhole tx")
@wormhole.command()
@click.option(
"-0", "zeromode", default=False, is_flag=True,
help="enable no-code anything-goes mode",
)
@click.option(
"--code", metavar="CODE",
help="human-generated code phrase",
)
@click.option(
"--text", default=None, metavar="MESSAGE",
help="text message to send, instead of a file. Use '-' to read from stdin.",
)
@click.argument("what", default=u'')
@click.pass_obj
def send(cfg, what, text, code, zeromode):
"""Send a text message, file, or directory"""
with cfg.timing.add("import", which="cmd_send"):
from . import cmd_send
cfg.what = what
cfg.text = text
cfg.zeromode = zeromode
cfg.code = code
# note: react() does not return
return react(_dispatch_command, (cfg, lambda: cmd_send.send(cfg)))
# wormhole receive (or "wormhole rx")
@wormhole.command()
@click.option(
"-0", "zeromode", default=False, is_flag=True,
help="enable no-code anything-goes mode",
)
@click.option(
"--only-text", "-t", is_flag=True,
help="refuse file transfers, only accept text transfers",
)
@click.option(
"--accept-file", is_flag=True,
help="accept file transfer without asking for confirmation",
)
@click.option(
"--output-file", "-o",
metavar="FILENAME|DIRNAME",
help=("The file or directory to create, overriding the name suggested"
" by the sender."),
)
@click.argument(
"code", nargs=-1, default=None,
# help=("The magic-wormhole code, from the sender. If omitted, the"
# " program will ask for it, using tab-completion."),
)
@click.pass_obj
def receive(cfg, code, zeromode, output_file, accept_file, only_text):
"""
Receive a text message, file, or directory (from 'wormhole send')
"""
with cfg.timing.add("import", which="cmd_receive"):
from . import cmd_receive
cfg.zeromode = zeromode
cfg.output_file = output_file
cfg.accept_file = accept_file
cfg.only_text = only_text
if len(code) == 1:
cfg.code = code[0]
elif len(code) > 1:
print(
"Pass either no code or just one code; you passed"
" {}: {}".format(len(code), ', '.join(code))
)
raise SystemExit(1)
else:
cfg.code = None
# note: react() does not return
return react(_dispatch_command, (cfg, lambda: cmd_receive.receive(cfg)))

View File

@ -1,124 +0,0 @@
import argparse
from textwrap import dedent
from . import public_relay
from .. import __version__
parser = argparse.ArgumentParser(
usage="wormhole SUBCOMMAND (subcommand-options)",
description=dedent("""
Create a Magic Wormhole and communicate through it. Wormholes are created
by speaking the same magic CODE in two different places at the same time.
Wormholes are secure against anyone who doesn't use the same code."""),
)
parser.add_argument("--version", action="version",
version="magic-wormhole "+ __version__)
g = parser.add_argument_group("wormhole configuration options")
g.add_argument("--relay-url", default=public_relay.RENDEZVOUS_RELAY,
metavar="URL", help="rendezvous relay to use", type=type(u""))
g.add_argument("--transit-helper", default=public_relay.TRANSIT_RELAY,
metavar="tcp:HOST:PORT", help="transit relay to use",
type=type(u""))
g.add_argument("-c", "--code-length", type=int, default=2,
metavar="WORDS", help="length of code (in bytes/words)")
g.add_argument("-v", "--verify", action="store_true",
help="display (and wait for acceptance of) verification string")
g.add_argument("--hide-progress", action="store_true",
help="supress progress-bar display")
g.add_argument("--dump-timing", type=type(u""), # TODO: hide from --help output
metavar="FILE", help="(debug) write timing data to file")
g.add_argument("--no-listen", action="store_true",
help="(debug) don't open a listening socket for Transit")
g.add_argument("--tor", action="store_true",
help="use Tor when connecting")
parser.set_defaults(timing=None)
subparsers = parser.add_subparsers(title="subcommands",
dest="subcommand")
# CLI: run-server
s = subparsers.add_parser("server", description="Start/stop a relay server")
sp = s.add_subparsers(title="subcommands", dest="subcommand")
sp_start = sp.add_parser("start", description="Start a relay server",
usage="wormhole server start [opts] [TWISTD-ARGS..]")
sp_start.add_argument("--rendezvous", default="tcp:3000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port")
sp_start.add_argument("--transit", default="tcp:3001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port")
sp_start.add_argument("--advertise-version", metavar="VERSION",
help="version to recommend to clients")
sp_start.add_argument("--blur-usage", default=None, type=int,
metavar="SECONDS",
help="round logged access times to improve privacy")
sp_start.add_argument("-n", "--no-daemon", action="store_true")
#sp_start.add_argument("twistd_args", nargs="*", default=None,
# metavar="[TWISTD-ARGS..]",
# help=dedent("""\
# Additional arguments to pass to twistd"""),
# )
sp_start.set_defaults(func="server/start")
sp_stop = sp.add_parser("stop", description="Stop the relay server",
usage="wormhole server stop")
sp_stop.set_defaults(func="server/stop")
sp_restart = sp.add_parser("restart", description="Restart the relay server",
usage="wormhole server restart")
sp_restart.add_argument("--rendezvous", default="tcp:3000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port")
sp_restart.add_argument("--transit", default="tcp:3001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port")
sp_restart.add_argument("--advertise-version", metavar="VERSION",
help="version to recommend to clients")
sp_restart.add_argument("--blur-usage", default=None, type=int,
metavar="SECONDS",
help="round logged access times to improve privacy")
sp_restart.add_argument("-n", "--no-daemon", action="store_true")
sp_restart.set_defaults(func="server/restart")
sp_show_usage = sp.add_parser("show-usage", description="Display usage data",
usage="wormhole server show-usage")
sp_show_usage.add_argument("-n", default=100, type=int,
help="show last N entries")
sp_show_usage.set_defaults(func="usage/usage")
sp_tail_usage = sp.add_parser("tail-usage", description="Follow latest usage",
usage="wormhole server tail-usage")
sp_tail_usage.set_defaults(func="usage/tail")
# CLI: send
p = subparsers.add_parser("send",
description="Send text message, file, or directory",
usage="wormhole send [FILENAME|DIRNAME]")
p.add_argument("--text", metavar="MESSAGE",
help="text message to send, instead of a file. Use '-' to read from stdin.")
p.add_argument("--code", metavar="CODE", help="human-generated code phrase",
type=type(u""))
p.add_argument("-0", dest="zeromode", action="store_true",
help="enable no-code anything-goes mode")
p.add_argument("what", nargs="?", default=None, metavar="[FILENAME|DIRNAME]",
help="the file/directory to send")
p.set_defaults(func="send/send")
# CLI: receive
p = subparsers.add_parser("receive",
description="Receive a text message, file, or directory",
usage="wormhole receive [CODE]")
p.add_argument("-0", dest="zeromode", action="store_true",
help="enable no-code anything-goes mode")
p.add_argument("-t", "--only-text", dest="only_text", action="store_true",
help="refuse file transfers, only accept text transfers")
p.add_argument("--accept-file", dest="accept_file", action="store_true",
help="accept file transfer with asking for confirmation")
p.add_argument("-o", "--output-file", default=None, metavar="FILENAME|DIRNAME",
help=dedent("""\
The file or directory to create, overriding the name suggested
by the sender."""),
)
p.add_argument("code", nargs="?", default=None, metavar="[CODE]",
help=dedent("""\
The magic-wormhole code, from the sender. If omitted, the
program will ask for it, using tab-completion."""),
type=type(u""),
)
p.set_defaults(func="receive/receive")

View File

@ -141,7 +141,7 @@ class TwistedReceiver:
@inlineCallbacks @inlineCallbacks
def _build_transit(self, w, sender_transit): def _build_transit(self, w, sender_transit):
tr = TransitReceiver(self.args.transit_helper, tr = TransitReceiver(self.args.transit_helper,
no_listen=self.args.no_listen, no_listen=(not self.args.listen),
tor_manager=self._tor_manager, tor_manager=self._tor_manager,
reactor=self._reactor, reactor=self._reactor,
timing=self.args.timing) timing=self.args.timing)

View File

@ -48,9 +48,10 @@ class Sender:
w = wormhole(APPID, self._args.relay_url, w = wormhole(APPID, self._args.relay_url,
self._reactor, self._tor_manager, self._reactor, self._tor_manager,
timing=self._timing) timing=self._timing)
d = self._go(w) try:
d.addBoth(w.close) yield self._go(w)
yield d finally:
w.close()
def _send_data(self, data, w): def _send_data(self, data, w):
data_bytes = dict_to_bytes(data) data_bytes = dict_to_bytes(data)
@ -100,7 +101,7 @@ class Sender:
if self._fd_to_send: if self._fd_to_send:
ts = TransitSender(args.transit_helper, ts = TransitSender(args.transit_helper,
no_listen=args.no_listen, no_listen=(not args.listen),
tor_manager=self._tor_manager, tor_manager=self._tor_manager,
reactor=self._reactor, reactor=self._reactor,
timing=self._timing) timing=self._timing)

View File

@ -1,81 +0,0 @@
from __future__ import print_function
import time
start = time.time()
import os, sys, textwrap
from twisted.internet.defer import maybeDeferred
from twisted.internet.task import react
from ..errors import (TransferError, WrongPasswordError, WelcomeError, Timeout,
KeyFormatError)
from ..timing import DebugTiming
from .cli_args import parser
top_import_finish = time.time()
def dispatch(args): # returns Deferred
if args.func == "send/send":
with args.timing.add("import", which="cmd_send"):
from . import cmd_send
return cmd_send.send(args)
if args.func == "receive/receive":
with args.timing.add("import", which="cmd_receive"):
from . import cmd_receive
return cmd_receive.receive(args)
raise ValueError("unknown args.func %s" % args.func)
def run(reactor, argv, cwd, stdout, stderr, executable=None):
"""This is invoked by entry() below, and can also be invoked directly by
tests.
"""
args = parser.parse_args(argv)
if not getattr(args, "func", None):
# So far this only works on py3. py2 exits with a really terse
# "error: too few arguments" during parse_args().
parser.print_help()
sys.exit(0)
args.cwd = cwd
args.stdout = stdout
args.stderr = stderr
args.timing = timing = DebugTiming()
timing.add("command dispatch")
timing.add("import", when=start, which="top").finish(when=top_import_finish)
# fires with None, or raises an error
d = maybeDeferred(dispatch, args)
def _maybe_dump_timing(res):
timing.add("exit")
if args.dump_timing:
timing.write(args.dump_timing, stderr)
return res
d.addBoth(_maybe_dump_timing)
def _explain_error(f):
# these errors don't print a traceback, just an explanation
f.trap(TransferError, WrongPasswordError, WelcomeError, Timeout,
KeyFormatError)
if f.check(WrongPasswordError):
msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__))
print(msg, file=stderr)
elif f.check(WelcomeError):
msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__))
print(msg, file=stderr)
print(file=stderr)
print(str(f.value), file=stderr)
elif f.check(KeyFormatError):
msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__))
print(msg, file=stderr)
else:
print("ERROR:", f.value, file=stderr)
raise SystemExit(1)
d.addErrback(_explain_error)
d.addCallback(lambda _: 0)
return d
def entry():
"""This is used by a setuptools entry_point. When invoked this way,
setuptools has already put the installed package on sys.path ."""
react(run, (sys.argv[1:], os.getcwd(), sys.stdout, sys.stderr,
sys.argv[0]))
if __name__ == "__main__":
args = parser.parse_args()
print(args)

View File

@ -47,10 +47,10 @@ class KeyFormatError(Exception):
class ReflectionAttack(Exception): class ReflectionAttack(Exception):
"""An attacker (or bug) reflected our outgoing message back to us.""" """An attacker (or bug) reflected our outgoing message back to us."""
class UsageError(Exception): class InternalError(Exception):
"""The programmer did something wrong.""" """The programmer did something wrong."""
class WormholeClosedError(UsageError): class WormholeClosedError(InternalError):
"""API calls may not be made after close() is called.""" """API calls may not be made after close() is called."""
class TransferError(Exception): class TransferError(Exception):

156
src/wormhole/server/cli.py Normal file
View File

@ -0,0 +1,156 @@
from __future__ import print_function
import click
# can put this back in to get this command as "wormhole server"
# instead
#from ..cli.cli import wormhole
#@wormhole.group()
@click.group()
@click.pass_context
def server(ctx):
"""
Control a relay server (most users shouldn't need to worry
about this and can use the default server).
"""
# just leaving this pointing to wormhole.cli.cli.Config for now,
# but if we want to keep wormhole-server as a separate command
# should probably have our own Config without all the options the
# server commands don't use
from ..cli.cli import Config
ctx.obj = Config()
@server.command()
@click.option(
"--rendezvous", default="tcp:4000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port",
)
@click.option(
"--transit", default="tcp:4001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port",
)
@click.option(
"--advertise-version", metavar="VERSION",
help="version to recommend to clients",
)
@click.option(
"--blur-usage", default=None, type=int,
metavar="SECONDS",
help="round logged access times to improve privacy",
)
@click.option(
"--no-daemon", "-n", is_flag=True,
help="Run in the foreground",
)
@click.option(
"--signal-error", is_flag=True,
help="force all clients to fail with a message",
)
@click.pass_obj
def start(cfg, signal_error, no_daemon, blur_usage, advertise_version, transit, rendezvous):
"""
Start a relay server
"""
from wormhole.server.cmd_server import start_server
cfg.no_daemon = no_daemon
cfg.blur_usage = blur_usage
cfg.advertise_version = advertise_version
cfg.transit = str(transit)
cfg.rendezvous = str(rendezvous)
cfg.signal_error = signal_error
start_server(cfg)
# XXX it would be nice to reduce the duplication between 'restart' and
# 'start' options...
@server.command()
@click.option(
"--rendezvous", default="tcp:4000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port",
)
@click.option(
"--transit", default="tcp:4001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port",
)
@click.option(
"--advertise-version", metavar="VERSION",
help="version to recommend to clients",
)
@click.option(
"--blur-usage", default=None, type=int,
metavar="SECONDS",
help="round logged access times to improve privacy",
)
@click.option(
"--no-daemon", "-n", is_flag=True,
help="Run in the foreground",
)
@click.option(
"--signal-error", is_flag=True,
help="force all clients to fail with a message",
)
@click.pass_obj
def restart(cfg, signal_error, no_daemon, blur_usage, advertise_version, transit, rendezvous):
"""
Re-start a relay server
"""
from wormhole.server.cmd_server import restart_server
cfg.no_daemon = no_daemon
cfg.blur_usage = blur_usage
cfg.advertise_version = advertise_version
cfg.transit = str(transit)
cfg.rendezvous = str(rendezvous)
cfg.signal_error = signal_error
restart_server(cfg)
@server.command()
@click.pass_obj
def stop(cfg):
"""
Stop a relay server
"""
from wormhole.server.cmd_server import stop_server
stop_server(cfg)
@server.command(name="tail-usage")
@click.pass_obj
def tail_usage(cfg):
"""
Follow the latest usage
"""
from wormhole.server.cmd_usage import tail_usage
tail_usage(cfg)
@server.command(name='count-channels')
@click.option(
"--json", is_flag=True,
)
@click.pass_obj
def count_channels(cfg, json):
"""
Count active channels
"""
from wormhole.server.cmd_usage import count_channels
cfg.json = json
count_channels(cfg)
@server.command(name='count-events')
@click.option(
"--json", is_flag=True,
)
@click.pass_obj
def count_events(cfg, json):
"""
Count events
"""
from wormhole.server.cmd_usage import count_events
cfg.json = json
count_events(cfg)

View File

@ -1,77 +0,0 @@
import argparse
from textwrap import dedent
from .. import __version__
parser = argparse.ArgumentParser(
usage="wormhole-server SUBCOMMAND (subcommand-options)",
description=dedent("""
Create a Magic Wormhole and communicate through it. Wormholes are created
by speaking the same magic CODE in two different places at the same time.
Wormholes are secure against anyone who doesn't use the same code."""),
)
parser.add_argument("--version", action="version",
version="magic-wormhole "+ __version__)
s = parser.add_subparsers(title="subcommands", dest="subcommand")
# CLI: run-server
sp_start = s.add_parser("start", description="Start a relay server",
usage="wormhole server start [opts] [TWISTD-ARGS..]")
sp_start.add_argument("--rendezvous", default="tcp:4000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port")
sp_start.add_argument("--transit", default="tcp:4001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port")
sp_start.add_argument("--advertise-version", metavar="VERSION",
help="version to recommend to clients")
sp_start.add_argument("--signal-error", metavar="ERROR",
help="force all clients to fail with a message")
sp_start.add_argument("--blur-usage", default=None, type=int,
metavar="SECONDS",
help="round logged access times to improve privacy")
sp_start.add_argument("-n", "--no-daemon", action="store_true")
#sp_start.add_argument("twistd_args", nargs="*", default=None,
# metavar="[TWISTD-ARGS..]",
# help=dedent("""\
# Additional arguments to pass to twistd"""),
# )
sp_start.set_defaults(func="server/start")
sp_stop = s.add_parser("stop", description="Stop the relay server",
usage="wormhole server stop")
sp_stop.set_defaults(func="server/stop")
sp_restart = s.add_parser("restart", description="Restart the relay server",
usage="wormhole server restart")
sp_restart.add_argument("--rendezvous", default="tcp:4000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port")
sp_restart.add_argument("--transit", default="tcp:4001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port")
sp_restart.add_argument("--advertise-version", metavar="VERSION",
help="version to recommend to clients")
sp_restart.add_argument("--signal-error", metavar="ERROR",
help="force all clients to fail with a message")
sp_restart.add_argument("--blur-usage", default=None, type=int,
metavar="SECONDS",
help="round logged access times to improve privacy")
sp_restart.add_argument("-n", "--no-daemon", action="store_true")
sp_restart.set_defaults(func="server/restart")
sp_show_usage = s.add_parser("show-usage", description="Display usage data",
usage="wormhole server show-usage")
sp_show_usage.add_argument("-n", default=100, type=int,
help="show last N entries")
sp_show_usage.set_defaults(func="usage/usage")
sp_tail_usage = s.add_parser("tail-usage", description="Follow latest usage",
usage="wormhole server tail-usage")
sp_tail_usage.set_defaults(func="usage/tail")
sp_count_channels = s.add_parser("count-channels",
description="Count active channels")
sp_count_channels.add_argument("--json", action="store_true")
sp_count_channels.set_defaults(func="usage/count-channels")
sp_count_events = s.add_parser("count-events", description="Count events")
sp_count_events.add_argument("--json", action="store_true")
sp_count_events.set_defaults(func="usage/count-events")

View File

@ -1,5 +1,6 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import os, time import os, time
import click
from twisted.python import usage from twisted.python import usage
from twisted.scripts import twistd from twisted.scripts import twistd
@ -38,16 +39,17 @@ def kill_server():
try: try:
f = open("twistd.pid", "r") f = open("twistd.pid", "r")
except EnvironmentError: except EnvironmentError:
print("Unable to find twistd.pid . Is this really a server directory?") raise click.UsageError(
return 1 "Unable to find 'twistd.pid' -- is this really a server directory?"
)
pid = int(f.read().strip()) pid = int(f.read().strip())
f.close() f.close()
os.kill(pid, 15) os.kill(pid, 15)
print("server process %d sent SIGTERM" % pid) print("server process %d sent SIGTERM" % pid)
return 0 return
def stop_server(args): def stop_server(args):
return kill_server() kill_server()
def restart_server(args): def restart_server(args):
kill_server() kill_server()

View File

@ -1,8 +1,8 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import os, time, json import os, time, json
from collections import defaultdict from collections import defaultdict
import click
from .database import get_db from .database import get_db
from ..errors import UsageError
def abbrev(t): def abbrev(t):
if t is None: if t is None:
@ -57,7 +57,9 @@ def show_usage(args):
print("closed for renovation") print("closed for renovation")
return 0 return 0
if not os.path.exists("relay.sqlite"): if not os.path.exists("relay.sqlite"):
raise UsageError("cannot find relay.sqlite, please run from the server directory") raise click.UsageError(
"cannot find relay.sqlite, please run from the server directory"
)
oldest = None oldest = None
newest = None newest = None
rendezvous_counters = defaultdict(int) rendezvous_counters = defaultdict(int)
@ -116,7 +118,9 @@ def show_usage(args):
def tail_usage(args): def tail_usage(args):
if not os.path.exists("relay.sqlite"): if not os.path.exists("relay.sqlite"):
raise UsageError("cannot find relay.sqlite, please run from the server directory") raise click.UsageError(
"cannot find relay.sqlite, please run from the server directory"
)
db = get_db("relay.sqlite") db = get_db("relay.sqlite")
# we don't seem to have unique row IDs, so this is an inaccurate and # we don't seem to have unique row IDs, so this is an inaccurate and
# inefficient hack # inefficient hack
@ -141,7 +145,9 @@ def tail_usage(args):
def count_channels(args): def count_channels(args):
if not os.path.exists("relay.sqlite"): if not os.path.exists("relay.sqlite"):
raise UsageError("cannot find relay.sqlite, please run from the server directory") raise click.UsageError(
"cannot find relay.sqlite, please run from the server directory"
)
db = get_db("relay.sqlite") db = get_db("relay.sqlite")
c_list = [] c_list = []
c_dict = {} c_dict = {}
@ -188,7 +194,9 @@ def count_channels(args):
def count_events(args): def count_events(args):
if not os.path.exists("relay.sqlite"): if not os.path.exists("relay.sqlite"):
raise UsageError("cannot find relay.sqlite, please run from the server directory") raise click.UsageError(
"cannot find relay.sqlite, please run from the server directory"
)
db = get_db("relay.sqlite") db = get_db("relay.sqlite")
c_list = [] c_list = []
c_dict = {} c_dict = {}

View File

@ -1,6 +1,6 @@
# no unicode_literals untill twisted update # no unicode_literals untill twisted update
from twisted.application import service from twisted.application import service
from twisted.internet import reactor, defer from twisted.internet import defer, task
from twisted.python import log from twisted.python import log
from ..transit import allocate_tcp_port from ..transit import allocate_tcp_port
from ..server.server import RelayServer from ..server.server import RelayServer
@ -36,8 +36,18 @@ class ServerBase:
# relay's .stopService() drops all connections, which ought to # relay's .stopService() drops all connections, which ought to
# encourage those threads to terminate soon. If they don't, print a # encourage those threads to terminate soon. If they don't, print a
# warning to ease debugging. # warning to ease debugging.
# XXX FIXME there's something in _noclobber test that's not
# waiting for a close, I think -- was pretty relieably getting
# unclean-reactor, but adding a slight pause here stops it...
from twisted.internet import reactor
tp = reactor.getThreadPool() tp = reactor.getThreadPool()
if not tp.working: if not tp.working:
d = defer.succeed(None)
d.addCallback(lambda _: self.sp.stopService())
d.addCallback(lambda _: task.deferLater(reactor, 0.1, lambda: None))
return d
return self.sp.stopService() return self.sp.stopService()
# disconnect all callers # disconnect all callers
d = defer.maybeDeferred(self.sp.stopService) d = defer.maybeDeferred(self.sp.stopService)

View File

@ -7,33 +7,32 @@ from twisted.internet.utils import getProcessOutputAndValue
from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.internet.defer import gatherResults, inlineCallbacks
from .. import __version__ from .. import __version__
from .common import ServerBase from .common import ServerBase
from ..cli import runner, cmd_send, cmd_receive from ..cli import cmd_send, cmd_receive
from ..cli.cli import Config
from ..errors import TransferError, WrongPasswordError, WelcomeError from ..errors import TransferError, WrongPasswordError, WelcomeError
from ..timing import DebugTiming
def build_offer(args): def build_offer(args):
s = cmd_send.Sender(args, None) s = cmd_send.Sender(args, None)
return s._build_offer() return s._build_offer()
class OfferData(unittest.TestCase): class OfferData(unittest.TestCase):
def setUp(self): def setUp(self):
self._things_to_delete = [] self._things_to_delete = []
self.cfg = cfg = Config()
cfg.stdout = io.StringIO()
cfg.stderr = io.StringIO()
def tearDown(self): def tearDown(self):
for fn in self._things_to_delete: for fn in self._things_to_delete:
if os.path.exists(fn): if os.path.exists(fn):
os.unlink(fn) os.unlink(fn)
del self.cfg
def test_text(self): def test_text(self):
message = "blah blah blah ponies" self.cfg.text = message = "blah blah blah ponies"
d, fd_to_send = build_offer(self.cfg)
send_args = [ "send", "--text", message ]
args = runner.parser.parse_args(send_args)
args.cwd = os.getcwd()
args.stdout = io.StringIO()
args.stderr = io.StringIO()
d, fd_to_send = build_offer(args)
self.assertIn("message", d) self.assertIn("message", d)
self.assertNotIn("file", d) self.assertNotIn("file", d)
@ -42,7 +41,7 @@ class OfferData(unittest.TestCase):
self.assertEqual(fd_to_send, None) self.assertEqual(fd_to_send, None)
def test_file(self): def test_file(self):
filename = "my file" self.cfg.what = filename = "my file"
message = b"yay ponies\n" message = b"yay ponies\n"
send_dir = self.mktemp() send_dir = self.mktemp()
os.mkdir(send_dir) os.mkdir(send_dir)
@ -50,13 +49,8 @@ class OfferData(unittest.TestCase):
with open(abs_filename, "wb") as f: with open(abs_filename, "wb") as f:
f.write(message) f.write(message)
send_args = [ "send", filename ] self.cfg.cwd = send_dir
args = runner.parser.parse_args(send_args) d, fd_to_send = build_offer(self.cfg)
args.cwd = send_dir
args.stdout = io.StringIO()
args.stderr = io.StringIO()
d, fd_to_send = build_offer(args)
self.assertNotIn("message", d) self.assertNotIn("message", d)
self.assertIn("file", d) self.assertIn("file", d)
@ -67,17 +61,12 @@ class OfferData(unittest.TestCase):
self.assertEqual(fd_to_send.read(), message) self.assertEqual(fd_to_send.read(), message)
def test_missing_file(self): def test_missing_file(self):
filename = "missing" self.cfg.what = filename = "missing"
send_dir = self.mktemp() send_dir = self.mktemp()
os.mkdir(send_dir) os.mkdir(send_dir)
self.cfg.cwd = send_dir
send_args = [ "send", filename ] e = self.assertRaises(TransferError, build_offer, self.cfg)
args = runner.parser.parse_args(send_args)
args.cwd = send_dir
args.stdout = io.StringIO()
args.stderr = io.StringIO()
e = self.assertRaises(TransferError, build_offer, args)
self.assertEqual(str(e), self.assertEqual(str(e),
"Cannot send: no file/directory named '%s'" % filename) "Cannot send: no file/directory named '%s'" % filename)
@ -94,13 +83,10 @@ class OfferData(unittest.TestCase):
send_dir_arg = send_dir send_dir_arg = send_dir
if addslash: if addslash:
send_dir_arg += os.sep send_dir_arg += os.sep
send_args = [ "send", send_dir_arg ] self.cfg.what = send_dir_arg
args = runner.parser.parse_args(send_args) self.cfg.cwd = parent_dir
args.cwd = parent_dir
args.stdout = io.StringIO()
args.stderr = io.StringIO()
d, fd_to_send = build_offer(args) d, fd_to_send = build_offer(self.cfg)
self.assertNotIn("message", d) self.assertNotIn("message", d)
self.assertNotIn("file", d) self.assertNotIn("file", d)
@ -130,10 +116,11 @@ class OfferData(unittest.TestCase):
return self._do_test_directory(addslash=True) return self._do_test_directory(addslash=True)
def test_unknown(self): def test_unknown(self):
filename = "unknown" self.cfg.what = filename = "unknown"
send_dir = self.mktemp() send_dir = self.mktemp()
os.mkdir(send_dir) os.mkdir(send_dir)
abs_filename = os.path.abspath(os.path.join(send_dir, filename)) abs_filename = os.path.abspath(os.path.join(send_dir, filename))
self.cfg.cwd = send_dir
try: try:
os.mkfifo(abs_filename) os.mkfifo(abs_filename)
@ -149,13 +136,7 @@ class OfferData(unittest.TestCase):
self.assertFalse(os.path.isfile(abs_filename)) self.assertFalse(os.path.isfile(abs_filename))
self.assertFalse(os.path.isdir(abs_filename)) self.assertFalse(os.path.isdir(abs_filename))
send_args = [ "send", filename ] e = self.assertRaises(TypeError, build_offer, self.cfg)
args = runner.parser.parse_args(send_args)
args.cwd = send_dir
args.stdout = io.StringIO()
args.stderr = io.StringIO()
e = self.assertRaises(TypeError, build_offer, args)
self.assertEqual(str(e), self.assertEqual(str(e),
"'%s' is neither file nor directory" % filename) "'%s' is neither file nor directory" % filename)
@ -203,27 +184,24 @@ class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase):
# we need Twisted to run the server, but we run the sender and receiver # we need Twisted to run the server, but we run the sender and receiver
# with deferToThread() # with deferToThread()
@inlineCallbacks
def test_version(self): def test_version(self):
# "wormhole" must be on the path, so e.g. "pip install -e ." in a # "wormhole" must be on the path, so e.g. "pip install -e ." in a
# virtualenv. This guards against an environment where the tests # virtualenv. This guards against an environment where the tests
# below might run the wrong executable. # below might run the wrong executable.
self.maxDiff = None
wormhole = self.find_executable() wormhole = self.find_executable()
d = getProcessOutputAndValue(wormhole, ["--version"]) # we must pass on the environment so that "something" doesn't
def _check(res): # get sad about UTF8 vs. ascii encodings
out, err, rc = res out, err, rc = yield getProcessOutputAndValue(wormhole, ["--version"], env=os.environ)
# argparse on py2 and py3.3 sends --version to stderr err = err.decode("utf-8")
# argparse on py3.4/py3.5 sends --version to stdout if "DistributionNotFound" in err:
# aargh log.msg("stderr was %s" % err)
err = err.decode("utf-8") last = err.strip().split("\n")[-1]
if "DistributionNotFound" in err: self.fail("wormhole not runnable: %s" % last)
log.msg("stderr was %s" % err) ver = out.decode("utf-8") or err
last = err.strip().split("\n")[-1] self.failUnlessEqual(ver.strip(), "magic-wormhole {}".format(__version__))
self.fail("wormhole not runnable: %s" % last) self.failUnlessEqual(rc, 0)
ver = out.decode("utf-8") or err
self.failUnlessEqual(ver, "magic-wormhole "+__version__+os.linesep)
self.failUnlessEqual(rc, 0)
d.addCallback(_check)
return d
class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
# we need Twisted to run the server, but we run the sender and receiver # we need Twisted to run the server, but we run the sender and receiver
@ -238,20 +216,18 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
def _do_test(self, as_subprocess=False, def _do_test(self, as_subprocess=False,
mode="text", addslash=False, override_filename=False): mode="text", addslash=False, override_filename=False):
assert mode in ("text", "file", "directory") assert mode in ("text", "file", "directory")
common_args = ["--hide-progress", send_cfg = Config()
"--relay-url", self.relayurl, recv_cfg = Config()
"--transit-helper", ""] message = "blah blah blah ponies"
code = "1-abc"
message = "test message"
send_args = common_args + [ for cfg in [send_cfg, recv_cfg]:
"send", cfg.hide_progress = True
"--code", code, cfg.relay_url = self.relayurl
] cfg.transit_helper = ""
cfg.listen = True
receive_args = common_args + [ cfg.code = "1-abc"
"receive", cfg.stdout = io.StringIO()
] cfg.stderr = io.StringIO()
send_dir = self.mktemp() send_dir = self.mktemp()
os.mkdir(send_dir) os.mkdir(send_dir)
@ -259,19 +235,18 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
os.mkdir(receive_dir) os.mkdir(receive_dir)
if mode == "text": if mode == "text":
send_args.extend(["--text", message]) send_cfg.text = message
elif mode == "file": elif mode == "file":
send_filename = "testfile" send_filename = "testfile"
with open(os.path.join(send_dir, send_filename), "w") as f: with open(os.path.join(send_dir, send_filename), "w") as f:
f.write(message) f.write(message)
send_args.append(send_filename) send_cfg.what = send_filename
receive_filename = send_filename receive_filename = send_filename
receive_args.append("--accept-file") recv_cfg.accept_file = True
if override_filename: if override_filename:
receive_args.extend(["-o", "outfile"]) recv_cfg.output_file = receive_filename = "outfile"
receive_filename = "outfile"
elif mode == "directory": elif mode == "directory":
# $send_dir/ # $send_dir/
@ -299,22 +274,48 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
send_dirname_arg = os.path.join("middle", send_dirname) send_dirname_arg = os.path.join("middle", send_dirname)
if addslash: if addslash:
send_dirname_arg += os.sep send_dirname_arg += os.sep
send_args.append(send_dirname_arg) send_cfg.what = send_dirname_arg
receive_dirname = send_dirname receive_dirname = send_dirname
receive_args.append("--accept-file") recv_cfg.accept_file = True
if override_filename: if override_filename:
receive_args.extend(["-o", "outdir"]) recv_cfg.output_file = receive_dirname = "outdir"
receive_dirname = "outdir"
receive_args.append(code)
if as_subprocess: if as_subprocess:
wormhole_bin = self.find_executable() wormhole_bin = self.find_executable()
send_d = getProcessOutputAndValue(wormhole_bin, send_args, if send_cfg.text:
path=send_dir) content_args = ['--text', send_cfg.text]
receive_d = getProcessOutputAndValue(wormhole_bin, receive_args, elif send_cfg.what:
path=receive_dir) content_args = [send_cfg.what]
send_args = [
'--hide-progress',
'--relay-url', self.relayurl,
'--transit-helper', '',
'send',
'--code', send_cfg.code,
] + content_args
send_d = getProcessOutputAndValue(
wormhole_bin, send_args,
path=send_dir,
)
recv_args = [
'--hide-progress',
'--relay-url', self.relayurl,
'--transit-helper', '',
'receive',
'--accept-file',
recv_cfg.code,
]
if override_filename:
recv_args.extend(['-o', receive_filename])
receive_d = getProcessOutputAndValue(
wormhole_bin, recv_args,
path=receive_dir,
)
(send_res, receive_res) = yield gatherResults([send_d, receive_d], (send_res, receive_res) = yield gatherResults([send_d, receive_d],
True) True)
send_stdout = send_res[0].decode("utf-8") send_stdout = send_res[0].decode("utf-8")
@ -327,27 +328,20 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.assertEqual((send_rc, receive_rc), (0, 0), self.assertEqual((send_rc, receive_rc), (0, 0),
(send_res, receive_res)) (send_res, receive_res))
else: else:
sargs = runner.parser.parse_args(send_args) send_cfg.cwd = send_dir
sargs.cwd = send_dir send_d = cmd_send.send(send_cfg)
sargs.stdout = io.StringIO()
sargs.stderr = io.StringIO() recv_cfg.cwd = receive_dir
sargs.timing = DebugTiming() receive_d = cmd_receive.receive(recv_cfg)
rargs = runner.parser.parse_args(receive_args)
rargs.cwd = receive_dir
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
send_d = cmd_send.send(sargs)
receive_d = cmd_receive.receive(rargs)
# The sender might fail, leaving the receiver hanging, or vice # The sender might fail, leaving the receiver hanging, or vice
# versa. Make sure we don't wait on one side exclusively # versa. Make sure we don't wait on one side exclusively
yield gatherResults([send_d, receive_d], True) yield gatherResults([send_d, receive_d], True)
send_stdout = sargs.stdout.getvalue() send_stdout = send_cfg.stdout.getvalue()
send_stderr = sargs.stderr.getvalue() send_stderr = send_cfg.stderr.getvalue()
receive_stdout = rargs.stdout.getvalue() receive_stdout = recv_cfg.stdout.getvalue()
receive_stderr = rargs.stderr.getvalue() receive_stderr = recv_cfg.stderr.getvalue()
# all output here comes from a StringIO, which uses \n for # all output here comes from a StringIO, which uses \n for
# newlines, even if we're on windows # newlines, even if we're on windows
@ -367,7 +361,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
"wormhole receive{NL}" "wormhole receive{NL}"
"Wormhole code is: {code}{NL}{NL}" "Wormhole code is: {code}{NL}{NL}"
"text message sent{NL}").format(bytes=len(message), "text message sent{NL}").format(bytes=len(message),
code=code, code=send_cfg.code,
NL=NL) NL=NL)
self.failUnlessEqual(send_stdout, expected) self.failUnlessEqual(send_stdout, expected)
elif mode == "file": elif mode == "file":
@ -377,7 +371,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.failUnlessIn("On the other computer, please run: " self.failUnlessIn("On the other computer, please run: "
"wormhole receive{NL}" "wormhole receive{NL}"
"Wormhole code is: {code}{NL}{NL}" "Wormhole code is: {code}{NL}{NL}"
.format(code=code, NL=NL), .format(code=send_cfg.code, NL=NL),
send_stdout) send_stdout)
self.failUnlessIn("File sent.. waiting for confirmation{NL}" self.failUnlessIn("File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}" "Confirmation received. Transfer complete.{NL}"
@ -388,7 +382,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.failUnlessIn("On the other computer, please run: " self.failUnlessIn("On the other computer, please run: "
"wormhole receive{NL}" "wormhole receive{NL}"
"Wormhole code is: {code}{NL}{NL}" "Wormhole code is: {code}{NL}{NL}"
.format(code=code, NL=NL), send_stdout) .format(code=send_cfg.code, NL=NL), send_stdout)
self.failUnlessIn("File sent.. waiting for confirmation{NL}" self.failUnlessIn("File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}" "Confirmation received. Transfer complete.{NL}"
.format(NL=NL), send_stdout) .format(NL=NL), send_stdout)
@ -440,14 +434,21 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_file_noclobber(self): def test_file_noclobber(self):
common_args = ["--hide-progress", "--no-listen", send_cfg = Config()
"--relay-url", self.relayurl, recv_cfg = Config()
"--transit-helper", ""]
code = "1-abc" for cfg in [send_cfg, recv_cfg]:
cfg.hide_progress = True
cfg.relay_url = self.relayurl
cfg.transit_helper = ""
cfg.listen = False
cfg.code = code = "1-abc"
cfg.stdout = io.StringIO()
cfg.stderr = io.StringIO()
message = "test message" message = "test message"
send_args = common_args + [ "send", "--code", code ] recv_cfg.accept_file = True
receive_args = common_args + [ "receive", "--accept-file", code ]
send_dir = self.mktemp() send_dir = self.mktemp()
os.mkdir(send_dir) os.mkdir(send_dir)
@ -457,26 +458,19 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
send_filename = "testfile" send_filename = "testfile"
with open(os.path.join(send_dir, send_filename), "w") as f: with open(os.path.join(send_dir, send_filename), "w") as f:
f.write(message) f.write(message)
send_args.append(send_filename) send_cfg.what = receive_filename = send_filename
receive_filename = send_filename recv_cfg.what = receive_filename
PRESERVE = "don't clobber me\n" PRESERVE = "don't clobber me\n"
clobberable = os.path.join(receive_dir, receive_filename) clobberable = os.path.join(receive_dir, receive_filename)
with open(clobberable, "w") as f: with open(clobberable, "w") as f:
f.write(PRESERVE) f.write(PRESERVE)
sargs = runner.parser.parse_args(send_args) send_cfg.cwd = send_dir
sargs.cwd = send_dir send_d = cmd_send.send(send_cfg)
sargs.stdout = io.StringIO()
sargs.stderr = io.StringIO() recv_cfg.cwd = receive_dir
sargs.timing = DebugTiming() receive_d = cmd_receive.receive(recv_cfg)
rargs = runner.parser.parse_args(receive_args)
rargs.cwd = receive_dir
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
send_d = cmd_send.send(sargs)
receive_d = cmd_receive.receive(rargs)
# both sides will fail because of the pre-existing file # both sides will fail because of the pre-existing file
@ -486,10 +480,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
f = yield self.assertFailure(receive_d, TransferError) f = yield self.assertFailure(receive_d, TransferError)
self.assertEqual(str(f), "file already exists") self.assertEqual(str(f), "file already exists")
send_stdout = sargs.stdout.getvalue() send_stdout = send_cfg.stdout.getvalue()
send_stderr = sargs.stderr.getvalue() send_stderr = send_cfg.stderr.getvalue()
receive_stdout = rargs.stdout.getvalue() receive_stdout = recv_cfg.stdout.getvalue()
receive_stderr = rargs.stderr.getvalue() receive_stderr = recv_cfg.stderr.getvalue()
# all output here comes from a StringIO, which uses \n for # all output here comes from a StringIO, which uses \n for
# newlines, even if we're on windows # newlines, even if we're on windows
@ -528,62 +522,54 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
class NotWelcome(ServerBase, unittest.TestCase): class NotWelcome(ServerBase, unittest.TestCase):
def setUp(self): def setUp(self):
self._setup_relay(error="please upgrade XYZ") self._setup_relay(error="please upgrade XYZ")
self.cfg = cfg = Config()
cfg.hide_progress = True
cfg.listen = False
cfg.relay_url = self.relayurl
cfg.transit_helper = ""
cfg.stdout = io.StringIO()
cfg.stderr = io.StringIO()
@inlineCallbacks @inlineCallbacks
def test_sender(self): def test_sender(self):
common_args = ["--hide-progress", "--no-listen", self.cfg.text = "hi"
"--relay-url", self.relayurl, self.cfg.code = "1-abc"
"--transit-helper", ""]
send_args = common_args + [ "send", "--text", "hi",
"--code", "1-abc" ]
sargs = runner.parser.parse_args(send_args)
sargs.cwd = self.mktemp()
sargs.stdout = io.StringIO()
sargs.stderr = io.StringIO()
sargs.timing = DebugTiming()
send_d = cmd_send.send(sargs) send_d = cmd_send.send(self.cfg)
f = yield self.assertFailure(send_d, WelcomeError) f = yield self.assertFailure(send_d, WelcomeError)
self.assertEqual(str(f), "please upgrade XYZ") self.assertEqual(str(f), "please upgrade XYZ")
@inlineCallbacks @inlineCallbacks
def test_receiver(self): def test_receiver(self):
common_args = ["--hide-progress", "--no-listen", self.cfg.code = "1-abc"
"--relay-url", self.relayurl,
"--transit-helper", ""]
receive_args = common_args + [ "receive", "1-abc" ]
rargs = runner.parser.parse_args(receive_args)
rargs.cwd = self.mktemp()
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
receive_d = cmd_receive.receive(rargs) receive_d = cmd_receive.receive(self.cfg)
f = yield self.assertFailure(receive_d, WelcomeError) f = yield self.assertFailure(receive_d, WelcomeError)
self.assertEqual(str(f), "please upgrade XYZ") self.assertEqual(str(f), "please upgrade XYZ")
class Cleanup(ServerBase, unittest.TestCase): class Cleanup(ServerBase, unittest.TestCase):
def setUp(self):
d = super(Cleanup, self).setUp()
self.cfg = cfg = Config()
# common options for all tests in this suite
cfg.hide_progress = True
cfg.relay_url = self.relayurl
cfg.transit_helper = ""
cfg.stdout = io.StringIO()
cfg.stderr = io.StringIO()
return d
@inlineCallbacks @inlineCallbacks
def test_text(self): @mock.patch('sys.stdout')
def test_text(self, stdout):
# the rendezvous channel should be deleted after success # the rendezvous channel should be deleted after success
code = "1-abc" self.cfg.text = "hello"
common_args = ["--hide-progress", self.cfg.code = "1-abc"
"--relay-url", self.relayurl,
"--transit-helper", ""] send_d = cmd_send.send(self.cfg)
sargs = runner.parser.parse_args(common_args + receive_d = cmd_receive.receive(self.cfg)
["send",
"--text", "secret message",
"--code", code])
sargs.stdout = io.StringIO()
sargs.stderr = io.StringIO()
sargs.timing = DebugTiming()
rargs = runner.parser.parse_args(common_args +
["receive", code])
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
send_d = cmd_send.send(sargs)
receive_d = cmd_receive.receive(rargs)
yield send_d yield send_d
yield receive_d yield receive_d
@ -595,23 +581,12 @@ class Cleanup(ServerBase, unittest.TestCase):
def test_text_wrong_password(self): def test_text_wrong_password(self):
# if the password was wrong, the rendezvous channel should still be # if the password was wrong, the rendezvous channel should still be
# deleted # deleted
common_args = ["--hide-progress", self.cfg.text = "secret message"
"--relay-url", self.relayurl, self.cfg.code = "1-abc"
"--transit-helper", ""] send_d = cmd_send.send(self.cfg)
sargs = runner.parser.parse_args(common_args +
["send", self.cfg.code = "1-WRONG"
"--text", "secret message", receive_d = cmd_receive.receive(self.cfg)
"--code", "1-abc"])
sargs.stdout = io.StringIO()
sargs.stderr = io.StringIO()
sargs.timing = DebugTiming()
rargs = runner.parser.parse_args(common_args +
["receive", "1-WRONG"])
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
send_d = cmd_send.send(sargs)
receive_d = cmd_receive.receive(rargs)
# both sides should be capable of detecting the mismatch # both sides should be capable of detecting the mismatch
yield self.assertFailure(send_d, WrongPasswordError) yield self.assertFailure(send_d, WrongPasswordError)

View File

@ -6,8 +6,8 @@ from twisted.internet import defer, task, endpoints, protocol, address, error
from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.internet.defer import gatherResults, inlineCallbacks
from twisted.python import log, failure from twisted.python import log, failure
from twisted.test import proto_helpers from twisted.test import proto_helpers
from ..errors import InternalError
from .. import transit from .. import transit
from ..errors import UsageError
from nacl.secret import SecretBox from nacl.secret import SecretBox
from nacl.exceptions import CryptoError from nacl.exceptions import CryptoError
@ -147,7 +147,7 @@ class Basic(unittest.TestCase):
"hostname": "host", "hostname": "host",
"port": 1234}], "port": 1234}],
}]) }])
self.assertRaises(UsageError, transit.Common, 123) self.assertRaises(InternalError, transit.Common, 123)
@inlineCallbacks @inlineCallbacks
def test_no_relay_hints(self): def test_no_relay_hints(self):

View File

@ -7,7 +7,7 @@ from twisted.internet import reactor
from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks
from .common import ServerBase from .common import ServerBase
from .. import wormhole from .. import wormhole
from ..errors import (WrongPasswordError, WelcomeError, UsageError, from ..errors import (WrongPasswordError, WelcomeError, InternalError,
KeyFormatError) KeyFormatError)
from spake2 import SPAKE2_Symmetric from spake2 import SPAKE2_Symmetric
from ..timing import DebugTiming from ..timing import DebugTiming
@ -876,20 +876,20 @@ class Errors(ServerBase, unittest.TestCase):
def test_codes_1(self): def test_codes_1(self):
w = wormhole.wormhole(APPID, self.relayurl, reactor) w = wormhole.wormhole(APPID, self.relayurl, reactor)
# definitely too early # definitely too early
self.assertRaises(UsageError, w.derive_key, "purpose", 12) self.assertRaises(InternalError, w.derive_key, "purpose", 12)
w.set_code("123-purple-elephant") w.set_code("123-purple-elephant")
# code can only be set once # code can only be set once
self.assertRaises(UsageError, w.set_code, "123-nope") self.assertRaises(InternalError, w.set_code, "123-nope")
yield self.assertFailure(w.get_code(), UsageError) yield self.assertFailure(w.get_code(), InternalError)
yield self.assertFailure(w.input_code(), UsageError) yield self.assertFailure(w.input_code(), InternalError)
yield w.close() yield w.close()
@inlineCallbacks @inlineCallbacks
def test_codes_2(self): def test_codes_2(self):
w = wormhole.wormhole(APPID, self.relayurl, reactor) w = wormhole.wormhole(APPID, self.relayurl, reactor)
yield w.get_code() yield w.get_code()
self.assertRaises(UsageError, w.set_code, "123-nope") self.assertRaises(InternalError, w.set_code, "123-nope")
yield self.assertFailure(w.get_code(), UsageError) yield self.assertFailure(w.get_code(), InternalError)
yield self.assertFailure(w.input_code(), UsageError) yield self.assertFailure(w.input_code(), InternalError)
yield w.close() yield w.close()

View File

@ -13,7 +13,7 @@ from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.protocols import policies from twisted.protocols import policies
from nacl.secret import SecretBox from nacl.secret import SecretBox
from hkdf import Hkdf from hkdf import Hkdf
from .errors import UsageError from .errors import InternalError
from .timing import DebugTiming from .timing import DebugTiming
from . import ipaddrs from . import ipaddrs
@ -276,7 +276,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
return self._description return self._description
def send_record(self, record): def send_record(self, record):
if not isinstance(record, type(b"")): raise UsageError if not isinstance(record, type(b"")): raise InternalError
assert SecretBox.NONCE_SIZE == 24 assert SecretBox.NONCE_SIZE == 24
assert self.send_nonce < 2**(8*24) assert self.send_nonce < 2**(8*24)
assert len(record) < 2**(8*4) assert len(record) < 2**(8*4)
@ -577,7 +577,7 @@ class Common:
reactor=reactor, timing=None): reactor=reactor, timing=None):
if transit_relay: if transit_relay:
if not isinstance(transit_relay, type(u"")): if not isinstance(transit_relay, type(u"")):
raise UsageError raise InternalError
relay = RelayV1Hint(hints=[parse_hint_argv(transit_relay)]) relay = RelayV1Hint(hints=[parse_hint_argv(transit_relay)])
self._transit_relays = [relay] self._transit_relays = [relay]
else: else:
@ -813,6 +813,9 @@ class Common:
is_relay=True) is_relay=True)
contenders.append(d) contenders.append(d)
if not contenders:
raise TransitError("No contenders for connection")
winner = there_can_be_only_one(contenders) winner = there_can_be_only_one(contenders)
return self._not_forever(2*TIMEOUT, winner) return self._not_forever(2*TIMEOUT, winner)

View File

@ -14,7 +14,7 @@ from hashlib import sha256
from . import __version__ from . import __version__
from . import codes from . import codes
#from .errors import ServerError, Timeout #from .errors import ServerError, Timeout
from .errors import (WrongPasswordError, UsageError, WelcomeError, from .errors import (WrongPasswordError, InternalError, WelcomeError,
WormholeClosedError, KeyFormatError) WormholeClosedError, KeyFormatError)
from .timing import DebugTiming from .timing import DebugTiming
from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes, from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes,
@ -425,8 +425,8 @@ class _Wormhole:
# entry point 1: generate a new code # entry point 1: generate a new code
@inlineCallbacks @inlineCallbacks
def _API_get_code(self, code_length): def _API_get_code(self, code_length):
if self._code is not None: raise UsageError if self._code is not None: raise InternalError
if self._started_get_code: raise UsageError if self._started_get_code: raise InternalError
self._started_get_code = True self._started_get_code = True
with self._timing.add("API get_code"): with self._timing.add("API get_code"):
yield self._when_connected() yield self._when_connected()
@ -443,8 +443,8 @@ class _Wormhole:
# entry point 2: interactively type in a code, with completion # entry point 2: interactively type in a code, with completion
@inlineCallbacks @inlineCallbacks
def _API_input_code(self, prompt, code_length): def _API_input_code(self, prompt, code_length):
if self._code is not None: raise UsageError if self._code is not None: raise InternalError
if self._started_input_code: raise UsageError if self._started_input_code: raise InternalError
self._started_input_code = True self._started_input_code = True
with self._timing.add("API input_code"): with self._timing.add("API input_code"):
yield self._when_connected() yield self._when_connected()
@ -462,8 +462,10 @@ class _Wormhole:
# entry point 3: paste in a fully-formed code # entry point 3: paste in a fully-formed code
def _API_set_code(self, code): def _API_set_code(self, code):
self._timing.add("API set_code") self._timing.add("API set_code")
if not isinstance(code, type("")): raise TypeError(type(code)) if not isinstance(code, type(u"")):
if self._code is not None: raise UsageError raise TypeError("Unexpected code type '{}'".format(type(code)))
if self._code is not None:
raise InternalError
self._event_learned_code(code) self._event_learned_code(code)
# TODO: entry point 4: restore pre-contact saved state (we haven't heard # TODO: entry point 4: restore pre-contact saved state (we haven't heard
@ -526,7 +528,7 @@ class _Wormhole:
self._event_learned_mailbox() self._event_learned_mailbox()
def _event_learned_mailbox(self): def _event_learned_mailbox(self):
if not self._mailbox_id: raise UsageError if not self._mailbox_id: raise InternalError
assert self._mailbox_state == CLOSED, self._mailbox_state assert self._mailbox_state == CLOSED, self._mailbox_state
if self._closing: if self._closing:
return return
@ -578,7 +580,7 @@ class _Wormhole:
def _API_verify(self): def _API_verify(self):
if self._error: return defer.fail(self._error) if self._error: return defer.fail(self._error)
if self._get_verifier_called: raise UsageError if self._get_verifier_called: raise InternalError
self._get_verifier_called = True self._get_verifier_called = True
if self._verify_result: if self._verify_result:
return defer.succeed(self._verify_result) # bytes or Failure return defer.succeed(self._verify_result) # bytes or Failure
@ -683,7 +685,7 @@ class _Wormhole:
return box.encrypt(data, nonce) return box.encrypt(data, nonce)
def _msg_send(self, phase, body): def _msg_send(self, phase, body):
if phase in self._sent_phases: raise UsageError if phase in self._sent_phases: raise InternalError
assert self._mailbox_state == OPEN, self._mailbox_state assert self._mailbox_state == OPEN, self._mailbox_state
self._sent_phases.add(phase) self._sent_phases.add(phase)
# TODO: retry on failure, with exponential backoff. We're guarding # TODO: retry on failure, with exponential backoff. We're guarding
@ -700,14 +702,14 @@ class _Wormhole:
def _API_derive_key(self, purpose, length): def _API_derive_key(self, purpose, length):
if self._error: raise self._error if self._error: raise self._error
if self._key is None: if self._key is None:
raise UsageError # call derive_key after get_verifier() or get() raise InternalError # call derive_key after get_verifier() or get()
if not isinstance(purpose, type("")): raise TypeError(type(purpose)) if not isinstance(purpose, type("")): raise TypeError(type(purpose))
return self._derive_key(to_bytes(purpose), length) return self._derive_key(to_bytes(purpose), length)
def _derive_key(self, purpose, length=SecretBox.KEY_SIZE): def _derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(purpose, type(b"")): raise TypeError(type(purpose)) if not isinstance(purpose, type(b"")): raise TypeError(type(purpose))
if self._key is None: if self._key is None:
raise UsageError # call derive_key after get_verifier() or get() raise InternalError # call derive_key after get_verifier() or get()
return HKDF(self._key, length, CTXinfo=purpose) return HKDF(self._key, length, CTXinfo=purpose)
def _response_handle_message(self, msg): def _response_handle_message(self, msg):
@ -782,7 +784,7 @@ class _Wormhole:
@inlineCallbacks @inlineCallbacks
def _API_close(self, res, mood="happy"): def _API_close(self, res, mood="happy"):
if self.DEBUG: print("close") if self.DEBUG: print("close")
if self._close_called: raise UsageError if self._close_called: raise InternalError
self._close_called = True self._close_called = True
self._maybe_close(WormholeClosedError(), mood) self._maybe_close(WormholeClosedError(), mood)
if self.DEBUG: print("waiting for disconnect") if self.DEBUG: print("waiting for disconnect")

View File

@ -4,7 +4,7 @@
# and then run "tox" from this directory. # and then run "tox" from this directory.
[tox] [tox]
envlist = {py27,py33,py34,py35} envlist = {py27,py33,py34,py35,pypy}
skip_missing_interpreters = True skip_missing_interpreters = True
# On windows we need "pypiwin32" installed. It's supposedly possible to make # On windows we need "pypiwin32" installed. It's supposedly possible to make