From 12dcd6a1848397008057d3a3e88070ee9919e84a Mon Sep 17 00:00:00 2001 From: Vasudev Kamath Date: Sat, 21 Apr 2018 13:00:08 +0530 Subject: [PATCH 1/7] Make code pep-8 compliant --- src/wormhole/__init__.py | 7 +- src/wormhole/__main__.py | 6 +- src/wormhole/_allocator.py | 62 ++- src/wormhole/_boss.py | 171 ++++--- src/wormhole/_code.py | 74 ++- src/wormhole/_input.py | 251 ++++++---- src/wormhole/_interfaces.py | 44 +- src/wormhole/_key.py | 115 +++-- src/wormhole/_lister.py | 74 +-- src/wormhole/_mailbox.py | 95 ++-- src/wormhole/_nameplate.py | 102 ++-- src/wormhole/_order.py | 34 +- src/wormhole/_receive.py | 66 ++- src/wormhole/_rendezvous.py | 65 +-- src/wormhole/_rlcompleter.py | 57 ++- src/wormhole/_send.py | 35 +- src/wormhole/_terminator.py | 82 ++-- src/wormhole/_wordlist.py | 404 ++++++++++------ src/wormhole/cli/cli.py | 189 +++++--- src/wormhole/cli/cmd_receive.py | 167 ++++--- src/wormhole/cli/cmd_send.py | 165 ++++--- src/wormhole/cli/cmd_ssh.py | 15 +- src/wormhole/cli/public_relay.py | 1 - src/wormhole/cli/welcome.py | 25 +- src/wormhole/errors.py | 29 ++ src/wormhole/eventual.py | 5 +- src/wormhole/ipaddrs.py | 46 +- src/wormhole/journal.py | 12 +- src/wormhole/observer.py | 7 +- src/wormhole/test/common.py | 45 +- src/wormhole/test/run_trial.py | 1 + src/wormhole/test/test_args.py | 12 +- src/wormhole/test/test_cli.py | 412 ++++++++++------- src/wormhole/test/test_eventual.py | 21 +- src/wormhole/test/test_hkdf.py | 59 +-- src/wormhole/test/test_ipaddrs.py | 46 +- src/wormhole/test/test_journal.py | 5 +- src/wormhole/test/test_machines.py | 640 +++++++++++++++----------- src/wormhole/test/test_observer.py | 5 +- src/wormhole/test/test_rlcompleter.py | 151 +++--- src/wormhole/test/test_ssh.py | 32 +- src/wormhole/test/test_tor_manager.py | 115 +++-- src/wormhole/test/test_transit.py | 507 ++++++++++++-------- src/wormhole/test/test_util.py | 16 +- src/wormhole/test/test_wordlist.py | 23 +- src/wormhole/test/test_wormhole.py | 100 ++-- src/wormhole/test/test_xfer_util.py | 28 +- src/wormhole/timing.py | 23 +- src/wormhole/tor_manager.py | 49 +- src/wormhole/transit.py | 256 +++++++---- src/wormhole/util.py | 14 +- src/wormhole/wormhole.py | 132 ++++-- src/wormhole/xfer_util.py | 62 +-- 53 files changed, 3260 insertions(+), 1899 deletions(-) diff --git a/src/wormhole/__init__.py b/src/wormhole/__init__.py index c00af56..0cc6a01 100644 --- a/src/wormhole/__init__.py +++ b/src/wormhole/__init__.py @@ -1,9 +1,4 @@ - -from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions - -from .wormhole import create from ._rlcompleter import input_with_completion +from .wormhole import create, __version__ __all__ = ["create", "input_with_completion", "__version__"] diff --git a/src/wormhole/__main__.py b/src/wormhole/__main__.py index a47f2e5..568abf3 100644 --- a/src/wormhole/__main__.py +++ b/src/wormhole/__main__.py @@ -1,8 +1,6 @@ +from .cli import cli + if __name__ != "__main__": raise ImportError('this module should not be imported') - -from .cli import cli - - cli.wormhole() diff --git a/src/wormhole/_allocator.py b/src/wormhole/_allocator.py index 3c98f10..554b101 100644 --- a/src/wormhole/_allocator.py +++ b/src/wormhole/_allocator.py @@ -1,56 +1,78 @@ -from __future__ import print_function, absolute_import, unicode_literals -from zope.interface import implementer -from attr import attrs, attrib +from __future__ import absolute_import, print_function, unicode_literals + +from attr import attrib, attrs from attr.validators import provides from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces + @attrs @implementer(_interfaces.IAllocator) class Allocator(object): _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def wire(self, rendezvous_connector, code): self._RC = _interfaces.IRendezvousConnector(rendezvous_connector) self._C = _interfaces.ICode(code) @m.state(initial=True) - def S0A_idle(self): pass # pragma: no cover + def S0A_idle(self): + pass # pragma: no cover + @m.state() - def S0B_idle_connected(self): pass # pragma: no cover + def S0B_idle_connected(self): + pass # pragma: no cover + @m.state() - def S1A_allocating(self): pass # pragma: no cover + def S1A_allocating(self): + pass # pragma: no cover + @m.state() - def S1B_allocating_connected(self): pass # pragma: no cover + def S1B_allocating_connected(self): + pass # pragma: no cover + @m.state() - def S2_done(self): pass # pragma: no cover + def S2_done(self): + pass # pragma: no cover # from Code @m.input() - def allocate(self, length, wordlist): pass + def allocate(self, length, wordlist): + pass # from RendezvousConnector @m.input() - def connected(self): pass + def connected(self): + pass + @m.input() - def lost(self): pass + def lost(self): + pass + @m.input() - def rx_allocated(self, nameplate): pass + def rx_allocated(self, nameplate): + pass @m.output() def stash(self, length, wordlist): self._length = length self._wordlist = _interfaces.IWordlist(wordlist) + @m.output() def stash_and_RC_rx_allocate(self, length, wordlist): self._length = length self._wordlist = _interfaces.IWordlist(wordlist) self._RC.tx_allocate() + @m.output() def RC_tx_allocate(self): self._RC.tx_allocate() + @m.output() def build_and_notify(self, nameplate): words = self._wordlist.choose_words(self._length) @@ -61,15 +83,17 @@ class Allocator(object): S0B_idle_connected.upon(lost, enter=S0A_idle, outputs=[]) S0A_idle.upon(allocate, enter=S1A_allocating, outputs=[stash]) - S0B_idle_connected.upon(allocate, enter=S1B_allocating_connected, - outputs=[stash_and_RC_rx_allocate]) + S0B_idle_connected.upon( + allocate, + enter=S1B_allocating_connected, + outputs=[stash_and_RC_rx_allocate]) - S1A_allocating.upon(connected, enter=S1B_allocating_connected, - outputs=[RC_tx_allocate]) + S1A_allocating.upon( + connected, enter=S1B_allocating_connected, outputs=[RC_tx_allocate]) S1B_allocating_connected.upon(lost, enter=S1A_allocating, outputs=[]) - S1B_allocating_connected.upon(rx_allocated, enter=S2_done, - outputs=[build_and_notify]) + S1B_allocating_connected.upon( + rx_allocated, enter=S2_done, outputs=[build_and_notify]) S2_done.upon(connected, enter=S2_done, outputs=[]) S2_done.upon(lost, enter=S2_done, outputs=[]) diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index c99bb3d..097efe7 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -1,29 +1,33 @@ -from __future__ import print_function, absolute_import, unicode_literals +from __future__ import absolute_import, print_function, unicode_literals + import re + import six -from zope.interface import implementer -from attr import attrs, attrib -from attr.validators import provides, instance_of, optional -from twisted.python import log +from attr import attrib, attrs +from attr.validators import instance_of, optional, provides from automat import MethodicalMachine +from twisted.python import log +from zope.interface import implementer + from . import _interfaces -from ._nameplate import Nameplate -from ._mailbox import Mailbox -from ._send import Send -from ._order import Order +from ._allocator import Allocator +from ._code import Code, validate_code +from ._input import Input from ._key import Key +from ._lister import Lister +from ._mailbox import Mailbox +from ._nameplate import Nameplate +from ._order import Order from ._receive import Receive from ._rendezvous import RendezvousConnector -from ._lister import Lister -from ._allocator import Allocator -from ._input import Input -from ._code import Code, validate_code +from ._send import Send from ._terminator import Terminator from ._wordlist import PGPWordList -from .errors import (ServerError, LonelyError, WrongPasswordError, - OnlyOneCodeError, _UnknownPhaseError, WelcomeError) +from .errors import (LonelyError, OnlyOneCodeError, ServerError, WelcomeError, + WrongPasswordError, _UnknownPhaseError) from .util import bytes_to_dict + @attrs @implementer(_interfaces.IBoss) class Boss(object): @@ -38,7 +42,8 @@ class Boss(object): _tor = attrib(validator=optional(provides(_interfaces.ITorManager))) _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._build_workers() @@ -52,9 +57,8 @@ class Boss(object): self._K = Key(self._appid, self._versions, self._side, self._timing) self._R = Receive(self._side, self._timing) self._RC = RendezvousConnector(self._url, self._appid, self._side, - self._reactor, self._journal, - self._tor, self._timing, - self._client_version) + self._reactor, self._journal, self._tor, + self._timing, self._client_version) self._L = Lister(self._timing) self._A = Allocator(self._timing) self._I = Input(self._timing) @@ -78,7 +82,7 @@ class Boss(object): self._did_start_code = False self._next_tx_phase = 0 self._next_rx_phase = 0 - self._rx_phases = {} # phase -> plaintext + self._rx_phases = {} # phase -> plaintext self._result = "empty" @@ -86,32 +90,45 @@ class Boss(object): def start(self): self._RC.start() - def _print_trace(self, old_state, input, new_state, - client_name, machine, file): + def _print_trace(self, old_state, input, new_state, client_name, machine, + file): if new_state: - print("%s.%s[%s].%s -> [%s]" % - (client_name, machine, old_state, input, - new_state), file=file) + print( + "%s.%s[%s].%s -> [%s]" % (client_name, machine, old_state, + input, new_state), + file=file) else: # the RendezvousConnector emits message events as if # they were state transitions, except that old_state # and new_state are empty strings. "input" is one of # R.connected, R.rx(type phase+side), R.tx(type # phase), R.lost . - print("%s.%s.%s" % (client_name, machine, input), - file=file) + print("%s.%s.%s" % (client_name, machine, input), file=file) file.flush() + def output_tracer(output): - print(" %s.%s.%s()" % (client_name, machine, output), - file=file) + print(" %s.%s.%s()" % (client_name, machine, output), file=file) file.flush() + return output_tracer def _set_trace(self, client_name, which, file): - names = {"B": self, "N": self._N, "M": self._M, "S": self._S, - "O": self._O, "K": self._K, "SK": self._K._SK, "R": self._R, - "RC": self._RC, "L": self._L, "A": self._A, "I": self._I, - "C": self._C, "T": self._T} + names = { + "B": self, + "N": self._N, + "M": self._M, + "S": self._S, + "O": self._O, + "K": self._K, + "SK": self._K._SK, + "R": self._R, + "RC": self._RC, + "L": self._L, + "A": self._A, + "I": self._I, + "C": self._C, + "T": self._T + } for machine in which.split(): t = (lambda old_state, input, new_state, machine=machine: self._print_trace(old_state, input, new_state, @@ -121,21 +138,30 @@ class Boss(object): if machine == "I": self._I.set_debug(t) - ## def serialize(self): - ## raise NotImplemented + # def serialize(self): + # raise NotImplemented # and these are the state-machine transition functions, which don't take # args @m.state(initial=True) - def S0_empty(self): pass # pragma: no cover + def S0_empty(self): + pass # pragma: no cover + @m.state() - def S1_lonely(self): pass # pragma: no cover + def S1_lonely(self): + pass # pragma: no cover + @m.state() - def S2_happy(self): pass # pragma: no cover + def S2_happy(self): + pass # pragma: no cover + @m.state() - def S3_closing(self): pass # pragma: no cover + def S3_closing(self): + pass # pragma: no cover + @m.state(terminal=True) - def S4_closed(self): pass # pragma: no cover + def S4_closed(self): + pass # pragma: no cover # from the Wormhole @@ -155,23 +181,28 @@ class Boss(object): raise OnlyOneCodeError() self._did_start_code = True return self._C.input_code() + def allocate_code(self, code_length): if self._did_start_code: raise OnlyOneCodeError() self._did_start_code = True wl = PGPWordList() self._C.allocate_code(code_length, wl) + def set_code(self, code): - validate_code(code) # can raise KeyFormatError + validate_code(code) # can raise KeyFormatError if self._did_start_code: raise OnlyOneCodeError() self._did_start_code = True self._C.set_code(code) @m.input() - def send(self, plaintext): pass + def send(self, plaintext): + pass + @m.input() - def close(self): pass + def close(self): + pass # from RendezvousConnector: # * "rx_welcome" is the Welcome message, which might signal an error, or @@ -190,26 +221,36 @@ class Boss(object): # delivering a new input (rx_error or something) while in the # middle of processing the rx_welcome input, and I wasn't sure # Automat would handle that correctly. - self._W.got_welcome(welcome) # TODO: let this raise WelcomeError? + self._W.got_welcome(welcome) # TODO: let this raise WelcomeError? except WelcomeError as welcome_error: self.rx_unwelcome(welcome_error) + @m.input() - def rx_unwelcome(self, welcome_error): pass + def rx_unwelcome(self, welcome_error): + pass + @m.input() - def rx_error(self, errmsg, orig): pass + def rx_error(self, errmsg, orig): + pass + @m.input() - def error(self, err): pass + def error(self, err): + pass # from Code (provoked by input/allocate/set_code) @m.input() - def got_code(self, code): pass + def got_code(self, code): + pass # Key sends (got_key, scared) # Receive sends (got_message, happy, got_verifier, scared) @m.input() - def happy(self): pass + def happy(self): + pass + @m.input() - def scared(self): pass + def scared(self): + pass def got_message(self, phase, plaintext): assert isinstance(phase, type("")), type(phase) @@ -222,22 +263,32 @@ class Boss(object): # Ignore unrecognized phases, for forwards-compatibility. Use # log.err so tests will catch surprises. log.err(_UnknownPhaseError("received unknown phase '%s'" % phase)) + @m.input() - def _got_version(self, plaintext): pass + def _got_version(self, plaintext): + pass + @m.input() - def _got_phase(self, phase, plaintext): pass + def _got_phase(self, phase, plaintext): + pass + @m.input() - def got_key(self, key): pass + def got_key(self, key): + pass + @m.input() - def got_verifier(self, verifier): pass + def got_verifier(self, verifier): + pass # Terminator sends closed @m.input() - def closed(self): pass + def closed(self): + pass @m.output() def do_got_code(self, code): self._W.got_code(code) + @m.output() def process_version(self, plaintext): # most of this is wormhole-to-wormhole, ignored for now @@ -256,21 +307,25 @@ class Boss(object): @m.output() def close_unwelcome(self, welcome_error): - #assert isinstance(err, WelcomeError) + # assert isinstance(err, WelcomeError) self._result = welcome_error self._T.close("unwelcome") + @m.output() def close_error(self, errmsg, orig): self._result = ServerError(errmsg) self._T.close("errory") + @m.output() def close_scared(self): self._result = WrongPasswordError() self._T.close("scary") + @m.output() def close_lonely(self): self._result = LonelyError() self._T.close("lonely") + @m.output() def close_happy(self): self._result = "happy" @@ -279,9 +334,11 @@ class Boss(object): @m.output() def W_got_key(self, key): self._W.got_key(key) + @m.output() def W_got_verifier(self, verifier): self._W.got_verifier(verifier) + @m.output() def W_received(self, phase, plaintext): assert isinstance(phase, six.integer_types), type(phase) @@ -293,7 +350,7 @@ class Boss(object): @m.output() def W_close_with_error(self, err): - self._result = err # exception + self._result = err # exception self._W.closed(self._result) @m.output() diff --git a/src/wormhole/_code.py b/src/wormhole/_code.py index 96314b2..9736698 100644 --- a/src/wormhole/_code.py +++ b/src/wormhole/_code.py @@ -7,21 +7,25 @@ from . import _interfaces from ._nameplate import validate_nameplate from .errors import KeyFormatError + def validate_code(code): if ' ' in code: raise KeyFormatError("Code '%s' contains spaces." % code) nameplate = code.split("-", 2)[0] - validate_nameplate(nameplate) # can raise KeyFormatError + validate_nameplate(nameplate) # can raise KeyFormatError + def first(outputs): return list(outputs)[0] + @attrs @implementer(_interfaces.ICode) class Code(object): _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def wire(self, boss, allocator, nameplate, key, input): self._B = _interfaces.IBoss(boss) @@ -31,36 +35,55 @@ class Code(object): self._I = _interfaces.IInput(input) @m.state(initial=True) - def S0_idle(self): pass # pragma: no cover + def S0_idle(self): + pass # pragma: no cover + @m.state() - def S1_inputting_nameplate(self): pass # pragma: no cover + def S1_inputting_nameplate(self): + pass # pragma: no cover + @m.state() - def S2_inputting_words(self): pass # pragma: no cover + def S2_inputting_words(self): + pass # pragma: no cover + @m.state() - def S3_allocating(self): pass # pragma: no cover + def S3_allocating(self): + pass # pragma: no cover + @m.state() - def S4_known(self): pass # pragma: no cover + def S4_known(self): + pass # pragma: no cover # from App @m.input() - def allocate_code(self, length, wordlist): pass + def allocate_code(self, length, wordlist): + pass + @m.input() - def input_code(self): pass + def input_code(self): + pass + def set_code(self, code): - validate_code(code) # can raise KeyFormatError + validate_code(code) # can raise KeyFormatError self._set_code(code) + @m.input() - def _set_code(self, code): pass + def _set_code(self, code): + pass # from Allocator @m.input() - def allocated(self, nameplate, code): pass + def allocated(self, nameplate, code): + pass # from Input @m.input() - def got_nameplate(self, nameplate): pass + def got_nameplate(self, nameplate): + pass + @m.input() - def finished_input(self, code): pass + def finished_input(self, code): + pass @m.output() def do_set_code(self, code): @@ -72,9 +95,11 @@ class Code(object): @m.output() def do_start_input(self): return self._I.start() + @m.output() def do_middle_input(self, nameplate): self._N.set_nameplate(nameplate) + @m.output() def do_finish_input(self, code): self._B.got_code(code) @@ -83,19 +108,24 @@ class Code(object): @m.output() def do_start_allocate(self, length, wordlist): self._A.allocate(length, wordlist) + @m.output() def do_finish_allocate(self, nameplate, code): - assert code.startswith(nameplate+"-"), (nameplate, code) + assert code.startswith(nameplate + "-"), (nameplate, code) self._N.set_nameplate(nameplate) self._B.got_code(code) self._K.got_code(code) S0_idle.upon(_set_code, enter=S4_known, outputs=[do_set_code]) - S0_idle.upon(input_code, enter=S1_inputting_nameplate, - outputs=[do_start_input], collector=first) - S1_inputting_nameplate.upon(got_nameplate, enter=S2_inputting_words, - outputs=[do_middle_input]) - S2_inputting_words.upon(finished_input, enter=S4_known, - outputs=[do_finish_input]) - S0_idle.upon(allocate_code, enter=S3_allocating, outputs=[do_start_allocate]) + S0_idle.upon( + input_code, + enter=S1_inputting_nameplate, + outputs=[do_start_input], + collector=first) + S1_inputting_nameplate.upon( + got_nameplate, enter=S2_inputting_words, outputs=[do_middle_input]) + S2_inputting_words.upon( + finished_input, enter=S4_known, outputs=[do_finish_input]) + S0_idle.upon( + allocate_code, enter=S3_allocating, outputs=[do_start_allocate]) S3_allocating.upon(allocated, enter=S4_known, outputs=[do_finish_allocate]) diff --git a/src/wormhole/_input.py b/src/wormhole/_input.py index 7c4f74a..7993fda 100644 --- a/src/wormhole/_input.py +++ b/src/wormhole/_input.py @@ -1,25 +1,31 @@ -from __future__ import print_function, absolute_import, unicode_literals +from __future__ import absolute_import, print_function, unicode_literals + # We use 'threading' defensively here, to detect if we're being called from a # non-main thread. _rlcompleter.py is the only internal Wormhole code that # deliberately creates a new thread. import threading -from zope.interface import implementer -from attr import attrs, attrib + +from attr import attrib, attrs from attr.validators import provides -from twisted.internet import defer from automat import MethodicalMachine +from twisted.internet import defer +from zope.interface import implementer + from . import _interfaces, errors from ._nameplate import validate_nameplate + def first(outputs): return list(outputs)[0] + @attrs @implementer(_interfaces.IInput) class Input(object): _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._all_nameplates = set() @@ -30,7 +36,8 @@ class Input(object): def set_debug(self, f): self._trace = f - def _debug(self, what): # pragma: no cover + + def _debug(self, what): # pragma: no cover if self._trace: self._trace(old_state="", input=what, new_state="") @@ -46,55 +53,80 @@ class Input(object): return d @m.state(initial=True) - def S0_idle(self): pass # pragma: no cover + def S0_idle(self): + pass # pragma: no cover + @m.state() - def S1_typing_nameplate(self): pass # pragma: no cover + def S1_typing_nameplate(self): + pass # pragma: no cover + @m.state() - def S2_typing_code_no_wordlist(self): pass # pragma: no cover + def S2_typing_code_no_wordlist(self): + pass # pragma: no cover + @m.state() - def S3_typing_code_yes_wordlist(self): pass # pragma: no cover + def S3_typing_code_yes_wordlist(self): + pass # pragma: no cover + @m.state(terminal=True) - def S4_done(self): pass # pragma: no cover + def S4_done(self): + pass # pragma: no cover # from Code @m.input() - def start(self): pass + def start(self): + pass # from Lister @m.input() - def got_nameplates(self, all_nameplates): pass + def got_nameplates(self, all_nameplates): + pass # from Nameplate @m.input() - def got_wordlist(self, wordlist): pass + def got_wordlist(self, wordlist): + pass # API provided to app as ICodeInputHelper @m.input() - def refresh_nameplates(self): pass + def refresh_nameplates(self): + pass + @m.input() - def get_nameplate_completions(self, prefix): pass + def get_nameplate_completions(self, prefix): + pass + def choose_nameplate(self, nameplate): - validate_nameplate(nameplate) # can raise KeyFormatError + validate_nameplate(nameplate) # can raise KeyFormatError self._choose_nameplate(nameplate) + @m.input() - def _choose_nameplate(self, nameplate): pass + def _choose_nameplate(self, nameplate): + pass + @m.input() - def get_word_completions(self, prefix): pass + def get_word_completions(self, prefix): + pass + @m.input() - def choose_words(self, words): pass + def choose_words(self, words): + pass @m.output() def do_start(self): self._start_timing = self._timing.add("input code", waiting="user") self._L.refresh() return Helper(self) + @m.output() def do_refresh(self): self._L.refresh() + @m.output() def record_nameplates(self, all_nameplates): # we get a set of nameplate id strings self._all_nameplates = all_nameplates + @m.output() def _get_nameplate_completions(self, prefix): completions = set() @@ -102,17 +134,20 @@ class Input(object): if nameplate.startswith(prefix): # TODO: it's a little weird that Input is responsible for the # hyphen on nameplates, but WordList owns it for words - completions.add(nameplate+"-") + completions.add(nameplate + "-") return completions + @m.output() def record_all_nameplates(self, nameplate): self._nameplate = nameplate self._C.got_nameplate(nameplate) + @m.output() def record_wordlist(self, wordlist): from ._rlcompleter import debug debug(" -record_wordlist") self._wordlist = wordlist + @m.output() def notify_wordlist_waiters(self, wordlist): while self._wordlist_waiters: @@ -122,6 +157,7 @@ class Input(object): @m.output() def no_word_completions(self, prefix): return set() + @m.output() def _get_word_completions(self, prefix): assert self._wordlist @@ -130,21 +166,27 @@ class Input(object): @m.output() def raise_must_choose_nameplate1(self, prefix): raise errors.MustChooseNameplateFirstError() + @m.output() def raise_must_choose_nameplate2(self, words): raise errors.MustChooseNameplateFirstError() + @m.output() def raise_already_chose_nameplate1(self): raise errors.AlreadyChoseNameplateError() + @m.output() def raise_already_chose_nameplate2(self, prefix): raise errors.AlreadyChoseNameplateError() + @m.output() def raise_already_chose_nameplate3(self, nameplate): raise errors.AlreadyChoseNameplateError() + @m.output() def raise_already_chose_words1(self, prefix): raise errors.AlreadyChoseWordsError() + @m.output() def raise_already_chose_words2(self, words): raise errors.AlreadyChoseWordsError() @@ -155,88 +197,110 @@ class Input(object): self._start_timing.finish() self._C.finished_input(code) - S0_idle.upon(start, enter=S1_typing_nameplate, - outputs=[do_start], collector=first) + S0_idle.upon( + start, enter=S1_typing_nameplate, outputs=[do_start], collector=first) # wormholes that don't use input_code (i.e. they use allocate_code or # generate_code) will never start() us, but Nameplate will give us a # wordlist anyways (as soon as the nameplate is claimed), so handle it. - S0_idle.upon(got_wordlist, enter=S0_idle, outputs=[record_wordlist, - notify_wordlist_waiters]) - S1_typing_nameplate.upon(got_nameplates, enter=S1_typing_nameplate, - outputs=[record_nameplates]) + S0_idle.upon( + got_wordlist, + enter=S0_idle, + outputs=[record_wordlist, notify_wordlist_waiters]) + S1_typing_nameplate.upon( + got_nameplates, enter=S1_typing_nameplate, outputs=[record_nameplates]) # but wormholes that *do* use input_code should not get got_wordlist # until after we tell Code that we got_nameplate, which is the earliest # it can be claimed - S1_typing_nameplate.upon(refresh_nameplates, enter=S1_typing_nameplate, - outputs=[do_refresh]) - S1_typing_nameplate.upon(get_nameplate_completions, - enter=S1_typing_nameplate, - outputs=[_get_nameplate_completions], - collector=first) - S1_typing_nameplate.upon(_choose_nameplate, enter=S2_typing_code_no_wordlist, - outputs=[record_all_nameplates]) - S1_typing_nameplate.upon(get_word_completions, - enter=S1_typing_nameplate, - outputs=[raise_must_choose_nameplate1]) - S1_typing_nameplate.upon(choose_words, enter=S1_typing_nameplate, - outputs=[raise_must_choose_nameplate2]) + S1_typing_nameplate.upon( + refresh_nameplates, enter=S1_typing_nameplate, outputs=[do_refresh]) + S1_typing_nameplate.upon( + get_nameplate_completions, + enter=S1_typing_nameplate, + outputs=[_get_nameplate_completions], + collector=first) + S1_typing_nameplate.upon( + _choose_nameplate, + enter=S2_typing_code_no_wordlist, + outputs=[record_all_nameplates]) + S1_typing_nameplate.upon( + get_word_completions, + enter=S1_typing_nameplate, + outputs=[raise_must_choose_nameplate1]) + S1_typing_nameplate.upon( + choose_words, + enter=S1_typing_nameplate, + outputs=[raise_must_choose_nameplate2]) - S2_typing_code_no_wordlist.upon(got_nameplates, - enter=S2_typing_code_no_wordlist, outputs=[]) - S2_typing_code_no_wordlist.upon(got_wordlist, - enter=S3_typing_code_yes_wordlist, - outputs=[record_wordlist, - notify_wordlist_waiters]) - S2_typing_code_no_wordlist.upon(refresh_nameplates, - enter=S2_typing_code_no_wordlist, - outputs=[raise_already_chose_nameplate1]) - S2_typing_code_no_wordlist.upon(get_nameplate_completions, - enter=S2_typing_code_no_wordlist, - outputs=[raise_already_chose_nameplate2]) - S2_typing_code_no_wordlist.upon(_choose_nameplate, - enter=S2_typing_code_no_wordlist, - outputs=[raise_already_chose_nameplate3]) - S2_typing_code_no_wordlist.upon(get_word_completions, - enter=S2_typing_code_no_wordlist, - outputs=[no_word_completions], - collector=first) - S2_typing_code_no_wordlist.upon(choose_words, enter=S4_done, - outputs=[do_words]) + S2_typing_code_no_wordlist.upon( + got_nameplates, enter=S2_typing_code_no_wordlist, outputs=[]) + S2_typing_code_no_wordlist.upon( + got_wordlist, + enter=S3_typing_code_yes_wordlist, + outputs=[record_wordlist, notify_wordlist_waiters]) + S2_typing_code_no_wordlist.upon( + refresh_nameplates, + enter=S2_typing_code_no_wordlist, + outputs=[raise_already_chose_nameplate1]) + S2_typing_code_no_wordlist.upon( + get_nameplate_completions, + enter=S2_typing_code_no_wordlist, + outputs=[raise_already_chose_nameplate2]) + S2_typing_code_no_wordlist.upon( + _choose_nameplate, + enter=S2_typing_code_no_wordlist, + outputs=[raise_already_chose_nameplate3]) + S2_typing_code_no_wordlist.upon( + get_word_completions, + enter=S2_typing_code_no_wordlist, + outputs=[no_word_completions], + collector=first) + S2_typing_code_no_wordlist.upon( + choose_words, enter=S4_done, outputs=[do_words]) - S3_typing_code_yes_wordlist.upon(got_nameplates, - enter=S3_typing_code_yes_wordlist, - outputs=[]) + S3_typing_code_yes_wordlist.upon( + got_nameplates, enter=S3_typing_code_yes_wordlist, outputs=[]) # got_wordlist: should never happen - S3_typing_code_yes_wordlist.upon(refresh_nameplates, - enter=S3_typing_code_yes_wordlist, - outputs=[raise_already_chose_nameplate1]) - S3_typing_code_yes_wordlist.upon(get_nameplate_completions, - enter=S3_typing_code_yes_wordlist, - outputs=[raise_already_chose_nameplate2]) - S3_typing_code_yes_wordlist.upon(_choose_nameplate, - enter=S3_typing_code_yes_wordlist, - outputs=[raise_already_chose_nameplate3]) - S3_typing_code_yes_wordlist.upon(get_word_completions, - enter=S3_typing_code_yes_wordlist, - outputs=[_get_word_completions], - collector=first) - S3_typing_code_yes_wordlist.upon(choose_words, enter=S4_done, - outputs=[do_words]) + S3_typing_code_yes_wordlist.upon( + refresh_nameplates, + enter=S3_typing_code_yes_wordlist, + outputs=[raise_already_chose_nameplate1]) + S3_typing_code_yes_wordlist.upon( + get_nameplate_completions, + enter=S3_typing_code_yes_wordlist, + outputs=[raise_already_chose_nameplate2]) + S3_typing_code_yes_wordlist.upon( + _choose_nameplate, + enter=S3_typing_code_yes_wordlist, + outputs=[raise_already_chose_nameplate3]) + S3_typing_code_yes_wordlist.upon( + get_word_completions, + enter=S3_typing_code_yes_wordlist, + outputs=[_get_word_completions], + collector=first) + S3_typing_code_yes_wordlist.upon( + choose_words, enter=S4_done, outputs=[do_words]) S4_done.upon(got_nameplates, enter=S4_done, outputs=[]) S4_done.upon(got_wordlist, enter=S4_done, outputs=[]) - S4_done.upon(refresh_nameplates, - enter=S4_done, - outputs=[raise_already_chose_nameplate1]) - S4_done.upon(get_nameplate_completions, - enter=S4_done, - outputs=[raise_already_chose_nameplate2]) - S4_done.upon(_choose_nameplate, enter=S4_done, - outputs=[raise_already_chose_nameplate3]) - S4_done.upon(get_word_completions, enter=S4_done, - outputs=[raise_already_chose_words1]) - S4_done.upon(choose_words, enter=S4_done, - outputs=[raise_already_chose_words2]) + S4_done.upon( + refresh_nameplates, + enter=S4_done, + outputs=[raise_already_chose_nameplate1]) + S4_done.upon( + get_nameplate_completions, + enter=S4_done, + outputs=[raise_already_chose_nameplate2]) + S4_done.upon( + _choose_nameplate, + enter=S4_done, + outputs=[raise_already_chose_nameplate3]) + S4_done.upon( + get_word_completions, + enter=S4_done, + outputs=[raise_already_chose_words1]) + S4_done.upon( + choose_words, enter=S4_done, outputs=[raise_already_chose_words2]) + # we only expose the Helper to application code, not _Input @attrs @@ -250,20 +314,25 @@ class Helper(object): def refresh_nameplates(self): assert threading.current_thread().ident == self._main_thread self._input.refresh_nameplates() + def get_nameplate_completions(self, prefix): assert threading.current_thread().ident == self._main_thread return self._input.get_nameplate_completions(prefix) + def choose_nameplate(self, nameplate): assert threading.current_thread().ident == self._main_thread self._input._debug("I.choose_nameplate") self._input.choose_nameplate(nameplate) self._input._debug("I.choose_nameplate finished") + def when_wordlist_is_available(self): assert threading.current_thread().ident == self._main_thread return self._input.when_wordlist_is_available() + def get_word_completions(self, prefix): assert threading.current_thread().ident == self._main_thread return self._input.get_word_completions(prefix) + def choose_words(self, words): assert threading.current_thread().ident == self._main_thread self._input._debug("I.choose_words") diff --git a/src/wormhole/_interfaces.py b/src/wormhole/_interfaces.py index 060ff22..9f59ff4 100644 --- a/src/wormhole/_interfaces.py +++ b/src/wormhole/_interfaces.py @@ -3,64 +3,105 @@ from zope.interface import Interface # These interfaces are private: we use them as markers to detect # swapped argument bugs in the various .wire() calls + class IWormhole(Interface): """Internal: this contains the methods invoked 'from below'.""" + def got_welcome(welcome): pass + def got_code(code): pass + def got_key(key): pass + def got_verifier(verifier): pass + def got_versions(versions): pass + def received(plaintext): pass + def closed(result): pass + class IBoss(Interface): pass + + class INameplate(Interface): pass + + class IMailbox(Interface): pass + + class ISend(Interface): pass + + class IOrder(Interface): pass + + class IKey(Interface): pass + + class IReceive(Interface): pass + + class IRendezvousConnector(Interface): pass + + class ILister(Interface): pass + + class ICode(Interface): pass + + class IInput(Interface): pass + + class IAllocator(Interface): pass + + class ITerminator(Interface): pass + class ITiming(Interface): pass + + class ITorManager(Interface): pass + + class IWordlist(Interface): def choose_words(length): """Randomly select LENGTH words, join them with hyphens, return the result.""" + def get_completions(prefix): """Return a list of all suffixes that could complete the given prefix.""" + # These interfaces are public, and are re-exported by __init__.py + class IDeferredWormhole(Interface): def get_welcome(): """ @@ -277,6 +318,7 @@ class IDeferredWormhole(Interface): :rtype: ``Deferred`` """ + class IInputHelper(Interface): def refresh_nameplates(): """ @@ -389,5 +431,5 @@ class IInputHelper(Interface): """ -class IJournal(Interface): # TODO: this needs to be public +class IJournal(Interface): # TODO: this needs to be public pass diff --git a/src/wormhole/_key.py b/src/wormhole/_key.py index f76f454..849ed41 100644 --- a/src/wormhole/_key.py +++ b/src/wormhole/_key.py @@ -1,41 +1,50 @@ -from __future__ import print_function, absolute_import, unicode_literals +from __future__ import absolute_import, print_function, unicode_literals + from hashlib import sha256 + import six -from zope.interface import implementer -from attr import attrs, attrib -from attr.validators import provides, instance_of -from spake2 import SPAKE2_Symmetric -from hkdf import Hkdf -from nacl.secret import SecretBox -from nacl.exceptions import CryptoError -from nacl import utils +from attr import attrib, attrs +from attr.validators import instance_of, provides from automat import MethodicalMachine -from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes, - bytes_to_dict, dict_to_bytes) +from hkdf import Hkdf +from nacl import utils +from nacl.exceptions import CryptoError +from nacl.secret import SecretBox +from spake2 import SPAKE2_Symmetric +from zope.interface import implementer + from . import _interfaces +from .util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes, + hexstr_to_bytes, to_bytes) + CryptoError -__all__ = ["derive_key", "derive_phase_key", "CryptoError", - "Key"] +__all__ = ["derive_key", "derive_phase_key", "CryptoError", "Key"] + def HKDF(skm, outlen, salt=None, CTXinfo=b""): return Hkdf(salt, skm).expand(CTXinfo, outlen) + def derive_key(key, purpose, length=SecretBox.KEY_SIZE): - if not isinstance(key, type(b"")): raise TypeError(type(key)) - if not isinstance(purpose, type(b"")): raise TypeError(type(purpose)) - if not isinstance(length, six.integer_types): raise TypeError(type(length)) + if not isinstance(key, type(b"")): + raise TypeError(type(key)) + if not isinstance(purpose, type(b"")): + raise TypeError(type(purpose)) + if not isinstance(length, six.integer_types): + raise TypeError(type(length)) return HKDF(key, length, CTXinfo=purpose) + def derive_phase_key(key, side, phase): assert isinstance(side, type("")), type(side) assert isinstance(phase, type("")), type(phase) side_bytes = side.encode("ascii") phase_bytes = phase.encode("ascii") - purpose = (b"wormhole:phase:" - + sha256(side_bytes).digest() - + sha256(phase_bytes).digest()) + purpose = (b"wormhole:phase:" + sha256(side_bytes).digest() + + sha256(phase_bytes).digest()) return derive_key(key, purpose) + def decrypt_data(key, encrypted): assert isinstance(key, type(b"")), type(key) assert isinstance(encrypted, type(b"")), type(encrypted) @@ -44,6 +53,7 @@ def decrypt_data(key, encrypted): data = box.decrypt(encrypted) return data + def encrypt_data(key, plaintext): assert isinstance(key, type(b"")), type(key) assert isinstance(plaintext, type(b"")), type(plaintext) @@ -52,10 +62,12 @@ def encrypt_data(key, plaintext): nonce = utils.random(SecretBox.NONCE_SIZE) return box.encrypt(plaintext, nonce) + # the Key we expose to callers (Boss, Ordering) is responsible for sorting # the two messages (got_code and got_pake), then delivering them to # _SortedKey in the right order. + @attrs @implementer(_interfaces.IKey) class Key(object): @@ -64,40 +76,54 @@ class Key(object): _side = attrib(validator=instance_of(type(u""))) _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._SK = _SortedKey(self._appid, self._versions, self._side, self._timing) - self._debug_pake_stashed = False # for tests + self._debug_pake_stashed = False # for tests def wire(self, boss, mailbox, receive): self._SK.wire(boss, mailbox, receive) @m.state(initial=True) - def S00(self): pass # pragma: no cover + def S00(self): + pass # pragma: no cover + @m.state() - def S01(self): pass # pragma: no cover + def S01(self): + pass # pragma: no cover + @m.state() - def S10(self): pass # pragma: no cover + def S10(self): + pass # pragma: no cover + @m.state() - def S11(self): pass # pragma: no cover + def S11(self): + pass # pragma: no cover @m.input() - def got_code(self, code): pass + def got_code(self, code): + pass + @m.input() - def got_pake(self, body): pass + def got_pake(self, body): + pass @m.output() def stash_pake(self, body): self._pake = body self._debug_pake_stashed = True + @m.output() def deliver_code(self, code): self._SK.got_code(code) + @m.output() def deliver_pake(self, body): self._SK.got_pake(body) + @m.output() def deliver_code_and_stashed_pake(self, code): self._SK.got_code(code) @@ -108,6 +134,7 @@ class Key(object): S00.upon(got_pake, enter=S01, outputs=[stash_pake]) S01.upon(got_code, enter=S11, outputs=[deliver_code_and_stashed_pake]) + @attrs class _SortedKey(object): _appid = attrib(validator=instance_of(type(u""))) @@ -115,7 +142,8 @@ class _SortedKey(object): _side = attrib(validator=instance_of(type(u""))) _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def wire(self, boss, mailbox, receive): self._B = _interfaces.IBoss(boss) @@ -123,17 +151,25 @@ class _SortedKey(object): self._R = _interfaces.IReceive(receive) @m.state(initial=True) - def S0_know_nothing(self): pass # pragma: no cover + def S0_know_nothing(self): + pass # pragma: no cover + @m.state() - def S1_know_code(self): pass # pragma: no cover + def S1_know_code(self): + pass # pragma: no cover + @m.state() - def S2_know_key(self): pass # pragma: no cover + def S2_know_key(self): + pass # pragma: no cover + @m.state(terminal=True) - def S3_scared(self): pass # pragma: no cover + def S3_scared(self): + pass # pragma: no cover # from Boss @m.input() - def got_code(self, code): pass + def got_code(self, code): + pass # from Ordering def got_pake(self, body): @@ -143,16 +179,20 @@ class _SortedKey(object): self.got_pake_good(hexstr_to_bytes(payload["pake_v1"])) else: self.got_pake_bad() + @m.input() - def got_pake_good(self, msg2): pass + def got_pake_good(self, msg2): + pass + @m.input() - def got_pake_bad(self): pass + def got_pake_bad(self): + pass @m.output() def build_pake(self, code): with self._timing.add("pake1", waiting="crypto"): - self._sp = SPAKE2_Symmetric(to_bytes(code), - idSymmetric=to_bytes(self._appid)) + self._sp = SPAKE2_Symmetric( + to_bytes(code), idSymmetric=to_bytes(self._appid)) msg1 = self._sp.start() body = dict_to_bytes({"pake_v1": bytes_to_hexstr(msg1)}) self._M.add_message("pake", body) @@ -160,6 +200,7 @@ class _SortedKey(object): @m.output() def scared(self): self._B.scared() + @m.output() def compute_key(self, msg2): assert isinstance(msg2, type(b"")) diff --git a/src/wormhole/_lister.py b/src/wormhole/_lister.py index ebe7919..de085fa 100644 --- a/src/wormhole/_lister.py +++ b/src/wormhole/_lister.py @@ -1,16 +1,20 @@ -from __future__ import print_function, absolute_import, unicode_literals -from zope.interface import implementer -from attr import attrs, attrib +from __future__ import absolute_import, print_function, unicode_literals + +from attr import attrib, attrs from attr.validators import provides from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces + @attrs @implementer(_interfaces.ILister) class Lister(object): _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def wire(self, rendezvous_connector, input): self._RC = _interfaces.IRendezvousConnector(rendezvous_connector) @@ -26,26 +30,41 @@ class Lister(object): # request arrives, both requests will be satisfied by the same response. @m.state(initial=True) - def S0A_idle_disconnected(self): pass # pragma: no cover + def S0A_idle_disconnected(self): + pass # pragma: no cover + @m.state() - def S1A_wanting_disconnected(self): pass # pragma: no cover + def S1A_wanting_disconnected(self): + pass # pragma: no cover + @m.state() - def S0B_idle_connected(self): pass # pragma: no cover + def S0B_idle_connected(self): + pass # pragma: no cover + @m.state() - def S1B_wanting_connected(self): pass # pragma: no cover + def S1B_wanting_connected(self): + pass # pragma: no cover @m.input() - def connected(self): pass + def connected(self): + pass + @m.input() - def lost(self): pass + def lost(self): + pass + @m.input() - def refresh(self): pass + def refresh(self): + pass + @m.input() - def rx_nameplates(self, all_nameplates): pass + def rx_nameplates(self, all_nameplates): + pass @m.output() def RC_tx_list(self): self._RC.tx_list() + @m.output() def I_got_nameplates(self, all_nameplates): # We get a set of nameplate ids. There may be more attributes in the @@ -56,18 +75,19 @@ class Lister(object): S0A_idle_disconnected.upon(connected, enter=S0B_idle_connected, outputs=[]) S0B_idle_connected.upon(lost, enter=S0A_idle_disconnected, outputs=[]) - S0A_idle_disconnected.upon(refresh, - enter=S1A_wanting_disconnected, outputs=[]) - S1A_wanting_disconnected.upon(refresh, - enter=S1A_wanting_disconnected, outputs=[]) - S1A_wanting_disconnected.upon(connected, enter=S1B_wanting_connected, - outputs=[RC_tx_list]) - S0B_idle_connected.upon(refresh, enter=S1B_wanting_connected, - outputs=[RC_tx_list]) - S0B_idle_connected.upon(rx_nameplates, enter=S0B_idle_connected, - outputs=[I_got_nameplates]) - S1B_wanting_connected.upon(lost, enter=S1A_wanting_disconnected, outputs=[]) - S1B_wanting_connected.upon(refresh, enter=S1B_wanting_connected, - outputs=[RC_tx_list]) - S1B_wanting_connected.upon(rx_nameplates, enter=S0B_idle_connected, - outputs=[I_got_nameplates]) + S0A_idle_disconnected.upon( + refresh, enter=S1A_wanting_disconnected, outputs=[]) + S1A_wanting_disconnected.upon( + refresh, enter=S1A_wanting_disconnected, outputs=[]) + S1A_wanting_disconnected.upon( + connected, enter=S1B_wanting_connected, outputs=[RC_tx_list]) + S0B_idle_connected.upon( + refresh, enter=S1B_wanting_connected, outputs=[RC_tx_list]) + S0B_idle_connected.upon( + rx_nameplates, enter=S0B_idle_connected, outputs=[I_got_nameplates]) + S1B_wanting_connected.upon( + lost, enter=S1A_wanting_disconnected, outputs=[]) + S1B_wanting_connected.upon( + refresh, enter=S1B_wanting_connected, outputs=[RC_tx_list]) + S1B_wanting_connected.upon( + rx_nameplates, enter=S0B_idle_connected, outputs=[I_got_nameplates]) diff --git a/src/wormhole/_mailbox.py b/src/wormhole/_mailbox.py index 36167e6..6a6a956 100644 --- a/src/wormhole/_mailbox.py +++ b/src/wormhole/_mailbox.py @@ -1,16 +1,20 @@ -from __future__ import print_function, absolute_import, unicode_literals -from zope.interface import implementer -from attr import attrs, attrib +from __future__ import absolute_import, print_function, unicode_literals + +from attr import attrib, attrs from attr.validators import instance_of from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces + @attrs @implementer(_interfaces.IMailbox) class Mailbox(object): _side = attrib(validator=instance_of(type(u""))) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._mailbox = None @@ -29,52 +33,68 @@ class Mailbox(object): # S0: know nothing @m.state(initial=True) - def S0A(self): pass # pragma: no cover + def S0A(self): + pass # pragma: no cover + @m.state() - def S0B(self): pass # pragma: no cover + def S0B(self): + pass # pragma: no cover # S1: mailbox known, not opened @m.state() - def S1A(self): pass # pragma: no cover + def S1A(self): + pass # pragma: no cover # S2: mailbox known, opened # We've definitely tried to open the mailbox at least once, but it must # be re-opened with each connection, because open() is also subscribe() @m.state() - def S2A(self): pass # pragma: no cover + def S2A(self): + pass # pragma: no cover + @m.state() - def S2B(self): pass # pragma: no cover + def S2B(self): + pass # pragma: no cover # S3: closing @m.state() - def S3A(self): pass # pragma: no cover + def S3A(self): + pass # pragma: no cover + @m.state() - def S3B(self): pass # pragma: no cover + def S3B(self): + pass # pragma: no cover # S4: closed. We no longer care whether we're connected or not - #@m.state() - #def S4A(self): pass - #@m.state() - #def S4B(self): pass + # @m.state() + # def S4A(self): pass + # @m.state() + # def S4B(self): pass @m.state(terminal=True) - def S4(self): pass # pragma: no cover + def S4(self): + pass # pragma: no cover + S4A = S4 S4B = S4 - # from Terminator @m.input() - def close(self, mood): pass + def close(self, mood): + pass # from Nameplate @m.input() - def got_mailbox(self, mailbox): pass + def got_mailbox(self, mailbox): + pass # from RendezvousConnector @m.input() - def connected(self): pass + def connected(self): + pass + @m.input() - def lost(self): pass + def lost(self): + pass def rx_message(self, side, phase, body): assert isinstance(side, type("")), type(side) @@ -84,73 +104,91 @@ class Mailbox(object): self.rx_message_ours(phase, body) else: self.rx_message_theirs(side, phase, body) + @m.input() - def rx_message_ours(self, phase, body): pass + def rx_message_ours(self, phase, body): + pass + @m.input() - def rx_message_theirs(self, side, phase, body): pass + def rx_message_theirs(self, side, phase, body): + pass + @m.input() - def rx_closed(self): pass + def rx_closed(self): + pass # from Send or Key @m.input() def add_message(self, phase, body): pass - @m.output() def record_mailbox(self, mailbox): self._mailbox = mailbox + @m.output() def RC_tx_open(self): assert self._mailbox self._RC.tx_open(self._mailbox) + @m.output() def queue(self, phase, body): assert isinstance(phase, type("")), type(phase) assert isinstance(body, type(b"")), (type(body), phase, body) self._pending_outbound[phase] = body + @m.output() def record_mailbox_and_RC_tx_open_and_drain(self, mailbox): self._mailbox = mailbox self._RC.tx_open(mailbox) self._drain() + @m.output() def drain(self): self._drain() + def _drain(self): for phase, body in self._pending_outbound.items(): self._RC.tx_add(phase, body) + @m.output() def RC_tx_add(self, phase, body): assert isinstance(phase, type("")), type(phase) assert isinstance(body, type(b"")), type(body) self._RC.tx_add(phase, body) + @m.output() def N_release_and_accept(self, side, phase, body): self._N.release() if phase not in self._processed: self._processed.add(phase) self._O.got_message(side, phase, body) + @m.output() def RC_tx_close(self): assert self._mood self._RC_tx_close() + def _RC_tx_close(self): self._RC.tx_close(self._mailbox, self._mood) @m.output() def dequeue(self, phase, body): self._pending_outbound.pop(phase, None) + @m.output() def record_mood(self, mood): self._mood = mood + @m.output() def record_mood_and_RC_tx_close(self, mood): self._mood = mood self._RC_tx_close() + @m.output() def ignore_mood_and_T_mailbox_done(self, mood): self._T.mailbox_done() + @m.output() def T_mailbox_done(self): self._T.mailbox_done() @@ -162,8 +200,10 @@ class Mailbox(object): S0B.upon(lost, enter=S0A, outputs=[]) S0B.upon(add_message, enter=S0B, outputs=[queue]) S0B.upon(close, enter=S4B, outputs=[ignore_mood_and_T_mailbox_done]) - S0B.upon(got_mailbox, enter=S2B, - outputs=[record_mailbox_and_RC_tx_open_and_drain]) + S0B.upon( + got_mailbox, + enter=S2B, + outputs=[record_mailbox_and_RC_tx_open_and_drain]) S1A.upon(connected, enter=S2B, outputs=[RC_tx_open, drain]) S1A.upon(add_message, enter=S1A, outputs=[queue]) @@ -192,4 +232,3 @@ class Mailbox(object): S4.upon(rx_message_theirs, enter=S4, outputs=[]) S4.upon(rx_message_ours, enter=S4, outputs=[]) S4.upon(close, enter=S4, outputs=[]) - diff --git a/src/wormhole/_nameplate.py b/src/wormhole/_nameplate.py index b0d673c..b41fa72 100644 --- a/src/wormhole/_nameplate.py +++ b/src/wormhole/_nameplate.py @@ -1,7 +1,10 @@ -from __future__ import print_function, absolute_import, unicode_literals +from __future__ import absolute_import, print_function, unicode_literals + import re -from zope.interface import implementer + from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces from ._wordlist import PGPWordList from .errors import KeyFormatError @@ -9,13 +12,15 @@ from .errors import KeyFormatError def validate_nameplate(nameplate): if not re.search(r'^\d+$', nameplate): - raise KeyFormatError("Nameplate '%s' must be numeric, with no spaces." - % nameplate) + raise KeyFormatError( + "Nameplate '%s' must be numeric, with no spaces." % nameplate) + @implementer(_interfaces.INameplate) class Nameplate(object): m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __init__(self): self._nameplate = None @@ -32,94 +37,125 @@ class Nameplate(object): # S0: know nothing @m.state(initial=True) - def S0A(self): pass # pragma: no cover + def S0A(self): + pass # pragma: no cover + @m.state() - def S0B(self): pass # pragma: no cover + def S0B(self): + pass # pragma: no cover # S1: nameplate known, never claimed @m.state() - def S1A(self): pass # pragma: no cover + def S1A(self): + pass # pragma: no cover # S2: nameplate known, maybe claimed @m.state() - def S2A(self): pass # pragma: no cover + def S2A(self): + pass # pragma: no cover + @m.state() - def S2B(self): pass # pragma: no cover + def S2B(self): + pass # pragma: no cover # S3: nameplate claimed @m.state() - def S3A(self): pass # pragma: no cover + def S3A(self): + pass # pragma: no cover + @m.state() - def S3B(self): pass # pragma: no cover + def S3B(self): + pass # pragma: no cover # S4: maybe released @m.state() - def S4A(self): pass # pragma: no cover + def S4A(self): + pass # pragma: no cover + @m.state() - def S4B(self): pass # pragma: no cover + def S4B(self): + pass # pragma: no cover # S5: released # we no longer care whether we're connected or not - #@m.state() - #def S5A(self): pass - #@m.state() - #def S5B(self): pass + # @m.state() + # def S5A(self): pass + # @m.state() + # def S5B(self): pass @m.state() - def S5(self): pass # pragma: no cover + def S5(self): + pass # pragma: no cover + S5A = S5 S5B = S5 # from Boss def set_nameplate(self, nameplate): - validate_nameplate(nameplate) # can raise KeyFormatError + validate_nameplate(nameplate) # can raise KeyFormatError self._set_nameplate(nameplate) + @m.input() - def _set_nameplate(self, nameplate): pass + def _set_nameplate(self, nameplate): + pass # from Mailbox @m.input() - def release(self): pass + def release(self): + pass # from Terminator @m.input() - def close(self): pass + def close(self): + pass # from RendezvousConnector @m.input() - def connected(self): pass - @m.input() - def lost(self): pass + def connected(self): + pass @m.input() - def rx_claimed(self, mailbox): pass + def lost(self): + pass + @m.input() - def rx_released(self): pass + def rx_claimed(self, mailbox): + pass + + @m.input() + def rx_released(self): + pass @m.output() def record_nameplate(self, nameplate): validate_nameplate(nameplate) self._nameplate = nameplate + @m.output() def record_nameplate_and_RC_tx_claim(self, nameplate): validate_nameplate(nameplate) self._nameplate = nameplate self._RC.tx_claim(self._nameplate) + @m.output() def RC_tx_claim(self): # when invoked via M.connected(), we must use the stored nameplate self._RC.tx_claim(self._nameplate) + @m.output() def I_got_wordlist(self, mailbox): # TODO select wordlist based on nameplate properties, in rx_claimed wordlist = PGPWordList() self._I.got_wordlist(wordlist) + @m.output() def M_got_mailbox(self, mailbox): self._M.got_mailbox(mailbox) + @m.output() def RC_tx_release(self): assert self._nameplate self._RC.tx_release(self._nameplate) + @m.output() def T_nameplate_done(self): self._T.nameplate_done() @@ -127,8 +163,8 @@ class Nameplate(object): S0A.upon(_set_nameplate, enter=S1A, outputs=[record_nameplate]) S0A.upon(connected, enter=S0B, outputs=[]) S0A.upon(close, enter=S5A, outputs=[T_nameplate_done]) - S0B.upon(_set_nameplate, enter=S2B, - outputs=[record_nameplate_and_RC_tx_claim]) + S0B.upon( + _set_nameplate, enter=S2B, outputs=[record_nameplate_and_RC_tx_claim]) S0B.upon(lost, enter=S0A, outputs=[]) S0B.upon(close, enter=S5A, outputs=[T_nameplate_done]) @@ -144,7 +180,7 @@ class Nameplate(object): S3A.upon(connected, enter=S3B, outputs=[]) S3A.upon(close, enter=S4A, outputs=[]) S3B.upon(lost, enter=S3A, outputs=[]) - #S3B.upon(rx_claimed, enter=S3B, outputs=[]) # shouldn't happen + # S3B.upon(rx_claimed, enter=S3B, outputs=[]) # shouldn't happen S3B.upon(release, enter=S4B, outputs=[RC_tx_release]) S3B.upon(close, enter=S4B, outputs=[RC_tx_release]) @@ -153,7 +189,7 @@ class Nameplate(object): S4B.upon(lost, enter=S4A, outputs=[]) S4B.upon(rx_claimed, enter=S4B, outputs=[]) S4B.upon(rx_released, enter=S5B, outputs=[T_nameplate_done]) - S4B.upon(release, enter=S4B, outputs=[]) # mailbox is lazy + S4B.upon(release, enter=S4B, outputs=[]) # mailbox is lazy # Mailbox doesn't remember how many times it's sent a release, and will # re-send a new one for each peer message it receives. Ignoring it here # is easier than adding a new pair of states to Mailbox. @@ -161,5 +197,5 @@ class Nameplate(object): S5A.upon(connected, enter=S5B, outputs=[]) S5B.upon(lost, enter=S5A, outputs=[]) - S5.upon(release, enter=S5, outputs=[]) # mailbox is lazy + S5.upon(release, enter=S5, outputs=[]) # mailbox is lazy S5.upon(close, enter=S5, outputs=[]) diff --git a/src/wormhole/_order.py b/src/wormhole/_order.py index e2128df..18f6a8e 100644 --- a/src/wormhole/_order.py +++ b/src/wormhole/_order.py @@ -1,32 +1,40 @@ -from __future__ import print_function, absolute_import, unicode_literals -from zope.interface import implementer -from attr import attrs, attrib -from attr.validators import provides, instance_of +from __future__ import absolute_import, print_function, unicode_literals + +from attr import attrib, attrs +from attr.validators import instance_of, provides from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces + @attrs @implementer(_interfaces.IOrder) class Order(object): _side = attrib(validator=instance_of(type(u""))) _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._key = None self._queue = [] + def wire(self, key, receive): self._K = _interfaces.IKey(key) self._R = _interfaces.IReceive(receive) @m.state(initial=True) - def S0_no_pake(self): pass # pragma: no cover + def S0_no_pake(self): + pass # pragma: no cover + @m.state(terminal=True) - def S1_yes_pake(self): pass # pragma: no cover + def S1_yes_pake(self): + pass # pragma: no cover def got_message(self, side, phase, body): - #print("ORDER[%s].got_message(%s)" % (self._side, phase)) + # print("ORDER[%s].got_message(%s)" % (self._side, phase)) assert isinstance(side, type("")), type(phase) assert isinstance(phase, type("")), type(phase) assert isinstance(body, type(b"")), type(body) @@ -36,9 +44,12 @@ class Order(object): self.got_non_pake(side, phase, body) @m.input() - def got_pake(self, side, phase, body): pass + def got_pake(self, side, phase, body): + pass + @m.input() - def got_non_pake(self, side, phase, body): pass + def got_non_pake(self, side, phase, body): + pass @m.output() def queue(self, side, phase, body): @@ -46,9 +57,11 @@ class Order(object): assert isinstance(phase, type("")), type(phase) assert isinstance(body, type(b"")), type(body) self._queue.append((side, phase, body)) + @m.output() def notify_key(self, side, phase, body): self._K.got_pake(body) + @m.output() def drain(self, side, phase, body): del phase @@ -56,6 +69,7 @@ class Order(object): for (side, phase, body) in self._queue: self._deliver(side, phase, body) self._queue[:] = [] + @m.output() def deliver(self, side, phase, body): self._deliver(side, phase, body) diff --git a/src/wormhole/_receive.py b/src/wormhole/_receive.py index 4fa2858..8e9de4f 100644 --- a/src/wormhole/_receive.py +++ b/src/wormhole/_receive.py @@ -1,10 +1,13 @@ -from __future__ import print_function, absolute_import, unicode_literals -from zope.interface import implementer -from attr import attrs, attrib -from attr.validators import provides, instance_of +from __future__ import absolute_import, print_function, unicode_literals + +from attr import attrib, attrs +from attr.validators import instance_of, provides from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces -from ._key import derive_key, derive_phase_key, decrypt_data, CryptoError +from ._key import CryptoError, decrypt_data, derive_key, derive_phase_key + @attrs @implementer(_interfaces.IReceive) @@ -12,7 +15,8 @@ class Receive(object): _side = attrib(validator=instance_of(type(u""))) _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._key = None @@ -22,13 +26,20 @@ class Receive(object): self._S = _interfaces.ISend(send) @m.state(initial=True) - def S0_unknown_key(self): pass # pragma: no cover + def S0_unknown_key(self): + pass # pragma: no cover + @m.state() - def S1_unverified_key(self): pass # pragma: no cover + def S1_unverified_key(self): + pass # pragma: no cover + @m.state() - def S2_verified_key(self): pass # pragma: no cover + def S2_verified_key(self): + pass # pragma: no cover + @m.state(terminal=True) - def S3_scared(self): pass # pragma: no cover + def S3_scared(self): + pass # pragma: no cover # from Ordering def got_message(self, side, phase, body): @@ -43,47 +54,56 @@ class Receive(object): self.got_message_bad() return self.got_message_good(phase, plaintext) + @m.input() - def got_message_good(self, phase, plaintext): pass + def got_message_good(self, phase, plaintext): + pass + @m.input() - def got_message_bad(self): pass + def got_message_bad(self): + pass # from Key @m.input() - def got_key(self, key): pass + def got_key(self, key): + pass @m.output() def record_key(self, key): self._key = key + @m.output() def S_got_verified_key(self, phase, plaintext): assert self._key self._S.got_verified_key(self._key) + @m.output() def W_happy(self, phase, plaintext): self._B.happy() + @m.output() def W_got_verifier(self, phase, plaintext): self._B.got_verifier(derive_key(self._key, b"wormhole:verifier")) + @m.output() def W_got_message(self, phase, plaintext): assert isinstance(phase, type("")), type(phase) assert isinstance(plaintext, type(b"")), type(plaintext) self._B.got_message(phase, plaintext) + @m.output() def W_scared(self): self._B.scared() S0_unknown_key.upon(got_key, enter=S1_unverified_key, outputs=[record_key]) - S1_unverified_key.upon(got_message_good, enter=S2_verified_key, - outputs=[S_got_verified_key, - W_happy, W_got_verifier, W_got_message]) - S1_unverified_key.upon(got_message_bad, enter=S3_scared, - outputs=[W_scared]) - S2_verified_key.upon(got_message_bad, enter=S3_scared, - outputs=[W_scared]) - S2_verified_key.upon(got_message_good, enter=S2_verified_key, - outputs=[W_got_message]) + S1_unverified_key.upon( + got_message_good, + enter=S2_verified_key, + outputs=[S_got_verified_key, W_happy, W_got_verifier, W_got_message]) + S1_unverified_key.upon( + got_message_bad, enter=S3_scared, outputs=[W_scared]) + S2_verified_key.upon(got_message_bad, enter=S3_scared, outputs=[W_scared]) + S2_verified_key.upon( + got_message_good, enter=S2_verified_key, outputs=[W_got_message]) S3_scared.upon(got_message_good, enter=S3_scared, outputs=[]) S3_scared.upon(got_message_bad, enter=S3_scared, outputs=[]) - diff --git a/src/wormhole/_rendezvous.py b/src/wormhole/_rendezvous.py index 963d12a..65d6545 100644 --- a/src/wormhole/_rendezvous.py +++ b/src/wormhole/_rendezvous.py @@ -9,44 +9,47 @@ from twisted.internet import defer, endpoints, task from twisted.application import internet from autobahn.twisted import websocket from . import _interfaces, errors -from .util import (bytes_to_hexstr, hexstr_to_bytes, - bytes_to_dict, dict_to_bytes) +from .util import (bytes_to_hexstr, hexstr_to_bytes, bytes_to_dict, + dict_to_bytes) + class WSClient(websocket.WebSocketClientProtocol): def onConnect(self, response): # this fires during WebSocket negotiation, and isn't very useful # unless you want to modify the protocol settings - #print("onConnect", response) + # print("onConnect", response) pass def onOpen(self, *args): # this fires when the WebSocket is ready to go. No arguments - #print("onOpen", args) - #self.wormhole_open = True + # print("onOpen", args) + # self.wormhole_open = True self._RC.ws_open(self) def onMessage(self, payload, isBinary): assert not isBinary try: self._RC.ws_message(payload) - except: + except Exception: from twisted.python.failure import Failure print("LOGGING", Failure()) log.err() raise def onClose(self, wasClean, code, reason): - #print("onClose") + # print("onClose") self._RC.ws_close(wasClean, code, reason) - #if self.wormhole_open: - # self.wormhole._ws_closed(wasClean, code, reason) - #else: - # # we closed before establishing a connection (onConnect) or - # # finishing WebSocket negotiation (onOpen): errback - # self.factory.d.errback(error.ConnectError(reason)) + # if self.wormhole_open: + # self.wormhole._ws_closed(wasClean, code, reason) + # else: + # # we closed before establishing a connection (onConnect) or + # # finishing WebSocket negotiation (onOpen): errback + # self.factory.d.errback(error.ConnectError(reason)) + class WSFactory(websocket.WebSocketClientFactory): protocol = WSClient + def __init__(self, RC, *args, **kwargs): websocket.WebSocketClientFactory.__init__(self, *args, **kwargs) self._RC = RC @@ -54,9 +57,10 @@ class WSFactory(websocket.WebSocketClientFactory): def buildProtocol(self, addr): proto = websocket.WebSocketClientFactory.buildProtocol(self, addr) proto._RC = self._RC - #proto.wormhole_open = False + # proto.wormhole_open = False return proto + @attrs @implementer(_interfaces.IRendezvousConnector) class RendezvousConnector(object): @@ -90,6 +94,7 @@ class RendezvousConnector(object): def set_trace(self, f): self._trace = f + def _debug(self, what): if self._trace: self._trace(old_state="", input=what, new_state="") @@ -133,14 +138,13 @@ class RendezvousConnector(object): def stop(self): # ClientService.stopService is defined to "Stop attempting to # reconnect and close any existing connections" - self._stopping = True # to catch _initial_connection_failed error + self._stopping = True # to catch _initial_connection_failed error d = defer.maybeDeferred(self._connector.stopService) # ClientService.stopService always fires with None, even if the # initial connection failed, so log.err just in case d.addErrback(log.err) d.addBoth(self._stopped) - # from Lister def tx_list(self): self._tx("list") @@ -157,7 +161,7 @@ class RendezvousConnector(object): # this should happen right away: the ClientService ought to be in # the "_waiting" state, and everything in the _waiting.stop # transition is immediate - d.addErrback(log.err) # just in case something goes wrong + d.addErrback(log.err) # just in case something goes wrong d.addCallback(lambda _: self._B.error(sce)) # from our WSClient (the WebSocket protocol) @@ -166,8 +170,11 @@ class RendezvousConnector(object): self._have_made_a_successful_connection = True self._ws = proto try: - self._tx("bind", appid=self._appid, side=self._side, - client_version=self._client_version) + self._tx( + "bind", + appid=self._appid, + side=self._side, + client_version=self._client_version) self._N.connected() self._M.connected() self._L.connected() @@ -180,19 +187,22 @@ class RendezvousConnector(object): def ws_message(self, payload): msg = bytes_to_dict(payload) if msg["type"] != "ack": - self._debug("R.rx(%s %s%s)" % - (msg["type"], msg.get("phase",""), - "[mine]" if msg.get("side","") == self._side else "", - )) + self._debug("R.rx(%s %s%s)" % ( + msg["type"], + msg.get("phase", ""), + "[mine]" if msg.get("side", "") == self._side else "", + )) self._timing.add("ws_receive", _side=self._side, message=msg) if self._debug_record_inbound_f: self._debug_record_inbound_f(msg) mtype = msg["type"] - meth = getattr(self, "_response_handle_"+mtype, None) + meth = getattr(self, "_response_handle_" + mtype, None) if not meth: # make tests fail, but real application will ignore it - log.err(errors._UnknownMessageTypeError("Unknown inbound message type %r" % (msg,))) + log.err( + errors._UnknownMessageTypeError( + "Unknown inbound message type %r" % (msg, ))) return try: return meth(msg) @@ -229,7 +239,7 @@ class RendezvousConnector(object): # valid connection sce = errors.ServerConnectionError(self._url, reason) d = defer.maybeDeferred(self._connector.stopService) - d.addErrback(log.err) # just in case something goes wrong + d.addErrback(log.err) # just in case something goes wrong # tell the Boss to quit and inform the user d.addCallback(lambda _: self._B.error(sce)) @@ -292,7 +302,7 @@ class RendezvousConnector(object): side = msg["side"] phase = msg["phase"] assert isinstance(phase, type("")), type(phase) - body = hexstr_to_bytes(msg["body"]) # bytes + body = hexstr_to_bytes(msg["body"]) # bytes self._M.rx_message(side, phase, body) def _response_handle_released(self, msg): @@ -301,5 +311,4 @@ class RendezvousConnector(object): def _response_handle_closed(self, msg): self._M.rx_closed() - # record, message, payload, packet, bundle, ciphertext, plaintext diff --git a/src/wormhole/_rlcompleter.py b/src/wormhole/_rlcompleter.py index 853b98f..e9262db 100644 --- a/src/wormhole/_rlcompleter.py +++ b/src/wormhole/_rlcompleter.py @@ -1,33 +1,41 @@ from __future__ import print_function, unicode_literals + import traceback from sys import stderr + +from attr import attrib, attrs +from six.moves import input +from twisted.internet.defer import inlineCallbacks, returnValue +from twisted.internet.threads import blockingCallFromThread, deferToThread + +from .errors import AlreadyInputNameplateError, KeyFormatError + try: import readline except ImportError: readline = None -from six.moves import input -from attr import attrs, attrib -from twisted.internet.defer import inlineCallbacks, returnValue -from twisted.internet.threads import deferToThread, blockingCallFromThread -from .errors import KeyFormatError, AlreadyInputNameplateError errf = None + + # uncomment this to enable tab-completion debugging -#import os ; errf = open("err", "w") if os.path.exists("err") else None -def debug(*args, **kwargs): # pragma: no cover +# import os ; errf = open("err", "w") if os.path.exists("err") else None +def debug(*args, **kwargs): # pragma: no cover if errf: print(*args, file=errf, **kwargs) errf.flush() + @attrs class CodeInputter(object): _input_helper = attrib() _reactor = attrib() + def __attrs_post_init__(self): self.used_completion = False self._matches = None # once we've claimed the nameplate, we can't go back - self._committed_nameplate = None # or string + self._committed_nameplate = None # or string def bcft(self, f, *a, **kw): return blockingCallFromThread(self._reactor, f, *a, **kw) @@ -66,7 +74,7 @@ class CodeInputter(object): nameplate, words = text.split("-", 1) else: got_nameplate = False - nameplate = text # partial + nameplate = text # partial # 'text' is one of these categories: # "" or "12": complete on nameplates (all that match, maybe just one) @@ -83,13 +91,15 @@ class CodeInputter(object): # they deleted past the committment point: we can't use # this. For now, bail, but in the future let's find a # gentler way to encourage them to not do that. - raise AlreadyInputNameplateError("nameplate (%s-) already entered, cannot go back" % self._committed_nameplate) + raise AlreadyInputNameplateError( + "nameplate (%s-) already entered, cannot go back" % + self._committed_nameplate) if not got_nameplate: # we're completing on nameplates: "" or "12" or "123" - self.bcft(ih.refresh_nameplates) # results arrive later + self.bcft(ih.refresh_nameplates) # results arrive later debug(" getting nameplates") completions = self.bcft(ih.get_nameplate_completions, nameplate) - else: # "123-" or "123-supp" + else: # "123-" or "123-supp" # time to commit to this nameplate, if they haven't already if not self._committed_nameplate: debug(" choose_nameplate(%s)" % nameplate) @@ -112,11 +122,13 @@ class CodeInputter(object): # heard about it from the server), it can't be helped. But # for the rest of the code, a simple wait-for-wordlist will # improve the user experience. - self.bcft(ih.when_wordlist_is_available) # blocks on CLAIM + self.bcft(ih.when_wordlist_is_available) # blocks on CLAIM # and we're completing on words now - debug(" getting words (%s)" % (words,)) - completions = [nameplate+"-"+c - for c in self.bcft(ih.get_word_completions, words)] + debug(" getting words (%s)" % (words, )) + completions = [ + nameplate + "-" + c + for c in self.bcft(ih.get_word_completions, words) + ] # rlcompleter wants full strings return sorted(completions) @@ -131,13 +143,16 @@ class CodeInputter(object): # they deleted past the committment point: we can't use # this. For now, bail, but in the future let's find a # gentler way to encourage them to not do that. - raise AlreadyInputNameplateError("nameplate (%s-) already entered, cannot go back" % self._committed_nameplate) + raise AlreadyInputNameplateError( + "nameplate (%s-) already entered, cannot go back" % + self._committed_nameplate) else: debug(" choose_nameplate(%s)" % nameplate) self.bcft(self._input_helper.choose_nameplate, nameplate) debug(" choose_words(%s)" % words) self.bcft(self._input_helper.choose_words, words) + def _input_code_with_completion(prompt, input_helper, reactor): # reminder: this all occurs in a separate thread. All calls to input_helper # must go through blockingCallFromThread() @@ -159,6 +174,7 @@ def _input_code_with_completion(prompt, input_helper, reactor): c.finish(code) return c.used_completion + def warn_readline(): # When our process receives a SIGINT, Twisted's SIGINT handler will # stop the reactor and wait for all threads to terminate before the @@ -192,11 +208,12 @@ def warn_readline(): # doesn't see the signal, and we must still wait for stdin to make # readline finish. + @inlineCallbacks def input_with_completion(prompt, input_helper, reactor): t = reactor.addSystemEventTrigger("before", "shutdown", warn_readline) - #input_helper.refresh_nameplates() - used_completion = yield deferToThread(_input_code_with_completion, - prompt, input_helper, reactor) + # input_helper.refresh_nameplates() + used_completion = yield deferToThread(_input_code_with_completion, prompt, + input_helper, reactor) reactor.removeSystemEventTrigger(t) returnValue(used_completion) diff --git a/src/wormhole/_send.py b/src/wormhole/_send.py index d0c45c0..e0293c6 100644 --- a/src/wormhole/_send.py +++ b/src/wormhole/_send.py @@ -1,18 +1,22 @@ -from __future__ import print_function, absolute_import, unicode_literals -from attr import attrs, attrib -from attr.validators import provides, instance_of -from zope.interface import implementer +from __future__ import absolute_import, print_function, unicode_literals + +from attr import attrib, attrs +from attr.validators import instance_of, provides from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces from ._key import derive_phase_key, encrypt_data + @attrs @implementer(_interfaces.ISend) class Send(object): _side = attrib(validator=instance_of(type(u""))) _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._queue = [] @@ -21,31 +25,40 @@ class Send(object): self._M = _interfaces.IMailbox(mailbox) @m.state(initial=True) - def S0_no_key(self): pass # pragma: no cover + def S0_no_key(self): + pass # pragma: no cover + @m.state(terminal=True) - def S1_verified_key(self): pass # pragma: no cover + def S1_verified_key(self): + pass # pragma: no cover # from Receive @m.input() - def got_verified_key(self, key): pass + def got_verified_key(self, key): + pass + # from Boss @m.input() - def send(self, phase, plaintext): pass + def send(self, phase, plaintext): + pass @m.output() def queue(self, phase, plaintext): assert isinstance(phase, type("")), type(phase) assert isinstance(plaintext, type(b"")), type(plaintext) self._queue.append((phase, plaintext)) + @m.output() def record_key(self, key): self._key = key + @m.output() def drain(self, key): del key for (phase, plaintext) in self._queue: self._encrypt_and_send(phase, plaintext) self._queue[:] = [] + @m.output() def deliver(self, phase, plaintext): assert isinstance(phase, type("")), type(phase) @@ -59,6 +72,6 @@ class Send(object): self._M.add_message(phase, encrypted) S0_no_key.upon(send, enter=S0_no_key, outputs=[queue]) - S0_no_key.upon(got_verified_key, enter=S1_verified_key, - outputs=[record_key, drain]) + S0_no_key.upon( + got_verified_key, enter=S1_verified_key, outputs=[record_key, drain]) S1_verified_key.upon(send, enter=S1_verified_key, outputs=[deliver]) diff --git a/src/wormhole/_terminator.py b/src/wormhole/_terminator.py index 2cb7327..fe4bdcb 100644 --- a/src/wormhole/_terminator.py +++ b/src/wormhole/_terminator.py @@ -1,12 +1,16 @@ -from __future__ import print_function, absolute_import, unicode_literals -from zope.interface import implementer +from __future__ import absolute_import, print_function, unicode_literals + from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces + @implementer(_interfaces.ITerminator) class Terminator(object): m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __init__(self): self._mood = None @@ -29,48 +33,68 @@ class Terminator(object): # done, and we're closing, then we stop the RendezvousConnector @m.state(initial=True) - def Snmo(self): pass # pragma: no cover - @m.state() - def Smo(self): pass # pragma: no cover - @m.state() - def Sno(self): pass # pragma: no cover - @m.state() - def S0o(self): pass # pragma: no cover + def Snmo(self): + pass # pragma: no cover @m.state() - def Snm(self): pass # pragma: no cover - @m.state() - def Sm(self): pass # pragma: no cover - @m.state() - def Sn(self): pass # pragma: no cover - #@m.state() - #def S0(self): pass # unused + def Smo(self): + pass # pragma: no cover @m.state() - def S_stopping(self): pass # pragma: no cover + def Sno(self): + pass # pragma: no cover + @m.state() - def S_stopped(self, terminal=True): pass # pragma: no cover + def S0o(self): + pass # pragma: no cover + + @m.state() + def Snm(self): + pass # pragma: no cover + + @m.state() + def Sm(self): + pass # pragma: no cover + + @m.state() + def Sn(self): + pass # pragma: no cover + + # @m.state() + # def S0(self): pass # unused + + @m.state() + def S_stopping(self): + pass # pragma: no cover + + @m.state() + def S_stopped(self, terminal=True): + pass # pragma: no cover # from Boss @m.input() - def close(self, mood): pass + def close(self, mood): + pass # from Nameplate @m.input() - def nameplate_done(self): pass + def nameplate_done(self): + pass # from Mailbox @m.input() - def mailbox_done(self): pass + def mailbox_done(self): + pass # from RendezvousConnector @m.input() - def stopped(self): pass - + def stopped(self): + pass @m.output() def close_nameplate(self, mood): - self._N.close() # ignores mood + self._N.close() # ignores mood + @m.output() def close_mailbox(self, mood): self._M.close(mood) @@ -78,9 +102,11 @@ class Terminator(object): @m.output() def ignore_mood_and_RC_stop(self, mood): self._RC.stop() + @m.output() def RC_stop(self): self._RC.stop() + @m.output() def B_closed(self): self._B.closed() @@ -99,8 +125,10 @@ class Terminator(object): Snm.upon(nameplate_done, enter=Sm, outputs=[]) Sn.upon(nameplate_done, enter=S_stopping, outputs=[RC_stop]) - S0o.upon(close, enter=S_stopping, - outputs=[close_nameplate, close_mailbox, ignore_mood_and_RC_stop]) + S0o.upon( + close, + enter=S_stopping, + outputs=[close_nameplate, close_mailbox, ignore_mood_and_RC_stop]) Sm.upon(mailbox_done, enter=S_stopping, outputs=[RC_stop]) S_stopping.upon(stopped, enter=S_stopped, outputs=[B_closed]) diff --git a/src/wormhole/_wordlist.py b/src/wormhole/_wordlist.py index e14972b..b434c71 100644 --- a/src/wormhole/_wordlist.py +++ b/src/wormhole/_wordlist.py @@ -1,6 +1,10 @@ -from __future__ import unicode_literals, print_function +from __future__ import print_function, unicode_literals + import os +from binascii import unhexlify + from zope.interface import implementer + from ._interfaces import IWordlist # The PGP Word List, which maps bytes to phonetically-distinct words. There @@ -10,154 +14,280 @@ from ._interfaces import IWordlist # Thanks to Warren Guy for transcribing them: # https://github.com/warrenguy/javascript-pgp-word-list -from binascii import unhexlify raw_words = { -'00': ['aardvark', 'adroitness'], '01': ['absurd', 'adviser'], -'02': ['accrue', 'aftermath'], '03': ['acme', 'aggregate'], -'04': ['adrift', 'alkali'], '05': ['adult', 'almighty'], -'06': ['afflict', 'amulet'], '07': ['ahead', 'amusement'], -'08': ['aimless', 'antenna'], '09': ['Algol', 'applicant'], -'0A': ['allow', 'Apollo'], '0B': ['alone', 'armistice'], -'0C': ['ammo', 'article'], '0D': ['ancient', 'asteroid'], -'0E': ['apple', 'Atlantic'], '0F': ['artist', 'atmosphere'], -'10': ['assume', 'autopsy'], '11': ['Athens', 'Babylon'], -'12': ['atlas', 'backwater'], '13': ['Aztec', 'barbecue'], -'14': ['baboon', 'belowground'], '15': ['backfield', 'bifocals'], -'16': ['backward', 'bodyguard'], '17': ['banjo', 'bookseller'], -'18': ['beaming', 'borderline'], '19': ['bedlamp', 'bottomless'], -'1A': ['beehive', 'Bradbury'], '1B': ['beeswax', 'bravado'], -'1C': ['befriend', 'Brazilian'], '1D': ['Belfast', 'breakaway'], -'1E': ['berserk', 'Burlington'], '1F': ['billiard', 'businessman'], -'20': ['bison', 'butterfat'], '21': ['blackjack', 'Camelot'], -'22': ['blockade', 'candidate'], '23': ['blowtorch', 'cannonball'], -'24': ['bluebird', 'Capricorn'], '25': ['bombast', 'caravan'], -'26': ['bookshelf', 'caretaker'], '27': ['brackish', 'celebrate'], -'28': ['breadline', 'cellulose'], '29': ['breakup', 'certify'], -'2A': ['brickyard', 'chambermaid'], '2B': ['briefcase', 'Cherokee'], -'2C': ['Burbank', 'Chicago'], '2D': ['button', 'clergyman'], -'2E': ['buzzard', 'coherence'], '2F': ['cement', 'combustion'], -'30': ['chairlift', 'commando'], '31': ['chatter', 'company'], -'32': ['checkup', 'component'], '33': ['chisel', 'concurrent'], -'34': ['choking', 'confidence'], '35': ['chopper', 'conformist'], -'36': ['Christmas', 'congregate'], '37': ['clamshell', 'consensus'], -'38': ['classic', 'consulting'], '39': ['classroom', 'corporate'], -'3A': ['cleanup', 'corrosion'], '3B': ['clockwork', 'councilman'], -'3C': ['cobra', 'crossover'], '3D': ['commence', 'crucifix'], -'3E': ['concert', 'cumbersome'], '3F': ['cowbell', 'customer'], -'40': ['crackdown', 'Dakota'], '41': ['cranky', 'decadence'], -'42': ['crowfoot', 'December'], '43': ['crucial', 'decimal'], -'44': ['crumpled', 'designing'], '45': ['crusade', 'detector'], -'46': ['cubic', 'detergent'], '47': ['dashboard', 'determine'], -'48': ['deadbolt', 'dictator'], '49': ['deckhand', 'dinosaur'], -'4A': ['dogsled', 'direction'], '4B': ['dragnet', 'disable'], -'4C': ['drainage', 'disbelief'], '4D': ['dreadful', 'disruptive'], -'4E': ['drifter', 'distortion'], '4F': ['dropper', 'document'], -'50': ['drumbeat', 'embezzle'], '51': ['drunken', 'enchanting'], -'52': ['Dupont', 'enrollment'], '53': ['dwelling', 'enterprise'], -'54': ['eating', 'equation'], '55': ['edict', 'equipment'], -'56': ['egghead', 'escapade'], '57': ['eightball', 'Eskimo'], -'58': ['endorse', 'everyday'], '59': ['endow', 'examine'], -'5A': ['enlist', 'existence'], '5B': ['erase', 'exodus'], -'5C': ['escape', 'fascinate'], '5D': ['exceed', 'filament'], -'5E': ['eyeglass', 'finicky'], '5F': ['eyetooth', 'forever'], -'60': ['facial', 'fortitude'], '61': ['fallout', 'frequency'], -'62': ['flagpole', 'gadgetry'], '63': ['flatfoot', 'Galveston'], -'64': ['flytrap', 'getaway'], '65': ['fracture', 'glossary'], -'66': ['framework', 'gossamer'], '67': ['freedom', 'graduate'], -'68': ['frighten', 'gravity'], '69': ['gazelle', 'guitarist'], -'6A': ['Geiger', 'hamburger'], '6B': ['glitter', 'Hamilton'], -'6C': ['glucose', 'handiwork'], '6D': ['goggles', 'hazardous'], -'6E': ['goldfish', 'headwaters'], '6F': ['gremlin', 'hemisphere'], -'70': ['guidance', 'hesitate'], '71': ['hamlet', 'hideaway'], -'72': ['highchair', 'holiness'], '73': ['hockey', 'hurricane'], -'74': ['indoors', 'hydraulic'], '75': ['indulge', 'impartial'], -'76': ['inverse', 'impetus'], '77': ['involve', 'inception'], -'78': ['island', 'indigo'], '79': ['jawbone', 'inertia'], -'7A': ['keyboard', 'infancy'], '7B': ['kickoff', 'inferno'], -'7C': ['kiwi', 'informant'], '7D': ['klaxon', 'insincere'], -'7E': ['locale', 'insurgent'], '7F': ['lockup', 'integrate'], -'80': ['merit', 'intention'], '81': ['minnow', 'inventive'], -'82': ['miser', 'Istanbul'], '83': ['Mohawk', 'Jamaica'], -'84': ['mural', 'Jupiter'], '85': ['music', 'leprosy'], -'86': ['necklace', 'letterhead'], '87': ['Neptune', 'liberty'], -'88': ['newborn', 'maritime'], '89': ['nightbird', 'matchmaker'], -'8A': ['Oakland', 'maverick'], '8B': ['obtuse', 'Medusa'], -'8C': ['offload', 'megaton'], '8D': ['optic', 'microscope'], -'8E': ['orca', 'microwave'], '8F': ['payday', 'midsummer'], -'90': ['peachy', 'millionaire'], '91': ['pheasant', 'miracle'], -'92': ['physique', 'misnomer'], '93': ['playhouse', 'molasses'], -'94': ['Pluto', 'molecule'], '95': ['preclude', 'Montana'], -'96': ['prefer', 'monument'], '97': ['preshrunk', 'mosquito'], -'98': ['printer', 'narrative'], '99': ['prowler', 'nebula'], -'9A': ['pupil', 'newsletter'], '9B': ['puppy', 'Norwegian'], -'9C': ['python', 'October'], '9D': ['quadrant', 'Ohio'], -'9E': ['quiver', 'onlooker'], '9F': ['quota', 'opulent'], -'A0': ['ragtime', 'Orlando'], 'A1': ['ratchet', 'outfielder'], -'A2': ['rebirth', 'Pacific'], 'A3': ['reform', 'pandemic'], -'A4': ['regain', 'Pandora'], 'A5': ['reindeer', 'paperweight'], -'A6': ['rematch', 'paragon'], 'A7': ['repay', 'paragraph'], -'A8': ['retouch', 'paramount'], 'A9': ['revenge', 'passenger'], -'AA': ['reward', 'pedigree'], 'AB': ['rhythm', 'Pegasus'], -'AC': ['ribcage', 'penetrate'], 'AD': ['ringbolt', 'perceptive'], -'AE': ['robust', 'performance'], 'AF': ['rocker', 'pharmacy'], -'B0': ['ruffled', 'phonetic'], 'B1': ['sailboat', 'photograph'], -'B2': ['sawdust', 'pioneer'], 'B3': ['scallion', 'pocketful'], -'B4': ['scenic', 'politeness'], 'B5': ['scorecard', 'positive'], -'B6': ['Scotland', 'potato'], 'B7': ['seabird', 'processor'], -'B8': ['select', 'provincial'], 'B9': ['sentence', 'proximate'], -'BA': ['shadow', 'puberty'], 'BB': ['shamrock', 'publisher'], -'BC': ['showgirl', 'pyramid'], 'BD': ['skullcap', 'quantity'], -'BE': ['skydive', 'racketeer'], 'BF': ['slingshot', 'rebellion'], -'C0': ['slowdown', 'recipe'], 'C1': ['snapline', 'recover'], -'C2': ['snapshot', 'repellent'], 'C3': ['snowcap', 'replica'], -'C4': ['snowslide', 'reproduce'], 'C5': ['solo', 'resistor'], -'C6': ['southward', 'responsive'], 'C7': ['soybean', 'retraction'], -'C8': ['spaniel', 'retrieval'], 'C9': ['spearhead', 'retrospect'], -'CA': ['spellbind', 'revenue'], 'CB': ['spheroid', 'revival'], -'CC': ['spigot', 'revolver'], 'CD': ['spindle', 'sandalwood'], -'CE': ['spyglass', 'sardonic'], 'CF': ['stagehand', 'Saturday'], -'D0': ['stagnate', 'savagery'], 'D1': ['stairway', 'scavenger'], -'D2': ['standard', 'sensation'], 'D3': ['stapler', 'sociable'], -'D4': ['steamship', 'souvenir'], 'D5': ['sterling', 'specialist'], -'D6': ['stockman', 'speculate'], 'D7': ['stopwatch', 'stethoscope'], -'D8': ['stormy', 'stupendous'], 'D9': ['sugar', 'supportive'], -'DA': ['surmount', 'surrender'], 'DB': ['suspense', 'suspicious'], -'DC': ['sweatband', 'sympathy'], 'DD': ['swelter', 'tambourine'], -'DE': ['tactics', 'telephone'], 'DF': ['talon', 'therapist'], -'E0': ['tapeworm', 'tobacco'], 'E1': ['tempest', 'tolerance'], -'E2': ['tiger', 'tomorrow'], 'E3': ['tissue', 'torpedo'], -'E4': ['tonic', 'tradition'], 'E5': ['topmost', 'travesty'], -'E6': ['tracker', 'trombonist'], 'E7': ['transit', 'truncated'], -'E8': ['trauma', 'typewriter'], 'E9': ['treadmill', 'ultimate'], -'EA': ['Trojan', 'undaunted'], 'EB': ['trouble', 'underfoot'], -'EC': ['tumor', 'unicorn'], 'ED': ['tunnel', 'unify'], -'EE': ['tycoon', 'universe'], 'EF': ['uncut', 'unravel'], -'F0': ['unearth', 'upcoming'], 'F1': ['unwind', 'vacancy'], -'F2': ['uproot', 'vagabond'], 'F3': ['upset', 'vertigo'], -'F4': ['upshot', 'Virginia'], 'F5': ['vapor', 'visitor'], -'F6': ['village', 'vocalist'], 'F7': ['virus', 'voyager'], -'F8': ['Vulcan', 'warranty'], 'F9': ['waffle', 'Waterloo'], -'FA': ['wallet', 'whimsical'], 'FB': ['watchword', 'Wichita'], -'FC': ['wayside', 'Wilmington'], 'FD': ['willow', 'Wyoming'], -'FE': ['woodlark', 'yesteryear'], 'FF': ['Zulu', 'Yucatan'] -}; + '00': ['aardvark', 'adroitness'], + '01': ['absurd', 'adviser'], + '02': ['accrue', 'aftermath'], + '03': ['acme', 'aggregate'], + '04': ['adrift', 'alkali'], + '05': ['adult', 'almighty'], + '06': ['afflict', 'amulet'], + '07': ['ahead', 'amusement'], + '08': ['aimless', 'antenna'], + '09': ['Algol', 'applicant'], + '0A': ['allow', 'Apollo'], + '0B': ['alone', 'armistice'], + '0C': ['ammo', 'article'], + '0D': ['ancient', 'asteroid'], + '0E': ['apple', 'Atlantic'], + '0F': ['artist', 'atmosphere'], + '10': ['assume', 'autopsy'], + '11': ['Athens', 'Babylon'], + '12': ['atlas', 'backwater'], + '13': ['Aztec', 'barbecue'], + '14': ['baboon', 'belowground'], + '15': ['backfield', 'bifocals'], + '16': ['backward', 'bodyguard'], + '17': ['banjo', 'bookseller'], + '18': ['beaming', 'borderline'], + '19': ['bedlamp', 'bottomless'], + '1A': ['beehive', 'Bradbury'], + '1B': ['beeswax', 'bravado'], + '1C': ['befriend', 'Brazilian'], + '1D': ['Belfast', 'breakaway'], + '1E': ['berserk', 'Burlington'], + '1F': ['billiard', 'businessman'], + '20': ['bison', 'butterfat'], + '21': ['blackjack', 'Camelot'], + '22': ['blockade', 'candidate'], + '23': ['blowtorch', 'cannonball'], + '24': ['bluebird', 'Capricorn'], + '25': ['bombast', 'caravan'], + '26': ['bookshelf', 'caretaker'], + '27': ['brackish', 'celebrate'], + '28': ['breadline', 'cellulose'], + '29': ['breakup', 'certify'], + '2A': ['brickyard', 'chambermaid'], + '2B': ['briefcase', 'Cherokee'], + '2C': ['Burbank', 'Chicago'], + '2D': ['button', 'clergyman'], + '2E': ['buzzard', 'coherence'], + '2F': ['cement', 'combustion'], + '30': ['chairlift', 'commando'], + '31': ['chatter', 'company'], + '32': ['checkup', 'component'], + '33': ['chisel', 'concurrent'], + '34': ['choking', 'confidence'], + '35': ['chopper', 'conformist'], + '36': ['Christmas', 'congregate'], + '37': ['clamshell', 'consensus'], + '38': ['classic', 'consulting'], + '39': ['classroom', 'corporate'], + '3A': ['cleanup', 'corrosion'], + '3B': ['clockwork', 'councilman'], + '3C': ['cobra', 'crossover'], + '3D': ['commence', 'crucifix'], + '3E': ['concert', 'cumbersome'], + '3F': ['cowbell', 'customer'], + '40': ['crackdown', 'Dakota'], + '41': ['cranky', 'decadence'], + '42': ['crowfoot', 'December'], + '43': ['crucial', 'decimal'], + '44': ['crumpled', 'designing'], + '45': ['crusade', 'detector'], + '46': ['cubic', 'detergent'], + '47': ['dashboard', 'determine'], + '48': ['deadbolt', 'dictator'], + '49': ['deckhand', 'dinosaur'], + '4A': ['dogsled', 'direction'], + '4B': ['dragnet', 'disable'], + '4C': ['drainage', 'disbelief'], + '4D': ['dreadful', 'disruptive'], + '4E': ['drifter', 'distortion'], + '4F': ['dropper', 'document'], + '50': ['drumbeat', 'embezzle'], + '51': ['drunken', 'enchanting'], + '52': ['Dupont', 'enrollment'], + '53': ['dwelling', 'enterprise'], + '54': ['eating', 'equation'], + '55': ['edict', 'equipment'], + '56': ['egghead', 'escapade'], + '57': ['eightball', 'Eskimo'], + '58': ['endorse', 'everyday'], + '59': ['endow', 'examine'], + '5A': ['enlist', 'existence'], + '5B': ['erase', 'exodus'], + '5C': ['escape', 'fascinate'], + '5D': ['exceed', 'filament'], + '5E': ['eyeglass', 'finicky'], + '5F': ['eyetooth', 'forever'], + '60': ['facial', 'fortitude'], + '61': ['fallout', 'frequency'], + '62': ['flagpole', 'gadgetry'], + '63': ['flatfoot', 'Galveston'], + '64': ['flytrap', 'getaway'], + '65': ['fracture', 'glossary'], + '66': ['framework', 'gossamer'], + '67': ['freedom', 'graduate'], + '68': ['frighten', 'gravity'], + '69': ['gazelle', 'guitarist'], + '6A': ['Geiger', 'hamburger'], + '6B': ['glitter', 'Hamilton'], + '6C': ['glucose', 'handiwork'], + '6D': ['goggles', 'hazardous'], + '6E': ['goldfish', 'headwaters'], + '6F': ['gremlin', 'hemisphere'], + '70': ['guidance', 'hesitate'], + '71': ['hamlet', 'hideaway'], + '72': ['highchair', 'holiness'], + '73': ['hockey', 'hurricane'], + '74': ['indoors', 'hydraulic'], + '75': ['indulge', 'impartial'], + '76': ['inverse', 'impetus'], + '77': ['involve', 'inception'], + '78': ['island', 'indigo'], + '79': ['jawbone', 'inertia'], + '7A': ['keyboard', 'infancy'], + '7B': ['kickoff', 'inferno'], + '7C': ['kiwi', 'informant'], + '7D': ['klaxon', 'insincere'], + '7E': ['locale', 'insurgent'], + '7F': ['lockup', 'integrate'], + '80': ['merit', 'intention'], + '81': ['minnow', 'inventive'], + '82': ['miser', 'Istanbul'], + '83': ['Mohawk', 'Jamaica'], + '84': ['mural', 'Jupiter'], + '85': ['music', 'leprosy'], + '86': ['necklace', 'letterhead'], + '87': ['Neptune', 'liberty'], + '88': ['newborn', 'maritime'], + '89': ['nightbird', 'matchmaker'], + '8A': ['Oakland', 'maverick'], + '8B': ['obtuse', 'Medusa'], + '8C': ['offload', 'megaton'], + '8D': ['optic', 'microscope'], + '8E': ['orca', 'microwave'], + '8F': ['payday', 'midsummer'], + '90': ['peachy', 'millionaire'], + '91': ['pheasant', 'miracle'], + '92': ['physique', 'misnomer'], + '93': ['playhouse', 'molasses'], + '94': ['Pluto', 'molecule'], + '95': ['preclude', 'Montana'], + '96': ['prefer', 'monument'], + '97': ['preshrunk', 'mosquito'], + '98': ['printer', 'narrative'], + '99': ['prowler', 'nebula'], + '9A': ['pupil', 'newsletter'], + '9B': ['puppy', 'Norwegian'], + '9C': ['python', 'October'], + '9D': ['quadrant', 'Ohio'], + '9E': ['quiver', 'onlooker'], + '9F': ['quota', 'opulent'], + 'A0': ['ragtime', 'Orlando'], + 'A1': ['ratchet', 'outfielder'], + 'A2': ['rebirth', 'Pacific'], + 'A3': ['reform', 'pandemic'], + 'A4': ['regain', 'Pandora'], + 'A5': ['reindeer', 'paperweight'], + 'A6': ['rematch', 'paragon'], + 'A7': ['repay', 'paragraph'], + 'A8': ['retouch', 'paramount'], + 'A9': ['revenge', 'passenger'], + 'AA': ['reward', 'pedigree'], + 'AB': ['rhythm', 'Pegasus'], + 'AC': ['ribcage', 'penetrate'], + 'AD': ['ringbolt', 'perceptive'], + 'AE': ['robust', 'performance'], + 'AF': ['rocker', 'pharmacy'], + 'B0': ['ruffled', 'phonetic'], + 'B1': ['sailboat', 'photograph'], + 'B2': ['sawdust', 'pioneer'], + 'B3': ['scallion', 'pocketful'], + 'B4': ['scenic', 'politeness'], + 'B5': ['scorecard', 'positive'], + 'B6': ['Scotland', 'potato'], + 'B7': ['seabird', 'processor'], + 'B8': ['select', 'provincial'], + 'B9': ['sentence', 'proximate'], + 'BA': ['shadow', 'puberty'], + 'BB': ['shamrock', 'publisher'], + 'BC': ['showgirl', 'pyramid'], + 'BD': ['skullcap', 'quantity'], + 'BE': ['skydive', 'racketeer'], + 'BF': ['slingshot', 'rebellion'], + 'C0': ['slowdown', 'recipe'], + 'C1': ['snapline', 'recover'], + 'C2': ['snapshot', 'repellent'], + 'C3': ['snowcap', 'replica'], + 'C4': ['snowslide', 'reproduce'], + 'C5': ['solo', 'resistor'], + 'C6': ['southward', 'responsive'], + 'C7': ['soybean', 'retraction'], + 'C8': ['spaniel', 'retrieval'], + 'C9': ['spearhead', 'retrospect'], + 'CA': ['spellbind', 'revenue'], + 'CB': ['spheroid', 'revival'], + 'CC': ['spigot', 'revolver'], + 'CD': ['spindle', 'sandalwood'], + 'CE': ['spyglass', 'sardonic'], + 'CF': ['stagehand', 'Saturday'], + 'D0': ['stagnate', 'savagery'], + 'D1': ['stairway', 'scavenger'], + 'D2': ['standard', 'sensation'], + 'D3': ['stapler', 'sociable'], + 'D4': ['steamship', 'souvenir'], + 'D5': ['sterling', 'specialist'], + 'D6': ['stockman', 'speculate'], + 'D7': ['stopwatch', 'stethoscope'], + 'D8': ['stormy', 'stupendous'], + 'D9': ['sugar', 'supportive'], + 'DA': ['surmount', 'surrender'], + 'DB': ['suspense', 'suspicious'], + 'DC': ['sweatband', 'sympathy'], + 'DD': ['swelter', 'tambourine'], + 'DE': ['tactics', 'telephone'], + 'DF': ['talon', 'therapist'], + 'E0': ['tapeworm', 'tobacco'], + 'E1': ['tempest', 'tolerance'], + 'E2': ['tiger', 'tomorrow'], + 'E3': ['tissue', 'torpedo'], + 'E4': ['tonic', 'tradition'], + 'E5': ['topmost', 'travesty'], + 'E6': ['tracker', 'trombonist'], + 'E7': ['transit', 'truncated'], + 'E8': ['trauma', 'typewriter'], + 'E9': ['treadmill', 'ultimate'], + 'EA': ['Trojan', 'undaunted'], + 'EB': ['trouble', 'underfoot'], + 'EC': ['tumor', 'unicorn'], + 'ED': ['tunnel', 'unify'], + 'EE': ['tycoon', 'universe'], + 'EF': ['uncut', 'unravel'], + 'F0': ['unearth', 'upcoming'], + 'F1': ['unwind', 'vacancy'], + 'F2': ['uproot', 'vagabond'], + 'F3': ['upset', 'vertigo'], + 'F4': ['upshot', 'Virginia'], + 'F5': ['vapor', 'visitor'], + 'F6': ['village', 'vocalist'], + 'F7': ['virus', 'voyager'], + 'F8': ['Vulcan', 'warranty'], + 'F9': ['waffle', 'Waterloo'], + 'FA': ['wallet', 'whimsical'], + 'FB': ['watchword', 'Wichita'], + 'FC': ['wayside', 'Wilmington'], + 'FD': ['willow', 'Wyoming'], + 'FE': ['woodlark', 'yesteryear'], + 'FF': ['Zulu', 'Yucatan'] +} byte_to_even_word = dict([(unhexlify(k.encode("ascii")), both_words[0]) - for k,both_words - in raw_words.items()]) + for k, both_words in raw_words.items()]) byte_to_odd_word = dict([(unhexlify(k.encode("ascii")), both_words[1]) - for k,both_words - in raw_words.items()]) + for k, both_words in raw_words.items()]) even_words_lowercase, odd_words_lowercase = set(), set() -for k,both_words in raw_words.items(): +for k, both_words in raw_words.items(): even_word, odd_word = both_words even_words_lowercase.add(even_word.lower()) odd_words_lowercase.add(odd_word.lower()) + @implementer(IWordlist) class PGPWordList(object): def get_completions(self, prefix, num_words=2): @@ -177,7 +307,7 @@ class PGPWordList(object): else: suffix = prefix[:-lp] + word # append a hyphen if we expect more words - if count+1 < num_words: + if count + 1 < num_words: suffix += "-" completions.add(suffix) return completions diff --git a/src/wormhole/cli/cli.py b/src/wormhole/cli/cli.py index 58ccd8f..ba74157 100644 --- a/src/wormhole/cli/cli.py +++ b/src/wormhole/cli/cli.py @@ -2,21 +2,24 @@ from __future__ import print_function import os import time -start = time.time() -import six -from textwrap import fill, dedent -from sys import stdout, stderr -from . import public_relay -from .. import __version__ -from ..timing import DebugTiming -from ..errors import (WrongPasswordError, WelcomeError, KeyFormatError, - TransferError, NoTorError, UnsendableFileError, - ServerConnectionError) -from twisted.internet.defer import inlineCallbacks, maybeDeferred -from twisted.python.failure import Failure -from twisted.internet.task import react +from sys import stderr, stdout +from textwrap import dedent, fill import click +import six +from twisted.internet.defer import inlineCallbacks, maybeDeferred +from twisted.internet.task import react +from twisted.python.failure import Failure + +from . import public_relay +from .. import __version__ +from ..errors import (KeyFormatError, NoTorError, ServerConnectionError, + TransferError, UnsendableFileError, WelcomeError, + WrongPasswordError) +from ..timing import DebugTiming + +start = time.time() + top_import_finish = time.time() @@ -24,6 +27,7 @@ class Config(object): """ Union of config options that we pass down to (sub) commands. """ + def __init__(self): # This only holds attributes which are *not* set by CLI arguments. # Everything else comes from Click decorators, so we can be sure @@ -34,11 +38,13 @@ class Config(object): self.stderr = stderr self.tor = False # XXX? + def _compose(*decorators): def decorate(f): for d in reversed(decorators): f = d(f) return f + return decorate @@ -48,6 +54,8 @@ ALIASES = { "recieve": "receive", "recv": "receive", } + + class AliasedGroup(click.Group): def get_command(self, ctx, cmd_name): cmd_name = ALIASES.get(cmd_name, cmd_name) @@ -56,22 +64,24 @@ class AliasedGroup(click.Group): # top-level command ("wormhole ...") @click.group(cls=AliasedGroup) +@click.option("--appid", default=None, metavar="APPID", help="appid to use") @click.option( - "--appid", default=None, metavar="APPID", help="appid to use") -@click.option( - "--relay-url", default=public_relay.RENDEZVOUS_RELAY, + "--relay-url", + default=public_relay.RENDEZVOUS_RELAY, envvar='WORMHOLE_RELAY_URL', metavar="URL", help="rendezvous relay to use", ) @click.option( - "--transit-helper", default=public_relay.TRANSIT_RELAY, + "--transit-helper", + default=public_relay.TRANSIT_RELAY, envvar='WORMHOLE_TRANSIT_HELPER', metavar="tcp:HOST:PORT", help="transit relay to use", ) @click.option( - "--dump-timing", type=type(u""), # TODO: hide from --help output + "--dump-timing", + type=type(u""), # TODO: hide from --help output default=None, metavar="FILE.json", help="(debug) write timing data to file", @@ -104,7 +114,8 @@ def _dispatch_command(reactor, cfg, command): errors for the user. """ cfg.timing.add("command dispatch") - cfg.timing.add("import", when=start, which="top").finish(when=top_import_finish) + cfg.timing.add( + "import", when=start, which="top").finish(when=top_import_finish) try: yield maybeDeferred(command) @@ -141,56 +152,89 @@ def _dispatch_command(reactor, cfg, command): CommonArgs = _compose( - click.option("-0", "zeromode", default=False, is_flag=True, - help="enable no-code anything-goes mode", - ), - click.option("-c", "--code-length", default=2, metavar="NUMWORDS", - help="length of code (in bytes/words)", - ), - click.option("-v", "--verify", is_flag=True, default=False, - help="display verification string (and wait for approval)", - ), - click.option("--hide-progress", is_flag=True, default=False, - help="supress progress-bar display", - ), - click.option("--listen/--no-listen", default=True, - help="(debug) don't open a listening socket for Transit", - ), + click.option( + "-0", + "zeromode", + default=False, + is_flag=True, + help="enable no-code anything-goes mode", + ), + click.option( + "-c", + "--code-length", + default=2, + metavar="NUMWORDS", + help="length of code (in bytes/words)", + ), + click.option( + "-v", + "--verify", + is_flag=True, + default=False, + help="display verification string (and wait for approval)", + ), + click.option( + "--hide-progress", + is_flag=True, + default=False, + help="supress progress-bar display", + ), + click.option( + "--listen/--no-listen", + default=True, + help="(debug) don't open a listening socket for Transit", + ), ) TorArgs = _compose( - click.option("--tor", is_flag=True, default=False, - help="use Tor when connecting", - ), - click.option("--launch-tor", is_flag=True, default=False, - help="launch Tor, rather than use existing control/socks port", - ), - click.option("--tor-control-port", default=None, metavar="ENDPOINT", - help="endpoint descriptor for Tor control port", - ), + click.option( + "--tor", + is_flag=True, + default=False, + help="use Tor when connecting", + ), + click.option( + "--launch-tor", + is_flag=True, + default=False, + help="launch Tor, rather than use existing control/socks port", + ), + click.option( + "--tor-control-port", + default=None, + metavar="ENDPOINT", + help="endpoint descriptor for Tor control port", + ), ) + @wormhole.command() @click.pass_context def help(context, **kwargs): print(context.find_root().get_help()) + # wormhole send (or "wormhole tx") @wormhole.command() @CommonArgs @TorArgs @click.option( - "--code", metavar="CODE", + "--code", + metavar="CODE", help="human-generated code phrase", ) @click.option( - "--text", default=None, metavar="MESSAGE", - help="text message to send, instead of a file. Use '-' to read from stdin.", + "--text", + default=None, + metavar="MESSAGE", + help=("text message to send, instead of a file." + " Use '-' to read from stdin."), ) @click.option( - "--ignore-unsendable-files", default=False, is_flag=True, - help="Don't raise an error if a file can't be read." -) + "--ignore-unsendable-files", + default=False, + is_flag=True, + help="Don't raise an error if a file can't be read.") @click.argument("what", required=False, type=click.Path(path_type=type(u""))) @click.pass_obj def send(cfg, **kwargs): @@ -202,6 +246,7 @@ def send(cfg, **kwargs): return go(cmd_send.send, cfg) + # this intermediate function can be mocked by tests that need to build a # Config object def go(f, cfg): @@ -214,23 +259,29 @@ def go(f, cfg): @CommonArgs @TorArgs @click.option( - "--only-text", "-t", is_flag=True, + "--only-text", + "-t", + is_flag=True, help="refuse file transfers, only accept text transfers", ) @click.option( - "--accept-file", is_flag=True, + "--accept-file", + is_flag=True, help="accept file transfer without asking for confirmation", ) @click.option( - "--output-file", "-o", + "--output-file", + "-o", metavar="FILENAME|DIRNAME", help=("The file or directory to create, overriding the name suggested" " by the sender."), ) @click.argument( - "code", nargs=-1, default=None, -# help=("The magic-wormhole code, from the sender. If omitted, the" -# " program will ask for it, using tab-completion."), + "code", + nargs=-1, + default=None, + # help=("The magic-wormhole code, from the sender. If omitted, the" + # " program will ask for it, using tab-completion."), ) @click.pass_obj def receive(cfg, code, **kwargs): @@ -244,10 +295,8 @@ def receive(cfg, code, **kwargs): if len(code) == 1: cfg.code = code[0] elif len(code) > 1: - print( - "Pass either no code or just one code; you passed" - " {}: {}".format(len(code), ', '.join(code)) - ) + print("Pass either no code or just one code; you passed" + " {}: {}".format(len(code), ', '.join(code))) raise SystemExit(1) else: cfg.code = None @@ -260,17 +309,19 @@ def ssh(): """ Facilitate sending/receiving SSH public keys """ - pass @ssh.command(name="invite") @click.option( - "-c", "--code-length", default=2, + "-c", + "--code-length", + default=2, metavar="NUMWORDS", help="length of code (in bytes/words)", ) @click.option( - "--user", "-u", + "--user", + "-u", default=None, metavar="USER", help="Add to USER's ~/.ssh/authorized_keys", @@ -291,15 +342,20 @@ def ssh_invite(ctx, code_length, user, **kwargs): @ssh.command(name="accept") @click.argument( - "code", nargs=1, required=True, + "code", + nargs=1, + required=True, ) @click.option( - "--key-file", "-F", + "--key-file", + "-F", default=None, type=click.Path(exists=True), ) @click.option( - "--yes", "-y", is_flag=True, + "--yes", + "-y", + is_flag=True, help="Skip confirmation prompt to send key", ) @TorArgs @@ -318,7 +374,8 @@ def ssh_accept(cfg, code, key_file, yes, **kwargs): kind, keyid, pubkey = cmd_ssh.find_public_key(key_file) print("Sending public key type='{}' keyid='{}'".format(kind, keyid)) if yes is not True: - click.confirm("Really send public key '{}' ?".format(keyid), abort=True) + click.confirm( + "Really send public key '{}' ?".format(keyid), abort=True) cfg.public_key = (kind, keyid, pubkey) cfg.code = code diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index 1cb457b..4b18822 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -1,14 +1,23 @@ from __future__ import print_function -import os, sys, six, tempfile, zipfile, hashlib, shutil -from tqdm import tqdm + +import hashlib +import os +import shutil +import sys +import tempfile +import zipfile + +import six from humanize import naturalsize +from tqdm import tqdm from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue from twisted.python import log -from wormhole import create, input_with_completion, __version__ -from ..transit import TransitReceiver +from wormhole import __version__, create, input_with_completion + from ..errors import TransferError -from ..util import (dict_to_bytes, bytes_to_dict, bytes_to_hexstr, +from ..transit import TransitReceiver +from ..util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes, estimate_free_space) from .welcome import handle_welcome @@ -17,14 +26,17 @@ APPID = u"lothar.com/wormhole/text-or-file-xfer" KEY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_KEY_TIMER", 1.0)) VERIFY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_VERIFY_TIMER", 1.0)) + class RespondError(Exception): def __init__(self, response): self.response = response + class TransferRejectedError(RespondError): def __init__(self): RespondError.__init__(self, "transfer rejected") + def receive(args, reactor=reactor, _debug_stash_wormhole=None): """I implement 'wormhole receive'. I return a Deferred that fires with None (for success), or signals one of the following errors: @@ -60,16 +72,19 @@ class Receiver: # tor in parallel with everything else, make sure the Tor object # can lazy-provide an endpoint, and overlap the startup process # with the user handing off the wormhole code - self._tor = yield get_tor(self._reactor, - self.args.launch_tor, - self.args.tor_control_port, - timing=self.args.timing) + self._tor = yield get_tor( + self._reactor, + self.args.launch_tor, + self.args.tor_control_port, + timing=self.args.timing) - w = create(self.args.appid or APPID, self.args.relay_url, - self._reactor, - tor=self._tor, - timing=self.args.timing) - self._w = w # so tests can wait on events too + w = create( + self.args.appid or APPID, + self.args.relay_url, + self._reactor, + tor=self._tor, + timing=self.args.timing) + self._w = w # so tests can wait on events too # I wanted to do this instead: # @@ -87,7 +102,7 @@ class Receiver: # (which might be an error) @inlineCallbacks def _good(res): - yield w.close() # wait for ack + yield w.close() # wait for ack returnValue(res) # if we raise an error, we should close and then return the original @@ -96,8 +111,8 @@ class Receiver: @inlineCallbacks def _bad(f): try: - yield w.close() # might be an error too - except: + yield w.close() # might be an error too + except Exception: pass returnValue(f) @@ -114,6 +129,7 @@ class Receiver: def on_slow_key(): print(u"Waiting for sender...", file=self.args.stderr) + notify = self._reactor.callLater(KEY_TIMER, on_slow_key) try: # We wait here until we connect to the server and see the senders @@ -129,8 +145,10 @@ class Receiver: notify.cancel() def on_slow_verification(): - print(u"Key established, waiting for confirmation...", - file=self.args.stderr) + print( + u"Key established, waiting for confirmation...", + file=self.args.stderr) + notify = self._reactor.callLater(VERIFY_TIMER, on_slow_verification) try: # We wait here until we've seen their VERSION message (which they @@ -155,7 +173,7 @@ class Receiver: while True: them_d = yield self._get_data(w) - #print("GOT", them_d) + # print("GOT", them_d) recognized = False if u"transit" in them_d: recognized = True @@ -172,7 +190,7 @@ class Receiver: raise TransferError(r.response) returnValue(None) if not recognized: - log.msg("unrecognized message %r" % (them_d,)) + log.msg("unrecognized message %r" % (them_d, )) def _send_data(self, data, w): data_bytes = dict_to_bytes(data) @@ -197,12 +215,12 @@ class Receiver: w.set_code(code) else: prompt = "Enter receive wormhole code: " - used_completion = yield input_with_completion(prompt, - w.input_code(), - self._reactor) + used_completion = yield input_with_completion( + prompt, w.input_code(), self._reactor) if not used_completion: - print(" (note: you can use to complete words)", - file=self.args.stderr) + print( + " (note: you can use to complete words)", + file=self.args.stderr) yield w.get_code() def _show_verifier(self, verifier_bytes): @@ -220,21 +238,24 @@ class Receiver: @inlineCallbacks def _build_transit(self, w, sender_transit): - tr = TransitReceiver(self.args.transit_helper, - no_listen=(not self.args.listen), - tor=self._tor, - reactor=self._reactor, - timing=self.args.timing) + tr = TransitReceiver( + self.args.transit_helper, + no_listen=(not self.args.listen), + tor=self._tor, + reactor=self._reactor, + timing=self.args.timing) self._transit_receiver = tr - transit_key = w.derive_key(APPID+u"/transit-key", tr.TRANSIT_KEY_LENGTH) + transit_key = w.derive_key(APPID + u"/transit-key", + tr.TRANSIT_KEY_LENGTH) tr.set_transit_key(transit_key) tr.add_connection_hints(sender_transit.get("hints-v1", [])) receiver_abilities = tr.get_connection_abilities() receiver_hints = yield tr.get_connection_hints() - receiver_transit = {"abilities-v1": receiver_abilities, - "hints-v1": receiver_hints, - } + receiver_transit = { + "abilities-v1": receiver_abilities, + "hints-v1": receiver_hints, + } self._send_data({u"transit": receiver_transit}, w) # TODO: send more hints as the TransitReceiver produces them @@ -260,7 +281,7 @@ class Receiver: yield self._close_transit(rp, datahash) else: self._msg(u"I don't know what they're offering\n") - self._msg(u"Offer details: %r" % (them_d,)) + self._msg(u"Offer details: %r" % (them_d, )) raise RespondError("unknown offer type") def _handle_text(self, them_d, w): @@ -276,12 +297,13 @@ class Receiver: self.xfersize = file_data["filesize"] free = estimate_free_space(self.abs_destname) if free is not None and free < self.xfersize: - self._msg(u"Error: insufficient free space (%sB) for file (%sB)" - % (free, self.xfersize)) + self._msg(u"Error: insufficient free space (%sB) for file (%sB)" % + (free, self.xfersize)) raise TransferRejectedError() self._msg(u"Receiving file (%s) into: %s" % - (naturalsize(self.xfersize), os.path.basename(self.abs_destname))) + (naturalsize(self.xfersize), + os.path.basename(self.abs_destname))) self._ask_permission() tmp_destname = self.abs_destname + ".tmp" return open(tmp_destname, "wb") @@ -290,19 +312,22 @@ class Receiver: file_data = them_d["directory"] zipmode = file_data["mode"] if zipmode != "zipfile/deflated": - self._msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,)) + self._msg(u"Error: unknown directory-transfer mode '%s'" % + (zipmode, )) raise RespondError("unknown mode") self.abs_destname = self._decide_destname("directory", file_data["dirname"]) self.xfersize = file_data["zipsize"] free = estimate_free_space(self.abs_destname) if free is not None and free < file_data["numbytes"]: - self._msg(u"Error: insufficient free space (%sB) for directory (%sB)" - % (free, file_data["numbytes"])) + self._msg( + u"Error: insufficient free space (%sB) for directory (%sB)" % + (free, file_data["numbytes"])) raise TransferRejectedError() self._msg(u"Receiving directory (%s) into: %s/" % - (naturalsize(self.xfersize), os.path.basename(self.abs_destname))) + (naturalsize(self.xfersize), + os.path.basename(self.abs_destname))) self._msg(u"%d files, %s (uncompressed)" % (file_data["numfiles"], naturalsize(file_data["numbytes"]))) self._ask_permission() @@ -313,23 +338,26 @@ class Receiver: # "~/.ssh/authorized_keys" and other attacks destname = os.path.basename(destname) if self.args.output_file: - destname = self.args.output_file # override - abs_destname = os.path.abspath( os.path.join(self.args.cwd, destname) ) + destname = self.args.output_file # override + abs_destname = os.path.abspath(os.path.join(self.args.cwd, destname)) # get confirmation from the user before writing to the local directory if os.path.exists(abs_destname): - if self.args.output_file: # overwrite is intentional + if self.args.output_file: # overwrite is intentional self._msg(u"Overwriting '%s'" % destname) if self.args.accept_file: self._remove_existing(abs_destname) else: - self._msg(u"Error: refusing to overwrite existing '%s'" % destname) + self._msg( + u"Error: refusing to overwrite existing '%s'" % destname) raise TransferRejectedError() return abs_destname def _remove_existing(self, path): - if os.path.isfile(path): os.remove(path) - if os.path.isdir(path): shutil.rmtree(path) + if os.path.isfile(path): + os.remove(path) + if os.path.isdir(path): + shutil.rmtree(path) def _ask_permission(self): with self.args.timing.add("permission", waiting="user") as t: @@ -345,7 +373,7 @@ class Receiver: t.detail(answer="yes") def _send_permission(self, w): - self._send_data({"answer": { "file_ack": "ok" }}, w) + self._send_data({"answer": {"file_ack": "ok"}}, w) @inlineCallbacks def _establish_transit(self): @@ -359,14 +387,16 @@ class Receiver: self._msg(u"Receiving (%s).." % record_pipe.describe()) with self.args.timing.add("rx file"): - progress = tqdm(file=self.args.stderr, - disable=self.args.hide_progress, - unit="B", unit_scale=True, total=self.xfersize) + progress = tqdm( + file=self.args.stderr, + disable=self.args.hide_progress, + unit="B", + unit_scale=True, + total=self.xfersize) hasher = hashlib.sha256() with progress: - received = yield record_pipe.writeToFile(f, self.xfersize, - progress.update, - hasher.update) + received = yield record_pipe.writeToFile( + f, self.xfersize, progress.update, hasher.update) datahash = hasher.digest() # except TransitError @@ -382,25 +412,26 @@ class Receiver: tmp_name = f.name f.close() os.rename(tmp_name, self.abs_destname) - self._msg(u"Received file written to %s" % - os.path.basename(self.abs_destname)) + self._msg(u"Received file written to %s" % os.path.basename( + self.abs_destname)) def _extract_file(self, zf, info, extract_dir): """ the zipfile module does not restore file permissions so we'll do it manually """ - out_path = os.path.join( extract_dir, info.filename ) - out_path = os.path.abspath( out_path ) - if not out_path.startswith( extract_dir ): - raise ValueError( "malicious zipfile, %s outside of extract_dir %s" - % (info.filename, extract_dir) ) + out_path = os.path.join(extract_dir, info.filename) + out_path = os.path.abspath(out_path) + if not out_path.startswith(extract_dir): + raise ValueError( + "malicious zipfile, %s outside of extract_dir %s" % + (info.filename, extract_dir)) - zf.extract( info.filename, path=extract_dir ) + zf.extract(info.filename, path=extract_dir) # not sure why zipfiles store the perms 16 bits away but they do perm = info.external_attr >> 16 - os.chmod( out_path, perm ) + os.chmod(out_path, perm) def _write_directory(self, f): @@ -408,10 +439,10 @@ class Receiver: with self.args.timing.add("unpack zip"): with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: for info in zf.infolist(): - self._extract_file( zf, info, self.abs_destname ) + self._extract_file(zf, info, self.abs_destname) - self._msg(u"Received files written to %s/" % - os.path.basename(self.abs_destname)) + self._msg(u"Received files written to %s/" % os.path.basename( + self.abs_destname)) f.close() @inlineCallbacks diff --git a/src/wormhole/cli/cmd_send.py b/src/wormhole/cli/cmd_send.py index 969ab0f..ea5d69e 100644 --- a/src/wormhole/cli/cmd_send.py +++ b/src/wormhole/cli/cmd_send.py @@ -1,20 +1,29 @@ from __future__ import print_function -import os, sys, six, tempfile, zipfile, hashlib -from tqdm import tqdm + +import hashlib +import os +import sys +import tempfile +import zipfile + +import six from humanize import naturalsize -from twisted.python import log -from twisted.protocols import basic +from tqdm import tqdm from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue +from twisted.protocols import basic +from twisted.python import log +from wormhole import __version__, create + from ..errors import TransferError, UnsendableFileError -from wormhole import create, __version__ from ..transit import TransitSender -from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr +from ..util import bytes_to_dict, bytes_to_hexstr, dict_to_bytes from .welcome import handle_welcome APPID = u"lothar.com/wormhole/text-or-file-xfer" VERIFY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_VERIFY_TIMER", 1.0)) + def send(args, reactor=reactor): """I implement 'wormhole send'. I return a Deferred that fires with None (for success), or signals one of the following errors: @@ -26,6 +35,7 @@ def send(args, reactor=reactor): """ return Sender(args, reactor).go() + class Sender: def __init__(self, args, reactor): self._args = args @@ -45,22 +55,25 @@ class Sender: # tor in parallel with everything else, make sure the Tor object # can lazy-provide an endpoint, and overlap the startup process # with the user handing off the wormhole code - self._tor = yield get_tor(reactor, - self._args.launch_tor, - self._args.tor_control_port, - timing=self._timing) + self._tor = yield get_tor( + reactor, + self._args.launch_tor, + self._args.tor_control_port, + timing=self._timing) - w = create(self._args.appid or APPID, self._args.relay_url, - self._reactor, - tor=self._tor, - timing=self._timing) + w = create( + self._args.appid or APPID, + self._args.relay_url, + self._reactor, + tor=self._tor, + timing=self._timing) d = self._go(w) # if we succeed, we should close and return the w.close results # (which might be an error) @inlineCallbacks def _good(res): - yield w.close() # wait for ack + yield w.close() # wait for ack returnValue(res) # if we raise an error, we should close and then return the original @@ -69,8 +82,8 @@ class Sender: @inlineCallbacks def _bad(f): try: - yield w.close() # might be an error too - except: + yield w.close() # might be an error too + except Exception: pass returnValue(f) @@ -125,8 +138,10 @@ class Sender: # TODO: don't stall on w.get_verifier() unless they want it def on_slow_connection(): - print(u"Key established, waiting for confirmation...", - file=args.stderr) + print( + u"Key established, waiting for confirmation...", + file=args.stderr) + notify = self._reactor.callLater(VERIFY_TIMER, on_slow_connection) try: # The usual sender-chooses-code sequence means the receiver's @@ -137,32 +152,35 @@ class Sender: # sitting here for a while, so printing the "waiting" message # seems like a good idea. It might even be appropriate to give up # after a while. - verifier_bytes = yield w.get_verifier() # might WrongPasswordError + verifier_bytes = yield w.get_verifier() # might WrongPasswordError finally: if not notify.called: notify.cancel() if args.verify: - self._check_verifier(w, verifier_bytes) # blocks, can TransferError + self._check_verifier(w, + verifier_bytes) # blocks, can TransferError if self._fd_to_send: - ts = TransitSender(args.transit_helper, - no_listen=(not args.listen), - tor=self._tor, - reactor=self._reactor, - timing=self._timing) + ts = TransitSender( + args.transit_helper, + no_listen=(not args.listen), + tor=self._tor, + reactor=self._reactor, + timing=self._timing) self._transit_sender = ts # for now, send this before the main offer sender_abilities = ts.get_connection_abilities() sender_hints = yield ts.get_connection_hints() - sender_transit = {"abilities-v1": sender_abilities, - "hints-v1": sender_hints, - } + sender_transit = { + "abilities-v1": sender_abilities, + "hints-v1": sender_hints, + } self._send_data({u"transit": sender_transit}, w) # TODO: move this down below w.get_message() - transit_key = w.derive_key(APPID+"/transit-key", + transit_key = w.derive_key(APPID + "/transit-key", ts.TRANSIT_KEY_LENGTH) ts.set_transit_key(transit_key) @@ -175,11 +193,11 @@ class Sender: # TODO: get_message() fired, so get_verifier must have fired, so # now it's safe to use w.derive_key() them_d = bytes_to_dict(them_d_bytes) - #print("GOT", them_d) + # print("GOT", them_d) recognized = False if u"error" in them_d: - raise TransferError("remote error, transfer abandoned: %s" - % them_d["error"]) + raise TransferError( + "remote error, transfer abandoned: %s" % them_d["error"]) if u"transit" in them_d: recognized = True yield self._handle_transit(them_d[u"transit"]) @@ -191,7 +209,7 @@ class Sender: yield self._handle_answer(them_d[u"answer"]) returnValue(None) if not recognized: - log.msg("unrecognized message %r" % (them_d,)) + log.msg("unrecognized message %r" % (them_d, )) def _check_verifier(self, w, verifier_bytes): verifier = bytes_to_hexstr(verifier_bytes) @@ -221,9 +239,10 @@ class Sender: text = six.moves.input("Text to send: ") if text is not None: - print(u"Sending text message (%s)" % naturalsize(len(text)), - file=args.stderr) - offer = { "message": text } + print( + u"Sending text message (%s)" % naturalsize(len(text)), + file=args.stderr) + offer = {"message": text} fd_to_send = None return offer, fd_to_send @@ -244,7 +263,7 @@ class Sender: # is a symlink to something with a different name. The normpath() is # there to remove trailing slashes. basename = os.path.basename(os.path.normpath(what)) - assert basename != "", what # normpath shouldn't allow this + assert basename != "", what # normpath shouldn't allow this # We use realpath() instead of normpath() to locate the actual # file/directory, because the path might contain symlinks, and @@ -273,8 +292,8 @@ class Sender: what = os.path.realpath(what) if not os.path.exists(what): - raise TransferError("Cannot send: no file/directory named '%s'" % - args.what) + raise TransferError( + "Cannot send: no file/directory named '%s'" % args.what) if os.path.isfile(what): # we're sending a file @@ -282,10 +301,11 @@ class Sender: offer["file"] = { "filename": basename, "filesize": filesize, - } - print(u"Sending %s file named '%s'" - % (naturalsize(filesize), basename), - file=args.stderr) + } + print( + u"Sending %s file named '%s'" % (naturalsize(filesize), + basename), + file=args.stderr) fd_to_send = open(what, "rb") return offer, fd_to_send @@ -297,16 +317,18 @@ class Sender: num_files = 0 num_bytes = 0 tostrip = len(what.split(os.sep)) - with zipfile.ZipFile(fd_to_send, "w", - compression=zipfile.ZIP_DEFLATED, - allowZip64=True) as zf: - for path,dirs,files in os.walk(what): + with zipfile.ZipFile( + fd_to_send, + "w", + compression=zipfile.ZIP_DEFLATED, + allowZip64=True) as zf: + for path, dirs, files in os.walk(what): # path always starts with args.what, then sometimes might # have "/subdir" appended. We want the zipfile to contain # "" or "subdir" localpath = list(path.split(os.sep)[tostrip:]) for fn in files: - archivename = os.path.join(*tuple(localpath+[fn])) + archivename = os.path.join(*tuple(localpath + [fn])) localfilename = os.path.join(path, fn) try: zf.write(localfilename, archivename) @@ -315,22 +337,25 @@ class Sender: except OSError as e: errmsg = u"{}: {}".format(fn, e.strerror) if self._args.ignore_unsendable_files: - print(u"{} (ignoring error)".format(errmsg), - file=args.stderr) + print( + u"{} (ignoring error)".format(errmsg), + file=args.stderr) else: raise UnsendableFileError(errmsg) - fd_to_send.seek(0,2) + fd_to_send.seek(0, 2) filesize = fd_to_send.tell() - fd_to_send.seek(0,0) + fd_to_send.seek(0, 0) offer["directory"] = { "mode": "zipfile/deflated", "dirname": basename, "zipsize": filesize, "numbytes": num_bytes, "numfiles": num_files, - } - print(u"Sending directory (%s compressed) named '%s'" - % (naturalsize(filesize), basename), file=args.stderr) + } + print( + u"Sending directory (%s compressed) named '%s'" % + (naturalsize(filesize), basename), + file=args.stderr) return offer, fd_to_send raise TypeError("'%s' is neither file nor directory" % args.what) @@ -340,23 +365,22 @@ class Sender: if self._fd_to_send is None: if them_answer["message_ack"] == "ok": print(u"text message sent", file=self._args.stderr) - returnValue(None) # terminates this function - raise TransferError("error sending text: %r" % (them_answer,)) + returnValue(None) # terminates this function + raise TransferError("error sending text: %r" % (them_answer, )) if them_answer.get("file_ack") != "ok": raise TransferError("ambiguous response from remote, " - "transfer abandoned: %s" % (them_answer,)) + "transfer abandoned: %s" % (them_answer, )) yield self._send_file() - @inlineCallbacks def _send_file(self): ts = self._transit_sender - self._fd_to_send.seek(0,2) + self._fd_to_send.seek(0, 2) filesize = self._fd_to_send.tell() - self._fd_to_send.seek(0,0) + self._fd_to_send.seek(0, 0) record_pipe = yield ts.connect() self._timing.add("transit connected") @@ -365,21 +389,28 @@ class Sender: print(u"Sending (%s).." % record_pipe.describe(), file=stderr) hasher = hashlib.sha256() - progress = tqdm(file=stderr, disable=self._args.hide_progress, - unit="B", unit_scale=True, - total=filesize) + progress = tqdm( + file=stderr, + disable=self._args.hide_progress, + unit="B", + unit_scale=True, + total=filesize) + def _count_and_hash(data): hasher.update(data) progress.update(len(data)) return data + fs = basic.FileSender() with self._timing.add("tx file"): with progress: if filesize: # don't send zero-length files - yield fs.beginFileTransfer(self._fd_to_send, record_pipe, - transform=_count_and_hash) + yield fs.beginFileTransfer( + self._fd_to_send, + record_pipe, + transform=_count_and_hash) expected_hash = hasher.digest() expected_hex = bytes_to_hexstr(expected_hash) diff --git a/src/wormhole/cli/cmd_ssh.py b/src/wormhole/cli/cmd_ssh.py index 8ac60a1..cbd5556 100644 --- a/src/wormhole/cli/cmd_ssh.py +++ b/src/wormhole/cli/cmd_ssh.py @@ -1,16 +1,19 @@ from __future__ import print_function import os -from os.path import expanduser, exists, join -from twisted.internet.defer import inlineCallbacks -from twisted.internet import reactor +from os.path import exists, expanduser, join + import click +from twisted.internet import reactor +from twisted.internet.defer import inlineCallbacks from .. import xfer_util + class PubkeyError(Exception): pass + def find_public_key(hint=None): """ This looks for an appropriate SSH key to send, possibly querying @@ -34,8 +37,9 @@ def find_public_key(hint=None): got_key = False while not got_key: ans = click.prompt( - "Multiple public-keys found:\n" + \ - "\n".join([" {}: {}".format(a, b) for a, b in enumerate(pubkeys)]) + \ + "Multiple public-keys found:\n" + + "\n".join([" {}: {}".format(a, b) + for a, b in enumerate(pubkeys)]) + "\nSend which one?" ) try: @@ -76,7 +80,6 @@ def accept(cfg, reactor=reactor): @inlineCallbacks def invite(cfg, reactor=reactor): - def on_code_created(code): print("Now tell the other user to run:") print() diff --git a/src/wormhole/cli/public_relay.py b/src/wormhole/cli/public_relay.py index 3cb0e62..c816bb2 100644 --- a/src/wormhole/cli/public_relay.py +++ b/src/wormhole/cli/public_relay.py @@ -1,4 +1,3 @@ - # This is a relay I run on a personal server. If it gets too expensive to # run, I'll shut it down. RENDEZVOUS_RELAY = u"ws://relay.magic-wormhole.io:4000/v1" diff --git a/src/wormhole/cli/welcome.py b/src/wormhole/cli/welcome.py index 93ac1cb..2bbfc7e 100644 --- a/src/wormhole/cli/welcome.py +++ b/src/wormhole/cli/welcome.py @@ -1,18 +1,23 @@ -from __future__ import print_function, absolute_import, unicode_literals +from __future__ import absolute_import, print_function, unicode_literals + def handle_welcome(welcome, relay_url, my_version, stderr): if "motd" in welcome: motd_lines = welcome["motd"].splitlines() motd_formatted = "\n ".join(motd_lines) - print("Server (at %s) says:\n %s" % (relay_url, motd_formatted), - file=stderr) + print( + "Server (at %s) says:\n %s" % (relay_url, motd_formatted), + file=stderr) # Only warn if we're running a release version (e.g. 0.0.6, not # 0.0.6+DISTANCE.gHASH). Only warn once. - if ("current_cli_version" in welcome - and "+" not in my_version - and welcome["current_cli_version"] != my_version): - print("Warning: errors may occur unless both sides are running the same version", file=stderr) - print("Server claims %s is current, but ours is %s" - % (welcome["current_cli_version"], my_version), - file=stderr) + if ("current_cli_version" in welcome and "+" not in my_version + and welcome["current_cli_version"] != my_version): + print( + ("Warning: errors may occur unless both sides are running the" + " same version"), + file=stderr) + print( + "Server claims %s is current, but ours is %s" % + (welcome["current_cli_version"], my_version), + file=stderr) diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index 8bd9718..64edf37 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -1,8 +1,10 @@ from __future__ import unicode_literals + class WormholeError(Exception): """Parent class for all wormhole-related errors""" + class UnsendableFileError(Exception): """ A file you wanted to send couldn't be read, maybe because it's not @@ -13,30 +15,38 @@ class UnsendableFileError(Exception): --ignore-unsendable-files flag. """ + class ServerError(WormholeError): """The relay server complained about something we did.""" + class ServerConnectionError(WormholeError): """We had a problem connecting to the relay server:""" + def __init__(self, url, reason): self.url = url self.reason = reason + def __str__(self): return str(self.reason) + class Timeout(WormholeError): pass + class WelcomeError(WormholeError): """ The relay server told us to signal an error, probably because our version is too old to possibly work. The server said:""" pass + class LonelyError(WormholeError): """wormhole.close() was called before the peer connection could be established""" + class WrongPasswordError(WormholeError): """ Key confirmation failed. Either you or your correspondent typed the code @@ -47,6 +57,7 @@ class WrongPasswordError(WormholeError): # or the data blob was corrupted, and that's why decrypt failed pass + class KeyFormatError(WormholeError): """ The key you entered contains spaces or was missing a dash. Magic-wormhole @@ -55,43 +66,61 @@ class KeyFormatError(WormholeError): dashes. """ + class ReflectionAttack(WormholeError): """An attacker (or bug) reflected our outgoing message back to us.""" + class InternalError(WormholeError): """The programmer did something wrong.""" + class TransferError(WormholeError): """Something bad happened and the transfer failed.""" + class NoTorError(WormholeError): """--tor was requested, but 'txtorcon' is not installed.""" + class NoKeyError(WormholeError): """w.derive_key() was called before got_verifier() fired""" + class OnlyOneCodeError(WormholeError): """Only one w.generate_code/w.set_code/w.input_code may be called""" + class MustChooseNameplateFirstError(WormholeError): """The InputHelper was asked to do get_word_completions() or choose_words() before the nameplate was chosen.""" + + class AlreadyChoseNameplateError(WormholeError): """The InputHelper was asked to do get_nameplate_completions() after choose_nameplate() was called, or choose_nameplate() was called a second time.""" + + class AlreadyChoseWordsError(WormholeError): """The InputHelper was asked to do get_word_completions() after choose_words() was called, or choose_words() was called a second time.""" + + class AlreadyInputNameplateError(WormholeError): """The CodeInputter was asked to do completion on a nameplate, when we had already committed to a different one.""" + + class WormholeClosed(Exception): """Deferred-returning API calls errback with WormholeClosed if the wormhole was already closed, or if it closes before a real result can be obtained.""" + class _UnknownPhaseError(Exception): """internal exception type, for tests.""" + + class _UnknownMessageTypeError(Exception): """internal exception type, for tests.""" diff --git a/src/wormhole/eventual.py b/src/wormhole/eventual.py index 4fc8731..1d8e061 100644 --- a/src/wormhole/eventual.py +++ b/src/wormhole/eventual.py @@ -5,6 +5,7 @@ from twisted.internet.defer import Deferred from twisted.internet.interfaces import IReactorTime from twisted.python import log + class EventualQueue(object): def __init__(self, clock): # pass clock=reactor unless you're testing @@ -14,7 +15,7 @@ class EventualQueue(object): self._timer = None def eventually(self, f, *args, **kwargs): - self._calls.append( (f, args, kwargs) ) + self._calls.append((f, args, kwargs)) if not self._timer: self._timer = self._clock.callLater(0, self._turn) @@ -28,7 +29,7 @@ class EventualQueue(object): (f, args, kwargs) = self._calls.pop(0) try: f(*args, **kwargs) - except: + except Exception: log.err() self._timer = None d, self._flush_d = self._flush_d, None diff --git a/src/wormhole/ipaddrs.py b/src/wormhole/ipaddrs.py index 2f35d55..f6c8bb3 100644 --- a/src/wormhole/ipaddrs.py +++ b/src/wormhole/ipaddrs.py @@ -1,27 +1,37 @@ # no unicode_literals # Find all of our ip addresses. From tahoe's src/allmydata/util/iputil.py -import os, re, subprocess, errno +import errno +import os +import re +import subprocess from sys import platform + from twisted.python.procutils import which # Wow, I'm really amazed at home much mileage we've gotten out of calling # the external route.exe program on windows... It appears to work on all # versions so far. Still, the real system calls would much be preferred... # ... thus wrote Greg Smith in time immemorial... -_win32_re = re.compile(r'^\s*\d+\.\d+\.\d+\.\d+\s.+\s(?P
\d+\.\d+\.\d+\.\d+)\s+(?P\d+)\s*$', flags=re.M|re.I|re.S) -_win32_commands = (('route.exe', ('print',), _win32_re),) +_win32_re = re.compile( + (r'^\s*\d+\.\d+\.\d+\.\d+\s.+\s' + r'(?P
\d+\.\d+\.\d+\.\d+)\s+(?P\d+)\s*$'), + flags=re.M | re.I | re.S) +_win32_commands = (('route.exe', ('print', ), _win32_re), ) # These work in most Unices. -_addr_re = re.compile(r'^\s*inet [a-zA-Z]*:?(?P
\d+\.\d+\.\d+\.\d+)[\s/].+$', flags=re.M|re.I|re.S) -_unix_commands = (('/bin/ip', ('addr',), _addr_re), - ('/sbin/ip', ('addr',), _addr_re), - ('/sbin/ifconfig', ('-a',), _addr_re), - ('/usr/sbin/ifconfig', ('-a',), _addr_re), - ('/usr/etc/ifconfig', ('-a',), _addr_re), - ('ifconfig', ('-a',), _addr_re), - ('/sbin/ifconfig', (), _addr_re), - ) +_addr_re = re.compile( + r'^\s*inet [a-zA-Z]*:?(?P
\d+\.\d+\.\d+\.\d+)[\s/].+$', + flags=re.M | re.I | re.S) +_unix_commands = ( + ('/bin/ip', ('addr', ), _addr_re), + ('/sbin/ip', ('addr', ), _addr_re), + ('/sbin/ifconfig', ('-a', ), _addr_re), + ('/usr/sbin/ifconfig', ('-a', ), _addr_re), + ('/usr/etc/ifconfig', ('-a', ), _addr_re), + ('ifconfig', ('-a', ), _addr_re), + ('/sbin/ifconfig', (), _addr_re), +) def find_addresses(): @@ -54,17 +64,19 @@ def find_addresses(): return ["127.0.0.1"] + def _query(path, args, regex): env = {'LANG': 'en_US.UTF-8'} trial = 0 while True: trial += 1 try: - p = subprocess.Popen([path] + list(args), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - universal_newlines=True) + p = subprocess.Popen( + [path] + list(args), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + universal_newlines=True) (output, err) = p.communicate() break except OSError as e: diff --git a/src/wormhole/journal.py b/src/wormhole/journal.py index f7bf0f3..b7cffb5 100644 --- a/src/wormhole/journal.py +++ b/src/wormhole/journal.py @@ -1,8 +1,12 @@ -from __future__ import print_function, absolute_import, unicode_literals -from zope.interface import implementer +from __future__ import absolute_import, print_function, unicode_literals + import contextlib + +from zope.interface import implementer + from ._interfaces import IJournal + @implementer(IJournal) class Journal(object): def __init__(self, save_checkpoint): @@ -19,7 +23,7 @@ class Journal(object): assert not self._processing assert not self._outbound_queue self._processing = True - yield # process inbound messages, change state, queue outbound + yield # process inbound messages, change state, queue outbound self._save_checkpoint() for (fn, args, kwargs) in self._outbound_queue: fn(*args, **kwargs) @@ -31,8 +35,10 @@ class Journal(object): class ImmediateJournal(object): def __init__(self): pass + def queue_outbound(self, fn, *args, **kwargs): fn(*args, **kwargs) + @contextlib.contextmanager def process(self): yield diff --git a/src/wormhole/observer.py b/src/wormhole/observer.py index 99a22a0..43afa70 100644 --- a/src/wormhole/observer.py +++ b/src/wormhole/observer.py @@ -1,14 +1,16 @@ -from __future__ import unicode_literals, print_function +from __future__ import print_function, unicode_literals + from twisted.internet.defer import Deferred from twisted.python.failure import Failure NoResult = object() + class OneShotObserver(object): def __init__(self, eventual_queue): self._eq = eventual_queue self._result = NoResult - self._observers = [] # list of Deferreds + self._observers = [] # list of Deferreds def when_fired(self): d = Deferred() @@ -38,6 +40,7 @@ class OneShotObserver(object): if self._result is NoResult: self.fire(result) + class SequenceObserver(object): def __init__(self, eventual_queue): self._eq = eventual_queue diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index 0b3897e..3521299 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -1,16 +1,19 @@ # no unicode_literals untill twisted update -from twisted.application import service, internet -from twisted.internet import defer, task, reactor, endpoints -from twisted.python import log from click.testing import CliRunner +from twisted.application import internet, service +from twisted.internet import defer, endpoints, reactor, task +from twisted.python import log + import mock -from ..cli import cli -from ..transit import allocate_tcp_port +from wormhole_mailbox_server.database import create_channel_db, create_usage_db from wormhole_mailbox_server.server import make_server from wormhole_mailbox_server.web import make_web_server -from wormhole_mailbox_server.database import create_channel_db, create_usage_db from wormhole_transit_relay.transit_server import Transit +from ..cli import cli +from ..transit import allocate_tcp_port + + class MyInternetService(service.Service, object): # like StreamServerEndpointService, but you can retrieve the port def __init__(self, endpoint, factory): @@ -22,12 +25,15 @@ class MyInternetService(service.Service, object): def startService(self): super(MyInternetService, self).startService() d = self.endpoint.listen(self.factory) + def good(lp): self._lp = lp self._port_d.callback(lp.getHost().port) + def bad(f): log.err(f) self._port_d.errback(f) + d.addCallbacks(good, bad) @defer.inlineCallbacks @@ -35,9 +41,10 @@ class MyInternetService(service.Service, object): if self._lp: yield self._lp.stopListening() - def getPort(self): # only call once! + def getPort(self): # only call once! return self._port_d + class ServerBase: @defer.inlineCallbacks def setUp(self): @@ -51,27 +58,27 @@ class ServerBase: # endpoints.serverFromString db = create_channel_db(":memory:") self._usage_db = create_usage_db(":memory:") - self._rendezvous = make_server(db, - advertise_version=advertise_version, - signal_error=error, - usage_db=self._usage_db) + self._rendezvous = make_server( + db, + advertise_version=advertise_version, + signal_error=error, + usage_db=self._usage_db) ep = endpoints.TCP4ServerEndpoint(reactor, 0, interface="127.0.0.1") site = make_web_server(self._rendezvous, log_requests=False) - #self._lp = yield ep.listen(site) + # self._lp = yield ep.listen(site) s = MyInternetService(ep, site) s.setServiceParent(self.sp) self.rdv_ws_port = yield s.getPort() self._relay_server = s - #self._rendezvous = s._rendezvous + # self._rendezvous = s._rendezvous self.relayurl = u"ws://127.0.0.1:%d/v1" % self.rdv_ws_port # ws://127.0.0.1:%d/wormhole-relay/ws self.transitport = allocate_tcp_port() - ep = endpoints.serverFromString(reactor, - "tcp:%d:interface=127.0.0.1" % - self.transitport) - self._transit_server = f = Transit(blur_usage=None, log_file=None, - usage_db=None) + ep = endpoints.serverFromString( + reactor, "tcp:%d:interface=127.0.0.1" % self.transitport) + self._transit_server = f = Transit( + blur_usage=None, log_file=None, usage_db=None) internet.StreamServerEndpointService(ep, f).setServiceParent(self.sp) self.transit = u"tcp:127.0.0.1:%d" % self.transitport @@ -109,6 +116,7 @@ class ServerBase: " I convinced all threads to exit.") yield d + def config(*argv): r = CliRunner() with mock.patch("wormhole.cli.cli.go") as go: @@ -121,6 +129,7 @@ def config(*argv): cfg = go.call_args[0][1] return cfg + @defer.inlineCallbacks def poll_until(predicate): # return a Deferred that won't fire until the predicate is True diff --git a/src/wormhole/test/run_trial.py b/src/wormhole/test/run_trial.py index 1c9ffee..1e218cd 100644 --- a/src/wormhole/test/run_trial.py +++ b/src/wormhole/test/run_trial.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # This is a tiny helper module, to let "python -m wormhole.test.run_trial # ARGS" does the same thing as running "trial ARGS" (unfortunately # twisted/scripts/trial.py does not have a '__name__=="__main__"' clause). diff --git a/src/wormhole/test/test_args.py b/src/wormhole/test/test_args.py index ff6f060..87f9189 100644 --- a/src/wormhole/test/test_args.py +++ b/src/wormhole/test/test_args.py @@ -1,15 +1,17 @@ import os import sys -import mock + from twisted.trial import unittest + +import mock + from ..cli.public_relay import RENDEZVOUS_RELAY, TRANSIT_RELAY from .common import config -#from pprint import pprint + class Send(unittest.TestCase): def test_baseline(self): cfg = config("send", "--text", "hi") - #pprint(cfg.__dict__) self.assertEqual(cfg.what, None) self.assertEqual(cfg.code, None) self.assertEqual(cfg.code_length, 2) @@ -32,7 +34,6 @@ class Send(unittest.TestCase): def test_file(self): cfg = config("send", "fn") - #pprint(cfg.__dict__) self.assertEqual(cfg.what, u"fn") self.assertEqual(cfg.text, None) @@ -101,7 +102,6 @@ class Send(unittest.TestCase): class Receive(unittest.TestCase): def test_baseline(self): cfg = config("receive") - #pprint(cfg.__dict__) self.assertEqual(cfg.accept_file, False) self.assertEqual(cfg.code, None) self.assertEqual(cfg.code_length, 2) @@ -191,10 +191,12 @@ class Receive(unittest.TestCase): cfg = config("--transit-helper", transit_url_2, "receive") self.assertEqual(cfg.transit_helper, transit_url_2) + class Config(unittest.TestCase): def test_send(self): cfg = config("send") self.assertEqual(cfg.stdout, sys.stdout) + def test_receive(self): cfg = config("receive") self.assertEqual(cfg.stdout, sys.stdout) diff --git a/src/wormhole/test/test_cli.py b/src/wormhole/test/test_cli.py index 5cc8b60..a5d5478 100644 --- a/src/wormhole/test/test_cli.py +++ b/src/wormhole/test/test_cli.py @@ -1,22 +1,32 @@ from __future__ import print_function -import os, sys, re, io, zipfile, six, stat -from textwrap import fill, dedent -from humanize import naturalsize -import mock + +import io +import os +import re +import stat +import sys +import zipfile +from textwrap import dedent, fill + +import six from click.testing import CliRunner -from zope.interface import implementer -from twisted.trial import unittest -from twisted.python import procutils, log +from humanize import naturalsize from twisted.internet import endpoints, reactor -from twisted.internet.utils import getProcessOutputAndValue from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue from twisted.internet.error import ConnectionRefusedError +from twisted.internet.utils import getProcessOutputAndValue +from twisted.python import log, procutils +from twisted.trial import unittest +from zope.interface import implementer + +import mock + from .. import __version__ -from .common import ServerBase, config -from ..cli import cmd_send, cmd_receive, welcome, cli -from ..errors import (TransferError, WrongPasswordError, WelcomeError, - UnsendableFileError, ServerConnectionError) from .._interfaces import ITorManager +from ..cli import cli, cmd_receive, cmd_send, welcome +from ..errors import (ServerConnectionError, TransferError, + UnsendableFileError, WelcomeError, WrongPasswordError) +from .common import ServerBase, config def build_offer(args): @@ -108,8 +118,8 @@ class OfferData(unittest.TestCase): self.cfg.cwd = send_dir e = self.assertRaises(TransferError, build_offer, self.cfg) - self.assertEqual(str(e), - "Cannot send: no file/directory named '%s'" % filename) + self.assertEqual( + str(e), "Cannot send: no file/directory named '%s'" % filename) def _do_test_directory(self, addslash): parent_dir = self.mktemp() @@ -178,8 +188,8 @@ class OfferData(unittest.TestCase): self.assertFalse(os.path.isdir(abs_filename)) e = self.assertRaises(TypeError, build_offer, self.cfg) - self.assertEqual(str(e), - "'%s' is neither file nor directory" % filename) + self.assertEqual( + str(e), "'%s' is neither file nor directory" % filename) def test_symlink(self): if not hasattr(os, 'symlink'): @@ -213,8 +223,9 @@ class OfferData(unittest.TestCase): os.mkdir(os.path.join(parent_dir, "B2", "C2")) with open(os.path.join(parent_dir, "B2", "D.txt"), "wb") as f: f.write(b"success") - os.symlink(os.path.abspath(os.path.join(parent_dir, "B2", "C2")), - os.path.join(parent_dir, "B1", "C1")) + os.symlink( + os.path.abspath(os.path.join(parent_dir, "B2", "C2")), + os.path.join(parent_dir, "B1", "C1")) # Now send "B1/C1/../D.txt" from A. The correct traversal will be: # * start: A # * B1: A/B1 @@ -231,6 +242,7 @@ class OfferData(unittest.TestCase): d, fd_to_send = build_offer(self.cfg) self.assertEqual(d["file"]["filename"], "D.txt") self.assertEqual(fd_to_send.read(), b"success") + if os.name == "nt": test_symlink_collapse.todo = "host OS has broken os.path.realpath()" # ntpath.py's realpath() is built out of normpath(), and does not @@ -241,6 +253,7 @@ class OfferData(unittest.TestCase): # misbehavior (albeit in rare circumstances), 2: it probably used to # work (sometimes, but not in #251). See cmd_send.py for more notes. + class LocaleFinder: def __init__(self): self._run_once = False @@ -266,10 +279,10 @@ class LocaleFinder: # twisted.python.usage to avoid this problem in the future. (out, err, rc) = yield getProcessOutputAndValue("locale", ["-a"]) if rc != 0: - log.msg("error running 'locale -a', rc=%s" % (rc,)) - log.msg("stderr: %s" % (err,)) + log.msg("error running 'locale -a', rc=%s" % (rc, )) + log.msg("stderr: %s" % (err, )) returnValue(None) - out = out.decode("utf-8") # make sure we get a string + out = out.decode("utf-8") # make sure we get a string utf8_locales = {} for locale in out.splitlines(): locale = locale.strip() @@ -281,8 +294,11 @@ class LocaleFinder: if utf8_locales: returnValue(list(utf8_locales.values())[0]) returnValue(None) + + locale_finder = LocaleFinder() + class ScriptsBase: def find_executable(self): # to make sure we're running the right executable (in a virtualenv), @@ -292,12 +308,13 @@ class ScriptsBase: if not locations: raise unittest.SkipTest("unable to find 'wormhole' in $PATH") wormhole = locations[0] - if (os.path.dirname(os.path.abspath(wormhole)) != - os.path.dirname(sys.executable)): - log.msg("locations: %s" % (locations,)) - log.msg("sys.executable: %s" % (sys.executable,)) - raise unittest.SkipTest("found the wrong 'wormhole' in $PATH: %s %s" - % (wormhole, sys.executable)) + if (os.path.dirname(os.path.abspath(wormhole)) != os.path.dirname( + sys.executable)): + log.msg("locations: %s" % (locations, )) + log.msg("sys.executable: %s" % (sys.executable, )) + raise unittest.SkipTest( + "found the wrong 'wormhole' in $PATH: %s %s" % + (wormhole, sys.executable)) return wormhole @inlineCallbacks @@ -323,8 +340,8 @@ class ScriptsBase: raise unittest.SkipTest("unable to find UTF-8 locale") locale_env = dict(LC_ALL=locale, LANG=locale) wormhole = self.find_executable() - res = yield getProcessOutputAndValue(wormhole, ["--version"], - env=locale_env) + res = yield getProcessOutputAndValue( + wormhole, ["--version"], env=locale_env) out, err, rc = res if rc != 0: log.msg("wormhole not runnable in this tree:") @@ -334,6 +351,7 @@ class ScriptsBase: raise unittest.SkipTest("wormhole is not runnable in this tree") returnValue(locale_env) + class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase): # we need Twisted to run the server, but we run the sender and receiver # with deferToThread() @@ -347,26 +365,30 @@ class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase): wormhole = self.find_executable() # we must pass on the environment so that "something" doesn't # get sad about UTF8 vs. ascii encodings - out, err, rc = yield getProcessOutputAndValue(wormhole, ["--version"], - env=os.environ) + out, err, rc = yield getProcessOutputAndValue( + wormhole, ["--version"], env=os.environ) err = err.decode("utf-8") if "DistributionNotFound" in err: log.msg("stderr was %s" % err) last = err.strip().split("\n")[-1] self.fail("wormhole not runnable: %s" % last) ver = out.decode("utf-8") or err - self.failUnlessEqual(ver.strip(), "magic-wormhole {}".format(__version__)) + self.failUnlessEqual(ver.strip(), + "magic-wormhole {}".format(__version__)) self.failUnlessEqual(rc, 0) + @implementer(ITorManager) class FakeTor: # use normal endpoints, but record the fact that we were asked def __init__(self): self.endpoints = [] + def stream_via(self, host, port): self.endpoints.append((host, port)) return endpoints.HostnameEndpoint(reactor, host, port) + class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): # we need Twisted to run the server, but we run the sender and receiver # with deferToThread() @@ -377,11 +399,16 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): yield ServerBase.setUp(self) @inlineCallbacks - def _do_test(self, as_subprocess=False, - mode="text", addslash=False, override_filename=False, - fake_tor=False, overwrite=False, mock_accept=False): - assert mode in ("text", "file", "empty-file", "directory", - "slow-text", "slow-sender-text") + def _do_test(self, + as_subprocess=False, + mode="text", + addslash=False, + override_filename=False, + fake_tor=False, + overwrite=False, + mock_accept=False): + assert mode in ("text", "file", "empty-file", "directory", "slow-text", + "slow-sender-text") if fake_tor: assert not as_subprocess send_cfg = config("send") @@ -408,7 +435,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): elif mode in ("file", "empty-file"): if mode == "empty-file": message = "" - send_filename = u"testfil\u00EB" # e-with-diaeresis + send_filename = u"testfil\u00EB" # e-with-diaeresis with open(os.path.join(send_dir, send_filename), "w") as f: f.write(message) send_cfg.what = send_filename @@ -433,8 +460,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): # expect: $receive_dir/$dirname/[12345] send_dirname = u"testdir" + def message(i): return "test message %d\n" % i + os.mkdir(os.path.join(send_dir, u"middle")) source_dir = os.path.join(send_dir, u"middle", send_dirname) os.mkdir(source_dir) @@ -476,21 +505,27 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): env["_MAGIC_WORMHOLE_TEST_KEY_TIMER"] = "999999" env["_MAGIC_WORMHOLE_TEST_VERIFY_TIMER"] = "999999" send_args = [ - '--relay-url', self.relayurl, - '--transit-helper', '', - 'send', - '--hide-progress', - '--code', send_cfg.code, - ] + content_args + '--relay-url', + self.relayurl, + '--transit-helper', + '', + 'send', + '--hide-progress', + '--code', + send_cfg.code, + ] + content_args send_d = getProcessOutputAndValue( - wormhole_bin, send_args, + wormhole_bin, + send_args, path=send_dir, env=env, ) recv_args = [ - '--relay-url', self.relayurl, - '--transit-helper', '', + '--relay-url', + self.relayurl, + '--transit-helper', + '', 'receive', '--hide-progress', '--accept-file', @@ -500,7 +535,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): recv_args.extend(['-o', receive_filename]) receive_d = getProcessOutputAndValue( - wormhole_bin, recv_args, + wormhole_bin, + recv_args, path=receive_dir, env=env, ) @@ -524,25 +560,27 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): send_cfg.tor = True send_cfg.transit_helper = self.transit tx_tm = FakeTor() - with mock.patch("wormhole.tor_manager.get_tor", - return_value=tx_tm, - ) as mtx_tm: + with mock.patch( + "wormhole.tor_manager.get_tor", + return_value=tx_tm, + ) as mtx_tm: send_d = cmd_send.send(send_cfg) recv_cfg.tor = True recv_cfg.transit_helper = self.transit rx_tm = FakeTor() - with mock.patch("wormhole.tor_manager.get_tor", - return_value=rx_tm, - ) as mrx_tm: + with mock.patch( + "wormhole.tor_manager.get_tor", + return_value=rx_tm, + ) as mrx_tm: receive_d = cmd_receive.receive(recv_cfg) else: KEY_TIMER = 0 if mode == "slow-sender-text" else 99999 rxw = [] with mock.patch.object(cmd_receive, "KEY_TIMER", KEY_TIMER): send_d = cmd_send.send(send_cfg) - receive_d = cmd_receive.receive(recv_cfg, - _debug_stash_wormhole=rxw) + receive_d = cmd_receive.receive( + recv_cfg, _debug_stash_wormhole=rxw) # we need to keep KEY_TIMER patched until the receiver # gets far enough to start the timer, which happens after # the code is set @@ -555,8 +593,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): with mock.patch.object(cmd_receive, "VERIFY_TIMER", VERIFY_TIMER): with mock.patch.object(cmd_send, "VERIFY_TIMER", VERIFY_TIMER): if mock_accept: - with mock.patch.object(cmd_receive.six.moves, - 'input', return_value='y'): + with mock.patch.object( + cmd_receive.six.moves, 'input', + return_value='y'): yield gatherResults([send_d, receive_d], True) else: yield gatherResults([send_d, receive_d], True) @@ -567,14 +606,14 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): expected_endpoints.append(("127.0.0.1", self.transitport)) tx_timing = mtx_tm.call_args[1]["timing"] self.assertEqual(tx_tm.endpoints, expected_endpoints) - self.assertEqual(mtx_tm.mock_calls, - [mock.call(reactor, False, None, - timing=tx_timing)]) + self.assertEqual( + mtx_tm.mock_calls, + [mock.call(reactor, False, None, timing=tx_timing)]) rx_timing = mrx_tm.call_args[1]["timing"] self.assertEqual(rx_tm.endpoints, expected_endpoints) - self.assertEqual(mrx_tm.mock_calls, - [mock.call(reactor, False, None, - timing=rx_timing)]) + self.assertEqual( + mrx_tm.mock_calls, + [mock.call(reactor, False, None, timing=rx_timing)]) send_stdout = send_cfg.stdout.getvalue() send_stderr = send_cfg.stderr.getvalue() @@ -585,7 +624,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): # newlines, even if we're on windows NL = "\n" - self.maxDiff = None # show full output for assertion failures + self.maxDiff = None # show full output for assertion failures key_established = "" if mode == "slow-text": @@ -600,38 +639,41 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): "On the other computer, please run:{NL}{NL}" "wormhole receive {code}{NL}{NL}" "{KE}" - "text message sent{NL}").format(bytes=len(message), - code=send_cfg.code, - NL=NL, - KE=key_established) + "text message sent{NL}").format( + bytes=len(message), + code=send_cfg.code, + NL=NL, + KE=key_established) self.failUnlessEqual(send_stderr, expected) elif mode == "file": self.failUnlessIn(u"Sending {size:s} file named '{name}'{NL}" - .format(size=naturalsize(len(message)), - name=send_filename, - NL=NL), send_stderr) + .format( + size=naturalsize(len(message)), + name=send_filename, + NL=NL), send_stderr) self.failUnlessIn(u"Wormhole code is: {code}{NL}" "On the other computer, please run:{NL}{NL}" - "wormhole receive {code}{NL}{NL}" - .format(code=send_cfg.code, NL=NL), - send_stderr) - self.failUnlessIn(u"File sent.. waiting for confirmation{NL}" - "Confirmation received. Transfer complete.{NL}" - .format(NL=NL), send_stderr) + "wormhole receive {code}{NL}{NL}".format( + code=send_cfg.code, NL=NL), send_stderr) + self.failUnlessIn( + u"File sent.. waiting for confirmation{NL}" + "Confirmation received. Transfer complete.{NL}".format(NL=NL), + send_stderr) elif mode == "directory": self.failUnlessIn(u"Sending directory", send_stderr) self.failUnlessIn(u"named 'testdir'", send_stderr) self.failUnlessIn(u"Wormhole code is: {code}{NL}" "On the other computer, please run:{NL}{NL}" - "wormhole receive {code}{NL}{NL}" - .format(code=send_cfg.code, NL=NL), send_stderr) - self.failUnlessIn(u"File sent.. waiting for confirmation{NL}" - "Confirmation received. Transfer complete.{NL}" - .format(NL=NL), send_stderr) + "wormhole receive {code}{NL}{NL}".format( + code=send_cfg.code, NL=NL), send_stderr) + self.failUnlessIn( + u"File sent.. waiting for confirmation{NL}" + "Confirmation received. Transfer complete.{NL}".format(NL=NL), + send_stderr) # check receiver if mode in ("text", "slow-text", "slow-sender-text"): - self.assertEqual(receive_stdout, message+NL) + self.assertEqual(receive_stdout, message + NL) if mode == "text": self.assertEqual(receive_stderr, "") elif mode == "slow-text": @@ -640,9 +682,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.assertEqual(receive_stderr, "Waiting for sender...\n") elif mode == "file": self.failUnlessEqual(receive_stdout, "") - self.failUnlessIn(u"Receiving file ({size:s}) into: {name}" - .format(size=naturalsize(len(message)), - name=receive_filename), receive_stderr) + self.failUnlessIn(u"Receiving file ({size:s}) into: {name}".format( + size=naturalsize(len(message)), name=receive_filename), + receive_stderr) self.failUnlessIn(u"Received file written to ", receive_stderr) fn = os.path.join(receive_dir, receive_filename) self.failUnless(os.path.exists(fn)) @@ -652,52 +694,67 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.failUnlessEqual(receive_stdout, "") want = (r"Receiving directory \(\d+ \w+\) into: {name}/" .format(name=receive_dirname)) - self.failUnless(re.search(want, receive_stderr), - (want, receive_stderr)) - self.failUnlessIn(u"Received files written to {name}" - .format(name=receive_dirname), receive_stderr) + self.failUnless( + re.search(want, receive_stderr), (want, receive_stderr)) + self.failUnlessIn( + u"Received files written to {name}" + .format(name=receive_dirname), + receive_stderr) fn = os.path.join(receive_dir, receive_dirname) self.failUnless(os.path.exists(fn), fn) for i in range(5): fn = os.path.join(receive_dir, receive_dirname, str(i)) with open(fn, "r") as f: self.failUnlessEqual(f.read(), message(i)) - self.failUnlessEqual(modes[i], - stat.S_IMODE(os.stat(fn).st_mode)) + self.failUnlessEqual(modes[i], stat.S_IMODE( + os.stat(fn).st_mode)) def test_text(self): return self._do_test() + def test_text_subprocess(self): return self._do_test(as_subprocess=True) + def test_text_tor(self): return self._do_test(fake_tor=True) def test_file(self): return self._do_test(mode="file") + def test_file_override(self): return self._do_test(mode="file", override_filename=True) + def test_file_overwrite(self): return self._do_test(mode="file", overwrite=True) + def test_file_overwrite_mock_accept(self): return self._do_test(mode="file", overwrite=True, mock_accept=True) + def test_file_tor(self): return self._do_test(mode="file", fake_tor=True) + def test_empty_file(self): return self._do_test(mode="empty-file") def test_directory(self): return self._do_test(mode="directory") + def test_directory_addslash(self): return self._do_test(mode="directory", addslash=True) + def test_directory_override(self): return self._do_test(mode="directory", override_filename=True) + def test_directory_overwrite(self): return self._do_test(mode="directory", overwrite=True) + def test_directory_overwrite_mock_accept(self): - return self._do_test(mode="directory", overwrite=True, mock_accept=True) + return self._do_test( + mode="directory", overwrite=True, mock_accept=True) def test_slow_text(self): return self._do_test(mode="slow-text") + def test_slow_sender_text(self): return self._do_test(mode="slow-sender-text") @@ -721,7 +778,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): os.mkdir(send_dir) receive_dir = self.mktemp() os.mkdir(receive_dir) - recv_cfg.accept_file = True # don't ask for permission + recv_cfg.accept_file = True # don't ask for permission if mode == "file": message = "test message\n" @@ -765,10 +822,12 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): free_space = 10000000 else: free_space = 0 - with mock.patch("wormhole.cli.cmd_receive.estimate_free_space", - return_value=free_space): + with mock.patch( + "wormhole.cli.cmd_receive.estimate_free_space", + return_value=free_space): f = yield self.assertFailure(send_d, TransferError) - self.assertEqual(str(f), "remote error, transfer abandoned: transfer rejected") + self.assertEqual( + str(f), "remote error, transfer abandoned: transfer rejected") f = yield self.assertFailure(receive_d, TransferError) self.assertEqual(str(f), "transfer rejected") @@ -781,7 +840,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): # newlines, even if we're on windows NL = "\n" - self.maxDiff = None # show full output for assertion failures + self.maxDiff = None # show full output for assertion failures self.assertEqual(send_stdout, "") self.assertEqual(receive_stdout, "") @@ -789,54 +848,63 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): # check sender if mode == "file": self.failUnlessIn("Sending {size:s} file named '{name}'{NL}" - .format(size=naturalsize(size), - name=send_filename, - NL=NL), send_stderr) + .format( + size=naturalsize(size), + name=send_filename, + NL=NL), send_stderr) self.failUnlessIn("Wormhole code is: {code}{NL}" "On the other computer, please run:{NL}{NL}" - "wormhole receive {code}{NL}" - .format(code=send_cfg.code, NL=NL), - send_stderr) - self.failIfIn("File sent.. waiting for confirmation{NL}" - "Confirmation received. Transfer complete.{NL}" - .format(NL=NL), send_stderr) + "wormhole receive {code}{NL}".format( + code=send_cfg.code, NL=NL), send_stderr) + self.failIfIn( + "File sent.. waiting for confirmation{NL}" + "Confirmation received. Transfer complete.{NL}".format(NL=NL), + send_stderr) elif mode == "directory": self.failUnlessIn("Sending directory", send_stderr) self.failUnlessIn("named 'testdir'", send_stderr) self.failUnlessIn("Wormhole code is: {code}{NL}" "On the other computer, please run:{NL}{NL}" - "wormhole receive {code}{NL}" - .format(code=send_cfg.code, NL=NL), send_stderr) - self.failIfIn("File sent.. waiting for confirmation{NL}" - "Confirmation received. Transfer complete.{NL}" - .format(NL=NL), send_stderr) + "wormhole receive {code}{NL}".format( + code=send_cfg.code, NL=NL), send_stderr) + self.failIfIn( + "File sent.. waiting for confirmation{NL}" + "Confirmation received. Transfer complete.{NL}".format(NL=NL), + send_stderr) # check receiver if mode == "file": self.failIfIn("Received file written to ", receive_stderr) if failmode == "noclobber": - self.failUnlessIn("Error: " - "refusing to overwrite existing 'testfile'{NL}" - .format(NL=NL), receive_stderr) + self.failUnlessIn( + "Error: " + "refusing to overwrite existing 'testfile'{NL}" + .format(NL=NL), + receive_stderr) else: - self.failUnlessIn("Error: " - "insufficient free space (0B) for file ({size:d}B){NL}" - .format(NL=NL, size=size), receive_stderr) + self.failUnlessIn( + "Error: " + "insufficient free space (0B) for file ({size:d}B){NL}" + .format(NL=NL, size=size), receive_stderr) elif mode == "directory": - self.failIfIn("Received files written to {name}" - .format(name=receive_name), receive_stderr) - #want = (r"Receiving directory \(\d+ \w+\) into: {name}/" + self.failIfIn( + "Received files written to {name}".format(name=receive_name), + receive_stderr) + # want = (r"Receiving directory \(\d+ \w+\) into: {name}/" # .format(name=receive_name)) - #self.failUnless(re.search(want, receive_stderr), + # self.failUnless(re.search(want, receive_stderr), # (want, receive_stderr)) if failmode == "noclobber": - self.failUnlessIn("Error: " - "refusing to overwrite existing 'testdir'{NL}" - .format(NL=NL), receive_stderr) + self.failUnlessIn( + "Error: " + "refusing to overwrite existing 'testdir'{NL}" + .format(NL=NL), + receive_stderr) else: - self.failUnlessIn("Error: " - "insufficient free space (0B) for directory ({size:d}B){NL}" - .format(NL=NL, size=size), receive_stderr) + self.failUnlessIn(("Error: " + "insufficient free space (0B) for directory" + " ({size:d}B){NL}").format( + NL=NL, size=size), receive_stderr) if failmode == "noclobber": fn = os.path.join(receive_dir, receive_name) @@ -846,13 +914,17 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): def test_fail_file_noclobber(self): return self._do_test_fail("file", "noclobber") + def test_fail_directory_noclobber(self): return self._do_test_fail("directory", "noclobber") + def test_fail_file_toobig(self): return self._do_test_fail("file", "toobig") + def test_fail_directory_toobig(self): return self._do_test_fail("directory", "toobig") + class ZeroMode(ServerBase, unittest.TestCase): @inlineCallbacks def test_text(self): @@ -871,8 +943,8 @@ class ZeroMode(ServerBase, unittest.TestCase): send_cfg.text = message - #send_cfg.cwd = send_dir - #recv_cfg.cwd = receive_dir + # send_cfg.cwd = send_dir + # recv_cfg.cwd = receive_dir send_d = cmd_send.send(send_cfg) receive_d = cmd_receive.receive(recv_cfg) @@ -888,7 +960,7 @@ class ZeroMode(ServerBase, unittest.TestCase): # newlines, even if we're on windows NL = "\n" - self.maxDiff = None # show full output for assertion failures + self.maxDiff = None # show full output for assertion failures self.assertEqual(send_stdout, "") @@ -898,15 +970,15 @@ class ZeroMode(ServerBase, unittest.TestCase): "{NL}" "wormhole receive -0{NL}" "{NL}" - "text message sent{NL}").format(bytes=len(message), - code=send_cfg.code, - NL=NL) + "text message sent{NL}").format( + bytes=len(message), code=send_cfg.code, NL=NL) self.failUnlessEqual(send_stderr, expected) # check receiver - self.assertEqual(receive_stdout, message+NL) + self.assertEqual(receive_stdout, message + NL) self.assertEqual(receive_stderr, "") + class NotWelcome(ServerBase, unittest.TestCase): @inlineCallbacks def setUp(self): @@ -936,6 +1008,7 @@ class NotWelcome(ServerBase, unittest.TestCase): f = yield self.assertFailure(receive_d, WelcomeError) self.assertEqual(str(f), "please upgrade XYZ") + class NoServer(ServerBase, unittest.TestCase): @inlineCallbacks def setUp(self): @@ -991,8 +1064,8 @@ class NoServer(ServerBase, unittest.TestCase): e = yield self.assertFailure(receive_d, ServerConnectionError) self.assertIsInstance(e.reason, ConnectionRefusedError) -class Cleanup(ServerBase, unittest.TestCase): +class Cleanup(ServerBase, unittest.TestCase): def make_config(self): cfg = config("send") # common options for all tests in this suite @@ -1040,6 +1113,7 @@ class Cleanup(ServerBase, unittest.TestCase): cids = self._rendezvous.get_app(cmd_send.APPID).get_nameplate_ids() self.assertEqual(len(cids), 0) + class ExtractFile(unittest.TestCase): def test_filenames(self): args = mock.Mock() @@ -1066,8 +1140,8 @@ class ExtractFile(unittest.TestCase): zf = mock.Mock() zi = mock.Mock() - zi.filename = "haha//root" # abspath squashes this, hopefully zipfile - # does too + zi.filename = "haha//root" # abspath squashes this, hopefully zipfile + # does too zi.external_attr = 5 << 16 expected = os.path.join(extract_dir, "haha", "root") with mock.patch.object(cmd_receive.os, "chmod") as chmod: @@ -1082,6 +1156,7 @@ class ExtractFile(unittest.TestCase): e = self.assertRaises(ValueError, ef, zf, zi, extract_dir) self.assertIn("malicious zipfile", str(e)) + class AppID(ServerBase, unittest.TestCase): @inlineCallbacks def setUp(self): @@ -1108,11 +1183,11 @@ class AppID(ServerBase, unittest.TestCase): yield receive_d used = self._usage_db.execute("SELECT DISTINCT `app_id`" - " FROM `nameplates`" - ).fetchall() + " FROM `nameplates`").fetchall() self.assertEqual(len(used), 1, used) self.assertEqual(used[0]["app_id"], u"appid2") + class Welcome(unittest.TestCase): def do(self, welcome_message, my_version="2.0"): stderr = io.StringIO() @@ -1129,27 +1204,33 @@ class Welcome(unittest.TestCase): def test_version_old(self): stderr = self.do({"current_cli_version": "3.0"}) - expected = ("Warning: errors may occur unless both sides are running the same version\n" + + expected = ("Warning: errors may occur unless both sides are" + " running the same version\n" "Server claims 3.0 is current, but ours is 2.0\n") self.assertEqual(stderr, expected) def test_version_unreleased(self): - stderr = self.do({"current_cli_version": "3.0"}, - my_version="2.5+middle.something") + stderr = self.do( + { + "current_cli_version": "3.0" + }, my_version="2.5+middle.something") self.assertEqual(stderr, "") def test_motd(self): stderr = self.do({"motd": "hello"}) self.assertEqual(stderr, "Server (at url) says:\n hello\n") + class Dispatch(unittest.TestCase): @inlineCallbacks def test_success(self): cfg = config("send") cfg.stderr = io.StringIO() called = [] + def fake(): called.append(1) + yield cli._dispatch_command(reactor, cfg, fake) self.assertEqual(called, [1]) self.assertEqual(cfg.stderr.getvalue(), "") @@ -1160,8 +1241,10 @@ class Dispatch(unittest.TestCase): cfg.stderr = io.StringIO() cfg.timing = mock.Mock() cfg.dump_timing = "filename" + def fake(): pass + yield cli._dispatch_command(reactor, cfg, fake) self.assertEqual(cfg.stderr.getvalue(), "") self.assertEqual(cfg.timing.mock_calls[-1], @@ -1171,32 +1254,39 @@ class Dispatch(unittest.TestCase): def test_wrong_password_error(self): cfg = config("send") cfg.stderr = io.StringIO() + def fake(): raise WrongPasswordError("abcd") - yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake), - SystemExit) - expected = fill("ERROR: " + dedent(WrongPasswordError.__doc__))+"\n" + + yield self.assertFailure( + cli._dispatch_command(reactor, cfg, fake), SystemExit) + expected = fill("ERROR: " + dedent(WrongPasswordError.__doc__)) + "\n" self.assertEqual(cfg.stderr.getvalue(), expected) @inlineCallbacks def test_welcome_error(self): cfg = config("send") cfg.stderr = io.StringIO() + def fake(): raise WelcomeError("abcd") - yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake), - SystemExit) - expected = fill("ERROR: " + dedent(WelcomeError.__doc__))+"\n\nabcd\n" + + yield self.assertFailure( + cli._dispatch_command(reactor, cfg, fake), SystemExit) + expected = ( + fill("ERROR: " + dedent(WelcomeError.__doc__)) + "\n\nabcd\n") self.assertEqual(cfg.stderr.getvalue(), expected) @inlineCallbacks def test_transfer_error(self): cfg = config("send") cfg.stderr = io.StringIO() + def fake(): raise TransferError("abcd") - yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake), - SystemExit) + + yield self.assertFailure( + cli._dispatch_command(reactor, cfg, fake), SystemExit) expected = "TransferError: abcd\n" self.assertEqual(cfg.stderr.getvalue(), expected) @@ -1204,11 +1294,14 @@ class Dispatch(unittest.TestCase): def test_server_connection_error(self): cfg = config("send") cfg.stderr = io.StringIO() + def fake(): raise ServerConnectionError("URL", ValueError("abcd")) - yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake), - SystemExit) - expected = fill("ERROR: " + dedent(ServerConnectionError.__doc__))+"\n" + + yield self.assertFailure( + cli._dispatch_command(reactor, cfg, fake), SystemExit) + expected = fill( + "ERROR: " + dedent(ServerConnectionError.__doc__)) + "\n" expected += "(relay URL was URL)\n" expected += "abcd\n" self.assertEqual(cfg.stderr.getvalue(), expected) @@ -1217,21 +1310,26 @@ class Dispatch(unittest.TestCase): def test_other_error(self): cfg = config("send") cfg.stderr = io.StringIO() + def fake(): raise ValueError("abcd") + # I'm seeing unicode problems with the Failure().printTraceback, and # the output would be kind of unpredictable anyways, so we'll mock it # out here. f = mock.Mock() + def mock_print(file): file.write(u"\n") + f.printTraceback = mock_print with mock.patch("wormhole.cli.cli.Failure", return_value=f): - yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake), - SystemExit) + yield self.assertFailure( + cli._dispatch_command(reactor, cfg, fake), SystemExit) expected = "\nERROR: abcd\n" self.assertEqual(cfg.stderr.getvalue(), expected) + class Help(unittest.TestCase): def _check_top_level_help(self, got): # the main wormhole.cli.cli.wormhole docstring should be in the diff --git a/src/wormhole/test/test_eventual.py b/src/wormhole/test/test_eventual.py index 4813198..cb0c896 100644 --- a/src/wormhole/test/test_eventual.py +++ b/src/wormhole/test/test_eventual.py @@ -1,14 +1,19 @@ from __future__ import print_function, unicode_literals -import mock -from twisted.trial import unittest + from twisted.internet import reactor -from twisted.internet.task import Clock from twisted.internet.defer import Deferred, inlineCallbacks +from twisted.internet.task import Clock +from twisted.trial import unittest + +import mock + from ..eventual import EventualQueue + class IntentionalError(Exception): pass + class Eventual(unittest.TestCase, object): def test_eventually(self): c = Clock() @@ -23,9 +28,10 @@ class Eventual(unittest.TestCase, object): self.assertNoResult(d3) eq.flush_sync() - self.assertEqual(c1.mock_calls, - [mock.call("arg1", "arg2", kwarg1="kw1"), - mock.call("arg3", "arg4", kwarg5="kw5")]) + self.assertEqual(c1.mock_calls, [ + mock.call("arg1", "arg2", kwarg1="kw1"), + mock.call("arg3", "arg4", kwarg5="kw5") + ]) self.assertEqual(self.successResultOf(d2), None) self.assertEqual(self.successResultOf(d3), "value") @@ -47,11 +53,12 @@ class Eventual(unittest.TestCase, object): eq = EventualQueue(reactor) d1 = eq.fire_eventually() d2 = Deferred() + def _more(res): eq.eventually(d2.callback, None) + d1.addCallback(_more) yield eq.flush() # d1 will fire, which will queue d2 to fire, and the flush() ought to # wait for d2 too self.successResultOf(d2) - diff --git a/src/wormhole/test/test_hkdf.py b/src/wormhole/test/test_hkdf.py index 7751383..1628646 100644 --- a/src/wormhole/test/test_hkdf.py +++ b/src/wormhole/test/test_hkdf.py @@ -1,44 +1,47 @@ from __future__ import print_function, unicode_literals + import unittest -from binascii import unhexlify #, hexlify +from binascii import unhexlify # , hexlify + from hkdf import Hkdf -#def generate_KAT(): -# print("KAT = [") -# for salt in (b"", b"salt"): -# for context in (b"", b"context"): -# skm = b"secret" -# out = HKDF(skm, 64, XTS=salt, CTXinfo=context) -# hexout = " '%s' +\n '%s'" % (hexlify(out[:32]), -# hexlify(out[32:])) -# print(" (%r, %r, %r,\n%s)," % (salt, context, skm, hexout)) -# print("]") +# def generate_KAT(): +# print("KAT = [") +# for salt in (b"", b"salt"): +# for context in (b"", b"context"): +# skm = b"secret" +# out = HKDF(skm, 64, XTS=salt, CTXinfo=context) +# hexout = " '%s' +\n '%s'" % (hexlify(out[:32]), +# hexlify(out[32:])) +# print(" (%r, %r, %r,\n%s)," % (salt, context, skm, hexout)) +# print("]") KAT = [ - ('', '', 'secret', - '2f34e5ff91ec85d53ca9b543683174d0cf550b60d5f52b24c97b386cfcf6cbbf' + - '9cfd42fd37e1e5a214d15f03058d7fee63dc28f564b7b9fe3da514f80daad4bf'), - ('', 'context', 'secret', - 'c24c303a1adfb4c3e2b092e6254ed481c41d8955ba8ec3f6a1473493a60c957b' + - '31b723018ca75557214d3d5c61c0c7a5315b103b21ff00cb03ebe023dc347a47'), - ('salt', '', 'secret', - 'f1156507c39b0e326159e778696253122de430899a8df2484040a85a5f95ceb1' + - 'dfca555d4cc603bdf7153ed1560de8cbc3234b27a6d2be8e8ca202d90649679a'), - ('salt', 'context', 'secret', - '61a4f201a867bcc12381ddb180d27074408d03ee9d5750855e5a12d967fa060f' + - '10336ead9370927eaabb0d60b259346ee5f57eb7ceba8c72f1ed3f2932b1bf19'), + ('', '', 'secret', + '2f34e5ff91ec85d53ca9b543683174d0cf550b60d5f52b24c97b386cfcf6cbbf' + + '9cfd42fd37e1e5a214d15f03058d7fee63dc28f564b7b9fe3da514f80daad4bf'), + ('', 'context', 'secret', + 'c24c303a1adfb4c3e2b092e6254ed481c41d8955ba8ec3f6a1473493a60c957b' + + '31b723018ca75557214d3d5c61c0c7a5315b103b21ff00cb03ebe023dc347a47'), + ('salt', '', 'secret', + 'f1156507c39b0e326159e778696253122de430899a8df2484040a85a5f95ceb1' + + 'dfca555d4cc603bdf7153ed1560de8cbc3234b27a6d2be8e8ca202d90649679a'), + ('salt', 'context', 'secret', + '61a4f201a867bcc12381ddb180d27074408d03ee9d5750855e5a12d967fa060f' + + '10336ead9370927eaabb0d60b259346ee5f57eb7ceba8c72f1ed3f2932b1bf19'), ] + class TestKAT(unittest.TestCase): # note: this uses SHA256 def test_kat(self): for (salt, context, skm, expected_hexout) in KAT: expected_out = unhexlify(expected_hexout) for outlen in range(0, len(expected_out)): - out = Hkdf(salt.encode("ascii"), - skm.encode("ascii")).expand(context.encode("ascii"), - outlen) + out = Hkdf(salt.encode("ascii"), skm.encode("ascii")).expand( + context.encode("ascii"), outlen) self.assertEqual(out, expected_out[:outlen]) -#if __name__ == '__main__': -# generate_KAT() + +# if __name__ == '__main__': +# generate_KAT() diff --git a/src/wormhole/test/test_ipaddrs.py b/src/wormhole/test/test_ipaddrs.py index 90abcce..9e2cec5 100644 --- a/src/wormhole/test/test_ipaddrs.py +++ b/src/wormhole/test/test_ipaddrs.py @@ -1,9 +1,13 @@ +import errno +import os +import re +import subprocess -import re, errno, subprocess, os from twisted.trial import unittest + from .. import ipaddrs -DOTTED_QUAD_RE=re.compile("^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$") +DOTTED_QUAD_RE = re.compile("^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$") MOCK_IPADDR_OUTPUT = """\ 1: lo: mtu 16436 qdisc noqueue state UNKNOWN \n\ @@ -11,12 +15,14 @@ MOCK_IPADDR_OUTPUT = """\ inet 127.0.0.1/8 scope host lo inet6 ::1/128 scope host \n\ valid_lft forever preferred_lft forever -2: eth1: mtu 1500 qdisc pfifo_fast state UP qlen 1000 +2: eth1: mtu 1500 qdisc pfifo_fast state UP \ +qlen 1000 link/ether d4:3d:7e:01:b4:3e brd ff:ff:ff:ff:ff:ff inet 192.168.0.6/24 brd 192.168.0.255 scope global eth1 inet6 fe80::d63d:7eff:fe01:b43e/64 scope link \n\ valid_lft forever preferred_lft forever -3: wlan0: mtu 1500 qdisc mq state UP qlen 1000 +3: wlan0: mtu 1500 qdisc mq state UP qlen\ + 1000 link/ether 90:f6:52:27:15:0a brd ff:ff:ff:ff:ff:ff inet 192.168.0.2/24 brd 192.168.0.255 scope global wlan0 inet6 fe80::92f6:52ff:fe27:150a/64 scope link \n\ @@ -58,7 +64,8 @@ MOCK_ROUTE_OUTPUT = """\ =========================================================================== Interface List 0x1 ........................... MS TCP Loopback interface -0x2 ...08 00 27 c3 80 ad ...... AMD PCNET Family PCI Ethernet Adapter - Packet Scheduler Miniport +0x2 ...08 00 27 c3 80 ad ...... AMD PCNET Family PCI Ethernet Adapter - \ +Packet Scheduler Miniport =========================================================================== =========================================================================== Active Routes: @@ -85,6 +92,7 @@ class FakeProcess: def __init__(self, output, err): self.output = output self.err = err + def communicate(self): return (self.output, self.err) @@ -94,16 +102,28 @@ class ListAddresses(unittest.TestCase): addresses = ipaddrs.find_addresses() self.failUnlessIn("127.0.0.1", addresses) self.failIfIn("0.0.0.0", addresses) + # David A.'s OpenSolaris box timed out on this test one time when it was at # 2s. - test_list.timeout=4 + test_list.timeout = 4 def _test_list_mock(self, command, output, expected): self.first = True - def call_Popen(args, bufsize=0, executable=None, stdin=None, stdout=None, stderr=None, - preexec_fn=None, close_fds=False, shell=False, cwd=None, env=None, - universal_newlines=False, startupinfo=None, creationflags=0): + def call_Popen(args, + bufsize=0, + executable=None, + stdin=None, + stdout=None, + stderr=None, + preexec_fn=None, + close_fds=False, + shell=False, + cwd=None, + env=None, + universal_newlines=False, + startupinfo=None, + creationflags=0): if self.first: self.first = False e = OSError("EINTR") @@ -115,11 +135,13 @@ class ListAddresses(unittest.TestCase): e = OSError("[Errno 2] No such file or directory") e.errno = errno.ENOENT raise e + self.patch(subprocess, 'Popen', call_Popen) self.patch(os.path, 'isfile', lambda x: True) def call_which(name): return [name] + self.patch(ipaddrs, 'which', call_which) addresses = ipaddrs.find_addresses() @@ -131,11 +153,13 @@ class ListAddresses(unittest.TestCase): def test_list_mock_ifconfig(self): self.patch(ipaddrs, 'platform', "linux2") - self._test_list_mock("ifconfig", MOCK_IFCONFIG_OUTPUT, UNIX_TEST_ADDRESSES) + self._test_list_mock("ifconfig", MOCK_IFCONFIG_OUTPUT, + UNIX_TEST_ADDRESSES) def test_list_mock_route(self): self.patch(ipaddrs, 'platform', "win32") - self._test_list_mock("route.exe", MOCK_ROUTE_OUTPUT, WINDOWS_TEST_ADDRESSES) + self._test_list_mock("route.exe", MOCK_ROUTE_OUTPUT, + WINDOWS_TEST_ADDRESSES) def test_list_mock_cygwin(self): self.patch(ipaddrs, 'platform', "cygwin") diff --git a/src/wormhole/test/test_journal.py b/src/wormhole/test/test_journal.py index 96b9319..372da0f 100644 --- a/src/wormhole/test/test_journal.py +++ b/src/wormhole/test/test_journal.py @@ -1,8 +1,11 @@ -from __future__ import print_function, absolute_import, unicode_literals +from __future__ import absolute_import, print_function, unicode_literals + from twisted.trial import unittest + from .. import journal from .._interfaces import IJournal + class Journal(unittest.TestCase): def test_journal(self): events = [] diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index ec96fc0..38e351e 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -1,29 +1,36 @@ from __future__ import print_function, unicode_literals + import json -import mock -from zope.interface import directlyProvides, implementer + +from nacl.secret import SecretBox +from spake2 import SPAKE2_Symmetric from twisted.trial import unittest -from .. import (errors, timing, _order, _receive, _key, _code, _lister, _boss, - _input, _allocator, _send, _terminator, _nameplate, _mailbox, - _rendezvous, __version__) -from .._interfaces import (IKey, IReceive, IBoss, ISend, IMailbox, IOrder, - IRendezvousConnector, ILister, IInput, IAllocator, - INameplate, ICode, IWordlist, ITerminator) +from zope.interface import directlyProvides, implementer + +import mock + +from .. import (__version__, _allocator, _boss, _code, _input, _key, _lister, + _mailbox, _nameplate, _order, _receive, _rendezvous, _send, + _terminator, errors, timing) +from .._interfaces import (IAllocator, IBoss, ICode, IInput, IKey, ILister, + IMailbox, INameplate, IOrder, IReceive, + IRendezvousConnector, ISend, ITerminator, IWordlist) from .._key import derive_key, derive_phase_key, encrypt_data from ..journal import ImmediateJournal -from ..util import (dict_to_bytes, bytes_to_dict, - hexstr_to_bytes, bytes_to_hexstr, to_bytes) -from spake2 import SPAKE2_Symmetric -from nacl.secret import SecretBox +from ..util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes, + hexstr_to_bytes, to_bytes) + @implementer(IWordlist) class FakeWordList(object): def choose_words(self, length): return "-".join(["word"] * length) + def get_completions(self, prefix): self._get_completions_prefix = prefix return self._completions + class Dummy: def __init__(self, name, events, iface, *meths): self.name = name @@ -33,12 +40,15 @@ class Dummy: for meth in meths: self.mock(meth) self.retval = None + def mock(self, meth): def log(*args): - self.events.append(("%s.%s" % (self.name, meth),) + args) + self.events.append(("%s.%s" % (self.name, meth), ) + args) return self.retval + setattr(self, meth, log) + class Send(unittest.TestCase): def build(self): events = [] @@ -56,8 +66,10 @@ class Send(unittest.TestCase): with mock.patch("nacl.utils.random", side_effect=[nonce1]) as r: s.got_verified_key(key) self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)]) - #print(bytes_to_hexstr(events[0][2])) - enc1 = hexstr_to_bytes("00000000000000000000000000000000000000000000000022f1a46c3c3496423c394621a2a5a8cf275b08") + # print(bytes_to_hexstr(events[0][2])) + enc1 = hexstr_to_bytes( + ("000000000000000000000000000000000000000000000000" + "22f1a46c3c3496423c394621a2a5a8cf275b08")) self.assertEqual(events, [("m.add_message", "phase1", enc1)]) events[:] = [] @@ -65,7 +77,9 @@ class Send(unittest.TestCase): with mock.patch("nacl.utils.random", side_effect=[nonce2]) as r: s.send("phase2", b"msg") self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)]) - enc2 = hexstr_to_bytes("0202020202020202020202020202020202020202020202026660337c3eac6513c0dac9818b62ef16d9cd7e") + enc2 = hexstr_to_bytes( + ("0202020202020202020202020202020202020202" + "020202026660337c3eac6513c0dac9818b62ef16d9cd7e")) self.assertEqual(events, [("m.add_message", "phase2", enc2)]) def test_key_first(self): @@ -78,7 +92,8 @@ class Send(unittest.TestCase): with mock.patch("nacl.utils.random", side_effect=[nonce1]) as r: s.send("phase1", b"msg") self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)]) - enc1 = hexstr_to_bytes("00000000000000000000000000000000000000000000000022f1a46c3c3496423c394621a2a5a8cf275b08") + enc1 = hexstr_to_bytes(("00000000000000000000000000000000000000000000" + "000022f1a46c3c3496423c394621a2a5a8cf275b08")) self.assertEqual(events, [("m.add_message", "phase1", enc1)]) events[:] = [] @@ -86,11 +101,12 @@ class Send(unittest.TestCase): with mock.patch("nacl.utils.random", side_effect=[nonce2]) as r: s.send("phase2", b"msg") self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)]) - enc2 = hexstr_to_bytes("0202020202020202020202020202020202020202020202026660337c3eac6513c0dac9818b62ef16d9cd7e") + enc2 = hexstr_to_bytes( + ("0202020202020202020202020202020202020" + "202020202026660337c3eac6513c0dac9818b62ef16d9cd7e")) self.assertEqual(events, [("m.add_message", "phase2", enc2)]) - class Order(unittest.TestCase): def build(self): events = [] @@ -103,35 +119,36 @@ class Order(unittest.TestCase): def test_in_order(self): o, k, r, events = self.build() o.got_message(u"side", u"pake", b"body") - self.assertEqual(events, [("k.got_pake", b"body")]) # right away + self.assertEqual(events, [("k.got_pake", b"body")]) # right away o.got_message(u"side", u"version", b"body") o.got_message(u"side", u"1", b"body") - self.assertEqual(events, - [("k.got_pake", b"body"), - ("r.got_message", u"side", u"version", b"body"), - ("r.got_message", u"side", u"1", b"body"), - ]) + self.assertEqual(events, [ + ("k.got_pake", b"body"), + ("r.got_message", u"side", u"version", b"body"), + ("r.got_message", u"side", u"1", b"body"), + ]) def test_out_of_order(self): o, k, r, events = self.build() o.got_message(u"side", u"version", b"body") - self.assertEqual(events, []) # nothing yet + self.assertEqual(events, []) # nothing yet o.got_message(u"side", u"1", b"body") - self.assertEqual(events, []) # nothing yet + self.assertEqual(events, []) # nothing yet o.got_message(u"side", u"pake", b"body") # got_pake is delivered first - self.assertEqual(events, - [("k.got_pake", b"body"), - ("r.got_message", u"side", u"version", b"body"), - ("r.got_message", u"side", u"1", b"body"), - ]) + self.assertEqual(events, [ + ("k.got_pake", b"body"), + ("r.got_message", u"side", u"version", b"body"), + ("r.got_message", u"side", u"1", b"body"), + ]) + class Receive(unittest.TestCase): def build(self): events = [] r = _receive.Receive(u"side", timing.DebugTiming()) - b = Dummy("b", events, IBoss, - "happy", "scared", "got_verifier", "got_message") + b = Dummy("b", events, IBoss, "happy", "scared", "got_verifier", + "got_message") s = Dummy("s", events, ISend, "got_verified_key") r.wire(b, s) return r, b, s, events @@ -146,22 +163,24 @@ class Receive(unittest.TestCase): data1 = b"data1" good_body = encrypt_data(phase1_key, data1) r.got_message(u"side", u"phase1", good_body) - self.assertEqual(events, [("s.got_verified_key", key), - ("b.happy",), - ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), - ]) + self.assertEqual(events, [ + ("s.got_verified_key", key), + ("b.happy", ), + ("b.got_verifier", verifier), + ("b.got_message", u"phase1", data1), + ]) phase2_key = derive_phase_key(key, u"side", u"phase2") data2 = b"data2" good_body = encrypt_data(phase2_key, data2) r.got_message(u"side", u"phase2", good_body) - self.assertEqual(events, [("s.got_verified_key", key), - ("b.happy",), - ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), - ("b.got_message", u"phase2", data2), - ]) + self.assertEqual(events, [ + ("s.got_verified_key", key), + ("b.happy", ), + ("b.got_verifier", verifier), + ("b.got_message", u"phase1", data1), + ("b.got_message", u"phase2", data2), + ]) def test_early_bad(self): r, b, s, events = self.build() @@ -172,15 +191,17 @@ class Receive(unittest.TestCase): data1 = b"data1" bad_body = encrypt_data(phase1_key, data1) r.got_message(u"side", u"phase1", bad_body) - self.assertEqual(events, [("b.scared",), - ]) + self.assertEqual(events, [ + ("b.scared", ), + ]) phase2_key = derive_phase_key(key, u"side", u"phase2") data2 = b"data2" good_body = encrypt_data(phase2_key, data2) r.got_message(u"side", u"phase2", good_body) - self.assertEqual(events, [("b.scared",), - ]) + self.assertEqual(events, [ + ("b.scared", ), + ]) def test_late_bad(self): r, b, s, events = self.build() @@ -192,30 +213,34 @@ class Receive(unittest.TestCase): data1 = b"data1" good_body = encrypt_data(phase1_key, data1) r.got_message(u"side", u"phase1", good_body) - self.assertEqual(events, [("s.got_verified_key", key), - ("b.happy",), - ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), - ]) + self.assertEqual(events, [ + ("s.got_verified_key", key), + ("b.happy", ), + ("b.got_verifier", verifier), + ("b.got_message", u"phase1", data1), + ]) phase2_key = derive_phase_key(key, u"side", u"bad") data2 = b"data2" bad_body = encrypt_data(phase2_key, data2) r.got_message(u"side", u"phase2", bad_body) - self.assertEqual(events, [("s.got_verified_key", key), - ("b.happy",), - ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), - ("b.scared",), - ]) + self.assertEqual(events, [ + ("s.got_verified_key", key), + ("b.happy", ), + ("b.got_verifier", verifier), + ("b.got_message", u"phase1", data1), + ("b.scared", ), + ]) r.got_message(u"side", u"phase1", good_body) r.got_message(u"side", u"phase2", bad_body) - self.assertEqual(events, [("s.got_verified_key", key), - ("b.happy",), - ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), - ("b.scared",), - ]) + self.assertEqual(events, [ + ("s.got_verified_key", key), + ("b.happy", ), + ("b.got_verifier", verifier), + ("b.got_message", u"phase1", data1), + ("b.scared", ), + ]) + class Key(unittest.TestCase): def test_derive_errors(self): @@ -260,11 +285,12 @@ class Key(unittest.TestCase): self.assertEqual(events[0][:2], ("m.add_message", "pake")) pake_1_json = events[0][2].decode("utf-8") pake_1 = json.loads(pake_1_json) - self.assertEqual(list(pake_1.keys()), ["pake_v1"]) # value is PAKE stuff + self.assertEqual(list(pake_1.keys()), + ["pake_v1"]) # value is PAKE stuff events[:] = [] bad_pake_d = {"not_pake_v1": "stuff"} k.got_pake(dict_to_bytes(bad_pake_d)) - self.assertEqual(events, [("b.scared",)]) + self.assertEqual(events, [("b.scared", )]) def test_reversed(self): # A receiver using input_code() will choose the nameplate first, then @@ -293,6 +319,7 @@ class Key(unittest.TestCase): self.assertEqual(events[2][:2], ("m.add_message", "version")) self.assertEqual(events[3], ("r.got_key", key2)) + class Code(unittest.TestCase): def build(self): events = [] @@ -308,10 +335,11 @@ class Code(unittest.TestCase): def test_set_code(self): c, b, a, n, k, i, events = self.build() c.set_code(u"1-code") - self.assertEqual(events, [("n.set_nameplate", u"1"), - ("b.got_code", u"1-code"), - ("k.got_code", u"1-code"), - ]) + self.assertEqual(events, [ + ("n.set_nameplate", u"1"), + ("b.got_code", u"1-code"), + ("k.got_code", u"1-code"), + ]) def test_set_code_invalid(self): c, b, a, n, k, i, events = self.build() @@ -323,15 +351,17 @@ class Code(unittest.TestCase): self.assertEqual(str(e.exception), "Code ' 1-code' contains spaces.") with self.assertRaises(errors.KeyFormatError) as e: c.set_code(u"code-code") - self.assertEqual(str(e.exception), - "Nameplate 'code' must be numeric, with no spaces.") + self.assertEqual( + str(e.exception), + "Nameplate 'code' must be numeric, with no spaces.") # it should still be possible to use the wormhole at this point c.set_code(u"1-code") - self.assertEqual(events, [("n.set_nameplate", u"1"), - ("b.got_code", u"1-code"), - ("k.got_code", u"1-code"), - ]) + self.assertEqual(events, [ + ("n.set_nameplate", u"1"), + ("b.got_code", u"1-code"), + ("k.got_code", u"1-code"), + ]) def test_allocate_code(self): c, b, a, n, k, i, events = self.build() @@ -340,30 +370,35 @@ class Code(unittest.TestCase): self.assertEqual(events, [("a.allocate", 2, wl)]) events[:] = [] c.allocated("1", "1-code") - self.assertEqual(events, [("n.set_nameplate", u"1"), - ("b.got_code", u"1-code"), - ("k.got_code", u"1-code"), - ]) + self.assertEqual(events, [ + ("n.set_nameplate", u"1"), + ("b.got_code", u"1-code"), + ("k.got_code", u"1-code"), + ]) def test_input_code(self): c, b, a, n, k, i, events = self.build() c.input_code() - self.assertEqual(events, [("i.start",)]) + self.assertEqual(events, [("i.start", )]) events[:] = [] c.got_nameplate("1") - self.assertEqual(events, [("n.set_nameplate", u"1"), - ]) + self.assertEqual(events, [ + ("n.set_nameplate", u"1"), + ]) events[:] = [] c.finished_input("1-code") - self.assertEqual(events, [("b.got_code", u"1-code"), - ("k.got_code", u"1-code"), - ]) + self.assertEqual(events, [ + ("b.got_code", u"1-code"), + ("k.got_code", u"1-code"), + ]) + class Input(unittest.TestCase): def build(self): events = [] i = _input.Input(timing.DebugTiming()) c = Dummy("c", events, ICode, "got_nameplate", "finished_input") + # renamed from l as l is indistinguishable from 1 in some fonts. l = Dummy("l", events, ILister, "refresh") i.wire(c, l) return i, c, l, events @@ -372,7 +407,7 @@ class Input(unittest.TestCase): i, c, l, events = self.build() helper = i.start() self.assertIsInstance(helper, _input.Helper) - self.assertEqual(events, [("l.refresh",)]) + self.assertEqual(events, [("l.refresh", )]) events[:] = [] with self.assertRaises(errors.MustChooseNameplateFirstError): helper.choose_words("word-word") @@ -390,7 +425,7 @@ class Input(unittest.TestCase): i, c, l, events = self.build() helper = i.start() self.assertIsInstance(helper, _input.Helper) - self.assertEqual(events, [("l.refresh",)]) + self.assertEqual(events, [("l.refresh", )]) events[:] = [] with self.assertRaises(errors.MustChooseNameplateFirstError): helper.choose_words("word-word") @@ -411,24 +446,24 @@ class Input(unittest.TestCase): i, c, l, events = self.build() helper = i.start() self.assertIsInstance(helper, _input.Helper) - self.assertEqual(events, [("l.refresh",)]) + self.assertEqual(events, [("l.refresh", )]) events[:] = [] d = helper.when_wordlist_is_available() self.assertNoResult(d) helper.refresh_nameplates() - self.assertEqual(events, [("l.refresh",)]) + self.assertEqual(events, [("l.refresh", )]) events[:] = [] with self.assertRaises(errors.MustChooseNameplateFirstError): helper.get_word_completions("prefix") i.got_nameplates({"1", "12", "34", "35", "367"}) self.assertNoResult(d) - self.assertEqual(helper.get_nameplate_completions(""), - {"1-", "12-", "34-", "35-", "367-"}) - self.assertEqual(helper.get_nameplate_completions("1"), - {"1-", "12-"}) + self.assertEqual( + helper.get_nameplate_completions(""), + {"1-", "12-", "34-", "35-", "367-"}) + self.assertEqual(helper.get_nameplate_completions("1"), {"1-", "12-"}) self.assertEqual(helper.get_nameplate_completions("2"), set()) - self.assertEqual(helper.get_nameplate_completions("3"), - {"34-", "35-", "367-"}) + self.assertEqual( + helper.get_nameplate_completions("3"), {"34-", "35-", "367-"}) helper.choose_nameplate("34") with self.assertRaises(errors.AlreadyChoseNameplateError): helper.refresh_nameplates() @@ -461,15 +496,14 @@ class Input(unittest.TestCase): self.assertEqual(events, [("c.finished_input", "34-word-word")]) - class Lister(unittest.TestCase): def build(self): events = [] - l = _lister.Lister(timing.DebugTiming()) + lister = _lister.Lister(timing.DebugTiming()) rc = Dummy("rc", events, IRendezvousConnector, "tx_list") i = Dummy("i", events, IInput, "got_nameplates") - l.wire(rc, i) - return l, rc, i, events + lister.wire(rc, i) + return lister, rc, i, events def test_connect_first(self): l, rc, i, events = self.build() @@ -478,12 +512,14 @@ class Lister(unittest.TestCase): l.connected() self.assertEqual(events, []) l.refresh() - self.assertEqual(events, [("rc.tx_list",), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ]) events[:] = [] l.rx_nameplates({"1", "2", "3"}) - self.assertEqual(events, [("i.got_nameplates", {"1", "2", "3"}), - ]) + self.assertEqual(events, [ + ("i.got_nameplates", {"1", "2", "3"}), + ]) events[:] = [] # now we're satisfied: disconnecting and reconnecting won't ask again l.lost() @@ -492,8 +528,9 @@ class Lister(unittest.TestCase): # but if we're told to refresh, we'll do so l.refresh() - self.assertEqual(events, [("rc.tx_list",), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ]) def test_connect_first_ask_twice(self): l, rc, i, events = self.build() @@ -501,44 +538,51 @@ class Lister(unittest.TestCase): self.assertEqual(events, []) l.refresh() l.refresh() - self.assertEqual(events, [("rc.tx_list",), - ("rc.tx_list",), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ("rc.tx_list", ), + ]) l.rx_nameplates({"1", "2", "3"}) - self.assertEqual(events, [("rc.tx_list",), - ("rc.tx_list",), - ("i.got_nameplates", {"1", "2", "3"}), - ]) - l.rx_nameplates({"1" ,"2", "3", "4"}) - self.assertEqual(events, [("rc.tx_list",), - ("rc.tx_list",), - ("i.got_nameplates", {"1", "2", "3"}), - ("i.got_nameplates", {"1", "2", "3", "4"}), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ("rc.tx_list", ), + ("i.got_nameplates", {"1", "2", "3"}), + ]) + l.rx_nameplates({"1", "2", "3", "4"}) + self.assertEqual(events, [ + ("rc.tx_list", ), + ("rc.tx_list", ), + ("i.got_nameplates", {"1", "2", "3"}), + ("i.got_nameplates", {"1", "2", "3", "4"}), + ]) def test_reconnect(self): l, rc, i, events = self.build() l.refresh() l.connected() - self.assertEqual(events, [("rc.tx_list",), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ]) events[:] = [] l.lost() l.connected() - self.assertEqual(events, [("rc.tx_list",), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ]) def test_refresh_first(self): l, rc, i, events = self.build() l.refresh() self.assertEqual(events, []) l.connected() - self.assertEqual(events, [("rc.tx_list",), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ]) l.rx_nameplates({"1", "2", "3"}) - self.assertEqual(events, [("rc.tx_list",), - ("i.got_nameplates", {"1", "2", "3"}), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ("i.got_nameplates", {"1", "2", "3"}), + ]) def test_unrefreshed(self): l, rc, i, events = self.build() @@ -547,8 +591,10 @@ class Lister(unittest.TestCase): l.connected() self.assertEqual(events, []) l.rx_nameplates({"1", "2", "3"}) - self.assertEqual(events, [("i.got_nameplates", {"1", "2", "3"}), - ]) + self.assertEqual(events, [ + ("i.got_nameplates", {"1", "2", "3"}), + ]) + class Allocator(unittest.TestCase): def build(self): @@ -569,32 +615,37 @@ class Allocator(unittest.TestCase): a.allocate(2, FakeWordList()) self.assertEqual(events, []) a.connected() - self.assertEqual(events, [("rc.tx_allocate",)]) + self.assertEqual(events, [("rc.tx_allocate", )]) events[:] = [] a.lost() a.connected() - self.assertEqual(events, [("rc.tx_allocate",), - ]) + self.assertEqual(events, [ + ("rc.tx_allocate", ), + ]) events[:] = [] a.rx_allocated("1") - self.assertEqual(events, [("c.allocated", "1", "1-word-word"), - ]) + self.assertEqual(events, [ + ("c.allocated", "1", "1-word-word"), + ]) def test_connect_first(self): a, rc, c, events = self.build() a.connected() self.assertEqual(events, []) a.allocate(2, FakeWordList()) - self.assertEqual(events, [("rc.tx_allocate",)]) + self.assertEqual(events, [("rc.tx_allocate", )]) events[:] = [] a.lost() a.connected() - self.assertEqual(events, [("rc.tx_allocate",), - ]) + self.assertEqual(events, [ + ("rc.tx_allocate", ), + ]) events[:] = [] a.rx_allocated("1") - self.assertEqual(events, [("c.allocated", "1", "1-word-word"), - ]) + self.assertEqual(events, [ + ("c.allocated", "1", "1-word-word"), + ]) + class Nameplate(unittest.TestCase): def build(self): @@ -602,7 +653,8 @@ class Nameplate(unittest.TestCase): n = _nameplate.Nameplate() m = Dummy("m", events, IMailbox, "got_mailbox") i = Dummy("i", events, IInput, "got_wordlist") - rc = Dummy("rc", events, IRendezvousConnector, "tx_claim", "tx_release") + rc = Dummy("rc", events, IRendezvousConnector, "tx_claim", + "tx_release") t = Dummy("t", events, ITerminator, "nameplate_done") n.wire(m, i, rc, t) return n, m, i, rc, t, events @@ -611,12 +663,14 @@ class Nameplate(unittest.TestCase): n, m, i, rc, t, events = self.build() with self.assertRaises(errors.KeyFormatError) as e: n.set_nameplate(" 1") - self.assertEqual(str(e.exception), - "Nameplate ' 1' must be numeric, with no spaces.") + self.assertEqual( + str(e.exception), + "Nameplate ' 1' must be numeric, with no spaces.") with self.assertRaises(errors.KeyFormatError) as e: n.set_nameplate("one") - self.assertEqual(str(e.exception), - "Nameplate 'one' must be numeric, with no spaces.") + self.assertEqual( + str(e.exception), + "Nameplate 'one' must be numeric, with no spaces.") # wormhole should still be usable n.set_nameplate("1") @@ -636,9 +690,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() @@ -646,7 +701,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_connect_first(self): # connection remains up throughout @@ -661,9 +716,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() @@ -671,7 +727,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_reconnect_while_claiming(self): # connection bounced while waiting for rx_claimed @@ -700,9 +756,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.lost() @@ -722,9 +779,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() @@ -748,9 +806,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() @@ -758,7 +817,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) events[:] = [] n.lost() @@ -768,20 +827,20 @@ class Nameplate(unittest.TestCase): def test_close_while_idle(self): n, m, i, rc, t, events = self.build() n.close() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_idle_connected(self): n, m, i, rc, t, events = self.build() n.connected() self.assertEqual(events, []) n.close() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_unclaimed(self): n, m, i, rc, t, events = self.build() n.set_nameplate("1") - n.close() # before ever being connected - self.assertEqual(events, [("t.nameplate_done",)]) + n.close() # before ever being connected + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_claiming(self): n, m, i, rc, t, events = self.build() @@ -796,7 +855,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_claiming_but_disconnected(self): n, m, i, rc, t, events = self.build() @@ -815,7 +874,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_claimed(self): n, m, i, rc, t, events = self.build() @@ -828,9 +887,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.close() @@ -839,7 +899,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_claimed_but_disconnected(self): n, m, i, rc, t, events = self.build() @@ -852,9 +912,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.lost() @@ -865,7 +926,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_releasing(self): n, m, i, rc, t, events = self.build() @@ -878,19 +939,20 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() self.assertEqual(events, [("rc.tx_release", "1")]) events[:] = [] - n.close() # ignored, we're already on our way out the door + n.close() # ignored, we're already on our way out the door self.assertEqual(events, []) n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_releasing_but_disconnecteda(self): n, m, i, rc, t, events = self.build() @@ -903,9 +965,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() @@ -922,7 +985,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_done(self): # connection remains up throughout @@ -937,9 +1000,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() @@ -947,10 +1011,10 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) events[:] = [] - n.close() # NOP + n.close() # NOP self.assertEqual(events, []) def test_close_while_done_but_disconnected(self): @@ -966,9 +1030,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() @@ -976,20 +1041,21 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) events[:] = [] n.lost() - n.close() # NOP + n.close() # NOP self.assertEqual(events, []) + class Mailbox(unittest.TestCase): def build(self): events = [] m = _mailbox.Mailbox("side1") n = Dummy("n", events, INameplate, "release") - rc = Dummy("rc", events, IRendezvousConnector, - "tx_add", "tx_open", "tx_close") + rc = Dummy("rc", events, IRendezvousConnector, "tx_add", "tx_open", + "tx_close") o = Dummy("o", events, IOrder, "got_message") t = Dummy("t", events, ITerminator, "mailbox_done") m.wire(n, rc, o, t) @@ -998,12 +1064,13 @@ class Mailbox(unittest.TestCase): # TODO: test moods def assert_events(self, events, initial_events, tx_add_events): - self.assertEqual(len(events), len(initial_events)+len(tx_add_events), - events) + self.assertEqual( + len(events), + len(initial_events) + len(tx_add_events), events) self.assertEqual(events[:len(initial_events)], initial_events) self.assertEqual(set(events[len(initial_events):]), tx_add_events) - def test_connect_first(self): # connect before got_mailbox + def test_connect_first(self): # connect before got_mailbox m, n, rc, o, t, events = self.build() m.add_message("phase1", b"msg1") self.assertEqual(events, []) @@ -1029,38 +1096,42 @@ class Mailbox(unittest.TestCase): m.connected() # the other messages are allowed to be sent in any order - self.assert_events(events, [("rc.tx_open", "mbox1")], - { ("rc.tx_add", "phase1", b"msg1"), - ("rc.tx_add", "phase2", b"msg2"), - ("rc.tx_add", "phase3", b"msg3"), - }) + self.assert_events( + events, [("rc.tx_open", "mbox1")], { + ("rc.tx_add", "phase1", b"msg1"), + ("rc.tx_add", "phase2", b"msg2"), + ("rc.tx_add", "phase3", b"msg3"), + }) events[:] = [] - m.rx_message("side1", "phase1", b"msg1") # echo of our message, dequeue + m.rx_message("side1", "phase1", + b"msg1") # echo of our message, dequeue self.assertEqual(events, []) m.lost() m.connected() - self.assert_events(events, [("rc.tx_open", "mbox1")], - {("rc.tx_add", "phase2", b"msg2"), - ("rc.tx_add", "phase3", b"msg3"), - }) + self.assert_events(events, [("rc.tx_open", "mbox1")], { + ("rc.tx_add", "phase2", b"msg2"), + ("rc.tx_add", "phase3", b"msg3"), + }) events[:] = [] # a new message from the peer gets delivered, and the Nameplate is # released since the message proves that our peer opened the Mailbox # and therefore no longer needs the Nameplate - m.rx_message("side2", "phase1", b"msg1them") # new message from peer - self.assertEqual(events, [("n.release",), - ("o.got_message", "side2", "phase1", b"msg1them"), - ]) + m.rx_message("side2", "phase1", b"msg1them") # new message from peer + self.assertEqual(events, [ + ("n.release", ), + ("o.got_message", "side2", "phase1", b"msg1them"), + ]) events[:] = [] # we de-duplicate peer messages, but still re-release the nameplate # since Nameplate is smart enough to ignore that m.rx_message("side2", "phase1", b"msg1them") - self.assertEqual(events, [("n.release",), - ]) + self.assertEqual(events, [ + ("n.release", ), + ]) events[:] = [] m.close("happy") @@ -1081,7 +1152,7 @@ class Mailbox(unittest.TestCase): events[:] = [] m.rx_closed() - self.assertEqual(events, [("t.mailbox_done",)]) + self.assertEqual(events, [("t.mailbox_done", )]) events[:] = [] # while closed, we ignore everything @@ -1092,7 +1163,7 @@ class Mailbox(unittest.TestCase): m.connected() self.assertEqual(events, []) - def test_mailbox_first(self): # got_mailbox before connect + def test_mailbox_first(self): # got_mailbox before connect m, n, rc, o, t, events = self.build() m.add_message("phase1", b"msg1") self.assertEqual(events, []) @@ -1103,27 +1174,27 @@ class Mailbox(unittest.TestCase): m.connected() - self.assert_events(events, [("rc.tx_open", "mbox1")], - { ("rc.tx_add", "phase1", b"msg1"), - ("rc.tx_add", "phase2", b"msg2"), - }) + self.assert_events(events, [("rc.tx_open", "mbox1")], { + ("rc.tx_add", "phase1", b"msg1"), + ("rc.tx_add", "phase2", b"msg2"), + }) def test_close_while_idle(self): m, n, rc, o, t, events = self.build() m.close("happy") - self.assertEqual(events, [("t.mailbox_done",)]) + self.assertEqual(events, [("t.mailbox_done", )]) def test_close_while_idle_but_connected(self): m, n, rc, o, t, events = self.build() m.connected() m.close("happy") - self.assertEqual(events, [("t.mailbox_done",)]) + self.assertEqual(events, [("t.mailbox_done", )]) def test_close_while_mailbox_disconnected(self): m, n, rc, o, t, events = self.build() m.got_mailbox("mbox1") m.close("happy") - self.assertEqual(events, [("t.mailbox_done",)]) + self.assertEqual(events, [("t.mailbox_done", )]) def test_close_while_reconnecting(self): m, n, rc, o, t, events = self.build() @@ -1143,9 +1214,10 @@ class Mailbox(unittest.TestCase): events[:] = [] m.rx_closed() - self.assertEqual(events, [("t.mailbox_done",)]) + self.assertEqual(events, [("t.mailbox_done", )]) events[:] = [] + class Terminator(unittest.TestCase): def build(self): events = [] @@ -1160,13 +1232,15 @@ class Terminator(unittest.TestCase): # there are three events, and we need to test all orderings of them def _do_test(self, ev1, ev2, ev3): t, b, rc, n, m, events = self.build() - input_events = {"mailbox": lambda: t.mailbox_done(), - "nameplate": lambda: t.nameplate_done(), - "close": lambda: t.close("happy"), - } - close_events = [("n.close",), - ("m.close", "happy"), - ] + input_events = { + "mailbox": lambda: t.mailbox_done(), + "nameplate": lambda: t.nameplate_done(), + "close": lambda: t.close("happy"), + } + close_events = [ + ("n.close", ), + ("m.close", "happy"), + ] input_events[ev1]() expected = [] @@ -1186,12 +1260,12 @@ class Terminator(unittest.TestCase): expected = [] if ev3 == "close": expected.extend(close_events) - expected.append(("rc.stop",)) + expected.append(("rc.stop", )) self.assertEqual(events, expected) events[:] = [] t.stopped() - self.assertEqual(events, [("b.closed",)]) + self.assertEqual(events, [("b.closed", )]) def test_terminate(self): self._do_test("mailbox", "nameplate", "close") @@ -1203,18 +1277,19 @@ class Terminator(unittest.TestCase): # TODO: test moods + class MockBoss(_boss.Boss): def __attrs_post_init__(self): - #self._build_workers() + # self._build_workers() self._init_other_state() + class Boss(unittest.TestCase): def build(self): events = [] - wormhole = Dummy("w", events, None, - "got_welcome", - "got_code", "got_key", "got_verifier", "got_versions", - "received", "closed") + wormhole = Dummy("w", events, None, "got_welcome", "got_code", + "got_key", "got_verifier", "got_versions", "received", + "closed") versions = {"app": "version1"} reactor = None journal = ImmediateJournal() @@ -1226,8 +1301,8 @@ class Boss(unittest.TestCase): b._T = Dummy("t", events, ITerminator, "close") b._S = Dummy("s", events, ISend, "send") b._RC = Dummy("rc", events, IRendezvousConnector, "start") - b._C = Dummy("c", events, ICode, - "allocate_code", "input_code", "set_code") + b._C = Dummy("c", events, ICode, "allocate_code", "input_code", + "set_code") return b, events def test_basic(self): @@ -1242,8 +1317,9 @@ class Boss(unittest.TestCase): welcome = {"howdy": "how are ya"} b.rx_welcome(welcome) - self.assertEqual(events, [("w.got_welcome", welcome), - ]) + self.assertEqual(events, [ + ("w.got_welcome", welcome), + ]) events[:] = [] # pretend a peer message was correctly decrypted @@ -1252,11 +1328,12 @@ class Boss(unittest.TestCase): b.got_verifier(b"verifier") b.got_message("version", b"{}") b.got_message("0", b"msg1") - self.assertEqual(events, [("w.got_key", b"key"), - ("w.got_verifier", b"verifier"), - ("w.got_versions", {}), - ("w.received", b"msg1"), - ]) + self.assertEqual(events, [ + ("w.got_key", b"key"), + ("w.got_verifier", b"verifier"), + ("w.got_versions", {}), + ("w.received", b"msg1"), + ]) events[:] = [] b.send(b"msg2") @@ -1330,7 +1407,7 @@ class Boss(unittest.TestCase): self.assertEqual(events, [("c.set_code", "1-code")]) events[:] = [] - b.close() # before even w.got_code + b.close() # before even w.got_code self.assertEqual(events, [("t.close", "lonely")]) events[:] = [] @@ -1383,9 +1460,9 @@ class Boss(unittest.TestCase): self.assertEqual(events, [("w.got_code", "1-code")]) events[:] = [] - b.happy() # phase=version + b.happy() # phase=version - b.scared() # phase=0 + b.scared() # phase=0 self.assertEqual(events, [("t.close", "scary")]) events[:] = [] @@ -1404,7 +1481,7 @@ class Boss(unittest.TestCase): self.assertEqual(events, [("w.got_code", "1-code")]) events[:] = [] - b.happy() # phase=version + b.happy() # phase=version b.got_message("unknown-phase", b"spooky") self.assertEqual(events, []) @@ -1429,7 +1506,7 @@ class Boss(unittest.TestCase): b, events = self.build() b._C.retval = "helper" helper = b.input_code() - self.assertEqual(events, [("c.input_code",)]) + self.assertEqual(events, [("c.input_code", )]) self.assertEqual(helper, "helper") with self.assertRaises(errors.OnlyOneCodeError): b.input_code() @@ -1451,11 +1528,9 @@ class Rendezvous(unittest.TestCase): journal = ImmediateJournal() tor_manager = None client_version = ("python", __version__) - rc = _rendezvous.RendezvousConnector("ws://host:4000/v1", "appid", - "side", reactor, - journal, tor_manager, - timing.DebugTiming(), - client_version) + rc = _rendezvous.RendezvousConnector( + "ws://host:4000/v1", "appid", "side", reactor, journal, + tor_manager, timing.DebugTiming(), client_version) b = Dummy("b", events, IBoss, "error") n = Dummy("n", events, INameplate, "connected", "lost") m = Dummy("m", events, IMailbox, "connected", "lost") @@ -1488,34 +1563,43 @@ class Rendezvous(unittest.TestCase): rc, events = self.build() ws = mock.Mock() + def notrandom(length): return b"\x00" * length + with mock.patch("os.urandom", notrandom): rc.ws_open(ws) - self.assertEqual(events, [("n.connected", ), - ("m.connected", ), - ("l.connected", ), - ("a.connected", ), - ]) + self.assertEqual(events, [ + ("n.connected", ), + ("m.connected", ), + ("l.connected", ), + ("a.connected", ), + ]) events[:] = [] + def sent_messages(ws): for c in ws.mock_calls: self.assertEqual(c[0], "sendMessage", ws.mock_calls) self.assertEqual(c[1][1], False, ws.mock_calls) yield bytes_to_dict(c[1][0]) - self.assertEqual(list(sent_messages(ws)), - [dict(appid="appid", side="side", - client_version=["python", __version__], - id="0000", type="bind"), - ]) + + self.assertEqual( + list(sent_messages(ws)), [ + dict( + appid="appid", + side="side", + client_version=["python", __version__], + id="0000", + type="bind"), + ]) rc.ws_close(True, None, None) - self.assertEqual(events, [("n.lost", ), - ("m.lost", ), - ("l.lost", ), - ("a.lost", ), - ]) - + self.assertEqual(events, [ + ("n.lost", ), + ("m.lost", ), + ("l.lost", ), + ("a.lost", ), + ]) # TODO diff --git a/src/wormhole/test/test_observer.py b/src/wormhole/test/test_observer.py index d974cc5..8d89ddd 100644 --- a/src/wormhole/test/test_observer.py +++ b/src/wormhole/test/test_observer.py @@ -1,9 +1,11 @@ -from twisted.trial import unittest from twisted.internet.task import Clock from twisted.python.failure import Failure +from twisted.trial import unittest + from ..eventual import EventualQueue from ..observer import OneShotObserver, SequenceObserver + class OneShot(unittest.TestCase): def test_fire(self): c = Clock() @@ -119,4 +121,3 @@ class Sequence(unittest.TestCase): d2 = o.when_next_event() eq.flush_sync() self.assertIdentical(self.failureResultOf(d2), f) - diff --git a/src/wormhole/test/test_rlcompleter.py b/src/wormhole/test/test_rlcompleter.py index 41f0abf..dd050d6 100644 --- a/src/wormhole/test/test_rlcompleter.py +++ b/src/wormhole/test/test_rlcompleter.py @@ -1,39 +1,44 @@ -from __future__ import print_function, absolute_import, unicode_literals -import mock +from __future__ import absolute_import, print_function, unicode_literals + from itertools import count -from twisted.trial import unittest + from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks from twisted.internet.threads import deferToThread -from .._rlcompleter import (input_with_completion, - _input_code_with_completion, - CodeInputter, warn_readline) -from ..errors import KeyFormatError, AlreadyInputNameplateError +from twisted.trial import unittest + +import mock + +from .._rlcompleter import (CodeInputter, _input_code_with_completion, + input_with_completion, warn_readline) +from ..errors import AlreadyInputNameplateError, KeyFormatError + APPID = "appid" + class Input(unittest.TestCase): @inlineCallbacks def test_wrapper(self): helper = object() trueish = object() - with mock.patch("wormhole._rlcompleter._input_code_with_completion", - return_value=trueish) as m: - used_completion = yield input_with_completion("prompt:", helper, - reactor) + with mock.patch( + "wormhole._rlcompleter._input_code_with_completion", + return_value=trueish) as m: + used_completion = yield input_with_completion( + "prompt:", helper, reactor) self.assertIs(used_completion, trueish) - self.assertEqual(m.mock_calls, - [mock.call("prompt:", helper, reactor)]) + self.assertEqual(m.mock_calls, [mock.call("prompt:", helper, reactor)]) # note: if this test fails, the warn_readline() message will probably # get written to stderr + class Sync(unittest.TestCase): # exercise _input_code_with_completion, which uses the blocking builtin # "input()" function, hence _input_code_with_completion is usually in a # thread with deferToThread @mock.patch("wormhole._rlcompleter.CodeInputter") - @mock.patch("wormhole._rlcompleter.readline", - __doc__="I am GNU readline") + @mock.patch("wormhole._rlcompleter.readline", __doc__="I am GNU readline") @mock.patch("wormhole._rlcompleter.input", return_value="code") def test_readline(self, input, readline, ci): c = mock.Mock(name="inhibit parenting") @@ -49,17 +54,17 @@ class Sync(unittest.TestCase): self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)]) self.assertEqual(c.mock_calls, [mock.call.finish("code")]) self.assertEqual(input.mock_calls, [mock.call(prompt)]) - self.assertEqual(readline.mock_calls, - [mock.call.parse_and_bind("tab: complete"), - mock.call.set_completer(c.completer), - mock.call.set_completer_delims(""), - ]) + self.assertEqual(readline.mock_calls, [ + mock.call.parse_and_bind("tab: complete"), + mock.call.set_completer(c.completer), + mock.call.set_completer_delims(""), + ]) @mock.patch("wormhole._rlcompleter.CodeInputter") @mock.patch("wormhole._rlcompleter.readline") @mock.patch("wormhole._rlcompleter.input", return_value="code") def test_readline_no_docstring(self, input, readline, ci): - del readline.__doc__ # when in doubt, it assumes GNU readline + del readline.__doc__ # when in doubt, it assumes GNU readline c = mock.Mock(name="inhibit parenting") c.completer = object() trueish = object() @@ -73,15 +78,14 @@ class Sync(unittest.TestCase): self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)]) self.assertEqual(c.mock_calls, [mock.call.finish("code")]) self.assertEqual(input.mock_calls, [mock.call(prompt)]) - self.assertEqual(readline.mock_calls, - [mock.call.parse_and_bind("tab: complete"), - mock.call.set_completer(c.completer), - mock.call.set_completer_delims(""), - ]) + self.assertEqual(readline.mock_calls, [ + mock.call.parse_and_bind("tab: complete"), + mock.call.set_completer(c.completer), + mock.call.set_completer_delims(""), + ]) @mock.patch("wormhole._rlcompleter.CodeInputter") - @mock.patch("wormhole._rlcompleter.readline", - __doc__="I am libedit") + @mock.patch("wormhole._rlcompleter.readline", __doc__="I am libedit") @mock.patch("wormhole._rlcompleter.input", return_value="code") def test_libedit(self, input, readline, ci): c = mock.Mock(name="inhibit parenting") @@ -97,11 +101,11 @@ class Sync(unittest.TestCase): self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)]) self.assertEqual(c.mock_calls, [mock.call.finish("code")]) self.assertEqual(input.mock_calls, [mock.call(prompt)]) - self.assertEqual(readline.mock_calls, - [mock.call.parse_and_bind("bind ^I rl_complete"), - mock.call.set_completer(c.completer), - mock.call.set_completer_delims(""), - ]) + self.assertEqual(readline.mock_calls, [ + mock.call.parse_and_bind("bind ^I rl_complete"), + mock.call.set_completer(c.completer), + mock.call.set_completer_delims(""), + ]) @mock.patch("wormhole._rlcompleter.CodeInputter") @mock.patch("wormhole._rlcompleter.readline", None) @@ -139,6 +143,7 @@ class Sync(unittest.TestCase): self.assertEqual(c.mock_calls, [mock.call.finish(u"code")]) self.assertEqual(input.mock_calls, [mock.call(prompt)]) + def get_completions(c, prefix): completions = [] for state in count(0): @@ -147,9 +152,11 @@ def get_completions(c, prefix): return completions completions.append(text) + def fake_blockingCallFromThread(f, *a, **kw): return f(*a, **kw) + class Completion(unittest.TestCase): def test_simple(self): # no actual completion @@ -158,12 +165,14 @@ class Completion(unittest.TestCase): c.bcft = fake_blockingCallFromThread c.finish("1-code-ghost") self.assertFalse(c.used_completion) - self.assertEqual(helper.mock_calls, - [mock.call.choose_nameplate("1"), - mock.call.choose_words("code-ghost")]) + self.assertEqual(helper.mock_calls, [ + mock.call.choose_nameplate("1"), + mock.call.choose_words("code-ghost") + ]) - @mock.patch("wormhole._rlcompleter.readline", - get_completion_type=mock.Mock(return_value=0)) + @mock.patch( + "wormhole._rlcompleter.readline", + get_completion_type=mock.Mock(return_value=0)) def test_call(self, readline): # check that it calls _commit_and_build_completions correctly helper = mock.Mock() @@ -188,14 +197,14 @@ class Completion(unittest.TestCase): # now we have three "a" words: "and", "ark", "aaah!zombies!!" cabc.reset_mock() cabc.configure_mock(return_value=["aargh", "ark", "aaah!zombies!!"]) - self.assertEqual(get_completions(c, "12-a"), - ["aargh", "ark", "aaah!zombies!!"]) + self.assertEqual( + get_completions(c, "12-a"), ["aargh", "ark", "aaah!zombies!!"]) self.assertEqual(cabc.mock_calls, [mock.call("12-a")]) cabc.reset_mock() cabc.configure_mock(return_value=["aargh", "aaah!zombies!!"]) - self.assertEqual(get_completions(c, "12-aa"), - ["aargh", "aaah!zombies!!"]) + self.assertEqual( + get_completions(c, "12-aa"), ["aargh", "aaah!zombies!!"]) self.assertEqual(cabc.mock_calls, [mock.call("12-aa")]) cabc.reset_mock() @@ -223,16 +232,17 @@ class Completion(unittest.TestCase): def test_build_completions(self): rn = mock.Mock() # InputHelper.get_nameplate_completions returns just the suffixes - gnc = mock.Mock() # get_nameplate_completions - cn = mock.Mock() # choose_nameplate - gwc = mock.Mock() # get_word_completions - cw = mock.Mock() # choose_words - helper = mock.Mock(refresh_nameplates=rn, - get_nameplate_completions=gnc, - choose_nameplate=cn, - get_word_completions=gwc, - choose_words=cw, - ) + gnc = mock.Mock() # get_nameplate_completions + cn = mock.Mock() # choose_nameplate + gwc = mock.Mock() # get_word_completions + cw = mock.Mock() # choose_words + helper = mock.Mock( + refresh_nameplates=rn, + get_nameplate_completions=gnc, + choose_nameplate=cn, + get_word_completions=gwc, + choose_words=cw, + ) # this needs a real reactor, for blockingCallFromThread c = CodeInputter(helper, reactor) cabc = c._commit_and_build_completions @@ -327,17 +337,18 @@ class Completion(unittest.TestCase): gwc.configure_mock(return_value=["code", "court"]) c = CodeInputter(helper, reactor) cabc = c._commit_and_build_completions - matches = yield deferToThread(cabc, "1-co") # this commits us to 1- - self.assertEqual(helper.mock_calls, - [mock.call.choose_nameplate("1"), - mock.call.when_wordlist_is_available(), - mock.call.get_word_completions("co")]) + matches = yield deferToThread(cabc, "1-co") # this commits us to 1- + self.assertEqual(helper.mock_calls, [ + mock.call.choose_nameplate("1"), + mock.call.when_wordlist_is_available(), + mock.call.get_word_completions("co") + ]) self.assertEqual(matches, ["1-code", "1-court"]) helper.reset_mock() with self.assertRaises(AlreadyInputNameplateError) as e: yield deferToThread(cabc, "2-co") - self.assertEqual(str(e.exception), - "nameplate (1-) already entered, cannot go back") + self.assertEqual( + str(e.exception), "nameplate (1-) already entered, cannot go back") self.assertEqual(helper.mock_calls, []) @inlineCallbacks @@ -347,17 +358,18 @@ class Completion(unittest.TestCase): gwc.configure_mock(return_value=["code", "court"]) c = CodeInputter(helper, reactor) cabc = c._commit_and_build_completions - matches = yield deferToThread(cabc, "1-co") # this commits us to 1- - self.assertEqual(helper.mock_calls, - [mock.call.choose_nameplate("1"), - mock.call.when_wordlist_is_available(), - mock.call.get_word_completions("co")]) + matches = yield deferToThread(cabc, "1-co") # this commits us to 1- + self.assertEqual(helper.mock_calls, [ + mock.call.choose_nameplate("1"), + mock.call.when_wordlist_is_available(), + mock.call.get_word_completions("co") + ]) self.assertEqual(matches, ["1-code", "1-court"]) helper.reset_mock() with self.assertRaises(AlreadyInputNameplateError) as e: yield deferToThread(c.finish, "2-code") - self.assertEqual(str(e.exception), - "nameplate (1-) already entered, cannot go back") + self.assertEqual( + str(e.exception), "nameplate (1-) already entered, cannot go back") self.assertEqual(helper.mock_calls, []) @mock.patch("wormhole._rlcompleter.stderr") @@ -366,6 +378,7 @@ class Completion(unittest.TestCase): # right time, since it involves a reactor and a "system event # trigger", but let's at least make sure it's invocable warn_readline() - expected ="\nCommand interrupted: please press Return to quit" - self.assertEqual(stderr.mock_calls, [mock.call.write(expected), - mock.call.write("\n")]) + expected = "\nCommand interrupted: please press Return to quit" + self.assertEqual(stderr.mock_calls, + [mock.call.write(expected), + mock.call.write("\n")]) diff --git a/src/wormhole/test/test_ssh.py b/src/wormhole/test/test_ssh.py index 775c22a..9f86eed 100644 --- a/src/wormhole/test/test_ssh.py +++ b/src/wormhole/test/test_ssh.py @@ -1,10 +1,15 @@ -import os, io -import mock +import io +import os + from twisted.trial import unittest + +import mock + from ..cli import cmd_ssh OTHERS = ["config", "config~", "known_hosts", "known_hosts~"] + class FindPubkey(unittest.TestCase): def test_find_one(self): files = OTHERS + ["id_rsa.pub", "id_rsa"] @@ -12,8 +17,8 @@ class FindPubkey(unittest.TestCase): pubkey_file = io.StringIO(pubkey_data) with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True): with mock.patch("os.listdir", return_value=files) as ld: - with mock.patch("wormhole.cli.cmd_ssh.open", - return_value=pubkey_file): + with mock.patch( + "wormhole.cli.cmd_ssh.open", return_value=pubkey_file): res = cmd_ssh.find_public_key() self.assertEqual(ld.mock_calls, [mock.call(os.path.expanduser("~/.ssh/"))]) @@ -24,7 +29,7 @@ class FindPubkey(unittest.TestCase): self.assertEqual(pubkey, pubkey_data) def test_find_none(self): - files = OTHERS # no pubkey + files = OTHERS # no pubkey with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True): with mock.patch("os.listdir", return_value=files): e = self.assertRaises(cmd_ssh.PubkeyError, @@ -34,12 +39,12 @@ class FindPubkey(unittest.TestCase): def test_bad_hint(self): with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=False): - e = self.assertRaises(cmd_ssh.PubkeyError, - cmd_ssh.find_public_key, - hint="bogus/path") + e = self.assertRaises( + cmd_ssh.PubkeyError, + cmd_ssh.find_public_key, + hint="bogus/path") self.assertEqual(str(e), "Can't find 'bogus/path'") - def test_find_multiple(self): files = OTHERS + ["id_rsa.pub", "id_rsa", "id_dsa.pub", "id_dsa"] pubkey_data = u"ssh-rsa AAAAkeystuff email@host\n" @@ -47,10 +52,11 @@ class FindPubkey(unittest.TestCase): with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True): with mock.patch("os.listdir", return_value=files): responses = iter(["frog", "NaN", "-1", "0"]) - with mock.patch("click.prompt", - side_effect=lambda p: next(responses)): - with mock.patch("wormhole.cli.cmd_ssh.open", - return_value=pubkey_file): + with mock.patch( + "click.prompt", side_effect=lambda p: next(responses)): + with mock.patch( + "wormhole.cli.cmd_ssh.open", + return_value=pubkey_file): res = cmd_ssh.find_public_key() self.assertEqual(len(res), 3, res) kind, keyid, pubkey = res diff --git a/src/wormhole/test/test_tor_manager.py b/src/wormhole/test/test_tor_manager.py index b758e71..386698c 100644 --- a/src/wormhole/test/test_tor_manager.py +++ b/src/wormhole/test/test_tor_manager.py @@ -1,42 +1,51 @@ from __future__ import print_function, unicode_literals -import mock, io -from twisted.trial import unittest + +import io + from twisted.internet import defer from twisted.internet.error import ConnectError +from twisted.trial import unittest + +import mock -from ..tor_manager import get_tor, SocksOnlyTor -from ..errors import NoTorError from .._interfaces import ITorManager +from ..errors import NoTorError +from ..tor_manager import SocksOnlyTor, get_tor + class X(): pass + class Tor(unittest.TestCase): def test_no_txtorcon(self): with mock.patch("wormhole.tor_manager.txtorcon", None): self.failureResultOf(get_tor(None), NoTorError) def test_bad_args(self): - f = self.failureResultOf(get_tor(None, launch_tor="not boolean"), - TypeError) + f = self.failureResultOf( + get_tor(None, launch_tor="not boolean"), TypeError) self.assertEqual(str(f.value), "launch_tor= must be boolean") - f = self.failureResultOf(get_tor(None, tor_control_port=1234), - TypeError) + f = self.failureResultOf( + get_tor(None, tor_control_port=1234), TypeError) self.assertEqual(str(f.value), "tor_control_port= must be str or None") - f = self.failureResultOf(get_tor(None, launch_tor=True, - tor_control_port="tcp:127.0.0.1:1234"), - ValueError) - self.assertEqual(str(f.value), - "cannot combine --launch-tor and --tor-control-port=") + f = self.failureResultOf( + get_tor( + None, launch_tor=True, tor_control_port="tcp:127.0.0.1:1234"), + ValueError) + self.assertEqual( + str(f.value), + "cannot combine --launch-tor and --tor-control-port=") def test_launch(self): reactor = object() - my_tor = X() # object() didn't like providedBy() + my_tor = X() # object() didn't like providedBy() launch_d = defer.Deferred() stderr = io.StringIO() - with mock.patch("wormhole.tor_manager.txtorcon.launch", - side_effect=launch_d) as launch: + with mock.patch( + "wormhole.tor_manager.txtorcon.launch", + side_effect=launch_d) as launch: d = get_tor(reactor, launch_tor=True, stderr=stderr) self.assertNoResult(d) self.assertEqual(launch.mock_calls, [mock.call(reactor)]) @@ -44,18 +53,21 @@ class Tor(unittest.TestCase): tor = self.successResultOf(d) self.assertIs(tor, my_tor) self.assert_(ITorManager.providedBy(tor)) - self.assertEqual(stderr.getvalue(), - " launching a new Tor process, this may take a while..\n") + self.assertEqual( + stderr.getvalue(), + " launching a new Tor process, this may take a while..\n") def test_connect(self): reactor = object() - my_tor = X() # object() didn't like providedBy() + my_tor = X() # object() didn't like providedBy() connect_d = defer.Deferred() stderr = io.StringIO() - with mock.patch("wormhole.tor_manager.txtorcon.connect", - side_effect=connect_d) as connect: - with mock.patch("wormhole.tor_manager.clientFromString", - side_effect=["foo"]) as sfs: + with mock.patch( + "wormhole.tor_manager.txtorcon.connect", + side_effect=connect_d) as connect: + with mock.patch( + "wormhole.tor_manager.clientFromString", + side_effect=["foo"]) as sfs: d = get_tor(reactor, stderr=stderr) self.assertEqual(sfs.mock_calls, []) self.assertNoResult(d) @@ -71,10 +83,12 @@ class Tor(unittest.TestCase): reactor = object() connect_d = defer.Deferred() stderr = io.StringIO() - with mock.patch("wormhole.tor_manager.txtorcon.connect", - side_effect=connect_d) as connect: - with mock.patch("wormhole.tor_manager.clientFromString", - side_effect=["foo"]) as sfs: + with mock.patch( + "wormhole.tor_manager.txtorcon.connect", + side_effect=connect_d) as connect: + with mock.patch( + "wormhole.tor_manager.clientFromString", + side_effect=["foo"]) as sfs: d = get_tor(reactor, stderr=stderr) self.assertEqual(sfs.mock_calls, []) self.assertNoResult(d) @@ -85,20 +99,23 @@ class Tor(unittest.TestCase): self.assertIsInstance(tor, SocksOnlyTor) self.assert_(ITorManager.providedBy(tor)) self.assertEqual(tor._reactor, reactor) - self.assertEqual(stderr.getvalue(), - " unable to find default Tor control port, using SOCKS\n") + self.assertEqual( + stderr.getvalue(), + " unable to find default Tor control port, using SOCKS\n") def test_connect_custom_control_port(self): reactor = object() - my_tor = X() # object() didn't like providedBy() + my_tor = X() # object() didn't like providedBy() tcp = "PORT" ep = object() connect_d = defer.Deferred() stderr = io.StringIO() - with mock.patch("wormhole.tor_manager.txtorcon.connect", - side_effect=connect_d) as connect: - with mock.patch("wormhole.tor_manager.clientFromString", - side_effect=[ep]) as sfs: + with mock.patch( + "wormhole.tor_manager.txtorcon.connect", + side_effect=connect_d) as connect: + with mock.patch( + "wormhole.tor_manager.clientFromString", + side_effect=[ep]) as sfs: d = get_tor(reactor, tor_control_port=tcp, stderr=stderr) self.assertEqual(sfs.mock_calls, [mock.call(reactor, tcp)]) self.assertNoResult(d) @@ -116,10 +133,12 @@ class Tor(unittest.TestCase): ep = object() connect_d = defer.Deferred() stderr = io.StringIO() - with mock.patch("wormhole.tor_manager.txtorcon.connect", - side_effect=connect_d) as connect: - with mock.patch("wormhole.tor_manager.clientFromString", - side_effect=[ep]) as sfs: + with mock.patch( + "wormhole.tor_manager.txtorcon.connect", + side_effect=connect_d) as connect: + with mock.patch( + "wormhole.tor_manager.clientFromString", + side_effect=[ep]) as sfs: d = get_tor(reactor, tor_control_port=tcp, stderr=stderr) self.assertEqual(sfs.mock_calls, [mock.call(reactor, tcp)]) self.assertNoResult(d) @@ -129,18 +148,22 @@ class Tor(unittest.TestCase): self.failureResultOf(d, ConnectError) self.assertEqual(stderr.getvalue(), "") + class SocksOnly(unittest.TestCase): def test_tor(self): reactor = object() sot = SocksOnlyTor(reactor) fake_ep = object() - with mock.patch("wormhole.tor_manager.txtorcon.TorClientEndpoint", - return_value=fake_ep) as tce: + with mock.patch( + "wormhole.tor_manager.txtorcon.TorClientEndpoint", + return_value=fake_ep) as tce: ep = sot.stream_via("host", "port") self.assertIs(ep, fake_ep) - self.assertEqual(tce.mock_calls, [mock.call("host", "port", - socks_endpoint=None, - tls=False, - reactor=reactor)]) - - + self.assertEqual(tce.mock_calls, [ + mock.call( + "host", + "port", + socks_endpoint=None, + tls=False, + reactor=reactor) + ]) diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index fb03f38..63216ae 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -1,27 +1,33 @@ from __future__ import print_function, unicode_literals -import six -import io + import gc -import mock +import io from binascii import hexlify, unhexlify from collections import namedtuple -from twisted.trial import unittest -from twisted.internet import defer, task, endpoints, protocol, address, error + +import six +from nacl.exceptions import CryptoError +from nacl.secret import SecretBox +from twisted.internet import address, defer, endpoints, error, protocol, task from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.python import log from twisted.test import proto_helpers +from twisted.trial import unittest + +import mock from wormhole_transit_relay import transit_server -from ..errors import InternalError + from .. import transit +from ..errors import InternalError from .common import ServerBase -from nacl.secret import SecretBox -from nacl.exceptions import CryptoError + class Highlander(unittest.TestCase): def test_one_winner(self): cancelled = set() - contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) - for i in range(5)] + contenders = [ + defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(5) + ] d = transit.there_can_be_only_one(contenders) self.assertNoResult(d) contenders[0].errback(ValueError()) @@ -30,12 +36,13 @@ class Highlander(unittest.TestCase): self.assertNoResult(d) contenders[2].callback("yay") self.assertEqual(self.successResultOf(d), "yay") - self.assertEqual(cancelled, set([3,4])) + self.assertEqual(cancelled, set([3, 4])) def test_there_might_also_be_none(self): cancelled = set() - contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) - for i in range(4)] + contenders = [ + defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4) + ] d = transit.there_can_be_only_one(contenders) self.assertNoResult(d) contenders[0].errback(ValueError()) @@ -45,13 +52,14 @@ class Highlander(unittest.TestCase): contenders[2].errback(TypeError()) self.assertNoResult(d) contenders[3].errback(NameError()) - self.failureResultOf(d, ValueError) # first failure is recorded + self.failureResultOf(d, ValueError) # first failure is recorded self.assertEqual(cancelled, set()) def test_cancel_early(self): cancelled = set() - contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) - for i in range(4)] + contenders = [ + defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4) + ] d = transit.there_can_be_only_one(contenders) self.assertNoResult(d) self.assertEqual(cancelled, set()) @@ -61,15 +69,17 @@ class Highlander(unittest.TestCase): def test_cancel_after_one_failure(self): cancelled = set() - contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) - for i in range(4)] + contenders = [ + defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4) + ] d = transit.there_can_be_only_one(contenders) self.assertNoResult(d) self.assertEqual(cancelled, set()) contenders[0].errback(ValueError()) d.cancel() self.failureResultOf(d, ValueError) - self.assertEqual(cancelled, set([1,2,3])) + self.assertEqual(cancelled, set([1, 2, 3])) + class Forever(unittest.TestCase): def _forever_setup(self): @@ -116,6 +126,7 @@ class Forever(unittest.TestCase): self.failureResultOf(d, defer.CancelledError) self.assertNot(clock.getDelayedCalls()) + class Misc(unittest.TestCase): def test_allocate_port(self): portno = transit.allocate_tcp_port() @@ -128,29 +139,36 @@ class Misc(unittest.TestCase): portno = transit.allocate_tcp_port() self.assertIsInstance(portno, int) + UnknownHint = namedtuple("UnknownHint", ["stuff"]) + class Hints(unittest.TestCase): def test_endpoint_from_hint_obj(self): c = transit.Common("") efho = c._endpoint_from_hint_obj - self.assertIsInstance(efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), - endpoints.HostnameEndpoint) + self.assertIsInstance( + efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), + endpoints.HostnameEndpoint) self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) # c._tor is currently None self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None) c._tor = mock.Mock() + def tor_ep(hostname, port): if hostname == "non-public": return None return ("tor_ep", hostname, port) + c._tor.stream_via = mock.Mock(side_effect=tor_ep) - self.assertEqual(efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), - ("tor_ep", "host", 1234)) - self.assertEqual(efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)), - ("tor_ep", "host2.onion", 1234)) - self.assertEqual(efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), - None) + self.assertEqual( + efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), + ("tor_ep", "host", 1234)) + self.assertEqual( + efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)), + ("tor_ep", "host2.onion", 1234)) + self.assertEqual( + efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), None) self.assertEqual(efho(UnknownHint("foo")), None) def test_comparable(self): @@ -170,94 +188,120 @@ class Hints(unittest.TestCase): self.assertEqual(p({"type": "unknown"}), None) h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234}) self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0)) - h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234, - "priority": 2.5}) + h = p({ + "type": "direct-tcp-v1", + "hostname": "foo", + "port": 1234, + "priority": 2.5 + }) self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5)) h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234}) self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0)) - h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234, - "priority": 2.5}) + h = p({ + "type": "tor-tcp-v1", + "hostname": "foo", + "port": 1234, + "priority": 2.5 + }) self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5)) - self.assertEqual(p({"type": "direct-tcp-v1"}), - None) # missing hostname - self.assertEqual(p({"type": "direct-tcp-v1", "hostname": 12}), - None) # invalid hostname - self.assertEqual(p({"type": "direct-tcp-v1", "hostname": "foo"}), - None) # missing port - self.assertEqual(p({"type": "direct-tcp-v1", "hostname": "foo", - "port": "not a number"}), - None) # invalid port + self.assertEqual(p({ + "type": "direct-tcp-v1" + }), None) # missing hostname + self.assertEqual(p({ + "type": "direct-tcp-v1", + "hostname": 12 + }), None) # invalid hostname + self.assertEqual( + p({ + "type": "direct-tcp-v1", + "hostname": "foo" + }), None) # missing port + self.assertEqual( + p({ + "type": "direct-tcp-v1", + "hostname": "foo", + "port": "not a number" + }), None) # invalid port def test_parse_hint_argv(self): def p(hint): stderr = io.StringIO() value = transit.parse_hint_argv(hint, stderr=stderr) return value, stderr.getvalue() - h,stderr = p("tcp:host:1234") + + h, stderr = p("tcp:host:1234") self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) self.assertEqual(stderr, "") - h,stderr = p("tcp:host:1234:priority=2.6") + h, stderr = p("tcp:host:1234:priority=2.6") self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6)) self.assertEqual(stderr, "") - h,stderr = p("tcp:host:1234:unknown=stuff") + h, stderr = p("tcp:host:1234:unknown=stuff") self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) self.assertEqual(stderr, "") - h,stderr = p("$!@#^") + h, stderr = p("$!@#^") self.assertEqual(h, None) self.assertEqual(stderr, "unparseable hint '$!@#^'\n") - h,stderr = p("unknown:stuff") + h, stderr = p("unknown:stuff") self.assertEqual(h, None) self.assertEqual(stderr, "unknown hint type 'unknown' in 'unknown:stuff'\n") - h,stderr = p("tcp:just-a-hostname") + h, stderr = p("tcp:just-a-hostname") self.assertEqual(h, None) - self.assertEqual(stderr, - "unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n") + self.assertEqual( + stderr, + "unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n") - h,stderr = p("tcp:host:number") + h, stderr = p("tcp:host:number") self.assertEqual(h, None) self.assertEqual(stderr, "non-numeric port in TCP hint 'tcp:host:number'\n") - h,stderr = p("tcp:host:1234:priority=bad") + h, stderr = p("tcp:host:1234:priority=bad") self.assertEqual(h, None) - self.assertEqual(stderr, - "non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n") + self.assertEqual( + stderr, + "non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n") def test_describe_hint_obj(self): d = transit.describe_hint_obj - self.assertEqual(d(transit.DirectTCPV1Hint("host", 1234, 0.0)), - "tcp:host:1234") - self.assertEqual(d(transit.TorTCPV1Hint("host", 1234, 0.0)), - "tor:host:1234") + self.assertEqual( + d(transit.DirectTCPV1Hint("host", 1234, 0.0)), "tcp:host:1234") + self.assertEqual( + d(transit.TorTCPV1Hint("host", 1234, 0.0)), "tor:host:1234") self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff"))) + # ipaddrs.py currently uses native strings: bytes on py2, unicode on # py3 if six.PY2: LOOPADDR = b"127.0.0.1" OTHERADDR = b"1.2.3.4" else: - LOOPADDR = "127.0.0.1" # unicode_literals + LOOPADDR = "127.0.0.1" # unicode_literals OTHERADDR = "1.2.3.4" + class Basic(unittest.TestCase): @inlineCallbacks def test_relay_hints(self): URL = "tcp:host:1234" c = transit.Common(URL, no_listen=True) hints = yield c.get_connection_hints() - self.assertEqual(hints, [{"type": "relay-v1", - "hints": [{"type": "direct-tcp-v1", - "hostname": "host", - "port": 1234, - "priority": 0.0}], - }]) + self.assertEqual(hints, [{ + "type": + "relay-v1", + "hints": [{ + "type": "direct-tcp-v1", + "hostname": "host", + "port": 1234, + "priority": 0.0 + }], + }]) self.assertRaises(InternalError, transit.Common, 123) @inlineCallbacks @@ -269,8 +313,12 @@ class Basic(unittest.TestCase): def test_ignore_bad_hints(self): c = transit.Common("") c.add_connection_hints([{"type": "unknown"}]) - c.add_connection_hints([{"type": "relay-v1", - "hints": [{"type": "unknown"}]}]) + c.add_connection_hints([{ + "type": "relay-v1", + "hints": [{ + "type": "unknown" + }] + }]) self.assertEqual(c._their_direct_hints, []) self.assertEqual(c._our_relay_hints, set()) @@ -290,8 +338,9 @@ class Basic(unittest.TestCase): def test_ignore_localhost_hint(self): # this actually starts the listener c = transit.TransitSender("") - with mock.patch("wormhole.ipaddrs.find_addresses", - return_value=[LOOPADDR, OTHERADDR]): + with mock.patch( + "wormhole.ipaddrs.find_addresses", + return_value=[LOOPADDR, OTHERADDR]): hints = self.successResultOf(c.get_connection_hints()) c._stop_listening() # If there are non-localhost hints, then localhost hints should be @@ -302,8 +351,8 @@ class Basic(unittest.TestCase): def test_keep_only_localhost_hint(self): # this actually starts the listener c = transit.TransitSender("") - with mock.patch("wormhole.ipaddrs.find_addresses", - return_value=[LOOPADDR]): + with mock.patch( + "wormhole.ipaddrs.find_addresses", return_value=[LOOPADDR]): hints = self.successResultOf(c.get_connection_hints()) c._stop_listening() # If the only hint is localhost, it should stay. @@ -313,9 +362,14 @@ class Basic(unittest.TestCase): def test_abilities(self): c = transit.Common(None, no_listen=True) abilities = c.get_connection_abilities() - self.assertEqual(abilities, [{"type": "direct-tcp-v1"}, - {"type": "relay-v1"}, - ]) + self.assertEqual(abilities, [ + { + "type": "direct-tcp-v1" + }, + { + "type": "relay-v1" + }, + ]) def test_transit_key_wait(self): KEY = b"123" @@ -339,19 +393,31 @@ class Basic(unittest.TestCase): r = transit.TransitReceiver("") r.set_transit_key(KEY) - self.assertEqual(s._send_this(), b"transit sender 559bdeae4b49fa6a23378d2b68f4c7e69378615d4af049c371c6a26e82391089 ready\n\n") + self.assertEqual(s._send_this(), ( + b"transit sender " + b"559bdeae4b49fa6a23378d2b68f4c7e69378615d4af049c371c6a26e82391089" + b" ready\n\n")) self.assertEqual(s._send_this(), r._expect_this()) - self.assertEqual(r._send_this(), b"transit receiver ed447528194bac4c00d0c854b12a97ce51413d89aa74d6304475f516fdc23a1b ready\n\n") + self.assertEqual(r._send_this(), ( + b"transit receiver " + b"ed447528194bac4c00d0c854b12a97ce51413d89aa74d6304475f516fdc23a1b" + b" ready\n\n")) self.assertEqual(r._send_this(), s._expect_this()) - self.assertEqual(hexlify(s._sender_record_key()), b"5a2fba3a9e524ab2e2823ff53b05f946896f6e4ce4e282ffd8e3ac0e5e9e0cda") - self.assertEqual(hexlify(s._sender_record_key()), - hexlify(r._receiver_record_key())) + self.assertEqual( + hexlify(s._sender_record_key()), + b"5a2fba3a9e524ab2e2823ff53b05f946896f6e4ce4e282ffd8e3ac0e5e9e0cda" + ) + self.assertEqual( + hexlify(s._sender_record_key()), hexlify(r._receiver_record_key())) - self.assertEqual(hexlify(r._sender_record_key()), b"eedb143117249f45b39da324decf6bd9aae33b7ccd58487436de611a3c6b871d") - self.assertEqual(hexlify(r._sender_record_key()), - hexlify(s._receiver_record_key())) + self.assertEqual( + hexlify(r._sender_record_key()), + b"eedb143117249f45b39da324decf6bd9aae33b7ccd58487436de611a3c6b871d" + ) + self.assertEqual( + hexlify(r._sender_record_key()), hexlify(s._receiver_record_key())) def test_connection_ready(self): s = transit.TransitSender("") @@ -407,7 +473,7 @@ class DummyProtocol(protocol.Protocol): def dataReceived(self, data): self.buf += data - #print("oDR", self._count, len(self.buf)) + # print("oDR", self._count, len(self.buf)) if self._count is not None and len(self.buf) >= self._count: got = self.buf[:self._count] self.buf = self.buf[self._count:] @@ -422,19 +488,24 @@ class DummyProtocol(protocol.Protocol): if self._d2: self._d2.callback(None) + class FakeTransport: signalConnectionLost = True + def __init__(self, p, peeraddr): self.protocol = p self._peeraddr = peeraddr self._buf = b"" self._connected = True + def write(self, data): self._buf += data + def loseConnection(self): self._connected = False if self.signalConnectionLost: self.protocol.connectionLost() + def getPeer(self): return self._peeraddr @@ -443,17 +514,21 @@ class FakeTransport: self._buf = b"" return b + class RandomError(Exception): pass + class MockConnection: def __init__(self, owner, relay_handshake, start, description): self.owner = owner self.relay_handshake = relay_handshake self.start = start self._description = description + def cancel(d): self._cancelled = True + self._d = defer.Deferred(cancel) self._start_negotiation_called = False self._cancelled = False @@ -462,6 +537,7 @@ class MockConnection: self._start_negotiation_called = True return self._d + class InboundConnectionFactory(unittest.TestCase): def test_describe(self): f = transit.InboundConnectionFactory(None) @@ -472,8 +548,8 @@ class InboundConnectionFactory(unittest.TestCase): addr6 = address.IPv6Address("TCP", "::1", 1234) self.assertEqual(f._describePeer(addr6), "<-::1:1234") addrU = address.UNIXAddress("/dev/unlikely") - self.assertEqual(f._describePeer(addrU), - "<-UNIXAddress('/dev/unlikely')") + self.assertEqual( + f._describePeer(addrU), "<-UNIXAddress('/dev/unlikely')") def test_success(self): f = transit.InboundConnectionFactory("owner") @@ -586,8 +662,10 @@ class InboundConnectionFactory(unittest.TestCase): self.assertEqual(p1._cancelled, True) self.assertEqual(p2._cancelled, True) + # XXX check descriptions + class OutboundConnectionFactory(unittest.TestCase): def test_success(self): f = transit.OutboundConnectionFactory("owner", "relay_handshake", @@ -603,31 +681,39 @@ class OutboundConnectionFactory(unittest.TestCase): # meh .start # this is normally called from Connection.connectionMade - f.connectionWasMade(p) # no-op for outbound + f.connectionWasMade(p) # no-op for outbound self.assertEqual(p._start_negotiation_called, False) class MockOwner: _connection_ready_called = False + def connection_ready(self, connection): self._connection_ready_called = True self._connection = connection return self._state + def _send_this(self): return b"send_this" + def _expect_this(self): return b"expect_this" + def _sender_record_key(self): - return b"s"*32 + return b"s" * 32 + def _receiver_record_key(self): - return b"r"*32 + return b"r" * 32 + class MockFactory: _connectionWasMade_called = False + def connectionWasMade(self, p): self._connectionWasMade_called = True self._p = p + class Connection(unittest.TestCase): # exercise the Connection protocol class @@ -640,8 +726,8 @@ class Connection(unittest.TestCase): c.buf = b"unexpected" e = self.assertRaises(transit.BadHandshake, c._check_and_remove, EXP) - self.assertEqual(str(e), - "got %r want %r" % (b'unexpected', b'expectation')) + self.assertEqual( + str(e), "got %r want %r" % (b'unexpected', b'expectation')) self.assertEqual(c.buf, b"unexpected") c.buf = b"expect" @@ -770,12 +856,12 @@ class Connection(unittest.TestCase): c.connectionMade() self.assertEqual(factory._connectionWasMade_called, True) self.assertEqual(factory._p, c) - self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation + self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation owner._state = "go" d = c.startNegotiation() self.assertEqual(t.read_buf(), relay_handshake) - self.assertEqual(c.state, "relay") # waiting for OK from relay + self.assertEqual(c.state, "relay") # waiting for OK from relay c.dataReceived(b"ok\n") self.assertEqual(t.read_buf(), b"send_this") @@ -801,20 +887,20 @@ class Connection(unittest.TestCase): c.connectionMade() self.assertEqual(factory._connectionWasMade_called, True) self.assertEqual(factory._p, c) - self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation + self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation owner._state = "go" d = c.startNegotiation() self.assertEqual(t.read_buf(), relay_handshake) - self.assertEqual(c.state, "relay") # waiting for OK from relay + self.assertEqual(c.state, "relay") # waiting for OK from relay c.dataReceived(b"not ok\n") self.assertEqual(t._connected, False) self.assertEqual(c.state, "hung up") f = self.failureResultOf(d, transit.BadHandshake) - self.assertEqual(str(f.value), - "got %r want %r" % (b"not ok\n", b"ok\n")) + self.assertEqual( + str(f.value), "got %r want %r" % (b"not ok\n", b"ok\n")) def test_receiver_accepted(self): # we're on the receiving side, so we wait for the sender to decide @@ -866,12 +952,12 @@ class Connection(unittest.TestCase): self.assertEqual(c.state, "wait-for-decision") self.assertNoResult(d) - c.dataReceived(b"nevermind\n") # polite rejection + c.dataReceived(b"nevermind\n") # polite rejection self.assertEqual(t._connected, False) self.assertEqual(c.state, "hung up") f = self.failureResultOf(d, transit.BadHandshake) - self.assertEqual(str(f.value), - "got %r want %r" % (b"nevermind\n", b"go\n")) + self.assertEqual( + str(f.value), "got %r want %r" % (b"nevermind\n", b"go\n")) def test_receiver_rejected_rudely(self): # we're on the receiving side, so we wait for the sender to decide @@ -901,7 +987,6 @@ class Connection(unittest.TestCase): f = self.failureResultOf(d, transit.BadHandshake) self.assertEqual(str(f.value), "connection lost") - def test_cancel(self): owner = MockOwner() factory = MockFactory() @@ -926,8 +1011,10 @@ class Connection(unittest.TestCase): factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) c = transit.Connection(owner, None, None, "description") + def _callLater(period, func): clock.callLater(period, func) + c.callLater = _callLater self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) @@ -955,7 +1042,7 @@ class Connection(unittest.TestCase): d = c.startNegotiation() c.dataReceived(b"expect_this") self.assertEqual(self.successResultOf(d), c) - t.read_buf() # flush input buffer, prepare for encrypted records + t.read_buf() # flush input buffer, prepare for encrypted records return t, c, owner @@ -973,13 +1060,13 @@ class Connection(unittest.TestCase): RECORD1 = b"record" c.send_record(RECORD1) buf = t.read_buf() - expected = ("%08x" % (24+len(RECORD1)+16)).encode("ascii") + expected = ("%08x" % (24 + len(RECORD1) + 16)).encode("ascii") self.assertEqual(hexlify(buf[:4]), expected) encrypted = buf[4:] receive_box = SecretBox(owner._sender_record_key()) - nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended + nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended nonce = int(hexlify(nonce_buf), 16) - self.assertEqual(nonce, 0) # first message gets nonce 0 + self.assertEqual(nonce, 0) # first message gets nonce 0 decrypted = receive_box.decrypt(encrypted) self.assertEqual(decrypted, RECORD1) @@ -987,11 +1074,11 @@ class Connection(unittest.TestCase): RECORD2 = b"record2" c.send_record(RECORD2) buf = t.read_buf() - expected = ("%08x" % (24+len(RECORD2)+16)).encode("ascii") + expected = ("%08x" % (24 + len(RECORD2) + 16)).encode("ascii") self.assertEqual(hexlify(buf[:4]), expected) encrypted = buf[4:] receive_box = SecretBox(owner._sender_record_key()) - nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended + nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended nonce = int(hexlify(nonce_buf), 16) self.assertEqual(nonce, 1) decrypted = receive_box.decrypt(encrypted) @@ -1003,9 +1090,9 @@ class Connection(unittest.TestCase): send_box = SecretBox(owner._receiver_record_key()) RECORD3 = b"record3" - nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0 + nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0 encrypted = send_box.encrypt(RECORD3, nonce_buf) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long c.dataReceived(length[:2]) c.dataReceived(length[2:]) c.dataReceived(encrypted[:-2]) @@ -1014,9 +1101,9 @@ class Connection(unittest.TestCase): self.assertEqual(inbound_records, [RECORD3]) RECORD4 = b"record4" - nonce_buf = unhexlify("%048x" % 1) # nonces increment + nonce_buf = unhexlify("%048x" % 1) # nonces increment encrypted = send_box.encrypt(RECORD4, nonce_buf) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long c.dataReceived(length[:2]) c.dataReceived(length[2:]) c.dataReceived(encrypted[:-2]) @@ -1027,16 +1114,16 @@ class Connection(unittest.TestCase): # receiving two records at the same time: deliver both inbound_records[:] = [] RECORD5 = b"record5" - nonce_buf = unhexlify("%048x" % 2) # nonces increment + nonce_buf = unhexlify("%048x" % 2) # nonces increment encrypted = send_box.encrypt(RECORD5, nonce_buf) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long - r5 = length+encrypted + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + r5 = length + encrypted RECORD6 = b"record6" - nonce_buf = unhexlify("%048x" % 3) # nonces increment + nonce_buf = unhexlify("%048x" % 3) # nonces increment encrypted = send_box.encrypt(RECORD6, nonce_buf) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long - r6 = length+encrypted - c.dataReceived(r5+r6) + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + r6 = length + encrypted + c.dataReceived(r5 + r6) self.assertEqual(inbound_records, [RECORD5, RECORD6]) def corrupt(self, orig): @@ -1055,9 +1142,9 @@ class Connection(unittest.TestCase): RECORD = b"record" send_box = SecretBox(owner._receiver_record_key()) - nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0 + nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0 encrypted = self.corrupt(send_box.encrypt(RECORD, nonce_buf)) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long c.dataReceived(length) c.dataReceived(encrypted[:-2]) self.assertEqual(inbound_records, []) @@ -1075,9 +1162,9 @@ class Connection(unittest.TestCase): RECORD = b"record" send_box = SecretBox(owner._receiver_record_key()) - nonce_buf = unhexlify("%048x" % 1) # first nonce must be 0 + nonce_buf = unhexlify("%048x" % 1) # first nonce must be 0 encrypted = send_box.encrypt(RECORD, nonce_buf) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long c.dataReceived(length) c.dataReceived(encrypted[:-2]) self.assertEqual(inbound_records, []) @@ -1159,7 +1246,7 @@ class Connection(unittest.TestCase): # connectConsumer() takes an optional number of bytes to expect, and # fires a Deferred when that many have been written c = transit.Connection(None, None, None, "description") - c._negotiation_d.addErrback(lambda err: None) # eat it + c._negotiation_d.addErrback(lambda err: None) # eat it c.transport = proto_helpers.StringTransport() c.recordReceived(b"r1.") @@ -1197,7 +1284,7 @@ class Connection(unittest.TestCase): # zero-length file), make sure it gets woken up right away, so it can # disconnect itself, even though no bytes will actually arrive c = transit.Connection(None, None, None, "description") - c._negotiation_d.addErrback(lambda err: None) # eat it + c._negotiation_d.addErrback(lambda err: None) # eat it c.transport = proto_helpers.StringTransport() consumer = proto_helpers.StringTransport() @@ -1208,7 +1295,7 @@ class Connection(unittest.TestCase): def test_writeToFile(self): c = transit.Connection(None, None, None, "description") - c._negotiation_d.addErrback(lambda err: None) # eat it + c._negotiation_d.addErrback(lambda err: None) # eat it c.transport = proto_helpers.StringTransport() c.recordReceived(b"r1.") @@ -1242,11 +1329,11 @@ class Connection(unittest.TestCase): self.assertEqual(progress, [3, 3, 3, 1]) # test what happens when enough data is queued ahead of time - c.recordReceived(b"second.") # now "overflow.second." - c.recordReceived(b"third.") # now "overflow.second.third." + c.recordReceived(b"second.") # now "overflow.second." + c.recordReceived(b"third.") # now "overflow.second.third." f = io.BytesIO() d = c.writeToFile(f, 10) - self.assertEqual(f.getvalue(), b"overflow.second.") # whole records + self.assertEqual(f.getvalue(), b"overflow.second.") # whole records self.assertEqual(self.successResultOf(d), 16) self.assertEqual(list(c._inbound_records), [b"third."]) @@ -1273,6 +1360,7 @@ class Connection(unittest.TestCase): c.unregisterProducer() self.assertEqual(c.transport.producer, None) + class FileConsumer(unittest.TestCase): def test_basic(self): f = io.BytesIO() @@ -1280,12 +1368,12 @@ class FileConsumer(unittest.TestCase): fc = transit.FileConsumer(f, progress.append) self.assertEqual(progress, []) self.assertEqual(f.getvalue(), b"") - fc.write(b"."* 99) + fc.write(b"." * 99) self.assertEqual(progress, [99]) - self.assertEqual(f.getvalue(), b"."*99) + self.assertEqual(f.getvalue(), b"." * 99) fc.write(b"!") self.assertEqual(progress, [99, 1]) - self.assertEqual(f.getvalue(), b"."*99+b"!") + self.assertEqual(f.getvalue(), b"." * 99 + b"!") def test_hasher(self): hashee = [] @@ -1295,32 +1383,53 @@ class FileConsumer(unittest.TestCase): self.assertEqual(progress, []) self.assertEqual(f.getvalue(), b"") self.assertEqual(hashee, []) - fc.write(b"."* 99) + fc.write(b"." * 99) self.assertEqual(progress, [99]) - self.assertEqual(f.getvalue(), b"."*99) - self.assertEqual(hashee, [b"."*99]) + self.assertEqual(f.getvalue(), b"." * 99) + self.assertEqual(hashee, [b"." * 99]) fc.write(b"!") self.assertEqual(progress, [99, 1]) - self.assertEqual(f.getvalue(), b"."*99+b"!") - self.assertEqual(hashee, [b"."*99, b"!"]) + self.assertEqual(f.getvalue(), b"." * 99 + b"!") + self.assertEqual(hashee, [b"." * 99, b"!"]) -DIRECT_HINT_JSON = {"type": "direct-tcp-v1", - "hostname": "direct", "port": 1234} -RELAY_HINT_JSON = {"type": "relay-v1", - "hints": [{"type": "direct-tcp-v1", - "hostname": "relay", "port": 1234}]} -UNRECOGNIZED_DIRECT_HINT_JSON = {"type": "direct-tcp-v1", - "hostname": ["cannot", "parse", "list"]} +DIRECT_HINT_JSON = { + "type": "direct-tcp-v1", + "hostname": "direct", + "port": 1234 +} +RELAY_HINT_JSON = { + "type": "relay-v1", + "hints": [{ + "type": "direct-tcp-v1", + "hostname": "relay", + "port": 1234 + }] +} +UNRECOGNIZED_DIRECT_HINT_JSON = { + "type": "direct-tcp-v1", + "hostname": ["cannot", "parse", "list"] +} UNRECOGNIZED_HINT_JSON = {"type": "unknown"} -UNAVAILABLE_HINT_JSON = {"type": "direct-tcp-v1", # e.g. Tor without txtorcon - "hostname": "unavailable", "port": 1234} -RELAY_HINT2_JSON = {"type": "relay-v1", - "hints": [{"type": "direct-tcp-v1", - "hostname": "relay", "port": 1234}, - UNRECOGNIZED_HINT_JSON]} -UNAVAILABLE_RELAY_HINT_JSON = {"type": "relay-v1", - "hints": [UNAVAILABLE_HINT_JSON]} +UNAVAILABLE_HINT_JSON = { + "type": "direct-tcp-v1", # e.g. Tor without txtorcon + "hostname": "unavailable", + "port": 1234 +} +RELAY_HINT2_JSON = { + "type": + "relay-v1", + "hints": [{ + "type": "direct-tcp-v1", + "hostname": "relay", + "port": 1234 + }, UNRECOGNIZED_HINT_JSON] +} +UNAVAILABLE_RELAY_HINT_JSON = { + "type": "relay-v1", + "hints": [UNAVAILABLE_HINT_JSON] +} + class Transit(unittest.TestCase): def setUp(self): @@ -1340,11 +1449,12 @@ class Transit(unittest.TestCase): clock = task.Clock() s = transit.TransitSender("", reactor=clock) s.set_transit_key(b"key") - hints = yield s.get_connection_hints() # start the listener + hints = yield s.get_connection_hints() # start the listener del hints - s.add_connection_hints([DIRECT_HINT_JSON, - UNRECOGNIZED_DIRECT_HINT_JSON, - UNRECOGNIZED_HINT_JSON]) + s.add_connection_hints([ + DIRECT_HINT_JSON, UNRECOGNIZED_DIRECT_HINT_JSON, + UNRECOGNIZED_HINT_JSON + ]) s._start_connector = self._start_connector d = s.connect() @@ -1361,7 +1471,7 @@ class Transit(unittest.TestCase): clock = task.Clock() s = transit.TransitSender("", tor=mock.Mock(), reactor=clock) s.set_transit_key(b"key") - hints = yield s.get_connection_hints() # start the listener + hints = yield s.get_connection_hints() # start the listener del hints s.add_connection_hints([DIRECT_HINT_JSON]) @@ -1380,7 +1490,7 @@ class Transit(unittest.TestCase): clock = task.Clock() s = transit.TransitSender("", tor=mock.Mock(), reactor=clock) s.set_transit_key(b"key") - hints = yield s.get_connection_hints() # start the listener + hints = yield s.get_connection_hints() # start the listener del hints s.add_connection_hints([RELAY_HINT_JSON]) @@ -1411,9 +1521,8 @@ class Transit(unittest.TestCase): s.set_transit_key(b"key") hints = yield s.get_connection_hints() del hints - s.add_connection_hints([DIRECT_HINT_JSON, - UNRECOGNIZED_HINT_JSON, - RELAY_HINT_JSON]) + s.add_connection_hints( + [DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON, RELAY_HINT_JSON]) s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector @@ -1437,20 +1546,46 @@ class Transit(unittest.TestCase): hints = yield s.get_connection_hints() del hints s.add_connection_hints([ - {"type": "relay-v1", - "hints": [{"type": "direct-tcp-v1", - "hostname": "relay", "port": 1234}]}, - {"type": "direct-tcp-v1", - "hostname": "direct", "port": 1234}, - {"type": "relay-v1", - "hints": [{"type": "direct-tcp-v1", "priority": 2.0, - "hostname": "relay2", "port": 1234}, - {"type": "direct-tcp-v1", "priority": 3.0, - "hostname": "relay3", "port": 1234}]}, - {"type": "relay-v1", - "hints": [{"type": "direct-tcp-v1", "priority": 2.0, - "hostname": "relay4", "port": 1234}]}, - ]) + { + "type": + "relay-v1", + "hints": [{ + "type": "direct-tcp-v1", + "hostname": "relay", + "port": 1234 + }] + }, + { + "type": "direct-tcp-v1", + "hostname": "direct", + "port": 1234 + }, + { + "type": + "relay-v1", + "hints": [{ + "type": "direct-tcp-v1", + "priority": 2.0, + "hostname": "relay2", + "port": 1234 + }, { + "type": "direct-tcp-v1", + "priority": 3.0, + "hostname": "relay3", + "port": 1234 + }] + }, + { + "type": + "relay-v1", + "hints": [{ + "type": "direct-tcp-v1", + "priority": 2.0, + "hostname": "relay4", + "port": 1234 + }] + }, + ]) s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector @@ -1482,13 +1617,13 @@ class Transit(unittest.TestCase): clock = task.Clock() s = transit.TransitSender("", reactor=clock, no_listen=True) s.set_transit_key(b"key") - hints = yield s.get_connection_hints() # start the listener + hints = yield s.get_connection_hints() # start the listener del hints # include hints that can't be turned into an endpoint at runtime - s.add_connection_hints([UNRECOGNIZED_HINT_JSON, - UNAVAILABLE_HINT_JSON, - RELAY_HINT2_JSON, - UNAVAILABLE_RELAY_HINT_JSON]) + s.add_connection_hints([ + UNRECOGNIZED_HINT_JSON, UNAVAILABLE_HINT_JSON, RELAY_HINT2_JSON, + UNAVAILABLE_RELAY_HINT_JSON + ]) s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector @@ -1509,9 +1644,9 @@ class Transit(unittest.TestCase): clock = task.Clock() s = transit.TransitSender("", reactor=clock, no_listen=True) s.set_transit_key(b"key") - hints = yield s.get_connection_hints() # start the listener + hints = yield s.get_connection_hints() # start the listener del hints - s.add_connection_hints([]) # no hints at all + s.add_connection_hints([]) # no hints at all s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector @@ -1519,10 +1654,11 @@ class Transit(unittest.TestCase): f = self.failureResultOf(d, transit.TransitError) self.assertEqual(str(f.value), "No contenders for connection") + class RelayHandshake(unittest.TestCase): def old_build_relay_handshake(self, key): token = transit.HKDF(key, 32, CTXinfo=b"transit_relay_token") - return (token, b"please relay "+hexlify(token)+b"\n") + return (token, b"please relay " + hexlify(token) + b"\n") def test_old(self): key = b"\x00" @@ -1548,8 +1684,9 @@ class RelayHandshake(unittest.TestCase): tc.dataReceived(new_handshake[:-1]) self.assertEqual(tc.factory.connection_got_token.mock_calls, []) tc.dataReceived(new_handshake[-1:]) - self.assertEqual(tc.factory.connection_got_token.mock_calls, - [mock.call(hexlify(token), c._side.encode("ascii"), tc)]) + self.assertEqual( + tc.factory.connection_got_token.mock_calls, + [mock.call(hexlify(token), c._side.encode("ascii"), tc)]) class Full(ServerBase, unittest.TestCase): @@ -1558,7 +1695,7 @@ class Full(ServerBase, unittest.TestCase): @inlineCallbacks def test_direct(self): - KEY = b"k"*32 + KEY = b"k" * 32 s = transit.TransitSender(None) r = transit.TransitReceiver(None) @@ -1571,7 +1708,7 @@ class Full(ServerBase, unittest.TestCase): s.add_connection_hints(rhints) r.add_connection_hints(shints) - (x,y) = yield self.doBoth(s.connect(), r.connect()) + (x, y) = yield self.doBoth(s.connect(), r.connect()) self.assertIsInstance(x, transit.Connection) self.assertIsInstance(y, transit.Connection) @@ -1586,7 +1723,7 @@ class Full(ServerBase, unittest.TestCase): @inlineCallbacks def test_relay(self): - KEY = b"k"*32 + KEY = b"k" * 32 s = transit.TransitSender(self.transit, no_listen=True) r = transit.TransitReceiver(self.transit, no_listen=True) @@ -1599,7 +1736,7 @@ class Full(ServerBase, unittest.TestCase): s.add_connection_hints(rhints) r.add_connection_hints(shints) - (x,y) = yield self.doBoth(s.connect(), r.connect()) + (x, y) = yield self.doBoth(s.connect(), r.connect()) self.assertIsInstance(x, transit.Connection) self.assertIsInstance(y, transit.Connection) diff --git a/src/wormhole/test/test_util.py b/src/wormhole/test/test_util.py index a939b91..e3c7d96 100644 --- a/src/wormhole/test/test_util.py +++ b/src/wormhole/test/test_util.py @@ -1,10 +1,15 @@ from __future__ import unicode_literals -import six -import mock + import unicodedata + +import six from twisted.trial import unittest + +import mock + from .. import util + class Utils(unittest.TestCase): def test_to_bytes(self): b = util.to_bytes("abc") @@ -41,11 +46,12 @@ class Utils(unittest.TestCase): self.assertIsInstance(d, dict) self.assertEqual(d, {"a": "b", "c": 2}) + class Space(unittest.TestCase): def test_free_space(self): free = util.estimate_free_space(".") - self.assert_(isinstance(free, six.integer_types + (type(None),)), - repr(free)) + self.assert_( + isinstance(free, six.integer_types + (type(None), )), repr(free)) # some platforms (I think the VMs used by travis are in this # category) return 0, and windows will return None, so don't assert # anything more specific about the return value @@ -56,5 +62,5 @@ class Space(unittest.TestCase): try: with mock.patch("os.statvfs", side_effect=AttributeError()): self.assertEqual(util.estimate_free_space("."), None) - except AttributeError: # raised by mock.get_original() + except AttributeError: # raised by mock.get_original() pass diff --git a/src/wormhole/test/test_wordlist.py b/src/wormhole/test/test_wordlist.py index 6b86cdb..06eb13d 100644 --- a/src/wormhole/test/test_wordlist.py +++ b/src/wormhole/test/test_wordlist.py @@ -1,8 +1,12 @@ from __future__ import print_function, unicode_literals -import mock + from twisted.trial import unittest + +import mock + from .._wordlist import PGPWordList + class Completions(unittest.TestCase): def test_completions(self): wl = PGPWordList() @@ -14,16 +18,21 @@ class Completions(unittest.TestCase): self.assertEqual(len(lots), 256, lots) first = list(lots)[0] self.assert_(first.startswith("armistice-"), first) - self.assertEqual(gc("armistice-ba", 2), - {"armistice-baboon", "armistice-backfield", - "armistice-backward", "armistice-banjo"}) - self.assertEqual(gc("armistice-ba", 3), - {"armistice-baboon-", "armistice-backfield-", - "armistice-backward-", "armistice-banjo-"}) + self.assertEqual( + gc("armistice-ba", 2), { + "armistice-baboon", "armistice-backfield", + "armistice-backward", "armistice-banjo" + }) + self.assertEqual( + gc("armistice-ba", 3), { + "armistice-baboon-", "armistice-backfield-", + "armistice-backward-", "armistice-banjo-" + }) self.assertEqual(gc("armistice-baboon", 2), {"armistice-baboon"}) self.assertEqual(gc("armistice-baboon", 3), {"armistice-baboon-"}) self.assertEqual(gc("armistice-baboon", 4), {"armistice-baboon-"}) + class Choose(unittest.TestCase): def test_choose_words(self): wl = PGPWordList() diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index df5830f..0362c65 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -1,17 +1,22 @@ from __future__ import print_function, unicode_literals -import io, re -import mock -from twisted.trial import unittest + +import io +import re + from twisted.internet import reactor from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue from twisted.internet.error import ConnectionRefusedError -from .common import ServerBase, poll_until -from .. import wormhole, _rendezvous -from ..errors import (WrongPasswordError, ServerConnectionError, - KeyFormatError, WormholeClosed, LonelyError, - NoKeyError, OnlyOneCodeError) -from ..transit import allocate_tcp_port +from twisted.trial import unittest + +import mock + +from .. import _rendezvous, wormhole +from ..errors import (KeyFormatError, LonelyError, NoKeyError, + OnlyOneCodeError, ServerConnectionError, WormholeClosed, + WrongPasswordError) from ..eventual import EventualQueue +from ..transit import allocate_tcp_port +from .common import ServerBase, poll_until APPID = "appid" @@ -26,6 +31,7 @@ APPID = "appid" # * set_code, then connected # * connected, receive_pake, send_phase, set_code + class Delegate: def __init__(self): self.welcome = None @@ -35,28 +41,35 @@ class Delegate: self.versions = None self.messages = [] self.closed = None + def wormhole_got_welcome(self, welcome): self.welcome = welcome + def wormhole_got_code(self, code): self.code = code + def wormhole_got_unverified_key(self, key): self.key = key + def wormhole_got_verifier(self, verifier): self.verifier = verifier + def wormhole_got_versions(self, versions): self.versions = versions + def wormhole_got_message(self, data): self.messages.append(data) + def wormhole_closed(self, result): self.closed = result -class Delegated(ServerBase, unittest.TestCase): +class Delegated(ServerBase, unittest.TestCase): @inlineCallbacks def test_delegated(self): dg = Delegate() w1 = wormhole.create(APPID, self.relayurl, reactor, delegate=dg) - #w1.debug_set_trace("W1") + # w1.debug_set_trace("W1") with self.assertRaises(NoKeyError): w1.derive_key("purpose", 12) w1.set_code("1-abc") @@ -103,6 +116,7 @@ class Delegated(ServerBase, unittest.TestCase): yield poll_until(lambda: dg.code is not None) w1.close() + class Wormholes(ServerBase, unittest.TestCase): # integration test, with a real server @@ -135,12 +149,12 @@ class Wormholes(ServerBase, unittest.TestCase): @inlineCallbacks def test_basic(self): w1 = wormhole.create(APPID, self.relayurl, reactor) - #w1.debug_set_trace("W1") + # w1.debug_set_trace("W1") with self.assertRaises(NoKeyError): w1.derive_key("purpose", 12) w2 = wormhole.create(APPID, self.relayurl, reactor) - #w2.debug_set_trace(" W2") + # w2.debug_set_trace(" W2") w1.allocate_code() code = yield w1.get_code() w2.set_code(code) @@ -302,7 +316,6 @@ class Wormholes(ServerBase, unittest.TestCase): yield w1.close() yield w2.close() - @inlineCallbacks def test_multiple_messages(self): w1 = wormhole.create(APPID, self.relayurl, reactor) @@ -322,7 +335,6 @@ class Wormholes(ServerBase, unittest.TestCase): yield w1.close() yield w2.close() - @inlineCallbacks def test_closed(self): eq = EventualQueue(reactor) @@ -377,7 +389,7 @@ class Wormholes(ServerBase, unittest.TestCase): w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq) w1.allocate_code() code = yield w1.get_code() - w2.set_code(code+"not") + w2.set_code(code + "not") code2 = yield w2.get_code() self.assertNotEqual(code, code2) # That's enough to allow both sides to discover the mismatch, but @@ -387,9 +399,9 @@ class Wormholes(ServerBase, unittest.TestCase): w1.send_message(b"should still work") w2.send_message(b"should still work") - key2 = yield w2.get_unverified_key() # should work + key2 = yield w2.get_unverified_key() # should work # w2 has just received w1.PAKE, and is about to send w2.VERSION - key1 = yield w1.get_unverified_key() # should work + key1 = yield w1.get_unverified_key() # should work # w1 has just received w2.PAKE, and is about to send w1.VERSION, and # then will receive w2.VERSION. When it sees w2.VERSION, it will # learn about the WrongPasswordError. @@ -451,7 +463,7 @@ class Wormholes(ServerBase, unittest.TestCase): badcode = "4 oops spaces" with self.assertRaises(KeyFormatError) as ex: w.set_code(badcode) - expected_msg = "Code '%s' contains spaces." % (badcode,) + expected_msg = "Code '%s' contains spaces." % (badcode, ) self.assertEqual(expected_msg, str(ex.exception)) yield self.assertFailure(w.close(), LonelyError) @@ -461,7 +473,7 @@ class Wormholes(ServerBase, unittest.TestCase): badcode = " 4-oops-space" with self.assertRaises(KeyFormatError) as ex: w.set_code(badcode) - expected_msg = "Code '%s' contains spaces." % (badcode,) + expected_msg = "Code '%s' contains spaces." % (badcode, ) self.assertEqual(expected_msg, str(ex.exception)) yield self.assertFailure(w.close(), LonelyError) @@ -478,8 +490,8 @@ class Wormholes(ServerBase, unittest.TestCase): @inlineCallbacks def test_welcome(self): w1 = wormhole.create(APPID, self.relayurl, reactor) - wel1 = yield w1.get_welcome() # early: before connection established - wel2 = yield w1.get_welcome() # late: already received welcome + wel1 = yield w1.get_welcome() # early: before connection established + wel2 = yield w1.get_welcome() # late: already received welcome self.assertEqual(wel1, wel2) self.assertIn("current_cli_version", wel1) @@ -489,7 +501,7 @@ class Wormholes(ServerBase, unittest.TestCase): w2.set_code("123-NOT") yield self.assertFailure(w1.get_verifier(), WrongPasswordError) - yield self.assertFailure(w1.get_welcome(), WrongPasswordError) # late + yield self.assertFailure(w1.get_welcome(), WrongPasswordError) # late yield self.assertFailure(w1.close(), WrongPasswordError) yield self.assertFailure(w2.close(), WrongPasswordError) @@ -502,7 +514,7 @@ class Wormholes(ServerBase, unittest.TestCase): w1.allocate_code() code = yield w1.get_code() w2.set_code(code) - v1 = yield w1.get_verifier() # early + v1 = yield w1.get_verifier() # early v2 = yield w2.get_verifier() self.failUnlessEqual(type(v1), type(b"")) self.failUnlessEqual(v1, v2) @@ -525,10 +537,10 @@ class Wormholes(ServerBase, unittest.TestCase): @inlineCallbacks def test_versions(self): # there's no API for this yet, but make sure the internals work - w1 = wormhole.create(APPID, self.relayurl, reactor, - versions={"w1": 123}) - w2 = wormhole.create(APPID, self.relayurl, reactor, - versions={"w2": 456}) + w1 = wormhole.create( + APPID, self.relayurl, reactor, versions={"w1": 123}) + w2 = wormhole.create( + APPID, self.relayurl, reactor, versions={"w2": 456}) w1.allocate_code() code = yield w1.get_code() w2.set_code(code) @@ -564,17 +576,19 @@ class Wormholes(ServerBase, unittest.TestCase): yield w1.close() yield w2.close() + class MessageDoubler(_rendezvous.RendezvousConnector): # we could double messages on the sending side, but a future server will # strip those duplicates, so to really exercise the receiver, we must # double them on the inbound side instead - #def _msg_send(self, phase, body): - # wormhole._Wormhole._msg_send(self, phase, body) - # self._ws_send_command("add", phase=phase, body=bytes_to_hexstr(body)) + # def _msg_send(self, phase, body): + # wormhole._Wormhole._msg_send(self, phase, body) + # self._ws_send_command("add", phase=phase, body=bytes_to_hexstr(body)) def _response_handle_message(self, msg): _rendezvous.RendezvousConnector._response_handle_message(self, msg) _rendezvous.RendezvousConnector._response_handle_message(self, msg) + class Errors(ServerBase, unittest.TestCase): @inlineCallbacks def test_derive_key_early(self): @@ -602,16 +616,17 @@ class Errors(ServerBase, unittest.TestCase): w.set_code("123-nope") yield self.assertFailure(w.close(), LonelyError) + class Reconnection(ServerBase, unittest.TestCase): @inlineCallbacks def test_basic(self): w1 = wormhole.create(APPID, self.relayurl, reactor) w1_in = [] w1._boss._RC._debug_record_inbound_f = w1_in.append - #w1.debug_set_trace("W1") + # w1.debug_set_trace("W1") w1.allocate_code() code = yield w1.get_code() - w1.send_message(b"data1") # queued until wormhole is established + w1.send_message(b"data1") # queued until wormhole is established # now wait until we've deposited all our messages on the server def seen_our_pake(): @@ -619,6 +634,7 @@ class Reconnection(ServerBase, unittest.TestCase): if m["type"] == "message" and m["phase"] == "pake": return True return False + yield poll_until(seen_our_pake) w1_in[:] = [] @@ -634,7 +650,7 @@ class Reconnection(ServerBase, unittest.TestCase): # receiver has started w2 = wormhole.create(APPID, self.relayurl, reactor) - #w2.debug_set_trace(" W2") + # w2.debug_set_trace(" W2") w2.set_code(code) dataY = yield w2.get_message() @@ -649,6 +665,7 @@ class Reconnection(ServerBase, unittest.TestCase): c2 = yield w2.close() self.assertEqual(c2, "happy") + class InitialFailure(unittest.TestCase): @inlineCallbacks def assertSCEFailure(self, eq, d, innerType): @@ -662,8 +679,8 @@ class InitialFailure(unittest.TestCase): def test_bad_dns(self): eq = EventualQueue(reactor) # point at a URL that will never connect - w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", - reactor, _eventual_queue=eq) + w = wormhole.create( + APPID, "ws://%%%.example.org:4000/v1", reactor, _eventual_queue=eq) # that should have already received an error, when it tried to # resolve the bogus DNS name. All API calls will return an error. @@ -717,6 +734,7 @@ class InitialFailure(unittest.TestCase): yield self.assertSCE(d4, ConnectionRefusedError) yield self.assertSCE(d5, ConnectionRefusedError) + class Trace(unittest.TestCase): def test_basic(self): w1 = wormhole.create(APPID, "ws://localhost:1", reactor) @@ -734,13 +752,11 @@ class Trace(unittest.TestCase): ["C1.M1[OLD].IN -> [NEW]"]) out("OUT1") self.assertEqual(stderr.getvalue().splitlines(), - ["C1.M1[OLD].IN -> [NEW]", - " C1.M1.OUT1()"]) + ["C1.M1[OLD].IN -> [NEW]", " C1.M1.OUT1()"]) w1._boss._print_trace("", "R.connected", "", "C1", "RC1", stderr) - self.assertEqual(stderr.getvalue().splitlines(), - ["C1.M1[OLD].IN -> [NEW]", - " C1.M1.OUT1()", - "C1.RC1.R.connected"]) + self.assertEqual( + stderr.getvalue().splitlines(), + ["C1.M1[OLD].IN -> [NEW]", " C1.M1.OUT1()", "C1.RC1.R.connected"]) def test_delegated(self): dg = Delegate() diff --git a/src/wormhole/test/test_xfer_util.py b/src/wormhole/test/test_xfer_util.py index 1c08f3a..673910a 100644 --- a/src/wormhole/test/test_xfer_util.py +++ b/src/wormhole/test/test_xfer_util.py @@ -1,11 +1,13 @@ -from twisted.trial import unittest -from twisted.internet import reactor, defer +from twisted.internet import defer, reactor from twisted.internet.defer import inlineCallbacks +from twisted.trial import unittest + from .. import xfer_util from .common import ServerBase APPID = u"appid" + class Xfer(ServerBase, unittest.TestCase): @inlineCallbacks def test_xfer(self): @@ -24,10 +26,15 @@ class Xfer(ServerBase, unittest.TestCase): data = u"data" send_code = [] receive_code = [] - d1 = xfer_util.send(reactor, APPID, self.relayurl, data, code, - on_code=send_code.append) - d2 = xfer_util.receive(reactor, APPID, self.relayurl, code, - on_code=receive_code.append) + d1 = xfer_util.send( + reactor, + APPID, + self.relayurl, + data, + code, + on_code=send_code.append) + d2 = xfer_util.receive( + reactor, APPID, self.relayurl, code, on_code=receive_code.append) send_result = yield d1 receive_result = yield d2 self.assertEqual(send_code, [code]) @@ -39,8 +46,13 @@ class Xfer(ServerBase, unittest.TestCase): def test_make_code(self): data = u"data" got_code = defer.Deferred() - d1 = xfer_util.send(reactor, APPID, self.relayurl, data, code=None, - on_code=got_code.callback) + d1 = xfer_util.send( + reactor, + APPID, + self.relayurl, + data, + code=None, + on_code=got_code.callback) code = yield got_code d2 = xfer_util.receive(reactor, APPID, self.relayurl, code) send_result = yield d1 diff --git a/src/wormhole/timing.py b/src/wormhole/timing.py index 8cb18e5..ecb29be 100644 --- a/src/wormhole/timing.py +++ b/src/wormhole/timing.py @@ -1,8 +1,13 @@ -from __future__ import print_function, absolute_import, unicode_literals -import json, time +from __future__ import absolute_import, print_function, unicode_literals + +import json +import time + from zope.interface import implementer + from ._interfaces import ITiming + class Event: def __init__(self, name, when, **details): # data fields that will be dumped to JSON later @@ -35,6 +40,7 @@ class Event: else: self.finish() + @implementer(ITiming) class DebugTiming: def __init__(self): @@ -47,11 +53,14 @@ class DebugTiming: def write(self, fn, stderr): with open(fn, "wt") as f: - data = [ dict(name=e._name, - start=e._start, stop=e._stop, - details=e._details, - ) - for e in self._events ] + data = [ + dict( + name=e._name, + start=e._start, + stop=e._stop, + details=e._details, + ) for e in self._events + ] json.dump(data, f, indent=1) f.write("\n") print("Timing data written to %s" % fn, file=stderr) diff --git a/src/wormhole/tor_manager.py b/src/wormhole/tor_manager.py index 0f8a449..eea12d6 100644 --- a/src/wormhole/tor_manager.py +++ b/src/wormhole/tor_manager.py @@ -1,15 +1,20 @@ from __future__ import print_function, unicode_literals + import sys -from attr import attrs, attrib -from zope.interface.declarations import directlyProvides + +from attr import attrib, attrs from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.endpoints import clientFromString +from zope.interface.declarations import directlyProvides + +from . import _interfaces, errors +from .timing import DebugTiming + try: import txtorcon except ImportError: txtorcon = None -from . import _interfaces, errors -from .timing import DebugTiming + @attrs class SocksOnlyTor(object): @@ -17,15 +22,20 @@ class SocksOnlyTor(object): def stream_via(self, host, port, tls=False): return txtorcon.TorClientEndpoint( - host, port, - socks_endpoint=None, # tries localhost:9050 and 9150 + host, + port, + socks_endpoint=None, # tries localhost:9050 and 9150 tls=tls, reactor=self._reactor, ) + @inlineCallbacks -def get_tor(reactor, launch_tor=False, tor_control_port=None, - timing=None, stderr=sys.stderr): +def get_tor(reactor, + launch_tor=False, + tor_control_port=None, + timing=None, + stderr=sys.stderr): """ If launch_tor=True, I will try to launch a new Tor process, ask it for its SOCKS and control ports, and use those for outbound @@ -59,7 +69,7 @@ def get_tor(reactor, launch_tor=False, tor_control_port=None, if not txtorcon: raise errors.NoTorError() - if not isinstance(launch_tor, bool): # note: False is int + if not isinstance(launch_tor, bool): # note: False is int raise TypeError("launch_tor= must be boolean") if not isinstance(tor_control_port, (type(""), type(None))): raise TypeError("tor_control_port= must be str or None") @@ -74,19 +84,21 @@ def get_tor(reactor, launch_tor=False, tor_control_port=None, # need the control port. if launch_tor: - print(" launching a new Tor process, this may take a while..", - file=stderr) + print( + " launching a new Tor process, this may take a while..", + file=stderr) with timing.add("launch tor"): tor = yield txtorcon.launch(reactor, - #data_directory=, - #tor_binary=, + # data_directory=, + # tor_binary=, ) elif tor_control_port: with timing.add("find tor"): control_ep = clientFromString(reactor, tor_control_port) - tor = yield txtorcon.connect(reactor, control_ep) # might raise - print(" using Tor via control port at %s" % tor_control_port, - file=stderr) + tor = yield txtorcon.connect(reactor, control_ep) # might raise + print( + " using Tor via control port at %s" % tor_control_port, + file=stderr) else: # Let txtorcon look through a list of usual places. If that fails, # we'll arrange to attempt the default SOCKS port @@ -98,8 +110,9 @@ def get_tor(reactor, launch_tor=False, tor_control_port=None, # TODO: make this more specific. I think connect() is # likely to throw a reactor.connectTCP -type error, like # ConnectionFailed or ConnectionRefused or something - print(" unable to find default Tor control port, using SOCKS", - file=stderr) + print( + " unable to find default Tor control port, using SOCKS", + file=stderr) tor = SocksOnlyTor(reactor) directlyProvides(tor, _interfaces.ITorManager) returnValue(tor) diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index f8f824a..6363c6d 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -1,38 +1,51 @@ # no unicode_literals, revisit after twisted patch -from __future__ import print_function, absolute_import -import os, re, sys, time, socket -from collections import namedtuple, deque +from __future__ import absolute_import, print_function + +import os +import re +import socket +import sys +import time from binascii import hexlify, unhexlify +from collections import deque, namedtuple + import six -from zope.interface import implementer -from twisted.python import log -from twisted.python.runtime import platformType -from twisted.internet import (reactor, interfaces, defer, protocol, - endpoints, task, address, error) +from hkdf import Hkdf +from nacl.secret import SecretBox +from twisted.internet import (address, defer, endpoints, error, interfaces, + protocol, reactor, task) from twisted.internet.defer import inlineCallbacks, returnValue from twisted.protocols import policies -from nacl.secret import SecretBox -from hkdf import Hkdf +from twisted.python import log +from twisted.python.runtime import platformType +from zope.interface import implementer + +from . import ipaddrs from .errors import InternalError from .timing import DebugTiming from .util import bytes_to_hexstr -from . import ipaddrs + def HKDF(skm, outlen, salt=None, CTXinfo=b""): return Hkdf(salt, skm).expand(CTXinfo, outlen) + class TransitError(Exception): pass + class BadHandshake(Exception): pass + class TransitClosed(TransitError): pass + class BadNonce(TransitError): pass + # The beginning of each TCP connection consists of the following handshake # messages. The sender transmits the same text regardless of whether it is on # the initiating/connecting end of the TCP connection, or on the @@ -63,19 +76,23 @@ class BadNonce(TransitError): # RXID_HEX ready\n\n" and then makes a first/not-first decision about sending # "go\n" or "nevermind\n"+close(). + def build_receiver_handshake(key): hexid = HKDF(key, 32, CTXinfo=b"transit_receiver") - return b"transit receiver "+hexlify(hexid)+b" ready\n\n" + return b"transit receiver " + hexlify(hexid) + b" ready\n\n" + def build_sender_handshake(key): hexid = HKDF(key, 32, CTXinfo=b"transit_sender") - return b"transit sender "+hexlify(hexid)+b" ready\n\n" + return b"transit sender " + hexlify(hexid) + b" ready\n\n" + def build_sided_relay_handshake(key, side): assert isinstance(side, type(u"")) - assert len(side) == 8*2 + assert len(side) == 8 * 2 token = HKDF(key, 32, CTXinfo=b"transit_relay_token") - return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n" + return b"please relay " + hexlify(token) + b" for side " + side.encode( + "ascii") + b"\n" # These namedtuples are "hint objects". The JSON-serializable dictionaries @@ -87,7 +104,8 @@ def build_sided_relay_handshake(key, side): # * expect to see the receiver/sender handshake bytes from the other side # * the sender writes "go\n", the receiver waits for "go\n" # * the rest of the connection contains transit data -DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"]) +DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", + ["hostname", "port", "priority"]) TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) # RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we # use a tuple rather than a list so they'll be hashable into a set). For each @@ -95,6 +113,7 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) # rest of the V1 protocol. Only one hint per relay is useful. RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) + def describe_hint_obj(hint): if isinstance(hint, DirectTCPV1Hint): return u"tcp:%s:%d" % (hint.hostname, hint.port) @@ -103,27 +122,30 @@ def describe_hint_obj(hint): else: return str(hint) + def parse_hint_argv(hint, stderr=sys.stderr): assert isinstance(hint, type(u"")) # return tuple or None for an unparseable hint priority = 0.0 mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) if not mo: - print("unparseable hint '%s'" % (hint,), file=stderr) + print("unparseable hint '%s'" % (hint, ), file=stderr) return None hint_type = mo.group(1) if hint_type != "tcp": - print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) + print( + "unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) return None hint_value = mo.group(2) pieces = hint_value.split(":") if len(pieces) < 2: - print("unparseable TCP hint (need more colons) '%s'" % (hint,), - file=stderr) + print( + "unparseable TCP hint (need more colons) '%s'" % (hint, ), + file=stderr) return None mo = re.search(r'^(\d+)$', pieces[1]) if not mo: - print("non-numeric port in TCP hint '%s'" % (hint,), file=stderr) + print("non-numeric port in TCP hint '%s'" % (hint, ), file=stderr) return None hint_host = pieces[0] hint_port = int(pieces[1]) @@ -133,12 +155,15 @@ def parse_hint_argv(hint, stderr=sys.stderr): try: priority = float(more_pieces[1]) except ValueError: - print("non-float priority= in TCP hint '%s'" % (hint,), - file=stderr) + print( + "non-float priority= in TCP hint '%s'" % (hint, ), + file=stderr) return None return DirectTCPV1Hint(hint_host, hint_port, priority) -TIMEOUT = 60 # seconds + +TIMEOUT = 60 # seconds + @implementer(interfaces.IProducer, interfaces.IConsumer) class Connection(protocol.Protocol, policies.TimeoutMixin): @@ -159,7 +184,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): self._waiting_reads = deque() def connectionMade(self): - self.setTimeout(TIMEOUT) # does timeoutConnection() when it expires + self.setTimeout(TIMEOUT) # does timeoutConnection() when it expires self.factory.connectionWasMade(self) def startNegotiation(self): @@ -168,11 +193,11 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): self.state = "relay" else: self.state = "start" - self.dataReceived(b"") # cycle the state machine + self.dataReceived(b"") # cycle the state machine return self._negotiation_d def _cancel(self, d): - self.state = "hung up" # stop reacting to anything further + self.state = "hung up" # stop reacting to anything further self._error = defer.CancelledError() self.transport.loseConnection() # if connectionLost isn't called synchronously, then our @@ -181,7 +206,6 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): if self._negotiation_d: self._negotiation_d = None - def dataReceived(self, data): try: self._dataReceived(data) @@ -198,7 +222,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): if not self.buf.startswith(expected[:len(self.buf)]): raise BadHandshake("got %r want %r" % (self.buf, expected)) if len(self.buf) < len(expected): - return False # keep waiting + return False # keep waiting self.buf = self.buf[len(expected):] return True @@ -245,9 +269,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): return self.dataReceivedRECORDS() if self.state == "hung up": return - if isinstance(self.state, Exception): # for tests + if isinstance(self.state, Exception): # for tests raise self.state - raise ValueError("internal error: unknown state %s" % (self.state,)) + raise ValueError("internal error: unknown state %s" % (self.state, )) def _negotiationSuccessful(self): self.state = "records" @@ -266,19 +290,20 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): if len(self.buf) < 4: return length = int(hexlify(self.buf[:4]), 16) - if len(self.buf) < 4+length: + if len(self.buf) < 4 + length: return - encrypted, self.buf = self.buf[4:4+length], self.buf[4+length:] + encrypted, self.buf = self.buf[4:4 + length], self.buf[4 + length:] record = self._decrypt_record(encrypted) self.recordReceived(record) def _decrypt_record(self, encrypted): - nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended + nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended nonce = int(hexlify(nonce_buf), 16) if nonce != self.next_receive_nonce: - raise BadNonce("received out-of-order record: got %d, expected %d" - % (nonce, self.next_receive_nonce)) + raise BadNonce( + "received out-of-order record: got %d, expected %d" % + (nonce, self.next_receive_nonce)) self.next_receive_nonce += 1 record = self.receive_box.decrypt(encrypted) return record @@ -287,14 +312,15 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): return self._description def send_record(self, record): - if not isinstance(record, type(b"")): raise InternalError + if not isinstance(record, type(b"")): + raise InternalError assert SecretBox.NONCE_SIZE == 24 - assert self.send_nonce < 2**(8*24) - assert len(record) < 2**(8*4) - nonce = unhexlify("%048x" % self.send_nonce) # big-endian + assert self.send_nonce < 2**(8 * 24) + assert len(record) < 2**(8 * 4) + nonce = unhexlify("%048x" % self.send_nonce) # big-endian self.send_nonce += 1 encrypted = self.send_box.encrypt(record, nonce) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long self.transport.write(length) self.transport.write(encrypted) @@ -353,8 +379,10 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): def registerProducer(self, producer, streaming): assert interfaces.IConsumer.providedBy(self.transport) self.transport.registerProducer(producer, streaming) + def unregisterProducer(self): self.transport.unregisterProducer() + def write(self, data): self.send_record(data) @@ -362,8 +390,10 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): # the transport. def stopProducing(self): self.transport.stopProducing() + def pauseProducing(self): self.transport.pauseProducing() + def resumeProducing(self): self.transport.resumeProducing() @@ -384,8 +414,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): Deferred, and you must call disconnectConsumer() when you are done.""" if self._consumer: - raise RuntimeError("A consumer is already attached: %r" % - self._consumer) + raise RuntimeError( + "A consumer is already attached: %r" % self._consumer) # be aware of an ordering hazard: when we call the consumer's # .registerProducer method, they are likely to immediately call @@ -440,6 +470,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): fc = FileConsumer(f, progress, hasher) return self.connectConsumer(fc, expected) + class OutboundConnectionFactory(protocol.ClientFactory): protocol = Connection @@ -478,7 +509,7 @@ class InboundConnectionFactory(protocol.ClientFactory): def _shutdown(self): for d in list(self._pending_connections): - d.cancel() # that fires _remove and _proto_failed + d.cancel() # that fires _remove and _proto_failed def _describePeer(self, addr): if isinstance(addr, address.HostnameAddress): @@ -511,6 +542,7 @@ class InboundConnectionFactory(protocol.ClientFactory): # ignore these two, let Twisted log everything else f.trap(BadHandshake, defer.CancelledError) + def allocate_tcp_port(): """Return an (integer) available TCP port on localhost. This briefly listens on the port in question, then closes it right away.""" @@ -527,6 +559,7 @@ def allocate_tcp_port(): s.close() return port + class _ThereCanBeOnlyOne: """Accept a list of contender Deferreds, and return a summary Deferred. When the first contender fires successfully, cancel the rest and fire the @@ -535,6 +568,7 @@ class _ThereCanBeOnlyOne: status_cb=? """ + def __init__(self, contenders): self._remaining = set(contenders) self._winner_d = defer.Deferred(self._cancel) @@ -581,26 +615,32 @@ class _ThereCanBeOnlyOne: else: self._winner_d.errback(self._first_failure) + def there_can_be_only_one(contenders): return _ThereCanBeOnlyOne(contenders).run() + class Common: RELAY_DELAY = 2.0 TRANSIT_KEY_LENGTH = SecretBox.KEY_SIZE - def __init__(self, transit_relay, no_listen=False, tor=None, - reactor=reactor, timing=None): - self._side = bytes_to_hexstr(os.urandom(8)) # unicode + def __init__(self, + transit_relay, + no_listen=False, + tor=None, + reactor=reactor, + timing=None): + self._side = bytes_to_hexstr(os.urandom(8)) # unicode if transit_relay: if not isinstance(transit_relay, type(u"")): raise InternalError # TODO: allow multiple hints for a single relay relay_hint = parse_hint_argv(transit_relay) - relay = RelayV1Hint(hints=(relay_hint,)) + relay = RelayV1Hint(hints=(relay_hint, )) self._transit_relays = [relay] else: self._transit_relays = [] - self._their_direct_hints = [] # hintobjs + self._their_direct_hints = [] # hintobjs self._our_relay_hints = set(self._transit_relays) self._tor = tor self._transit_key = None @@ -622,33 +662,42 @@ class Common: # some test hosts, including the appveyor VMs, *only* have # 127.0.0.1, and the tests will hang badly if we remove it. addresses = non_loopback_addresses - direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0) - for addr in addresses] + direct_hints = [ + DirectTCPV1Hint(six.u(addr), portnum, 0.0) for addr in addresses + ] ep = endpoints.serverFromString(reactor, "tcp:%d" % portnum) return direct_hints, ep def get_connection_abilities(self): - return [{u"type": u"direct-tcp-v1"}, - {u"type": u"relay-v1"}, - ] + return [ + { + u"type": u"direct-tcp-v1" + }, + { + u"type": u"relay-v1" + }, + ] @inlineCallbacks def get_connection_hints(self): hints = [] direct_hints = yield self._get_direct_hints() for dh in direct_hints: - hints.append({u"type": u"direct-tcp-v1", - u"priority": dh.priority, - u"hostname": dh.hostname, - u"port": dh.port, # integer - }) + hints.append({ + u"type": u"direct-tcp-v1", + u"priority": dh.priority, + u"hostname": dh.hostname, + u"port": dh.port, # integer + }) for relay in self._transit_relays: rhint = {u"type": u"relay-v1", u"hints": []} for rh in relay.hints: - rhint[u"hints"].append({u"type": u"direct-tcp-v1", - u"priority": rh.priority, - u"hostname": rh.hostname, - u"port": rh.port}) + rhint[u"hints"].append({ + u"type": u"direct-tcp-v1", + u"priority": rh.priority, + u"hostname": rh.hostname, + u"port": rh.port + }) hints.append(rhint) returnValue(hints) @@ -665,24 +714,27 @@ class Common: # listener will win. self._my_direct_hints, self._listener = self._build_listener() - if self._listener is None: # don't listen + if self._listener is None: # don't listen self._listener_d = None - return defer.succeed(self._my_direct_hints) # empty + return defer.succeed(self._my_direct_hints) # empty # Start the server, so it will be running by the time anyone tries to # connect to the direct hints we return. f = InboundConnectionFactory(self) - self._listener_f = f # for tests # XX move to __init__ ? + self._listener_f = f # for tests # XX move to __init__ ? self._listener_d = f.whenDone() d = self._listener.listen(f) + def _listening(lp): # lp is an IListeningPort - #self._listener_port = lp # for tests + # self._listener_port = lp # for tests def _stop_listening(res): lp.stopListening() return res + self._listener_d.addBoth(_stop_listening) return self._my_direct_hints + d.addCallback(_listening) return d @@ -694,18 +746,18 @@ class Common: self._listener_d.addErrback(lambda f: None) self._listener_d.cancel() - def _parse_tcp_v1_hint(self, hint): # hint_struct -> hint_obj + def _parse_tcp_v1_hint(self, hint): # hint_struct -> hint_obj hint_type = hint.get(u"type", u"") if hint_type not in [u"direct-tcp-v1", u"tor-tcp-v1"]: - log.msg("unknown hint type: %r" % (hint,)) + log.msg("unknown hint type: %r" % (hint, )) return None - if not(u"hostname" in hint - and isinstance(hint[u"hostname"], type(u""))): - log.msg("invalid hostname in hint: %r" % (hint,)) + if not (u"hostname" in hint + and isinstance(hint[u"hostname"], type(u""))): + log.msg("invalid hostname in hint: %r" % (hint, )) return None - if not(u"port" in hint - and isinstance(hint[u"port"], six.integer_types)): - log.msg("invalid port in hint: %r" % (hint,)) + if not (u"port" in hint + and isinstance(hint[u"port"], six.integer_types)): + log.msg("invalid port in hint: %r" % (hint, )) return None priority = hint.get(u"priority", 0.0) if hint_type == u"direct-tcp-v1": @@ -714,12 +766,12 @@ class Common: return TorTCPV1Hint(hint[u"hostname"], hint[u"port"], priority) def add_connection_hints(self, hints): - for h in hints: # hint structs + for h in hints: # hint structs hint_type = h.get(u"type", u"") if hint_type in [u"direct-tcp-v1", u"tor-tcp-v1"]: dh = self._parse_tcp_v1_hint(h) if dh: - self._their_direct_hints.append(dh) # hint_obj + self._their_direct_hints.append(dh) # hint_obj elif hint_type == u"relay-v1": # TODO: each relay-v1 clause describes a different relay, # with a set of equally-valid ways to connect to it. Treat @@ -734,7 +786,7 @@ class Common: rh = RelayV1Hint(hints=tuple(sorted(relay_hints))) self._our_relay_hints.add(rh) else: - log.msg("unknown hint type: %r" % (h,)) + log.msg("unknown hint type: %r" % (h, )) def _send_this(self): assert self._transit_key @@ -748,25 +800,33 @@ class Common: if self.is_sender: return build_receiver_handshake(self._transit_key) else: - return build_sender_handshake(self._transit_key)# + b"go\n" + return build_sender_handshake(self._transit_key) # + b"go\n" def _sender_record_key(self): assert self._transit_key if self.is_sender: - return HKDF(self._transit_key, SecretBox.KEY_SIZE, - CTXinfo=b"transit_record_sender_key") + return HKDF( + self._transit_key, + SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_sender_key") else: - return HKDF(self._transit_key, SecretBox.KEY_SIZE, - CTXinfo=b"transit_record_receiver_key") + return HKDF( + self._transit_key, + SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_receiver_key") def _receiver_record_key(self): assert self._transit_key if self.is_sender: - return HKDF(self._transit_key, SecretBox.KEY_SIZE, - CTXinfo=b"transit_record_receiver_key") + return HKDF( + self._transit_key, + SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_receiver_key") else: - return HKDF(self._transit_key, SecretBox.KEY_SIZE, - CTXinfo=b"transit_record_sender_key") + return HKDF( + self._transit_key, + SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_sender_key") def set_transit_key(self, key): assert isinstance(key, type(b"")), type(key) @@ -848,9 +908,13 @@ class Common: description = "->relay:%s" % describe_hint_obj(hint_obj) if self._tor: description = "tor" + description - d = task.deferLater(self._reactor, relay_delay, - self._start_connector, ep, description, - is_relay=True) + d = task.deferLater( + self._reactor, + relay_delay, + self._start_connector, + ep, + description, + is_relay=True) contenders.append(d) relay_delay += self.RELAY_DELAY @@ -858,16 +922,18 @@ class Common: raise TransitError("No contenders for connection") winner = there_can_be_only_one(contenders) - return self._not_forever(2*TIMEOUT, winner) + return self._not_forever(2 * TIMEOUT, winner) def _not_forever(self, timeout, d): """If the timer fires first, cancel the deferred. If the deferred fires first, cancel the timer.""" t = self._reactor.callLater(timeout, d.cancel) + def _done(res): if t.active(): t.cancel() return res + d.addBoth(_done) return d @@ -896,8 +962,8 @@ class Common: return None return None if isinstance(hint, DirectTCPV1Hint): - return endpoints.HostnameEndpoint(self._reactor, - hint.hostname, hint.port) + return endpoints.HostnameEndpoint(self._reactor, hint.hostname, + hint.port) return None def connection_ready(self, p): @@ -915,9 +981,11 @@ class Common: self._winner = p return "go" + class TransitSender(Common): is_sender = True + class TransitReceiver(Common): is_sender = False @@ -926,6 +994,7 @@ class TransitReceiver(Common): # when done, and add a progress function that gets called with the length of # each write, and a hasher function that gets called with the data. + @implementer(interfaces.IConsumer) class FileConsumer: def __init__(self, f, progress=None, hasher=None): @@ -950,6 +1019,7 @@ class FileConsumer: assert self._producer self._producer = None + # the TransitSender/Receiver.connect() yields a Connection, on which you can # do send_record(), but what should the receive API be? set a callback for # inbound records? get a Deferred for the next record? The producer/consumer diff --git a/src/wormhole/util.py b/src/wormhole/util.py index b5e39fb..0b57c5e 100644 --- a/src/wormhole/util.py +++ b/src/wormhole/util.py @@ -1,30 +1,42 @@ # No unicode_literals -import os, json, unicodedata +import json +import os +import unicodedata from binascii import hexlify, unhexlify + def to_bytes(u): return unicodedata.normalize("NFC", u).encode("utf-8") + + def bytes_to_hexstr(b): assert isinstance(b, type(b"")) hexstr = hexlify(b).decode("ascii") assert isinstance(hexstr, type(u"")) return hexstr + + def hexstr_to_bytes(hexstr): assert isinstance(hexstr, type(u"")) b = unhexlify(hexstr.encode("ascii")) assert isinstance(b, type(b"")) return b + + def dict_to_bytes(d): assert isinstance(d, dict) b = json.dumps(d).encode("utf-8") assert isinstance(b, type(b"")) return b + + def bytes_to_dict(b): assert isinstance(b, type(b"")) d = json.loads(b.decode("utf-8")) assert isinstance(d, dict) return d + def estimate_free_space(target): # f_bfree is the blocks available to a root user. It might be more # accurate to use f_bavail (blocks available to non-root user), but we diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 311e3fc..610069c 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -1,19 +1,25 @@ -from __future__ import print_function, absolute_import, unicode_literals -import os, sys -from attr import attrs, attrib -from zope.interface import implementer +from __future__ import absolute_import, print_function, unicode_literals + +import os +import sys + +from attr import attrib, attrs from twisted.python import failure -from . import __version__ -from ._interfaces import IWormhole, IDeferredWormhole -from .util import bytes_to_hexstr -from .eventual import EventualQueue -from .observer import OneShotObserver, SequenceObserver -from .timing import DebugTiming -from .journal import ImmediateJournal +from zope.interface import implementer + from ._boss import Boss +from ._interfaces import IDeferredWormhole, IWormhole from ._key import derive_key from .errors import NoKeyError, WormholeClosed -from .util import to_bytes +from .eventual import EventualQueue +from .journal import ImmediateJournal +from .observer import OneShotObserver, SequenceObserver +from .timing import DebugTiming +from .util import bytes_to_hexstr, to_bytes +from ._version import get_versions + +__version__ = get_versions()['version'] +del get_versions # We can provide different APIs to different apps: # * Deferreds @@ -36,6 +42,7 @@ from .util import to_bytes # wormhole(delegate=app, delegate_prefix="wormhole_", # delegate_args=(args, kwargs)) + @attrs @implementer(IWormhole) class _DelegatedWormhole(object): @@ -51,16 +58,18 @@ class _DelegatedWormhole(object): def allocate_code(self, code_length=2): self._boss.allocate_code(code_length) + def input_code(self): return self._boss.input_code() + def set_code(self, code): self._boss.set_code(code) - ## def serialize(self): - ## s = {"serialized_wormhole_version": 1, - ## "boss": self._boss.serialize(), - ## } - ## return s + # def serialize(self): + # s = {"serialized_wormhole_version": 1, + # "boss": self._boss.serialize(), + # } + # return s def send_message(self, plaintext): self._boss.send(plaintext) @@ -72,34 +81,45 @@ class _DelegatedWormhole(object): cannot be called until when_verifier() has fired, nor after close() was called. """ - if not isinstance(purpose, type("")): raise TypeError(type(purpose)) - if not self._key: raise NoKeyError() + if not isinstance(purpose, type("")): + raise TypeError(type(purpose)) + if not self._key: + raise NoKeyError() return derive_key(self._key, to_bytes(purpose), length) def close(self): self._boss.close() - def debug_set_trace(self, client_name, which="B N M S O K SK R RC L C T", + def debug_set_trace(self, + client_name, + which="B N M S O K SK R RC L C T", file=sys.stderr): self._boss._set_trace(client_name, which, file) # from below def got_welcome(self, welcome): self._delegate.wormhole_got_welcome(welcome) + def got_code(self, code): self._delegate.wormhole_got_code(code) + def got_key(self, key): self._delegate.wormhole_got_unverified_key(key) - self._key = key # for derive_key() + self._key = key # for derive_key() + def got_verifier(self, verifier): self._delegate.wormhole_got_verifier(verifier) + def got_versions(self, versions): self._delegate.wormhole_got_versions(versions) + def received(self, plaintext): self._delegate.wormhole_got_message(plaintext) + def closed(self, result): self._delegate.wormhole_closed(result) + @implementer(IWormhole, IDeferredWormhole) class _DeferredWormhole(object): def __init__(self, eq): @@ -142,8 +162,10 @@ class _DeferredWormhole(object): def allocate_code(self, code_length=2): self._boss.allocate_code(code_length) + def input_code(self): return self._boss.input_code() + def set_code(self, code): self._boss.set_code(code) @@ -159,20 +181,23 @@ class _DeferredWormhole(object): cannot be called until when_verified() has fired, nor after close() was called. """ - if not isinstance(purpose, type("")): raise TypeError(type(purpose)) - if not self._key: raise NoKeyError() + if not isinstance(purpose, type("")): + raise TypeError(type(purpose)) + if not self._key: + raise NoKeyError() return derive_key(self._key, to_bytes(purpose), length) def close(self): # fails with WormholeError unless we established a connection # (state=="happy"). Fails with WrongPasswordError (a subclass of # WormholeError) if state=="scary". - d = self._closed_observer.when_fired() # maybe Failure + d = self._closed_observer.when_fired() # maybe Failure if not self._closed: - self._boss.close() # only need to close if it wasn't already + self._boss.close() # only need to close if it wasn't already return d - def debug_set_trace(self, client_name, + def debug_set_trace(self, + client_name, which="B N M S O K SK R RC L A I C T", file=sys.stderr): self._boss._set_trace(client_name, which, file) @@ -180,14 +205,17 @@ class _DeferredWormhole(object): # from below def got_welcome(self, welcome): self._welcome_observer.fire_if_not_fired(welcome) + def got_code(self, code): self._code_observer.fire_if_not_fired(code) + def got_key(self, key): - self._key = key # for derive_key() + self._key = key # for derive_key() self._key_observer.fire_if_not_fired(key) def got_verifier(self, verifier): self._verifier_observer.fire_if_not_fired(verifier) + def got_versions(self, versions): self._version_observer.fire_if_not_fired(versions) @@ -196,7 +224,7 @@ class _DeferredWormhole(object): def closed(self, result): self._closed = True - #print("closed", result, type(result), file=sys.stderr) + # print("closed", result, type(result), file=sys.stderr) if isinstance(result, Exception): # everything pending gets an error, including close() f = failure.Failure(result) @@ -215,12 +243,17 @@ class _DeferredWormhole(object): self._received_observer.fire(f) -def create(appid, relay_url, reactor, # use keyword args for everything else - versions={}, - delegate=None, journal=None, tor=None, - timing=None, - stderr=sys.stderr, - _eventual_queue=None): +def create( + appid, + relay_url, + reactor, # use keyword args for everything else + versions={}, + delegate=None, + journal=None, + tor=None, + timing=None, + stderr=sys.stderr, + _eventual_queue=None): timing = timing or DebugTiming() side = bytes_to_hexstr(os.urandom(5)) journal = journal or ImmediateJournal() @@ -229,27 +262,28 @@ def create(appid, relay_url, reactor, # use keyword args for everything else w = _DelegatedWormhole(delegate) else: w = _DeferredWormhole(eq) - wormhole_versions = {} # will be used to indicate Wormhole capabilities - wormhole_versions["app_versions"] = versions # app-specific capabilities + wormhole_versions = {} # will be used to indicate Wormhole capabilities + wormhole_versions["app_versions"] = versions # app-specific capabilities v = __version__ if isinstance(v, type(b"")): v = v.decode("utf-8", errors="replace") client_version = ("python", v) b = Boss(w, side, relay_url, appid, wormhole_versions, client_version, - reactor, journal, tor, timing) + reactor, journal, tor, timing) w._set_boss(b) b.start() return w -## def from_serialized(serialized, reactor, delegate, -## journal=None, tor=None, -## timing=None, stderr=sys.stderr): -## assert serialized["serialized_wormhole_version"] == 1 -## timing = timing or DebugTiming() -## w = _DelegatedWormhole(delegate) -## # now unpack state machines, including the SPAKE2 in Key -## b = Boss.from_serialized(w, serialized["boss"], reactor, journal, timing) -## w._set_boss(b) -## b.start() # ?? -## raise NotImplemented -## # should the new Wormhole call got_code? only if it wasn't called before. + +# def from_serialized(serialized, reactor, delegate, +# journal=None, tor=None, +# timing=None, stderr=sys.stderr): +# assert serialized["serialized_wormhole_version"] == 1 +# timing = timing or DebugTiming() +# w = _DelegatedWormhole(delegate) +# # now unpack state machines, including the SPAKE2 in Key +# b = Boss.from_serialized(w, serialized["boss"], reactor, journal, timing) +# w._set_boss(b) +# b.start() # ?? +# raise NotImplemented +# # should the new Wormhole call got_code? only if it wasn't called before. diff --git a/src/wormhole/xfer_util.py b/src/wormhole/xfer_util.py index c62aa12..465c05d 100644 --- a/src/wormhole/xfer_util.py +++ b/src/wormhole/xfer_util.py @@ -1,12 +1,19 @@ import json + from twisted.internet.defer import inlineCallbacks, returnValue from . import wormhole from .tor_manager import get_tor + @inlineCallbacks -def receive(reactor, appid, relay_url, code, - use_tor=False, launch_tor=False, tor_control_port=None, +def receive(reactor, + appid, + relay_url, + code, + use_tor=False, + launch_tor=False, + tor_control_port=None, on_code=None): """ This is a convenience API which returns a Deferred that callbacks @@ -21,9 +28,11 @@ def receive(reactor, appid, relay_url, code, :param unicode code: a pre-existing code to use, or None - :param bool use_tor: True if we should use Tor, False to not use it (None for default) + :param bool use_tor: True if we should use Tor, False to not use it (None + for default) - :param on_code: if not None, this is called when we have a code (even if you passed in one explicitly) + :param on_code: if not None, this is called when we have a code (even if + you passed in one explicitly) :type on_code: single-argument callable """ tor = None @@ -48,27 +57,33 @@ def receive(reactor, appid, relay_url, code, data = json.loads(data.decode("utf-8")) offer = data.get('offer', None) if not offer: - raise Exception( - "Do not understand response: {}".format(data) - ) + raise Exception("Do not understand response: {}".format(data)) msg = None if 'message' in offer: msg = offer['message'] - wh.send_message(json.dumps({"answer": - {"message_ack": "ok"}}).encode("utf-8")) + wh.send_message( + json.dumps({ + "answer": { + "message_ack": "ok" + } + }).encode("utf-8")) else: - raise Exception( - "Unknown offer type: {}".format(offer.keys()) - ) + raise Exception("Unknown offer type: {}".format(offer.keys())) yield wh.close() returnValue(msg) @inlineCallbacks -def send(reactor, appid, relay_url, data, code, - use_tor=False, launch_tor=False, tor_control_port=None, +def send(reactor, + appid, + relay_url, + data, + code, + use_tor=False, + launch_tor=False, + tor_control_port=None, on_code=None): """ This is a convenience API which returns a Deferred that callbacks @@ -83,9 +98,12 @@ def send(reactor, appid, relay_url, data, code, :param unicode code: a pre-existing code to use, or None - :param bool use_tor: True if we should use Tor, False to not use it (None for default) + :param bool use_tor: True if we should use Tor, False to not use it (None + for default) + + :param on_code: if not None, this is called when we have a code (even if + you passed in one explicitly) - :param on_code: if not None, this is called when we have a code (even if you passed in one explicitly) :type on_code: single-argument callable """ tor = None @@ -104,13 +122,7 @@ def send(reactor, appid, relay_url, data, code, if on_code: on_code(code) - wh.send_message( - json.dumps({ - "offer": { - "message": data - } - }).encode("utf-8") - ) + wh.send_message(json.dumps({"offer": {"message": data}}).encode("utf-8")) data = yield wh.get_message() data = json.loads(data.decode("utf-8")) answer = data.get('answer', None) @@ -118,6 +130,4 @@ def send(reactor, appid, relay_url, data, code, if answer: returnValue(None) else: - raise Exception( - "Unknown answer: {}".format(data) - ) + raise Exception("Unknown answer: {}".format(data)) From aac5980bf452c0bec851417e0d3a38c0d00ee2ab Mon Sep 17 00:00:00 2001 From: Vasudev Kamath Date: Sat, 21 Apr 2018 13:00:25 +0530 Subject: [PATCH 2/7] Enable checking for pep8 confirmance in tox Also make sure to ignore E741 naming a variable as l will raise this error and reason is l is similar to 1 and people might get confused. For me it doesn't look like an error hence ignored in tox.ini --- tox.ini | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 1943025..7944deb 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ # and then run "tox" from this directory. [tox] -envlist = {py27,py34,py35,py36,pypy} +envlist = {py27,py34,py35,py36,pypy,flake8} skip_missing_interpreters = True minversion = 2.4.0 @@ -36,3 +36,12 @@ commands = wormhole --version coverage run --branch -m wormhole.test.run_trial {posargs:wormhole} coverage xml + +[testenv:flake8] +deps = flake8 +commands = flake8 src/wormhole + +[flake8] +ignore = E741 +exclude = .git,__pycache__,docs/source/conf.py,old,build,dist +max-complexity = 10 \ No newline at end of file From b216c0adf3512d677ed6e6f016ee1dcee883e906 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 16 Jun 2018 16:13:22 -0700 Subject: [PATCH 3/7] tox: increase flake8 complexity limit to 40, we can lower later --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 7944deb..a9f863f 100644 --- a/tox.ini +++ b/tox.ini @@ -44,4 +44,4 @@ commands = flake8 src/wormhole [flake8] ignore = E741 exclude = .git,__pycache__,docs/source/conf.py,old,build,dist -max-complexity = 10 \ No newline at end of file +max-complexity = 40 From dada79d85cf7d896ed2d4813268fec913c08ffbf Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 16 Jun 2018 16:17:26 -0700 Subject: [PATCH 4/7] fix remaining pep8 complaints --- src/wormhole/cli/welcome.py | 4 ++-- src/wormhole/test/test_cli.py | 2 +- src/wormhole/transit.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/wormhole/cli/welcome.py b/src/wormhole/cli/welcome.py index 2bbfc7e..32c1f84 100644 --- a/src/wormhole/cli/welcome.py +++ b/src/wormhole/cli/welcome.py @@ -11,8 +11,8 @@ def handle_welcome(welcome, relay_url, my_version, stderr): # Only warn if we're running a release version (e.g. 0.0.6, not # 0.0.6+DISTANCE.gHASH). Only warn once. - if ("current_cli_version" in welcome and "+" not in my_version - and welcome["current_cli_version"] != my_version): + if ("current_cli_version" in welcome and "+" not in my_version and + welcome["current_cli_version"] != my_version): print( ("Warning: errors may occur unless both sides are running the" " same version"), diff --git a/src/wormhole/test/test_cli.py b/src/wormhole/test/test_cli.py index a5d5478..d983d51 100644 --- a/src/wormhole/test/test_cli.py +++ b/src/wormhole/test/test_cli.py @@ -684,7 +684,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.failUnlessEqual(receive_stdout, "") self.failUnlessIn(u"Receiving file ({size:s}) into: {name}".format( size=naturalsize(len(message)), name=receive_filename), - receive_stderr) + receive_stderr) self.failUnlessIn(u"Received file written to ", receive_stderr) fn = os.path.join(receive_dir, receive_filename) self.failUnless(os.path.exists(fn)) diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index 6363c6d..d2b2c98 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -751,12 +751,12 @@ class Common: if hint_type not in [u"direct-tcp-v1", u"tor-tcp-v1"]: log.msg("unknown hint type: %r" % (hint, )) return None - if not (u"hostname" in hint - and isinstance(hint[u"hostname"], type(u""))): + if not (u"hostname" in hint and + isinstance(hint[u"hostname"], type(u""))): log.msg("invalid hostname in hint: %r" % (hint, )) return None - if not (u"port" in hint - and isinstance(hint[u"port"], six.integer_types)): + if not (u"port" in hint and + isinstance(hint[u"port"], six.integer_types)): log.msg("invalid port in hint: %r" % (hint, )) return None priority = hint.get(u"priority", 0.0) From 1444e32746f6ac9f420f9ff187ce2381b5aebcdb Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 16 Jun 2018 16:22:14 -0700 Subject: [PATCH 5/7] extreme measures to appease last pep8 complaint a singly-parenthesized 'if' condition will always line up with the 'then' body, won't it --- src/wormhole/cli/welcome.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/wormhole/cli/welcome.py b/src/wormhole/cli/welcome.py index 32c1f84..2380ce8 100644 --- a/src/wormhole/cli/welcome.py +++ b/src/wormhole/cli/welcome.py @@ -11,8 +11,9 @@ def handle_welcome(welcome, relay_url, my_version, stderr): # Only warn if we're running a release version (e.g. 0.0.6, not # 0.0.6+DISTANCE.gHASH). Only warn once. - if ("current_cli_version" in welcome and "+" not in my_version and - welcome["current_cli_version"] != my_version): + if (("current_cli_version" in welcome and + "+" not in my_version and + welcome["current_cli_version"] != my_version)): print( ("Warning: errors may occur unless both sides are running the" " same version"), From fc177726e1779eb8408abed0593587da591fab4d Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 16 Jun 2018 16:27:11 -0700 Subject: [PATCH 6/7] cli.py: move timing check back to top We care about how long it takes to import all the wormhole-specific things, to investigate user-perceived latency from the time the command is launched to the time they can actually interact with it. So we need to record `time.time()` before doing the rest of the imports, even though pep8 says all imports should be done before any non-importing statements. --- src/wormhole/cli/cli.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/wormhole/cli/cli.py b/src/wormhole/cli/cli.py index ba74157..292d662 100644 --- a/src/wormhole/cli/cli.py +++ b/src/wormhole/cli/cli.py @@ -2,23 +2,24 @@ from __future__ import print_function import os import time -from sys import stderr, stdout -from textwrap import dedent, fill +start = time.time() -import click -import six -from twisted.internet.defer import inlineCallbacks, maybeDeferred -from twisted.internet.task import react -from twisted.python.failure import Failure +from sys import stderr, stdout # noqa: E402 +from textwrap import dedent, fill # noqa: E402 -from . import public_relay -from .. import __version__ -from ..errors import (KeyFormatError, NoTorError, ServerConnectionError, +import click # noqa: E402 +import six # noqa: E402 +from twisted.internet.defer import inlineCallbacks, maybeDeferred # noqa: E402 +from twisted.internet.task import react # noqa: E402 +from twisted.python.failure import Failure # noqa: E402 + +from . import public_relay # noqa: E402 +from .. import __version__ # noqa: E402 +from ..errors import (KeyFormatError, NoTorError, # noqa: E402 + ServerConnectionError, TransferError, UnsendableFileError, WelcomeError, WrongPasswordError) -from ..timing import DebugTiming - -start = time.time() +from ..timing import DebugTiming # noqa: E402 top_import_finish = time.time() From 41fabd39ba0144475877df5c3940cff8dd352a3b Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 16 Jun 2018 16:33:27 -0700 Subject: [PATCH 7/7] test_machines: remove no-longer relevant comment The flake8 config excludes E741, which would complain about using 'l' (lower-case ell) as a variable name. We use this for the Lister object in one test that uses single-character variable names for all the machines ('b' for Boss, 'm' for Mailbox, etc). That comment was added before excluding E741. If we ever restore that warning, we might want to rename the variable. --- src/wormhole/test/test_machines.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index 38e351e..8227580 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -398,7 +398,6 @@ class Input(unittest.TestCase): events = [] i = _input.Input(timing.DebugTiming()) c = Dummy("c", events, ICode, "got_nameplate", "finished_input") - # renamed from l as l is indistinguishable from 1 in some fonts. l = Dummy("l", events, ILister, "refresh") i.wire(c, l) return i, c, l, events