diff --git a/src/wormhole/__init__.py b/src/wormhole/__init__.py index c00af56..0cc6a01 100644 --- a/src/wormhole/__init__.py +++ b/src/wormhole/__init__.py @@ -1,9 +1,4 @@ - -from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions - -from .wormhole import create from ._rlcompleter import input_with_completion +from .wormhole import create, __version__ __all__ = ["create", "input_with_completion", "__version__"] diff --git a/src/wormhole/__main__.py b/src/wormhole/__main__.py index a47f2e5..568abf3 100644 --- a/src/wormhole/__main__.py +++ b/src/wormhole/__main__.py @@ -1,8 +1,6 @@ +from .cli import cli + if __name__ != "__main__": raise ImportError('this module should not be imported') - -from .cli import cli - - cli.wormhole() diff --git a/src/wormhole/_allocator.py b/src/wormhole/_allocator.py index 3c98f10..554b101 100644 --- a/src/wormhole/_allocator.py +++ b/src/wormhole/_allocator.py @@ -1,56 +1,78 @@ -from __future__ import print_function, absolute_import, unicode_literals -from zope.interface import implementer -from attr import attrs, attrib +from __future__ import absolute_import, print_function, unicode_literals + +from attr import attrib, attrs from attr.validators import provides from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces + @attrs @implementer(_interfaces.IAllocator) class Allocator(object): _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def wire(self, rendezvous_connector, code): self._RC = _interfaces.IRendezvousConnector(rendezvous_connector) self._C = _interfaces.ICode(code) @m.state(initial=True) - def S0A_idle(self): pass # pragma: no cover + def S0A_idle(self): + pass # pragma: no cover + @m.state() - def S0B_idle_connected(self): pass # pragma: no cover + def S0B_idle_connected(self): + pass # pragma: no cover + @m.state() - def S1A_allocating(self): pass # pragma: no cover + def S1A_allocating(self): + pass # pragma: no cover + @m.state() - def S1B_allocating_connected(self): pass # pragma: no cover + def S1B_allocating_connected(self): + pass # pragma: no cover + @m.state() - def S2_done(self): pass # pragma: no cover + def S2_done(self): + pass # pragma: no cover # from Code @m.input() - def allocate(self, length, wordlist): pass + def allocate(self, length, wordlist): + pass # from RendezvousConnector @m.input() - def connected(self): pass + def connected(self): + pass + @m.input() - def lost(self): pass + def lost(self): + pass + @m.input() - def rx_allocated(self, nameplate): pass + def rx_allocated(self, nameplate): + pass @m.output() def stash(self, length, wordlist): self._length = length self._wordlist = _interfaces.IWordlist(wordlist) + @m.output() def stash_and_RC_rx_allocate(self, length, wordlist): self._length = length self._wordlist = _interfaces.IWordlist(wordlist) self._RC.tx_allocate() + @m.output() def RC_tx_allocate(self): self._RC.tx_allocate() + @m.output() def build_and_notify(self, nameplate): words = self._wordlist.choose_words(self._length) @@ -61,15 +83,17 @@ class Allocator(object): S0B_idle_connected.upon(lost, enter=S0A_idle, outputs=[]) S0A_idle.upon(allocate, enter=S1A_allocating, outputs=[stash]) - S0B_idle_connected.upon(allocate, enter=S1B_allocating_connected, - outputs=[stash_and_RC_rx_allocate]) + S0B_idle_connected.upon( + allocate, + enter=S1B_allocating_connected, + outputs=[stash_and_RC_rx_allocate]) - S1A_allocating.upon(connected, enter=S1B_allocating_connected, - outputs=[RC_tx_allocate]) + S1A_allocating.upon( + connected, enter=S1B_allocating_connected, outputs=[RC_tx_allocate]) S1B_allocating_connected.upon(lost, enter=S1A_allocating, outputs=[]) - S1B_allocating_connected.upon(rx_allocated, enter=S2_done, - outputs=[build_and_notify]) + S1B_allocating_connected.upon( + rx_allocated, enter=S2_done, outputs=[build_and_notify]) S2_done.upon(connected, enter=S2_done, outputs=[]) S2_done.upon(lost, enter=S2_done, outputs=[]) diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index c99bb3d..097efe7 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -1,29 +1,33 @@ -from __future__ import print_function, absolute_import, unicode_literals +from __future__ import absolute_import, print_function, unicode_literals + import re + import six -from zope.interface import implementer -from attr import attrs, attrib -from attr.validators import provides, instance_of, optional -from twisted.python import log +from attr import attrib, attrs +from attr.validators import instance_of, optional, provides from automat import MethodicalMachine +from twisted.python import log +from zope.interface import implementer + from . import _interfaces -from ._nameplate import Nameplate -from ._mailbox import Mailbox -from ._send import Send -from ._order import Order +from ._allocator import Allocator +from ._code import Code, validate_code +from ._input import Input from ._key import Key +from ._lister import Lister +from ._mailbox import Mailbox +from ._nameplate import Nameplate +from ._order import Order from ._receive import Receive from ._rendezvous import RendezvousConnector -from ._lister import Lister -from ._allocator import Allocator -from ._input import Input -from ._code import Code, validate_code +from ._send import Send from ._terminator import Terminator from ._wordlist import PGPWordList -from .errors import (ServerError, LonelyError, WrongPasswordError, - OnlyOneCodeError, _UnknownPhaseError, WelcomeError) +from .errors import (LonelyError, OnlyOneCodeError, ServerError, WelcomeError, + WrongPasswordError, _UnknownPhaseError) from .util import bytes_to_dict + @attrs @implementer(_interfaces.IBoss) class Boss(object): @@ -38,7 +42,8 @@ class Boss(object): _tor = attrib(validator=optional(provides(_interfaces.ITorManager))) _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._build_workers() @@ -52,9 +57,8 @@ class Boss(object): self._K = Key(self._appid, self._versions, self._side, self._timing) self._R = Receive(self._side, self._timing) self._RC = RendezvousConnector(self._url, self._appid, self._side, - self._reactor, self._journal, - self._tor, self._timing, - self._client_version) + self._reactor, self._journal, self._tor, + self._timing, self._client_version) self._L = Lister(self._timing) self._A = Allocator(self._timing) self._I = Input(self._timing) @@ -78,7 +82,7 @@ class Boss(object): self._did_start_code = False self._next_tx_phase = 0 self._next_rx_phase = 0 - self._rx_phases = {} # phase -> plaintext + self._rx_phases = {} # phase -> plaintext self._result = "empty" @@ -86,32 +90,45 @@ class Boss(object): def start(self): self._RC.start() - def _print_trace(self, old_state, input, new_state, - client_name, machine, file): + def _print_trace(self, old_state, input, new_state, client_name, machine, + file): if new_state: - print("%s.%s[%s].%s -> [%s]" % - (client_name, machine, old_state, input, - new_state), file=file) + print( + "%s.%s[%s].%s -> [%s]" % (client_name, machine, old_state, + input, new_state), + file=file) else: # the RendezvousConnector emits message events as if # they were state transitions, except that old_state # and new_state are empty strings. "input" is one of # R.connected, R.rx(type phase+side), R.tx(type # phase), R.lost . - print("%s.%s.%s" % (client_name, machine, input), - file=file) + print("%s.%s.%s" % (client_name, machine, input), file=file) file.flush() + def output_tracer(output): - print(" %s.%s.%s()" % (client_name, machine, output), - file=file) + print(" %s.%s.%s()" % (client_name, machine, output), file=file) file.flush() + return output_tracer def _set_trace(self, client_name, which, file): - names = {"B": self, "N": self._N, "M": self._M, "S": self._S, - "O": self._O, "K": self._K, "SK": self._K._SK, "R": self._R, - "RC": self._RC, "L": self._L, "A": self._A, "I": self._I, - "C": self._C, "T": self._T} + names = { + "B": self, + "N": self._N, + "M": self._M, + "S": self._S, + "O": self._O, + "K": self._K, + "SK": self._K._SK, + "R": self._R, + "RC": self._RC, + "L": self._L, + "A": self._A, + "I": self._I, + "C": self._C, + "T": self._T + } for machine in which.split(): t = (lambda old_state, input, new_state, machine=machine: self._print_trace(old_state, input, new_state, @@ -121,21 +138,30 @@ class Boss(object): if machine == "I": self._I.set_debug(t) - ## def serialize(self): - ## raise NotImplemented + # def serialize(self): + # raise NotImplemented # and these are the state-machine transition functions, which don't take # args @m.state(initial=True) - def S0_empty(self): pass # pragma: no cover + def S0_empty(self): + pass # pragma: no cover + @m.state() - def S1_lonely(self): pass # pragma: no cover + def S1_lonely(self): + pass # pragma: no cover + @m.state() - def S2_happy(self): pass # pragma: no cover + def S2_happy(self): + pass # pragma: no cover + @m.state() - def S3_closing(self): pass # pragma: no cover + def S3_closing(self): + pass # pragma: no cover + @m.state(terminal=True) - def S4_closed(self): pass # pragma: no cover + def S4_closed(self): + pass # pragma: no cover # from the Wormhole @@ -155,23 +181,28 @@ class Boss(object): raise OnlyOneCodeError() self._did_start_code = True return self._C.input_code() + def allocate_code(self, code_length): if self._did_start_code: raise OnlyOneCodeError() self._did_start_code = True wl = PGPWordList() self._C.allocate_code(code_length, wl) + def set_code(self, code): - validate_code(code) # can raise KeyFormatError + validate_code(code) # can raise KeyFormatError if self._did_start_code: raise OnlyOneCodeError() self._did_start_code = True self._C.set_code(code) @m.input() - def send(self, plaintext): pass + def send(self, plaintext): + pass + @m.input() - def close(self): pass + def close(self): + pass # from RendezvousConnector: # * "rx_welcome" is the Welcome message, which might signal an error, or @@ -190,26 +221,36 @@ class Boss(object): # delivering a new input (rx_error or something) while in the # middle of processing the rx_welcome input, and I wasn't sure # Automat would handle that correctly. - self._W.got_welcome(welcome) # TODO: let this raise WelcomeError? + self._W.got_welcome(welcome) # TODO: let this raise WelcomeError? except WelcomeError as welcome_error: self.rx_unwelcome(welcome_error) + @m.input() - def rx_unwelcome(self, welcome_error): pass + def rx_unwelcome(self, welcome_error): + pass + @m.input() - def rx_error(self, errmsg, orig): pass + def rx_error(self, errmsg, orig): + pass + @m.input() - def error(self, err): pass + def error(self, err): + pass # from Code (provoked by input/allocate/set_code) @m.input() - def got_code(self, code): pass + def got_code(self, code): + pass # Key sends (got_key, scared) # Receive sends (got_message, happy, got_verifier, scared) @m.input() - def happy(self): pass + def happy(self): + pass + @m.input() - def scared(self): pass + def scared(self): + pass def got_message(self, phase, plaintext): assert isinstance(phase, type("")), type(phase) @@ -222,22 +263,32 @@ class Boss(object): # Ignore unrecognized phases, for forwards-compatibility. Use # log.err so tests will catch surprises. log.err(_UnknownPhaseError("received unknown phase '%s'" % phase)) + @m.input() - def _got_version(self, plaintext): pass + def _got_version(self, plaintext): + pass + @m.input() - def _got_phase(self, phase, plaintext): pass + def _got_phase(self, phase, plaintext): + pass + @m.input() - def got_key(self, key): pass + def got_key(self, key): + pass + @m.input() - def got_verifier(self, verifier): pass + def got_verifier(self, verifier): + pass # Terminator sends closed @m.input() - def closed(self): pass + def closed(self): + pass @m.output() def do_got_code(self, code): self._W.got_code(code) + @m.output() def process_version(self, plaintext): # most of this is wormhole-to-wormhole, ignored for now @@ -256,21 +307,25 @@ class Boss(object): @m.output() def close_unwelcome(self, welcome_error): - #assert isinstance(err, WelcomeError) + # assert isinstance(err, WelcomeError) self._result = welcome_error self._T.close("unwelcome") + @m.output() def close_error(self, errmsg, orig): self._result = ServerError(errmsg) self._T.close("errory") + @m.output() def close_scared(self): self._result = WrongPasswordError() self._T.close("scary") + @m.output() def close_lonely(self): self._result = LonelyError() self._T.close("lonely") + @m.output() def close_happy(self): self._result = "happy" @@ -279,9 +334,11 @@ class Boss(object): @m.output() def W_got_key(self, key): self._W.got_key(key) + @m.output() def W_got_verifier(self, verifier): self._W.got_verifier(verifier) + @m.output() def W_received(self, phase, plaintext): assert isinstance(phase, six.integer_types), type(phase) @@ -293,7 +350,7 @@ class Boss(object): @m.output() def W_close_with_error(self, err): - self._result = err # exception + self._result = err # exception self._W.closed(self._result) @m.output() diff --git a/src/wormhole/_code.py b/src/wormhole/_code.py index 96314b2..9736698 100644 --- a/src/wormhole/_code.py +++ b/src/wormhole/_code.py @@ -7,21 +7,25 @@ from . import _interfaces from ._nameplate import validate_nameplate from .errors import KeyFormatError + def validate_code(code): if ' ' in code: raise KeyFormatError("Code '%s' contains spaces." % code) nameplate = code.split("-", 2)[0] - validate_nameplate(nameplate) # can raise KeyFormatError + validate_nameplate(nameplate) # can raise KeyFormatError + def first(outputs): return list(outputs)[0] + @attrs @implementer(_interfaces.ICode) class Code(object): _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def wire(self, boss, allocator, nameplate, key, input): self._B = _interfaces.IBoss(boss) @@ -31,36 +35,55 @@ class Code(object): self._I = _interfaces.IInput(input) @m.state(initial=True) - def S0_idle(self): pass # pragma: no cover + def S0_idle(self): + pass # pragma: no cover + @m.state() - def S1_inputting_nameplate(self): pass # pragma: no cover + def S1_inputting_nameplate(self): + pass # pragma: no cover + @m.state() - def S2_inputting_words(self): pass # pragma: no cover + def S2_inputting_words(self): + pass # pragma: no cover + @m.state() - def S3_allocating(self): pass # pragma: no cover + def S3_allocating(self): + pass # pragma: no cover + @m.state() - def S4_known(self): pass # pragma: no cover + def S4_known(self): + pass # pragma: no cover # from App @m.input() - def allocate_code(self, length, wordlist): pass + def allocate_code(self, length, wordlist): + pass + @m.input() - def input_code(self): pass + def input_code(self): + pass + def set_code(self, code): - validate_code(code) # can raise KeyFormatError + validate_code(code) # can raise KeyFormatError self._set_code(code) + @m.input() - def _set_code(self, code): pass + def _set_code(self, code): + pass # from Allocator @m.input() - def allocated(self, nameplate, code): pass + def allocated(self, nameplate, code): + pass # from Input @m.input() - def got_nameplate(self, nameplate): pass + def got_nameplate(self, nameplate): + pass + @m.input() - def finished_input(self, code): pass + def finished_input(self, code): + pass @m.output() def do_set_code(self, code): @@ -72,9 +95,11 @@ class Code(object): @m.output() def do_start_input(self): return self._I.start() + @m.output() def do_middle_input(self, nameplate): self._N.set_nameplate(nameplate) + @m.output() def do_finish_input(self, code): self._B.got_code(code) @@ -83,19 +108,24 @@ class Code(object): @m.output() def do_start_allocate(self, length, wordlist): self._A.allocate(length, wordlist) + @m.output() def do_finish_allocate(self, nameplate, code): - assert code.startswith(nameplate+"-"), (nameplate, code) + assert code.startswith(nameplate + "-"), (nameplate, code) self._N.set_nameplate(nameplate) self._B.got_code(code) self._K.got_code(code) S0_idle.upon(_set_code, enter=S4_known, outputs=[do_set_code]) - S0_idle.upon(input_code, enter=S1_inputting_nameplate, - outputs=[do_start_input], collector=first) - S1_inputting_nameplate.upon(got_nameplate, enter=S2_inputting_words, - outputs=[do_middle_input]) - S2_inputting_words.upon(finished_input, enter=S4_known, - outputs=[do_finish_input]) - S0_idle.upon(allocate_code, enter=S3_allocating, outputs=[do_start_allocate]) + S0_idle.upon( + input_code, + enter=S1_inputting_nameplate, + outputs=[do_start_input], + collector=first) + S1_inputting_nameplate.upon( + got_nameplate, enter=S2_inputting_words, outputs=[do_middle_input]) + S2_inputting_words.upon( + finished_input, enter=S4_known, outputs=[do_finish_input]) + S0_idle.upon( + allocate_code, enter=S3_allocating, outputs=[do_start_allocate]) S3_allocating.upon(allocated, enter=S4_known, outputs=[do_finish_allocate]) diff --git a/src/wormhole/_input.py b/src/wormhole/_input.py index 7c4f74a..7993fda 100644 --- a/src/wormhole/_input.py +++ b/src/wormhole/_input.py @@ -1,25 +1,31 @@ -from __future__ import print_function, absolute_import, unicode_literals +from __future__ import absolute_import, print_function, unicode_literals + # We use 'threading' defensively here, to detect if we're being called from a # non-main thread. _rlcompleter.py is the only internal Wormhole code that # deliberately creates a new thread. import threading -from zope.interface import implementer -from attr import attrs, attrib + +from attr import attrib, attrs from attr.validators import provides -from twisted.internet import defer from automat import MethodicalMachine +from twisted.internet import defer +from zope.interface import implementer + from . import _interfaces, errors from ._nameplate import validate_nameplate + def first(outputs): return list(outputs)[0] + @attrs @implementer(_interfaces.IInput) class Input(object): _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._all_nameplates = set() @@ -30,7 +36,8 @@ class Input(object): def set_debug(self, f): self._trace = f - def _debug(self, what): # pragma: no cover + + def _debug(self, what): # pragma: no cover if self._trace: self._trace(old_state="", input=what, new_state="") @@ -46,55 +53,80 @@ class Input(object): return d @m.state(initial=True) - def S0_idle(self): pass # pragma: no cover + def S0_idle(self): + pass # pragma: no cover + @m.state() - def S1_typing_nameplate(self): pass # pragma: no cover + def S1_typing_nameplate(self): + pass # pragma: no cover + @m.state() - def S2_typing_code_no_wordlist(self): pass # pragma: no cover + def S2_typing_code_no_wordlist(self): + pass # pragma: no cover + @m.state() - def S3_typing_code_yes_wordlist(self): pass # pragma: no cover + def S3_typing_code_yes_wordlist(self): + pass # pragma: no cover + @m.state(terminal=True) - def S4_done(self): pass # pragma: no cover + def S4_done(self): + pass # pragma: no cover # from Code @m.input() - def start(self): pass + def start(self): + pass # from Lister @m.input() - def got_nameplates(self, all_nameplates): pass + def got_nameplates(self, all_nameplates): + pass # from Nameplate @m.input() - def got_wordlist(self, wordlist): pass + def got_wordlist(self, wordlist): + pass # API provided to app as ICodeInputHelper @m.input() - def refresh_nameplates(self): pass + def refresh_nameplates(self): + pass + @m.input() - def get_nameplate_completions(self, prefix): pass + def get_nameplate_completions(self, prefix): + pass + def choose_nameplate(self, nameplate): - validate_nameplate(nameplate) # can raise KeyFormatError + validate_nameplate(nameplate) # can raise KeyFormatError self._choose_nameplate(nameplate) + @m.input() - def _choose_nameplate(self, nameplate): pass + def _choose_nameplate(self, nameplate): + pass + @m.input() - def get_word_completions(self, prefix): pass + def get_word_completions(self, prefix): + pass + @m.input() - def choose_words(self, words): pass + def choose_words(self, words): + pass @m.output() def do_start(self): self._start_timing = self._timing.add("input code", waiting="user") self._L.refresh() return Helper(self) + @m.output() def do_refresh(self): self._L.refresh() + @m.output() def record_nameplates(self, all_nameplates): # we get a set of nameplate id strings self._all_nameplates = all_nameplates + @m.output() def _get_nameplate_completions(self, prefix): completions = set() @@ -102,17 +134,20 @@ class Input(object): if nameplate.startswith(prefix): # TODO: it's a little weird that Input is responsible for the # hyphen on nameplates, but WordList owns it for words - completions.add(nameplate+"-") + completions.add(nameplate + "-") return completions + @m.output() def record_all_nameplates(self, nameplate): self._nameplate = nameplate self._C.got_nameplate(nameplate) + @m.output() def record_wordlist(self, wordlist): from ._rlcompleter import debug debug(" -record_wordlist") self._wordlist = wordlist + @m.output() def notify_wordlist_waiters(self, wordlist): while self._wordlist_waiters: @@ -122,6 +157,7 @@ class Input(object): @m.output() def no_word_completions(self, prefix): return set() + @m.output() def _get_word_completions(self, prefix): assert self._wordlist @@ -130,21 +166,27 @@ class Input(object): @m.output() def raise_must_choose_nameplate1(self, prefix): raise errors.MustChooseNameplateFirstError() + @m.output() def raise_must_choose_nameplate2(self, words): raise errors.MustChooseNameplateFirstError() + @m.output() def raise_already_chose_nameplate1(self): raise errors.AlreadyChoseNameplateError() + @m.output() def raise_already_chose_nameplate2(self, prefix): raise errors.AlreadyChoseNameplateError() + @m.output() def raise_already_chose_nameplate3(self, nameplate): raise errors.AlreadyChoseNameplateError() + @m.output() def raise_already_chose_words1(self, prefix): raise errors.AlreadyChoseWordsError() + @m.output() def raise_already_chose_words2(self, words): raise errors.AlreadyChoseWordsError() @@ -155,88 +197,110 @@ class Input(object): self._start_timing.finish() self._C.finished_input(code) - S0_idle.upon(start, enter=S1_typing_nameplate, - outputs=[do_start], collector=first) + S0_idle.upon( + start, enter=S1_typing_nameplate, outputs=[do_start], collector=first) # wormholes that don't use input_code (i.e. they use allocate_code or # generate_code) will never start() us, but Nameplate will give us a # wordlist anyways (as soon as the nameplate is claimed), so handle it. - S0_idle.upon(got_wordlist, enter=S0_idle, outputs=[record_wordlist, - notify_wordlist_waiters]) - S1_typing_nameplate.upon(got_nameplates, enter=S1_typing_nameplate, - outputs=[record_nameplates]) + S0_idle.upon( + got_wordlist, + enter=S0_idle, + outputs=[record_wordlist, notify_wordlist_waiters]) + S1_typing_nameplate.upon( + got_nameplates, enter=S1_typing_nameplate, outputs=[record_nameplates]) # but wormholes that *do* use input_code should not get got_wordlist # until after we tell Code that we got_nameplate, which is the earliest # it can be claimed - S1_typing_nameplate.upon(refresh_nameplates, enter=S1_typing_nameplate, - outputs=[do_refresh]) - S1_typing_nameplate.upon(get_nameplate_completions, - enter=S1_typing_nameplate, - outputs=[_get_nameplate_completions], - collector=first) - S1_typing_nameplate.upon(_choose_nameplate, enter=S2_typing_code_no_wordlist, - outputs=[record_all_nameplates]) - S1_typing_nameplate.upon(get_word_completions, - enter=S1_typing_nameplate, - outputs=[raise_must_choose_nameplate1]) - S1_typing_nameplate.upon(choose_words, enter=S1_typing_nameplate, - outputs=[raise_must_choose_nameplate2]) + S1_typing_nameplate.upon( + refresh_nameplates, enter=S1_typing_nameplate, outputs=[do_refresh]) + S1_typing_nameplate.upon( + get_nameplate_completions, + enter=S1_typing_nameplate, + outputs=[_get_nameplate_completions], + collector=first) + S1_typing_nameplate.upon( + _choose_nameplate, + enter=S2_typing_code_no_wordlist, + outputs=[record_all_nameplates]) + S1_typing_nameplate.upon( + get_word_completions, + enter=S1_typing_nameplate, + outputs=[raise_must_choose_nameplate1]) + S1_typing_nameplate.upon( + choose_words, + enter=S1_typing_nameplate, + outputs=[raise_must_choose_nameplate2]) - S2_typing_code_no_wordlist.upon(got_nameplates, - enter=S2_typing_code_no_wordlist, outputs=[]) - S2_typing_code_no_wordlist.upon(got_wordlist, - enter=S3_typing_code_yes_wordlist, - outputs=[record_wordlist, - notify_wordlist_waiters]) - S2_typing_code_no_wordlist.upon(refresh_nameplates, - enter=S2_typing_code_no_wordlist, - outputs=[raise_already_chose_nameplate1]) - S2_typing_code_no_wordlist.upon(get_nameplate_completions, - enter=S2_typing_code_no_wordlist, - outputs=[raise_already_chose_nameplate2]) - S2_typing_code_no_wordlist.upon(_choose_nameplate, - enter=S2_typing_code_no_wordlist, - outputs=[raise_already_chose_nameplate3]) - S2_typing_code_no_wordlist.upon(get_word_completions, - enter=S2_typing_code_no_wordlist, - outputs=[no_word_completions], - collector=first) - S2_typing_code_no_wordlist.upon(choose_words, enter=S4_done, - outputs=[do_words]) + S2_typing_code_no_wordlist.upon( + got_nameplates, enter=S2_typing_code_no_wordlist, outputs=[]) + S2_typing_code_no_wordlist.upon( + got_wordlist, + enter=S3_typing_code_yes_wordlist, + outputs=[record_wordlist, notify_wordlist_waiters]) + S2_typing_code_no_wordlist.upon( + refresh_nameplates, + enter=S2_typing_code_no_wordlist, + outputs=[raise_already_chose_nameplate1]) + S2_typing_code_no_wordlist.upon( + get_nameplate_completions, + enter=S2_typing_code_no_wordlist, + outputs=[raise_already_chose_nameplate2]) + S2_typing_code_no_wordlist.upon( + _choose_nameplate, + enter=S2_typing_code_no_wordlist, + outputs=[raise_already_chose_nameplate3]) + S2_typing_code_no_wordlist.upon( + get_word_completions, + enter=S2_typing_code_no_wordlist, + outputs=[no_word_completions], + collector=first) + S2_typing_code_no_wordlist.upon( + choose_words, enter=S4_done, outputs=[do_words]) - S3_typing_code_yes_wordlist.upon(got_nameplates, - enter=S3_typing_code_yes_wordlist, - outputs=[]) + S3_typing_code_yes_wordlist.upon( + got_nameplates, enter=S3_typing_code_yes_wordlist, outputs=[]) # got_wordlist: should never happen - S3_typing_code_yes_wordlist.upon(refresh_nameplates, - enter=S3_typing_code_yes_wordlist, - outputs=[raise_already_chose_nameplate1]) - S3_typing_code_yes_wordlist.upon(get_nameplate_completions, - enter=S3_typing_code_yes_wordlist, - outputs=[raise_already_chose_nameplate2]) - S3_typing_code_yes_wordlist.upon(_choose_nameplate, - enter=S3_typing_code_yes_wordlist, - outputs=[raise_already_chose_nameplate3]) - S3_typing_code_yes_wordlist.upon(get_word_completions, - enter=S3_typing_code_yes_wordlist, - outputs=[_get_word_completions], - collector=first) - S3_typing_code_yes_wordlist.upon(choose_words, enter=S4_done, - outputs=[do_words]) + S3_typing_code_yes_wordlist.upon( + refresh_nameplates, + enter=S3_typing_code_yes_wordlist, + outputs=[raise_already_chose_nameplate1]) + S3_typing_code_yes_wordlist.upon( + get_nameplate_completions, + enter=S3_typing_code_yes_wordlist, + outputs=[raise_already_chose_nameplate2]) + S3_typing_code_yes_wordlist.upon( + _choose_nameplate, + enter=S3_typing_code_yes_wordlist, + outputs=[raise_already_chose_nameplate3]) + S3_typing_code_yes_wordlist.upon( + get_word_completions, + enter=S3_typing_code_yes_wordlist, + outputs=[_get_word_completions], + collector=first) + S3_typing_code_yes_wordlist.upon( + choose_words, enter=S4_done, outputs=[do_words]) S4_done.upon(got_nameplates, enter=S4_done, outputs=[]) S4_done.upon(got_wordlist, enter=S4_done, outputs=[]) - S4_done.upon(refresh_nameplates, - enter=S4_done, - outputs=[raise_already_chose_nameplate1]) - S4_done.upon(get_nameplate_completions, - enter=S4_done, - outputs=[raise_already_chose_nameplate2]) - S4_done.upon(_choose_nameplate, enter=S4_done, - outputs=[raise_already_chose_nameplate3]) - S4_done.upon(get_word_completions, enter=S4_done, - outputs=[raise_already_chose_words1]) - S4_done.upon(choose_words, enter=S4_done, - outputs=[raise_already_chose_words2]) + S4_done.upon( + refresh_nameplates, + enter=S4_done, + outputs=[raise_already_chose_nameplate1]) + S4_done.upon( + get_nameplate_completions, + enter=S4_done, + outputs=[raise_already_chose_nameplate2]) + S4_done.upon( + _choose_nameplate, + enter=S4_done, + outputs=[raise_already_chose_nameplate3]) + S4_done.upon( + get_word_completions, + enter=S4_done, + outputs=[raise_already_chose_words1]) + S4_done.upon( + choose_words, enter=S4_done, outputs=[raise_already_chose_words2]) + # we only expose the Helper to application code, not _Input @attrs @@ -250,20 +314,25 @@ class Helper(object): def refresh_nameplates(self): assert threading.current_thread().ident == self._main_thread self._input.refresh_nameplates() + def get_nameplate_completions(self, prefix): assert threading.current_thread().ident == self._main_thread return self._input.get_nameplate_completions(prefix) + def choose_nameplate(self, nameplate): assert threading.current_thread().ident == self._main_thread self._input._debug("I.choose_nameplate") self._input.choose_nameplate(nameplate) self._input._debug("I.choose_nameplate finished") + def when_wordlist_is_available(self): assert threading.current_thread().ident == self._main_thread return self._input.when_wordlist_is_available() + def get_word_completions(self, prefix): assert threading.current_thread().ident == self._main_thread return self._input.get_word_completions(prefix) + def choose_words(self, words): assert threading.current_thread().ident == self._main_thread self._input._debug("I.choose_words") diff --git a/src/wormhole/_interfaces.py b/src/wormhole/_interfaces.py index 060ff22..9f59ff4 100644 --- a/src/wormhole/_interfaces.py +++ b/src/wormhole/_interfaces.py @@ -3,64 +3,105 @@ from zope.interface import Interface # These interfaces are private: we use them as markers to detect # swapped argument bugs in the various .wire() calls + class IWormhole(Interface): """Internal: this contains the methods invoked 'from below'.""" + def got_welcome(welcome): pass + def got_code(code): pass + def got_key(key): pass + def got_verifier(verifier): pass + def got_versions(versions): pass + def received(plaintext): pass + def closed(result): pass + class IBoss(Interface): pass + + class INameplate(Interface): pass + + class IMailbox(Interface): pass + + class ISend(Interface): pass + + class IOrder(Interface): pass + + class IKey(Interface): pass + + class IReceive(Interface): pass + + class IRendezvousConnector(Interface): pass + + class ILister(Interface): pass + + class ICode(Interface): pass + + class IInput(Interface): pass + + class IAllocator(Interface): pass + + class ITerminator(Interface): pass + class ITiming(Interface): pass + + class ITorManager(Interface): pass + + class IWordlist(Interface): def choose_words(length): """Randomly select LENGTH words, join them with hyphens, return the result.""" + def get_completions(prefix): """Return a list of all suffixes that could complete the given prefix.""" + # These interfaces are public, and are re-exported by __init__.py + class IDeferredWormhole(Interface): def get_welcome(): """ @@ -277,6 +318,7 @@ class IDeferredWormhole(Interface): :rtype: ``Deferred`` """ + class IInputHelper(Interface): def refresh_nameplates(): """ @@ -389,5 +431,5 @@ class IInputHelper(Interface): """ -class IJournal(Interface): # TODO: this needs to be public +class IJournal(Interface): # TODO: this needs to be public pass diff --git a/src/wormhole/_key.py b/src/wormhole/_key.py index f76f454..849ed41 100644 --- a/src/wormhole/_key.py +++ b/src/wormhole/_key.py @@ -1,41 +1,50 @@ -from __future__ import print_function, absolute_import, unicode_literals +from __future__ import absolute_import, print_function, unicode_literals + from hashlib import sha256 + import six -from zope.interface import implementer -from attr import attrs, attrib -from attr.validators import provides, instance_of -from spake2 import SPAKE2_Symmetric -from hkdf import Hkdf -from nacl.secret import SecretBox -from nacl.exceptions import CryptoError -from nacl import utils +from attr import attrib, attrs +from attr.validators import instance_of, provides from automat import MethodicalMachine -from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes, - bytes_to_dict, dict_to_bytes) +from hkdf import Hkdf +from nacl import utils +from nacl.exceptions import CryptoError +from nacl.secret import SecretBox +from spake2 import SPAKE2_Symmetric +from zope.interface import implementer + from . import _interfaces +from .util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes, + hexstr_to_bytes, to_bytes) + CryptoError -__all__ = ["derive_key", "derive_phase_key", "CryptoError", - "Key"] +__all__ = ["derive_key", "derive_phase_key", "CryptoError", "Key"] + def HKDF(skm, outlen, salt=None, CTXinfo=b""): return Hkdf(salt, skm).expand(CTXinfo, outlen) + def derive_key(key, purpose, length=SecretBox.KEY_SIZE): - if not isinstance(key, type(b"")): raise TypeError(type(key)) - if not isinstance(purpose, type(b"")): raise TypeError(type(purpose)) - if not isinstance(length, six.integer_types): raise TypeError(type(length)) + if not isinstance(key, type(b"")): + raise TypeError(type(key)) + if not isinstance(purpose, type(b"")): + raise TypeError(type(purpose)) + if not isinstance(length, six.integer_types): + raise TypeError(type(length)) return HKDF(key, length, CTXinfo=purpose) + def derive_phase_key(key, side, phase): assert isinstance(side, type("")), type(side) assert isinstance(phase, type("")), type(phase) side_bytes = side.encode("ascii") phase_bytes = phase.encode("ascii") - purpose = (b"wormhole:phase:" - + sha256(side_bytes).digest() - + sha256(phase_bytes).digest()) + purpose = (b"wormhole:phase:" + sha256(side_bytes).digest() + + sha256(phase_bytes).digest()) return derive_key(key, purpose) + def decrypt_data(key, encrypted): assert isinstance(key, type(b"")), type(key) assert isinstance(encrypted, type(b"")), type(encrypted) @@ -44,6 +53,7 @@ def decrypt_data(key, encrypted): data = box.decrypt(encrypted) return data + def encrypt_data(key, plaintext): assert isinstance(key, type(b"")), type(key) assert isinstance(plaintext, type(b"")), type(plaintext) @@ -52,10 +62,12 @@ def encrypt_data(key, plaintext): nonce = utils.random(SecretBox.NONCE_SIZE) return box.encrypt(plaintext, nonce) + # the Key we expose to callers (Boss, Ordering) is responsible for sorting # the two messages (got_code and got_pake), then delivering them to # _SortedKey in the right order. + @attrs @implementer(_interfaces.IKey) class Key(object): @@ -64,40 +76,54 @@ class Key(object): _side = attrib(validator=instance_of(type(u""))) _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._SK = _SortedKey(self._appid, self._versions, self._side, self._timing) - self._debug_pake_stashed = False # for tests + self._debug_pake_stashed = False # for tests def wire(self, boss, mailbox, receive): self._SK.wire(boss, mailbox, receive) @m.state(initial=True) - def S00(self): pass # pragma: no cover + def S00(self): + pass # pragma: no cover + @m.state() - def S01(self): pass # pragma: no cover + def S01(self): + pass # pragma: no cover + @m.state() - def S10(self): pass # pragma: no cover + def S10(self): + pass # pragma: no cover + @m.state() - def S11(self): pass # pragma: no cover + def S11(self): + pass # pragma: no cover @m.input() - def got_code(self, code): pass + def got_code(self, code): + pass + @m.input() - def got_pake(self, body): pass + def got_pake(self, body): + pass @m.output() def stash_pake(self, body): self._pake = body self._debug_pake_stashed = True + @m.output() def deliver_code(self, code): self._SK.got_code(code) + @m.output() def deliver_pake(self, body): self._SK.got_pake(body) + @m.output() def deliver_code_and_stashed_pake(self, code): self._SK.got_code(code) @@ -108,6 +134,7 @@ class Key(object): S00.upon(got_pake, enter=S01, outputs=[stash_pake]) S01.upon(got_code, enter=S11, outputs=[deliver_code_and_stashed_pake]) + @attrs class _SortedKey(object): _appid = attrib(validator=instance_of(type(u""))) @@ -115,7 +142,8 @@ class _SortedKey(object): _side = attrib(validator=instance_of(type(u""))) _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def wire(self, boss, mailbox, receive): self._B = _interfaces.IBoss(boss) @@ -123,17 +151,25 @@ class _SortedKey(object): self._R = _interfaces.IReceive(receive) @m.state(initial=True) - def S0_know_nothing(self): pass # pragma: no cover + def S0_know_nothing(self): + pass # pragma: no cover + @m.state() - def S1_know_code(self): pass # pragma: no cover + def S1_know_code(self): + pass # pragma: no cover + @m.state() - def S2_know_key(self): pass # pragma: no cover + def S2_know_key(self): + pass # pragma: no cover + @m.state(terminal=True) - def S3_scared(self): pass # pragma: no cover + def S3_scared(self): + pass # pragma: no cover # from Boss @m.input() - def got_code(self, code): pass + def got_code(self, code): + pass # from Ordering def got_pake(self, body): @@ -143,16 +179,20 @@ class _SortedKey(object): self.got_pake_good(hexstr_to_bytes(payload["pake_v1"])) else: self.got_pake_bad() + @m.input() - def got_pake_good(self, msg2): pass + def got_pake_good(self, msg2): + pass + @m.input() - def got_pake_bad(self): pass + def got_pake_bad(self): + pass @m.output() def build_pake(self, code): with self._timing.add("pake1", waiting="crypto"): - self._sp = SPAKE2_Symmetric(to_bytes(code), - idSymmetric=to_bytes(self._appid)) + self._sp = SPAKE2_Symmetric( + to_bytes(code), idSymmetric=to_bytes(self._appid)) msg1 = self._sp.start() body = dict_to_bytes({"pake_v1": bytes_to_hexstr(msg1)}) self._M.add_message("pake", body) @@ -160,6 +200,7 @@ class _SortedKey(object): @m.output() def scared(self): self._B.scared() + @m.output() def compute_key(self, msg2): assert isinstance(msg2, type(b"")) diff --git a/src/wormhole/_lister.py b/src/wormhole/_lister.py index ebe7919..de085fa 100644 --- a/src/wormhole/_lister.py +++ b/src/wormhole/_lister.py @@ -1,16 +1,20 @@ -from __future__ import print_function, absolute_import, unicode_literals -from zope.interface import implementer -from attr import attrs, attrib +from __future__ import absolute_import, print_function, unicode_literals + +from attr import attrib, attrs from attr.validators import provides from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces + @attrs @implementer(_interfaces.ILister) class Lister(object): _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def wire(self, rendezvous_connector, input): self._RC = _interfaces.IRendezvousConnector(rendezvous_connector) @@ -26,26 +30,41 @@ class Lister(object): # request arrives, both requests will be satisfied by the same response. @m.state(initial=True) - def S0A_idle_disconnected(self): pass # pragma: no cover + def S0A_idle_disconnected(self): + pass # pragma: no cover + @m.state() - def S1A_wanting_disconnected(self): pass # pragma: no cover + def S1A_wanting_disconnected(self): + pass # pragma: no cover + @m.state() - def S0B_idle_connected(self): pass # pragma: no cover + def S0B_idle_connected(self): + pass # pragma: no cover + @m.state() - def S1B_wanting_connected(self): pass # pragma: no cover + def S1B_wanting_connected(self): + pass # pragma: no cover @m.input() - def connected(self): pass + def connected(self): + pass + @m.input() - def lost(self): pass + def lost(self): + pass + @m.input() - def refresh(self): pass + def refresh(self): + pass + @m.input() - def rx_nameplates(self, all_nameplates): pass + def rx_nameplates(self, all_nameplates): + pass @m.output() def RC_tx_list(self): self._RC.tx_list() + @m.output() def I_got_nameplates(self, all_nameplates): # We get a set of nameplate ids. There may be more attributes in the @@ -56,18 +75,19 @@ class Lister(object): S0A_idle_disconnected.upon(connected, enter=S0B_idle_connected, outputs=[]) S0B_idle_connected.upon(lost, enter=S0A_idle_disconnected, outputs=[]) - S0A_idle_disconnected.upon(refresh, - enter=S1A_wanting_disconnected, outputs=[]) - S1A_wanting_disconnected.upon(refresh, - enter=S1A_wanting_disconnected, outputs=[]) - S1A_wanting_disconnected.upon(connected, enter=S1B_wanting_connected, - outputs=[RC_tx_list]) - S0B_idle_connected.upon(refresh, enter=S1B_wanting_connected, - outputs=[RC_tx_list]) - S0B_idle_connected.upon(rx_nameplates, enter=S0B_idle_connected, - outputs=[I_got_nameplates]) - S1B_wanting_connected.upon(lost, enter=S1A_wanting_disconnected, outputs=[]) - S1B_wanting_connected.upon(refresh, enter=S1B_wanting_connected, - outputs=[RC_tx_list]) - S1B_wanting_connected.upon(rx_nameplates, enter=S0B_idle_connected, - outputs=[I_got_nameplates]) + S0A_idle_disconnected.upon( + refresh, enter=S1A_wanting_disconnected, outputs=[]) + S1A_wanting_disconnected.upon( + refresh, enter=S1A_wanting_disconnected, outputs=[]) + S1A_wanting_disconnected.upon( + connected, enter=S1B_wanting_connected, outputs=[RC_tx_list]) + S0B_idle_connected.upon( + refresh, enter=S1B_wanting_connected, outputs=[RC_tx_list]) + S0B_idle_connected.upon( + rx_nameplates, enter=S0B_idle_connected, outputs=[I_got_nameplates]) + S1B_wanting_connected.upon( + lost, enter=S1A_wanting_disconnected, outputs=[]) + S1B_wanting_connected.upon( + refresh, enter=S1B_wanting_connected, outputs=[RC_tx_list]) + S1B_wanting_connected.upon( + rx_nameplates, enter=S0B_idle_connected, outputs=[I_got_nameplates]) diff --git a/src/wormhole/_mailbox.py b/src/wormhole/_mailbox.py index 36167e6..6a6a956 100644 --- a/src/wormhole/_mailbox.py +++ b/src/wormhole/_mailbox.py @@ -1,16 +1,20 @@ -from __future__ import print_function, absolute_import, unicode_literals -from zope.interface import implementer -from attr import attrs, attrib +from __future__ import absolute_import, print_function, unicode_literals + +from attr import attrib, attrs from attr.validators import instance_of from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces + @attrs @implementer(_interfaces.IMailbox) class Mailbox(object): _side = attrib(validator=instance_of(type(u""))) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._mailbox = None @@ -29,52 +33,68 @@ class Mailbox(object): # S0: know nothing @m.state(initial=True) - def S0A(self): pass # pragma: no cover + def S0A(self): + pass # pragma: no cover + @m.state() - def S0B(self): pass # pragma: no cover + def S0B(self): + pass # pragma: no cover # S1: mailbox known, not opened @m.state() - def S1A(self): pass # pragma: no cover + def S1A(self): + pass # pragma: no cover # S2: mailbox known, opened # We've definitely tried to open the mailbox at least once, but it must # be re-opened with each connection, because open() is also subscribe() @m.state() - def S2A(self): pass # pragma: no cover + def S2A(self): + pass # pragma: no cover + @m.state() - def S2B(self): pass # pragma: no cover + def S2B(self): + pass # pragma: no cover # S3: closing @m.state() - def S3A(self): pass # pragma: no cover + def S3A(self): + pass # pragma: no cover + @m.state() - def S3B(self): pass # pragma: no cover + def S3B(self): + pass # pragma: no cover # S4: closed. We no longer care whether we're connected or not - #@m.state() - #def S4A(self): pass - #@m.state() - #def S4B(self): pass + # @m.state() + # def S4A(self): pass + # @m.state() + # def S4B(self): pass @m.state(terminal=True) - def S4(self): pass # pragma: no cover + def S4(self): + pass # pragma: no cover + S4A = S4 S4B = S4 - # from Terminator @m.input() - def close(self, mood): pass + def close(self, mood): + pass # from Nameplate @m.input() - def got_mailbox(self, mailbox): pass + def got_mailbox(self, mailbox): + pass # from RendezvousConnector @m.input() - def connected(self): pass + def connected(self): + pass + @m.input() - def lost(self): pass + def lost(self): + pass def rx_message(self, side, phase, body): assert isinstance(side, type("")), type(side) @@ -84,73 +104,91 @@ class Mailbox(object): self.rx_message_ours(phase, body) else: self.rx_message_theirs(side, phase, body) + @m.input() - def rx_message_ours(self, phase, body): pass + def rx_message_ours(self, phase, body): + pass + @m.input() - def rx_message_theirs(self, side, phase, body): pass + def rx_message_theirs(self, side, phase, body): + pass + @m.input() - def rx_closed(self): pass + def rx_closed(self): + pass # from Send or Key @m.input() def add_message(self, phase, body): pass - @m.output() def record_mailbox(self, mailbox): self._mailbox = mailbox + @m.output() def RC_tx_open(self): assert self._mailbox self._RC.tx_open(self._mailbox) + @m.output() def queue(self, phase, body): assert isinstance(phase, type("")), type(phase) assert isinstance(body, type(b"")), (type(body), phase, body) self._pending_outbound[phase] = body + @m.output() def record_mailbox_and_RC_tx_open_and_drain(self, mailbox): self._mailbox = mailbox self._RC.tx_open(mailbox) self._drain() + @m.output() def drain(self): self._drain() + def _drain(self): for phase, body in self._pending_outbound.items(): self._RC.tx_add(phase, body) + @m.output() def RC_tx_add(self, phase, body): assert isinstance(phase, type("")), type(phase) assert isinstance(body, type(b"")), type(body) self._RC.tx_add(phase, body) + @m.output() def N_release_and_accept(self, side, phase, body): self._N.release() if phase not in self._processed: self._processed.add(phase) self._O.got_message(side, phase, body) + @m.output() def RC_tx_close(self): assert self._mood self._RC_tx_close() + def _RC_tx_close(self): self._RC.tx_close(self._mailbox, self._mood) @m.output() def dequeue(self, phase, body): self._pending_outbound.pop(phase, None) + @m.output() def record_mood(self, mood): self._mood = mood + @m.output() def record_mood_and_RC_tx_close(self, mood): self._mood = mood self._RC_tx_close() + @m.output() def ignore_mood_and_T_mailbox_done(self, mood): self._T.mailbox_done() + @m.output() def T_mailbox_done(self): self._T.mailbox_done() @@ -162,8 +200,10 @@ class Mailbox(object): S0B.upon(lost, enter=S0A, outputs=[]) S0B.upon(add_message, enter=S0B, outputs=[queue]) S0B.upon(close, enter=S4B, outputs=[ignore_mood_and_T_mailbox_done]) - S0B.upon(got_mailbox, enter=S2B, - outputs=[record_mailbox_and_RC_tx_open_and_drain]) + S0B.upon( + got_mailbox, + enter=S2B, + outputs=[record_mailbox_and_RC_tx_open_and_drain]) S1A.upon(connected, enter=S2B, outputs=[RC_tx_open, drain]) S1A.upon(add_message, enter=S1A, outputs=[queue]) @@ -192,4 +232,3 @@ class Mailbox(object): S4.upon(rx_message_theirs, enter=S4, outputs=[]) S4.upon(rx_message_ours, enter=S4, outputs=[]) S4.upon(close, enter=S4, outputs=[]) - diff --git a/src/wormhole/_nameplate.py b/src/wormhole/_nameplate.py index b0d673c..b41fa72 100644 --- a/src/wormhole/_nameplate.py +++ b/src/wormhole/_nameplate.py @@ -1,7 +1,10 @@ -from __future__ import print_function, absolute_import, unicode_literals +from __future__ import absolute_import, print_function, unicode_literals + import re -from zope.interface import implementer + from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces from ._wordlist import PGPWordList from .errors import KeyFormatError @@ -9,13 +12,15 @@ from .errors import KeyFormatError def validate_nameplate(nameplate): if not re.search(r'^\d+$', nameplate): - raise KeyFormatError("Nameplate '%s' must be numeric, with no spaces." - % nameplate) + raise KeyFormatError( + "Nameplate '%s' must be numeric, with no spaces." % nameplate) + @implementer(_interfaces.INameplate) class Nameplate(object): m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __init__(self): self._nameplate = None @@ -32,94 +37,125 @@ class Nameplate(object): # S0: know nothing @m.state(initial=True) - def S0A(self): pass # pragma: no cover + def S0A(self): + pass # pragma: no cover + @m.state() - def S0B(self): pass # pragma: no cover + def S0B(self): + pass # pragma: no cover # S1: nameplate known, never claimed @m.state() - def S1A(self): pass # pragma: no cover + def S1A(self): + pass # pragma: no cover # S2: nameplate known, maybe claimed @m.state() - def S2A(self): pass # pragma: no cover + def S2A(self): + pass # pragma: no cover + @m.state() - def S2B(self): pass # pragma: no cover + def S2B(self): + pass # pragma: no cover # S3: nameplate claimed @m.state() - def S3A(self): pass # pragma: no cover + def S3A(self): + pass # pragma: no cover + @m.state() - def S3B(self): pass # pragma: no cover + def S3B(self): + pass # pragma: no cover # S4: maybe released @m.state() - def S4A(self): pass # pragma: no cover + def S4A(self): + pass # pragma: no cover + @m.state() - def S4B(self): pass # pragma: no cover + def S4B(self): + pass # pragma: no cover # S5: released # we no longer care whether we're connected or not - #@m.state() - #def S5A(self): pass - #@m.state() - #def S5B(self): pass + # @m.state() + # def S5A(self): pass + # @m.state() + # def S5B(self): pass @m.state() - def S5(self): pass # pragma: no cover + def S5(self): + pass # pragma: no cover + S5A = S5 S5B = S5 # from Boss def set_nameplate(self, nameplate): - validate_nameplate(nameplate) # can raise KeyFormatError + validate_nameplate(nameplate) # can raise KeyFormatError self._set_nameplate(nameplate) + @m.input() - def _set_nameplate(self, nameplate): pass + def _set_nameplate(self, nameplate): + pass # from Mailbox @m.input() - def release(self): pass + def release(self): + pass # from Terminator @m.input() - def close(self): pass + def close(self): + pass # from RendezvousConnector @m.input() - def connected(self): pass - @m.input() - def lost(self): pass + def connected(self): + pass @m.input() - def rx_claimed(self, mailbox): pass + def lost(self): + pass + @m.input() - def rx_released(self): pass + def rx_claimed(self, mailbox): + pass + + @m.input() + def rx_released(self): + pass @m.output() def record_nameplate(self, nameplate): validate_nameplate(nameplate) self._nameplate = nameplate + @m.output() def record_nameplate_and_RC_tx_claim(self, nameplate): validate_nameplate(nameplate) self._nameplate = nameplate self._RC.tx_claim(self._nameplate) + @m.output() def RC_tx_claim(self): # when invoked via M.connected(), we must use the stored nameplate self._RC.tx_claim(self._nameplate) + @m.output() def I_got_wordlist(self, mailbox): # TODO select wordlist based on nameplate properties, in rx_claimed wordlist = PGPWordList() self._I.got_wordlist(wordlist) + @m.output() def M_got_mailbox(self, mailbox): self._M.got_mailbox(mailbox) + @m.output() def RC_tx_release(self): assert self._nameplate self._RC.tx_release(self._nameplate) + @m.output() def T_nameplate_done(self): self._T.nameplate_done() @@ -127,8 +163,8 @@ class Nameplate(object): S0A.upon(_set_nameplate, enter=S1A, outputs=[record_nameplate]) S0A.upon(connected, enter=S0B, outputs=[]) S0A.upon(close, enter=S5A, outputs=[T_nameplate_done]) - S0B.upon(_set_nameplate, enter=S2B, - outputs=[record_nameplate_and_RC_tx_claim]) + S0B.upon( + _set_nameplate, enter=S2B, outputs=[record_nameplate_and_RC_tx_claim]) S0B.upon(lost, enter=S0A, outputs=[]) S0B.upon(close, enter=S5A, outputs=[T_nameplate_done]) @@ -144,7 +180,7 @@ class Nameplate(object): S3A.upon(connected, enter=S3B, outputs=[]) S3A.upon(close, enter=S4A, outputs=[]) S3B.upon(lost, enter=S3A, outputs=[]) - #S3B.upon(rx_claimed, enter=S3B, outputs=[]) # shouldn't happen + # S3B.upon(rx_claimed, enter=S3B, outputs=[]) # shouldn't happen S3B.upon(release, enter=S4B, outputs=[RC_tx_release]) S3B.upon(close, enter=S4B, outputs=[RC_tx_release]) @@ -153,7 +189,7 @@ class Nameplate(object): S4B.upon(lost, enter=S4A, outputs=[]) S4B.upon(rx_claimed, enter=S4B, outputs=[]) S4B.upon(rx_released, enter=S5B, outputs=[T_nameplate_done]) - S4B.upon(release, enter=S4B, outputs=[]) # mailbox is lazy + S4B.upon(release, enter=S4B, outputs=[]) # mailbox is lazy # Mailbox doesn't remember how many times it's sent a release, and will # re-send a new one for each peer message it receives. Ignoring it here # is easier than adding a new pair of states to Mailbox. @@ -161,5 +197,5 @@ class Nameplate(object): S5A.upon(connected, enter=S5B, outputs=[]) S5B.upon(lost, enter=S5A, outputs=[]) - S5.upon(release, enter=S5, outputs=[]) # mailbox is lazy + S5.upon(release, enter=S5, outputs=[]) # mailbox is lazy S5.upon(close, enter=S5, outputs=[]) diff --git a/src/wormhole/_order.py b/src/wormhole/_order.py index e2128df..18f6a8e 100644 --- a/src/wormhole/_order.py +++ b/src/wormhole/_order.py @@ -1,32 +1,40 @@ -from __future__ import print_function, absolute_import, unicode_literals -from zope.interface import implementer -from attr import attrs, attrib -from attr.validators import provides, instance_of +from __future__ import absolute_import, print_function, unicode_literals + +from attr import attrib, attrs +from attr.validators import instance_of, provides from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces + @attrs @implementer(_interfaces.IOrder) class Order(object): _side = attrib(validator=instance_of(type(u""))) _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._key = None self._queue = [] + def wire(self, key, receive): self._K = _interfaces.IKey(key) self._R = _interfaces.IReceive(receive) @m.state(initial=True) - def S0_no_pake(self): pass # pragma: no cover + def S0_no_pake(self): + pass # pragma: no cover + @m.state(terminal=True) - def S1_yes_pake(self): pass # pragma: no cover + def S1_yes_pake(self): + pass # pragma: no cover def got_message(self, side, phase, body): - #print("ORDER[%s].got_message(%s)" % (self._side, phase)) + # print("ORDER[%s].got_message(%s)" % (self._side, phase)) assert isinstance(side, type("")), type(phase) assert isinstance(phase, type("")), type(phase) assert isinstance(body, type(b"")), type(body) @@ -36,9 +44,12 @@ class Order(object): self.got_non_pake(side, phase, body) @m.input() - def got_pake(self, side, phase, body): pass + def got_pake(self, side, phase, body): + pass + @m.input() - def got_non_pake(self, side, phase, body): pass + def got_non_pake(self, side, phase, body): + pass @m.output() def queue(self, side, phase, body): @@ -46,9 +57,11 @@ class Order(object): assert isinstance(phase, type("")), type(phase) assert isinstance(body, type(b"")), type(body) self._queue.append((side, phase, body)) + @m.output() def notify_key(self, side, phase, body): self._K.got_pake(body) + @m.output() def drain(self, side, phase, body): del phase @@ -56,6 +69,7 @@ class Order(object): for (side, phase, body) in self._queue: self._deliver(side, phase, body) self._queue[:] = [] + @m.output() def deliver(self, side, phase, body): self._deliver(side, phase, body) diff --git a/src/wormhole/_receive.py b/src/wormhole/_receive.py index 4fa2858..8e9de4f 100644 --- a/src/wormhole/_receive.py +++ b/src/wormhole/_receive.py @@ -1,10 +1,13 @@ -from __future__ import print_function, absolute_import, unicode_literals -from zope.interface import implementer -from attr import attrs, attrib -from attr.validators import provides, instance_of +from __future__ import absolute_import, print_function, unicode_literals + +from attr import attrib, attrs +from attr.validators import instance_of, provides from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces -from ._key import derive_key, derive_phase_key, decrypt_data, CryptoError +from ._key import CryptoError, decrypt_data, derive_key, derive_phase_key + @attrs @implementer(_interfaces.IReceive) @@ -12,7 +15,8 @@ class Receive(object): _side = attrib(validator=instance_of(type(u""))) _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._key = None @@ -22,13 +26,20 @@ class Receive(object): self._S = _interfaces.ISend(send) @m.state(initial=True) - def S0_unknown_key(self): pass # pragma: no cover + def S0_unknown_key(self): + pass # pragma: no cover + @m.state() - def S1_unverified_key(self): pass # pragma: no cover + def S1_unverified_key(self): + pass # pragma: no cover + @m.state() - def S2_verified_key(self): pass # pragma: no cover + def S2_verified_key(self): + pass # pragma: no cover + @m.state(terminal=True) - def S3_scared(self): pass # pragma: no cover + def S3_scared(self): + pass # pragma: no cover # from Ordering def got_message(self, side, phase, body): @@ -43,47 +54,56 @@ class Receive(object): self.got_message_bad() return self.got_message_good(phase, plaintext) + @m.input() - def got_message_good(self, phase, plaintext): pass + def got_message_good(self, phase, plaintext): + pass + @m.input() - def got_message_bad(self): pass + def got_message_bad(self): + pass # from Key @m.input() - def got_key(self, key): pass + def got_key(self, key): + pass @m.output() def record_key(self, key): self._key = key + @m.output() def S_got_verified_key(self, phase, plaintext): assert self._key self._S.got_verified_key(self._key) + @m.output() def W_happy(self, phase, plaintext): self._B.happy() + @m.output() def W_got_verifier(self, phase, plaintext): self._B.got_verifier(derive_key(self._key, b"wormhole:verifier")) + @m.output() def W_got_message(self, phase, plaintext): assert isinstance(phase, type("")), type(phase) assert isinstance(plaintext, type(b"")), type(plaintext) self._B.got_message(phase, plaintext) + @m.output() def W_scared(self): self._B.scared() S0_unknown_key.upon(got_key, enter=S1_unverified_key, outputs=[record_key]) - S1_unverified_key.upon(got_message_good, enter=S2_verified_key, - outputs=[S_got_verified_key, - W_happy, W_got_verifier, W_got_message]) - S1_unverified_key.upon(got_message_bad, enter=S3_scared, - outputs=[W_scared]) - S2_verified_key.upon(got_message_bad, enter=S3_scared, - outputs=[W_scared]) - S2_verified_key.upon(got_message_good, enter=S2_verified_key, - outputs=[W_got_message]) + S1_unverified_key.upon( + got_message_good, + enter=S2_verified_key, + outputs=[S_got_verified_key, W_happy, W_got_verifier, W_got_message]) + S1_unverified_key.upon( + got_message_bad, enter=S3_scared, outputs=[W_scared]) + S2_verified_key.upon(got_message_bad, enter=S3_scared, outputs=[W_scared]) + S2_verified_key.upon( + got_message_good, enter=S2_verified_key, outputs=[W_got_message]) S3_scared.upon(got_message_good, enter=S3_scared, outputs=[]) S3_scared.upon(got_message_bad, enter=S3_scared, outputs=[]) - diff --git a/src/wormhole/_rendezvous.py b/src/wormhole/_rendezvous.py index 963d12a..65d6545 100644 --- a/src/wormhole/_rendezvous.py +++ b/src/wormhole/_rendezvous.py @@ -9,44 +9,47 @@ from twisted.internet import defer, endpoints, task from twisted.application import internet from autobahn.twisted import websocket from . import _interfaces, errors -from .util import (bytes_to_hexstr, hexstr_to_bytes, - bytes_to_dict, dict_to_bytes) +from .util import (bytes_to_hexstr, hexstr_to_bytes, bytes_to_dict, + dict_to_bytes) + class WSClient(websocket.WebSocketClientProtocol): def onConnect(self, response): # this fires during WebSocket negotiation, and isn't very useful # unless you want to modify the protocol settings - #print("onConnect", response) + # print("onConnect", response) pass def onOpen(self, *args): # this fires when the WebSocket is ready to go. No arguments - #print("onOpen", args) - #self.wormhole_open = True + # print("onOpen", args) + # self.wormhole_open = True self._RC.ws_open(self) def onMessage(self, payload, isBinary): assert not isBinary try: self._RC.ws_message(payload) - except: + except Exception: from twisted.python.failure import Failure print("LOGGING", Failure()) log.err() raise def onClose(self, wasClean, code, reason): - #print("onClose") + # print("onClose") self._RC.ws_close(wasClean, code, reason) - #if self.wormhole_open: - # self.wormhole._ws_closed(wasClean, code, reason) - #else: - # # we closed before establishing a connection (onConnect) or - # # finishing WebSocket negotiation (onOpen): errback - # self.factory.d.errback(error.ConnectError(reason)) + # if self.wormhole_open: + # self.wormhole._ws_closed(wasClean, code, reason) + # else: + # # we closed before establishing a connection (onConnect) or + # # finishing WebSocket negotiation (onOpen): errback + # self.factory.d.errback(error.ConnectError(reason)) + class WSFactory(websocket.WebSocketClientFactory): protocol = WSClient + def __init__(self, RC, *args, **kwargs): websocket.WebSocketClientFactory.__init__(self, *args, **kwargs) self._RC = RC @@ -54,9 +57,10 @@ class WSFactory(websocket.WebSocketClientFactory): def buildProtocol(self, addr): proto = websocket.WebSocketClientFactory.buildProtocol(self, addr) proto._RC = self._RC - #proto.wormhole_open = False + # proto.wormhole_open = False return proto + @attrs @implementer(_interfaces.IRendezvousConnector) class RendezvousConnector(object): @@ -90,6 +94,7 @@ class RendezvousConnector(object): def set_trace(self, f): self._trace = f + def _debug(self, what): if self._trace: self._trace(old_state="", input=what, new_state="") @@ -133,14 +138,13 @@ class RendezvousConnector(object): def stop(self): # ClientService.stopService is defined to "Stop attempting to # reconnect and close any existing connections" - self._stopping = True # to catch _initial_connection_failed error + self._stopping = True # to catch _initial_connection_failed error d = defer.maybeDeferred(self._connector.stopService) # ClientService.stopService always fires with None, even if the # initial connection failed, so log.err just in case d.addErrback(log.err) d.addBoth(self._stopped) - # from Lister def tx_list(self): self._tx("list") @@ -157,7 +161,7 @@ class RendezvousConnector(object): # this should happen right away: the ClientService ought to be in # the "_waiting" state, and everything in the _waiting.stop # transition is immediate - d.addErrback(log.err) # just in case something goes wrong + d.addErrback(log.err) # just in case something goes wrong d.addCallback(lambda _: self._B.error(sce)) # from our WSClient (the WebSocket protocol) @@ -166,8 +170,11 @@ class RendezvousConnector(object): self._have_made_a_successful_connection = True self._ws = proto try: - self._tx("bind", appid=self._appid, side=self._side, - client_version=self._client_version) + self._tx( + "bind", + appid=self._appid, + side=self._side, + client_version=self._client_version) self._N.connected() self._M.connected() self._L.connected() @@ -180,19 +187,22 @@ class RendezvousConnector(object): def ws_message(self, payload): msg = bytes_to_dict(payload) if msg["type"] != "ack": - self._debug("R.rx(%s %s%s)" % - (msg["type"], msg.get("phase",""), - "[mine]" if msg.get("side","") == self._side else "", - )) + self._debug("R.rx(%s %s%s)" % ( + msg["type"], + msg.get("phase", ""), + "[mine]" if msg.get("side", "") == self._side else "", + )) self._timing.add("ws_receive", _side=self._side, message=msg) if self._debug_record_inbound_f: self._debug_record_inbound_f(msg) mtype = msg["type"] - meth = getattr(self, "_response_handle_"+mtype, None) + meth = getattr(self, "_response_handle_" + mtype, None) if not meth: # make tests fail, but real application will ignore it - log.err(errors._UnknownMessageTypeError("Unknown inbound message type %r" % (msg,))) + log.err( + errors._UnknownMessageTypeError( + "Unknown inbound message type %r" % (msg, ))) return try: return meth(msg) @@ -229,7 +239,7 @@ class RendezvousConnector(object): # valid connection sce = errors.ServerConnectionError(self._url, reason) d = defer.maybeDeferred(self._connector.stopService) - d.addErrback(log.err) # just in case something goes wrong + d.addErrback(log.err) # just in case something goes wrong # tell the Boss to quit and inform the user d.addCallback(lambda _: self._B.error(sce)) @@ -292,7 +302,7 @@ class RendezvousConnector(object): side = msg["side"] phase = msg["phase"] assert isinstance(phase, type("")), type(phase) - body = hexstr_to_bytes(msg["body"]) # bytes + body = hexstr_to_bytes(msg["body"]) # bytes self._M.rx_message(side, phase, body) def _response_handle_released(self, msg): @@ -301,5 +311,4 @@ class RendezvousConnector(object): def _response_handle_closed(self, msg): self._M.rx_closed() - # record, message, payload, packet, bundle, ciphertext, plaintext diff --git a/src/wormhole/_rlcompleter.py b/src/wormhole/_rlcompleter.py index 853b98f..e9262db 100644 --- a/src/wormhole/_rlcompleter.py +++ b/src/wormhole/_rlcompleter.py @@ -1,33 +1,41 @@ from __future__ import print_function, unicode_literals + import traceback from sys import stderr + +from attr import attrib, attrs +from six.moves import input +from twisted.internet.defer import inlineCallbacks, returnValue +from twisted.internet.threads import blockingCallFromThread, deferToThread + +from .errors import AlreadyInputNameplateError, KeyFormatError + try: import readline except ImportError: readline = None -from six.moves import input -from attr import attrs, attrib -from twisted.internet.defer import inlineCallbacks, returnValue -from twisted.internet.threads import deferToThread, blockingCallFromThread -from .errors import KeyFormatError, AlreadyInputNameplateError errf = None + + # uncomment this to enable tab-completion debugging -#import os ; errf = open("err", "w") if os.path.exists("err") else None -def debug(*args, **kwargs): # pragma: no cover +# import os ; errf = open("err", "w") if os.path.exists("err") else None +def debug(*args, **kwargs): # pragma: no cover if errf: print(*args, file=errf, **kwargs) errf.flush() + @attrs class CodeInputter(object): _input_helper = attrib() _reactor = attrib() + def __attrs_post_init__(self): self.used_completion = False self._matches = None # once we've claimed the nameplate, we can't go back - self._committed_nameplate = None # or string + self._committed_nameplate = None # or string def bcft(self, f, *a, **kw): return blockingCallFromThread(self._reactor, f, *a, **kw) @@ -66,7 +74,7 @@ class CodeInputter(object): nameplate, words = text.split("-", 1) else: got_nameplate = False - nameplate = text # partial + nameplate = text # partial # 'text' is one of these categories: # "" or "12": complete on nameplates (all that match, maybe just one) @@ -83,13 +91,15 @@ class CodeInputter(object): # they deleted past the committment point: we can't use # this. For now, bail, but in the future let's find a # gentler way to encourage them to not do that. - raise AlreadyInputNameplateError("nameplate (%s-) already entered, cannot go back" % self._committed_nameplate) + raise AlreadyInputNameplateError( + "nameplate (%s-) already entered, cannot go back" % + self._committed_nameplate) if not got_nameplate: # we're completing on nameplates: "" or "12" or "123" - self.bcft(ih.refresh_nameplates) # results arrive later + self.bcft(ih.refresh_nameplates) # results arrive later debug(" getting nameplates") completions = self.bcft(ih.get_nameplate_completions, nameplate) - else: # "123-" or "123-supp" + else: # "123-" or "123-supp" # time to commit to this nameplate, if they haven't already if not self._committed_nameplate: debug(" choose_nameplate(%s)" % nameplate) @@ -112,11 +122,13 @@ class CodeInputter(object): # heard about it from the server), it can't be helped. But # for the rest of the code, a simple wait-for-wordlist will # improve the user experience. - self.bcft(ih.when_wordlist_is_available) # blocks on CLAIM + self.bcft(ih.when_wordlist_is_available) # blocks on CLAIM # and we're completing on words now - debug(" getting words (%s)" % (words,)) - completions = [nameplate+"-"+c - for c in self.bcft(ih.get_word_completions, words)] + debug(" getting words (%s)" % (words, )) + completions = [ + nameplate + "-" + c + for c in self.bcft(ih.get_word_completions, words) + ] # rlcompleter wants full strings return sorted(completions) @@ -131,13 +143,16 @@ class CodeInputter(object): # they deleted past the committment point: we can't use # this. For now, bail, but in the future let's find a # gentler way to encourage them to not do that. - raise AlreadyInputNameplateError("nameplate (%s-) already entered, cannot go back" % self._committed_nameplate) + raise AlreadyInputNameplateError( + "nameplate (%s-) already entered, cannot go back" % + self._committed_nameplate) else: debug(" choose_nameplate(%s)" % nameplate) self.bcft(self._input_helper.choose_nameplate, nameplate) debug(" choose_words(%s)" % words) self.bcft(self._input_helper.choose_words, words) + def _input_code_with_completion(prompt, input_helper, reactor): # reminder: this all occurs in a separate thread. All calls to input_helper # must go through blockingCallFromThread() @@ -159,6 +174,7 @@ def _input_code_with_completion(prompt, input_helper, reactor): c.finish(code) return c.used_completion + def warn_readline(): # When our process receives a SIGINT, Twisted's SIGINT handler will # stop the reactor and wait for all threads to terminate before the @@ -192,11 +208,12 @@ def warn_readline(): # doesn't see the signal, and we must still wait for stdin to make # readline finish. + @inlineCallbacks def input_with_completion(prompt, input_helper, reactor): t = reactor.addSystemEventTrigger("before", "shutdown", warn_readline) - #input_helper.refresh_nameplates() - used_completion = yield deferToThread(_input_code_with_completion, - prompt, input_helper, reactor) + # input_helper.refresh_nameplates() + used_completion = yield deferToThread(_input_code_with_completion, prompt, + input_helper, reactor) reactor.removeSystemEventTrigger(t) returnValue(used_completion) diff --git a/src/wormhole/_send.py b/src/wormhole/_send.py index d0c45c0..e0293c6 100644 --- a/src/wormhole/_send.py +++ b/src/wormhole/_send.py @@ -1,18 +1,22 @@ -from __future__ import print_function, absolute_import, unicode_literals -from attr import attrs, attrib -from attr.validators import provides, instance_of -from zope.interface import implementer +from __future__ import absolute_import, print_function, unicode_literals + +from attr import attrib, attrs +from attr.validators import instance_of, provides from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces from ._key import derive_phase_key, encrypt_data + @attrs @implementer(_interfaces.ISend) class Send(object): _side = attrib(validator=instance_of(type(u""))) _timing = attrib(validator=provides(_interfaces.ITiming)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._queue = [] @@ -21,31 +25,40 @@ class Send(object): self._M = _interfaces.IMailbox(mailbox) @m.state(initial=True) - def S0_no_key(self): pass # pragma: no cover + def S0_no_key(self): + pass # pragma: no cover + @m.state(terminal=True) - def S1_verified_key(self): pass # pragma: no cover + def S1_verified_key(self): + pass # pragma: no cover # from Receive @m.input() - def got_verified_key(self, key): pass + def got_verified_key(self, key): + pass + # from Boss @m.input() - def send(self, phase, plaintext): pass + def send(self, phase, plaintext): + pass @m.output() def queue(self, phase, plaintext): assert isinstance(phase, type("")), type(phase) assert isinstance(plaintext, type(b"")), type(plaintext) self._queue.append((phase, plaintext)) + @m.output() def record_key(self, key): self._key = key + @m.output() def drain(self, key): del key for (phase, plaintext) in self._queue: self._encrypt_and_send(phase, plaintext) self._queue[:] = [] + @m.output() def deliver(self, phase, plaintext): assert isinstance(phase, type("")), type(phase) @@ -59,6 +72,6 @@ class Send(object): self._M.add_message(phase, encrypted) S0_no_key.upon(send, enter=S0_no_key, outputs=[queue]) - S0_no_key.upon(got_verified_key, enter=S1_verified_key, - outputs=[record_key, drain]) + S0_no_key.upon( + got_verified_key, enter=S1_verified_key, outputs=[record_key, drain]) S1_verified_key.upon(send, enter=S1_verified_key, outputs=[deliver]) diff --git a/src/wormhole/_terminator.py b/src/wormhole/_terminator.py index 2cb7327..fe4bdcb 100644 --- a/src/wormhole/_terminator.py +++ b/src/wormhole/_terminator.py @@ -1,12 +1,16 @@ -from __future__ import print_function, absolute_import, unicode_literals -from zope.interface import implementer +from __future__ import absolute_import, print_function, unicode_literals + from automat import MethodicalMachine +from zope.interface import implementer + from . import _interfaces + @implementer(_interfaces.ITerminator) class Terminator(object): m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", + lambda self, f: None) # pragma: no cover def __init__(self): self._mood = None @@ -29,48 +33,68 @@ class Terminator(object): # done, and we're closing, then we stop the RendezvousConnector @m.state(initial=True) - def Snmo(self): pass # pragma: no cover - @m.state() - def Smo(self): pass # pragma: no cover - @m.state() - def Sno(self): pass # pragma: no cover - @m.state() - def S0o(self): pass # pragma: no cover + def Snmo(self): + pass # pragma: no cover @m.state() - def Snm(self): pass # pragma: no cover - @m.state() - def Sm(self): pass # pragma: no cover - @m.state() - def Sn(self): pass # pragma: no cover - #@m.state() - #def S0(self): pass # unused + def Smo(self): + pass # pragma: no cover @m.state() - def S_stopping(self): pass # pragma: no cover + def Sno(self): + pass # pragma: no cover + @m.state() - def S_stopped(self, terminal=True): pass # pragma: no cover + def S0o(self): + pass # pragma: no cover + + @m.state() + def Snm(self): + pass # pragma: no cover + + @m.state() + def Sm(self): + pass # pragma: no cover + + @m.state() + def Sn(self): + pass # pragma: no cover + + # @m.state() + # def S0(self): pass # unused + + @m.state() + def S_stopping(self): + pass # pragma: no cover + + @m.state() + def S_stopped(self, terminal=True): + pass # pragma: no cover # from Boss @m.input() - def close(self, mood): pass + def close(self, mood): + pass # from Nameplate @m.input() - def nameplate_done(self): pass + def nameplate_done(self): + pass # from Mailbox @m.input() - def mailbox_done(self): pass + def mailbox_done(self): + pass # from RendezvousConnector @m.input() - def stopped(self): pass - + def stopped(self): + pass @m.output() def close_nameplate(self, mood): - self._N.close() # ignores mood + self._N.close() # ignores mood + @m.output() def close_mailbox(self, mood): self._M.close(mood) @@ -78,9 +102,11 @@ class Terminator(object): @m.output() def ignore_mood_and_RC_stop(self, mood): self._RC.stop() + @m.output() def RC_stop(self): self._RC.stop() + @m.output() def B_closed(self): self._B.closed() @@ -99,8 +125,10 @@ class Terminator(object): Snm.upon(nameplate_done, enter=Sm, outputs=[]) Sn.upon(nameplate_done, enter=S_stopping, outputs=[RC_stop]) - S0o.upon(close, enter=S_stopping, - outputs=[close_nameplate, close_mailbox, ignore_mood_and_RC_stop]) + S0o.upon( + close, + enter=S_stopping, + outputs=[close_nameplate, close_mailbox, ignore_mood_and_RC_stop]) Sm.upon(mailbox_done, enter=S_stopping, outputs=[RC_stop]) S_stopping.upon(stopped, enter=S_stopped, outputs=[B_closed]) diff --git a/src/wormhole/_wordlist.py b/src/wormhole/_wordlist.py index e14972b..b434c71 100644 --- a/src/wormhole/_wordlist.py +++ b/src/wormhole/_wordlist.py @@ -1,6 +1,10 @@ -from __future__ import unicode_literals, print_function +from __future__ import print_function, unicode_literals + import os +from binascii import unhexlify + from zope.interface import implementer + from ._interfaces import IWordlist # The PGP Word List, which maps bytes to phonetically-distinct words. There @@ -10,154 +14,280 @@ from ._interfaces import IWordlist # Thanks to Warren Guy for transcribing them: # https://github.com/warrenguy/javascript-pgp-word-list -from binascii import unhexlify raw_words = { -'00': ['aardvark', 'adroitness'], '01': ['absurd', 'adviser'], -'02': ['accrue', 'aftermath'], '03': ['acme', 'aggregate'], -'04': ['adrift', 'alkali'], '05': ['adult', 'almighty'], -'06': ['afflict', 'amulet'], '07': ['ahead', 'amusement'], -'08': ['aimless', 'antenna'], '09': ['Algol', 'applicant'], -'0A': ['allow', 'Apollo'], '0B': ['alone', 'armistice'], -'0C': ['ammo', 'article'], '0D': ['ancient', 'asteroid'], -'0E': ['apple', 'Atlantic'], '0F': ['artist', 'atmosphere'], -'10': ['assume', 'autopsy'], '11': ['Athens', 'Babylon'], -'12': ['atlas', 'backwater'], '13': ['Aztec', 'barbecue'], -'14': ['baboon', 'belowground'], '15': ['backfield', 'bifocals'], -'16': ['backward', 'bodyguard'], '17': ['banjo', 'bookseller'], -'18': ['beaming', 'borderline'], '19': ['bedlamp', 'bottomless'], -'1A': ['beehive', 'Bradbury'], '1B': ['beeswax', 'bravado'], -'1C': ['befriend', 'Brazilian'], '1D': ['Belfast', 'breakaway'], -'1E': ['berserk', 'Burlington'], '1F': ['billiard', 'businessman'], -'20': ['bison', 'butterfat'], '21': ['blackjack', 'Camelot'], -'22': ['blockade', 'candidate'], '23': ['blowtorch', 'cannonball'], -'24': ['bluebird', 'Capricorn'], '25': ['bombast', 'caravan'], -'26': ['bookshelf', 'caretaker'], '27': ['brackish', 'celebrate'], -'28': ['breadline', 'cellulose'], '29': ['breakup', 'certify'], -'2A': ['brickyard', 'chambermaid'], '2B': ['briefcase', 'Cherokee'], -'2C': ['Burbank', 'Chicago'], '2D': ['button', 'clergyman'], -'2E': ['buzzard', 'coherence'], '2F': ['cement', 'combustion'], -'30': ['chairlift', 'commando'], '31': ['chatter', 'company'], -'32': ['checkup', 'component'], '33': ['chisel', 'concurrent'], -'34': ['choking', 'confidence'], '35': ['chopper', 'conformist'], -'36': ['Christmas', 'congregate'], '37': ['clamshell', 'consensus'], -'38': ['classic', 'consulting'], '39': ['classroom', 'corporate'], -'3A': ['cleanup', 'corrosion'], '3B': ['clockwork', 'councilman'], -'3C': ['cobra', 'crossover'], '3D': ['commence', 'crucifix'], -'3E': ['concert', 'cumbersome'], '3F': ['cowbell', 'customer'], -'40': ['crackdown', 'Dakota'], '41': ['cranky', 'decadence'], -'42': ['crowfoot', 'December'], '43': ['crucial', 'decimal'], -'44': ['crumpled', 'designing'], '45': ['crusade', 'detector'], -'46': ['cubic', 'detergent'], '47': ['dashboard', 'determine'], -'48': ['deadbolt', 'dictator'], '49': ['deckhand', 'dinosaur'], -'4A': ['dogsled', 'direction'], '4B': ['dragnet', 'disable'], -'4C': ['drainage', 'disbelief'], '4D': ['dreadful', 'disruptive'], -'4E': ['drifter', 'distortion'], '4F': ['dropper', 'document'], -'50': ['drumbeat', 'embezzle'], '51': ['drunken', 'enchanting'], -'52': ['Dupont', 'enrollment'], '53': ['dwelling', 'enterprise'], -'54': ['eating', 'equation'], '55': ['edict', 'equipment'], -'56': ['egghead', 'escapade'], '57': ['eightball', 'Eskimo'], -'58': ['endorse', 'everyday'], '59': ['endow', 'examine'], -'5A': ['enlist', 'existence'], '5B': ['erase', 'exodus'], -'5C': ['escape', 'fascinate'], '5D': ['exceed', 'filament'], -'5E': ['eyeglass', 'finicky'], '5F': ['eyetooth', 'forever'], -'60': ['facial', 'fortitude'], '61': ['fallout', 'frequency'], -'62': ['flagpole', 'gadgetry'], '63': ['flatfoot', 'Galveston'], -'64': ['flytrap', 'getaway'], '65': ['fracture', 'glossary'], -'66': ['framework', 'gossamer'], '67': ['freedom', 'graduate'], -'68': ['frighten', 'gravity'], '69': ['gazelle', 'guitarist'], -'6A': ['Geiger', 'hamburger'], '6B': ['glitter', 'Hamilton'], -'6C': ['glucose', 'handiwork'], '6D': ['goggles', 'hazardous'], -'6E': ['goldfish', 'headwaters'], '6F': ['gremlin', 'hemisphere'], -'70': ['guidance', 'hesitate'], '71': ['hamlet', 'hideaway'], -'72': ['highchair', 'holiness'], '73': ['hockey', 'hurricane'], -'74': ['indoors', 'hydraulic'], '75': ['indulge', 'impartial'], -'76': ['inverse', 'impetus'], '77': ['involve', 'inception'], -'78': ['island', 'indigo'], '79': ['jawbone', 'inertia'], -'7A': ['keyboard', 'infancy'], '7B': ['kickoff', 'inferno'], -'7C': ['kiwi', 'informant'], '7D': ['klaxon', 'insincere'], -'7E': ['locale', 'insurgent'], '7F': ['lockup', 'integrate'], -'80': ['merit', 'intention'], '81': ['minnow', 'inventive'], -'82': ['miser', 'Istanbul'], '83': ['Mohawk', 'Jamaica'], -'84': ['mural', 'Jupiter'], '85': ['music', 'leprosy'], -'86': ['necklace', 'letterhead'], '87': ['Neptune', 'liberty'], -'88': ['newborn', 'maritime'], '89': ['nightbird', 'matchmaker'], -'8A': ['Oakland', 'maverick'], '8B': ['obtuse', 'Medusa'], -'8C': ['offload', 'megaton'], '8D': ['optic', 'microscope'], -'8E': ['orca', 'microwave'], '8F': ['payday', 'midsummer'], -'90': ['peachy', 'millionaire'], '91': ['pheasant', 'miracle'], -'92': ['physique', 'misnomer'], '93': ['playhouse', 'molasses'], -'94': ['Pluto', 'molecule'], '95': ['preclude', 'Montana'], -'96': ['prefer', 'monument'], '97': ['preshrunk', 'mosquito'], -'98': ['printer', 'narrative'], '99': ['prowler', 'nebula'], -'9A': ['pupil', 'newsletter'], '9B': ['puppy', 'Norwegian'], -'9C': ['python', 'October'], '9D': ['quadrant', 'Ohio'], -'9E': ['quiver', 'onlooker'], '9F': ['quota', 'opulent'], -'A0': ['ragtime', 'Orlando'], 'A1': ['ratchet', 'outfielder'], -'A2': ['rebirth', 'Pacific'], 'A3': ['reform', 'pandemic'], -'A4': ['regain', 'Pandora'], 'A5': ['reindeer', 'paperweight'], -'A6': ['rematch', 'paragon'], 'A7': ['repay', 'paragraph'], -'A8': ['retouch', 'paramount'], 'A9': ['revenge', 'passenger'], -'AA': ['reward', 'pedigree'], 'AB': ['rhythm', 'Pegasus'], -'AC': ['ribcage', 'penetrate'], 'AD': ['ringbolt', 'perceptive'], -'AE': ['robust', 'performance'], 'AF': ['rocker', 'pharmacy'], -'B0': ['ruffled', 'phonetic'], 'B1': ['sailboat', 'photograph'], -'B2': ['sawdust', 'pioneer'], 'B3': ['scallion', 'pocketful'], -'B4': ['scenic', 'politeness'], 'B5': ['scorecard', 'positive'], -'B6': ['Scotland', 'potato'], 'B7': ['seabird', 'processor'], -'B8': ['select', 'provincial'], 'B9': ['sentence', 'proximate'], -'BA': ['shadow', 'puberty'], 'BB': ['shamrock', 'publisher'], -'BC': ['showgirl', 'pyramid'], 'BD': ['skullcap', 'quantity'], -'BE': ['skydive', 'racketeer'], 'BF': ['slingshot', 'rebellion'], -'C0': ['slowdown', 'recipe'], 'C1': ['snapline', 'recover'], -'C2': ['snapshot', 'repellent'], 'C3': ['snowcap', 'replica'], -'C4': ['snowslide', 'reproduce'], 'C5': ['solo', 'resistor'], -'C6': ['southward', 'responsive'], 'C7': ['soybean', 'retraction'], -'C8': ['spaniel', 'retrieval'], 'C9': ['spearhead', 'retrospect'], -'CA': ['spellbind', 'revenue'], 'CB': ['spheroid', 'revival'], -'CC': ['spigot', 'revolver'], 'CD': ['spindle', 'sandalwood'], -'CE': ['spyglass', 'sardonic'], 'CF': ['stagehand', 'Saturday'], -'D0': ['stagnate', 'savagery'], 'D1': ['stairway', 'scavenger'], -'D2': ['standard', 'sensation'], 'D3': ['stapler', 'sociable'], -'D4': ['steamship', 'souvenir'], 'D5': ['sterling', 'specialist'], -'D6': ['stockman', 'speculate'], 'D7': ['stopwatch', 'stethoscope'], -'D8': ['stormy', 'stupendous'], 'D9': ['sugar', 'supportive'], -'DA': ['surmount', 'surrender'], 'DB': ['suspense', 'suspicious'], -'DC': ['sweatband', 'sympathy'], 'DD': ['swelter', 'tambourine'], -'DE': ['tactics', 'telephone'], 'DF': ['talon', 'therapist'], -'E0': ['tapeworm', 'tobacco'], 'E1': ['tempest', 'tolerance'], -'E2': ['tiger', 'tomorrow'], 'E3': ['tissue', 'torpedo'], -'E4': ['tonic', 'tradition'], 'E5': ['topmost', 'travesty'], -'E6': ['tracker', 'trombonist'], 'E7': ['transit', 'truncated'], -'E8': ['trauma', 'typewriter'], 'E9': ['treadmill', 'ultimate'], -'EA': ['Trojan', 'undaunted'], 'EB': ['trouble', 'underfoot'], -'EC': ['tumor', 'unicorn'], 'ED': ['tunnel', 'unify'], -'EE': ['tycoon', 'universe'], 'EF': ['uncut', 'unravel'], -'F0': ['unearth', 'upcoming'], 'F1': ['unwind', 'vacancy'], -'F2': ['uproot', 'vagabond'], 'F3': ['upset', 'vertigo'], -'F4': ['upshot', 'Virginia'], 'F5': ['vapor', 'visitor'], -'F6': ['village', 'vocalist'], 'F7': ['virus', 'voyager'], -'F8': ['Vulcan', 'warranty'], 'F9': ['waffle', 'Waterloo'], -'FA': ['wallet', 'whimsical'], 'FB': ['watchword', 'Wichita'], -'FC': ['wayside', 'Wilmington'], 'FD': ['willow', 'Wyoming'], -'FE': ['woodlark', 'yesteryear'], 'FF': ['Zulu', 'Yucatan'] -}; + '00': ['aardvark', 'adroitness'], + '01': ['absurd', 'adviser'], + '02': ['accrue', 'aftermath'], + '03': ['acme', 'aggregate'], + '04': ['adrift', 'alkali'], + '05': ['adult', 'almighty'], + '06': ['afflict', 'amulet'], + '07': ['ahead', 'amusement'], + '08': ['aimless', 'antenna'], + '09': ['Algol', 'applicant'], + '0A': ['allow', 'Apollo'], + '0B': ['alone', 'armistice'], + '0C': ['ammo', 'article'], + '0D': ['ancient', 'asteroid'], + '0E': ['apple', 'Atlantic'], + '0F': ['artist', 'atmosphere'], + '10': ['assume', 'autopsy'], + '11': ['Athens', 'Babylon'], + '12': ['atlas', 'backwater'], + '13': ['Aztec', 'barbecue'], + '14': ['baboon', 'belowground'], + '15': ['backfield', 'bifocals'], + '16': ['backward', 'bodyguard'], + '17': ['banjo', 'bookseller'], + '18': ['beaming', 'borderline'], + '19': ['bedlamp', 'bottomless'], + '1A': ['beehive', 'Bradbury'], + '1B': ['beeswax', 'bravado'], + '1C': ['befriend', 'Brazilian'], + '1D': ['Belfast', 'breakaway'], + '1E': ['berserk', 'Burlington'], + '1F': ['billiard', 'businessman'], + '20': ['bison', 'butterfat'], + '21': ['blackjack', 'Camelot'], + '22': ['blockade', 'candidate'], + '23': ['blowtorch', 'cannonball'], + '24': ['bluebird', 'Capricorn'], + '25': ['bombast', 'caravan'], + '26': ['bookshelf', 'caretaker'], + '27': ['brackish', 'celebrate'], + '28': ['breadline', 'cellulose'], + '29': ['breakup', 'certify'], + '2A': ['brickyard', 'chambermaid'], + '2B': ['briefcase', 'Cherokee'], + '2C': ['Burbank', 'Chicago'], + '2D': ['button', 'clergyman'], + '2E': ['buzzard', 'coherence'], + '2F': ['cement', 'combustion'], + '30': ['chairlift', 'commando'], + '31': ['chatter', 'company'], + '32': ['checkup', 'component'], + '33': ['chisel', 'concurrent'], + '34': ['choking', 'confidence'], + '35': ['chopper', 'conformist'], + '36': ['Christmas', 'congregate'], + '37': ['clamshell', 'consensus'], + '38': ['classic', 'consulting'], + '39': ['classroom', 'corporate'], + '3A': ['cleanup', 'corrosion'], + '3B': ['clockwork', 'councilman'], + '3C': ['cobra', 'crossover'], + '3D': ['commence', 'crucifix'], + '3E': ['concert', 'cumbersome'], + '3F': ['cowbell', 'customer'], + '40': ['crackdown', 'Dakota'], + '41': ['cranky', 'decadence'], + '42': ['crowfoot', 'December'], + '43': ['crucial', 'decimal'], + '44': ['crumpled', 'designing'], + '45': ['crusade', 'detector'], + '46': ['cubic', 'detergent'], + '47': ['dashboard', 'determine'], + '48': ['deadbolt', 'dictator'], + '49': ['deckhand', 'dinosaur'], + '4A': ['dogsled', 'direction'], + '4B': ['dragnet', 'disable'], + '4C': ['drainage', 'disbelief'], + '4D': ['dreadful', 'disruptive'], + '4E': ['drifter', 'distortion'], + '4F': ['dropper', 'document'], + '50': ['drumbeat', 'embezzle'], + '51': ['drunken', 'enchanting'], + '52': ['Dupont', 'enrollment'], + '53': ['dwelling', 'enterprise'], + '54': ['eating', 'equation'], + '55': ['edict', 'equipment'], + '56': ['egghead', 'escapade'], + '57': ['eightball', 'Eskimo'], + '58': ['endorse', 'everyday'], + '59': ['endow', 'examine'], + '5A': ['enlist', 'existence'], + '5B': ['erase', 'exodus'], + '5C': ['escape', 'fascinate'], + '5D': ['exceed', 'filament'], + '5E': ['eyeglass', 'finicky'], + '5F': ['eyetooth', 'forever'], + '60': ['facial', 'fortitude'], + '61': ['fallout', 'frequency'], + '62': ['flagpole', 'gadgetry'], + '63': ['flatfoot', 'Galveston'], + '64': ['flytrap', 'getaway'], + '65': ['fracture', 'glossary'], + '66': ['framework', 'gossamer'], + '67': ['freedom', 'graduate'], + '68': ['frighten', 'gravity'], + '69': ['gazelle', 'guitarist'], + '6A': ['Geiger', 'hamburger'], + '6B': ['glitter', 'Hamilton'], + '6C': ['glucose', 'handiwork'], + '6D': ['goggles', 'hazardous'], + '6E': ['goldfish', 'headwaters'], + '6F': ['gremlin', 'hemisphere'], + '70': ['guidance', 'hesitate'], + '71': ['hamlet', 'hideaway'], + '72': ['highchair', 'holiness'], + '73': ['hockey', 'hurricane'], + '74': ['indoors', 'hydraulic'], + '75': ['indulge', 'impartial'], + '76': ['inverse', 'impetus'], + '77': ['involve', 'inception'], + '78': ['island', 'indigo'], + '79': ['jawbone', 'inertia'], + '7A': ['keyboard', 'infancy'], + '7B': ['kickoff', 'inferno'], + '7C': ['kiwi', 'informant'], + '7D': ['klaxon', 'insincere'], + '7E': ['locale', 'insurgent'], + '7F': ['lockup', 'integrate'], + '80': ['merit', 'intention'], + '81': ['minnow', 'inventive'], + '82': ['miser', 'Istanbul'], + '83': ['Mohawk', 'Jamaica'], + '84': ['mural', 'Jupiter'], + '85': ['music', 'leprosy'], + '86': ['necklace', 'letterhead'], + '87': ['Neptune', 'liberty'], + '88': ['newborn', 'maritime'], + '89': ['nightbird', 'matchmaker'], + '8A': ['Oakland', 'maverick'], + '8B': ['obtuse', 'Medusa'], + '8C': ['offload', 'megaton'], + '8D': ['optic', 'microscope'], + '8E': ['orca', 'microwave'], + '8F': ['payday', 'midsummer'], + '90': ['peachy', 'millionaire'], + '91': ['pheasant', 'miracle'], + '92': ['physique', 'misnomer'], + '93': ['playhouse', 'molasses'], + '94': ['Pluto', 'molecule'], + '95': ['preclude', 'Montana'], + '96': ['prefer', 'monument'], + '97': ['preshrunk', 'mosquito'], + '98': ['printer', 'narrative'], + '99': ['prowler', 'nebula'], + '9A': ['pupil', 'newsletter'], + '9B': ['puppy', 'Norwegian'], + '9C': ['python', 'October'], + '9D': ['quadrant', 'Ohio'], + '9E': ['quiver', 'onlooker'], + '9F': ['quota', 'opulent'], + 'A0': ['ragtime', 'Orlando'], + 'A1': ['ratchet', 'outfielder'], + 'A2': ['rebirth', 'Pacific'], + 'A3': ['reform', 'pandemic'], + 'A4': ['regain', 'Pandora'], + 'A5': ['reindeer', 'paperweight'], + 'A6': ['rematch', 'paragon'], + 'A7': ['repay', 'paragraph'], + 'A8': ['retouch', 'paramount'], + 'A9': ['revenge', 'passenger'], + 'AA': ['reward', 'pedigree'], + 'AB': ['rhythm', 'Pegasus'], + 'AC': ['ribcage', 'penetrate'], + 'AD': ['ringbolt', 'perceptive'], + 'AE': ['robust', 'performance'], + 'AF': ['rocker', 'pharmacy'], + 'B0': ['ruffled', 'phonetic'], + 'B1': ['sailboat', 'photograph'], + 'B2': ['sawdust', 'pioneer'], + 'B3': ['scallion', 'pocketful'], + 'B4': ['scenic', 'politeness'], + 'B5': ['scorecard', 'positive'], + 'B6': ['Scotland', 'potato'], + 'B7': ['seabird', 'processor'], + 'B8': ['select', 'provincial'], + 'B9': ['sentence', 'proximate'], + 'BA': ['shadow', 'puberty'], + 'BB': ['shamrock', 'publisher'], + 'BC': ['showgirl', 'pyramid'], + 'BD': ['skullcap', 'quantity'], + 'BE': ['skydive', 'racketeer'], + 'BF': ['slingshot', 'rebellion'], + 'C0': ['slowdown', 'recipe'], + 'C1': ['snapline', 'recover'], + 'C2': ['snapshot', 'repellent'], + 'C3': ['snowcap', 'replica'], + 'C4': ['snowslide', 'reproduce'], + 'C5': ['solo', 'resistor'], + 'C6': ['southward', 'responsive'], + 'C7': ['soybean', 'retraction'], + 'C8': ['spaniel', 'retrieval'], + 'C9': ['spearhead', 'retrospect'], + 'CA': ['spellbind', 'revenue'], + 'CB': ['spheroid', 'revival'], + 'CC': ['spigot', 'revolver'], + 'CD': ['spindle', 'sandalwood'], + 'CE': ['spyglass', 'sardonic'], + 'CF': ['stagehand', 'Saturday'], + 'D0': ['stagnate', 'savagery'], + 'D1': ['stairway', 'scavenger'], + 'D2': ['standard', 'sensation'], + 'D3': ['stapler', 'sociable'], + 'D4': ['steamship', 'souvenir'], + 'D5': ['sterling', 'specialist'], + 'D6': ['stockman', 'speculate'], + 'D7': ['stopwatch', 'stethoscope'], + 'D8': ['stormy', 'stupendous'], + 'D9': ['sugar', 'supportive'], + 'DA': ['surmount', 'surrender'], + 'DB': ['suspense', 'suspicious'], + 'DC': ['sweatband', 'sympathy'], + 'DD': ['swelter', 'tambourine'], + 'DE': ['tactics', 'telephone'], + 'DF': ['talon', 'therapist'], + 'E0': ['tapeworm', 'tobacco'], + 'E1': ['tempest', 'tolerance'], + 'E2': ['tiger', 'tomorrow'], + 'E3': ['tissue', 'torpedo'], + 'E4': ['tonic', 'tradition'], + 'E5': ['topmost', 'travesty'], + 'E6': ['tracker', 'trombonist'], + 'E7': ['transit', 'truncated'], + 'E8': ['trauma', 'typewriter'], + 'E9': ['treadmill', 'ultimate'], + 'EA': ['Trojan', 'undaunted'], + 'EB': ['trouble', 'underfoot'], + 'EC': ['tumor', 'unicorn'], + 'ED': ['tunnel', 'unify'], + 'EE': ['tycoon', 'universe'], + 'EF': ['uncut', 'unravel'], + 'F0': ['unearth', 'upcoming'], + 'F1': ['unwind', 'vacancy'], + 'F2': ['uproot', 'vagabond'], + 'F3': ['upset', 'vertigo'], + 'F4': ['upshot', 'Virginia'], + 'F5': ['vapor', 'visitor'], + 'F6': ['village', 'vocalist'], + 'F7': ['virus', 'voyager'], + 'F8': ['Vulcan', 'warranty'], + 'F9': ['waffle', 'Waterloo'], + 'FA': ['wallet', 'whimsical'], + 'FB': ['watchword', 'Wichita'], + 'FC': ['wayside', 'Wilmington'], + 'FD': ['willow', 'Wyoming'], + 'FE': ['woodlark', 'yesteryear'], + 'FF': ['Zulu', 'Yucatan'] +} byte_to_even_word = dict([(unhexlify(k.encode("ascii")), both_words[0]) - for k,both_words - in raw_words.items()]) + for k, both_words in raw_words.items()]) byte_to_odd_word = dict([(unhexlify(k.encode("ascii")), both_words[1]) - for k,both_words - in raw_words.items()]) + for k, both_words in raw_words.items()]) even_words_lowercase, odd_words_lowercase = set(), set() -for k,both_words in raw_words.items(): +for k, both_words in raw_words.items(): even_word, odd_word = both_words even_words_lowercase.add(even_word.lower()) odd_words_lowercase.add(odd_word.lower()) + @implementer(IWordlist) class PGPWordList(object): def get_completions(self, prefix, num_words=2): @@ -177,7 +307,7 @@ class PGPWordList(object): else: suffix = prefix[:-lp] + word # append a hyphen if we expect more words - if count+1 < num_words: + if count + 1 < num_words: suffix += "-" completions.add(suffix) return completions diff --git a/src/wormhole/cli/cli.py b/src/wormhole/cli/cli.py index 58ccd8f..ba74157 100644 --- a/src/wormhole/cli/cli.py +++ b/src/wormhole/cli/cli.py @@ -2,21 +2,24 @@ from __future__ import print_function import os import time -start = time.time() -import six -from textwrap import fill, dedent -from sys import stdout, stderr -from . import public_relay -from .. import __version__ -from ..timing import DebugTiming -from ..errors import (WrongPasswordError, WelcomeError, KeyFormatError, - TransferError, NoTorError, UnsendableFileError, - ServerConnectionError) -from twisted.internet.defer import inlineCallbacks, maybeDeferred -from twisted.python.failure import Failure -from twisted.internet.task import react +from sys import stderr, stdout +from textwrap import dedent, fill import click +import six +from twisted.internet.defer import inlineCallbacks, maybeDeferred +from twisted.internet.task import react +from twisted.python.failure import Failure + +from . import public_relay +from .. import __version__ +from ..errors import (KeyFormatError, NoTorError, ServerConnectionError, + TransferError, UnsendableFileError, WelcomeError, + WrongPasswordError) +from ..timing import DebugTiming + +start = time.time() + top_import_finish = time.time() @@ -24,6 +27,7 @@ class Config(object): """ Union of config options that we pass down to (sub) commands. """ + def __init__(self): # This only holds attributes which are *not* set by CLI arguments. # Everything else comes from Click decorators, so we can be sure @@ -34,11 +38,13 @@ class Config(object): self.stderr = stderr self.tor = False # XXX? + def _compose(*decorators): def decorate(f): for d in reversed(decorators): f = d(f) return f + return decorate @@ -48,6 +54,8 @@ ALIASES = { "recieve": "receive", "recv": "receive", } + + class AliasedGroup(click.Group): def get_command(self, ctx, cmd_name): cmd_name = ALIASES.get(cmd_name, cmd_name) @@ -56,22 +64,24 @@ class AliasedGroup(click.Group): # top-level command ("wormhole ...") @click.group(cls=AliasedGroup) +@click.option("--appid", default=None, metavar="APPID", help="appid to use") @click.option( - "--appid", default=None, metavar="APPID", help="appid to use") -@click.option( - "--relay-url", default=public_relay.RENDEZVOUS_RELAY, + "--relay-url", + default=public_relay.RENDEZVOUS_RELAY, envvar='WORMHOLE_RELAY_URL', metavar="URL", help="rendezvous relay to use", ) @click.option( - "--transit-helper", default=public_relay.TRANSIT_RELAY, + "--transit-helper", + default=public_relay.TRANSIT_RELAY, envvar='WORMHOLE_TRANSIT_HELPER', metavar="tcp:HOST:PORT", help="transit relay to use", ) @click.option( - "--dump-timing", type=type(u""), # TODO: hide from --help output + "--dump-timing", + type=type(u""), # TODO: hide from --help output default=None, metavar="FILE.json", help="(debug) write timing data to file", @@ -104,7 +114,8 @@ def _dispatch_command(reactor, cfg, command): errors for the user. """ cfg.timing.add("command dispatch") - cfg.timing.add("import", when=start, which="top").finish(when=top_import_finish) + cfg.timing.add( + "import", when=start, which="top").finish(when=top_import_finish) try: yield maybeDeferred(command) @@ -141,56 +152,89 @@ def _dispatch_command(reactor, cfg, command): CommonArgs = _compose( - click.option("-0", "zeromode", default=False, is_flag=True, - help="enable no-code anything-goes mode", - ), - click.option("-c", "--code-length", default=2, metavar="NUMWORDS", - help="length of code (in bytes/words)", - ), - click.option("-v", "--verify", is_flag=True, default=False, - help="display verification string (and wait for approval)", - ), - click.option("--hide-progress", is_flag=True, default=False, - help="supress progress-bar display", - ), - click.option("--listen/--no-listen", default=True, - help="(debug) don't open a listening socket for Transit", - ), + click.option( + "-0", + "zeromode", + default=False, + is_flag=True, + help="enable no-code anything-goes mode", + ), + click.option( + "-c", + "--code-length", + default=2, + metavar="NUMWORDS", + help="length of code (in bytes/words)", + ), + click.option( + "-v", + "--verify", + is_flag=True, + default=False, + help="display verification string (and wait for approval)", + ), + click.option( + "--hide-progress", + is_flag=True, + default=False, + help="supress progress-bar display", + ), + click.option( + "--listen/--no-listen", + default=True, + help="(debug) don't open a listening socket for Transit", + ), ) TorArgs = _compose( - click.option("--tor", is_flag=True, default=False, - help="use Tor when connecting", - ), - click.option("--launch-tor", is_flag=True, default=False, - help="launch Tor, rather than use existing control/socks port", - ), - click.option("--tor-control-port", default=None, metavar="ENDPOINT", - help="endpoint descriptor for Tor control port", - ), + click.option( + "--tor", + is_flag=True, + default=False, + help="use Tor when connecting", + ), + click.option( + "--launch-tor", + is_flag=True, + default=False, + help="launch Tor, rather than use existing control/socks port", + ), + click.option( + "--tor-control-port", + default=None, + metavar="ENDPOINT", + help="endpoint descriptor for Tor control port", + ), ) + @wormhole.command() @click.pass_context def help(context, **kwargs): print(context.find_root().get_help()) + # wormhole send (or "wormhole tx") @wormhole.command() @CommonArgs @TorArgs @click.option( - "--code", metavar="CODE", + "--code", + metavar="CODE", help="human-generated code phrase", ) @click.option( - "--text", default=None, metavar="MESSAGE", - help="text message to send, instead of a file. Use '-' to read from stdin.", + "--text", + default=None, + metavar="MESSAGE", + help=("text message to send, instead of a file." + " Use '-' to read from stdin."), ) @click.option( - "--ignore-unsendable-files", default=False, is_flag=True, - help="Don't raise an error if a file can't be read." -) + "--ignore-unsendable-files", + default=False, + is_flag=True, + help="Don't raise an error if a file can't be read.") @click.argument("what", required=False, type=click.Path(path_type=type(u""))) @click.pass_obj def send(cfg, **kwargs): @@ -202,6 +246,7 @@ def send(cfg, **kwargs): return go(cmd_send.send, cfg) + # this intermediate function can be mocked by tests that need to build a # Config object def go(f, cfg): @@ -214,23 +259,29 @@ def go(f, cfg): @CommonArgs @TorArgs @click.option( - "--only-text", "-t", is_flag=True, + "--only-text", + "-t", + is_flag=True, help="refuse file transfers, only accept text transfers", ) @click.option( - "--accept-file", is_flag=True, + "--accept-file", + is_flag=True, help="accept file transfer without asking for confirmation", ) @click.option( - "--output-file", "-o", + "--output-file", + "-o", metavar="FILENAME|DIRNAME", help=("The file or directory to create, overriding the name suggested" " by the sender."), ) @click.argument( - "code", nargs=-1, default=None, -# help=("The magic-wormhole code, from the sender. If omitted, the" -# " program will ask for it, using tab-completion."), + "code", + nargs=-1, + default=None, + # help=("The magic-wormhole code, from the sender. If omitted, the" + # " program will ask for it, using tab-completion."), ) @click.pass_obj def receive(cfg, code, **kwargs): @@ -244,10 +295,8 @@ def receive(cfg, code, **kwargs): if len(code) == 1: cfg.code = code[0] elif len(code) > 1: - print( - "Pass either no code or just one code; you passed" - " {}: {}".format(len(code), ', '.join(code)) - ) + print("Pass either no code or just one code; you passed" + " {}: {}".format(len(code), ', '.join(code))) raise SystemExit(1) else: cfg.code = None @@ -260,17 +309,19 @@ def ssh(): """ Facilitate sending/receiving SSH public keys """ - pass @ssh.command(name="invite") @click.option( - "-c", "--code-length", default=2, + "-c", + "--code-length", + default=2, metavar="NUMWORDS", help="length of code (in bytes/words)", ) @click.option( - "--user", "-u", + "--user", + "-u", default=None, metavar="USER", help="Add to USER's ~/.ssh/authorized_keys", @@ -291,15 +342,20 @@ def ssh_invite(ctx, code_length, user, **kwargs): @ssh.command(name="accept") @click.argument( - "code", nargs=1, required=True, + "code", + nargs=1, + required=True, ) @click.option( - "--key-file", "-F", + "--key-file", + "-F", default=None, type=click.Path(exists=True), ) @click.option( - "--yes", "-y", is_flag=True, + "--yes", + "-y", + is_flag=True, help="Skip confirmation prompt to send key", ) @TorArgs @@ -318,7 +374,8 @@ def ssh_accept(cfg, code, key_file, yes, **kwargs): kind, keyid, pubkey = cmd_ssh.find_public_key(key_file) print("Sending public key type='{}' keyid='{}'".format(kind, keyid)) if yes is not True: - click.confirm("Really send public key '{}' ?".format(keyid), abort=True) + click.confirm( + "Really send public key '{}' ?".format(keyid), abort=True) cfg.public_key = (kind, keyid, pubkey) cfg.code = code diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index 1cb457b..4b18822 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -1,14 +1,23 @@ from __future__ import print_function -import os, sys, six, tempfile, zipfile, hashlib, shutil -from tqdm import tqdm + +import hashlib +import os +import shutil +import sys +import tempfile +import zipfile + +import six from humanize import naturalsize +from tqdm import tqdm from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue from twisted.python import log -from wormhole import create, input_with_completion, __version__ -from ..transit import TransitReceiver +from wormhole import __version__, create, input_with_completion + from ..errors import TransferError -from ..util import (dict_to_bytes, bytes_to_dict, bytes_to_hexstr, +from ..transit import TransitReceiver +from ..util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes, estimate_free_space) from .welcome import handle_welcome @@ -17,14 +26,17 @@ APPID = u"lothar.com/wormhole/text-or-file-xfer" KEY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_KEY_TIMER", 1.0)) VERIFY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_VERIFY_TIMER", 1.0)) + class RespondError(Exception): def __init__(self, response): self.response = response + class TransferRejectedError(RespondError): def __init__(self): RespondError.__init__(self, "transfer rejected") + def receive(args, reactor=reactor, _debug_stash_wormhole=None): """I implement 'wormhole receive'. I return a Deferred that fires with None (for success), or signals one of the following errors: @@ -60,16 +72,19 @@ class Receiver: # tor in parallel with everything else, make sure the Tor object # can lazy-provide an endpoint, and overlap the startup process # with the user handing off the wormhole code - self._tor = yield get_tor(self._reactor, - self.args.launch_tor, - self.args.tor_control_port, - timing=self.args.timing) + self._tor = yield get_tor( + self._reactor, + self.args.launch_tor, + self.args.tor_control_port, + timing=self.args.timing) - w = create(self.args.appid or APPID, self.args.relay_url, - self._reactor, - tor=self._tor, - timing=self.args.timing) - self._w = w # so tests can wait on events too + w = create( + self.args.appid or APPID, + self.args.relay_url, + self._reactor, + tor=self._tor, + timing=self.args.timing) + self._w = w # so tests can wait on events too # I wanted to do this instead: # @@ -87,7 +102,7 @@ class Receiver: # (which might be an error) @inlineCallbacks def _good(res): - yield w.close() # wait for ack + yield w.close() # wait for ack returnValue(res) # if we raise an error, we should close and then return the original @@ -96,8 +111,8 @@ class Receiver: @inlineCallbacks def _bad(f): try: - yield w.close() # might be an error too - except: + yield w.close() # might be an error too + except Exception: pass returnValue(f) @@ -114,6 +129,7 @@ class Receiver: def on_slow_key(): print(u"Waiting for sender...", file=self.args.stderr) + notify = self._reactor.callLater(KEY_TIMER, on_slow_key) try: # We wait here until we connect to the server and see the senders @@ -129,8 +145,10 @@ class Receiver: notify.cancel() def on_slow_verification(): - print(u"Key established, waiting for confirmation...", - file=self.args.stderr) + print( + u"Key established, waiting for confirmation...", + file=self.args.stderr) + notify = self._reactor.callLater(VERIFY_TIMER, on_slow_verification) try: # We wait here until we've seen their VERSION message (which they @@ -155,7 +173,7 @@ class Receiver: while True: them_d = yield self._get_data(w) - #print("GOT", them_d) + # print("GOT", them_d) recognized = False if u"transit" in them_d: recognized = True @@ -172,7 +190,7 @@ class Receiver: raise TransferError(r.response) returnValue(None) if not recognized: - log.msg("unrecognized message %r" % (them_d,)) + log.msg("unrecognized message %r" % (them_d, )) def _send_data(self, data, w): data_bytes = dict_to_bytes(data) @@ -197,12 +215,12 @@ class Receiver: w.set_code(code) else: prompt = "Enter receive wormhole code: " - used_completion = yield input_with_completion(prompt, - w.input_code(), - self._reactor) + used_completion = yield input_with_completion( + prompt, w.input_code(), self._reactor) if not used_completion: - print(" (note: you can use to complete words)", - file=self.args.stderr) + print( + " (note: you can use to complete words)", + file=self.args.stderr) yield w.get_code() def _show_verifier(self, verifier_bytes): @@ -220,21 +238,24 @@ class Receiver: @inlineCallbacks def _build_transit(self, w, sender_transit): - tr = TransitReceiver(self.args.transit_helper, - no_listen=(not self.args.listen), - tor=self._tor, - reactor=self._reactor, - timing=self.args.timing) + tr = TransitReceiver( + self.args.transit_helper, + no_listen=(not self.args.listen), + tor=self._tor, + reactor=self._reactor, + timing=self.args.timing) self._transit_receiver = tr - transit_key = w.derive_key(APPID+u"/transit-key", tr.TRANSIT_KEY_LENGTH) + transit_key = w.derive_key(APPID + u"/transit-key", + tr.TRANSIT_KEY_LENGTH) tr.set_transit_key(transit_key) tr.add_connection_hints(sender_transit.get("hints-v1", [])) receiver_abilities = tr.get_connection_abilities() receiver_hints = yield tr.get_connection_hints() - receiver_transit = {"abilities-v1": receiver_abilities, - "hints-v1": receiver_hints, - } + receiver_transit = { + "abilities-v1": receiver_abilities, + "hints-v1": receiver_hints, + } self._send_data({u"transit": receiver_transit}, w) # TODO: send more hints as the TransitReceiver produces them @@ -260,7 +281,7 @@ class Receiver: yield self._close_transit(rp, datahash) else: self._msg(u"I don't know what they're offering\n") - self._msg(u"Offer details: %r" % (them_d,)) + self._msg(u"Offer details: %r" % (them_d, )) raise RespondError("unknown offer type") def _handle_text(self, them_d, w): @@ -276,12 +297,13 @@ class Receiver: self.xfersize = file_data["filesize"] free = estimate_free_space(self.abs_destname) if free is not None and free < self.xfersize: - self._msg(u"Error: insufficient free space (%sB) for file (%sB)" - % (free, self.xfersize)) + self._msg(u"Error: insufficient free space (%sB) for file (%sB)" % + (free, self.xfersize)) raise TransferRejectedError() self._msg(u"Receiving file (%s) into: %s" % - (naturalsize(self.xfersize), os.path.basename(self.abs_destname))) + (naturalsize(self.xfersize), + os.path.basename(self.abs_destname))) self._ask_permission() tmp_destname = self.abs_destname + ".tmp" return open(tmp_destname, "wb") @@ -290,19 +312,22 @@ class Receiver: file_data = them_d["directory"] zipmode = file_data["mode"] if zipmode != "zipfile/deflated": - self._msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,)) + self._msg(u"Error: unknown directory-transfer mode '%s'" % + (zipmode, )) raise RespondError("unknown mode") self.abs_destname = self._decide_destname("directory", file_data["dirname"]) self.xfersize = file_data["zipsize"] free = estimate_free_space(self.abs_destname) if free is not None and free < file_data["numbytes"]: - self._msg(u"Error: insufficient free space (%sB) for directory (%sB)" - % (free, file_data["numbytes"])) + self._msg( + u"Error: insufficient free space (%sB) for directory (%sB)" % + (free, file_data["numbytes"])) raise TransferRejectedError() self._msg(u"Receiving directory (%s) into: %s/" % - (naturalsize(self.xfersize), os.path.basename(self.abs_destname))) + (naturalsize(self.xfersize), + os.path.basename(self.abs_destname))) self._msg(u"%d files, %s (uncompressed)" % (file_data["numfiles"], naturalsize(file_data["numbytes"]))) self._ask_permission() @@ -313,23 +338,26 @@ class Receiver: # "~/.ssh/authorized_keys" and other attacks destname = os.path.basename(destname) if self.args.output_file: - destname = self.args.output_file # override - abs_destname = os.path.abspath( os.path.join(self.args.cwd, destname) ) + destname = self.args.output_file # override + abs_destname = os.path.abspath(os.path.join(self.args.cwd, destname)) # get confirmation from the user before writing to the local directory if os.path.exists(abs_destname): - if self.args.output_file: # overwrite is intentional + if self.args.output_file: # overwrite is intentional self._msg(u"Overwriting '%s'" % destname) if self.args.accept_file: self._remove_existing(abs_destname) else: - self._msg(u"Error: refusing to overwrite existing '%s'" % destname) + self._msg( + u"Error: refusing to overwrite existing '%s'" % destname) raise TransferRejectedError() return abs_destname def _remove_existing(self, path): - if os.path.isfile(path): os.remove(path) - if os.path.isdir(path): shutil.rmtree(path) + if os.path.isfile(path): + os.remove(path) + if os.path.isdir(path): + shutil.rmtree(path) def _ask_permission(self): with self.args.timing.add("permission", waiting="user") as t: @@ -345,7 +373,7 @@ class Receiver: t.detail(answer="yes") def _send_permission(self, w): - self._send_data({"answer": { "file_ack": "ok" }}, w) + self._send_data({"answer": {"file_ack": "ok"}}, w) @inlineCallbacks def _establish_transit(self): @@ -359,14 +387,16 @@ class Receiver: self._msg(u"Receiving (%s).." % record_pipe.describe()) with self.args.timing.add("rx file"): - progress = tqdm(file=self.args.stderr, - disable=self.args.hide_progress, - unit="B", unit_scale=True, total=self.xfersize) + progress = tqdm( + file=self.args.stderr, + disable=self.args.hide_progress, + unit="B", + unit_scale=True, + total=self.xfersize) hasher = hashlib.sha256() with progress: - received = yield record_pipe.writeToFile(f, self.xfersize, - progress.update, - hasher.update) + received = yield record_pipe.writeToFile( + f, self.xfersize, progress.update, hasher.update) datahash = hasher.digest() # except TransitError @@ -382,25 +412,26 @@ class Receiver: tmp_name = f.name f.close() os.rename(tmp_name, self.abs_destname) - self._msg(u"Received file written to %s" % - os.path.basename(self.abs_destname)) + self._msg(u"Received file written to %s" % os.path.basename( + self.abs_destname)) def _extract_file(self, zf, info, extract_dir): """ the zipfile module does not restore file permissions so we'll do it manually """ - out_path = os.path.join( extract_dir, info.filename ) - out_path = os.path.abspath( out_path ) - if not out_path.startswith( extract_dir ): - raise ValueError( "malicious zipfile, %s outside of extract_dir %s" - % (info.filename, extract_dir) ) + out_path = os.path.join(extract_dir, info.filename) + out_path = os.path.abspath(out_path) + if not out_path.startswith(extract_dir): + raise ValueError( + "malicious zipfile, %s outside of extract_dir %s" % + (info.filename, extract_dir)) - zf.extract( info.filename, path=extract_dir ) + zf.extract(info.filename, path=extract_dir) # not sure why zipfiles store the perms 16 bits away but they do perm = info.external_attr >> 16 - os.chmod( out_path, perm ) + os.chmod(out_path, perm) def _write_directory(self, f): @@ -408,10 +439,10 @@ class Receiver: with self.args.timing.add("unpack zip"): with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: for info in zf.infolist(): - self._extract_file( zf, info, self.abs_destname ) + self._extract_file(zf, info, self.abs_destname) - self._msg(u"Received files written to %s/" % - os.path.basename(self.abs_destname)) + self._msg(u"Received files written to %s/" % os.path.basename( + self.abs_destname)) f.close() @inlineCallbacks diff --git a/src/wormhole/cli/cmd_send.py b/src/wormhole/cli/cmd_send.py index 969ab0f..ea5d69e 100644 --- a/src/wormhole/cli/cmd_send.py +++ b/src/wormhole/cli/cmd_send.py @@ -1,20 +1,29 @@ from __future__ import print_function -import os, sys, six, tempfile, zipfile, hashlib -from tqdm import tqdm + +import hashlib +import os +import sys +import tempfile +import zipfile + +import six from humanize import naturalsize -from twisted.python import log -from twisted.protocols import basic +from tqdm import tqdm from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue +from twisted.protocols import basic +from twisted.python import log +from wormhole import __version__, create + from ..errors import TransferError, UnsendableFileError -from wormhole import create, __version__ from ..transit import TransitSender -from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr +from ..util import bytes_to_dict, bytes_to_hexstr, dict_to_bytes from .welcome import handle_welcome APPID = u"lothar.com/wormhole/text-or-file-xfer" VERIFY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_VERIFY_TIMER", 1.0)) + def send(args, reactor=reactor): """I implement 'wormhole send'. I return a Deferred that fires with None (for success), or signals one of the following errors: @@ -26,6 +35,7 @@ def send(args, reactor=reactor): """ return Sender(args, reactor).go() + class Sender: def __init__(self, args, reactor): self._args = args @@ -45,22 +55,25 @@ class Sender: # tor in parallel with everything else, make sure the Tor object # can lazy-provide an endpoint, and overlap the startup process # with the user handing off the wormhole code - self._tor = yield get_tor(reactor, - self._args.launch_tor, - self._args.tor_control_port, - timing=self._timing) + self._tor = yield get_tor( + reactor, + self._args.launch_tor, + self._args.tor_control_port, + timing=self._timing) - w = create(self._args.appid or APPID, self._args.relay_url, - self._reactor, - tor=self._tor, - timing=self._timing) + w = create( + self._args.appid or APPID, + self._args.relay_url, + self._reactor, + tor=self._tor, + timing=self._timing) d = self._go(w) # if we succeed, we should close and return the w.close results # (which might be an error) @inlineCallbacks def _good(res): - yield w.close() # wait for ack + yield w.close() # wait for ack returnValue(res) # if we raise an error, we should close and then return the original @@ -69,8 +82,8 @@ class Sender: @inlineCallbacks def _bad(f): try: - yield w.close() # might be an error too - except: + yield w.close() # might be an error too + except Exception: pass returnValue(f) @@ -125,8 +138,10 @@ class Sender: # TODO: don't stall on w.get_verifier() unless they want it def on_slow_connection(): - print(u"Key established, waiting for confirmation...", - file=args.stderr) + print( + u"Key established, waiting for confirmation...", + file=args.stderr) + notify = self._reactor.callLater(VERIFY_TIMER, on_slow_connection) try: # The usual sender-chooses-code sequence means the receiver's @@ -137,32 +152,35 @@ class Sender: # sitting here for a while, so printing the "waiting" message # seems like a good idea. It might even be appropriate to give up # after a while. - verifier_bytes = yield w.get_verifier() # might WrongPasswordError + verifier_bytes = yield w.get_verifier() # might WrongPasswordError finally: if not notify.called: notify.cancel() if args.verify: - self._check_verifier(w, verifier_bytes) # blocks, can TransferError + self._check_verifier(w, + verifier_bytes) # blocks, can TransferError if self._fd_to_send: - ts = TransitSender(args.transit_helper, - no_listen=(not args.listen), - tor=self._tor, - reactor=self._reactor, - timing=self._timing) + ts = TransitSender( + args.transit_helper, + no_listen=(not args.listen), + tor=self._tor, + reactor=self._reactor, + timing=self._timing) self._transit_sender = ts # for now, send this before the main offer sender_abilities = ts.get_connection_abilities() sender_hints = yield ts.get_connection_hints() - sender_transit = {"abilities-v1": sender_abilities, - "hints-v1": sender_hints, - } + sender_transit = { + "abilities-v1": sender_abilities, + "hints-v1": sender_hints, + } self._send_data({u"transit": sender_transit}, w) # TODO: move this down below w.get_message() - transit_key = w.derive_key(APPID+"/transit-key", + transit_key = w.derive_key(APPID + "/transit-key", ts.TRANSIT_KEY_LENGTH) ts.set_transit_key(transit_key) @@ -175,11 +193,11 @@ class Sender: # TODO: get_message() fired, so get_verifier must have fired, so # now it's safe to use w.derive_key() them_d = bytes_to_dict(them_d_bytes) - #print("GOT", them_d) + # print("GOT", them_d) recognized = False if u"error" in them_d: - raise TransferError("remote error, transfer abandoned: %s" - % them_d["error"]) + raise TransferError( + "remote error, transfer abandoned: %s" % them_d["error"]) if u"transit" in them_d: recognized = True yield self._handle_transit(them_d[u"transit"]) @@ -191,7 +209,7 @@ class Sender: yield self._handle_answer(them_d[u"answer"]) returnValue(None) if not recognized: - log.msg("unrecognized message %r" % (them_d,)) + log.msg("unrecognized message %r" % (them_d, )) def _check_verifier(self, w, verifier_bytes): verifier = bytes_to_hexstr(verifier_bytes) @@ -221,9 +239,10 @@ class Sender: text = six.moves.input("Text to send: ") if text is not None: - print(u"Sending text message (%s)" % naturalsize(len(text)), - file=args.stderr) - offer = { "message": text } + print( + u"Sending text message (%s)" % naturalsize(len(text)), + file=args.stderr) + offer = {"message": text} fd_to_send = None return offer, fd_to_send @@ -244,7 +263,7 @@ class Sender: # is a symlink to something with a different name. The normpath() is # there to remove trailing slashes. basename = os.path.basename(os.path.normpath(what)) - assert basename != "", what # normpath shouldn't allow this + assert basename != "", what # normpath shouldn't allow this # We use realpath() instead of normpath() to locate the actual # file/directory, because the path might contain symlinks, and @@ -273,8 +292,8 @@ class Sender: what = os.path.realpath(what) if not os.path.exists(what): - raise TransferError("Cannot send: no file/directory named '%s'" % - args.what) + raise TransferError( + "Cannot send: no file/directory named '%s'" % args.what) if os.path.isfile(what): # we're sending a file @@ -282,10 +301,11 @@ class Sender: offer["file"] = { "filename": basename, "filesize": filesize, - } - print(u"Sending %s file named '%s'" - % (naturalsize(filesize), basename), - file=args.stderr) + } + print( + u"Sending %s file named '%s'" % (naturalsize(filesize), + basename), + file=args.stderr) fd_to_send = open(what, "rb") return offer, fd_to_send @@ -297,16 +317,18 @@ class Sender: num_files = 0 num_bytes = 0 tostrip = len(what.split(os.sep)) - with zipfile.ZipFile(fd_to_send, "w", - compression=zipfile.ZIP_DEFLATED, - allowZip64=True) as zf: - for path,dirs,files in os.walk(what): + with zipfile.ZipFile( + fd_to_send, + "w", + compression=zipfile.ZIP_DEFLATED, + allowZip64=True) as zf: + for path, dirs, files in os.walk(what): # path always starts with args.what, then sometimes might # have "/subdir" appended. We want the zipfile to contain # "" or "subdir" localpath = list(path.split(os.sep)[tostrip:]) for fn in files: - archivename = os.path.join(*tuple(localpath+[fn])) + archivename = os.path.join(*tuple(localpath + [fn])) localfilename = os.path.join(path, fn) try: zf.write(localfilename, archivename) @@ -315,22 +337,25 @@ class Sender: except OSError as e: errmsg = u"{}: {}".format(fn, e.strerror) if self._args.ignore_unsendable_files: - print(u"{} (ignoring error)".format(errmsg), - file=args.stderr) + print( + u"{} (ignoring error)".format(errmsg), + file=args.stderr) else: raise UnsendableFileError(errmsg) - fd_to_send.seek(0,2) + fd_to_send.seek(0, 2) filesize = fd_to_send.tell() - fd_to_send.seek(0,0) + fd_to_send.seek(0, 0) offer["directory"] = { "mode": "zipfile/deflated", "dirname": basename, "zipsize": filesize, "numbytes": num_bytes, "numfiles": num_files, - } - print(u"Sending directory (%s compressed) named '%s'" - % (naturalsize(filesize), basename), file=args.stderr) + } + print( + u"Sending directory (%s compressed) named '%s'" % + (naturalsize(filesize), basename), + file=args.stderr) return offer, fd_to_send raise TypeError("'%s' is neither file nor directory" % args.what) @@ -340,23 +365,22 @@ class Sender: if self._fd_to_send is None: if them_answer["message_ack"] == "ok": print(u"text message sent", file=self._args.stderr) - returnValue(None) # terminates this function - raise TransferError("error sending text: %r" % (them_answer,)) + returnValue(None) # terminates this function + raise TransferError("error sending text: %r" % (them_answer, )) if them_answer.get("file_ack") != "ok": raise TransferError("ambiguous response from remote, " - "transfer abandoned: %s" % (them_answer,)) + "transfer abandoned: %s" % (them_answer, )) yield self._send_file() - @inlineCallbacks def _send_file(self): ts = self._transit_sender - self._fd_to_send.seek(0,2) + self._fd_to_send.seek(0, 2) filesize = self._fd_to_send.tell() - self._fd_to_send.seek(0,0) + self._fd_to_send.seek(0, 0) record_pipe = yield ts.connect() self._timing.add("transit connected") @@ -365,21 +389,28 @@ class Sender: print(u"Sending (%s).." % record_pipe.describe(), file=stderr) hasher = hashlib.sha256() - progress = tqdm(file=stderr, disable=self._args.hide_progress, - unit="B", unit_scale=True, - total=filesize) + progress = tqdm( + file=stderr, + disable=self._args.hide_progress, + unit="B", + unit_scale=True, + total=filesize) + def _count_and_hash(data): hasher.update(data) progress.update(len(data)) return data + fs = basic.FileSender() with self._timing.add("tx file"): with progress: if filesize: # don't send zero-length files - yield fs.beginFileTransfer(self._fd_to_send, record_pipe, - transform=_count_and_hash) + yield fs.beginFileTransfer( + self._fd_to_send, + record_pipe, + transform=_count_and_hash) expected_hash = hasher.digest() expected_hex = bytes_to_hexstr(expected_hash) diff --git a/src/wormhole/cli/cmd_ssh.py b/src/wormhole/cli/cmd_ssh.py index 8ac60a1..cbd5556 100644 --- a/src/wormhole/cli/cmd_ssh.py +++ b/src/wormhole/cli/cmd_ssh.py @@ -1,16 +1,19 @@ from __future__ import print_function import os -from os.path import expanduser, exists, join -from twisted.internet.defer import inlineCallbacks -from twisted.internet import reactor +from os.path import exists, expanduser, join + import click +from twisted.internet import reactor +from twisted.internet.defer import inlineCallbacks from .. import xfer_util + class PubkeyError(Exception): pass + def find_public_key(hint=None): """ This looks for an appropriate SSH key to send, possibly querying @@ -34,8 +37,9 @@ def find_public_key(hint=None): got_key = False while not got_key: ans = click.prompt( - "Multiple public-keys found:\n" + \ - "\n".join([" {}: {}".format(a, b) for a, b in enumerate(pubkeys)]) + \ + "Multiple public-keys found:\n" + + "\n".join([" {}: {}".format(a, b) + for a, b in enumerate(pubkeys)]) + "\nSend which one?" ) try: @@ -76,7 +80,6 @@ def accept(cfg, reactor=reactor): @inlineCallbacks def invite(cfg, reactor=reactor): - def on_code_created(code): print("Now tell the other user to run:") print() diff --git a/src/wormhole/cli/public_relay.py b/src/wormhole/cli/public_relay.py index 3cb0e62..c816bb2 100644 --- a/src/wormhole/cli/public_relay.py +++ b/src/wormhole/cli/public_relay.py @@ -1,4 +1,3 @@ - # This is a relay I run on a personal server. If it gets too expensive to # run, I'll shut it down. RENDEZVOUS_RELAY = u"ws://relay.magic-wormhole.io:4000/v1" diff --git a/src/wormhole/cli/welcome.py b/src/wormhole/cli/welcome.py index 93ac1cb..2bbfc7e 100644 --- a/src/wormhole/cli/welcome.py +++ b/src/wormhole/cli/welcome.py @@ -1,18 +1,23 @@ -from __future__ import print_function, absolute_import, unicode_literals +from __future__ import absolute_import, print_function, unicode_literals + def handle_welcome(welcome, relay_url, my_version, stderr): if "motd" in welcome: motd_lines = welcome["motd"].splitlines() motd_formatted = "\n ".join(motd_lines) - print("Server (at %s) says:\n %s" % (relay_url, motd_formatted), - file=stderr) + print( + "Server (at %s) says:\n %s" % (relay_url, motd_formatted), + file=stderr) # Only warn if we're running a release version (e.g. 0.0.6, not # 0.0.6+DISTANCE.gHASH). Only warn once. - if ("current_cli_version" in welcome - and "+" not in my_version - and welcome["current_cli_version"] != my_version): - print("Warning: errors may occur unless both sides are running the same version", file=stderr) - print("Server claims %s is current, but ours is %s" - % (welcome["current_cli_version"], my_version), - file=stderr) + if ("current_cli_version" in welcome and "+" not in my_version + and welcome["current_cli_version"] != my_version): + print( + ("Warning: errors may occur unless both sides are running the" + " same version"), + file=stderr) + print( + "Server claims %s is current, but ours is %s" % + (welcome["current_cli_version"], my_version), + file=stderr) diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index 8bd9718..64edf37 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -1,8 +1,10 @@ from __future__ import unicode_literals + class WormholeError(Exception): """Parent class for all wormhole-related errors""" + class UnsendableFileError(Exception): """ A file you wanted to send couldn't be read, maybe because it's not @@ -13,30 +15,38 @@ class UnsendableFileError(Exception): --ignore-unsendable-files flag. """ + class ServerError(WormholeError): """The relay server complained about something we did.""" + class ServerConnectionError(WormholeError): """We had a problem connecting to the relay server:""" + def __init__(self, url, reason): self.url = url self.reason = reason + def __str__(self): return str(self.reason) + class Timeout(WormholeError): pass + class WelcomeError(WormholeError): """ The relay server told us to signal an error, probably because our version is too old to possibly work. The server said:""" pass + class LonelyError(WormholeError): """wormhole.close() was called before the peer connection could be established""" + class WrongPasswordError(WormholeError): """ Key confirmation failed. Either you or your correspondent typed the code @@ -47,6 +57,7 @@ class WrongPasswordError(WormholeError): # or the data blob was corrupted, and that's why decrypt failed pass + class KeyFormatError(WormholeError): """ The key you entered contains spaces or was missing a dash. Magic-wormhole @@ -55,43 +66,61 @@ class KeyFormatError(WormholeError): dashes. """ + class ReflectionAttack(WormholeError): """An attacker (or bug) reflected our outgoing message back to us.""" + class InternalError(WormholeError): """The programmer did something wrong.""" + class TransferError(WormholeError): """Something bad happened and the transfer failed.""" + class NoTorError(WormholeError): """--tor was requested, but 'txtorcon' is not installed.""" + class NoKeyError(WormholeError): """w.derive_key() was called before got_verifier() fired""" + class OnlyOneCodeError(WormholeError): """Only one w.generate_code/w.set_code/w.input_code may be called""" + class MustChooseNameplateFirstError(WormholeError): """The InputHelper was asked to do get_word_completions() or choose_words() before the nameplate was chosen.""" + + class AlreadyChoseNameplateError(WormholeError): """The InputHelper was asked to do get_nameplate_completions() after choose_nameplate() was called, or choose_nameplate() was called a second time.""" + + class AlreadyChoseWordsError(WormholeError): """The InputHelper was asked to do get_word_completions() after choose_words() was called, or choose_words() was called a second time.""" + + class AlreadyInputNameplateError(WormholeError): """The CodeInputter was asked to do completion on a nameplate, when we had already committed to a different one.""" + + class WormholeClosed(Exception): """Deferred-returning API calls errback with WormholeClosed if the wormhole was already closed, or if it closes before a real result can be obtained.""" + class _UnknownPhaseError(Exception): """internal exception type, for tests.""" + + class _UnknownMessageTypeError(Exception): """internal exception type, for tests.""" diff --git a/src/wormhole/eventual.py b/src/wormhole/eventual.py index 4fc8731..1d8e061 100644 --- a/src/wormhole/eventual.py +++ b/src/wormhole/eventual.py @@ -5,6 +5,7 @@ from twisted.internet.defer import Deferred from twisted.internet.interfaces import IReactorTime from twisted.python import log + class EventualQueue(object): def __init__(self, clock): # pass clock=reactor unless you're testing @@ -14,7 +15,7 @@ class EventualQueue(object): self._timer = None def eventually(self, f, *args, **kwargs): - self._calls.append( (f, args, kwargs) ) + self._calls.append((f, args, kwargs)) if not self._timer: self._timer = self._clock.callLater(0, self._turn) @@ -28,7 +29,7 @@ class EventualQueue(object): (f, args, kwargs) = self._calls.pop(0) try: f(*args, **kwargs) - except: + except Exception: log.err() self._timer = None d, self._flush_d = self._flush_d, None diff --git a/src/wormhole/ipaddrs.py b/src/wormhole/ipaddrs.py index 2f35d55..f6c8bb3 100644 --- a/src/wormhole/ipaddrs.py +++ b/src/wormhole/ipaddrs.py @@ -1,27 +1,37 @@ # no unicode_literals # Find all of our ip addresses. From tahoe's src/allmydata/util/iputil.py -import os, re, subprocess, errno +import errno +import os +import re +import subprocess from sys import platform + from twisted.python.procutils import which # Wow, I'm really amazed at home much mileage we've gotten out of calling # the external route.exe program on windows... It appears to work on all # versions so far. Still, the real system calls would much be preferred... # ... thus wrote Greg Smith in time immemorial... -_win32_re = re.compile(r'^\s*\d+\.\d+\.\d+\.\d+\s.+\s(?P
\d+\.\d+\.\d+\.\d+)\s+(?P\d+)\s*$', flags=re.M|re.I|re.S) -_win32_commands = (('route.exe', ('print',), _win32_re),) +_win32_re = re.compile( + (r'^\s*\d+\.\d+\.\d+\.\d+\s.+\s' + r'(?P
\d+\.\d+\.\d+\.\d+)\s+(?P\d+)\s*$'), + flags=re.M | re.I | re.S) +_win32_commands = (('route.exe', ('print', ), _win32_re), ) # These work in most Unices. -_addr_re = re.compile(r'^\s*inet [a-zA-Z]*:?(?P
\d+\.\d+\.\d+\.\d+)[\s/].+$', flags=re.M|re.I|re.S) -_unix_commands = (('/bin/ip', ('addr',), _addr_re), - ('/sbin/ip', ('addr',), _addr_re), - ('/sbin/ifconfig', ('-a',), _addr_re), - ('/usr/sbin/ifconfig', ('-a',), _addr_re), - ('/usr/etc/ifconfig', ('-a',), _addr_re), - ('ifconfig', ('-a',), _addr_re), - ('/sbin/ifconfig', (), _addr_re), - ) +_addr_re = re.compile( + r'^\s*inet [a-zA-Z]*:?(?P
\d+\.\d+\.\d+\.\d+)[\s/].+$', + flags=re.M | re.I | re.S) +_unix_commands = ( + ('/bin/ip', ('addr', ), _addr_re), + ('/sbin/ip', ('addr', ), _addr_re), + ('/sbin/ifconfig', ('-a', ), _addr_re), + ('/usr/sbin/ifconfig', ('-a', ), _addr_re), + ('/usr/etc/ifconfig', ('-a', ), _addr_re), + ('ifconfig', ('-a', ), _addr_re), + ('/sbin/ifconfig', (), _addr_re), +) def find_addresses(): @@ -54,17 +64,19 @@ def find_addresses(): return ["127.0.0.1"] + def _query(path, args, regex): env = {'LANG': 'en_US.UTF-8'} trial = 0 while True: trial += 1 try: - p = subprocess.Popen([path] + list(args), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - universal_newlines=True) + p = subprocess.Popen( + [path] + list(args), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + universal_newlines=True) (output, err) = p.communicate() break except OSError as e: diff --git a/src/wormhole/journal.py b/src/wormhole/journal.py index f7bf0f3..b7cffb5 100644 --- a/src/wormhole/journal.py +++ b/src/wormhole/journal.py @@ -1,8 +1,12 @@ -from __future__ import print_function, absolute_import, unicode_literals -from zope.interface import implementer +from __future__ import absolute_import, print_function, unicode_literals + import contextlib + +from zope.interface import implementer + from ._interfaces import IJournal + @implementer(IJournal) class Journal(object): def __init__(self, save_checkpoint): @@ -19,7 +23,7 @@ class Journal(object): assert not self._processing assert not self._outbound_queue self._processing = True - yield # process inbound messages, change state, queue outbound + yield # process inbound messages, change state, queue outbound self._save_checkpoint() for (fn, args, kwargs) in self._outbound_queue: fn(*args, **kwargs) @@ -31,8 +35,10 @@ class Journal(object): class ImmediateJournal(object): def __init__(self): pass + def queue_outbound(self, fn, *args, **kwargs): fn(*args, **kwargs) + @contextlib.contextmanager def process(self): yield diff --git a/src/wormhole/observer.py b/src/wormhole/observer.py index 99a22a0..43afa70 100644 --- a/src/wormhole/observer.py +++ b/src/wormhole/observer.py @@ -1,14 +1,16 @@ -from __future__ import unicode_literals, print_function +from __future__ import print_function, unicode_literals + from twisted.internet.defer import Deferred from twisted.python.failure import Failure NoResult = object() + class OneShotObserver(object): def __init__(self, eventual_queue): self._eq = eventual_queue self._result = NoResult - self._observers = [] # list of Deferreds + self._observers = [] # list of Deferreds def when_fired(self): d = Deferred() @@ -38,6 +40,7 @@ class OneShotObserver(object): if self._result is NoResult: self.fire(result) + class SequenceObserver(object): def __init__(self, eventual_queue): self._eq = eventual_queue diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index 0b3897e..3521299 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -1,16 +1,19 @@ # no unicode_literals untill twisted update -from twisted.application import service, internet -from twisted.internet import defer, task, reactor, endpoints -from twisted.python import log from click.testing import CliRunner +from twisted.application import internet, service +from twisted.internet import defer, endpoints, reactor, task +from twisted.python import log + import mock -from ..cli import cli -from ..transit import allocate_tcp_port +from wormhole_mailbox_server.database import create_channel_db, create_usage_db from wormhole_mailbox_server.server import make_server from wormhole_mailbox_server.web import make_web_server -from wormhole_mailbox_server.database import create_channel_db, create_usage_db from wormhole_transit_relay.transit_server import Transit +from ..cli import cli +from ..transit import allocate_tcp_port + + class MyInternetService(service.Service, object): # like StreamServerEndpointService, but you can retrieve the port def __init__(self, endpoint, factory): @@ -22,12 +25,15 @@ class MyInternetService(service.Service, object): def startService(self): super(MyInternetService, self).startService() d = self.endpoint.listen(self.factory) + def good(lp): self._lp = lp self._port_d.callback(lp.getHost().port) + def bad(f): log.err(f) self._port_d.errback(f) + d.addCallbacks(good, bad) @defer.inlineCallbacks @@ -35,9 +41,10 @@ class MyInternetService(service.Service, object): if self._lp: yield self._lp.stopListening() - def getPort(self): # only call once! + def getPort(self): # only call once! return self._port_d + class ServerBase: @defer.inlineCallbacks def setUp(self): @@ -51,27 +58,27 @@ class ServerBase: # endpoints.serverFromString db = create_channel_db(":memory:") self._usage_db = create_usage_db(":memory:") - self._rendezvous = make_server(db, - advertise_version=advertise_version, - signal_error=error, - usage_db=self._usage_db) + self._rendezvous = make_server( + db, + advertise_version=advertise_version, + signal_error=error, + usage_db=self._usage_db) ep = endpoints.TCP4ServerEndpoint(reactor, 0, interface="127.0.0.1") site = make_web_server(self._rendezvous, log_requests=False) - #self._lp = yield ep.listen(site) + # self._lp = yield ep.listen(site) s = MyInternetService(ep, site) s.setServiceParent(self.sp) self.rdv_ws_port = yield s.getPort() self._relay_server = s - #self._rendezvous = s._rendezvous + # self._rendezvous = s._rendezvous self.relayurl = u"ws://127.0.0.1:%d/v1" % self.rdv_ws_port # ws://127.0.0.1:%d/wormhole-relay/ws self.transitport = allocate_tcp_port() - ep = endpoints.serverFromString(reactor, - "tcp:%d:interface=127.0.0.1" % - self.transitport) - self._transit_server = f = Transit(blur_usage=None, log_file=None, - usage_db=None) + ep = endpoints.serverFromString( + reactor, "tcp:%d:interface=127.0.0.1" % self.transitport) + self._transit_server = f = Transit( + blur_usage=None, log_file=None, usage_db=None) internet.StreamServerEndpointService(ep, f).setServiceParent(self.sp) self.transit = u"tcp:127.0.0.1:%d" % self.transitport @@ -109,6 +116,7 @@ class ServerBase: " I convinced all threads to exit.") yield d + def config(*argv): r = CliRunner() with mock.patch("wormhole.cli.cli.go") as go: @@ -121,6 +129,7 @@ def config(*argv): cfg = go.call_args[0][1] return cfg + @defer.inlineCallbacks def poll_until(predicate): # return a Deferred that won't fire until the predicate is True diff --git a/src/wormhole/test/run_trial.py b/src/wormhole/test/run_trial.py index 1c9ffee..1e218cd 100644 --- a/src/wormhole/test/run_trial.py +++ b/src/wormhole/test/run_trial.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # This is a tiny helper module, to let "python -m wormhole.test.run_trial # ARGS" does the same thing as running "trial ARGS" (unfortunately # twisted/scripts/trial.py does not have a '__name__=="__main__"' clause). diff --git a/src/wormhole/test/test_args.py b/src/wormhole/test/test_args.py index ff6f060..87f9189 100644 --- a/src/wormhole/test/test_args.py +++ b/src/wormhole/test/test_args.py @@ -1,15 +1,17 @@ import os import sys -import mock + from twisted.trial import unittest + +import mock + from ..cli.public_relay import RENDEZVOUS_RELAY, TRANSIT_RELAY from .common import config -#from pprint import pprint + class Send(unittest.TestCase): def test_baseline(self): cfg = config("send", "--text", "hi") - #pprint(cfg.__dict__) self.assertEqual(cfg.what, None) self.assertEqual(cfg.code, None) self.assertEqual(cfg.code_length, 2) @@ -32,7 +34,6 @@ class Send(unittest.TestCase): def test_file(self): cfg = config("send", "fn") - #pprint(cfg.__dict__) self.assertEqual(cfg.what, u"fn") self.assertEqual(cfg.text, None) @@ -101,7 +102,6 @@ class Send(unittest.TestCase): class Receive(unittest.TestCase): def test_baseline(self): cfg = config("receive") - #pprint(cfg.__dict__) self.assertEqual(cfg.accept_file, False) self.assertEqual(cfg.code, None) self.assertEqual(cfg.code_length, 2) @@ -191,10 +191,12 @@ class Receive(unittest.TestCase): cfg = config("--transit-helper", transit_url_2, "receive") self.assertEqual(cfg.transit_helper, transit_url_2) + class Config(unittest.TestCase): def test_send(self): cfg = config("send") self.assertEqual(cfg.stdout, sys.stdout) + def test_receive(self): cfg = config("receive") self.assertEqual(cfg.stdout, sys.stdout) diff --git a/src/wormhole/test/test_cli.py b/src/wormhole/test/test_cli.py index 5cc8b60..a5d5478 100644 --- a/src/wormhole/test/test_cli.py +++ b/src/wormhole/test/test_cli.py @@ -1,22 +1,32 @@ from __future__ import print_function -import os, sys, re, io, zipfile, six, stat -from textwrap import fill, dedent -from humanize import naturalsize -import mock + +import io +import os +import re +import stat +import sys +import zipfile +from textwrap import dedent, fill + +import six from click.testing import CliRunner -from zope.interface import implementer -from twisted.trial import unittest -from twisted.python import procutils, log +from humanize import naturalsize from twisted.internet import endpoints, reactor -from twisted.internet.utils import getProcessOutputAndValue from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue from twisted.internet.error import ConnectionRefusedError +from twisted.internet.utils import getProcessOutputAndValue +from twisted.python import log, procutils +from twisted.trial import unittest +from zope.interface import implementer + +import mock + from .. import __version__ -from .common import ServerBase, config -from ..cli import cmd_send, cmd_receive, welcome, cli -from ..errors import (TransferError, WrongPasswordError, WelcomeError, - UnsendableFileError, ServerConnectionError) from .._interfaces import ITorManager +from ..cli import cli, cmd_receive, cmd_send, welcome +from ..errors import (ServerConnectionError, TransferError, + UnsendableFileError, WelcomeError, WrongPasswordError) +from .common import ServerBase, config def build_offer(args): @@ -108,8 +118,8 @@ class OfferData(unittest.TestCase): self.cfg.cwd = send_dir e = self.assertRaises(TransferError, build_offer, self.cfg) - self.assertEqual(str(e), - "Cannot send: no file/directory named '%s'" % filename) + self.assertEqual( + str(e), "Cannot send: no file/directory named '%s'" % filename) def _do_test_directory(self, addslash): parent_dir = self.mktemp() @@ -178,8 +188,8 @@ class OfferData(unittest.TestCase): self.assertFalse(os.path.isdir(abs_filename)) e = self.assertRaises(TypeError, build_offer, self.cfg) - self.assertEqual(str(e), - "'%s' is neither file nor directory" % filename) + self.assertEqual( + str(e), "'%s' is neither file nor directory" % filename) def test_symlink(self): if not hasattr(os, 'symlink'): @@ -213,8 +223,9 @@ class OfferData(unittest.TestCase): os.mkdir(os.path.join(parent_dir, "B2", "C2")) with open(os.path.join(parent_dir, "B2", "D.txt"), "wb") as f: f.write(b"success") - os.symlink(os.path.abspath(os.path.join(parent_dir, "B2", "C2")), - os.path.join(parent_dir, "B1", "C1")) + os.symlink( + os.path.abspath(os.path.join(parent_dir, "B2", "C2")), + os.path.join(parent_dir, "B1", "C1")) # Now send "B1/C1/../D.txt" from A. The correct traversal will be: # * start: A # * B1: A/B1 @@ -231,6 +242,7 @@ class OfferData(unittest.TestCase): d, fd_to_send = build_offer(self.cfg) self.assertEqual(d["file"]["filename"], "D.txt") self.assertEqual(fd_to_send.read(), b"success") + if os.name == "nt": test_symlink_collapse.todo = "host OS has broken os.path.realpath()" # ntpath.py's realpath() is built out of normpath(), and does not @@ -241,6 +253,7 @@ class OfferData(unittest.TestCase): # misbehavior (albeit in rare circumstances), 2: it probably used to # work (sometimes, but not in #251). See cmd_send.py for more notes. + class LocaleFinder: def __init__(self): self._run_once = False @@ -266,10 +279,10 @@ class LocaleFinder: # twisted.python.usage to avoid this problem in the future. (out, err, rc) = yield getProcessOutputAndValue("locale", ["-a"]) if rc != 0: - log.msg("error running 'locale -a', rc=%s" % (rc,)) - log.msg("stderr: %s" % (err,)) + log.msg("error running 'locale -a', rc=%s" % (rc, )) + log.msg("stderr: %s" % (err, )) returnValue(None) - out = out.decode("utf-8") # make sure we get a string + out = out.decode("utf-8") # make sure we get a string utf8_locales = {} for locale in out.splitlines(): locale = locale.strip() @@ -281,8 +294,11 @@ class LocaleFinder: if utf8_locales: returnValue(list(utf8_locales.values())[0]) returnValue(None) + + locale_finder = LocaleFinder() + class ScriptsBase: def find_executable(self): # to make sure we're running the right executable (in a virtualenv), @@ -292,12 +308,13 @@ class ScriptsBase: if not locations: raise unittest.SkipTest("unable to find 'wormhole' in $PATH") wormhole = locations[0] - if (os.path.dirname(os.path.abspath(wormhole)) != - os.path.dirname(sys.executable)): - log.msg("locations: %s" % (locations,)) - log.msg("sys.executable: %s" % (sys.executable,)) - raise unittest.SkipTest("found the wrong 'wormhole' in $PATH: %s %s" - % (wormhole, sys.executable)) + if (os.path.dirname(os.path.abspath(wormhole)) != os.path.dirname( + sys.executable)): + log.msg("locations: %s" % (locations, )) + log.msg("sys.executable: %s" % (sys.executable, )) + raise unittest.SkipTest( + "found the wrong 'wormhole' in $PATH: %s %s" % + (wormhole, sys.executable)) return wormhole @inlineCallbacks @@ -323,8 +340,8 @@ class ScriptsBase: raise unittest.SkipTest("unable to find UTF-8 locale") locale_env = dict(LC_ALL=locale, LANG=locale) wormhole = self.find_executable() - res = yield getProcessOutputAndValue(wormhole, ["--version"], - env=locale_env) + res = yield getProcessOutputAndValue( + wormhole, ["--version"], env=locale_env) out, err, rc = res if rc != 0: log.msg("wormhole not runnable in this tree:") @@ -334,6 +351,7 @@ class ScriptsBase: raise unittest.SkipTest("wormhole is not runnable in this tree") returnValue(locale_env) + class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase): # we need Twisted to run the server, but we run the sender and receiver # with deferToThread() @@ -347,26 +365,30 @@ class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase): wormhole = self.find_executable() # we must pass on the environment so that "something" doesn't # get sad about UTF8 vs. ascii encodings - out, err, rc = yield getProcessOutputAndValue(wormhole, ["--version"], - env=os.environ) + out, err, rc = yield getProcessOutputAndValue( + wormhole, ["--version"], env=os.environ) err = err.decode("utf-8") if "DistributionNotFound" in err: log.msg("stderr was %s" % err) last = err.strip().split("\n")[-1] self.fail("wormhole not runnable: %s" % last) ver = out.decode("utf-8") or err - self.failUnlessEqual(ver.strip(), "magic-wormhole {}".format(__version__)) + self.failUnlessEqual(ver.strip(), + "magic-wormhole {}".format(__version__)) self.failUnlessEqual(rc, 0) + @implementer(ITorManager) class FakeTor: # use normal endpoints, but record the fact that we were asked def __init__(self): self.endpoints = [] + def stream_via(self, host, port): self.endpoints.append((host, port)) return endpoints.HostnameEndpoint(reactor, host, port) + class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): # we need Twisted to run the server, but we run the sender and receiver # with deferToThread() @@ -377,11 +399,16 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): yield ServerBase.setUp(self) @inlineCallbacks - def _do_test(self, as_subprocess=False, - mode="text", addslash=False, override_filename=False, - fake_tor=False, overwrite=False, mock_accept=False): - assert mode in ("text", "file", "empty-file", "directory", - "slow-text", "slow-sender-text") + def _do_test(self, + as_subprocess=False, + mode="text", + addslash=False, + override_filename=False, + fake_tor=False, + overwrite=False, + mock_accept=False): + assert mode in ("text", "file", "empty-file", "directory", "slow-text", + "slow-sender-text") if fake_tor: assert not as_subprocess send_cfg = config("send") @@ -408,7 +435,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): elif mode in ("file", "empty-file"): if mode == "empty-file": message = "" - send_filename = u"testfil\u00EB" # e-with-diaeresis + send_filename = u"testfil\u00EB" # e-with-diaeresis with open(os.path.join(send_dir, send_filename), "w") as f: f.write(message) send_cfg.what = send_filename @@ -433,8 +460,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): # expect: $receive_dir/$dirname/[12345] send_dirname = u"testdir" + def message(i): return "test message %d\n" % i + os.mkdir(os.path.join(send_dir, u"middle")) source_dir = os.path.join(send_dir, u"middle", send_dirname) os.mkdir(source_dir) @@ -476,21 +505,27 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): env["_MAGIC_WORMHOLE_TEST_KEY_TIMER"] = "999999" env["_MAGIC_WORMHOLE_TEST_VERIFY_TIMER"] = "999999" send_args = [ - '--relay-url', self.relayurl, - '--transit-helper', '', - 'send', - '--hide-progress', - '--code', send_cfg.code, - ] + content_args + '--relay-url', + self.relayurl, + '--transit-helper', + '', + 'send', + '--hide-progress', + '--code', + send_cfg.code, + ] + content_args send_d = getProcessOutputAndValue( - wormhole_bin, send_args, + wormhole_bin, + send_args, path=send_dir, env=env, ) recv_args = [ - '--relay-url', self.relayurl, - '--transit-helper', '', + '--relay-url', + self.relayurl, + '--transit-helper', + '', 'receive', '--hide-progress', '--accept-file', @@ -500,7 +535,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): recv_args.extend(['-o', receive_filename]) receive_d = getProcessOutputAndValue( - wormhole_bin, recv_args, + wormhole_bin, + recv_args, path=receive_dir, env=env, ) @@ -524,25 +560,27 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): send_cfg.tor = True send_cfg.transit_helper = self.transit tx_tm = FakeTor() - with mock.patch("wormhole.tor_manager.get_tor", - return_value=tx_tm, - ) as mtx_tm: + with mock.patch( + "wormhole.tor_manager.get_tor", + return_value=tx_tm, + ) as mtx_tm: send_d = cmd_send.send(send_cfg) recv_cfg.tor = True recv_cfg.transit_helper = self.transit rx_tm = FakeTor() - with mock.patch("wormhole.tor_manager.get_tor", - return_value=rx_tm, - ) as mrx_tm: + with mock.patch( + "wormhole.tor_manager.get_tor", + return_value=rx_tm, + ) as mrx_tm: receive_d = cmd_receive.receive(recv_cfg) else: KEY_TIMER = 0 if mode == "slow-sender-text" else 99999 rxw = [] with mock.patch.object(cmd_receive, "KEY_TIMER", KEY_TIMER): send_d = cmd_send.send(send_cfg) - receive_d = cmd_receive.receive(recv_cfg, - _debug_stash_wormhole=rxw) + receive_d = cmd_receive.receive( + recv_cfg, _debug_stash_wormhole=rxw) # we need to keep KEY_TIMER patched until the receiver # gets far enough to start the timer, which happens after # the code is set @@ -555,8 +593,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): with mock.patch.object(cmd_receive, "VERIFY_TIMER", VERIFY_TIMER): with mock.patch.object(cmd_send, "VERIFY_TIMER", VERIFY_TIMER): if mock_accept: - with mock.patch.object(cmd_receive.six.moves, - 'input', return_value='y'): + with mock.patch.object( + cmd_receive.six.moves, 'input', + return_value='y'): yield gatherResults([send_d, receive_d], True) else: yield gatherResults([send_d, receive_d], True) @@ -567,14 +606,14 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): expected_endpoints.append(("127.0.0.1", self.transitport)) tx_timing = mtx_tm.call_args[1]["timing"] self.assertEqual(tx_tm.endpoints, expected_endpoints) - self.assertEqual(mtx_tm.mock_calls, - [mock.call(reactor, False, None, - timing=tx_timing)]) + self.assertEqual( + mtx_tm.mock_calls, + [mock.call(reactor, False, None, timing=tx_timing)]) rx_timing = mrx_tm.call_args[1]["timing"] self.assertEqual(rx_tm.endpoints, expected_endpoints) - self.assertEqual(mrx_tm.mock_calls, - [mock.call(reactor, False, None, - timing=rx_timing)]) + self.assertEqual( + mrx_tm.mock_calls, + [mock.call(reactor, False, None, timing=rx_timing)]) send_stdout = send_cfg.stdout.getvalue() send_stderr = send_cfg.stderr.getvalue() @@ -585,7 +624,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): # newlines, even if we're on windows NL = "\n" - self.maxDiff = None # show full output for assertion failures + self.maxDiff = None # show full output for assertion failures key_established = "" if mode == "slow-text": @@ -600,38 +639,41 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): "On the other computer, please run:{NL}{NL}" "wormhole receive {code}{NL}{NL}" "{KE}" - "text message sent{NL}").format(bytes=len(message), - code=send_cfg.code, - NL=NL, - KE=key_established) + "text message sent{NL}").format( + bytes=len(message), + code=send_cfg.code, + NL=NL, + KE=key_established) self.failUnlessEqual(send_stderr, expected) elif mode == "file": self.failUnlessIn(u"Sending {size:s} file named '{name}'{NL}" - .format(size=naturalsize(len(message)), - name=send_filename, - NL=NL), send_stderr) + .format( + size=naturalsize(len(message)), + name=send_filename, + NL=NL), send_stderr) self.failUnlessIn(u"Wormhole code is: {code}{NL}" "On the other computer, please run:{NL}{NL}" - "wormhole receive {code}{NL}{NL}" - .format(code=send_cfg.code, NL=NL), - send_stderr) - self.failUnlessIn(u"File sent.. waiting for confirmation{NL}" - "Confirmation received. Transfer complete.{NL}" - .format(NL=NL), send_stderr) + "wormhole receive {code}{NL}{NL}".format( + code=send_cfg.code, NL=NL), send_stderr) + self.failUnlessIn( + u"File sent.. waiting for confirmation{NL}" + "Confirmation received. Transfer complete.{NL}".format(NL=NL), + send_stderr) elif mode == "directory": self.failUnlessIn(u"Sending directory", send_stderr) self.failUnlessIn(u"named 'testdir'", send_stderr) self.failUnlessIn(u"Wormhole code is: {code}{NL}" "On the other computer, please run:{NL}{NL}" - "wormhole receive {code}{NL}{NL}" - .format(code=send_cfg.code, NL=NL), send_stderr) - self.failUnlessIn(u"File sent.. waiting for confirmation{NL}" - "Confirmation received. Transfer complete.{NL}" - .format(NL=NL), send_stderr) + "wormhole receive {code}{NL}{NL}".format( + code=send_cfg.code, NL=NL), send_stderr) + self.failUnlessIn( + u"File sent.. waiting for confirmation{NL}" + "Confirmation received. Transfer complete.{NL}".format(NL=NL), + send_stderr) # check receiver if mode in ("text", "slow-text", "slow-sender-text"): - self.assertEqual(receive_stdout, message+NL) + self.assertEqual(receive_stdout, message + NL) if mode == "text": self.assertEqual(receive_stderr, "") elif mode == "slow-text": @@ -640,9 +682,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.assertEqual(receive_stderr, "Waiting for sender...\n") elif mode == "file": self.failUnlessEqual(receive_stdout, "") - self.failUnlessIn(u"Receiving file ({size:s}) into: {name}" - .format(size=naturalsize(len(message)), - name=receive_filename), receive_stderr) + self.failUnlessIn(u"Receiving file ({size:s}) into: {name}".format( + size=naturalsize(len(message)), name=receive_filename), + receive_stderr) self.failUnlessIn(u"Received file written to ", receive_stderr) fn = os.path.join(receive_dir, receive_filename) self.failUnless(os.path.exists(fn)) @@ -652,52 +694,67 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.failUnlessEqual(receive_stdout, "") want = (r"Receiving directory \(\d+ \w+\) into: {name}/" .format(name=receive_dirname)) - self.failUnless(re.search(want, receive_stderr), - (want, receive_stderr)) - self.failUnlessIn(u"Received files written to {name}" - .format(name=receive_dirname), receive_stderr) + self.failUnless( + re.search(want, receive_stderr), (want, receive_stderr)) + self.failUnlessIn( + u"Received files written to {name}" + .format(name=receive_dirname), + receive_stderr) fn = os.path.join(receive_dir, receive_dirname) self.failUnless(os.path.exists(fn), fn) for i in range(5): fn = os.path.join(receive_dir, receive_dirname, str(i)) with open(fn, "r") as f: self.failUnlessEqual(f.read(), message(i)) - self.failUnlessEqual(modes[i], - stat.S_IMODE(os.stat(fn).st_mode)) + self.failUnlessEqual(modes[i], stat.S_IMODE( + os.stat(fn).st_mode)) def test_text(self): return self._do_test() + def test_text_subprocess(self): return self._do_test(as_subprocess=True) + def test_text_tor(self): return self._do_test(fake_tor=True) def test_file(self): return self._do_test(mode="file") + def test_file_override(self): return self._do_test(mode="file", override_filename=True) + def test_file_overwrite(self): return self._do_test(mode="file", overwrite=True) + def test_file_overwrite_mock_accept(self): return self._do_test(mode="file", overwrite=True, mock_accept=True) + def test_file_tor(self): return self._do_test(mode="file", fake_tor=True) + def test_empty_file(self): return self._do_test(mode="empty-file") def test_directory(self): return self._do_test(mode="directory") + def test_directory_addslash(self): return self._do_test(mode="directory", addslash=True) + def test_directory_override(self): return self._do_test(mode="directory", override_filename=True) + def test_directory_overwrite(self): return self._do_test(mode="directory", overwrite=True) + def test_directory_overwrite_mock_accept(self): - return self._do_test(mode="directory", overwrite=True, mock_accept=True) + return self._do_test( + mode="directory", overwrite=True, mock_accept=True) def test_slow_text(self): return self._do_test(mode="slow-text") + def test_slow_sender_text(self): return self._do_test(mode="slow-sender-text") @@ -721,7 +778,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): os.mkdir(send_dir) receive_dir = self.mktemp() os.mkdir(receive_dir) - recv_cfg.accept_file = True # don't ask for permission + recv_cfg.accept_file = True # don't ask for permission if mode == "file": message = "test message\n" @@ -765,10 +822,12 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): free_space = 10000000 else: free_space = 0 - with mock.patch("wormhole.cli.cmd_receive.estimate_free_space", - return_value=free_space): + with mock.patch( + "wormhole.cli.cmd_receive.estimate_free_space", + return_value=free_space): f = yield self.assertFailure(send_d, TransferError) - self.assertEqual(str(f), "remote error, transfer abandoned: transfer rejected") + self.assertEqual( + str(f), "remote error, transfer abandoned: transfer rejected") f = yield self.assertFailure(receive_d, TransferError) self.assertEqual(str(f), "transfer rejected") @@ -781,7 +840,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): # newlines, even if we're on windows NL = "\n" - self.maxDiff = None # show full output for assertion failures + self.maxDiff = None # show full output for assertion failures self.assertEqual(send_stdout, "") self.assertEqual(receive_stdout, "") @@ -789,54 +848,63 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): # check sender if mode == "file": self.failUnlessIn("Sending {size:s} file named '{name}'{NL}" - .format(size=naturalsize(size), - name=send_filename, - NL=NL), send_stderr) + .format( + size=naturalsize(size), + name=send_filename, + NL=NL), send_stderr) self.failUnlessIn("Wormhole code is: {code}{NL}" "On the other computer, please run:{NL}{NL}" - "wormhole receive {code}{NL}" - .format(code=send_cfg.code, NL=NL), - send_stderr) - self.failIfIn("File sent.. waiting for confirmation{NL}" - "Confirmation received. Transfer complete.{NL}" - .format(NL=NL), send_stderr) + "wormhole receive {code}{NL}".format( + code=send_cfg.code, NL=NL), send_stderr) + self.failIfIn( + "File sent.. waiting for confirmation{NL}" + "Confirmation received. Transfer complete.{NL}".format(NL=NL), + send_stderr) elif mode == "directory": self.failUnlessIn("Sending directory", send_stderr) self.failUnlessIn("named 'testdir'", send_stderr) self.failUnlessIn("Wormhole code is: {code}{NL}" "On the other computer, please run:{NL}{NL}" - "wormhole receive {code}{NL}" - .format(code=send_cfg.code, NL=NL), send_stderr) - self.failIfIn("File sent.. waiting for confirmation{NL}" - "Confirmation received. Transfer complete.{NL}" - .format(NL=NL), send_stderr) + "wormhole receive {code}{NL}".format( + code=send_cfg.code, NL=NL), send_stderr) + self.failIfIn( + "File sent.. waiting for confirmation{NL}" + "Confirmation received. Transfer complete.{NL}".format(NL=NL), + send_stderr) # check receiver if mode == "file": self.failIfIn("Received file written to ", receive_stderr) if failmode == "noclobber": - self.failUnlessIn("Error: " - "refusing to overwrite existing 'testfile'{NL}" - .format(NL=NL), receive_stderr) + self.failUnlessIn( + "Error: " + "refusing to overwrite existing 'testfile'{NL}" + .format(NL=NL), + receive_stderr) else: - self.failUnlessIn("Error: " - "insufficient free space (0B) for file ({size:d}B){NL}" - .format(NL=NL, size=size), receive_stderr) + self.failUnlessIn( + "Error: " + "insufficient free space (0B) for file ({size:d}B){NL}" + .format(NL=NL, size=size), receive_stderr) elif mode == "directory": - self.failIfIn("Received files written to {name}" - .format(name=receive_name), receive_stderr) - #want = (r"Receiving directory \(\d+ \w+\) into: {name}/" + self.failIfIn( + "Received files written to {name}".format(name=receive_name), + receive_stderr) + # want = (r"Receiving directory \(\d+ \w+\) into: {name}/" # .format(name=receive_name)) - #self.failUnless(re.search(want, receive_stderr), + # self.failUnless(re.search(want, receive_stderr), # (want, receive_stderr)) if failmode == "noclobber": - self.failUnlessIn("Error: " - "refusing to overwrite existing 'testdir'{NL}" - .format(NL=NL), receive_stderr) + self.failUnlessIn( + "Error: " + "refusing to overwrite existing 'testdir'{NL}" + .format(NL=NL), + receive_stderr) else: - self.failUnlessIn("Error: " - "insufficient free space (0B) for directory ({size:d}B){NL}" - .format(NL=NL, size=size), receive_stderr) + self.failUnlessIn(("Error: " + "insufficient free space (0B) for directory" + " ({size:d}B){NL}").format( + NL=NL, size=size), receive_stderr) if failmode == "noclobber": fn = os.path.join(receive_dir, receive_name) @@ -846,13 +914,17 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): def test_fail_file_noclobber(self): return self._do_test_fail("file", "noclobber") + def test_fail_directory_noclobber(self): return self._do_test_fail("directory", "noclobber") + def test_fail_file_toobig(self): return self._do_test_fail("file", "toobig") + def test_fail_directory_toobig(self): return self._do_test_fail("directory", "toobig") + class ZeroMode(ServerBase, unittest.TestCase): @inlineCallbacks def test_text(self): @@ -871,8 +943,8 @@ class ZeroMode(ServerBase, unittest.TestCase): send_cfg.text = message - #send_cfg.cwd = send_dir - #recv_cfg.cwd = receive_dir + # send_cfg.cwd = send_dir + # recv_cfg.cwd = receive_dir send_d = cmd_send.send(send_cfg) receive_d = cmd_receive.receive(recv_cfg) @@ -888,7 +960,7 @@ class ZeroMode(ServerBase, unittest.TestCase): # newlines, even if we're on windows NL = "\n" - self.maxDiff = None # show full output for assertion failures + self.maxDiff = None # show full output for assertion failures self.assertEqual(send_stdout, "") @@ -898,15 +970,15 @@ class ZeroMode(ServerBase, unittest.TestCase): "{NL}" "wormhole receive -0{NL}" "{NL}" - "text message sent{NL}").format(bytes=len(message), - code=send_cfg.code, - NL=NL) + "text message sent{NL}").format( + bytes=len(message), code=send_cfg.code, NL=NL) self.failUnlessEqual(send_stderr, expected) # check receiver - self.assertEqual(receive_stdout, message+NL) + self.assertEqual(receive_stdout, message + NL) self.assertEqual(receive_stderr, "") + class NotWelcome(ServerBase, unittest.TestCase): @inlineCallbacks def setUp(self): @@ -936,6 +1008,7 @@ class NotWelcome(ServerBase, unittest.TestCase): f = yield self.assertFailure(receive_d, WelcomeError) self.assertEqual(str(f), "please upgrade XYZ") + class NoServer(ServerBase, unittest.TestCase): @inlineCallbacks def setUp(self): @@ -991,8 +1064,8 @@ class NoServer(ServerBase, unittest.TestCase): e = yield self.assertFailure(receive_d, ServerConnectionError) self.assertIsInstance(e.reason, ConnectionRefusedError) -class Cleanup(ServerBase, unittest.TestCase): +class Cleanup(ServerBase, unittest.TestCase): def make_config(self): cfg = config("send") # common options for all tests in this suite @@ -1040,6 +1113,7 @@ class Cleanup(ServerBase, unittest.TestCase): cids = self._rendezvous.get_app(cmd_send.APPID).get_nameplate_ids() self.assertEqual(len(cids), 0) + class ExtractFile(unittest.TestCase): def test_filenames(self): args = mock.Mock() @@ -1066,8 +1140,8 @@ class ExtractFile(unittest.TestCase): zf = mock.Mock() zi = mock.Mock() - zi.filename = "haha//root" # abspath squashes this, hopefully zipfile - # does too + zi.filename = "haha//root" # abspath squashes this, hopefully zipfile + # does too zi.external_attr = 5 << 16 expected = os.path.join(extract_dir, "haha", "root") with mock.patch.object(cmd_receive.os, "chmod") as chmod: @@ -1082,6 +1156,7 @@ class ExtractFile(unittest.TestCase): e = self.assertRaises(ValueError, ef, zf, zi, extract_dir) self.assertIn("malicious zipfile", str(e)) + class AppID(ServerBase, unittest.TestCase): @inlineCallbacks def setUp(self): @@ -1108,11 +1183,11 @@ class AppID(ServerBase, unittest.TestCase): yield receive_d used = self._usage_db.execute("SELECT DISTINCT `app_id`" - " FROM `nameplates`" - ).fetchall() + " FROM `nameplates`").fetchall() self.assertEqual(len(used), 1, used) self.assertEqual(used[0]["app_id"], u"appid2") + class Welcome(unittest.TestCase): def do(self, welcome_message, my_version="2.0"): stderr = io.StringIO() @@ -1129,27 +1204,33 @@ class Welcome(unittest.TestCase): def test_version_old(self): stderr = self.do({"current_cli_version": "3.0"}) - expected = ("Warning: errors may occur unless both sides are running the same version\n" + + expected = ("Warning: errors may occur unless both sides are" + " running the same version\n" "Server claims 3.0 is current, but ours is 2.0\n") self.assertEqual(stderr, expected) def test_version_unreleased(self): - stderr = self.do({"current_cli_version": "3.0"}, - my_version="2.5+middle.something") + stderr = self.do( + { + "current_cli_version": "3.0" + }, my_version="2.5+middle.something") self.assertEqual(stderr, "") def test_motd(self): stderr = self.do({"motd": "hello"}) self.assertEqual(stderr, "Server (at url) says:\n hello\n") + class Dispatch(unittest.TestCase): @inlineCallbacks def test_success(self): cfg = config("send") cfg.stderr = io.StringIO() called = [] + def fake(): called.append(1) + yield cli._dispatch_command(reactor, cfg, fake) self.assertEqual(called, [1]) self.assertEqual(cfg.stderr.getvalue(), "") @@ -1160,8 +1241,10 @@ class Dispatch(unittest.TestCase): cfg.stderr = io.StringIO() cfg.timing = mock.Mock() cfg.dump_timing = "filename" + def fake(): pass + yield cli._dispatch_command(reactor, cfg, fake) self.assertEqual(cfg.stderr.getvalue(), "") self.assertEqual(cfg.timing.mock_calls[-1], @@ -1171,32 +1254,39 @@ class Dispatch(unittest.TestCase): def test_wrong_password_error(self): cfg = config("send") cfg.stderr = io.StringIO() + def fake(): raise WrongPasswordError("abcd") - yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake), - SystemExit) - expected = fill("ERROR: " + dedent(WrongPasswordError.__doc__))+"\n" + + yield self.assertFailure( + cli._dispatch_command(reactor, cfg, fake), SystemExit) + expected = fill("ERROR: " + dedent(WrongPasswordError.__doc__)) + "\n" self.assertEqual(cfg.stderr.getvalue(), expected) @inlineCallbacks def test_welcome_error(self): cfg = config("send") cfg.stderr = io.StringIO() + def fake(): raise WelcomeError("abcd") - yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake), - SystemExit) - expected = fill("ERROR: " + dedent(WelcomeError.__doc__))+"\n\nabcd\n" + + yield self.assertFailure( + cli._dispatch_command(reactor, cfg, fake), SystemExit) + expected = ( + fill("ERROR: " + dedent(WelcomeError.__doc__)) + "\n\nabcd\n") self.assertEqual(cfg.stderr.getvalue(), expected) @inlineCallbacks def test_transfer_error(self): cfg = config("send") cfg.stderr = io.StringIO() + def fake(): raise TransferError("abcd") - yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake), - SystemExit) + + yield self.assertFailure( + cli._dispatch_command(reactor, cfg, fake), SystemExit) expected = "TransferError: abcd\n" self.assertEqual(cfg.stderr.getvalue(), expected) @@ -1204,11 +1294,14 @@ class Dispatch(unittest.TestCase): def test_server_connection_error(self): cfg = config("send") cfg.stderr = io.StringIO() + def fake(): raise ServerConnectionError("URL", ValueError("abcd")) - yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake), - SystemExit) - expected = fill("ERROR: " + dedent(ServerConnectionError.__doc__))+"\n" + + yield self.assertFailure( + cli._dispatch_command(reactor, cfg, fake), SystemExit) + expected = fill( + "ERROR: " + dedent(ServerConnectionError.__doc__)) + "\n" expected += "(relay URL was URL)\n" expected += "abcd\n" self.assertEqual(cfg.stderr.getvalue(), expected) @@ -1217,21 +1310,26 @@ class Dispatch(unittest.TestCase): def test_other_error(self): cfg = config("send") cfg.stderr = io.StringIO() + def fake(): raise ValueError("abcd") + # I'm seeing unicode problems with the Failure().printTraceback, and # the output would be kind of unpredictable anyways, so we'll mock it # out here. f = mock.Mock() + def mock_print(file): file.write(u"\n") + f.printTraceback = mock_print with mock.patch("wormhole.cli.cli.Failure", return_value=f): - yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake), - SystemExit) + yield self.assertFailure( + cli._dispatch_command(reactor, cfg, fake), SystemExit) expected = "\nERROR: abcd\n" self.assertEqual(cfg.stderr.getvalue(), expected) + class Help(unittest.TestCase): def _check_top_level_help(self, got): # the main wormhole.cli.cli.wormhole docstring should be in the diff --git a/src/wormhole/test/test_eventual.py b/src/wormhole/test/test_eventual.py index 4813198..cb0c896 100644 --- a/src/wormhole/test/test_eventual.py +++ b/src/wormhole/test/test_eventual.py @@ -1,14 +1,19 @@ from __future__ import print_function, unicode_literals -import mock -from twisted.trial import unittest + from twisted.internet import reactor -from twisted.internet.task import Clock from twisted.internet.defer import Deferred, inlineCallbacks +from twisted.internet.task import Clock +from twisted.trial import unittest + +import mock + from ..eventual import EventualQueue + class IntentionalError(Exception): pass + class Eventual(unittest.TestCase, object): def test_eventually(self): c = Clock() @@ -23,9 +28,10 @@ class Eventual(unittest.TestCase, object): self.assertNoResult(d3) eq.flush_sync() - self.assertEqual(c1.mock_calls, - [mock.call("arg1", "arg2", kwarg1="kw1"), - mock.call("arg3", "arg4", kwarg5="kw5")]) + self.assertEqual(c1.mock_calls, [ + mock.call("arg1", "arg2", kwarg1="kw1"), + mock.call("arg3", "arg4", kwarg5="kw5") + ]) self.assertEqual(self.successResultOf(d2), None) self.assertEqual(self.successResultOf(d3), "value") @@ -47,11 +53,12 @@ class Eventual(unittest.TestCase, object): eq = EventualQueue(reactor) d1 = eq.fire_eventually() d2 = Deferred() + def _more(res): eq.eventually(d2.callback, None) + d1.addCallback(_more) yield eq.flush() # d1 will fire, which will queue d2 to fire, and the flush() ought to # wait for d2 too self.successResultOf(d2) - diff --git a/src/wormhole/test/test_hkdf.py b/src/wormhole/test/test_hkdf.py index 7751383..1628646 100644 --- a/src/wormhole/test/test_hkdf.py +++ b/src/wormhole/test/test_hkdf.py @@ -1,44 +1,47 @@ from __future__ import print_function, unicode_literals + import unittest -from binascii import unhexlify #, hexlify +from binascii import unhexlify # , hexlify + from hkdf import Hkdf -#def generate_KAT(): -# print("KAT = [") -# for salt in (b"", b"salt"): -# for context in (b"", b"context"): -# skm = b"secret" -# out = HKDF(skm, 64, XTS=salt, CTXinfo=context) -# hexout = " '%s' +\n '%s'" % (hexlify(out[:32]), -# hexlify(out[32:])) -# print(" (%r, %r, %r,\n%s)," % (salt, context, skm, hexout)) -# print("]") +# def generate_KAT(): +# print("KAT = [") +# for salt in (b"", b"salt"): +# for context in (b"", b"context"): +# skm = b"secret" +# out = HKDF(skm, 64, XTS=salt, CTXinfo=context) +# hexout = " '%s' +\n '%s'" % (hexlify(out[:32]), +# hexlify(out[32:])) +# print(" (%r, %r, %r,\n%s)," % (salt, context, skm, hexout)) +# print("]") KAT = [ - ('', '', 'secret', - '2f34e5ff91ec85d53ca9b543683174d0cf550b60d5f52b24c97b386cfcf6cbbf' + - '9cfd42fd37e1e5a214d15f03058d7fee63dc28f564b7b9fe3da514f80daad4bf'), - ('', 'context', 'secret', - 'c24c303a1adfb4c3e2b092e6254ed481c41d8955ba8ec3f6a1473493a60c957b' + - '31b723018ca75557214d3d5c61c0c7a5315b103b21ff00cb03ebe023dc347a47'), - ('salt', '', 'secret', - 'f1156507c39b0e326159e778696253122de430899a8df2484040a85a5f95ceb1' + - 'dfca555d4cc603bdf7153ed1560de8cbc3234b27a6d2be8e8ca202d90649679a'), - ('salt', 'context', 'secret', - '61a4f201a867bcc12381ddb180d27074408d03ee9d5750855e5a12d967fa060f' + - '10336ead9370927eaabb0d60b259346ee5f57eb7ceba8c72f1ed3f2932b1bf19'), + ('', '', 'secret', + '2f34e5ff91ec85d53ca9b543683174d0cf550b60d5f52b24c97b386cfcf6cbbf' + + '9cfd42fd37e1e5a214d15f03058d7fee63dc28f564b7b9fe3da514f80daad4bf'), + ('', 'context', 'secret', + 'c24c303a1adfb4c3e2b092e6254ed481c41d8955ba8ec3f6a1473493a60c957b' + + '31b723018ca75557214d3d5c61c0c7a5315b103b21ff00cb03ebe023dc347a47'), + ('salt', '', 'secret', + 'f1156507c39b0e326159e778696253122de430899a8df2484040a85a5f95ceb1' + + 'dfca555d4cc603bdf7153ed1560de8cbc3234b27a6d2be8e8ca202d90649679a'), + ('salt', 'context', 'secret', + '61a4f201a867bcc12381ddb180d27074408d03ee9d5750855e5a12d967fa060f' + + '10336ead9370927eaabb0d60b259346ee5f57eb7ceba8c72f1ed3f2932b1bf19'), ] + class TestKAT(unittest.TestCase): # note: this uses SHA256 def test_kat(self): for (salt, context, skm, expected_hexout) in KAT: expected_out = unhexlify(expected_hexout) for outlen in range(0, len(expected_out)): - out = Hkdf(salt.encode("ascii"), - skm.encode("ascii")).expand(context.encode("ascii"), - outlen) + out = Hkdf(salt.encode("ascii"), skm.encode("ascii")).expand( + context.encode("ascii"), outlen) self.assertEqual(out, expected_out[:outlen]) -#if __name__ == '__main__': -# generate_KAT() + +# if __name__ == '__main__': +# generate_KAT() diff --git a/src/wormhole/test/test_ipaddrs.py b/src/wormhole/test/test_ipaddrs.py index 90abcce..9e2cec5 100644 --- a/src/wormhole/test/test_ipaddrs.py +++ b/src/wormhole/test/test_ipaddrs.py @@ -1,9 +1,13 @@ +import errno +import os +import re +import subprocess -import re, errno, subprocess, os from twisted.trial import unittest + from .. import ipaddrs -DOTTED_QUAD_RE=re.compile("^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$") +DOTTED_QUAD_RE = re.compile("^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$") MOCK_IPADDR_OUTPUT = """\ 1: lo: mtu 16436 qdisc noqueue state UNKNOWN \n\ @@ -11,12 +15,14 @@ MOCK_IPADDR_OUTPUT = """\ inet 127.0.0.1/8 scope host lo inet6 ::1/128 scope host \n\ valid_lft forever preferred_lft forever -2: eth1: mtu 1500 qdisc pfifo_fast state UP qlen 1000 +2: eth1: mtu 1500 qdisc pfifo_fast state UP \ +qlen 1000 link/ether d4:3d:7e:01:b4:3e brd ff:ff:ff:ff:ff:ff inet 192.168.0.6/24 brd 192.168.0.255 scope global eth1 inet6 fe80::d63d:7eff:fe01:b43e/64 scope link \n\ valid_lft forever preferred_lft forever -3: wlan0: mtu 1500 qdisc mq state UP qlen 1000 +3: wlan0: mtu 1500 qdisc mq state UP qlen\ + 1000 link/ether 90:f6:52:27:15:0a brd ff:ff:ff:ff:ff:ff inet 192.168.0.2/24 brd 192.168.0.255 scope global wlan0 inet6 fe80::92f6:52ff:fe27:150a/64 scope link \n\ @@ -58,7 +64,8 @@ MOCK_ROUTE_OUTPUT = """\ =========================================================================== Interface List 0x1 ........................... MS TCP Loopback interface -0x2 ...08 00 27 c3 80 ad ...... AMD PCNET Family PCI Ethernet Adapter - Packet Scheduler Miniport +0x2 ...08 00 27 c3 80 ad ...... AMD PCNET Family PCI Ethernet Adapter - \ +Packet Scheduler Miniport =========================================================================== =========================================================================== Active Routes: @@ -85,6 +92,7 @@ class FakeProcess: def __init__(self, output, err): self.output = output self.err = err + def communicate(self): return (self.output, self.err) @@ -94,16 +102,28 @@ class ListAddresses(unittest.TestCase): addresses = ipaddrs.find_addresses() self.failUnlessIn("127.0.0.1", addresses) self.failIfIn("0.0.0.0", addresses) + # David A.'s OpenSolaris box timed out on this test one time when it was at # 2s. - test_list.timeout=4 + test_list.timeout = 4 def _test_list_mock(self, command, output, expected): self.first = True - def call_Popen(args, bufsize=0, executable=None, stdin=None, stdout=None, stderr=None, - preexec_fn=None, close_fds=False, shell=False, cwd=None, env=None, - universal_newlines=False, startupinfo=None, creationflags=0): + def call_Popen(args, + bufsize=0, + executable=None, + stdin=None, + stdout=None, + stderr=None, + preexec_fn=None, + close_fds=False, + shell=False, + cwd=None, + env=None, + universal_newlines=False, + startupinfo=None, + creationflags=0): if self.first: self.first = False e = OSError("EINTR") @@ -115,11 +135,13 @@ class ListAddresses(unittest.TestCase): e = OSError("[Errno 2] No such file or directory") e.errno = errno.ENOENT raise e + self.patch(subprocess, 'Popen', call_Popen) self.patch(os.path, 'isfile', lambda x: True) def call_which(name): return [name] + self.patch(ipaddrs, 'which', call_which) addresses = ipaddrs.find_addresses() @@ -131,11 +153,13 @@ class ListAddresses(unittest.TestCase): def test_list_mock_ifconfig(self): self.patch(ipaddrs, 'platform', "linux2") - self._test_list_mock("ifconfig", MOCK_IFCONFIG_OUTPUT, UNIX_TEST_ADDRESSES) + self._test_list_mock("ifconfig", MOCK_IFCONFIG_OUTPUT, + UNIX_TEST_ADDRESSES) def test_list_mock_route(self): self.patch(ipaddrs, 'platform', "win32") - self._test_list_mock("route.exe", MOCK_ROUTE_OUTPUT, WINDOWS_TEST_ADDRESSES) + self._test_list_mock("route.exe", MOCK_ROUTE_OUTPUT, + WINDOWS_TEST_ADDRESSES) def test_list_mock_cygwin(self): self.patch(ipaddrs, 'platform', "cygwin") diff --git a/src/wormhole/test/test_journal.py b/src/wormhole/test/test_journal.py index 96b9319..372da0f 100644 --- a/src/wormhole/test/test_journal.py +++ b/src/wormhole/test/test_journal.py @@ -1,8 +1,11 @@ -from __future__ import print_function, absolute_import, unicode_literals +from __future__ import absolute_import, print_function, unicode_literals + from twisted.trial import unittest + from .. import journal from .._interfaces import IJournal + class Journal(unittest.TestCase): def test_journal(self): events = [] diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index ec96fc0..38e351e 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -1,29 +1,36 @@ from __future__ import print_function, unicode_literals + import json -import mock -from zope.interface import directlyProvides, implementer + +from nacl.secret import SecretBox +from spake2 import SPAKE2_Symmetric from twisted.trial import unittest -from .. import (errors, timing, _order, _receive, _key, _code, _lister, _boss, - _input, _allocator, _send, _terminator, _nameplate, _mailbox, - _rendezvous, __version__) -from .._interfaces import (IKey, IReceive, IBoss, ISend, IMailbox, IOrder, - IRendezvousConnector, ILister, IInput, IAllocator, - INameplate, ICode, IWordlist, ITerminator) +from zope.interface import directlyProvides, implementer + +import mock + +from .. import (__version__, _allocator, _boss, _code, _input, _key, _lister, + _mailbox, _nameplate, _order, _receive, _rendezvous, _send, + _terminator, errors, timing) +from .._interfaces import (IAllocator, IBoss, ICode, IInput, IKey, ILister, + IMailbox, INameplate, IOrder, IReceive, + IRendezvousConnector, ISend, ITerminator, IWordlist) from .._key import derive_key, derive_phase_key, encrypt_data from ..journal import ImmediateJournal -from ..util import (dict_to_bytes, bytes_to_dict, - hexstr_to_bytes, bytes_to_hexstr, to_bytes) -from spake2 import SPAKE2_Symmetric -from nacl.secret import SecretBox +from ..util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes, + hexstr_to_bytes, to_bytes) + @implementer(IWordlist) class FakeWordList(object): def choose_words(self, length): return "-".join(["word"] * length) + def get_completions(self, prefix): self._get_completions_prefix = prefix return self._completions + class Dummy: def __init__(self, name, events, iface, *meths): self.name = name @@ -33,12 +40,15 @@ class Dummy: for meth in meths: self.mock(meth) self.retval = None + def mock(self, meth): def log(*args): - self.events.append(("%s.%s" % (self.name, meth),) + args) + self.events.append(("%s.%s" % (self.name, meth), ) + args) return self.retval + setattr(self, meth, log) + class Send(unittest.TestCase): def build(self): events = [] @@ -56,8 +66,10 @@ class Send(unittest.TestCase): with mock.patch("nacl.utils.random", side_effect=[nonce1]) as r: s.got_verified_key(key) self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)]) - #print(bytes_to_hexstr(events[0][2])) - enc1 = hexstr_to_bytes("00000000000000000000000000000000000000000000000022f1a46c3c3496423c394621a2a5a8cf275b08") + # print(bytes_to_hexstr(events[0][2])) + enc1 = hexstr_to_bytes( + ("000000000000000000000000000000000000000000000000" + "22f1a46c3c3496423c394621a2a5a8cf275b08")) self.assertEqual(events, [("m.add_message", "phase1", enc1)]) events[:] = [] @@ -65,7 +77,9 @@ class Send(unittest.TestCase): with mock.patch("nacl.utils.random", side_effect=[nonce2]) as r: s.send("phase2", b"msg") self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)]) - enc2 = hexstr_to_bytes("0202020202020202020202020202020202020202020202026660337c3eac6513c0dac9818b62ef16d9cd7e") + enc2 = hexstr_to_bytes( + ("0202020202020202020202020202020202020202" + "020202026660337c3eac6513c0dac9818b62ef16d9cd7e")) self.assertEqual(events, [("m.add_message", "phase2", enc2)]) def test_key_first(self): @@ -78,7 +92,8 @@ class Send(unittest.TestCase): with mock.patch("nacl.utils.random", side_effect=[nonce1]) as r: s.send("phase1", b"msg") self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)]) - enc1 = hexstr_to_bytes("00000000000000000000000000000000000000000000000022f1a46c3c3496423c394621a2a5a8cf275b08") + enc1 = hexstr_to_bytes(("00000000000000000000000000000000000000000000" + "000022f1a46c3c3496423c394621a2a5a8cf275b08")) self.assertEqual(events, [("m.add_message", "phase1", enc1)]) events[:] = [] @@ -86,11 +101,12 @@ class Send(unittest.TestCase): with mock.patch("nacl.utils.random", side_effect=[nonce2]) as r: s.send("phase2", b"msg") self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)]) - enc2 = hexstr_to_bytes("0202020202020202020202020202020202020202020202026660337c3eac6513c0dac9818b62ef16d9cd7e") + enc2 = hexstr_to_bytes( + ("0202020202020202020202020202020202020" + "202020202026660337c3eac6513c0dac9818b62ef16d9cd7e")) self.assertEqual(events, [("m.add_message", "phase2", enc2)]) - class Order(unittest.TestCase): def build(self): events = [] @@ -103,35 +119,36 @@ class Order(unittest.TestCase): def test_in_order(self): o, k, r, events = self.build() o.got_message(u"side", u"pake", b"body") - self.assertEqual(events, [("k.got_pake", b"body")]) # right away + self.assertEqual(events, [("k.got_pake", b"body")]) # right away o.got_message(u"side", u"version", b"body") o.got_message(u"side", u"1", b"body") - self.assertEqual(events, - [("k.got_pake", b"body"), - ("r.got_message", u"side", u"version", b"body"), - ("r.got_message", u"side", u"1", b"body"), - ]) + self.assertEqual(events, [ + ("k.got_pake", b"body"), + ("r.got_message", u"side", u"version", b"body"), + ("r.got_message", u"side", u"1", b"body"), + ]) def test_out_of_order(self): o, k, r, events = self.build() o.got_message(u"side", u"version", b"body") - self.assertEqual(events, []) # nothing yet + self.assertEqual(events, []) # nothing yet o.got_message(u"side", u"1", b"body") - self.assertEqual(events, []) # nothing yet + self.assertEqual(events, []) # nothing yet o.got_message(u"side", u"pake", b"body") # got_pake is delivered first - self.assertEqual(events, - [("k.got_pake", b"body"), - ("r.got_message", u"side", u"version", b"body"), - ("r.got_message", u"side", u"1", b"body"), - ]) + self.assertEqual(events, [ + ("k.got_pake", b"body"), + ("r.got_message", u"side", u"version", b"body"), + ("r.got_message", u"side", u"1", b"body"), + ]) + class Receive(unittest.TestCase): def build(self): events = [] r = _receive.Receive(u"side", timing.DebugTiming()) - b = Dummy("b", events, IBoss, - "happy", "scared", "got_verifier", "got_message") + b = Dummy("b", events, IBoss, "happy", "scared", "got_verifier", + "got_message") s = Dummy("s", events, ISend, "got_verified_key") r.wire(b, s) return r, b, s, events @@ -146,22 +163,24 @@ class Receive(unittest.TestCase): data1 = b"data1" good_body = encrypt_data(phase1_key, data1) r.got_message(u"side", u"phase1", good_body) - self.assertEqual(events, [("s.got_verified_key", key), - ("b.happy",), - ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), - ]) + self.assertEqual(events, [ + ("s.got_verified_key", key), + ("b.happy", ), + ("b.got_verifier", verifier), + ("b.got_message", u"phase1", data1), + ]) phase2_key = derive_phase_key(key, u"side", u"phase2") data2 = b"data2" good_body = encrypt_data(phase2_key, data2) r.got_message(u"side", u"phase2", good_body) - self.assertEqual(events, [("s.got_verified_key", key), - ("b.happy",), - ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), - ("b.got_message", u"phase2", data2), - ]) + self.assertEqual(events, [ + ("s.got_verified_key", key), + ("b.happy", ), + ("b.got_verifier", verifier), + ("b.got_message", u"phase1", data1), + ("b.got_message", u"phase2", data2), + ]) def test_early_bad(self): r, b, s, events = self.build() @@ -172,15 +191,17 @@ class Receive(unittest.TestCase): data1 = b"data1" bad_body = encrypt_data(phase1_key, data1) r.got_message(u"side", u"phase1", bad_body) - self.assertEqual(events, [("b.scared",), - ]) + self.assertEqual(events, [ + ("b.scared", ), + ]) phase2_key = derive_phase_key(key, u"side", u"phase2") data2 = b"data2" good_body = encrypt_data(phase2_key, data2) r.got_message(u"side", u"phase2", good_body) - self.assertEqual(events, [("b.scared",), - ]) + self.assertEqual(events, [ + ("b.scared", ), + ]) def test_late_bad(self): r, b, s, events = self.build() @@ -192,30 +213,34 @@ class Receive(unittest.TestCase): data1 = b"data1" good_body = encrypt_data(phase1_key, data1) r.got_message(u"side", u"phase1", good_body) - self.assertEqual(events, [("s.got_verified_key", key), - ("b.happy",), - ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), - ]) + self.assertEqual(events, [ + ("s.got_verified_key", key), + ("b.happy", ), + ("b.got_verifier", verifier), + ("b.got_message", u"phase1", data1), + ]) phase2_key = derive_phase_key(key, u"side", u"bad") data2 = b"data2" bad_body = encrypt_data(phase2_key, data2) r.got_message(u"side", u"phase2", bad_body) - self.assertEqual(events, [("s.got_verified_key", key), - ("b.happy",), - ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), - ("b.scared",), - ]) + self.assertEqual(events, [ + ("s.got_verified_key", key), + ("b.happy", ), + ("b.got_verifier", verifier), + ("b.got_message", u"phase1", data1), + ("b.scared", ), + ]) r.got_message(u"side", u"phase1", good_body) r.got_message(u"side", u"phase2", bad_body) - self.assertEqual(events, [("s.got_verified_key", key), - ("b.happy",), - ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), - ("b.scared",), - ]) + self.assertEqual(events, [ + ("s.got_verified_key", key), + ("b.happy", ), + ("b.got_verifier", verifier), + ("b.got_message", u"phase1", data1), + ("b.scared", ), + ]) + class Key(unittest.TestCase): def test_derive_errors(self): @@ -260,11 +285,12 @@ class Key(unittest.TestCase): self.assertEqual(events[0][:2], ("m.add_message", "pake")) pake_1_json = events[0][2].decode("utf-8") pake_1 = json.loads(pake_1_json) - self.assertEqual(list(pake_1.keys()), ["pake_v1"]) # value is PAKE stuff + self.assertEqual(list(pake_1.keys()), + ["pake_v1"]) # value is PAKE stuff events[:] = [] bad_pake_d = {"not_pake_v1": "stuff"} k.got_pake(dict_to_bytes(bad_pake_d)) - self.assertEqual(events, [("b.scared",)]) + self.assertEqual(events, [("b.scared", )]) def test_reversed(self): # A receiver using input_code() will choose the nameplate first, then @@ -293,6 +319,7 @@ class Key(unittest.TestCase): self.assertEqual(events[2][:2], ("m.add_message", "version")) self.assertEqual(events[3], ("r.got_key", key2)) + class Code(unittest.TestCase): def build(self): events = [] @@ -308,10 +335,11 @@ class Code(unittest.TestCase): def test_set_code(self): c, b, a, n, k, i, events = self.build() c.set_code(u"1-code") - self.assertEqual(events, [("n.set_nameplate", u"1"), - ("b.got_code", u"1-code"), - ("k.got_code", u"1-code"), - ]) + self.assertEqual(events, [ + ("n.set_nameplate", u"1"), + ("b.got_code", u"1-code"), + ("k.got_code", u"1-code"), + ]) def test_set_code_invalid(self): c, b, a, n, k, i, events = self.build() @@ -323,15 +351,17 @@ class Code(unittest.TestCase): self.assertEqual(str(e.exception), "Code ' 1-code' contains spaces.") with self.assertRaises(errors.KeyFormatError) as e: c.set_code(u"code-code") - self.assertEqual(str(e.exception), - "Nameplate 'code' must be numeric, with no spaces.") + self.assertEqual( + str(e.exception), + "Nameplate 'code' must be numeric, with no spaces.") # it should still be possible to use the wormhole at this point c.set_code(u"1-code") - self.assertEqual(events, [("n.set_nameplate", u"1"), - ("b.got_code", u"1-code"), - ("k.got_code", u"1-code"), - ]) + self.assertEqual(events, [ + ("n.set_nameplate", u"1"), + ("b.got_code", u"1-code"), + ("k.got_code", u"1-code"), + ]) def test_allocate_code(self): c, b, a, n, k, i, events = self.build() @@ -340,30 +370,35 @@ class Code(unittest.TestCase): self.assertEqual(events, [("a.allocate", 2, wl)]) events[:] = [] c.allocated("1", "1-code") - self.assertEqual(events, [("n.set_nameplate", u"1"), - ("b.got_code", u"1-code"), - ("k.got_code", u"1-code"), - ]) + self.assertEqual(events, [ + ("n.set_nameplate", u"1"), + ("b.got_code", u"1-code"), + ("k.got_code", u"1-code"), + ]) def test_input_code(self): c, b, a, n, k, i, events = self.build() c.input_code() - self.assertEqual(events, [("i.start",)]) + self.assertEqual(events, [("i.start", )]) events[:] = [] c.got_nameplate("1") - self.assertEqual(events, [("n.set_nameplate", u"1"), - ]) + self.assertEqual(events, [ + ("n.set_nameplate", u"1"), + ]) events[:] = [] c.finished_input("1-code") - self.assertEqual(events, [("b.got_code", u"1-code"), - ("k.got_code", u"1-code"), - ]) + self.assertEqual(events, [ + ("b.got_code", u"1-code"), + ("k.got_code", u"1-code"), + ]) + class Input(unittest.TestCase): def build(self): events = [] i = _input.Input(timing.DebugTiming()) c = Dummy("c", events, ICode, "got_nameplate", "finished_input") + # renamed from l as l is indistinguishable from 1 in some fonts. l = Dummy("l", events, ILister, "refresh") i.wire(c, l) return i, c, l, events @@ -372,7 +407,7 @@ class Input(unittest.TestCase): i, c, l, events = self.build() helper = i.start() self.assertIsInstance(helper, _input.Helper) - self.assertEqual(events, [("l.refresh",)]) + self.assertEqual(events, [("l.refresh", )]) events[:] = [] with self.assertRaises(errors.MustChooseNameplateFirstError): helper.choose_words("word-word") @@ -390,7 +425,7 @@ class Input(unittest.TestCase): i, c, l, events = self.build() helper = i.start() self.assertIsInstance(helper, _input.Helper) - self.assertEqual(events, [("l.refresh",)]) + self.assertEqual(events, [("l.refresh", )]) events[:] = [] with self.assertRaises(errors.MustChooseNameplateFirstError): helper.choose_words("word-word") @@ -411,24 +446,24 @@ class Input(unittest.TestCase): i, c, l, events = self.build() helper = i.start() self.assertIsInstance(helper, _input.Helper) - self.assertEqual(events, [("l.refresh",)]) + self.assertEqual(events, [("l.refresh", )]) events[:] = [] d = helper.when_wordlist_is_available() self.assertNoResult(d) helper.refresh_nameplates() - self.assertEqual(events, [("l.refresh",)]) + self.assertEqual(events, [("l.refresh", )]) events[:] = [] with self.assertRaises(errors.MustChooseNameplateFirstError): helper.get_word_completions("prefix") i.got_nameplates({"1", "12", "34", "35", "367"}) self.assertNoResult(d) - self.assertEqual(helper.get_nameplate_completions(""), - {"1-", "12-", "34-", "35-", "367-"}) - self.assertEqual(helper.get_nameplate_completions("1"), - {"1-", "12-"}) + self.assertEqual( + helper.get_nameplate_completions(""), + {"1-", "12-", "34-", "35-", "367-"}) + self.assertEqual(helper.get_nameplate_completions("1"), {"1-", "12-"}) self.assertEqual(helper.get_nameplate_completions("2"), set()) - self.assertEqual(helper.get_nameplate_completions("3"), - {"34-", "35-", "367-"}) + self.assertEqual( + helper.get_nameplate_completions("3"), {"34-", "35-", "367-"}) helper.choose_nameplate("34") with self.assertRaises(errors.AlreadyChoseNameplateError): helper.refresh_nameplates() @@ -461,15 +496,14 @@ class Input(unittest.TestCase): self.assertEqual(events, [("c.finished_input", "34-word-word")]) - class Lister(unittest.TestCase): def build(self): events = [] - l = _lister.Lister(timing.DebugTiming()) + lister = _lister.Lister(timing.DebugTiming()) rc = Dummy("rc", events, IRendezvousConnector, "tx_list") i = Dummy("i", events, IInput, "got_nameplates") - l.wire(rc, i) - return l, rc, i, events + lister.wire(rc, i) + return lister, rc, i, events def test_connect_first(self): l, rc, i, events = self.build() @@ -478,12 +512,14 @@ class Lister(unittest.TestCase): l.connected() self.assertEqual(events, []) l.refresh() - self.assertEqual(events, [("rc.tx_list",), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ]) events[:] = [] l.rx_nameplates({"1", "2", "3"}) - self.assertEqual(events, [("i.got_nameplates", {"1", "2", "3"}), - ]) + self.assertEqual(events, [ + ("i.got_nameplates", {"1", "2", "3"}), + ]) events[:] = [] # now we're satisfied: disconnecting and reconnecting won't ask again l.lost() @@ -492,8 +528,9 @@ class Lister(unittest.TestCase): # but if we're told to refresh, we'll do so l.refresh() - self.assertEqual(events, [("rc.tx_list",), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ]) def test_connect_first_ask_twice(self): l, rc, i, events = self.build() @@ -501,44 +538,51 @@ class Lister(unittest.TestCase): self.assertEqual(events, []) l.refresh() l.refresh() - self.assertEqual(events, [("rc.tx_list",), - ("rc.tx_list",), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ("rc.tx_list", ), + ]) l.rx_nameplates({"1", "2", "3"}) - self.assertEqual(events, [("rc.tx_list",), - ("rc.tx_list",), - ("i.got_nameplates", {"1", "2", "3"}), - ]) - l.rx_nameplates({"1" ,"2", "3", "4"}) - self.assertEqual(events, [("rc.tx_list",), - ("rc.tx_list",), - ("i.got_nameplates", {"1", "2", "3"}), - ("i.got_nameplates", {"1", "2", "3", "4"}), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ("rc.tx_list", ), + ("i.got_nameplates", {"1", "2", "3"}), + ]) + l.rx_nameplates({"1", "2", "3", "4"}) + self.assertEqual(events, [ + ("rc.tx_list", ), + ("rc.tx_list", ), + ("i.got_nameplates", {"1", "2", "3"}), + ("i.got_nameplates", {"1", "2", "3", "4"}), + ]) def test_reconnect(self): l, rc, i, events = self.build() l.refresh() l.connected() - self.assertEqual(events, [("rc.tx_list",), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ]) events[:] = [] l.lost() l.connected() - self.assertEqual(events, [("rc.tx_list",), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ]) def test_refresh_first(self): l, rc, i, events = self.build() l.refresh() self.assertEqual(events, []) l.connected() - self.assertEqual(events, [("rc.tx_list",), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ]) l.rx_nameplates({"1", "2", "3"}) - self.assertEqual(events, [("rc.tx_list",), - ("i.got_nameplates", {"1", "2", "3"}), - ]) + self.assertEqual(events, [ + ("rc.tx_list", ), + ("i.got_nameplates", {"1", "2", "3"}), + ]) def test_unrefreshed(self): l, rc, i, events = self.build() @@ -547,8 +591,10 @@ class Lister(unittest.TestCase): l.connected() self.assertEqual(events, []) l.rx_nameplates({"1", "2", "3"}) - self.assertEqual(events, [("i.got_nameplates", {"1", "2", "3"}), - ]) + self.assertEqual(events, [ + ("i.got_nameplates", {"1", "2", "3"}), + ]) + class Allocator(unittest.TestCase): def build(self): @@ -569,32 +615,37 @@ class Allocator(unittest.TestCase): a.allocate(2, FakeWordList()) self.assertEqual(events, []) a.connected() - self.assertEqual(events, [("rc.tx_allocate",)]) + self.assertEqual(events, [("rc.tx_allocate", )]) events[:] = [] a.lost() a.connected() - self.assertEqual(events, [("rc.tx_allocate",), - ]) + self.assertEqual(events, [ + ("rc.tx_allocate", ), + ]) events[:] = [] a.rx_allocated("1") - self.assertEqual(events, [("c.allocated", "1", "1-word-word"), - ]) + self.assertEqual(events, [ + ("c.allocated", "1", "1-word-word"), + ]) def test_connect_first(self): a, rc, c, events = self.build() a.connected() self.assertEqual(events, []) a.allocate(2, FakeWordList()) - self.assertEqual(events, [("rc.tx_allocate",)]) + self.assertEqual(events, [("rc.tx_allocate", )]) events[:] = [] a.lost() a.connected() - self.assertEqual(events, [("rc.tx_allocate",), - ]) + self.assertEqual(events, [ + ("rc.tx_allocate", ), + ]) events[:] = [] a.rx_allocated("1") - self.assertEqual(events, [("c.allocated", "1", "1-word-word"), - ]) + self.assertEqual(events, [ + ("c.allocated", "1", "1-word-word"), + ]) + class Nameplate(unittest.TestCase): def build(self): @@ -602,7 +653,8 @@ class Nameplate(unittest.TestCase): n = _nameplate.Nameplate() m = Dummy("m", events, IMailbox, "got_mailbox") i = Dummy("i", events, IInput, "got_wordlist") - rc = Dummy("rc", events, IRendezvousConnector, "tx_claim", "tx_release") + rc = Dummy("rc", events, IRendezvousConnector, "tx_claim", + "tx_release") t = Dummy("t", events, ITerminator, "nameplate_done") n.wire(m, i, rc, t) return n, m, i, rc, t, events @@ -611,12 +663,14 @@ class Nameplate(unittest.TestCase): n, m, i, rc, t, events = self.build() with self.assertRaises(errors.KeyFormatError) as e: n.set_nameplate(" 1") - self.assertEqual(str(e.exception), - "Nameplate ' 1' must be numeric, with no spaces.") + self.assertEqual( + str(e.exception), + "Nameplate ' 1' must be numeric, with no spaces.") with self.assertRaises(errors.KeyFormatError) as e: n.set_nameplate("one") - self.assertEqual(str(e.exception), - "Nameplate 'one' must be numeric, with no spaces.") + self.assertEqual( + str(e.exception), + "Nameplate 'one' must be numeric, with no spaces.") # wormhole should still be usable n.set_nameplate("1") @@ -636,9 +690,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() @@ -646,7 +701,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_connect_first(self): # connection remains up throughout @@ -661,9 +716,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() @@ -671,7 +727,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_reconnect_while_claiming(self): # connection bounced while waiting for rx_claimed @@ -700,9 +756,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.lost() @@ -722,9 +779,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() @@ -748,9 +806,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() @@ -758,7 +817,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) events[:] = [] n.lost() @@ -768,20 +827,20 @@ class Nameplate(unittest.TestCase): def test_close_while_idle(self): n, m, i, rc, t, events = self.build() n.close() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_idle_connected(self): n, m, i, rc, t, events = self.build() n.connected() self.assertEqual(events, []) n.close() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_unclaimed(self): n, m, i, rc, t, events = self.build() n.set_nameplate("1") - n.close() # before ever being connected - self.assertEqual(events, [("t.nameplate_done",)]) + n.close() # before ever being connected + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_claiming(self): n, m, i, rc, t, events = self.build() @@ -796,7 +855,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_claiming_but_disconnected(self): n, m, i, rc, t, events = self.build() @@ -815,7 +874,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_claimed(self): n, m, i, rc, t, events = self.build() @@ -828,9 +887,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.close() @@ -839,7 +899,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_claimed_but_disconnected(self): n, m, i, rc, t, events = self.build() @@ -852,9 +912,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.lost() @@ -865,7 +926,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_releasing(self): n, m, i, rc, t, events = self.build() @@ -878,19 +939,20 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() self.assertEqual(events, [("rc.tx_release", "1")]) events[:] = [] - n.close() # ignored, we're already on our way out the door + n.close() # ignored, we're already on our way out the door self.assertEqual(events, []) n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_releasing_but_disconnecteda(self): n, m, i, rc, t, events = self.build() @@ -903,9 +965,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() @@ -922,7 +985,7 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) def test_close_while_done(self): # connection remains up throughout @@ -937,9 +1000,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() @@ -947,10 +1011,10 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) events[:] = [] - n.close() # NOP + n.close() # NOP self.assertEqual(events, []) def test_close_while_done_but_disconnected(self): @@ -966,9 +1030,10 @@ class Nameplate(unittest.TestCase): wl = object() with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): n.rx_claimed("mbox1") - self.assertEqual(events, [("i.got_wordlist", wl), - ("m.got_mailbox", "mbox1"), - ]) + self.assertEqual(events, [ + ("i.got_wordlist", wl), + ("m.got_mailbox", "mbox1"), + ]) events[:] = [] n.release() @@ -976,20 +1041,21 @@ class Nameplate(unittest.TestCase): events[:] = [] n.rx_released() - self.assertEqual(events, [("t.nameplate_done",)]) + self.assertEqual(events, [("t.nameplate_done", )]) events[:] = [] n.lost() - n.close() # NOP + n.close() # NOP self.assertEqual(events, []) + class Mailbox(unittest.TestCase): def build(self): events = [] m = _mailbox.Mailbox("side1") n = Dummy("n", events, INameplate, "release") - rc = Dummy("rc", events, IRendezvousConnector, - "tx_add", "tx_open", "tx_close") + rc = Dummy("rc", events, IRendezvousConnector, "tx_add", "tx_open", + "tx_close") o = Dummy("o", events, IOrder, "got_message") t = Dummy("t", events, ITerminator, "mailbox_done") m.wire(n, rc, o, t) @@ -998,12 +1064,13 @@ class Mailbox(unittest.TestCase): # TODO: test moods def assert_events(self, events, initial_events, tx_add_events): - self.assertEqual(len(events), len(initial_events)+len(tx_add_events), - events) + self.assertEqual( + len(events), + len(initial_events) + len(tx_add_events), events) self.assertEqual(events[:len(initial_events)], initial_events) self.assertEqual(set(events[len(initial_events):]), tx_add_events) - def test_connect_first(self): # connect before got_mailbox + def test_connect_first(self): # connect before got_mailbox m, n, rc, o, t, events = self.build() m.add_message("phase1", b"msg1") self.assertEqual(events, []) @@ -1029,38 +1096,42 @@ class Mailbox(unittest.TestCase): m.connected() # the other messages are allowed to be sent in any order - self.assert_events(events, [("rc.tx_open", "mbox1")], - { ("rc.tx_add", "phase1", b"msg1"), - ("rc.tx_add", "phase2", b"msg2"), - ("rc.tx_add", "phase3", b"msg3"), - }) + self.assert_events( + events, [("rc.tx_open", "mbox1")], { + ("rc.tx_add", "phase1", b"msg1"), + ("rc.tx_add", "phase2", b"msg2"), + ("rc.tx_add", "phase3", b"msg3"), + }) events[:] = [] - m.rx_message("side1", "phase1", b"msg1") # echo of our message, dequeue + m.rx_message("side1", "phase1", + b"msg1") # echo of our message, dequeue self.assertEqual(events, []) m.lost() m.connected() - self.assert_events(events, [("rc.tx_open", "mbox1")], - {("rc.tx_add", "phase2", b"msg2"), - ("rc.tx_add", "phase3", b"msg3"), - }) + self.assert_events(events, [("rc.tx_open", "mbox1")], { + ("rc.tx_add", "phase2", b"msg2"), + ("rc.tx_add", "phase3", b"msg3"), + }) events[:] = [] # a new message from the peer gets delivered, and the Nameplate is # released since the message proves that our peer opened the Mailbox # and therefore no longer needs the Nameplate - m.rx_message("side2", "phase1", b"msg1them") # new message from peer - self.assertEqual(events, [("n.release",), - ("o.got_message", "side2", "phase1", b"msg1them"), - ]) + m.rx_message("side2", "phase1", b"msg1them") # new message from peer + self.assertEqual(events, [ + ("n.release", ), + ("o.got_message", "side2", "phase1", b"msg1them"), + ]) events[:] = [] # we de-duplicate peer messages, but still re-release the nameplate # since Nameplate is smart enough to ignore that m.rx_message("side2", "phase1", b"msg1them") - self.assertEqual(events, [("n.release",), - ]) + self.assertEqual(events, [ + ("n.release", ), + ]) events[:] = [] m.close("happy") @@ -1081,7 +1152,7 @@ class Mailbox(unittest.TestCase): events[:] = [] m.rx_closed() - self.assertEqual(events, [("t.mailbox_done",)]) + self.assertEqual(events, [("t.mailbox_done", )]) events[:] = [] # while closed, we ignore everything @@ -1092,7 +1163,7 @@ class Mailbox(unittest.TestCase): m.connected() self.assertEqual(events, []) - def test_mailbox_first(self): # got_mailbox before connect + def test_mailbox_first(self): # got_mailbox before connect m, n, rc, o, t, events = self.build() m.add_message("phase1", b"msg1") self.assertEqual(events, []) @@ -1103,27 +1174,27 @@ class Mailbox(unittest.TestCase): m.connected() - self.assert_events(events, [("rc.tx_open", "mbox1")], - { ("rc.tx_add", "phase1", b"msg1"), - ("rc.tx_add", "phase2", b"msg2"), - }) + self.assert_events(events, [("rc.tx_open", "mbox1")], { + ("rc.tx_add", "phase1", b"msg1"), + ("rc.tx_add", "phase2", b"msg2"), + }) def test_close_while_idle(self): m, n, rc, o, t, events = self.build() m.close("happy") - self.assertEqual(events, [("t.mailbox_done",)]) + self.assertEqual(events, [("t.mailbox_done", )]) def test_close_while_idle_but_connected(self): m, n, rc, o, t, events = self.build() m.connected() m.close("happy") - self.assertEqual(events, [("t.mailbox_done",)]) + self.assertEqual(events, [("t.mailbox_done", )]) def test_close_while_mailbox_disconnected(self): m, n, rc, o, t, events = self.build() m.got_mailbox("mbox1") m.close("happy") - self.assertEqual(events, [("t.mailbox_done",)]) + self.assertEqual(events, [("t.mailbox_done", )]) def test_close_while_reconnecting(self): m, n, rc, o, t, events = self.build() @@ -1143,9 +1214,10 @@ class Mailbox(unittest.TestCase): events[:] = [] m.rx_closed() - self.assertEqual(events, [("t.mailbox_done",)]) + self.assertEqual(events, [("t.mailbox_done", )]) events[:] = [] + class Terminator(unittest.TestCase): def build(self): events = [] @@ -1160,13 +1232,15 @@ class Terminator(unittest.TestCase): # there are three events, and we need to test all orderings of them def _do_test(self, ev1, ev2, ev3): t, b, rc, n, m, events = self.build() - input_events = {"mailbox": lambda: t.mailbox_done(), - "nameplate": lambda: t.nameplate_done(), - "close": lambda: t.close("happy"), - } - close_events = [("n.close",), - ("m.close", "happy"), - ] + input_events = { + "mailbox": lambda: t.mailbox_done(), + "nameplate": lambda: t.nameplate_done(), + "close": lambda: t.close("happy"), + } + close_events = [ + ("n.close", ), + ("m.close", "happy"), + ] input_events[ev1]() expected = [] @@ -1186,12 +1260,12 @@ class Terminator(unittest.TestCase): expected = [] if ev3 == "close": expected.extend(close_events) - expected.append(("rc.stop",)) + expected.append(("rc.stop", )) self.assertEqual(events, expected) events[:] = [] t.stopped() - self.assertEqual(events, [("b.closed",)]) + self.assertEqual(events, [("b.closed", )]) def test_terminate(self): self._do_test("mailbox", "nameplate", "close") @@ -1203,18 +1277,19 @@ class Terminator(unittest.TestCase): # TODO: test moods + class MockBoss(_boss.Boss): def __attrs_post_init__(self): - #self._build_workers() + # self._build_workers() self._init_other_state() + class Boss(unittest.TestCase): def build(self): events = [] - wormhole = Dummy("w", events, None, - "got_welcome", - "got_code", "got_key", "got_verifier", "got_versions", - "received", "closed") + wormhole = Dummy("w", events, None, "got_welcome", "got_code", + "got_key", "got_verifier", "got_versions", "received", + "closed") versions = {"app": "version1"} reactor = None journal = ImmediateJournal() @@ -1226,8 +1301,8 @@ class Boss(unittest.TestCase): b._T = Dummy("t", events, ITerminator, "close") b._S = Dummy("s", events, ISend, "send") b._RC = Dummy("rc", events, IRendezvousConnector, "start") - b._C = Dummy("c", events, ICode, - "allocate_code", "input_code", "set_code") + b._C = Dummy("c", events, ICode, "allocate_code", "input_code", + "set_code") return b, events def test_basic(self): @@ -1242,8 +1317,9 @@ class Boss(unittest.TestCase): welcome = {"howdy": "how are ya"} b.rx_welcome(welcome) - self.assertEqual(events, [("w.got_welcome", welcome), - ]) + self.assertEqual(events, [ + ("w.got_welcome", welcome), + ]) events[:] = [] # pretend a peer message was correctly decrypted @@ -1252,11 +1328,12 @@ class Boss(unittest.TestCase): b.got_verifier(b"verifier") b.got_message("version", b"{}") b.got_message("0", b"msg1") - self.assertEqual(events, [("w.got_key", b"key"), - ("w.got_verifier", b"verifier"), - ("w.got_versions", {}), - ("w.received", b"msg1"), - ]) + self.assertEqual(events, [ + ("w.got_key", b"key"), + ("w.got_verifier", b"verifier"), + ("w.got_versions", {}), + ("w.received", b"msg1"), + ]) events[:] = [] b.send(b"msg2") @@ -1330,7 +1407,7 @@ class Boss(unittest.TestCase): self.assertEqual(events, [("c.set_code", "1-code")]) events[:] = [] - b.close() # before even w.got_code + b.close() # before even w.got_code self.assertEqual(events, [("t.close", "lonely")]) events[:] = [] @@ -1383,9 +1460,9 @@ class Boss(unittest.TestCase): self.assertEqual(events, [("w.got_code", "1-code")]) events[:] = [] - b.happy() # phase=version + b.happy() # phase=version - b.scared() # phase=0 + b.scared() # phase=0 self.assertEqual(events, [("t.close", "scary")]) events[:] = [] @@ -1404,7 +1481,7 @@ class Boss(unittest.TestCase): self.assertEqual(events, [("w.got_code", "1-code")]) events[:] = [] - b.happy() # phase=version + b.happy() # phase=version b.got_message("unknown-phase", b"spooky") self.assertEqual(events, []) @@ -1429,7 +1506,7 @@ class Boss(unittest.TestCase): b, events = self.build() b._C.retval = "helper" helper = b.input_code() - self.assertEqual(events, [("c.input_code",)]) + self.assertEqual(events, [("c.input_code", )]) self.assertEqual(helper, "helper") with self.assertRaises(errors.OnlyOneCodeError): b.input_code() @@ -1451,11 +1528,9 @@ class Rendezvous(unittest.TestCase): journal = ImmediateJournal() tor_manager = None client_version = ("python", __version__) - rc = _rendezvous.RendezvousConnector("ws://host:4000/v1", "appid", - "side", reactor, - journal, tor_manager, - timing.DebugTiming(), - client_version) + rc = _rendezvous.RendezvousConnector( + "ws://host:4000/v1", "appid", "side", reactor, journal, + tor_manager, timing.DebugTiming(), client_version) b = Dummy("b", events, IBoss, "error") n = Dummy("n", events, INameplate, "connected", "lost") m = Dummy("m", events, IMailbox, "connected", "lost") @@ -1488,34 +1563,43 @@ class Rendezvous(unittest.TestCase): rc, events = self.build() ws = mock.Mock() + def notrandom(length): return b"\x00" * length + with mock.patch("os.urandom", notrandom): rc.ws_open(ws) - self.assertEqual(events, [("n.connected", ), - ("m.connected", ), - ("l.connected", ), - ("a.connected", ), - ]) + self.assertEqual(events, [ + ("n.connected", ), + ("m.connected", ), + ("l.connected", ), + ("a.connected", ), + ]) events[:] = [] + def sent_messages(ws): for c in ws.mock_calls: self.assertEqual(c[0], "sendMessage", ws.mock_calls) self.assertEqual(c[1][1], False, ws.mock_calls) yield bytes_to_dict(c[1][0]) - self.assertEqual(list(sent_messages(ws)), - [dict(appid="appid", side="side", - client_version=["python", __version__], - id="0000", type="bind"), - ]) + + self.assertEqual( + list(sent_messages(ws)), [ + dict( + appid="appid", + side="side", + client_version=["python", __version__], + id="0000", + type="bind"), + ]) rc.ws_close(True, None, None) - self.assertEqual(events, [("n.lost", ), - ("m.lost", ), - ("l.lost", ), - ("a.lost", ), - ]) - + self.assertEqual(events, [ + ("n.lost", ), + ("m.lost", ), + ("l.lost", ), + ("a.lost", ), + ]) # TODO diff --git a/src/wormhole/test/test_observer.py b/src/wormhole/test/test_observer.py index d974cc5..8d89ddd 100644 --- a/src/wormhole/test/test_observer.py +++ b/src/wormhole/test/test_observer.py @@ -1,9 +1,11 @@ -from twisted.trial import unittest from twisted.internet.task import Clock from twisted.python.failure import Failure +from twisted.trial import unittest + from ..eventual import EventualQueue from ..observer import OneShotObserver, SequenceObserver + class OneShot(unittest.TestCase): def test_fire(self): c = Clock() @@ -119,4 +121,3 @@ class Sequence(unittest.TestCase): d2 = o.when_next_event() eq.flush_sync() self.assertIdentical(self.failureResultOf(d2), f) - diff --git a/src/wormhole/test/test_rlcompleter.py b/src/wormhole/test/test_rlcompleter.py index 41f0abf..dd050d6 100644 --- a/src/wormhole/test/test_rlcompleter.py +++ b/src/wormhole/test/test_rlcompleter.py @@ -1,39 +1,44 @@ -from __future__ import print_function, absolute_import, unicode_literals -import mock +from __future__ import absolute_import, print_function, unicode_literals + from itertools import count -from twisted.trial import unittest + from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks from twisted.internet.threads import deferToThread -from .._rlcompleter import (input_with_completion, - _input_code_with_completion, - CodeInputter, warn_readline) -from ..errors import KeyFormatError, AlreadyInputNameplateError +from twisted.trial import unittest + +import mock + +from .._rlcompleter import (CodeInputter, _input_code_with_completion, + input_with_completion, warn_readline) +from ..errors import AlreadyInputNameplateError, KeyFormatError + APPID = "appid" + class Input(unittest.TestCase): @inlineCallbacks def test_wrapper(self): helper = object() trueish = object() - with mock.patch("wormhole._rlcompleter._input_code_with_completion", - return_value=trueish) as m: - used_completion = yield input_with_completion("prompt:", helper, - reactor) + with mock.patch( + "wormhole._rlcompleter._input_code_with_completion", + return_value=trueish) as m: + used_completion = yield input_with_completion( + "prompt:", helper, reactor) self.assertIs(used_completion, trueish) - self.assertEqual(m.mock_calls, - [mock.call("prompt:", helper, reactor)]) + self.assertEqual(m.mock_calls, [mock.call("prompt:", helper, reactor)]) # note: if this test fails, the warn_readline() message will probably # get written to stderr + class Sync(unittest.TestCase): # exercise _input_code_with_completion, which uses the blocking builtin # "input()" function, hence _input_code_with_completion is usually in a # thread with deferToThread @mock.patch("wormhole._rlcompleter.CodeInputter") - @mock.patch("wormhole._rlcompleter.readline", - __doc__="I am GNU readline") + @mock.patch("wormhole._rlcompleter.readline", __doc__="I am GNU readline") @mock.patch("wormhole._rlcompleter.input", return_value="code") def test_readline(self, input, readline, ci): c = mock.Mock(name="inhibit parenting") @@ -49,17 +54,17 @@ class Sync(unittest.TestCase): self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)]) self.assertEqual(c.mock_calls, [mock.call.finish("code")]) self.assertEqual(input.mock_calls, [mock.call(prompt)]) - self.assertEqual(readline.mock_calls, - [mock.call.parse_and_bind("tab: complete"), - mock.call.set_completer(c.completer), - mock.call.set_completer_delims(""), - ]) + self.assertEqual(readline.mock_calls, [ + mock.call.parse_and_bind("tab: complete"), + mock.call.set_completer(c.completer), + mock.call.set_completer_delims(""), + ]) @mock.patch("wormhole._rlcompleter.CodeInputter") @mock.patch("wormhole._rlcompleter.readline") @mock.patch("wormhole._rlcompleter.input", return_value="code") def test_readline_no_docstring(self, input, readline, ci): - del readline.__doc__ # when in doubt, it assumes GNU readline + del readline.__doc__ # when in doubt, it assumes GNU readline c = mock.Mock(name="inhibit parenting") c.completer = object() trueish = object() @@ -73,15 +78,14 @@ class Sync(unittest.TestCase): self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)]) self.assertEqual(c.mock_calls, [mock.call.finish("code")]) self.assertEqual(input.mock_calls, [mock.call(prompt)]) - self.assertEqual(readline.mock_calls, - [mock.call.parse_and_bind("tab: complete"), - mock.call.set_completer(c.completer), - mock.call.set_completer_delims(""), - ]) + self.assertEqual(readline.mock_calls, [ + mock.call.parse_and_bind("tab: complete"), + mock.call.set_completer(c.completer), + mock.call.set_completer_delims(""), + ]) @mock.patch("wormhole._rlcompleter.CodeInputter") - @mock.patch("wormhole._rlcompleter.readline", - __doc__="I am libedit") + @mock.patch("wormhole._rlcompleter.readline", __doc__="I am libedit") @mock.patch("wormhole._rlcompleter.input", return_value="code") def test_libedit(self, input, readline, ci): c = mock.Mock(name="inhibit parenting") @@ -97,11 +101,11 @@ class Sync(unittest.TestCase): self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)]) self.assertEqual(c.mock_calls, [mock.call.finish("code")]) self.assertEqual(input.mock_calls, [mock.call(prompt)]) - self.assertEqual(readline.mock_calls, - [mock.call.parse_and_bind("bind ^I rl_complete"), - mock.call.set_completer(c.completer), - mock.call.set_completer_delims(""), - ]) + self.assertEqual(readline.mock_calls, [ + mock.call.parse_and_bind("bind ^I rl_complete"), + mock.call.set_completer(c.completer), + mock.call.set_completer_delims(""), + ]) @mock.patch("wormhole._rlcompleter.CodeInputter") @mock.patch("wormhole._rlcompleter.readline", None) @@ -139,6 +143,7 @@ class Sync(unittest.TestCase): self.assertEqual(c.mock_calls, [mock.call.finish(u"code")]) self.assertEqual(input.mock_calls, [mock.call(prompt)]) + def get_completions(c, prefix): completions = [] for state in count(0): @@ -147,9 +152,11 @@ def get_completions(c, prefix): return completions completions.append(text) + def fake_blockingCallFromThread(f, *a, **kw): return f(*a, **kw) + class Completion(unittest.TestCase): def test_simple(self): # no actual completion @@ -158,12 +165,14 @@ class Completion(unittest.TestCase): c.bcft = fake_blockingCallFromThread c.finish("1-code-ghost") self.assertFalse(c.used_completion) - self.assertEqual(helper.mock_calls, - [mock.call.choose_nameplate("1"), - mock.call.choose_words("code-ghost")]) + self.assertEqual(helper.mock_calls, [ + mock.call.choose_nameplate("1"), + mock.call.choose_words("code-ghost") + ]) - @mock.patch("wormhole._rlcompleter.readline", - get_completion_type=mock.Mock(return_value=0)) + @mock.patch( + "wormhole._rlcompleter.readline", + get_completion_type=mock.Mock(return_value=0)) def test_call(self, readline): # check that it calls _commit_and_build_completions correctly helper = mock.Mock() @@ -188,14 +197,14 @@ class Completion(unittest.TestCase): # now we have three "a" words: "and", "ark", "aaah!zombies!!" cabc.reset_mock() cabc.configure_mock(return_value=["aargh", "ark", "aaah!zombies!!"]) - self.assertEqual(get_completions(c, "12-a"), - ["aargh", "ark", "aaah!zombies!!"]) + self.assertEqual( + get_completions(c, "12-a"), ["aargh", "ark", "aaah!zombies!!"]) self.assertEqual(cabc.mock_calls, [mock.call("12-a")]) cabc.reset_mock() cabc.configure_mock(return_value=["aargh", "aaah!zombies!!"]) - self.assertEqual(get_completions(c, "12-aa"), - ["aargh", "aaah!zombies!!"]) + self.assertEqual( + get_completions(c, "12-aa"), ["aargh", "aaah!zombies!!"]) self.assertEqual(cabc.mock_calls, [mock.call("12-aa")]) cabc.reset_mock() @@ -223,16 +232,17 @@ class Completion(unittest.TestCase): def test_build_completions(self): rn = mock.Mock() # InputHelper.get_nameplate_completions returns just the suffixes - gnc = mock.Mock() # get_nameplate_completions - cn = mock.Mock() # choose_nameplate - gwc = mock.Mock() # get_word_completions - cw = mock.Mock() # choose_words - helper = mock.Mock(refresh_nameplates=rn, - get_nameplate_completions=gnc, - choose_nameplate=cn, - get_word_completions=gwc, - choose_words=cw, - ) + gnc = mock.Mock() # get_nameplate_completions + cn = mock.Mock() # choose_nameplate + gwc = mock.Mock() # get_word_completions + cw = mock.Mock() # choose_words + helper = mock.Mock( + refresh_nameplates=rn, + get_nameplate_completions=gnc, + choose_nameplate=cn, + get_word_completions=gwc, + choose_words=cw, + ) # this needs a real reactor, for blockingCallFromThread c = CodeInputter(helper, reactor) cabc = c._commit_and_build_completions @@ -327,17 +337,18 @@ class Completion(unittest.TestCase): gwc.configure_mock(return_value=["code", "court"]) c = CodeInputter(helper, reactor) cabc = c._commit_and_build_completions - matches = yield deferToThread(cabc, "1-co") # this commits us to 1- - self.assertEqual(helper.mock_calls, - [mock.call.choose_nameplate("1"), - mock.call.when_wordlist_is_available(), - mock.call.get_word_completions("co")]) + matches = yield deferToThread(cabc, "1-co") # this commits us to 1- + self.assertEqual(helper.mock_calls, [ + mock.call.choose_nameplate("1"), + mock.call.when_wordlist_is_available(), + mock.call.get_word_completions("co") + ]) self.assertEqual(matches, ["1-code", "1-court"]) helper.reset_mock() with self.assertRaises(AlreadyInputNameplateError) as e: yield deferToThread(cabc, "2-co") - self.assertEqual(str(e.exception), - "nameplate (1-) already entered, cannot go back") + self.assertEqual( + str(e.exception), "nameplate (1-) already entered, cannot go back") self.assertEqual(helper.mock_calls, []) @inlineCallbacks @@ -347,17 +358,18 @@ class Completion(unittest.TestCase): gwc.configure_mock(return_value=["code", "court"]) c = CodeInputter(helper, reactor) cabc = c._commit_and_build_completions - matches = yield deferToThread(cabc, "1-co") # this commits us to 1- - self.assertEqual(helper.mock_calls, - [mock.call.choose_nameplate("1"), - mock.call.when_wordlist_is_available(), - mock.call.get_word_completions("co")]) + matches = yield deferToThread(cabc, "1-co") # this commits us to 1- + self.assertEqual(helper.mock_calls, [ + mock.call.choose_nameplate("1"), + mock.call.when_wordlist_is_available(), + mock.call.get_word_completions("co") + ]) self.assertEqual(matches, ["1-code", "1-court"]) helper.reset_mock() with self.assertRaises(AlreadyInputNameplateError) as e: yield deferToThread(c.finish, "2-code") - self.assertEqual(str(e.exception), - "nameplate (1-) already entered, cannot go back") + self.assertEqual( + str(e.exception), "nameplate (1-) already entered, cannot go back") self.assertEqual(helper.mock_calls, []) @mock.patch("wormhole._rlcompleter.stderr") @@ -366,6 +378,7 @@ class Completion(unittest.TestCase): # right time, since it involves a reactor and a "system event # trigger", but let's at least make sure it's invocable warn_readline() - expected ="\nCommand interrupted: please press Return to quit" - self.assertEqual(stderr.mock_calls, [mock.call.write(expected), - mock.call.write("\n")]) + expected = "\nCommand interrupted: please press Return to quit" + self.assertEqual(stderr.mock_calls, + [mock.call.write(expected), + mock.call.write("\n")]) diff --git a/src/wormhole/test/test_ssh.py b/src/wormhole/test/test_ssh.py index 775c22a..9f86eed 100644 --- a/src/wormhole/test/test_ssh.py +++ b/src/wormhole/test/test_ssh.py @@ -1,10 +1,15 @@ -import os, io -import mock +import io +import os + from twisted.trial import unittest + +import mock + from ..cli import cmd_ssh OTHERS = ["config", "config~", "known_hosts", "known_hosts~"] + class FindPubkey(unittest.TestCase): def test_find_one(self): files = OTHERS + ["id_rsa.pub", "id_rsa"] @@ -12,8 +17,8 @@ class FindPubkey(unittest.TestCase): pubkey_file = io.StringIO(pubkey_data) with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True): with mock.patch("os.listdir", return_value=files) as ld: - with mock.patch("wormhole.cli.cmd_ssh.open", - return_value=pubkey_file): + with mock.patch( + "wormhole.cli.cmd_ssh.open", return_value=pubkey_file): res = cmd_ssh.find_public_key() self.assertEqual(ld.mock_calls, [mock.call(os.path.expanduser("~/.ssh/"))]) @@ -24,7 +29,7 @@ class FindPubkey(unittest.TestCase): self.assertEqual(pubkey, pubkey_data) def test_find_none(self): - files = OTHERS # no pubkey + files = OTHERS # no pubkey with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True): with mock.patch("os.listdir", return_value=files): e = self.assertRaises(cmd_ssh.PubkeyError, @@ -34,12 +39,12 @@ class FindPubkey(unittest.TestCase): def test_bad_hint(self): with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=False): - e = self.assertRaises(cmd_ssh.PubkeyError, - cmd_ssh.find_public_key, - hint="bogus/path") + e = self.assertRaises( + cmd_ssh.PubkeyError, + cmd_ssh.find_public_key, + hint="bogus/path") self.assertEqual(str(e), "Can't find 'bogus/path'") - def test_find_multiple(self): files = OTHERS + ["id_rsa.pub", "id_rsa", "id_dsa.pub", "id_dsa"] pubkey_data = u"ssh-rsa AAAAkeystuff email@host\n" @@ -47,10 +52,11 @@ class FindPubkey(unittest.TestCase): with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True): with mock.patch("os.listdir", return_value=files): responses = iter(["frog", "NaN", "-1", "0"]) - with mock.patch("click.prompt", - side_effect=lambda p: next(responses)): - with mock.patch("wormhole.cli.cmd_ssh.open", - return_value=pubkey_file): + with mock.patch( + "click.prompt", side_effect=lambda p: next(responses)): + with mock.patch( + "wormhole.cli.cmd_ssh.open", + return_value=pubkey_file): res = cmd_ssh.find_public_key() self.assertEqual(len(res), 3, res) kind, keyid, pubkey = res diff --git a/src/wormhole/test/test_tor_manager.py b/src/wormhole/test/test_tor_manager.py index b758e71..386698c 100644 --- a/src/wormhole/test/test_tor_manager.py +++ b/src/wormhole/test/test_tor_manager.py @@ -1,42 +1,51 @@ from __future__ import print_function, unicode_literals -import mock, io -from twisted.trial import unittest + +import io + from twisted.internet import defer from twisted.internet.error import ConnectError +from twisted.trial import unittest + +import mock -from ..tor_manager import get_tor, SocksOnlyTor -from ..errors import NoTorError from .._interfaces import ITorManager +from ..errors import NoTorError +from ..tor_manager import SocksOnlyTor, get_tor + class X(): pass + class Tor(unittest.TestCase): def test_no_txtorcon(self): with mock.patch("wormhole.tor_manager.txtorcon", None): self.failureResultOf(get_tor(None), NoTorError) def test_bad_args(self): - f = self.failureResultOf(get_tor(None, launch_tor="not boolean"), - TypeError) + f = self.failureResultOf( + get_tor(None, launch_tor="not boolean"), TypeError) self.assertEqual(str(f.value), "launch_tor= must be boolean") - f = self.failureResultOf(get_tor(None, tor_control_port=1234), - TypeError) + f = self.failureResultOf( + get_tor(None, tor_control_port=1234), TypeError) self.assertEqual(str(f.value), "tor_control_port= must be str or None") - f = self.failureResultOf(get_tor(None, launch_tor=True, - tor_control_port="tcp:127.0.0.1:1234"), - ValueError) - self.assertEqual(str(f.value), - "cannot combine --launch-tor and --tor-control-port=") + f = self.failureResultOf( + get_tor( + None, launch_tor=True, tor_control_port="tcp:127.0.0.1:1234"), + ValueError) + self.assertEqual( + str(f.value), + "cannot combine --launch-tor and --tor-control-port=") def test_launch(self): reactor = object() - my_tor = X() # object() didn't like providedBy() + my_tor = X() # object() didn't like providedBy() launch_d = defer.Deferred() stderr = io.StringIO() - with mock.patch("wormhole.tor_manager.txtorcon.launch", - side_effect=launch_d) as launch: + with mock.patch( + "wormhole.tor_manager.txtorcon.launch", + side_effect=launch_d) as launch: d = get_tor(reactor, launch_tor=True, stderr=stderr) self.assertNoResult(d) self.assertEqual(launch.mock_calls, [mock.call(reactor)]) @@ -44,18 +53,21 @@ class Tor(unittest.TestCase): tor = self.successResultOf(d) self.assertIs(tor, my_tor) self.assert_(ITorManager.providedBy(tor)) - self.assertEqual(stderr.getvalue(), - " launching a new Tor process, this may take a while..\n") + self.assertEqual( + stderr.getvalue(), + " launching a new Tor process, this may take a while..\n") def test_connect(self): reactor = object() - my_tor = X() # object() didn't like providedBy() + my_tor = X() # object() didn't like providedBy() connect_d = defer.Deferred() stderr = io.StringIO() - with mock.patch("wormhole.tor_manager.txtorcon.connect", - side_effect=connect_d) as connect: - with mock.patch("wormhole.tor_manager.clientFromString", - side_effect=["foo"]) as sfs: + with mock.patch( + "wormhole.tor_manager.txtorcon.connect", + side_effect=connect_d) as connect: + with mock.patch( + "wormhole.tor_manager.clientFromString", + side_effect=["foo"]) as sfs: d = get_tor(reactor, stderr=stderr) self.assertEqual(sfs.mock_calls, []) self.assertNoResult(d) @@ -71,10 +83,12 @@ class Tor(unittest.TestCase): reactor = object() connect_d = defer.Deferred() stderr = io.StringIO() - with mock.patch("wormhole.tor_manager.txtorcon.connect", - side_effect=connect_d) as connect: - with mock.patch("wormhole.tor_manager.clientFromString", - side_effect=["foo"]) as sfs: + with mock.patch( + "wormhole.tor_manager.txtorcon.connect", + side_effect=connect_d) as connect: + with mock.patch( + "wormhole.tor_manager.clientFromString", + side_effect=["foo"]) as sfs: d = get_tor(reactor, stderr=stderr) self.assertEqual(sfs.mock_calls, []) self.assertNoResult(d) @@ -85,20 +99,23 @@ class Tor(unittest.TestCase): self.assertIsInstance(tor, SocksOnlyTor) self.assert_(ITorManager.providedBy(tor)) self.assertEqual(tor._reactor, reactor) - self.assertEqual(stderr.getvalue(), - " unable to find default Tor control port, using SOCKS\n") + self.assertEqual( + stderr.getvalue(), + " unable to find default Tor control port, using SOCKS\n") def test_connect_custom_control_port(self): reactor = object() - my_tor = X() # object() didn't like providedBy() + my_tor = X() # object() didn't like providedBy() tcp = "PORT" ep = object() connect_d = defer.Deferred() stderr = io.StringIO() - with mock.patch("wormhole.tor_manager.txtorcon.connect", - side_effect=connect_d) as connect: - with mock.patch("wormhole.tor_manager.clientFromString", - side_effect=[ep]) as sfs: + with mock.patch( + "wormhole.tor_manager.txtorcon.connect", + side_effect=connect_d) as connect: + with mock.patch( + "wormhole.tor_manager.clientFromString", + side_effect=[ep]) as sfs: d = get_tor(reactor, tor_control_port=tcp, stderr=stderr) self.assertEqual(sfs.mock_calls, [mock.call(reactor, tcp)]) self.assertNoResult(d) @@ -116,10 +133,12 @@ class Tor(unittest.TestCase): ep = object() connect_d = defer.Deferred() stderr = io.StringIO() - with mock.patch("wormhole.tor_manager.txtorcon.connect", - side_effect=connect_d) as connect: - with mock.patch("wormhole.tor_manager.clientFromString", - side_effect=[ep]) as sfs: + with mock.patch( + "wormhole.tor_manager.txtorcon.connect", + side_effect=connect_d) as connect: + with mock.patch( + "wormhole.tor_manager.clientFromString", + side_effect=[ep]) as sfs: d = get_tor(reactor, tor_control_port=tcp, stderr=stderr) self.assertEqual(sfs.mock_calls, [mock.call(reactor, tcp)]) self.assertNoResult(d) @@ -129,18 +148,22 @@ class Tor(unittest.TestCase): self.failureResultOf(d, ConnectError) self.assertEqual(stderr.getvalue(), "") + class SocksOnly(unittest.TestCase): def test_tor(self): reactor = object() sot = SocksOnlyTor(reactor) fake_ep = object() - with mock.patch("wormhole.tor_manager.txtorcon.TorClientEndpoint", - return_value=fake_ep) as tce: + with mock.patch( + "wormhole.tor_manager.txtorcon.TorClientEndpoint", + return_value=fake_ep) as tce: ep = sot.stream_via("host", "port") self.assertIs(ep, fake_ep) - self.assertEqual(tce.mock_calls, [mock.call("host", "port", - socks_endpoint=None, - tls=False, - reactor=reactor)]) - - + self.assertEqual(tce.mock_calls, [ + mock.call( + "host", + "port", + socks_endpoint=None, + tls=False, + reactor=reactor) + ]) diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index fb03f38..63216ae 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -1,27 +1,33 @@ from __future__ import print_function, unicode_literals -import six -import io + import gc -import mock +import io from binascii import hexlify, unhexlify from collections import namedtuple -from twisted.trial import unittest -from twisted.internet import defer, task, endpoints, protocol, address, error + +import six +from nacl.exceptions import CryptoError +from nacl.secret import SecretBox +from twisted.internet import address, defer, endpoints, error, protocol, task from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.python import log from twisted.test import proto_helpers +from twisted.trial import unittest + +import mock from wormhole_transit_relay import transit_server -from ..errors import InternalError + from .. import transit +from ..errors import InternalError from .common import ServerBase -from nacl.secret import SecretBox -from nacl.exceptions import CryptoError + class Highlander(unittest.TestCase): def test_one_winner(self): cancelled = set() - contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) - for i in range(5)] + contenders = [ + defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(5) + ] d = transit.there_can_be_only_one(contenders) self.assertNoResult(d) contenders[0].errback(ValueError()) @@ -30,12 +36,13 @@ class Highlander(unittest.TestCase): self.assertNoResult(d) contenders[2].callback("yay") self.assertEqual(self.successResultOf(d), "yay") - self.assertEqual(cancelled, set([3,4])) + self.assertEqual(cancelled, set([3, 4])) def test_there_might_also_be_none(self): cancelled = set() - contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) - for i in range(4)] + contenders = [ + defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4) + ] d = transit.there_can_be_only_one(contenders) self.assertNoResult(d) contenders[0].errback(ValueError()) @@ -45,13 +52,14 @@ class Highlander(unittest.TestCase): contenders[2].errback(TypeError()) self.assertNoResult(d) contenders[3].errback(NameError()) - self.failureResultOf(d, ValueError) # first failure is recorded + self.failureResultOf(d, ValueError) # first failure is recorded self.assertEqual(cancelled, set()) def test_cancel_early(self): cancelled = set() - contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) - for i in range(4)] + contenders = [ + defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4) + ] d = transit.there_can_be_only_one(contenders) self.assertNoResult(d) self.assertEqual(cancelled, set()) @@ -61,15 +69,17 @@ class Highlander(unittest.TestCase): def test_cancel_after_one_failure(self): cancelled = set() - contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) - for i in range(4)] + contenders = [ + defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4) + ] d = transit.there_can_be_only_one(contenders) self.assertNoResult(d) self.assertEqual(cancelled, set()) contenders[0].errback(ValueError()) d.cancel() self.failureResultOf(d, ValueError) - self.assertEqual(cancelled, set([1,2,3])) + self.assertEqual(cancelled, set([1, 2, 3])) + class Forever(unittest.TestCase): def _forever_setup(self): @@ -116,6 +126,7 @@ class Forever(unittest.TestCase): self.failureResultOf(d, defer.CancelledError) self.assertNot(clock.getDelayedCalls()) + class Misc(unittest.TestCase): def test_allocate_port(self): portno = transit.allocate_tcp_port() @@ -128,29 +139,36 @@ class Misc(unittest.TestCase): portno = transit.allocate_tcp_port() self.assertIsInstance(portno, int) + UnknownHint = namedtuple("UnknownHint", ["stuff"]) + class Hints(unittest.TestCase): def test_endpoint_from_hint_obj(self): c = transit.Common("") efho = c._endpoint_from_hint_obj - self.assertIsInstance(efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), - endpoints.HostnameEndpoint) + self.assertIsInstance( + efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), + endpoints.HostnameEndpoint) self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) # c._tor is currently None self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None) c._tor = mock.Mock() + def tor_ep(hostname, port): if hostname == "non-public": return None return ("tor_ep", hostname, port) + c._tor.stream_via = mock.Mock(side_effect=tor_ep) - self.assertEqual(efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), - ("tor_ep", "host", 1234)) - self.assertEqual(efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)), - ("tor_ep", "host2.onion", 1234)) - self.assertEqual(efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), - None) + self.assertEqual( + efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), + ("tor_ep", "host", 1234)) + self.assertEqual( + efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)), + ("tor_ep", "host2.onion", 1234)) + self.assertEqual( + efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), None) self.assertEqual(efho(UnknownHint("foo")), None) def test_comparable(self): @@ -170,94 +188,120 @@ class Hints(unittest.TestCase): self.assertEqual(p({"type": "unknown"}), None) h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234}) self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0)) - h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234, - "priority": 2.5}) + h = p({ + "type": "direct-tcp-v1", + "hostname": "foo", + "port": 1234, + "priority": 2.5 + }) self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5)) h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234}) self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0)) - h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234, - "priority": 2.5}) + h = p({ + "type": "tor-tcp-v1", + "hostname": "foo", + "port": 1234, + "priority": 2.5 + }) self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5)) - self.assertEqual(p({"type": "direct-tcp-v1"}), - None) # missing hostname - self.assertEqual(p({"type": "direct-tcp-v1", "hostname": 12}), - None) # invalid hostname - self.assertEqual(p({"type": "direct-tcp-v1", "hostname": "foo"}), - None) # missing port - self.assertEqual(p({"type": "direct-tcp-v1", "hostname": "foo", - "port": "not a number"}), - None) # invalid port + self.assertEqual(p({ + "type": "direct-tcp-v1" + }), None) # missing hostname + self.assertEqual(p({ + "type": "direct-tcp-v1", + "hostname": 12 + }), None) # invalid hostname + self.assertEqual( + p({ + "type": "direct-tcp-v1", + "hostname": "foo" + }), None) # missing port + self.assertEqual( + p({ + "type": "direct-tcp-v1", + "hostname": "foo", + "port": "not a number" + }), None) # invalid port def test_parse_hint_argv(self): def p(hint): stderr = io.StringIO() value = transit.parse_hint_argv(hint, stderr=stderr) return value, stderr.getvalue() - h,stderr = p("tcp:host:1234") + + h, stderr = p("tcp:host:1234") self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) self.assertEqual(stderr, "") - h,stderr = p("tcp:host:1234:priority=2.6") + h, stderr = p("tcp:host:1234:priority=2.6") self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6)) self.assertEqual(stderr, "") - h,stderr = p("tcp:host:1234:unknown=stuff") + h, stderr = p("tcp:host:1234:unknown=stuff") self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) self.assertEqual(stderr, "") - h,stderr = p("$!@#^") + h, stderr = p("$!@#^") self.assertEqual(h, None) self.assertEqual(stderr, "unparseable hint '$!@#^'\n") - h,stderr = p("unknown:stuff") + h, stderr = p("unknown:stuff") self.assertEqual(h, None) self.assertEqual(stderr, "unknown hint type 'unknown' in 'unknown:stuff'\n") - h,stderr = p("tcp:just-a-hostname") + h, stderr = p("tcp:just-a-hostname") self.assertEqual(h, None) - self.assertEqual(stderr, - "unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n") + self.assertEqual( + stderr, + "unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n") - h,stderr = p("tcp:host:number") + h, stderr = p("tcp:host:number") self.assertEqual(h, None) self.assertEqual(stderr, "non-numeric port in TCP hint 'tcp:host:number'\n") - h,stderr = p("tcp:host:1234:priority=bad") + h, stderr = p("tcp:host:1234:priority=bad") self.assertEqual(h, None) - self.assertEqual(stderr, - "non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n") + self.assertEqual( + stderr, + "non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n") def test_describe_hint_obj(self): d = transit.describe_hint_obj - self.assertEqual(d(transit.DirectTCPV1Hint("host", 1234, 0.0)), - "tcp:host:1234") - self.assertEqual(d(transit.TorTCPV1Hint("host", 1234, 0.0)), - "tor:host:1234") + self.assertEqual( + d(transit.DirectTCPV1Hint("host", 1234, 0.0)), "tcp:host:1234") + self.assertEqual( + d(transit.TorTCPV1Hint("host", 1234, 0.0)), "tor:host:1234") self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff"))) + # ipaddrs.py currently uses native strings: bytes on py2, unicode on # py3 if six.PY2: LOOPADDR = b"127.0.0.1" OTHERADDR = b"1.2.3.4" else: - LOOPADDR = "127.0.0.1" # unicode_literals + LOOPADDR = "127.0.0.1" # unicode_literals OTHERADDR = "1.2.3.4" + class Basic(unittest.TestCase): @inlineCallbacks def test_relay_hints(self): URL = "tcp:host:1234" c = transit.Common(URL, no_listen=True) hints = yield c.get_connection_hints() - self.assertEqual(hints, [{"type": "relay-v1", - "hints": [{"type": "direct-tcp-v1", - "hostname": "host", - "port": 1234, - "priority": 0.0}], - }]) + self.assertEqual(hints, [{ + "type": + "relay-v1", + "hints": [{ + "type": "direct-tcp-v1", + "hostname": "host", + "port": 1234, + "priority": 0.0 + }], + }]) self.assertRaises(InternalError, transit.Common, 123) @inlineCallbacks @@ -269,8 +313,12 @@ class Basic(unittest.TestCase): def test_ignore_bad_hints(self): c = transit.Common("") c.add_connection_hints([{"type": "unknown"}]) - c.add_connection_hints([{"type": "relay-v1", - "hints": [{"type": "unknown"}]}]) + c.add_connection_hints([{ + "type": "relay-v1", + "hints": [{ + "type": "unknown" + }] + }]) self.assertEqual(c._their_direct_hints, []) self.assertEqual(c._our_relay_hints, set()) @@ -290,8 +338,9 @@ class Basic(unittest.TestCase): def test_ignore_localhost_hint(self): # this actually starts the listener c = transit.TransitSender("") - with mock.patch("wormhole.ipaddrs.find_addresses", - return_value=[LOOPADDR, OTHERADDR]): + with mock.patch( + "wormhole.ipaddrs.find_addresses", + return_value=[LOOPADDR, OTHERADDR]): hints = self.successResultOf(c.get_connection_hints()) c._stop_listening() # If there are non-localhost hints, then localhost hints should be @@ -302,8 +351,8 @@ class Basic(unittest.TestCase): def test_keep_only_localhost_hint(self): # this actually starts the listener c = transit.TransitSender("") - with mock.patch("wormhole.ipaddrs.find_addresses", - return_value=[LOOPADDR]): + with mock.patch( + "wormhole.ipaddrs.find_addresses", return_value=[LOOPADDR]): hints = self.successResultOf(c.get_connection_hints()) c._stop_listening() # If the only hint is localhost, it should stay. @@ -313,9 +362,14 @@ class Basic(unittest.TestCase): def test_abilities(self): c = transit.Common(None, no_listen=True) abilities = c.get_connection_abilities() - self.assertEqual(abilities, [{"type": "direct-tcp-v1"}, - {"type": "relay-v1"}, - ]) + self.assertEqual(abilities, [ + { + "type": "direct-tcp-v1" + }, + { + "type": "relay-v1" + }, + ]) def test_transit_key_wait(self): KEY = b"123" @@ -339,19 +393,31 @@ class Basic(unittest.TestCase): r = transit.TransitReceiver("") r.set_transit_key(KEY) - self.assertEqual(s._send_this(), b"transit sender 559bdeae4b49fa6a23378d2b68f4c7e69378615d4af049c371c6a26e82391089 ready\n\n") + self.assertEqual(s._send_this(), ( + b"transit sender " + b"559bdeae4b49fa6a23378d2b68f4c7e69378615d4af049c371c6a26e82391089" + b" ready\n\n")) self.assertEqual(s._send_this(), r._expect_this()) - self.assertEqual(r._send_this(), b"transit receiver ed447528194bac4c00d0c854b12a97ce51413d89aa74d6304475f516fdc23a1b ready\n\n") + self.assertEqual(r._send_this(), ( + b"transit receiver " + b"ed447528194bac4c00d0c854b12a97ce51413d89aa74d6304475f516fdc23a1b" + b" ready\n\n")) self.assertEqual(r._send_this(), s._expect_this()) - self.assertEqual(hexlify(s._sender_record_key()), b"5a2fba3a9e524ab2e2823ff53b05f946896f6e4ce4e282ffd8e3ac0e5e9e0cda") - self.assertEqual(hexlify(s._sender_record_key()), - hexlify(r._receiver_record_key())) + self.assertEqual( + hexlify(s._sender_record_key()), + b"5a2fba3a9e524ab2e2823ff53b05f946896f6e4ce4e282ffd8e3ac0e5e9e0cda" + ) + self.assertEqual( + hexlify(s._sender_record_key()), hexlify(r._receiver_record_key())) - self.assertEqual(hexlify(r._sender_record_key()), b"eedb143117249f45b39da324decf6bd9aae33b7ccd58487436de611a3c6b871d") - self.assertEqual(hexlify(r._sender_record_key()), - hexlify(s._receiver_record_key())) + self.assertEqual( + hexlify(r._sender_record_key()), + b"eedb143117249f45b39da324decf6bd9aae33b7ccd58487436de611a3c6b871d" + ) + self.assertEqual( + hexlify(r._sender_record_key()), hexlify(s._receiver_record_key())) def test_connection_ready(self): s = transit.TransitSender("") @@ -407,7 +473,7 @@ class DummyProtocol(protocol.Protocol): def dataReceived(self, data): self.buf += data - #print("oDR", self._count, len(self.buf)) + # print("oDR", self._count, len(self.buf)) if self._count is not None and len(self.buf) >= self._count: got = self.buf[:self._count] self.buf = self.buf[self._count:] @@ -422,19 +488,24 @@ class DummyProtocol(protocol.Protocol): if self._d2: self._d2.callback(None) + class FakeTransport: signalConnectionLost = True + def __init__(self, p, peeraddr): self.protocol = p self._peeraddr = peeraddr self._buf = b"" self._connected = True + def write(self, data): self._buf += data + def loseConnection(self): self._connected = False if self.signalConnectionLost: self.protocol.connectionLost() + def getPeer(self): return self._peeraddr @@ -443,17 +514,21 @@ class FakeTransport: self._buf = b"" return b + class RandomError(Exception): pass + class MockConnection: def __init__(self, owner, relay_handshake, start, description): self.owner = owner self.relay_handshake = relay_handshake self.start = start self._description = description + def cancel(d): self._cancelled = True + self._d = defer.Deferred(cancel) self._start_negotiation_called = False self._cancelled = False @@ -462,6 +537,7 @@ class MockConnection: self._start_negotiation_called = True return self._d + class InboundConnectionFactory(unittest.TestCase): def test_describe(self): f = transit.InboundConnectionFactory(None) @@ -472,8 +548,8 @@ class InboundConnectionFactory(unittest.TestCase): addr6 = address.IPv6Address("TCP", "::1", 1234) self.assertEqual(f._describePeer(addr6), "<-::1:1234") addrU = address.UNIXAddress("/dev/unlikely") - self.assertEqual(f._describePeer(addrU), - "<-UNIXAddress('/dev/unlikely')") + self.assertEqual( + f._describePeer(addrU), "<-UNIXAddress('/dev/unlikely')") def test_success(self): f = transit.InboundConnectionFactory("owner") @@ -586,8 +662,10 @@ class InboundConnectionFactory(unittest.TestCase): self.assertEqual(p1._cancelled, True) self.assertEqual(p2._cancelled, True) + # XXX check descriptions + class OutboundConnectionFactory(unittest.TestCase): def test_success(self): f = transit.OutboundConnectionFactory("owner", "relay_handshake", @@ -603,31 +681,39 @@ class OutboundConnectionFactory(unittest.TestCase): # meh .start # this is normally called from Connection.connectionMade - f.connectionWasMade(p) # no-op for outbound + f.connectionWasMade(p) # no-op for outbound self.assertEqual(p._start_negotiation_called, False) class MockOwner: _connection_ready_called = False + def connection_ready(self, connection): self._connection_ready_called = True self._connection = connection return self._state + def _send_this(self): return b"send_this" + def _expect_this(self): return b"expect_this" + def _sender_record_key(self): - return b"s"*32 + return b"s" * 32 + def _receiver_record_key(self): - return b"r"*32 + return b"r" * 32 + class MockFactory: _connectionWasMade_called = False + def connectionWasMade(self, p): self._connectionWasMade_called = True self._p = p + class Connection(unittest.TestCase): # exercise the Connection protocol class @@ -640,8 +726,8 @@ class Connection(unittest.TestCase): c.buf = b"unexpected" e = self.assertRaises(transit.BadHandshake, c._check_and_remove, EXP) - self.assertEqual(str(e), - "got %r want %r" % (b'unexpected', b'expectation')) + self.assertEqual( + str(e), "got %r want %r" % (b'unexpected', b'expectation')) self.assertEqual(c.buf, b"unexpected") c.buf = b"expect" @@ -770,12 +856,12 @@ class Connection(unittest.TestCase): c.connectionMade() self.assertEqual(factory._connectionWasMade_called, True) self.assertEqual(factory._p, c) - self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation + self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation owner._state = "go" d = c.startNegotiation() self.assertEqual(t.read_buf(), relay_handshake) - self.assertEqual(c.state, "relay") # waiting for OK from relay + self.assertEqual(c.state, "relay") # waiting for OK from relay c.dataReceived(b"ok\n") self.assertEqual(t.read_buf(), b"send_this") @@ -801,20 +887,20 @@ class Connection(unittest.TestCase): c.connectionMade() self.assertEqual(factory._connectionWasMade_called, True) self.assertEqual(factory._p, c) - self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation + self.assertEqual(t.read_buf(), b"") # quiet until startNegotiation owner._state = "go" d = c.startNegotiation() self.assertEqual(t.read_buf(), relay_handshake) - self.assertEqual(c.state, "relay") # waiting for OK from relay + self.assertEqual(c.state, "relay") # waiting for OK from relay c.dataReceived(b"not ok\n") self.assertEqual(t._connected, False) self.assertEqual(c.state, "hung up") f = self.failureResultOf(d, transit.BadHandshake) - self.assertEqual(str(f.value), - "got %r want %r" % (b"not ok\n", b"ok\n")) + self.assertEqual( + str(f.value), "got %r want %r" % (b"not ok\n", b"ok\n")) def test_receiver_accepted(self): # we're on the receiving side, so we wait for the sender to decide @@ -866,12 +952,12 @@ class Connection(unittest.TestCase): self.assertEqual(c.state, "wait-for-decision") self.assertNoResult(d) - c.dataReceived(b"nevermind\n") # polite rejection + c.dataReceived(b"nevermind\n") # polite rejection self.assertEqual(t._connected, False) self.assertEqual(c.state, "hung up") f = self.failureResultOf(d, transit.BadHandshake) - self.assertEqual(str(f.value), - "got %r want %r" % (b"nevermind\n", b"go\n")) + self.assertEqual( + str(f.value), "got %r want %r" % (b"nevermind\n", b"go\n")) def test_receiver_rejected_rudely(self): # we're on the receiving side, so we wait for the sender to decide @@ -901,7 +987,6 @@ class Connection(unittest.TestCase): f = self.failureResultOf(d, transit.BadHandshake) self.assertEqual(str(f.value), "connection lost") - def test_cancel(self): owner = MockOwner() factory = MockFactory() @@ -926,8 +1011,10 @@ class Connection(unittest.TestCase): factory = MockFactory() addr = address.HostnameAddress("example.com", 1234) c = transit.Connection(owner, None, None, "description") + def _callLater(period, func): clock.callLater(period, func) + c.callLater = _callLater self.assertEqual(c.state, "too-early") t = c.transport = FakeTransport(c, addr) @@ -955,7 +1042,7 @@ class Connection(unittest.TestCase): d = c.startNegotiation() c.dataReceived(b"expect_this") self.assertEqual(self.successResultOf(d), c) - t.read_buf() # flush input buffer, prepare for encrypted records + t.read_buf() # flush input buffer, prepare for encrypted records return t, c, owner @@ -973,13 +1060,13 @@ class Connection(unittest.TestCase): RECORD1 = b"record" c.send_record(RECORD1) buf = t.read_buf() - expected = ("%08x" % (24+len(RECORD1)+16)).encode("ascii") + expected = ("%08x" % (24 + len(RECORD1) + 16)).encode("ascii") self.assertEqual(hexlify(buf[:4]), expected) encrypted = buf[4:] receive_box = SecretBox(owner._sender_record_key()) - nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended + nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended nonce = int(hexlify(nonce_buf), 16) - self.assertEqual(nonce, 0) # first message gets nonce 0 + self.assertEqual(nonce, 0) # first message gets nonce 0 decrypted = receive_box.decrypt(encrypted) self.assertEqual(decrypted, RECORD1) @@ -987,11 +1074,11 @@ class Connection(unittest.TestCase): RECORD2 = b"record2" c.send_record(RECORD2) buf = t.read_buf() - expected = ("%08x" % (24+len(RECORD2)+16)).encode("ascii") + expected = ("%08x" % (24 + len(RECORD2) + 16)).encode("ascii") self.assertEqual(hexlify(buf[:4]), expected) encrypted = buf[4:] receive_box = SecretBox(owner._sender_record_key()) - nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended + nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended nonce = int(hexlify(nonce_buf), 16) self.assertEqual(nonce, 1) decrypted = receive_box.decrypt(encrypted) @@ -1003,9 +1090,9 @@ class Connection(unittest.TestCase): send_box = SecretBox(owner._receiver_record_key()) RECORD3 = b"record3" - nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0 + nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0 encrypted = send_box.encrypt(RECORD3, nonce_buf) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long c.dataReceived(length[:2]) c.dataReceived(length[2:]) c.dataReceived(encrypted[:-2]) @@ -1014,9 +1101,9 @@ class Connection(unittest.TestCase): self.assertEqual(inbound_records, [RECORD3]) RECORD4 = b"record4" - nonce_buf = unhexlify("%048x" % 1) # nonces increment + nonce_buf = unhexlify("%048x" % 1) # nonces increment encrypted = send_box.encrypt(RECORD4, nonce_buf) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long c.dataReceived(length[:2]) c.dataReceived(length[2:]) c.dataReceived(encrypted[:-2]) @@ -1027,16 +1114,16 @@ class Connection(unittest.TestCase): # receiving two records at the same time: deliver both inbound_records[:] = [] RECORD5 = b"record5" - nonce_buf = unhexlify("%048x" % 2) # nonces increment + nonce_buf = unhexlify("%048x" % 2) # nonces increment encrypted = send_box.encrypt(RECORD5, nonce_buf) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long - r5 = length+encrypted + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + r5 = length + encrypted RECORD6 = b"record6" - nonce_buf = unhexlify("%048x" % 3) # nonces increment + nonce_buf = unhexlify("%048x" % 3) # nonces increment encrypted = send_box.encrypt(RECORD6, nonce_buf) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long - r6 = length+encrypted - c.dataReceived(r5+r6) + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + r6 = length + encrypted + c.dataReceived(r5 + r6) self.assertEqual(inbound_records, [RECORD5, RECORD6]) def corrupt(self, orig): @@ -1055,9 +1142,9 @@ class Connection(unittest.TestCase): RECORD = b"record" send_box = SecretBox(owner._receiver_record_key()) - nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0 + nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0 encrypted = self.corrupt(send_box.encrypt(RECORD, nonce_buf)) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long c.dataReceived(length) c.dataReceived(encrypted[:-2]) self.assertEqual(inbound_records, []) @@ -1075,9 +1162,9 @@ class Connection(unittest.TestCase): RECORD = b"record" send_box = SecretBox(owner._receiver_record_key()) - nonce_buf = unhexlify("%048x" % 1) # first nonce must be 0 + nonce_buf = unhexlify("%048x" % 1) # first nonce must be 0 encrypted = send_box.encrypt(RECORD, nonce_buf) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long c.dataReceived(length) c.dataReceived(encrypted[:-2]) self.assertEqual(inbound_records, []) @@ -1159,7 +1246,7 @@ class Connection(unittest.TestCase): # connectConsumer() takes an optional number of bytes to expect, and # fires a Deferred when that many have been written c = transit.Connection(None, None, None, "description") - c._negotiation_d.addErrback(lambda err: None) # eat it + c._negotiation_d.addErrback(lambda err: None) # eat it c.transport = proto_helpers.StringTransport() c.recordReceived(b"r1.") @@ -1197,7 +1284,7 @@ class Connection(unittest.TestCase): # zero-length file), make sure it gets woken up right away, so it can # disconnect itself, even though no bytes will actually arrive c = transit.Connection(None, None, None, "description") - c._negotiation_d.addErrback(lambda err: None) # eat it + c._negotiation_d.addErrback(lambda err: None) # eat it c.transport = proto_helpers.StringTransport() consumer = proto_helpers.StringTransport() @@ -1208,7 +1295,7 @@ class Connection(unittest.TestCase): def test_writeToFile(self): c = transit.Connection(None, None, None, "description") - c._negotiation_d.addErrback(lambda err: None) # eat it + c._negotiation_d.addErrback(lambda err: None) # eat it c.transport = proto_helpers.StringTransport() c.recordReceived(b"r1.") @@ -1242,11 +1329,11 @@ class Connection(unittest.TestCase): self.assertEqual(progress, [3, 3, 3, 1]) # test what happens when enough data is queued ahead of time - c.recordReceived(b"second.") # now "overflow.second." - c.recordReceived(b"third.") # now "overflow.second.third." + c.recordReceived(b"second.") # now "overflow.second." + c.recordReceived(b"third.") # now "overflow.second.third." f = io.BytesIO() d = c.writeToFile(f, 10) - self.assertEqual(f.getvalue(), b"overflow.second.") # whole records + self.assertEqual(f.getvalue(), b"overflow.second.") # whole records self.assertEqual(self.successResultOf(d), 16) self.assertEqual(list(c._inbound_records), [b"third."]) @@ -1273,6 +1360,7 @@ class Connection(unittest.TestCase): c.unregisterProducer() self.assertEqual(c.transport.producer, None) + class FileConsumer(unittest.TestCase): def test_basic(self): f = io.BytesIO() @@ -1280,12 +1368,12 @@ class FileConsumer(unittest.TestCase): fc = transit.FileConsumer(f, progress.append) self.assertEqual(progress, []) self.assertEqual(f.getvalue(), b"") - fc.write(b"."* 99) + fc.write(b"." * 99) self.assertEqual(progress, [99]) - self.assertEqual(f.getvalue(), b"."*99) + self.assertEqual(f.getvalue(), b"." * 99) fc.write(b"!") self.assertEqual(progress, [99, 1]) - self.assertEqual(f.getvalue(), b"."*99+b"!") + self.assertEqual(f.getvalue(), b"." * 99 + b"!") def test_hasher(self): hashee = [] @@ -1295,32 +1383,53 @@ class FileConsumer(unittest.TestCase): self.assertEqual(progress, []) self.assertEqual(f.getvalue(), b"") self.assertEqual(hashee, []) - fc.write(b"."* 99) + fc.write(b"." * 99) self.assertEqual(progress, [99]) - self.assertEqual(f.getvalue(), b"."*99) - self.assertEqual(hashee, [b"."*99]) + self.assertEqual(f.getvalue(), b"." * 99) + self.assertEqual(hashee, [b"." * 99]) fc.write(b"!") self.assertEqual(progress, [99, 1]) - self.assertEqual(f.getvalue(), b"."*99+b"!") - self.assertEqual(hashee, [b"."*99, b"!"]) + self.assertEqual(f.getvalue(), b"." * 99 + b"!") + self.assertEqual(hashee, [b"." * 99, b"!"]) -DIRECT_HINT_JSON = {"type": "direct-tcp-v1", - "hostname": "direct", "port": 1234} -RELAY_HINT_JSON = {"type": "relay-v1", - "hints": [{"type": "direct-tcp-v1", - "hostname": "relay", "port": 1234}]} -UNRECOGNIZED_DIRECT_HINT_JSON = {"type": "direct-tcp-v1", - "hostname": ["cannot", "parse", "list"]} +DIRECT_HINT_JSON = { + "type": "direct-tcp-v1", + "hostname": "direct", + "port": 1234 +} +RELAY_HINT_JSON = { + "type": "relay-v1", + "hints": [{ + "type": "direct-tcp-v1", + "hostname": "relay", + "port": 1234 + }] +} +UNRECOGNIZED_DIRECT_HINT_JSON = { + "type": "direct-tcp-v1", + "hostname": ["cannot", "parse", "list"] +} UNRECOGNIZED_HINT_JSON = {"type": "unknown"} -UNAVAILABLE_HINT_JSON = {"type": "direct-tcp-v1", # e.g. Tor without txtorcon - "hostname": "unavailable", "port": 1234} -RELAY_HINT2_JSON = {"type": "relay-v1", - "hints": [{"type": "direct-tcp-v1", - "hostname": "relay", "port": 1234}, - UNRECOGNIZED_HINT_JSON]} -UNAVAILABLE_RELAY_HINT_JSON = {"type": "relay-v1", - "hints": [UNAVAILABLE_HINT_JSON]} +UNAVAILABLE_HINT_JSON = { + "type": "direct-tcp-v1", # e.g. Tor without txtorcon + "hostname": "unavailable", + "port": 1234 +} +RELAY_HINT2_JSON = { + "type": + "relay-v1", + "hints": [{ + "type": "direct-tcp-v1", + "hostname": "relay", + "port": 1234 + }, UNRECOGNIZED_HINT_JSON] +} +UNAVAILABLE_RELAY_HINT_JSON = { + "type": "relay-v1", + "hints": [UNAVAILABLE_HINT_JSON] +} + class Transit(unittest.TestCase): def setUp(self): @@ -1340,11 +1449,12 @@ class Transit(unittest.TestCase): clock = task.Clock() s = transit.TransitSender("", reactor=clock) s.set_transit_key(b"key") - hints = yield s.get_connection_hints() # start the listener + hints = yield s.get_connection_hints() # start the listener del hints - s.add_connection_hints([DIRECT_HINT_JSON, - UNRECOGNIZED_DIRECT_HINT_JSON, - UNRECOGNIZED_HINT_JSON]) + s.add_connection_hints([ + DIRECT_HINT_JSON, UNRECOGNIZED_DIRECT_HINT_JSON, + UNRECOGNIZED_HINT_JSON + ]) s._start_connector = self._start_connector d = s.connect() @@ -1361,7 +1471,7 @@ class Transit(unittest.TestCase): clock = task.Clock() s = transit.TransitSender("", tor=mock.Mock(), reactor=clock) s.set_transit_key(b"key") - hints = yield s.get_connection_hints() # start the listener + hints = yield s.get_connection_hints() # start the listener del hints s.add_connection_hints([DIRECT_HINT_JSON]) @@ -1380,7 +1490,7 @@ class Transit(unittest.TestCase): clock = task.Clock() s = transit.TransitSender("", tor=mock.Mock(), reactor=clock) s.set_transit_key(b"key") - hints = yield s.get_connection_hints() # start the listener + hints = yield s.get_connection_hints() # start the listener del hints s.add_connection_hints([RELAY_HINT_JSON]) @@ -1411,9 +1521,8 @@ class Transit(unittest.TestCase): s.set_transit_key(b"key") hints = yield s.get_connection_hints() del hints - s.add_connection_hints([DIRECT_HINT_JSON, - UNRECOGNIZED_HINT_JSON, - RELAY_HINT_JSON]) + s.add_connection_hints( + [DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON, RELAY_HINT_JSON]) s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector @@ -1437,20 +1546,46 @@ class Transit(unittest.TestCase): hints = yield s.get_connection_hints() del hints s.add_connection_hints([ - {"type": "relay-v1", - "hints": [{"type": "direct-tcp-v1", - "hostname": "relay", "port": 1234}]}, - {"type": "direct-tcp-v1", - "hostname": "direct", "port": 1234}, - {"type": "relay-v1", - "hints": [{"type": "direct-tcp-v1", "priority": 2.0, - "hostname": "relay2", "port": 1234}, - {"type": "direct-tcp-v1", "priority": 3.0, - "hostname": "relay3", "port": 1234}]}, - {"type": "relay-v1", - "hints": [{"type": "direct-tcp-v1", "priority": 2.0, - "hostname": "relay4", "port": 1234}]}, - ]) + { + "type": + "relay-v1", + "hints": [{ + "type": "direct-tcp-v1", + "hostname": "relay", + "port": 1234 + }] + }, + { + "type": "direct-tcp-v1", + "hostname": "direct", + "port": 1234 + }, + { + "type": + "relay-v1", + "hints": [{ + "type": "direct-tcp-v1", + "priority": 2.0, + "hostname": "relay2", + "port": 1234 + }, { + "type": "direct-tcp-v1", + "priority": 3.0, + "hostname": "relay3", + "port": 1234 + }] + }, + { + "type": + "relay-v1", + "hints": [{ + "type": "direct-tcp-v1", + "priority": 2.0, + "hostname": "relay4", + "port": 1234 + }] + }, + ]) s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector @@ -1482,13 +1617,13 @@ class Transit(unittest.TestCase): clock = task.Clock() s = transit.TransitSender("", reactor=clock, no_listen=True) s.set_transit_key(b"key") - hints = yield s.get_connection_hints() # start the listener + hints = yield s.get_connection_hints() # start the listener del hints # include hints that can't be turned into an endpoint at runtime - s.add_connection_hints([UNRECOGNIZED_HINT_JSON, - UNAVAILABLE_HINT_JSON, - RELAY_HINT2_JSON, - UNAVAILABLE_RELAY_HINT_JSON]) + s.add_connection_hints([ + UNRECOGNIZED_HINT_JSON, UNAVAILABLE_HINT_JSON, RELAY_HINT2_JSON, + UNAVAILABLE_RELAY_HINT_JSON + ]) s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector @@ -1509,9 +1644,9 @@ class Transit(unittest.TestCase): clock = task.Clock() s = transit.TransitSender("", reactor=clock, no_listen=True) s.set_transit_key(b"key") - hints = yield s.get_connection_hints() # start the listener + hints = yield s.get_connection_hints() # start the listener del hints - s.add_connection_hints([]) # no hints at all + s.add_connection_hints([]) # no hints at all s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector @@ -1519,10 +1654,11 @@ class Transit(unittest.TestCase): f = self.failureResultOf(d, transit.TransitError) self.assertEqual(str(f.value), "No contenders for connection") + class RelayHandshake(unittest.TestCase): def old_build_relay_handshake(self, key): token = transit.HKDF(key, 32, CTXinfo=b"transit_relay_token") - return (token, b"please relay "+hexlify(token)+b"\n") + return (token, b"please relay " + hexlify(token) + b"\n") def test_old(self): key = b"\x00" @@ -1548,8 +1684,9 @@ class RelayHandshake(unittest.TestCase): tc.dataReceived(new_handshake[:-1]) self.assertEqual(tc.factory.connection_got_token.mock_calls, []) tc.dataReceived(new_handshake[-1:]) - self.assertEqual(tc.factory.connection_got_token.mock_calls, - [mock.call(hexlify(token), c._side.encode("ascii"), tc)]) + self.assertEqual( + tc.factory.connection_got_token.mock_calls, + [mock.call(hexlify(token), c._side.encode("ascii"), tc)]) class Full(ServerBase, unittest.TestCase): @@ -1558,7 +1695,7 @@ class Full(ServerBase, unittest.TestCase): @inlineCallbacks def test_direct(self): - KEY = b"k"*32 + KEY = b"k" * 32 s = transit.TransitSender(None) r = transit.TransitReceiver(None) @@ -1571,7 +1708,7 @@ class Full(ServerBase, unittest.TestCase): s.add_connection_hints(rhints) r.add_connection_hints(shints) - (x,y) = yield self.doBoth(s.connect(), r.connect()) + (x, y) = yield self.doBoth(s.connect(), r.connect()) self.assertIsInstance(x, transit.Connection) self.assertIsInstance(y, transit.Connection) @@ -1586,7 +1723,7 @@ class Full(ServerBase, unittest.TestCase): @inlineCallbacks def test_relay(self): - KEY = b"k"*32 + KEY = b"k" * 32 s = transit.TransitSender(self.transit, no_listen=True) r = transit.TransitReceiver(self.transit, no_listen=True) @@ -1599,7 +1736,7 @@ class Full(ServerBase, unittest.TestCase): s.add_connection_hints(rhints) r.add_connection_hints(shints) - (x,y) = yield self.doBoth(s.connect(), r.connect()) + (x, y) = yield self.doBoth(s.connect(), r.connect()) self.assertIsInstance(x, transit.Connection) self.assertIsInstance(y, transit.Connection) diff --git a/src/wormhole/test/test_util.py b/src/wormhole/test/test_util.py index a939b91..e3c7d96 100644 --- a/src/wormhole/test/test_util.py +++ b/src/wormhole/test/test_util.py @@ -1,10 +1,15 @@ from __future__ import unicode_literals -import six -import mock + import unicodedata + +import six from twisted.trial import unittest + +import mock + from .. import util + class Utils(unittest.TestCase): def test_to_bytes(self): b = util.to_bytes("abc") @@ -41,11 +46,12 @@ class Utils(unittest.TestCase): self.assertIsInstance(d, dict) self.assertEqual(d, {"a": "b", "c": 2}) + class Space(unittest.TestCase): def test_free_space(self): free = util.estimate_free_space(".") - self.assert_(isinstance(free, six.integer_types + (type(None),)), - repr(free)) + self.assert_( + isinstance(free, six.integer_types + (type(None), )), repr(free)) # some platforms (I think the VMs used by travis are in this # category) return 0, and windows will return None, so don't assert # anything more specific about the return value @@ -56,5 +62,5 @@ class Space(unittest.TestCase): try: with mock.patch("os.statvfs", side_effect=AttributeError()): self.assertEqual(util.estimate_free_space("."), None) - except AttributeError: # raised by mock.get_original() + except AttributeError: # raised by mock.get_original() pass diff --git a/src/wormhole/test/test_wordlist.py b/src/wormhole/test/test_wordlist.py index 6b86cdb..06eb13d 100644 --- a/src/wormhole/test/test_wordlist.py +++ b/src/wormhole/test/test_wordlist.py @@ -1,8 +1,12 @@ from __future__ import print_function, unicode_literals -import mock + from twisted.trial import unittest + +import mock + from .._wordlist import PGPWordList + class Completions(unittest.TestCase): def test_completions(self): wl = PGPWordList() @@ -14,16 +18,21 @@ class Completions(unittest.TestCase): self.assertEqual(len(lots), 256, lots) first = list(lots)[0] self.assert_(first.startswith("armistice-"), first) - self.assertEqual(gc("armistice-ba", 2), - {"armistice-baboon", "armistice-backfield", - "armistice-backward", "armistice-banjo"}) - self.assertEqual(gc("armistice-ba", 3), - {"armistice-baboon-", "armistice-backfield-", - "armistice-backward-", "armistice-banjo-"}) + self.assertEqual( + gc("armistice-ba", 2), { + "armistice-baboon", "armistice-backfield", + "armistice-backward", "armistice-banjo" + }) + self.assertEqual( + gc("armistice-ba", 3), { + "armistice-baboon-", "armistice-backfield-", + "armistice-backward-", "armistice-banjo-" + }) self.assertEqual(gc("armistice-baboon", 2), {"armistice-baboon"}) self.assertEqual(gc("armistice-baboon", 3), {"armistice-baboon-"}) self.assertEqual(gc("armistice-baboon", 4), {"armistice-baboon-"}) + class Choose(unittest.TestCase): def test_choose_words(self): wl = PGPWordList() diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index df5830f..0362c65 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -1,17 +1,22 @@ from __future__ import print_function, unicode_literals -import io, re -import mock -from twisted.trial import unittest + +import io +import re + from twisted.internet import reactor from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue from twisted.internet.error import ConnectionRefusedError -from .common import ServerBase, poll_until -from .. import wormhole, _rendezvous -from ..errors import (WrongPasswordError, ServerConnectionError, - KeyFormatError, WormholeClosed, LonelyError, - NoKeyError, OnlyOneCodeError) -from ..transit import allocate_tcp_port +from twisted.trial import unittest + +import mock + +from .. import _rendezvous, wormhole +from ..errors import (KeyFormatError, LonelyError, NoKeyError, + OnlyOneCodeError, ServerConnectionError, WormholeClosed, + WrongPasswordError) from ..eventual import EventualQueue +from ..transit import allocate_tcp_port +from .common import ServerBase, poll_until APPID = "appid" @@ -26,6 +31,7 @@ APPID = "appid" # * set_code, then connected # * connected, receive_pake, send_phase, set_code + class Delegate: def __init__(self): self.welcome = None @@ -35,28 +41,35 @@ class Delegate: self.versions = None self.messages = [] self.closed = None + def wormhole_got_welcome(self, welcome): self.welcome = welcome + def wormhole_got_code(self, code): self.code = code + def wormhole_got_unverified_key(self, key): self.key = key + def wormhole_got_verifier(self, verifier): self.verifier = verifier + def wormhole_got_versions(self, versions): self.versions = versions + def wormhole_got_message(self, data): self.messages.append(data) + def wormhole_closed(self, result): self.closed = result -class Delegated(ServerBase, unittest.TestCase): +class Delegated(ServerBase, unittest.TestCase): @inlineCallbacks def test_delegated(self): dg = Delegate() w1 = wormhole.create(APPID, self.relayurl, reactor, delegate=dg) - #w1.debug_set_trace("W1") + # w1.debug_set_trace("W1") with self.assertRaises(NoKeyError): w1.derive_key("purpose", 12) w1.set_code("1-abc") @@ -103,6 +116,7 @@ class Delegated(ServerBase, unittest.TestCase): yield poll_until(lambda: dg.code is not None) w1.close() + class Wormholes(ServerBase, unittest.TestCase): # integration test, with a real server @@ -135,12 +149,12 @@ class Wormholes(ServerBase, unittest.TestCase): @inlineCallbacks def test_basic(self): w1 = wormhole.create(APPID, self.relayurl, reactor) - #w1.debug_set_trace("W1") + # w1.debug_set_trace("W1") with self.assertRaises(NoKeyError): w1.derive_key("purpose", 12) w2 = wormhole.create(APPID, self.relayurl, reactor) - #w2.debug_set_trace(" W2") + # w2.debug_set_trace(" W2") w1.allocate_code() code = yield w1.get_code() w2.set_code(code) @@ -302,7 +316,6 @@ class Wormholes(ServerBase, unittest.TestCase): yield w1.close() yield w2.close() - @inlineCallbacks def test_multiple_messages(self): w1 = wormhole.create(APPID, self.relayurl, reactor) @@ -322,7 +335,6 @@ class Wormholes(ServerBase, unittest.TestCase): yield w1.close() yield w2.close() - @inlineCallbacks def test_closed(self): eq = EventualQueue(reactor) @@ -377,7 +389,7 @@ class Wormholes(ServerBase, unittest.TestCase): w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq) w1.allocate_code() code = yield w1.get_code() - w2.set_code(code+"not") + w2.set_code(code + "not") code2 = yield w2.get_code() self.assertNotEqual(code, code2) # That's enough to allow both sides to discover the mismatch, but @@ -387,9 +399,9 @@ class Wormholes(ServerBase, unittest.TestCase): w1.send_message(b"should still work") w2.send_message(b"should still work") - key2 = yield w2.get_unverified_key() # should work + key2 = yield w2.get_unverified_key() # should work # w2 has just received w1.PAKE, and is about to send w2.VERSION - key1 = yield w1.get_unverified_key() # should work + key1 = yield w1.get_unverified_key() # should work # w1 has just received w2.PAKE, and is about to send w1.VERSION, and # then will receive w2.VERSION. When it sees w2.VERSION, it will # learn about the WrongPasswordError. @@ -451,7 +463,7 @@ class Wormholes(ServerBase, unittest.TestCase): badcode = "4 oops spaces" with self.assertRaises(KeyFormatError) as ex: w.set_code(badcode) - expected_msg = "Code '%s' contains spaces." % (badcode,) + expected_msg = "Code '%s' contains spaces." % (badcode, ) self.assertEqual(expected_msg, str(ex.exception)) yield self.assertFailure(w.close(), LonelyError) @@ -461,7 +473,7 @@ class Wormholes(ServerBase, unittest.TestCase): badcode = " 4-oops-space" with self.assertRaises(KeyFormatError) as ex: w.set_code(badcode) - expected_msg = "Code '%s' contains spaces." % (badcode,) + expected_msg = "Code '%s' contains spaces." % (badcode, ) self.assertEqual(expected_msg, str(ex.exception)) yield self.assertFailure(w.close(), LonelyError) @@ -478,8 +490,8 @@ class Wormholes(ServerBase, unittest.TestCase): @inlineCallbacks def test_welcome(self): w1 = wormhole.create(APPID, self.relayurl, reactor) - wel1 = yield w1.get_welcome() # early: before connection established - wel2 = yield w1.get_welcome() # late: already received welcome + wel1 = yield w1.get_welcome() # early: before connection established + wel2 = yield w1.get_welcome() # late: already received welcome self.assertEqual(wel1, wel2) self.assertIn("current_cli_version", wel1) @@ -489,7 +501,7 @@ class Wormholes(ServerBase, unittest.TestCase): w2.set_code("123-NOT") yield self.assertFailure(w1.get_verifier(), WrongPasswordError) - yield self.assertFailure(w1.get_welcome(), WrongPasswordError) # late + yield self.assertFailure(w1.get_welcome(), WrongPasswordError) # late yield self.assertFailure(w1.close(), WrongPasswordError) yield self.assertFailure(w2.close(), WrongPasswordError) @@ -502,7 +514,7 @@ class Wormholes(ServerBase, unittest.TestCase): w1.allocate_code() code = yield w1.get_code() w2.set_code(code) - v1 = yield w1.get_verifier() # early + v1 = yield w1.get_verifier() # early v2 = yield w2.get_verifier() self.failUnlessEqual(type(v1), type(b"")) self.failUnlessEqual(v1, v2) @@ -525,10 +537,10 @@ class Wormholes(ServerBase, unittest.TestCase): @inlineCallbacks def test_versions(self): # there's no API for this yet, but make sure the internals work - w1 = wormhole.create(APPID, self.relayurl, reactor, - versions={"w1": 123}) - w2 = wormhole.create(APPID, self.relayurl, reactor, - versions={"w2": 456}) + w1 = wormhole.create( + APPID, self.relayurl, reactor, versions={"w1": 123}) + w2 = wormhole.create( + APPID, self.relayurl, reactor, versions={"w2": 456}) w1.allocate_code() code = yield w1.get_code() w2.set_code(code) @@ -564,17 +576,19 @@ class Wormholes(ServerBase, unittest.TestCase): yield w1.close() yield w2.close() + class MessageDoubler(_rendezvous.RendezvousConnector): # we could double messages on the sending side, but a future server will # strip those duplicates, so to really exercise the receiver, we must # double them on the inbound side instead - #def _msg_send(self, phase, body): - # wormhole._Wormhole._msg_send(self, phase, body) - # self._ws_send_command("add", phase=phase, body=bytes_to_hexstr(body)) + # def _msg_send(self, phase, body): + # wormhole._Wormhole._msg_send(self, phase, body) + # self._ws_send_command("add", phase=phase, body=bytes_to_hexstr(body)) def _response_handle_message(self, msg): _rendezvous.RendezvousConnector._response_handle_message(self, msg) _rendezvous.RendezvousConnector._response_handle_message(self, msg) + class Errors(ServerBase, unittest.TestCase): @inlineCallbacks def test_derive_key_early(self): @@ -602,16 +616,17 @@ class Errors(ServerBase, unittest.TestCase): w.set_code("123-nope") yield self.assertFailure(w.close(), LonelyError) + class Reconnection(ServerBase, unittest.TestCase): @inlineCallbacks def test_basic(self): w1 = wormhole.create(APPID, self.relayurl, reactor) w1_in = [] w1._boss._RC._debug_record_inbound_f = w1_in.append - #w1.debug_set_trace("W1") + # w1.debug_set_trace("W1") w1.allocate_code() code = yield w1.get_code() - w1.send_message(b"data1") # queued until wormhole is established + w1.send_message(b"data1") # queued until wormhole is established # now wait until we've deposited all our messages on the server def seen_our_pake(): @@ -619,6 +634,7 @@ class Reconnection(ServerBase, unittest.TestCase): if m["type"] == "message" and m["phase"] == "pake": return True return False + yield poll_until(seen_our_pake) w1_in[:] = [] @@ -634,7 +650,7 @@ class Reconnection(ServerBase, unittest.TestCase): # receiver has started w2 = wormhole.create(APPID, self.relayurl, reactor) - #w2.debug_set_trace(" W2") + # w2.debug_set_trace(" W2") w2.set_code(code) dataY = yield w2.get_message() @@ -649,6 +665,7 @@ class Reconnection(ServerBase, unittest.TestCase): c2 = yield w2.close() self.assertEqual(c2, "happy") + class InitialFailure(unittest.TestCase): @inlineCallbacks def assertSCEFailure(self, eq, d, innerType): @@ -662,8 +679,8 @@ class InitialFailure(unittest.TestCase): def test_bad_dns(self): eq = EventualQueue(reactor) # point at a URL that will never connect - w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", - reactor, _eventual_queue=eq) + w = wormhole.create( + APPID, "ws://%%%.example.org:4000/v1", reactor, _eventual_queue=eq) # that should have already received an error, when it tried to # resolve the bogus DNS name. All API calls will return an error. @@ -717,6 +734,7 @@ class InitialFailure(unittest.TestCase): yield self.assertSCE(d4, ConnectionRefusedError) yield self.assertSCE(d5, ConnectionRefusedError) + class Trace(unittest.TestCase): def test_basic(self): w1 = wormhole.create(APPID, "ws://localhost:1", reactor) @@ -734,13 +752,11 @@ class Trace(unittest.TestCase): ["C1.M1[OLD].IN -> [NEW]"]) out("OUT1") self.assertEqual(stderr.getvalue().splitlines(), - ["C1.M1[OLD].IN -> [NEW]", - " C1.M1.OUT1()"]) + ["C1.M1[OLD].IN -> [NEW]", " C1.M1.OUT1()"]) w1._boss._print_trace("", "R.connected", "", "C1", "RC1", stderr) - self.assertEqual(stderr.getvalue().splitlines(), - ["C1.M1[OLD].IN -> [NEW]", - " C1.M1.OUT1()", - "C1.RC1.R.connected"]) + self.assertEqual( + stderr.getvalue().splitlines(), + ["C1.M1[OLD].IN -> [NEW]", " C1.M1.OUT1()", "C1.RC1.R.connected"]) def test_delegated(self): dg = Delegate() diff --git a/src/wormhole/test/test_xfer_util.py b/src/wormhole/test/test_xfer_util.py index 1c08f3a..673910a 100644 --- a/src/wormhole/test/test_xfer_util.py +++ b/src/wormhole/test/test_xfer_util.py @@ -1,11 +1,13 @@ -from twisted.trial import unittest -from twisted.internet import reactor, defer +from twisted.internet import defer, reactor from twisted.internet.defer import inlineCallbacks +from twisted.trial import unittest + from .. import xfer_util from .common import ServerBase APPID = u"appid" + class Xfer(ServerBase, unittest.TestCase): @inlineCallbacks def test_xfer(self): @@ -24,10 +26,15 @@ class Xfer(ServerBase, unittest.TestCase): data = u"data" send_code = [] receive_code = [] - d1 = xfer_util.send(reactor, APPID, self.relayurl, data, code, - on_code=send_code.append) - d2 = xfer_util.receive(reactor, APPID, self.relayurl, code, - on_code=receive_code.append) + d1 = xfer_util.send( + reactor, + APPID, + self.relayurl, + data, + code, + on_code=send_code.append) + d2 = xfer_util.receive( + reactor, APPID, self.relayurl, code, on_code=receive_code.append) send_result = yield d1 receive_result = yield d2 self.assertEqual(send_code, [code]) @@ -39,8 +46,13 @@ class Xfer(ServerBase, unittest.TestCase): def test_make_code(self): data = u"data" got_code = defer.Deferred() - d1 = xfer_util.send(reactor, APPID, self.relayurl, data, code=None, - on_code=got_code.callback) + d1 = xfer_util.send( + reactor, + APPID, + self.relayurl, + data, + code=None, + on_code=got_code.callback) code = yield got_code d2 = xfer_util.receive(reactor, APPID, self.relayurl, code) send_result = yield d1 diff --git a/src/wormhole/timing.py b/src/wormhole/timing.py index 8cb18e5..ecb29be 100644 --- a/src/wormhole/timing.py +++ b/src/wormhole/timing.py @@ -1,8 +1,13 @@ -from __future__ import print_function, absolute_import, unicode_literals -import json, time +from __future__ import absolute_import, print_function, unicode_literals + +import json +import time + from zope.interface import implementer + from ._interfaces import ITiming + class Event: def __init__(self, name, when, **details): # data fields that will be dumped to JSON later @@ -35,6 +40,7 @@ class Event: else: self.finish() + @implementer(ITiming) class DebugTiming: def __init__(self): @@ -47,11 +53,14 @@ class DebugTiming: def write(self, fn, stderr): with open(fn, "wt") as f: - data = [ dict(name=e._name, - start=e._start, stop=e._stop, - details=e._details, - ) - for e in self._events ] + data = [ + dict( + name=e._name, + start=e._start, + stop=e._stop, + details=e._details, + ) for e in self._events + ] json.dump(data, f, indent=1) f.write("\n") print("Timing data written to %s" % fn, file=stderr) diff --git a/src/wormhole/tor_manager.py b/src/wormhole/tor_manager.py index 0f8a449..eea12d6 100644 --- a/src/wormhole/tor_manager.py +++ b/src/wormhole/tor_manager.py @@ -1,15 +1,20 @@ from __future__ import print_function, unicode_literals + import sys -from attr import attrs, attrib -from zope.interface.declarations import directlyProvides + +from attr import attrib, attrs from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.endpoints import clientFromString +from zope.interface.declarations import directlyProvides + +from . import _interfaces, errors +from .timing import DebugTiming + try: import txtorcon except ImportError: txtorcon = None -from . import _interfaces, errors -from .timing import DebugTiming + @attrs class SocksOnlyTor(object): @@ -17,15 +22,20 @@ class SocksOnlyTor(object): def stream_via(self, host, port, tls=False): return txtorcon.TorClientEndpoint( - host, port, - socks_endpoint=None, # tries localhost:9050 and 9150 + host, + port, + socks_endpoint=None, # tries localhost:9050 and 9150 tls=tls, reactor=self._reactor, ) + @inlineCallbacks -def get_tor(reactor, launch_tor=False, tor_control_port=None, - timing=None, stderr=sys.stderr): +def get_tor(reactor, + launch_tor=False, + tor_control_port=None, + timing=None, + stderr=sys.stderr): """ If launch_tor=True, I will try to launch a new Tor process, ask it for its SOCKS and control ports, and use those for outbound @@ -59,7 +69,7 @@ def get_tor(reactor, launch_tor=False, tor_control_port=None, if not txtorcon: raise errors.NoTorError() - if not isinstance(launch_tor, bool): # note: False is int + if not isinstance(launch_tor, bool): # note: False is int raise TypeError("launch_tor= must be boolean") if not isinstance(tor_control_port, (type(""), type(None))): raise TypeError("tor_control_port= must be str or None") @@ -74,19 +84,21 @@ def get_tor(reactor, launch_tor=False, tor_control_port=None, # need the control port. if launch_tor: - print(" launching a new Tor process, this may take a while..", - file=stderr) + print( + " launching a new Tor process, this may take a while..", + file=stderr) with timing.add("launch tor"): tor = yield txtorcon.launch(reactor, - #data_directory=, - #tor_binary=, + # data_directory=, + # tor_binary=, ) elif tor_control_port: with timing.add("find tor"): control_ep = clientFromString(reactor, tor_control_port) - tor = yield txtorcon.connect(reactor, control_ep) # might raise - print(" using Tor via control port at %s" % tor_control_port, - file=stderr) + tor = yield txtorcon.connect(reactor, control_ep) # might raise + print( + " using Tor via control port at %s" % tor_control_port, + file=stderr) else: # Let txtorcon look through a list of usual places. If that fails, # we'll arrange to attempt the default SOCKS port @@ -98,8 +110,9 @@ def get_tor(reactor, launch_tor=False, tor_control_port=None, # TODO: make this more specific. I think connect() is # likely to throw a reactor.connectTCP -type error, like # ConnectionFailed or ConnectionRefused or something - print(" unable to find default Tor control port, using SOCKS", - file=stderr) + print( + " unable to find default Tor control port, using SOCKS", + file=stderr) tor = SocksOnlyTor(reactor) directlyProvides(tor, _interfaces.ITorManager) returnValue(tor) diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index f8f824a..6363c6d 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -1,38 +1,51 @@ # no unicode_literals, revisit after twisted patch -from __future__ import print_function, absolute_import -import os, re, sys, time, socket -from collections import namedtuple, deque +from __future__ import absolute_import, print_function + +import os +import re +import socket +import sys +import time from binascii import hexlify, unhexlify +from collections import deque, namedtuple + import six -from zope.interface import implementer -from twisted.python import log -from twisted.python.runtime import platformType -from twisted.internet import (reactor, interfaces, defer, protocol, - endpoints, task, address, error) +from hkdf import Hkdf +from nacl.secret import SecretBox +from twisted.internet import (address, defer, endpoints, error, interfaces, + protocol, reactor, task) from twisted.internet.defer import inlineCallbacks, returnValue from twisted.protocols import policies -from nacl.secret import SecretBox -from hkdf import Hkdf +from twisted.python import log +from twisted.python.runtime import platformType +from zope.interface import implementer + +from . import ipaddrs from .errors import InternalError from .timing import DebugTiming from .util import bytes_to_hexstr -from . import ipaddrs + def HKDF(skm, outlen, salt=None, CTXinfo=b""): return Hkdf(salt, skm).expand(CTXinfo, outlen) + class TransitError(Exception): pass + class BadHandshake(Exception): pass + class TransitClosed(TransitError): pass + class BadNonce(TransitError): pass + # The beginning of each TCP connection consists of the following handshake # messages. The sender transmits the same text regardless of whether it is on # the initiating/connecting end of the TCP connection, or on the @@ -63,19 +76,23 @@ class BadNonce(TransitError): # RXID_HEX ready\n\n" and then makes a first/not-first decision about sending # "go\n" or "nevermind\n"+close(). + def build_receiver_handshake(key): hexid = HKDF(key, 32, CTXinfo=b"transit_receiver") - return b"transit receiver "+hexlify(hexid)+b" ready\n\n" + return b"transit receiver " + hexlify(hexid) + b" ready\n\n" + def build_sender_handshake(key): hexid = HKDF(key, 32, CTXinfo=b"transit_sender") - return b"transit sender "+hexlify(hexid)+b" ready\n\n" + return b"transit sender " + hexlify(hexid) + b" ready\n\n" + def build_sided_relay_handshake(key, side): assert isinstance(side, type(u"")) - assert len(side) == 8*2 + assert len(side) == 8 * 2 token = HKDF(key, 32, CTXinfo=b"transit_relay_token") - return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n" + return b"please relay " + hexlify(token) + b" for side " + side.encode( + "ascii") + b"\n" # These namedtuples are "hint objects". The JSON-serializable dictionaries @@ -87,7 +104,8 @@ def build_sided_relay_handshake(key, side): # * expect to see the receiver/sender handshake bytes from the other side # * the sender writes "go\n", the receiver waits for "go\n" # * the rest of the connection contains transit data -DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"]) +DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", + ["hostname", "port", "priority"]) TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) # RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we # use a tuple rather than a list so they'll be hashable into a set). For each @@ -95,6 +113,7 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) # rest of the V1 protocol. Only one hint per relay is useful. RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) + def describe_hint_obj(hint): if isinstance(hint, DirectTCPV1Hint): return u"tcp:%s:%d" % (hint.hostname, hint.port) @@ -103,27 +122,30 @@ def describe_hint_obj(hint): else: return str(hint) + def parse_hint_argv(hint, stderr=sys.stderr): assert isinstance(hint, type(u"")) # return tuple or None for an unparseable hint priority = 0.0 mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) if not mo: - print("unparseable hint '%s'" % (hint,), file=stderr) + print("unparseable hint '%s'" % (hint, ), file=stderr) return None hint_type = mo.group(1) if hint_type != "tcp": - print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) + print( + "unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) return None hint_value = mo.group(2) pieces = hint_value.split(":") if len(pieces) < 2: - print("unparseable TCP hint (need more colons) '%s'" % (hint,), - file=stderr) + print( + "unparseable TCP hint (need more colons) '%s'" % (hint, ), + file=stderr) return None mo = re.search(r'^(\d+)$', pieces[1]) if not mo: - print("non-numeric port in TCP hint '%s'" % (hint,), file=stderr) + print("non-numeric port in TCP hint '%s'" % (hint, ), file=stderr) return None hint_host = pieces[0] hint_port = int(pieces[1]) @@ -133,12 +155,15 @@ def parse_hint_argv(hint, stderr=sys.stderr): try: priority = float(more_pieces[1]) except ValueError: - print("non-float priority= in TCP hint '%s'" % (hint,), - file=stderr) + print( + "non-float priority= in TCP hint '%s'" % (hint, ), + file=stderr) return None return DirectTCPV1Hint(hint_host, hint_port, priority) -TIMEOUT = 60 # seconds + +TIMEOUT = 60 # seconds + @implementer(interfaces.IProducer, interfaces.IConsumer) class Connection(protocol.Protocol, policies.TimeoutMixin): @@ -159,7 +184,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): self._waiting_reads = deque() def connectionMade(self): - self.setTimeout(TIMEOUT) # does timeoutConnection() when it expires + self.setTimeout(TIMEOUT) # does timeoutConnection() when it expires self.factory.connectionWasMade(self) def startNegotiation(self): @@ -168,11 +193,11 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): self.state = "relay" else: self.state = "start" - self.dataReceived(b"") # cycle the state machine + self.dataReceived(b"") # cycle the state machine return self._negotiation_d def _cancel(self, d): - self.state = "hung up" # stop reacting to anything further + self.state = "hung up" # stop reacting to anything further self._error = defer.CancelledError() self.transport.loseConnection() # if connectionLost isn't called synchronously, then our @@ -181,7 +206,6 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): if self._negotiation_d: self._negotiation_d = None - def dataReceived(self, data): try: self._dataReceived(data) @@ -198,7 +222,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): if not self.buf.startswith(expected[:len(self.buf)]): raise BadHandshake("got %r want %r" % (self.buf, expected)) if len(self.buf) < len(expected): - return False # keep waiting + return False # keep waiting self.buf = self.buf[len(expected):] return True @@ -245,9 +269,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): return self.dataReceivedRECORDS() if self.state == "hung up": return - if isinstance(self.state, Exception): # for tests + if isinstance(self.state, Exception): # for tests raise self.state - raise ValueError("internal error: unknown state %s" % (self.state,)) + raise ValueError("internal error: unknown state %s" % (self.state, )) def _negotiationSuccessful(self): self.state = "records" @@ -266,19 +290,20 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): if len(self.buf) < 4: return length = int(hexlify(self.buf[:4]), 16) - if len(self.buf) < 4+length: + if len(self.buf) < 4 + length: return - encrypted, self.buf = self.buf[4:4+length], self.buf[4+length:] + encrypted, self.buf = self.buf[4:4 + length], self.buf[4 + length:] record = self._decrypt_record(encrypted) self.recordReceived(record) def _decrypt_record(self, encrypted): - nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended + nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended nonce = int(hexlify(nonce_buf), 16) if nonce != self.next_receive_nonce: - raise BadNonce("received out-of-order record: got %d, expected %d" - % (nonce, self.next_receive_nonce)) + raise BadNonce( + "received out-of-order record: got %d, expected %d" % + (nonce, self.next_receive_nonce)) self.next_receive_nonce += 1 record = self.receive_box.decrypt(encrypted) return record @@ -287,14 +312,15 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): return self._description def send_record(self, record): - if not isinstance(record, type(b"")): raise InternalError + if not isinstance(record, type(b"")): + raise InternalError assert SecretBox.NONCE_SIZE == 24 - assert self.send_nonce < 2**(8*24) - assert len(record) < 2**(8*4) - nonce = unhexlify("%048x" % self.send_nonce) # big-endian + assert self.send_nonce < 2**(8 * 24) + assert len(record) < 2**(8 * 4) + nonce = unhexlify("%048x" % self.send_nonce) # big-endian self.send_nonce += 1 encrypted = self.send_box.encrypt(record, nonce) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long self.transport.write(length) self.transport.write(encrypted) @@ -353,8 +379,10 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): def registerProducer(self, producer, streaming): assert interfaces.IConsumer.providedBy(self.transport) self.transport.registerProducer(producer, streaming) + def unregisterProducer(self): self.transport.unregisterProducer() + def write(self, data): self.send_record(data) @@ -362,8 +390,10 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): # the transport. def stopProducing(self): self.transport.stopProducing() + def pauseProducing(self): self.transport.pauseProducing() + def resumeProducing(self): self.transport.resumeProducing() @@ -384,8 +414,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): Deferred, and you must call disconnectConsumer() when you are done.""" if self._consumer: - raise RuntimeError("A consumer is already attached: %r" % - self._consumer) + raise RuntimeError( + "A consumer is already attached: %r" % self._consumer) # be aware of an ordering hazard: when we call the consumer's # .registerProducer method, they are likely to immediately call @@ -440,6 +470,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): fc = FileConsumer(f, progress, hasher) return self.connectConsumer(fc, expected) + class OutboundConnectionFactory(protocol.ClientFactory): protocol = Connection @@ -478,7 +509,7 @@ class InboundConnectionFactory(protocol.ClientFactory): def _shutdown(self): for d in list(self._pending_connections): - d.cancel() # that fires _remove and _proto_failed + d.cancel() # that fires _remove and _proto_failed def _describePeer(self, addr): if isinstance(addr, address.HostnameAddress): @@ -511,6 +542,7 @@ class InboundConnectionFactory(protocol.ClientFactory): # ignore these two, let Twisted log everything else f.trap(BadHandshake, defer.CancelledError) + def allocate_tcp_port(): """Return an (integer) available TCP port on localhost. This briefly listens on the port in question, then closes it right away.""" @@ -527,6 +559,7 @@ def allocate_tcp_port(): s.close() return port + class _ThereCanBeOnlyOne: """Accept a list of contender Deferreds, and return a summary Deferred. When the first contender fires successfully, cancel the rest and fire the @@ -535,6 +568,7 @@ class _ThereCanBeOnlyOne: status_cb=? """ + def __init__(self, contenders): self._remaining = set(contenders) self._winner_d = defer.Deferred(self._cancel) @@ -581,26 +615,32 @@ class _ThereCanBeOnlyOne: else: self._winner_d.errback(self._first_failure) + def there_can_be_only_one(contenders): return _ThereCanBeOnlyOne(contenders).run() + class Common: RELAY_DELAY = 2.0 TRANSIT_KEY_LENGTH = SecretBox.KEY_SIZE - def __init__(self, transit_relay, no_listen=False, tor=None, - reactor=reactor, timing=None): - self._side = bytes_to_hexstr(os.urandom(8)) # unicode + def __init__(self, + transit_relay, + no_listen=False, + tor=None, + reactor=reactor, + timing=None): + self._side = bytes_to_hexstr(os.urandom(8)) # unicode if transit_relay: if not isinstance(transit_relay, type(u"")): raise InternalError # TODO: allow multiple hints for a single relay relay_hint = parse_hint_argv(transit_relay) - relay = RelayV1Hint(hints=(relay_hint,)) + relay = RelayV1Hint(hints=(relay_hint, )) self._transit_relays = [relay] else: self._transit_relays = [] - self._their_direct_hints = [] # hintobjs + self._their_direct_hints = [] # hintobjs self._our_relay_hints = set(self._transit_relays) self._tor = tor self._transit_key = None @@ -622,33 +662,42 @@ class Common: # some test hosts, including the appveyor VMs, *only* have # 127.0.0.1, and the tests will hang badly if we remove it. addresses = non_loopback_addresses - direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0) - for addr in addresses] + direct_hints = [ + DirectTCPV1Hint(six.u(addr), portnum, 0.0) for addr in addresses + ] ep = endpoints.serverFromString(reactor, "tcp:%d" % portnum) return direct_hints, ep def get_connection_abilities(self): - return [{u"type": u"direct-tcp-v1"}, - {u"type": u"relay-v1"}, - ] + return [ + { + u"type": u"direct-tcp-v1" + }, + { + u"type": u"relay-v1" + }, + ] @inlineCallbacks def get_connection_hints(self): hints = [] direct_hints = yield self._get_direct_hints() for dh in direct_hints: - hints.append({u"type": u"direct-tcp-v1", - u"priority": dh.priority, - u"hostname": dh.hostname, - u"port": dh.port, # integer - }) + hints.append({ + u"type": u"direct-tcp-v1", + u"priority": dh.priority, + u"hostname": dh.hostname, + u"port": dh.port, # integer + }) for relay in self._transit_relays: rhint = {u"type": u"relay-v1", u"hints": []} for rh in relay.hints: - rhint[u"hints"].append({u"type": u"direct-tcp-v1", - u"priority": rh.priority, - u"hostname": rh.hostname, - u"port": rh.port}) + rhint[u"hints"].append({ + u"type": u"direct-tcp-v1", + u"priority": rh.priority, + u"hostname": rh.hostname, + u"port": rh.port + }) hints.append(rhint) returnValue(hints) @@ -665,24 +714,27 @@ class Common: # listener will win. self._my_direct_hints, self._listener = self._build_listener() - if self._listener is None: # don't listen + if self._listener is None: # don't listen self._listener_d = None - return defer.succeed(self._my_direct_hints) # empty + return defer.succeed(self._my_direct_hints) # empty # Start the server, so it will be running by the time anyone tries to # connect to the direct hints we return. f = InboundConnectionFactory(self) - self._listener_f = f # for tests # XX move to __init__ ? + self._listener_f = f # for tests # XX move to __init__ ? self._listener_d = f.whenDone() d = self._listener.listen(f) + def _listening(lp): # lp is an IListeningPort - #self._listener_port = lp # for tests + # self._listener_port = lp # for tests def _stop_listening(res): lp.stopListening() return res + self._listener_d.addBoth(_stop_listening) return self._my_direct_hints + d.addCallback(_listening) return d @@ -694,18 +746,18 @@ class Common: self._listener_d.addErrback(lambda f: None) self._listener_d.cancel() - def _parse_tcp_v1_hint(self, hint): # hint_struct -> hint_obj + def _parse_tcp_v1_hint(self, hint): # hint_struct -> hint_obj hint_type = hint.get(u"type", u"") if hint_type not in [u"direct-tcp-v1", u"tor-tcp-v1"]: - log.msg("unknown hint type: %r" % (hint,)) + log.msg("unknown hint type: %r" % (hint, )) return None - if not(u"hostname" in hint - and isinstance(hint[u"hostname"], type(u""))): - log.msg("invalid hostname in hint: %r" % (hint,)) + if not (u"hostname" in hint + and isinstance(hint[u"hostname"], type(u""))): + log.msg("invalid hostname in hint: %r" % (hint, )) return None - if not(u"port" in hint - and isinstance(hint[u"port"], six.integer_types)): - log.msg("invalid port in hint: %r" % (hint,)) + if not (u"port" in hint + and isinstance(hint[u"port"], six.integer_types)): + log.msg("invalid port in hint: %r" % (hint, )) return None priority = hint.get(u"priority", 0.0) if hint_type == u"direct-tcp-v1": @@ -714,12 +766,12 @@ class Common: return TorTCPV1Hint(hint[u"hostname"], hint[u"port"], priority) def add_connection_hints(self, hints): - for h in hints: # hint structs + for h in hints: # hint structs hint_type = h.get(u"type", u"") if hint_type in [u"direct-tcp-v1", u"tor-tcp-v1"]: dh = self._parse_tcp_v1_hint(h) if dh: - self._their_direct_hints.append(dh) # hint_obj + self._their_direct_hints.append(dh) # hint_obj elif hint_type == u"relay-v1": # TODO: each relay-v1 clause describes a different relay, # with a set of equally-valid ways to connect to it. Treat @@ -734,7 +786,7 @@ class Common: rh = RelayV1Hint(hints=tuple(sorted(relay_hints))) self._our_relay_hints.add(rh) else: - log.msg("unknown hint type: %r" % (h,)) + log.msg("unknown hint type: %r" % (h, )) def _send_this(self): assert self._transit_key @@ -748,25 +800,33 @@ class Common: if self.is_sender: return build_receiver_handshake(self._transit_key) else: - return build_sender_handshake(self._transit_key)# + b"go\n" + return build_sender_handshake(self._transit_key) # + b"go\n" def _sender_record_key(self): assert self._transit_key if self.is_sender: - return HKDF(self._transit_key, SecretBox.KEY_SIZE, - CTXinfo=b"transit_record_sender_key") + return HKDF( + self._transit_key, + SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_sender_key") else: - return HKDF(self._transit_key, SecretBox.KEY_SIZE, - CTXinfo=b"transit_record_receiver_key") + return HKDF( + self._transit_key, + SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_receiver_key") def _receiver_record_key(self): assert self._transit_key if self.is_sender: - return HKDF(self._transit_key, SecretBox.KEY_SIZE, - CTXinfo=b"transit_record_receiver_key") + return HKDF( + self._transit_key, + SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_receiver_key") else: - return HKDF(self._transit_key, SecretBox.KEY_SIZE, - CTXinfo=b"transit_record_sender_key") + return HKDF( + self._transit_key, + SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_sender_key") def set_transit_key(self, key): assert isinstance(key, type(b"")), type(key) @@ -848,9 +908,13 @@ class Common: description = "->relay:%s" % describe_hint_obj(hint_obj) if self._tor: description = "tor" + description - d = task.deferLater(self._reactor, relay_delay, - self._start_connector, ep, description, - is_relay=True) + d = task.deferLater( + self._reactor, + relay_delay, + self._start_connector, + ep, + description, + is_relay=True) contenders.append(d) relay_delay += self.RELAY_DELAY @@ -858,16 +922,18 @@ class Common: raise TransitError("No contenders for connection") winner = there_can_be_only_one(contenders) - return self._not_forever(2*TIMEOUT, winner) + return self._not_forever(2 * TIMEOUT, winner) def _not_forever(self, timeout, d): """If the timer fires first, cancel the deferred. If the deferred fires first, cancel the timer.""" t = self._reactor.callLater(timeout, d.cancel) + def _done(res): if t.active(): t.cancel() return res + d.addBoth(_done) return d @@ -896,8 +962,8 @@ class Common: return None return None if isinstance(hint, DirectTCPV1Hint): - return endpoints.HostnameEndpoint(self._reactor, - hint.hostname, hint.port) + return endpoints.HostnameEndpoint(self._reactor, hint.hostname, + hint.port) return None def connection_ready(self, p): @@ -915,9 +981,11 @@ class Common: self._winner = p return "go" + class TransitSender(Common): is_sender = True + class TransitReceiver(Common): is_sender = False @@ -926,6 +994,7 @@ class TransitReceiver(Common): # when done, and add a progress function that gets called with the length of # each write, and a hasher function that gets called with the data. + @implementer(interfaces.IConsumer) class FileConsumer: def __init__(self, f, progress=None, hasher=None): @@ -950,6 +1019,7 @@ class FileConsumer: assert self._producer self._producer = None + # the TransitSender/Receiver.connect() yields a Connection, on which you can # do send_record(), but what should the receive API be? set a callback for # inbound records? get a Deferred for the next record? The producer/consumer diff --git a/src/wormhole/util.py b/src/wormhole/util.py index b5e39fb..0b57c5e 100644 --- a/src/wormhole/util.py +++ b/src/wormhole/util.py @@ -1,30 +1,42 @@ # No unicode_literals -import os, json, unicodedata +import json +import os +import unicodedata from binascii import hexlify, unhexlify + def to_bytes(u): return unicodedata.normalize("NFC", u).encode("utf-8") + + def bytes_to_hexstr(b): assert isinstance(b, type(b"")) hexstr = hexlify(b).decode("ascii") assert isinstance(hexstr, type(u"")) return hexstr + + def hexstr_to_bytes(hexstr): assert isinstance(hexstr, type(u"")) b = unhexlify(hexstr.encode("ascii")) assert isinstance(b, type(b"")) return b + + def dict_to_bytes(d): assert isinstance(d, dict) b = json.dumps(d).encode("utf-8") assert isinstance(b, type(b"")) return b + + def bytes_to_dict(b): assert isinstance(b, type(b"")) d = json.loads(b.decode("utf-8")) assert isinstance(d, dict) return d + def estimate_free_space(target): # f_bfree is the blocks available to a root user. It might be more # accurate to use f_bavail (blocks available to non-root user), but we diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 311e3fc..610069c 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -1,19 +1,25 @@ -from __future__ import print_function, absolute_import, unicode_literals -import os, sys -from attr import attrs, attrib -from zope.interface import implementer +from __future__ import absolute_import, print_function, unicode_literals + +import os +import sys + +from attr import attrib, attrs from twisted.python import failure -from . import __version__ -from ._interfaces import IWormhole, IDeferredWormhole -from .util import bytes_to_hexstr -from .eventual import EventualQueue -from .observer import OneShotObserver, SequenceObserver -from .timing import DebugTiming -from .journal import ImmediateJournal +from zope.interface import implementer + from ._boss import Boss +from ._interfaces import IDeferredWormhole, IWormhole from ._key import derive_key from .errors import NoKeyError, WormholeClosed -from .util import to_bytes +from .eventual import EventualQueue +from .journal import ImmediateJournal +from .observer import OneShotObserver, SequenceObserver +from .timing import DebugTiming +from .util import bytes_to_hexstr, to_bytes +from ._version import get_versions + +__version__ = get_versions()['version'] +del get_versions # We can provide different APIs to different apps: # * Deferreds @@ -36,6 +42,7 @@ from .util import to_bytes # wormhole(delegate=app, delegate_prefix="wormhole_", # delegate_args=(args, kwargs)) + @attrs @implementer(IWormhole) class _DelegatedWormhole(object): @@ -51,16 +58,18 @@ class _DelegatedWormhole(object): def allocate_code(self, code_length=2): self._boss.allocate_code(code_length) + def input_code(self): return self._boss.input_code() + def set_code(self, code): self._boss.set_code(code) - ## def serialize(self): - ## s = {"serialized_wormhole_version": 1, - ## "boss": self._boss.serialize(), - ## } - ## return s + # def serialize(self): + # s = {"serialized_wormhole_version": 1, + # "boss": self._boss.serialize(), + # } + # return s def send_message(self, plaintext): self._boss.send(plaintext) @@ -72,34 +81,45 @@ class _DelegatedWormhole(object): cannot be called until when_verifier() has fired, nor after close() was called. """ - if not isinstance(purpose, type("")): raise TypeError(type(purpose)) - if not self._key: raise NoKeyError() + if not isinstance(purpose, type("")): + raise TypeError(type(purpose)) + if not self._key: + raise NoKeyError() return derive_key(self._key, to_bytes(purpose), length) def close(self): self._boss.close() - def debug_set_trace(self, client_name, which="B N M S O K SK R RC L C T", + def debug_set_trace(self, + client_name, + which="B N M S O K SK R RC L C T", file=sys.stderr): self._boss._set_trace(client_name, which, file) # from below def got_welcome(self, welcome): self._delegate.wormhole_got_welcome(welcome) + def got_code(self, code): self._delegate.wormhole_got_code(code) + def got_key(self, key): self._delegate.wormhole_got_unverified_key(key) - self._key = key # for derive_key() + self._key = key # for derive_key() + def got_verifier(self, verifier): self._delegate.wormhole_got_verifier(verifier) + def got_versions(self, versions): self._delegate.wormhole_got_versions(versions) + def received(self, plaintext): self._delegate.wormhole_got_message(plaintext) + def closed(self, result): self._delegate.wormhole_closed(result) + @implementer(IWormhole, IDeferredWormhole) class _DeferredWormhole(object): def __init__(self, eq): @@ -142,8 +162,10 @@ class _DeferredWormhole(object): def allocate_code(self, code_length=2): self._boss.allocate_code(code_length) + def input_code(self): return self._boss.input_code() + def set_code(self, code): self._boss.set_code(code) @@ -159,20 +181,23 @@ class _DeferredWormhole(object): cannot be called until when_verified() has fired, nor after close() was called. """ - if not isinstance(purpose, type("")): raise TypeError(type(purpose)) - if not self._key: raise NoKeyError() + if not isinstance(purpose, type("")): + raise TypeError(type(purpose)) + if not self._key: + raise NoKeyError() return derive_key(self._key, to_bytes(purpose), length) def close(self): # fails with WormholeError unless we established a connection # (state=="happy"). Fails with WrongPasswordError (a subclass of # WormholeError) if state=="scary". - d = self._closed_observer.when_fired() # maybe Failure + d = self._closed_observer.when_fired() # maybe Failure if not self._closed: - self._boss.close() # only need to close if it wasn't already + self._boss.close() # only need to close if it wasn't already return d - def debug_set_trace(self, client_name, + def debug_set_trace(self, + client_name, which="B N M S O K SK R RC L A I C T", file=sys.stderr): self._boss._set_trace(client_name, which, file) @@ -180,14 +205,17 @@ class _DeferredWormhole(object): # from below def got_welcome(self, welcome): self._welcome_observer.fire_if_not_fired(welcome) + def got_code(self, code): self._code_observer.fire_if_not_fired(code) + def got_key(self, key): - self._key = key # for derive_key() + self._key = key # for derive_key() self._key_observer.fire_if_not_fired(key) def got_verifier(self, verifier): self._verifier_observer.fire_if_not_fired(verifier) + def got_versions(self, versions): self._version_observer.fire_if_not_fired(versions) @@ -196,7 +224,7 @@ class _DeferredWormhole(object): def closed(self, result): self._closed = True - #print("closed", result, type(result), file=sys.stderr) + # print("closed", result, type(result), file=sys.stderr) if isinstance(result, Exception): # everything pending gets an error, including close() f = failure.Failure(result) @@ -215,12 +243,17 @@ class _DeferredWormhole(object): self._received_observer.fire(f) -def create(appid, relay_url, reactor, # use keyword args for everything else - versions={}, - delegate=None, journal=None, tor=None, - timing=None, - stderr=sys.stderr, - _eventual_queue=None): +def create( + appid, + relay_url, + reactor, # use keyword args for everything else + versions={}, + delegate=None, + journal=None, + tor=None, + timing=None, + stderr=sys.stderr, + _eventual_queue=None): timing = timing or DebugTiming() side = bytes_to_hexstr(os.urandom(5)) journal = journal or ImmediateJournal() @@ -229,27 +262,28 @@ def create(appid, relay_url, reactor, # use keyword args for everything else w = _DelegatedWormhole(delegate) else: w = _DeferredWormhole(eq) - wormhole_versions = {} # will be used to indicate Wormhole capabilities - wormhole_versions["app_versions"] = versions # app-specific capabilities + wormhole_versions = {} # will be used to indicate Wormhole capabilities + wormhole_versions["app_versions"] = versions # app-specific capabilities v = __version__ if isinstance(v, type(b"")): v = v.decode("utf-8", errors="replace") client_version = ("python", v) b = Boss(w, side, relay_url, appid, wormhole_versions, client_version, - reactor, journal, tor, timing) + reactor, journal, tor, timing) w._set_boss(b) b.start() return w -## def from_serialized(serialized, reactor, delegate, -## journal=None, tor=None, -## timing=None, stderr=sys.stderr): -## assert serialized["serialized_wormhole_version"] == 1 -## timing = timing or DebugTiming() -## w = _DelegatedWormhole(delegate) -## # now unpack state machines, including the SPAKE2 in Key -## b = Boss.from_serialized(w, serialized["boss"], reactor, journal, timing) -## w._set_boss(b) -## b.start() # ?? -## raise NotImplemented -## # should the new Wormhole call got_code? only if it wasn't called before. + +# def from_serialized(serialized, reactor, delegate, +# journal=None, tor=None, +# timing=None, stderr=sys.stderr): +# assert serialized["serialized_wormhole_version"] == 1 +# timing = timing or DebugTiming() +# w = _DelegatedWormhole(delegate) +# # now unpack state machines, including the SPAKE2 in Key +# b = Boss.from_serialized(w, serialized["boss"], reactor, journal, timing) +# w._set_boss(b) +# b.start() # ?? +# raise NotImplemented +# # should the new Wormhole call got_code? only if it wasn't called before. diff --git a/src/wormhole/xfer_util.py b/src/wormhole/xfer_util.py index c62aa12..465c05d 100644 --- a/src/wormhole/xfer_util.py +++ b/src/wormhole/xfer_util.py @@ -1,12 +1,19 @@ import json + from twisted.internet.defer import inlineCallbacks, returnValue from . import wormhole from .tor_manager import get_tor + @inlineCallbacks -def receive(reactor, appid, relay_url, code, - use_tor=False, launch_tor=False, tor_control_port=None, +def receive(reactor, + appid, + relay_url, + code, + use_tor=False, + launch_tor=False, + tor_control_port=None, on_code=None): """ This is a convenience API which returns a Deferred that callbacks @@ -21,9 +28,11 @@ def receive(reactor, appid, relay_url, code, :param unicode code: a pre-existing code to use, or None - :param bool use_tor: True if we should use Tor, False to not use it (None for default) + :param bool use_tor: True if we should use Tor, False to not use it (None + for default) - :param on_code: if not None, this is called when we have a code (even if you passed in one explicitly) + :param on_code: if not None, this is called when we have a code (even if + you passed in one explicitly) :type on_code: single-argument callable """ tor = None @@ -48,27 +57,33 @@ def receive(reactor, appid, relay_url, code, data = json.loads(data.decode("utf-8")) offer = data.get('offer', None) if not offer: - raise Exception( - "Do not understand response: {}".format(data) - ) + raise Exception("Do not understand response: {}".format(data)) msg = None if 'message' in offer: msg = offer['message'] - wh.send_message(json.dumps({"answer": - {"message_ack": "ok"}}).encode("utf-8")) + wh.send_message( + json.dumps({ + "answer": { + "message_ack": "ok" + } + }).encode("utf-8")) else: - raise Exception( - "Unknown offer type: {}".format(offer.keys()) - ) + raise Exception("Unknown offer type: {}".format(offer.keys())) yield wh.close() returnValue(msg) @inlineCallbacks -def send(reactor, appid, relay_url, data, code, - use_tor=False, launch_tor=False, tor_control_port=None, +def send(reactor, + appid, + relay_url, + data, + code, + use_tor=False, + launch_tor=False, + tor_control_port=None, on_code=None): """ This is a convenience API which returns a Deferred that callbacks @@ -83,9 +98,12 @@ def send(reactor, appid, relay_url, data, code, :param unicode code: a pre-existing code to use, or None - :param bool use_tor: True if we should use Tor, False to not use it (None for default) + :param bool use_tor: True if we should use Tor, False to not use it (None + for default) + + :param on_code: if not None, this is called when we have a code (even if + you passed in one explicitly) - :param on_code: if not None, this is called when we have a code (even if you passed in one explicitly) :type on_code: single-argument callable """ tor = None @@ -104,13 +122,7 @@ def send(reactor, appid, relay_url, data, code, if on_code: on_code(code) - wh.send_message( - json.dumps({ - "offer": { - "message": data - } - }).encode("utf-8") - ) + wh.send_message(json.dumps({"offer": {"message": data}}).encode("utf-8")) data = yield wh.get_message() data = json.loads(data.decode("utf-8")) answer = data.get('answer', None) @@ -118,6 +130,4 @@ def send(reactor, appid, relay_url, data, code, if answer: returnValue(None) else: - raise Exception( - "Unknown answer: {}".format(data) - ) + raise Exception("Unknown answer: {}".format(data))