Make code pep-8 compliant

This commit is contained in:
Vasudev Kamath 2018-04-21 13:00:08 +05:30
parent 355cc01aee
commit 12dcd6a184
53 changed files with 3260 additions and 1899 deletions

View File

@ -1,9 +1,4 @@
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions
from .wormhole import create
from ._rlcompleter import input_with_completion
from .wormhole import create, __version__
__all__ = ["create", "input_with_completion", "__version__"]

View File

@ -1,8 +1,6 @@
from .cli import cli
if __name__ != "__main__":
raise ImportError('this module should not be imported')
from .cli import cli
cli.wormhole()

View File

@ -1,56 +1,78 @@
from __future__ import print_function, absolute_import, unicode_literals
from zope.interface import implementer
from attr import attrs, attrib
from __future__ import absolute_import, print_function, unicode_literals
from attr import attrib, attrs
from attr.validators import provides
from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces
@attrs
@implementer(_interfaces.IAllocator)
class Allocator(object):
_timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def wire(self, rendezvous_connector, code):
self._RC = _interfaces.IRendezvousConnector(rendezvous_connector)
self._C = _interfaces.ICode(code)
@m.state(initial=True)
def S0A_idle(self): pass # pragma: no cover
def S0A_idle(self):
pass # pragma: no cover
@m.state()
def S0B_idle_connected(self): pass # pragma: no cover
def S0B_idle_connected(self):
pass # pragma: no cover
@m.state()
def S1A_allocating(self): pass # pragma: no cover
def S1A_allocating(self):
pass # pragma: no cover
@m.state()
def S1B_allocating_connected(self): pass # pragma: no cover
def S1B_allocating_connected(self):
pass # pragma: no cover
@m.state()
def S2_done(self): pass # pragma: no cover
def S2_done(self):
pass # pragma: no cover
# from Code
@m.input()
def allocate(self, length, wordlist): pass
def allocate(self, length, wordlist):
pass
# from RendezvousConnector
@m.input()
def connected(self): pass
def connected(self):
pass
@m.input()
def lost(self): pass
def lost(self):
pass
@m.input()
def rx_allocated(self, nameplate): pass
def rx_allocated(self, nameplate):
pass
@m.output()
def stash(self, length, wordlist):
self._length = length
self._wordlist = _interfaces.IWordlist(wordlist)
@m.output()
def stash_and_RC_rx_allocate(self, length, wordlist):
self._length = length
self._wordlist = _interfaces.IWordlist(wordlist)
self._RC.tx_allocate()
@m.output()
def RC_tx_allocate(self):
self._RC.tx_allocate()
@m.output()
def build_and_notify(self, nameplate):
words = self._wordlist.choose_words(self._length)
@ -61,15 +83,17 @@ class Allocator(object):
S0B_idle_connected.upon(lost, enter=S0A_idle, outputs=[])
S0A_idle.upon(allocate, enter=S1A_allocating, outputs=[stash])
S0B_idle_connected.upon(allocate, enter=S1B_allocating_connected,
outputs=[stash_and_RC_rx_allocate])
S0B_idle_connected.upon(
allocate,
enter=S1B_allocating_connected,
outputs=[stash_and_RC_rx_allocate])
S1A_allocating.upon(connected, enter=S1B_allocating_connected,
outputs=[RC_tx_allocate])
S1A_allocating.upon(
connected, enter=S1B_allocating_connected, outputs=[RC_tx_allocate])
S1B_allocating_connected.upon(lost, enter=S1A_allocating, outputs=[])
S1B_allocating_connected.upon(rx_allocated, enter=S2_done,
outputs=[build_and_notify])
S1B_allocating_connected.upon(
rx_allocated, enter=S2_done, outputs=[build_and_notify])
S2_done.upon(connected, enter=S2_done, outputs=[])
S2_done.upon(lost, enter=S2_done, outputs=[])

View File

@ -1,29 +1,33 @@
from __future__ import print_function, absolute_import, unicode_literals
from __future__ import absolute_import, print_function, unicode_literals
import re
import six
from zope.interface import implementer
from attr import attrs, attrib
from attr.validators import provides, instance_of, optional
from twisted.python import log
from attr import attrib, attrs
from attr.validators import instance_of, optional, provides
from automat import MethodicalMachine
from twisted.python import log
from zope.interface import implementer
from . import _interfaces
from ._nameplate import Nameplate
from ._mailbox import Mailbox
from ._send import Send
from ._order import Order
from ._allocator import Allocator
from ._code import Code, validate_code
from ._input import Input
from ._key import Key
from ._lister import Lister
from ._mailbox import Mailbox
from ._nameplate import Nameplate
from ._order import Order
from ._receive import Receive
from ._rendezvous import RendezvousConnector
from ._lister import Lister
from ._allocator import Allocator
from ._input import Input
from ._code import Code, validate_code
from ._send import Send
from ._terminator import Terminator
from ._wordlist import PGPWordList
from .errors import (ServerError, LonelyError, WrongPasswordError,
OnlyOneCodeError, _UnknownPhaseError, WelcomeError)
from .errors import (LonelyError, OnlyOneCodeError, ServerError, WelcomeError,
WrongPasswordError, _UnknownPhaseError)
from .util import bytes_to_dict
@attrs
@implementer(_interfaces.IBoss)
class Boss(object):
@ -38,7 +42,8 @@ class Boss(object):
_tor = attrib(validator=optional(provides(_interfaces.ITorManager)))
_timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self):
self._build_workers()
@ -52,9 +57,8 @@ class Boss(object):
self._K = Key(self._appid, self._versions, self._side, self._timing)
self._R = Receive(self._side, self._timing)
self._RC = RendezvousConnector(self._url, self._appid, self._side,
self._reactor, self._journal,
self._tor, self._timing,
self._client_version)
self._reactor, self._journal, self._tor,
self._timing, self._client_version)
self._L = Lister(self._timing)
self._A = Allocator(self._timing)
self._I = Input(self._timing)
@ -78,7 +82,7 @@ class Boss(object):
self._did_start_code = False
self._next_tx_phase = 0
self._next_rx_phase = 0
self._rx_phases = {} # phase -> plaintext
self._rx_phases = {} # phase -> plaintext
self._result = "empty"
@ -86,32 +90,45 @@ class Boss(object):
def start(self):
self._RC.start()
def _print_trace(self, old_state, input, new_state,
client_name, machine, file):
def _print_trace(self, old_state, input, new_state, client_name, machine,
file):
if new_state:
print("%s.%s[%s].%s -> [%s]" %
(client_name, machine, old_state, input,
new_state), file=file)
print(
"%s.%s[%s].%s -> [%s]" % (client_name, machine, old_state,
input, new_state),
file=file)
else:
# the RendezvousConnector emits message events as if
# they were state transitions, except that old_state
# and new_state are empty strings. "input" is one of
# R.connected, R.rx(type phase+side), R.tx(type
# phase), R.lost .
print("%s.%s.%s" % (client_name, machine, input),
file=file)
print("%s.%s.%s" % (client_name, machine, input), file=file)
file.flush()
def output_tracer(output):
print(" %s.%s.%s()" % (client_name, machine, output),
file=file)
print(" %s.%s.%s()" % (client_name, machine, output), file=file)
file.flush()
return output_tracer
def _set_trace(self, client_name, which, file):
names = {"B": self, "N": self._N, "M": self._M, "S": self._S,
"O": self._O, "K": self._K, "SK": self._K._SK, "R": self._R,
"RC": self._RC, "L": self._L, "A": self._A, "I": self._I,
"C": self._C, "T": self._T}
names = {
"B": self,
"N": self._N,
"M": self._M,
"S": self._S,
"O": self._O,
"K": self._K,
"SK": self._K._SK,
"R": self._R,
"RC": self._RC,
"L": self._L,
"A": self._A,
"I": self._I,
"C": self._C,
"T": self._T
}
for machine in which.split():
t = (lambda old_state, input, new_state, machine=machine:
self._print_trace(old_state, input, new_state,
@ -121,21 +138,30 @@ class Boss(object):
if machine == "I":
self._I.set_debug(t)
## def serialize(self):
## raise NotImplemented
# def serialize(self):
# raise NotImplemented
# and these are the state-machine transition functions, which don't take
# args
@m.state(initial=True)
def S0_empty(self): pass # pragma: no cover
def S0_empty(self):
pass # pragma: no cover
@m.state()
def S1_lonely(self): pass # pragma: no cover
def S1_lonely(self):
pass # pragma: no cover
@m.state()
def S2_happy(self): pass # pragma: no cover
def S2_happy(self):
pass # pragma: no cover
@m.state()
def S3_closing(self): pass # pragma: no cover
def S3_closing(self):
pass # pragma: no cover
@m.state(terminal=True)
def S4_closed(self): pass # pragma: no cover
def S4_closed(self):
pass # pragma: no cover
# from the Wormhole
@ -155,23 +181,28 @@ class Boss(object):
raise OnlyOneCodeError()
self._did_start_code = True
return self._C.input_code()
def allocate_code(self, code_length):
if self._did_start_code:
raise OnlyOneCodeError()
self._did_start_code = True
wl = PGPWordList()
self._C.allocate_code(code_length, wl)
def set_code(self, code):
validate_code(code) # can raise KeyFormatError
validate_code(code) # can raise KeyFormatError
if self._did_start_code:
raise OnlyOneCodeError()
self._did_start_code = True
self._C.set_code(code)
@m.input()
def send(self, plaintext): pass
def send(self, plaintext):
pass
@m.input()
def close(self): pass
def close(self):
pass
# from RendezvousConnector:
# * "rx_welcome" is the Welcome message, which might signal an error, or
@ -190,26 +221,36 @@ class Boss(object):
# delivering a new input (rx_error or something) while in the
# middle of processing the rx_welcome input, and I wasn't sure
# Automat would handle that correctly.
self._W.got_welcome(welcome) # TODO: let this raise WelcomeError?
self._W.got_welcome(welcome) # TODO: let this raise WelcomeError?
except WelcomeError as welcome_error:
self.rx_unwelcome(welcome_error)
@m.input()
def rx_unwelcome(self, welcome_error): pass
def rx_unwelcome(self, welcome_error):
pass
@m.input()
def rx_error(self, errmsg, orig): pass
def rx_error(self, errmsg, orig):
pass
@m.input()
def error(self, err): pass
def error(self, err):
pass
# from Code (provoked by input/allocate/set_code)
@m.input()
def got_code(self, code): pass
def got_code(self, code):
pass
# Key sends (got_key, scared)
# Receive sends (got_message, happy, got_verifier, scared)
@m.input()
def happy(self): pass
def happy(self):
pass
@m.input()
def scared(self): pass
def scared(self):
pass
def got_message(self, phase, plaintext):
assert isinstance(phase, type("")), type(phase)
@ -222,22 +263,32 @@ class Boss(object):
# Ignore unrecognized phases, for forwards-compatibility. Use
# log.err so tests will catch surprises.
log.err(_UnknownPhaseError("received unknown phase '%s'" % phase))
@m.input()
def _got_version(self, plaintext): pass
def _got_version(self, plaintext):
pass
@m.input()
def _got_phase(self, phase, plaintext): pass
def _got_phase(self, phase, plaintext):
pass
@m.input()
def got_key(self, key): pass
def got_key(self, key):
pass
@m.input()
def got_verifier(self, verifier): pass
def got_verifier(self, verifier):
pass
# Terminator sends closed
@m.input()
def closed(self): pass
def closed(self):
pass
@m.output()
def do_got_code(self, code):
self._W.got_code(code)
@m.output()
def process_version(self, plaintext):
# most of this is wormhole-to-wormhole, ignored for now
@ -256,21 +307,25 @@ class Boss(object):
@m.output()
def close_unwelcome(self, welcome_error):
#assert isinstance(err, WelcomeError)
# assert isinstance(err, WelcomeError)
self._result = welcome_error
self._T.close("unwelcome")
@m.output()
def close_error(self, errmsg, orig):
self._result = ServerError(errmsg)
self._T.close("errory")
@m.output()
def close_scared(self):
self._result = WrongPasswordError()
self._T.close("scary")
@m.output()
def close_lonely(self):
self._result = LonelyError()
self._T.close("lonely")
@m.output()
def close_happy(self):
self._result = "happy"
@ -279,9 +334,11 @@ class Boss(object):
@m.output()
def W_got_key(self, key):
self._W.got_key(key)
@m.output()
def W_got_verifier(self, verifier):
self._W.got_verifier(verifier)
@m.output()
def W_received(self, phase, plaintext):
assert isinstance(phase, six.integer_types), type(phase)
@ -293,7 +350,7 @@ class Boss(object):
@m.output()
def W_close_with_error(self, err):
self._result = err # exception
self._result = err # exception
self._W.closed(self._result)
@m.output()

View File

@ -7,21 +7,25 @@ from . import _interfaces
from ._nameplate import validate_nameplate
from .errors import KeyFormatError
def validate_code(code):
if ' ' in code:
raise KeyFormatError("Code '%s' contains spaces." % code)
nameplate = code.split("-", 2)[0]
validate_nameplate(nameplate) # can raise KeyFormatError
validate_nameplate(nameplate) # can raise KeyFormatError
def first(outputs):
return list(outputs)[0]
@attrs
@implementer(_interfaces.ICode)
class Code(object):
_timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def wire(self, boss, allocator, nameplate, key, input):
self._B = _interfaces.IBoss(boss)
@ -31,36 +35,55 @@ class Code(object):
self._I = _interfaces.IInput(input)
@m.state(initial=True)
def S0_idle(self): pass # pragma: no cover
def S0_idle(self):
pass # pragma: no cover
@m.state()
def S1_inputting_nameplate(self): pass # pragma: no cover
def S1_inputting_nameplate(self):
pass # pragma: no cover
@m.state()
def S2_inputting_words(self): pass # pragma: no cover
def S2_inputting_words(self):
pass # pragma: no cover
@m.state()
def S3_allocating(self): pass # pragma: no cover
def S3_allocating(self):
pass # pragma: no cover
@m.state()
def S4_known(self): pass # pragma: no cover
def S4_known(self):
pass # pragma: no cover
# from App
@m.input()
def allocate_code(self, length, wordlist): pass
def allocate_code(self, length, wordlist):
pass
@m.input()
def input_code(self): pass
def input_code(self):
pass
def set_code(self, code):
validate_code(code) # can raise KeyFormatError
validate_code(code) # can raise KeyFormatError
self._set_code(code)
@m.input()
def _set_code(self, code): pass
def _set_code(self, code):
pass
# from Allocator
@m.input()
def allocated(self, nameplate, code): pass
def allocated(self, nameplate, code):
pass
# from Input
@m.input()
def got_nameplate(self, nameplate): pass
def got_nameplate(self, nameplate):
pass
@m.input()
def finished_input(self, code): pass
def finished_input(self, code):
pass
@m.output()
def do_set_code(self, code):
@ -72,9 +95,11 @@ class Code(object):
@m.output()
def do_start_input(self):
return self._I.start()
@m.output()
def do_middle_input(self, nameplate):
self._N.set_nameplate(nameplate)
@m.output()
def do_finish_input(self, code):
self._B.got_code(code)
@ -83,19 +108,24 @@ class Code(object):
@m.output()
def do_start_allocate(self, length, wordlist):
self._A.allocate(length, wordlist)
@m.output()
def do_finish_allocate(self, nameplate, code):
assert code.startswith(nameplate+"-"), (nameplate, code)
assert code.startswith(nameplate + "-"), (nameplate, code)
self._N.set_nameplate(nameplate)
self._B.got_code(code)
self._K.got_code(code)
S0_idle.upon(_set_code, enter=S4_known, outputs=[do_set_code])
S0_idle.upon(input_code, enter=S1_inputting_nameplate,
outputs=[do_start_input], collector=first)
S1_inputting_nameplate.upon(got_nameplate, enter=S2_inputting_words,
outputs=[do_middle_input])
S2_inputting_words.upon(finished_input, enter=S4_known,
outputs=[do_finish_input])
S0_idle.upon(allocate_code, enter=S3_allocating, outputs=[do_start_allocate])
S0_idle.upon(
input_code,
enter=S1_inputting_nameplate,
outputs=[do_start_input],
collector=first)
S1_inputting_nameplate.upon(
got_nameplate, enter=S2_inputting_words, outputs=[do_middle_input])
S2_inputting_words.upon(
finished_input, enter=S4_known, outputs=[do_finish_input])
S0_idle.upon(
allocate_code, enter=S3_allocating, outputs=[do_start_allocate])
S3_allocating.upon(allocated, enter=S4_known, outputs=[do_finish_allocate])

View File

@ -1,25 +1,31 @@
from __future__ import print_function, absolute_import, unicode_literals
from __future__ import absolute_import, print_function, unicode_literals
# We use 'threading' defensively here, to detect if we're being called from a
# non-main thread. _rlcompleter.py is the only internal Wormhole code that
# deliberately creates a new thread.
import threading
from zope.interface import implementer
from attr import attrs, attrib
from attr import attrib, attrs
from attr.validators import provides
from twisted.internet import defer
from automat import MethodicalMachine
from twisted.internet import defer
from zope.interface import implementer
from . import _interfaces, errors
from ._nameplate import validate_nameplate
def first(outputs):
return list(outputs)[0]
@attrs
@implementer(_interfaces.IInput)
class Input(object):
_timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self):
self._all_nameplates = set()
@ -30,7 +36,8 @@ class Input(object):
def set_debug(self, f):
self._trace = f
def _debug(self, what): # pragma: no cover
def _debug(self, what): # pragma: no cover
if self._trace:
self._trace(old_state="", input=what, new_state="")
@ -46,55 +53,80 @@ class Input(object):
return d
@m.state(initial=True)
def S0_idle(self): pass # pragma: no cover
def S0_idle(self):
pass # pragma: no cover
@m.state()
def S1_typing_nameplate(self): pass # pragma: no cover
def S1_typing_nameplate(self):
pass # pragma: no cover
@m.state()
def S2_typing_code_no_wordlist(self): pass # pragma: no cover
def S2_typing_code_no_wordlist(self):
pass # pragma: no cover
@m.state()
def S3_typing_code_yes_wordlist(self): pass # pragma: no cover
def S3_typing_code_yes_wordlist(self):
pass # pragma: no cover
@m.state(terminal=True)
def S4_done(self): pass # pragma: no cover
def S4_done(self):
pass # pragma: no cover
# from Code
@m.input()
def start(self): pass
def start(self):
pass
# from Lister
@m.input()
def got_nameplates(self, all_nameplates): pass
def got_nameplates(self, all_nameplates):
pass
# from Nameplate
@m.input()
def got_wordlist(self, wordlist): pass
def got_wordlist(self, wordlist):
pass
# API provided to app as ICodeInputHelper
@m.input()
def refresh_nameplates(self): pass
def refresh_nameplates(self):
pass
@m.input()
def get_nameplate_completions(self, prefix): pass
def get_nameplate_completions(self, prefix):
pass
def choose_nameplate(self, nameplate):
validate_nameplate(nameplate) # can raise KeyFormatError
validate_nameplate(nameplate) # can raise KeyFormatError
self._choose_nameplate(nameplate)
@m.input()
def _choose_nameplate(self, nameplate): pass
def _choose_nameplate(self, nameplate):
pass
@m.input()
def get_word_completions(self, prefix): pass
def get_word_completions(self, prefix):
pass
@m.input()
def choose_words(self, words): pass
def choose_words(self, words):
pass
@m.output()
def do_start(self):
self._start_timing = self._timing.add("input code", waiting="user")
self._L.refresh()
return Helper(self)
@m.output()
def do_refresh(self):
self._L.refresh()
@m.output()
def record_nameplates(self, all_nameplates):
# we get a set of nameplate id strings
self._all_nameplates = all_nameplates
@m.output()
def _get_nameplate_completions(self, prefix):
completions = set()
@ -102,17 +134,20 @@ class Input(object):
if nameplate.startswith(prefix):
# TODO: it's a little weird that Input is responsible for the
# hyphen on nameplates, but WordList owns it for words
completions.add(nameplate+"-")
completions.add(nameplate + "-")
return completions
@m.output()
def record_all_nameplates(self, nameplate):
self._nameplate = nameplate
self._C.got_nameplate(nameplate)
@m.output()
def record_wordlist(self, wordlist):
from ._rlcompleter import debug
debug(" -record_wordlist")
self._wordlist = wordlist
@m.output()
def notify_wordlist_waiters(self, wordlist):
while self._wordlist_waiters:
@ -122,6 +157,7 @@ class Input(object):
@m.output()
def no_word_completions(self, prefix):
return set()
@m.output()
def _get_word_completions(self, prefix):
assert self._wordlist
@ -130,21 +166,27 @@ class Input(object):
@m.output()
def raise_must_choose_nameplate1(self, prefix):
raise errors.MustChooseNameplateFirstError()
@m.output()
def raise_must_choose_nameplate2(self, words):
raise errors.MustChooseNameplateFirstError()
@m.output()
def raise_already_chose_nameplate1(self):
raise errors.AlreadyChoseNameplateError()
@m.output()
def raise_already_chose_nameplate2(self, prefix):
raise errors.AlreadyChoseNameplateError()
@m.output()
def raise_already_chose_nameplate3(self, nameplate):
raise errors.AlreadyChoseNameplateError()
@m.output()
def raise_already_chose_words1(self, prefix):
raise errors.AlreadyChoseWordsError()
@m.output()
def raise_already_chose_words2(self, words):
raise errors.AlreadyChoseWordsError()
@ -155,88 +197,110 @@ class Input(object):
self._start_timing.finish()
self._C.finished_input(code)
S0_idle.upon(start, enter=S1_typing_nameplate,
outputs=[do_start], collector=first)
S0_idle.upon(
start, enter=S1_typing_nameplate, outputs=[do_start], collector=first)
# wormholes that don't use input_code (i.e. they use allocate_code or
# generate_code) will never start() us, but Nameplate will give us a
# wordlist anyways (as soon as the nameplate is claimed), so handle it.
S0_idle.upon(got_wordlist, enter=S0_idle, outputs=[record_wordlist,
notify_wordlist_waiters])
S1_typing_nameplate.upon(got_nameplates, enter=S1_typing_nameplate,
outputs=[record_nameplates])
S0_idle.upon(
got_wordlist,
enter=S0_idle,
outputs=[record_wordlist, notify_wordlist_waiters])
S1_typing_nameplate.upon(
got_nameplates, enter=S1_typing_nameplate, outputs=[record_nameplates])
# but wormholes that *do* use input_code should not get got_wordlist
# until after we tell Code that we got_nameplate, which is the earliest
# it can be claimed
S1_typing_nameplate.upon(refresh_nameplates, enter=S1_typing_nameplate,
outputs=[do_refresh])
S1_typing_nameplate.upon(get_nameplate_completions,
enter=S1_typing_nameplate,
outputs=[_get_nameplate_completions],
collector=first)
S1_typing_nameplate.upon(_choose_nameplate, enter=S2_typing_code_no_wordlist,
outputs=[record_all_nameplates])
S1_typing_nameplate.upon(get_word_completions,
enter=S1_typing_nameplate,
outputs=[raise_must_choose_nameplate1])
S1_typing_nameplate.upon(choose_words, enter=S1_typing_nameplate,
outputs=[raise_must_choose_nameplate2])
S1_typing_nameplate.upon(
refresh_nameplates, enter=S1_typing_nameplate, outputs=[do_refresh])
S1_typing_nameplate.upon(
get_nameplate_completions,
enter=S1_typing_nameplate,
outputs=[_get_nameplate_completions],
collector=first)
S1_typing_nameplate.upon(
_choose_nameplate,
enter=S2_typing_code_no_wordlist,
outputs=[record_all_nameplates])
S1_typing_nameplate.upon(
get_word_completions,
enter=S1_typing_nameplate,
outputs=[raise_must_choose_nameplate1])
S1_typing_nameplate.upon(
choose_words,
enter=S1_typing_nameplate,
outputs=[raise_must_choose_nameplate2])
S2_typing_code_no_wordlist.upon(got_nameplates,
enter=S2_typing_code_no_wordlist, outputs=[])
S2_typing_code_no_wordlist.upon(got_wordlist,
enter=S3_typing_code_yes_wordlist,
outputs=[record_wordlist,
notify_wordlist_waiters])
S2_typing_code_no_wordlist.upon(refresh_nameplates,
enter=S2_typing_code_no_wordlist,
outputs=[raise_already_chose_nameplate1])
S2_typing_code_no_wordlist.upon(get_nameplate_completions,
enter=S2_typing_code_no_wordlist,
outputs=[raise_already_chose_nameplate2])
S2_typing_code_no_wordlist.upon(_choose_nameplate,
enter=S2_typing_code_no_wordlist,
outputs=[raise_already_chose_nameplate3])
S2_typing_code_no_wordlist.upon(get_word_completions,
enter=S2_typing_code_no_wordlist,
outputs=[no_word_completions],
collector=first)
S2_typing_code_no_wordlist.upon(choose_words, enter=S4_done,
outputs=[do_words])
S2_typing_code_no_wordlist.upon(
got_nameplates, enter=S2_typing_code_no_wordlist, outputs=[])
S2_typing_code_no_wordlist.upon(
got_wordlist,
enter=S3_typing_code_yes_wordlist,
outputs=[record_wordlist, notify_wordlist_waiters])
S2_typing_code_no_wordlist.upon(
refresh_nameplates,
enter=S2_typing_code_no_wordlist,
outputs=[raise_already_chose_nameplate1])
S2_typing_code_no_wordlist.upon(
get_nameplate_completions,
enter=S2_typing_code_no_wordlist,
outputs=[raise_already_chose_nameplate2])
S2_typing_code_no_wordlist.upon(
_choose_nameplate,
enter=S2_typing_code_no_wordlist,
outputs=[raise_already_chose_nameplate3])
S2_typing_code_no_wordlist.upon(
get_word_completions,
enter=S2_typing_code_no_wordlist,
outputs=[no_word_completions],
collector=first)
S2_typing_code_no_wordlist.upon(
choose_words, enter=S4_done, outputs=[do_words])
S3_typing_code_yes_wordlist.upon(got_nameplates,
enter=S3_typing_code_yes_wordlist,
outputs=[])
S3_typing_code_yes_wordlist.upon(
got_nameplates, enter=S3_typing_code_yes_wordlist, outputs=[])
# got_wordlist: should never happen
S3_typing_code_yes_wordlist.upon(refresh_nameplates,
enter=S3_typing_code_yes_wordlist,
outputs=[raise_already_chose_nameplate1])
S3_typing_code_yes_wordlist.upon(get_nameplate_completions,
enter=S3_typing_code_yes_wordlist,
outputs=[raise_already_chose_nameplate2])
S3_typing_code_yes_wordlist.upon(_choose_nameplate,
enter=S3_typing_code_yes_wordlist,
outputs=[raise_already_chose_nameplate3])
S3_typing_code_yes_wordlist.upon(get_word_completions,
enter=S3_typing_code_yes_wordlist,
outputs=[_get_word_completions],
collector=first)
S3_typing_code_yes_wordlist.upon(choose_words, enter=S4_done,
outputs=[do_words])
S3_typing_code_yes_wordlist.upon(
refresh_nameplates,
enter=S3_typing_code_yes_wordlist,
outputs=[raise_already_chose_nameplate1])
S3_typing_code_yes_wordlist.upon(
get_nameplate_completions,
enter=S3_typing_code_yes_wordlist,
outputs=[raise_already_chose_nameplate2])
S3_typing_code_yes_wordlist.upon(
_choose_nameplate,
enter=S3_typing_code_yes_wordlist,
outputs=[raise_already_chose_nameplate3])
S3_typing_code_yes_wordlist.upon(
get_word_completions,
enter=S3_typing_code_yes_wordlist,
outputs=[_get_word_completions],
collector=first)
S3_typing_code_yes_wordlist.upon(
choose_words, enter=S4_done, outputs=[do_words])
S4_done.upon(got_nameplates, enter=S4_done, outputs=[])
S4_done.upon(got_wordlist, enter=S4_done, outputs=[])
S4_done.upon(refresh_nameplates,
enter=S4_done,
outputs=[raise_already_chose_nameplate1])
S4_done.upon(get_nameplate_completions,
enter=S4_done,
outputs=[raise_already_chose_nameplate2])
S4_done.upon(_choose_nameplate, enter=S4_done,
outputs=[raise_already_chose_nameplate3])
S4_done.upon(get_word_completions, enter=S4_done,
outputs=[raise_already_chose_words1])
S4_done.upon(choose_words, enter=S4_done,
outputs=[raise_already_chose_words2])
S4_done.upon(
refresh_nameplates,
enter=S4_done,
outputs=[raise_already_chose_nameplate1])
S4_done.upon(
get_nameplate_completions,
enter=S4_done,
outputs=[raise_already_chose_nameplate2])
S4_done.upon(
_choose_nameplate,
enter=S4_done,
outputs=[raise_already_chose_nameplate3])
S4_done.upon(
get_word_completions,
enter=S4_done,
outputs=[raise_already_chose_words1])
S4_done.upon(
choose_words, enter=S4_done, outputs=[raise_already_chose_words2])
# we only expose the Helper to application code, not _Input
@attrs
@ -250,20 +314,25 @@ class Helper(object):
def refresh_nameplates(self):
assert threading.current_thread().ident == self._main_thread
self._input.refresh_nameplates()
def get_nameplate_completions(self, prefix):
assert threading.current_thread().ident == self._main_thread
return self._input.get_nameplate_completions(prefix)
def choose_nameplate(self, nameplate):
assert threading.current_thread().ident == self._main_thread
self._input._debug("I.choose_nameplate")
self._input.choose_nameplate(nameplate)
self._input._debug("I.choose_nameplate finished")
def when_wordlist_is_available(self):
assert threading.current_thread().ident == self._main_thread
return self._input.when_wordlist_is_available()
def get_word_completions(self, prefix):
assert threading.current_thread().ident == self._main_thread
return self._input.get_word_completions(prefix)
def choose_words(self, words):
assert threading.current_thread().ident == self._main_thread
self._input._debug("I.choose_words")

View File

@ -3,64 +3,105 @@ from zope.interface import Interface
# These interfaces are private: we use them as markers to detect
# swapped argument bugs in the various .wire() calls
class IWormhole(Interface):
"""Internal: this contains the methods invoked 'from below'."""
def got_welcome(welcome):
pass
def got_code(code):
pass
def got_key(key):
pass
def got_verifier(verifier):
pass
def got_versions(versions):
pass
def received(plaintext):
pass
def closed(result):
pass
class IBoss(Interface):
pass
class INameplate(Interface):
pass
class IMailbox(Interface):
pass
class ISend(Interface):
pass
class IOrder(Interface):
pass
class IKey(Interface):
pass
class IReceive(Interface):
pass
class IRendezvousConnector(Interface):
pass
class ILister(Interface):
pass
class ICode(Interface):
pass
class IInput(Interface):
pass
class IAllocator(Interface):
pass
class ITerminator(Interface):
pass
class ITiming(Interface):
pass
class ITorManager(Interface):
pass
class IWordlist(Interface):
def choose_words(length):
"""Randomly select LENGTH words, join them with hyphens, return the
result."""
def get_completions(prefix):
"""Return a list of all suffixes that could complete the given
prefix."""
# These interfaces are public, and are re-exported by __init__.py
class IDeferredWormhole(Interface):
def get_welcome():
"""
@ -277,6 +318,7 @@ class IDeferredWormhole(Interface):
:rtype: ``Deferred``
"""
class IInputHelper(Interface):
def refresh_nameplates():
"""
@ -389,5 +431,5 @@ class IInputHelper(Interface):
"""
class IJournal(Interface): # TODO: this needs to be public
class IJournal(Interface): # TODO: this needs to be public
pass

View File

@ -1,41 +1,50 @@
from __future__ import print_function, absolute_import, unicode_literals
from __future__ import absolute_import, print_function, unicode_literals
from hashlib import sha256
import six
from zope.interface import implementer
from attr import attrs, attrib
from attr.validators import provides, instance_of
from spake2 import SPAKE2_Symmetric
from hkdf import Hkdf
from nacl.secret import SecretBox
from nacl.exceptions import CryptoError
from nacl import utils
from attr import attrib, attrs
from attr.validators import instance_of, provides
from automat import MethodicalMachine
from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes,
bytes_to_dict, dict_to_bytes)
from hkdf import Hkdf
from nacl import utils
from nacl.exceptions import CryptoError
from nacl.secret import SecretBox
from spake2 import SPAKE2_Symmetric
from zope.interface import implementer
from . import _interfaces
from .util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes,
hexstr_to_bytes, to_bytes)
CryptoError
__all__ = ["derive_key", "derive_phase_key", "CryptoError",
"Key"]
__all__ = ["derive_key", "derive_phase_key", "CryptoError", "Key"]
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen)
def derive_key(key, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(key, type(b"")): raise TypeError(type(key))
if not isinstance(purpose, type(b"")): raise TypeError(type(purpose))
if not isinstance(length, six.integer_types): raise TypeError(type(length))
if not isinstance(key, type(b"")):
raise TypeError(type(key))
if not isinstance(purpose, type(b"")):
raise TypeError(type(purpose))
if not isinstance(length, six.integer_types):
raise TypeError(type(length))
return HKDF(key, length, CTXinfo=purpose)
def derive_phase_key(key, side, phase):
assert isinstance(side, type("")), type(side)
assert isinstance(phase, type("")), type(phase)
side_bytes = side.encode("ascii")
phase_bytes = phase.encode("ascii")
purpose = (b"wormhole:phase:"
+ sha256(side_bytes).digest()
+ sha256(phase_bytes).digest())
purpose = (b"wormhole:phase:" + sha256(side_bytes).digest() +
sha256(phase_bytes).digest())
return derive_key(key, purpose)
def decrypt_data(key, encrypted):
assert isinstance(key, type(b"")), type(key)
assert isinstance(encrypted, type(b"")), type(encrypted)
@ -44,6 +53,7 @@ def decrypt_data(key, encrypted):
data = box.decrypt(encrypted)
return data
def encrypt_data(key, plaintext):
assert isinstance(key, type(b"")), type(key)
assert isinstance(plaintext, type(b"")), type(plaintext)
@ -52,10 +62,12 @@ def encrypt_data(key, plaintext):
nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(plaintext, nonce)
# the Key we expose to callers (Boss, Ordering) is responsible for sorting
# the two messages (got_code and got_pake), then delivering them to
# _SortedKey in the right order.
@attrs
@implementer(_interfaces.IKey)
class Key(object):
@ -64,40 +76,54 @@ class Key(object):
_side = attrib(validator=instance_of(type(u"")))
_timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self):
self._SK = _SortedKey(self._appid, self._versions, self._side,
self._timing)
self._debug_pake_stashed = False # for tests
self._debug_pake_stashed = False # for tests
def wire(self, boss, mailbox, receive):
self._SK.wire(boss, mailbox, receive)
@m.state(initial=True)
def S00(self): pass # pragma: no cover
def S00(self):
pass # pragma: no cover
@m.state()
def S01(self): pass # pragma: no cover
def S01(self):
pass # pragma: no cover
@m.state()
def S10(self): pass # pragma: no cover
def S10(self):
pass # pragma: no cover
@m.state()
def S11(self): pass # pragma: no cover
def S11(self):
pass # pragma: no cover
@m.input()
def got_code(self, code): pass
def got_code(self, code):
pass
@m.input()
def got_pake(self, body): pass
def got_pake(self, body):
pass
@m.output()
def stash_pake(self, body):
self._pake = body
self._debug_pake_stashed = True
@m.output()
def deliver_code(self, code):
self._SK.got_code(code)
@m.output()
def deliver_pake(self, body):
self._SK.got_pake(body)
@m.output()
def deliver_code_and_stashed_pake(self, code):
self._SK.got_code(code)
@ -108,6 +134,7 @@ class Key(object):
S00.upon(got_pake, enter=S01, outputs=[stash_pake])
S01.upon(got_code, enter=S11, outputs=[deliver_code_and_stashed_pake])
@attrs
class _SortedKey(object):
_appid = attrib(validator=instance_of(type(u"")))
@ -115,7 +142,8 @@ class _SortedKey(object):
_side = attrib(validator=instance_of(type(u"")))
_timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def wire(self, boss, mailbox, receive):
self._B = _interfaces.IBoss(boss)
@ -123,17 +151,25 @@ class _SortedKey(object):
self._R = _interfaces.IReceive(receive)
@m.state(initial=True)
def S0_know_nothing(self): pass # pragma: no cover
def S0_know_nothing(self):
pass # pragma: no cover
@m.state()
def S1_know_code(self): pass # pragma: no cover
def S1_know_code(self):
pass # pragma: no cover
@m.state()
def S2_know_key(self): pass # pragma: no cover
def S2_know_key(self):
pass # pragma: no cover
@m.state(terminal=True)
def S3_scared(self): pass # pragma: no cover
def S3_scared(self):
pass # pragma: no cover
# from Boss
@m.input()
def got_code(self, code): pass
def got_code(self, code):
pass
# from Ordering
def got_pake(self, body):
@ -143,16 +179,20 @@ class _SortedKey(object):
self.got_pake_good(hexstr_to_bytes(payload["pake_v1"]))
else:
self.got_pake_bad()
@m.input()
def got_pake_good(self, msg2): pass
def got_pake_good(self, msg2):
pass
@m.input()
def got_pake_bad(self): pass
def got_pake_bad(self):
pass
@m.output()
def build_pake(self, code):
with self._timing.add("pake1", waiting="crypto"):
self._sp = SPAKE2_Symmetric(to_bytes(code),
idSymmetric=to_bytes(self._appid))
self._sp = SPAKE2_Symmetric(
to_bytes(code), idSymmetric=to_bytes(self._appid))
msg1 = self._sp.start()
body = dict_to_bytes({"pake_v1": bytes_to_hexstr(msg1)})
self._M.add_message("pake", body)
@ -160,6 +200,7 @@ class _SortedKey(object):
@m.output()
def scared(self):
self._B.scared()
@m.output()
def compute_key(self, msg2):
assert isinstance(msg2, type(b""))

View File

@ -1,16 +1,20 @@
from __future__ import print_function, absolute_import, unicode_literals
from zope.interface import implementer
from attr import attrs, attrib
from __future__ import absolute_import, print_function, unicode_literals
from attr import attrib, attrs
from attr.validators import provides
from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces
@attrs
@implementer(_interfaces.ILister)
class Lister(object):
_timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def wire(self, rendezvous_connector, input):
self._RC = _interfaces.IRendezvousConnector(rendezvous_connector)
@ -26,26 +30,41 @@ class Lister(object):
# request arrives, both requests will be satisfied by the same response.
@m.state(initial=True)
def S0A_idle_disconnected(self): pass # pragma: no cover
def S0A_idle_disconnected(self):
pass # pragma: no cover
@m.state()
def S1A_wanting_disconnected(self): pass # pragma: no cover
def S1A_wanting_disconnected(self):
pass # pragma: no cover
@m.state()
def S0B_idle_connected(self): pass # pragma: no cover
def S0B_idle_connected(self):
pass # pragma: no cover
@m.state()
def S1B_wanting_connected(self): pass # pragma: no cover
def S1B_wanting_connected(self):
pass # pragma: no cover
@m.input()
def connected(self): pass
def connected(self):
pass
@m.input()
def lost(self): pass
def lost(self):
pass
@m.input()
def refresh(self): pass
def refresh(self):
pass
@m.input()
def rx_nameplates(self, all_nameplates): pass
def rx_nameplates(self, all_nameplates):
pass
@m.output()
def RC_tx_list(self):
self._RC.tx_list()
@m.output()
def I_got_nameplates(self, all_nameplates):
# We get a set of nameplate ids. There may be more attributes in the
@ -56,18 +75,19 @@ class Lister(object):
S0A_idle_disconnected.upon(connected, enter=S0B_idle_connected, outputs=[])
S0B_idle_connected.upon(lost, enter=S0A_idle_disconnected, outputs=[])
S0A_idle_disconnected.upon(refresh,
enter=S1A_wanting_disconnected, outputs=[])
S1A_wanting_disconnected.upon(refresh,
enter=S1A_wanting_disconnected, outputs=[])
S1A_wanting_disconnected.upon(connected, enter=S1B_wanting_connected,
outputs=[RC_tx_list])
S0B_idle_connected.upon(refresh, enter=S1B_wanting_connected,
outputs=[RC_tx_list])
S0B_idle_connected.upon(rx_nameplates, enter=S0B_idle_connected,
outputs=[I_got_nameplates])
S1B_wanting_connected.upon(lost, enter=S1A_wanting_disconnected, outputs=[])
S1B_wanting_connected.upon(refresh, enter=S1B_wanting_connected,
outputs=[RC_tx_list])
S1B_wanting_connected.upon(rx_nameplates, enter=S0B_idle_connected,
outputs=[I_got_nameplates])
S0A_idle_disconnected.upon(
refresh, enter=S1A_wanting_disconnected, outputs=[])
S1A_wanting_disconnected.upon(
refresh, enter=S1A_wanting_disconnected, outputs=[])
S1A_wanting_disconnected.upon(
connected, enter=S1B_wanting_connected, outputs=[RC_tx_list])
S0B_idle_connected.upon(
refresh, enter=S1B_wanting_connected, outputs=[RC_tx_list])
S0B_idle_connected.upon(
rx_nameplates, enter=S0B_idle_connected, outputs=[I_got_nameplates])
S1B_wanting_connected.upon(
lost, enter=S1A_wanting_disconnected, outputs=[])
S1B_wanting_connected.upon(
refresh, enter=S1B_wanting_connected, outputs=[RC_tx_list])
S1B_wanting_connected.upon(
rx_nameplates, enter=S0B_idle_connected, outputs=[I_got_nameplates])

View File

@ -1,16 +1,20 @@
from __future__ import print_function, absolute_import, unicode_literals
from zope.interface import implementer
from attr import attrs, attrib
from __future__ import absolute_import, print_function, unicode_literals
from attr import attrib, attrs
from attr.validators import instance_of
from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces
@attrs
@implementer(_interfaces.IMailbox)
class Mailbox(object):
_side = attrib(validator=instance_of(type(u"")))
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self):
self._mailbox = None
@ -29,52 +33,68 @@ class Mailbox(object):
# S0: know nothing
@m.state(initial=True)
def S0A(self): pass # pragma: no cover
def S0A(self):
pass # pragma: no cover
@m.state()
def S0B(self): pass # pragma: no cover
def S0B(self):
pass # pragma: no cover
# S1: mailbox known, not opened
@m.state()
def S1A(self): pass # pragma: no cover
def S1A(self):
pass # pragma: no cover
# S2: mailbox known, opened
# We've definitely tried to open the mailbox at least once, but it must
# be re-opened with each connection, because open() is also subscribe()
@m.state()
def S2A(self): pass # pragma: no cover
def S2A(self):
pass # pragma: no cover
@m.state()
def S2B(self): pass # pragma: no cover
def S2B(self):
pass # pragma: no cover
# S3: closing
@m.state()
def S3A(self): pass # pragma: no cover
def S3A(self):
pass # pragma: no cover
@m.state()
def S3B(self): pass # pragma: no cover
def S3B(self):
pass # pragma: no cover
# S4: closed. We no longer care whether we're connected or not
#@m.state()
#def S4A(self): pass
#@m.state()
#def S4B(self): pass
# @m.state()
# def S4A(self): pass
# @m.state()
# def S4B(self): pass
@m.state(terminal=True)
def S4(self): pass # pragma: no cover
def S4(self):
pass # pragma: no cover
S4A = S4
S4B = S4
# from Terminator
@m.input()
def close(self, mood): pass
def close(self, mood):
pass
# from Nameplate
@m.input()
def got_mailbox(self, mailbox): pass
def got_mailbox(self, mailbox):
pass
# from RendezvousConnector
@m.input()
def connected(self): pass
def connected(self):
pass
@m.input()
def lost(self): pass
def lost(self):
pass
def rx_message(self, side, phase, body):
assert isinstance(side, type("")), type(side)
@ -84,73 +104,91 @@ class Mailbox(object):
self.rx_message_ours(phase, body)
else:
self.rx_message_theirs(side, phase, body)
@m.input()
def rx_message_ours(self, phase, body): pass
def rx_message_ours(self, phase, body):
pass
@m.input()
def rx_message_theirs(self, side, phase, body): pass
def rx_message_theirs(self, side, phase, body):
pass
@m.input()
def rx_closed(self): pass
def rx_closed(self):
pass
# from Send or Key
@m.input()
def add_message(self, phase, body):
pass
@m.output()
def record_mailbox(self, mailbox):
self._mailbox = mailbox
@m.output()
def RC_tx_open(self):
assert self._mailbox
self._RC.tx_open(self._mailbox)
@m.output()
def queue(self, phase, body):
assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), (type(body), phase, body)
self._pending_outbound[phase] = body
@m.output()
def record_mailbox_and_RC_tx_open_and_drain(self, mailbox):
self._mailbox = mailbox
self._RC.tx_open(mailbox)
self._drain()
@m.output()
def drain(self):
self._drain()
def _drain(self):
for phase, body in self._pending_outbound.items():
self._RC.tx_add(phase, body)
@m.output()
def RC_tx_add(self, phase, body):
assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body)
self._RC.tx_add(phase, body)
@m.output()
def N_release_and_accept(self, side, phase, body):
self._N.release()
if phase not in self._processed:
self._processed.add(phase)
self._O.got_message(side, phase, body)
@m.output()
def RC_tx_close(self):
assert self._mood
self._RC_tx_close()
def _RC_tx_close(self):
self._RC.tx_close(self._mailbox, self._mood)
@m.output()
def dequeue(self, phase, body):
self._pending_outbound.pop(phase, None)
@m.output()
def record_mood(self, mood):
self._mood = mood
@m.output()
def record_mood_and_RC_tx_close(self, mood):
self._mood = mood
self._RC_tx_close()
@m.output()
def ignore_mood_and_T_mailbox_done(self, mood):
self._T.mailbox_done()
@m.output()
def T_mailbox_done(self):
self._T.mailbox_done()
@ -162,8 +200,10 @@ class Mailbox(object):
S0B.upon(lost, enter=S0A, outputs=[])
S0B.upon(add_message, enter=S0B, outputs=[queue])
S0B.upon(close, enter=S4B, outputs=[ignore_mood_and_T_mailbox_done])
S0B.upon(got_mailbox, enter=S2B,
outputs=[record_mailbox_and_RC_tx_open_and_drain])
S0B.upon(
got_mailbox,
enter=S2B,
outputs=[record_mailbox_and_RC_tx_open_and_drain])
S1A.upon(connected, enter=S2B, outputs=[RC_tx_open, drain])
S1A.upon(add_message, enter=S1A, outputs=[queue])
@ -192,4 +232,3 @@ class Mailbox(object):
S4.upon(rx_message_theirs, enter=S4, outputs=[])
S4.upon(rx_message_ours, enter=S4, outputs=[])
S4.upon(close, enter=S4, outputs=[])

View File

@ -1,7 +1,10 @@
from __future__ import print_function, absolute_import, unicode_literals
from __future__ import absolute_import, print_function, unicode_literals
import re
from zope.interface import implementer
from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces
from ._wordlist import PGPWordList
from .errors import KeyFormatError
@ -9,13 +12,15 @@ from .errors import KeyFormatError
def validate_nameplate(nameplate):
if not re.search(r'^\d+$', nameplate):
raise KeyFormatError("Nameplate '%s' must be numeric, with no spaces."
% nameplate)
raise KeyFormatError(
"Nameplate '%s' must be numeric, with no spaces." % nameplate)
@implementer(_interfaces.INameplate)
class Nameplate(object):
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __init__(self):
self._nameplate = None
@ -32,94 +37,125 @@ class Nameplate(object):
# S0: know nothing
@m.state(initial=True)
def S0A(self): pass # pragma: no cover
def S0A(self):
pass # pragma: no cover
@m.state()
def S0B(self): pass # pragma: no cover
def S0B(self):
pass # pragma: no cover
# S1: nameplate known, never claimed
@m.state()
def S1A(self): pass # pragma: no cover
def S1A(self):
pass # pragma: no cover
# S2: nameplate known, maybe claimed
@m.state()
def S2A(self): pass # pragma: no cover
def S2A(self):
pass # pragma: no cover
@m.state()
def S2B(self): pass # pragma: no cover
def S2B(self):
pass # pragma: no cover
# S3: nameplate claimed
@m.state()
def S3A(self): pass # pragma: no cover
def S3A(self):
pass # pragma: no cover
@m.state()
def S3B(self): pass # pragma: no cover
def S3B(self):
pass # pragma: no cover
# S4: maybe released
@m.state()
def S4A(self): pass # pragma: no cover
def S4A(self):
pass # pragma: no cover
@m.state()
def S4B(self): pass # pragma: no cover
def S4B(self):
pass # pragma: no cover
# S5: released
# we no longer care whether we're connected or not
#@m.state()
#def S5A(self): pass
#@m.state()
#def S5B(self): pass
# @m.state()
# def S5A(self): pass
# @m.state()
# def S5B(self): pass
@m.state()
def S5(self): pass # pragma: no cover
def S5(self):
pass # pragma: no cover
S5A = S5
S5B = S5
# from Boss
def set_nameplate(self, nameplate):
validate_nameplate(nameplate) # can raise KeyFormatError
validate_nameplate(nameplate) # can raise KeyFormatError
self._set_nameplate(nameplate)
@m.input()
def _set_nameplate(self, nameplate): pass
def _set_nameplate(self, nameplate):
pass
# from Mailbox
@m.input()
def release(self): pass
def release(self):
pass
# from Terminator
@m.input()
def close(self): pass
def close(self):
pass
# from RendezvousConnector
@m.input()
def connected(self): pass
@m.input()
def lost(self): pass
def connected(self):
pass
@m.input()
def rx_claimed(self, mailbox): pass
def lost(self):
pass
@m.input()
def rx_released(self): pass
def rx_claimed(self, mailbox):
pass
@m.input()
def rx_released(self):
pass
@m.output()
def record_nameplate(self, nameplate):
validate_nameplate(nameplate)
self._nameplate = nameplate
@m.output()
def record_nameplate_and_RC_tx_claim(self, nameplate):
validate_nameplate(nameplate)
self._nameplate = nameplate
self._RC.tx_claim(self._nameplate)
@m.output()
def RC_tx_claim(self):
# when invoked via M.connected(), we must use the stored nameplate
self._RC.tx_claim(self._nameplate)
@m.output()
def I_got_wordlist(self, mailbox):
# TODO select wordlist based on nameplate properties, in rx_claimed
wordlist = PGPWordList()
self._I.got_wordlist(wordlist)
@m.output()
def M_got_mailbox(self, mailbox):
self._M.got_mailbox(mailbox)
@m.output()
def RC_tx_release(self):
assert self._nameplate
self._RC.tx_release(self._nameplate)
@m.output()
def T_nameplate_done(self):
self._T.nameplate_done()
@ -127,8 +163,8 @@ class Nameplate(object):
S0A.upon(_set_nameplate, enter=S1A, outputs=[record_nameplate])
S0A.upon(connected, enter=S0B, outputs=[])
S0A.upon(close, enter=S5A, outputs=[T_nameplate_done])
S0B.upon(_set_nameplate, enter=S2B,
outputs=[record_nameplate_and_RC_tx_claim])
S0B.upon(
_set_nameplate, enter=S2B, outputs=[record_nameplate_and_RC_tx_claim])
S0B.upon(lost, enter=S0A, outputs=[])
S0B.upon(close, enter=S5A, outputs=[T_nameplate_done])
@ -144,7 +180,7 @@ class Nameplate(object):
S3A.upon(connected, enter=S3B, outputs=[])
S3A.upon(close, enter=S4A, outputs=[])
S3B.upon(lost, enter=S3A, outputs=[])
#S3B.upon(rx_claimed, enter=S3B, outputs=[]) # shouldn't happen
# S3B.upon(rx_claimed, enter=S3B, outputs=[]) # shouldn't happen
S3B.upon(release, enter=S4B, outputs=[RC_tx_release])
S3B.upon(close, enter=S4B, outputs=[RC_tx_release])
@ -153,7 +189,7 @@ class Nameplate(object):
S4B.upon(lost, enter=S4A, outputs=[])
S4B.upon(rx_claimed, enter=S4B, outputs=[])
S4B.upon(rx_released, enter=S5B, outputs=[T_nameplate_done])
S4B.upon(release, enter=S4B, outputs=[]) # mailbox is lazy
S4B.upon(release, enter=S4B, outputs=[]) # mailbox is lazy
# Mailbox doesn't remember how many times it's sent a release, and will
# re-send a new one for each peer message it receives. Ignoring it here
# is easier than adding a new pair of states to Mailbox.
@ -161,5 +197,5 @@ class Nameplate(object):
S5A.upon(connected, enter=S5B, outputs=[])
S5B.upon(lost, enter=S5A, outputs=[])
S5.upon(release, enter=S5, outputs=[]) # mailbox is lazy
S5.upon(release, enter=S5, outputs=[]) # mailbox is lazy
S5.upon(close, enter=S5, outputs=[])

View File

@ -1,32 +1,40 @@
from __future__ import print_function, absolute_import, unicode_literals
from zope.interface import implementer
from attr import attrs, attrib
from attr.validators import provides, instance_of
from __future__ import absolute_import, print_function, unicode_literals
from attr import attrib, attrs
from attr.validators import instance_of, provides
from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces
@attrs
@implementer(_interfaces.IOrder)
class Order(object):
_side = attrib(validator=instance_of(type(u"")))
_timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self):
self._key = None
self._queue = []
def wire(self, key, receive):
self._K = _interfaces.IKey(key)
self._R = _interfaces.IReceive(receive)
@m.state(initial=True)
def S0_no_pake(self): pass # pragma: no cover
def S0_no_pake(self):
pass # pragma: no cover
@m.state(terminal=True)
def S1_yes_pake(self): pass # pragma: no cover
def S1_yes_pake(self):
pass # pragma: no cover
def got_message(self, side, phase, body):
#print("ORDER[%s].got_message(%s)" % (self._side, phase))
# print("ORDER[%s].got_message(%s)" % (self._side, phase))
assert isinstance(side, type("")), type(phase)
assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body)
@ -36,9 +44,12 @@ class Order(object):
self.got_non_pake(side, phase, body)
@m.input()
def got_pake(self, side, phase, body): pass
def got_pake(self, side, phase, body):
pass
@m.input()
def got_non_pake(self, side, phase, body): pass
def got_non_pake(self, side, phase, body):
pass
@m.output()
def queue(self, side, phase, body):
@ -46,9 +57,11 @@ class Order(object):
assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body)
self._queue.append((side, phase, body))
@m.output()
def notify_key(self, side, phase, body):
self._K.got_pake(body)
@m.output()
def drain(self, side, phase, body):
del phase
@ -56,6 +69,7 @@ class Order(object):
for (side, phase, body) in self._queue:
self._deliver(side, phase, body)
self._queue[:] = []
@m.output()
def deliver(self, side, phase, body):
self._deliver(side, phase, body)

View File

@ -1,10 +1,13 @@
from __future__ import print_function, absolute_import, unicode_literals
from zope.interface import implementer
from attr import attrs, attrib
from attr.validators import provides, instance_of
from __future__ import absolute_import, print_function, unicode_literals
from attr import attrib, attrs
from attr.validators import instance_of, provides
from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces
from ._key import derive_key, derive_phase_key, decrypt_data, CryptoError
from ._key import CryptoError, decrypt_data, derive_key, derive_phase_key
@attrs
@implementer(_interfaces.IReceive)
@ -12,7 +15,8 @@ class Receive(object):
_side = attrib(validator=instance_of(type(u"")))
_timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self):
self._key = None
@ -22,13 +26,20 @@ class Receive(object):
self._S = _interfaces.ISend(send)
@m.state(initial=True)
def S0_unknown_key(self): pass # pragma: no cover
def S0_unknown_key(self):
pass # pragma: no cover
@m.state()
def S1_unverified_key(self): pass # pragma: no cover
def S1_unverified_key(self):
pass # pragma: no cover
@m.state()
def S2_verified_key(self): pass # pragma: no cover
def S2_verified_key(self):
pass # pragma: no cover
@m.state(terminal=True)
def S3_scared(self): pass # pragma: no cover
def S3_scared(self):
pass # pragma: no cover
# from Ordering
def got_message(self, side, phase, body):
@ -43,47 +54,56 @@ class Receive(object):
self.got_message_bad()
return
self.got_message_good(phase, plaintext)
@m.input()
def got_message_good(self, phase, plaintext): pass
def got_message_good(self, phase, plaintext):
pass
@m.input()
def got_message_bad(self): pass
def got_message_bad(self):
pass
# from Key
@m.input()
def got_key(self, key): pass
def got_key(self, key):
pass
@m.output()
def record_key(self, key):
self._key = key
@m.output()
def S_got_verified_key(self, phase, plaintext):
assert self._key
self._S.got_verified_key(self._key)
@m.output()
def W_happy(self, phase, plaintext):
self._B.happy()
@m.output()
def W_got_verifier(self, phase, plaintext):
self._B.got_verifier(derive_key(self._key, b"wormhole:verifier"))
@m.output()
def W_got_message(self, phase, plaintext):
assert isinstance(phase, type("")), type(phase)
assert isinstance(plaintext, type(b"")), type(plaintext)
self._B.got_message(phase, plaintext)
@m.output()
def W_scared(self):
self._B.scared()
S0_unknown_key.upon(got_key, enter=S1_unverified_key, outputs=[record_key])
S1_unverified_key.upon(got_message_good, enter=S2_verified_key,
outputs=[S_got_verified_key,
W_happy, W_got_verifier, W_got_message])
S1_unverified_key.upon(got_message_bad, enter=S3_scared,
outputs=[W_scared])
S2_verified_key.upon(got_message_bad, enter=S3_scared,
outputs=[W_scared])
S2_verified_key.upon(got_message_good, enter=S2_verified_key,
outputs=[W_got_message])
S1_unverified_key.upon(
got_message_good,
enter=S2_verified_key,
outputs=[S_got_verified_key, W_happy, W_got_verifier, W_got_message])
S1_unverified_key.upon(
got_message_bad, enter=S3_scared, outputs=[W_scared])
S2_verified_key.upon(got_message_bad, enter=S3_scared, outputs=[W_scared])
S2_verified_key.upon(
got_message_good, enter=S2_verified_key, outputs=[W_got_message])
S3_scared.upon(got_message_good, enter=S3_scared, outputs=[])
S3_scared.upon(got_message_bad, enter=S3_scared, outputs=[])

View File

@ -9,44 +9,47 @@ from twisted.internet import defer, endpoints, task
from twisted.application import internet
from autobahn.twisted import websocket
from . import _interfaces, errors
from .util import (bytes_to_hexstr, hexstr_to_bytes,
bytes_to_dict, dict_to_bytes)
from .util import (bytes_to_hexstr, hexstr_to_bytes, bytes_to_dict,
dict_to_bytes)
class WSClient(websocket.WebSocketClientProtocol):
def onConnect(self, response):
# this fires during WebSocket negotiation, and isn't very useful
# unless you want to modify the protocol settings
#print("onConnect", response)
# print("onConnect", response)
pass
def onOpen(self, *args):
# this fires when the WebSocket is ready to go. No arguments
#print("onOpen", args)
#self.wormhole_open = True
# print("onOpen", args)
# self.wormhole_open = True
self._RC.ws_open(self)
def onMessage(self, payload, isBinary):
assert not isBinary
try:
self._RC.ws_message(payload)
except:
except Exception:
from twisted.python.failure import Failure
print("LOGGING", Failure())
log.err()
raise
def onClose(self, wasClean, code, reason):
#print("onClose")
# print("onClose")
self._RC.ws_close(wasClean, code, reason)
#if self.wormhole_open:
# self.wormhole._ws_closed(wasClean, code, reason)
#else:
# # we closed before establishing a connection (onConnect) or
# # finishing WebSocket negotiation (onOpen): errback
# self.factory.d.errback(error.ConnectError(reason))
# if self.wormhole_open:
# self.wormhole._ws_closed(wasClean, code, reason)
# else:
# # we closed before establishing a connection (onConnect) or
# # finishing WebSocket negotiation (onOpen): errback
# self.factory.d.errback(error.ConnectError(reason))
class WSFactory(websocket.WebSocketClientFactory):
protocol = WSClient
def __init__(self, RC, *args, **kwargs):
websocket.WebSocketClientFactory.__init__(self, *args, **kwargs)
self._RC = RC
@ -54,9 +57,10 @@ class WSFactory(websocket.WebSocketClientFactory):
def buildProtocol(self, addr):
proto = websocket.WebSocketClientFactory.buildProtocol(self, addr)
proto._RC = self._RC
#proto.wormhole_open = False
# proto.wormhole_open = False
return proto
@attrs
@implementer(_interfaces.IRendezvousConnector)
class RendezvousConnector(object):
@ -90,6 +94,7 @@ class RendezvousConnector(object):
def set_trace(self, f):
self._trace = f
def _debug(self, what):
if self._trace:
self._trace(old_state="", input=what, new_state="")
@ -133,14 +138,13 @@ class RendezvousConnector(object):
def stop(self):
# ClientService.stopService is defined to "Stop attempting to
# reconnect and close any existing connections"
self._stopping = True # to catch _initial_connection_failed error
self._stopping = True # to catch _initial_connection_failed error
d = defer.maybeDeferred(self._connector.stopService)
# ClientService.stopService always fires with None, even if the
# initial connection failed, so log.err just in case
d.addErrback(log.err)
d.addBoth(self._stopped)
# from Lister
def tx_list(self):
self._tx("list")
@ -157,7 +161,7 @@ class RendezvousConnector(object):
# this should happen right away: the ClientService ought to be in
# the "_waiting" state, and everything in the _waiting.stop
# transition is immediate
d.addErrback(log.err) # just in case something goes wrong
d.addErrback(log.err) # just in case something goes wrong
d.addCallback(lambda _: self._B.error(sce))
# from our WSClient (the WebSocket protocol)
@ -166,8 +170,11 @@ class RendezvousConnector(object):
self._have_made_a_successful_connection = True
self._ws = proto
try:
self._tx("bind", appid=self._appid, side=self._side,
client_version=self._client_version)
self._tx(
"bind",
appid=self._appid,
side=self._side,
client_version=self._client_version)
self._N.connected()
self._M.connected()
self._L.connected()
@ -180,19 +187,22 @@ class RendezvousConnector(object):
def ws_message(self, payload):
msg = bytes_to_dict(payload)
if msg["type"] != "ack":
self._debug("R.rx(%s %s%s)" %
(msg["type"], msg.get("phase",""),
"[mine]" if msg.get("side","") == self._side else "",
))
self._debug("R.rx(%s %s%s)" % (
msg["type"],
msg.get("phase", ""),
"[mine]" if msg.get("side", "") == self._side else "",
))
self._timing.add("ws_receive", _side=self._side, message=msg)
if self._debug_record_inbound_f:
self._debug_record_inbound_f(msg)
mtype = msg["type"]
meth = getattr(self, "_response_handle_"+mtype, None)
meth = getattr(self, "_response_handle_" + mtype, None)
if not meth:
# make tests fail, but real application will ignore it
log.err(errors._UnknownMessageTypeError("Unknown inbound message type %r" % (msg,)))
log.err(
errors._UnknownMessageTypeError(
"Unknown inbound message type %r" % (msg, )))
return
try:
return meth(msg)
@ -229,7 +239,7 @@ class RendezvousConnector(object):
# valid connection
sce = errors.ServerConnectionError(self._url, reason)
d = defer.maybeDeferred(self._connector.stopService)
d.addErrback(log.err) # just in case something goes wrong
d.addErrback(log.err) # just in case something goes wrong
# tell the Boss to quit and inform the user
d.addCallback(lambda _: self._B.error(sce))
@ -292,7 +302,7 @@ class RendezvousConnector(object):
side = msg["side"]
phase = msg["phase"]
assert isinstance(phase, type("")), type(phase)
body = hexstr_to_bytes(msg["body"]) # bytes
body = hexstr_to_bytes(msg["body"]) # bytes
self._M.rx_message(side, phase, body)
def _response_handle_released(self, msg):
@ -301,5 +311,4 @@ class RendezvousConnector(object):
def _response_handle_closed(self, msg):
self._M.rx_closed()
# record, message, payload, packet, bundle, ciphertext, plaintext

View File

@ -1,33 +1,41 @@
from __future__ import print_function, unicode_literals
import traceback
from sys import stderr
from attr import attrib, attrs
from six.moves import input
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.threads import blockingCallFromThread, deferToThread
from .errors import AlreadyInputNameplateError, KeyFormatError
try:
import readline
except ImportError:
readline = None
from six.moves import input
from attr import attrs, attrib
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.threads import deferToThread, blockingCallFromThread
from .errors import KeyFormatError, AlreadyInputNameplateError
errf = None
# uncomment this to enable tab-completion debugging
#import os ; errf = open("err", "w") if os.path.exists("err") else None
def debug(*args, **kwargs): # pragma: no cover
# import os ; errf = open("err", "w") if os.path.exists("err") else None
def debug(*args, **kwargs): # pragma: no cover
if errf:
print(*args, file=errf, **kwargs)
errf.flush()
@attrs
class CodeInputter(object):
_input_helper = attrib()
_reactor = attrib()
def __attrs_post_init__(self):
self.used_completion = False
self._matches = None
# once we've claimed the nameplate, we can't go back
self._committed_nameplate = None # or string
self._committed_nameplate = None # or string
def bcft(self, f, *a, **kw):
return blockingCallFromThread(self._reactor, f, *a, **kw)
@ -66,7 +74,7 @@ class CodeInputter(object):
nameplate, words = text.split("-", 1)
else:
got_nameplate = False
nameplate = text # partial
nameplate = text # partial
# 'text' is one of these categories:
# "" or "12": complete on nameplates (all that match, maybe just one)
@ -83,13 +91,15 @@ class CodeInputter(object):
# they deleted past the committment point: we can't use
# this. For now, bail, but in the future let's find a
# gentler way to encourage them to not do that.
raise AlreadyInputNameplateError("nameplate (%s-) already entered, cannot go back" % self._committed_nameplate)
raise AlreadyInputNameplateError(
"nameplate (%s-) already entered, cannot go back" %
self._committed_nameplate)
if not got_nameplate:
# we're completing on nameplates: "" or "12" or "123"
self.bcft(ih.refresh_nameplates) # results arrive later
self.bcft(ih.refresh_nameplates) # results arrive later
debug(" getting nameplates")
completions = self.bcft(ih.get_nameplate_completions, nameplate)
else: # "123-" or "123-supp"
else: # "123-" or "123-supp"
# time to commit to this nameplate, if they haven't already
if not self._committed_nameplate:
debug(" choose_nameplate(%s)" % nameplate)
@ -112,11 +122,13 @@ class CodeInputter(object):
# heard about it from the server), it can't be helped. But
# for the rest of the code, a simple wait-for-wordlist will
# improve the user experience.
self.bcft(ih.when_wordlist_is_available) # blocks on CLAIM
self.bcft(ih.when_wordlist_is_available) # blocks on CLAIM
# and we're completing on words now
debug(" getting words (%s)" % (words,))
completions = [nameplate+"-"+c
for c in self.bcft(ih.get_word_completions, words)]
debug(" getting words (%s)" % (words, ))
completions = [
nameplate + "-" + c
for c in self.bcft(ih.get_word_completions, words)
]
# rlcompleter wants full strings
return sorted(completions)
@ -131,13 +143,16 @@ class CodeInputter(object):
# they deleted past the committment point: we can't use
# this. For now, bail, but in the future let's find a
# gentler way to encourage them to not do that.
raise AlreadyInputNameplateError("nameplate (%s-) already entered, cannot go back" % self._committed_nameplate)
raise AlreadyInputNameplateError(
"nameplate (%s-) already entered, cannot go back" %
self._committed_nameplate)
else:
debug(" choose_nameplate(%s)" % nameplate)
self.bcft(self._input_helper.choose_nameplate, nameplate)
debug(" choose_words(%s)" % words)
self.bcft(self._input_helper.choose_words, words)
def _input_code_with_completion(prompt, input_helper, reactor):
# reminder: this all occurs in a separate thread. All calls to input_helper
# must go through blockingCallFromThread()
@ -159,6 +174,7 @@ def _input_code_with_completion(prompt, input_helper, reactor):
c.finish(code)
return c.used_completion
def warn_readline():
# When our process receives a SIGINT, Twisted's SIGINT handler will
# stop the reactor and wait for all threads to terminate before the
@ -192,11 +208,12 @@ def warn_readline():
# doesn't see the signal, and we must still wait for stdin to make
# readline finish.
@inlineCallbacks
def input_with_completion(prompt, input_helper, reactor):
t = reactor.addSystemEventTrigger("before", "shutdown", warn_readline)
#input_helper.refresh_nameplates()
used_completion = yield deferToThread(_input_code_with_completion,
prompt, input_helper, reactor)
# input_helper.refresh_nameplates()
used_completion = yield deferToThread(_input_code_with_completion, prompt,
input_helper, reactor)
reactor.removeSystemEventTrigger(t)
returnValue(used_completion)

View File

@ -1,18 +1,22 @@
from __future__ import print_function, absolute_import, unicode_literals
from attr import attrs, attrib
from attr.validators import provides, instance_of
from zope.interface import implementer
from __future__ import absolute_import, print_function, unicode_literals
from attr import attrib, attrs
from attr.validators import instance_of, provides
from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces
from ._key import derive_phase_key, encrypt_data
@attrs
@implementer(_interfaces.ISend)
class Send(object):
_side = attrib(validator=instance_of(type(u"")))
_timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self):
self._queue = []
@ -21,31 +25,40 @@ class Send(object):
self._M = _interfaces.IMailbox(mailbox)
@m.state(initial=True)
def S0_no_key(self): pass # pragma: no cover
def S0_no_key(self):
pass # pragma: no cover
@m.state(terminal=True)
def S1_verified_key(self): pass # pragma: no cover
def S1_verified_key(self):
pass # pragma: no cover
# from Receive
@m.input()
def got_verified_key(self, key): pass
def got_verified_key(self, key):
pass
# from Boss
@m.input()
def send(self, phase, plaintext): pass
def send(self, phase, plaintext):
pass
@m.output()
def queue(self, phase, plaintext):
assert isinstance(phase, type("")), type(phase)
assert isinstance(plaintext, type(b"")), type(plaintext)
self._queue.append((phase, plaintext))
@m.output()
def record_key(self, key):
self._key = key
@m.output()
def drain(self, key):
del key
for (phase, plaintext) in self._queue:
self._encrypt_and_send(phase, plaintext)
self._queue[:] = []
@m.output()
def deliver(self, phase, plaintext):
assert isinstance(phase, type("")), type(phase)
@ -59,6 +72,6 @@ class Send(object):
self._M.add_message(phase, encrypted)
S0_no_key.upon(send, enter=S0_no_key, outputs=[queue])
S0_no_key.upon(got_verified_key, enter=S1_verified_key,
outputs=[record_key, drain])
S0_no_key.upon(
got_verified_key, enter=S1_verified_key, outputs=[record_key, drain])
S1_verified_key.upon(send, enter=S1_verified_key, outputs=[deliver])

View File

@ -1,12 +1,16 @@
from __future__ import print_function, absolute_import, unicode_literals
from zope.interface import implementer
from __future__ import absolute_import, print_function, unicode_literals
from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces
@implementer(_interfaces.ITerminator)
class Terminator(object):
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __init__(self):
self._mood = None
@ -29,48 +33,68 @@ class Terminator(object):
# done, and we're closing, then we stop the RendezvousConnector
@m.state(initial=True)
def Snmo(self): pass # pragma: no cover
@m.state()
def Smo(self): pass # pragma: no cover
@m.state()
def Sno(self): pass # pragma: no cover
@m.state()
def S0o(self): pass # pragma: no cover
def Snmo(self):
pass # pragma: no cover
@m.state()
def Snm(self): pass # pragma: no cover
@m.state()
def Sm(self): pass # pragma: no cover
@m.state()
def Sn(self): pass # pragma: no cover
#@m.state()
#def S0(self): pass # unused
def Smo(self):
pass # pragma: no cover
@m.state()
def S_stopping(self): pass # pragma: no cover
def Sno(self):
pass # pragma: no cover
@m.state()
def S_stopped(self, terminal=True): pass # pragma: no cover
def S0o(self):
pass # pragma: no cover
@m.state()
def Snm(self):
pass # pragma: no cover
@m.state()
def Sm(self):
pass # pragma: no cover
@m.state()
def Sn(self):
pass # pragma: no cover
# @m.state()
# def S0(self): pass # unused
@m.state()
def S_stopping(self):
pass # pragma: no cover
@m.state()
def S_stopped(self, terminal=True):
pass # pragma: no cover
# from Boss
@m.input()
def close(self, mood): pass
def close(self, mood):
pass
# from Nameplate
@m.input()
def nameplate_done(self): pass
def nameplate_done(self):
pass
# from Mailbox
@m.input()
def mailbox_done(self): pass
def mailbox_done(self):
pass
# from RendezvousConnector
@m.input()
def stopped(self): pass
def stopped(self):
pass
@m.output()
def close_nameplate(self, mood):
self._N.close() # ignores mood
self._N.close() # ignores mood
@m.output()
def close_mailbox(self, mood):
self._M.close(mood)
@ -78,9 +102,11 @@ class Terminator(object):
@m.output()
def ignore_mood_and_RC_stop(self, mood):
self._RC.stop()
@m.output()
def RC_stop(self):
self._RC.stop()
@m.output()
def B_closed(self):
self._B.closed()
@ -99,8 +125,10 @@ class Terminator(object):
Snm.upon(nameplate_done, enter=Sm, outputs=[])
Sn.upon(nameplate_done, enter=S_stopping, outputs=[RC_stop])
S0o.upon(close, enter=S_stopping,
outputs=[close_nameplate, close_mailbox, ignore_mood_and_RC_stop])
S0o.upon(
close,
enter=S_stopping,
outputs=[close_nameplate, close_mailbox, ignore_mood_and_RC_stop])
Sm.upon(mailbox_done, enter=S_stopping, outputs=[RC_stop])
S_stopping.upon(stopped, enter=S_stopped, outputs=[B_closed])

View File

@ -1,6 +1,10 @@
from __future__ import unicode_literals, print_function
from __future__ import print_function, unicode_literals
import os
from binascii import unhexlify
from zope.interface import implementer
from ._interfaces import IWordlist
# The PGP Word List, which maps bytes to phonetically-distinct words. There
@ -10,154 +14,280 @@ from ._interfaces import IWordlist
# Thanks to Warren Guy for transcribing them:
# https://github.com/warrenguy/javascript-pgp-word-list
from binascii import unhexlify
raw_words = {
'00': ['aardvark', 'adroitness'], '01': ['absurd', 'adviser'],
'02': ['accrue', 'aftermath'], '03': ['acme', 'aggregate'],
'04': ['adrift', 'alkali'], '05': ['adult', 'almighty'],
'06': ['afflict', 'amulet'], '07': ['ahead', 'amusement'],
'08': ['aimless', 'antenna'], '09': ['Algol', 'applicant'],
'0A': ['allow', 'Apollo'], '0B': ['alone', 'armistice'],
'0C': ['ammo', 'article'], '0D': ['ancient', 'asteroid'],
'0E': ['apple', 'Atlantic'], '0F': ['artist', 'atmosphere'],
'10': ['assume', 'autopsy'], '11': ['Athens', 'Babylon'],
'12': ['atlas', 'backwater'], '13': ['Aztec', 'barbecue'],
'14': ['baboon', 'belowground'], '15': ['backfield', 'bifocals'],
'16': ['backward', 'bodyguard'], '17': ['banjo', 'bookseller'],
'18': ['beaming', 'borderline'], '19': ['bedlamp', 'bottomless'],
'1A': ['beehive', 'Bradbury'], '1B': ['beeswax', 'bravado'],
'1C': ['befriend', 'Brazilian'], '1D': ['Belfast', 'breakaway'],
'1E': ['berserk', 'Burlington'], '1F': ['billiard', 'businessman'],
'20': ['bison', 'butterfat'], '21': ['blackjack', 'Camelot'],
'22': ['blockade', 'candidate'], '23': ['blowtorch', 'cannonball'],
'24': ['bluebird', 'Capricorn'], '25': ['bombast', 'caravan'],
'26': ['bookshelf', 'caretaker'], '27': ['brackish', 'celebrate'],
'28': ['breadline', 'cellulose'], '29': ['breakup', 'certify'],
'2A': ['brickyard', 'chambermaid'], '2B': ['briefcase', 'Cherokee'],
'2C': ['Burbank', 'Chicago'], '2D': ['button', 'clergyman'],
'2E': ['buzzard', 'coherence'], '2F': ['cement', 'combustion'],
'30': ['chairlift', 'commando'], '31': ['chatter', 'company'],
'32': ['checkup', 'component'], '33': ['chisel', 'concurrent'],
'34': ['choking', 'confidence'], '35': ['chopper', 'conformist'],
'36': ['Christmas', 'congregate'], '37': ['clamshell', 'consensus'],
'38': ['classic', 'consulting'], '39': ['classroom', 'corporate'],
'3A': ['cleanup', 'corrosion'], '3B': ['clockwork', 'councilman'],
'3C': ['cobra', 'crossover'], '3D': ['commence', 'crucifix'],
'3E': ['concert', 'cumbersome'], '3F': ['cowbell', 'customer'],
'40': ['crackdown', 'Dakota'], '41': ['cranky', 'decadence'],
'42': ['crowfoot', 'December'], '43': ['crucial', 'decimal'],
'44': ['crumpled', 'designing'], '45': ['crusade', 'detector'],
'46': ['cubic', 'detergent'], '47': ['dashboard', 'determine'],
'48': ['deadbolt', 'dictator'], '49': ['deckhand', 'dinosaur'],
'4A': ['dogsled', 'direction'], '4B': ['dragnet', 'disable'],
'4C': ['drainage', 'disbelief'], '4D': ['dreadful', 'disruptive'],
'4E': ['drifter', 'distortion'], '4F': ['dropper', 'document'],
'50': ['drumbeat', 'embezzle'], '51': ['drunken', 'enchanting'],
'52': ['Dupont', 'enrollment'], '53': ['dwelling', 'enterprise'],
'54': ['eating', 'equation'], '55': ['edict', 'equipment'],
'56': ['egghead', 'escapade'], '57': ['eightball', 'Eskimo'],
'58': ['endorse', 'everyday'], '59': ['endow', 'examine'],
'5A': ['enlist', 'existence'], '5B': ['erase', 'exodus'],
'5C': ['escape', 'fascinate'], '5D': ['exceed', 'filament'],
'5E': ['eyeglass', 'finicky'], '5F': ['eyetooth', 'forever'],
'60': ['facial', 'fortitude'], '61': ['fallout', 'frequency'],
'62': ['flagpole', 'gadgetry'], '63': ['flatfoot', 'Galveston'],
'64': ['flytrap', 'getaway'], '65': ['fracture', 'glossary'],
'66': ['framework', 'gossamer'], '67': ['freedom', 'graduate'],
'68': ['frighten', 'gravity'], '69': ['gazelle', 'guitarist'],
'6A': ['Geiger', 'hamburger'], '6B': ['glitter', 'Hamilton'],
'6C': ['glucose', 'handiwork'], '6D': ['goggles', 'hazardous'],
'6E': ['goldfish', 'headwaters'], '6F': ['gremlin', 'hemisphere'],
'70': ['guidance', 'hesitate'], '71': ['hamlet', 'hideaway'],
'72': ['highchair', 'holiness'], '73': ['hockey', 'hurricane'],
'74': ['indoors', 'hydraulic'], '75': ['indulge', 'impartial'],
'76': ['inverse', 'impetus'], '77': ['involve', 'inception'],
'78': ['island', 'indigo'], '79': ['jawbone', 'inertia'],
'7A': ['keyboard', 'infancy'], '7B': ['kickoff', 'inferno'],
'7C': ['kiwi', 'informant'], '7D': ['klaxon', 'insincere'],
'7E': ['locale', 'insurgent'], '7F': ['lockup', 'integrate'],
'80': ['merit', 'intention'], '81': ['minnow', 'inventive'],
'82': ['miser', 'Istanbul'], '83': ['Mohawk', 'Jamaica'],
'84': ['mural', 'Jupiter'], '85': ['music', 'leprosy'],
'86': ['necklace', 'letterhead'], '87': ['Neptune', 'liberty'],
'88': ['newborn', 'maritime'], '89': ['nightbird', 'matchmaker'],
'8A': ['Oakland', 'maverick'], '8B': ['obtuse', 'Medusa'],
'8C': ['offload', 'megaton'], '8D': ['optic', 'microscope'],
'8E': ['orca', 'microwave'], '8F': ['payday', 'midsummer'],
'90': ['peachy', 'millionaire'], '91': ['pheasant', 'miracle'],
'92': ['physique', 'misnomer'], '93': ['playhouse', 'molasses'],
'94': ['Pluto', 'molecule'], '95': ['preclude', 'Montana'],
'96': ['prefer', 'monument'], '97': ['preshrunk', 'mosquito'],
'98': ['printer', 'narrative'], '99': ['prowler', 'nebula'],
'9A': ['pupil', 'newsletter'], '9B': ['puppy', 'Norwegian'],
'9C': ['python', 'October'], '9D': ['quadrant', 'Ohio'],
'9E': ['quiver', 'onlooker'], '9F': ['quota', 'opulent'],
'A0': ['ragtime', 'Orlando'], 'A1': ['ratchet', 'outfielder'],
'A2': ['rebirth', 'Pacific'], 'A3': ['reform', 'pandemic'],
'A4': ['regain', 'Pandora'], 'A5': ['reindeer', 'paperweight'],
'A6': ['rematch', 'paragon'], 'A7': ['repay', 'paragraph'],
'A8': ['retouch', 'paramount'], 'A9': ['revenge', 'passenger'],
'AA': ['reward', 'pedigree'], 'AB': ['rhythm', 'Pegasus'],
'AC': ['ribcage', 'penetrate'], 'AD': ['ringbolt', 'perceptive'],
'AE': ['robust', 'performance'], 'AF': ['rocker', 'pharmacy'],
'B0': ['ruffled', 'phonetic'], 'B1': ['sailboat', 'photograph'],
'B2': ['sawdust', 'pioneer'], 'B3': ['scallion', 'pocketful'],
'B4': ['scenic', 'politeness'], 'B5': ['scorecard', 'positive'],
'B6': ['Scotland', 'potato'], 'B7': ['seabird', 'processor'],
'B8': ['select', 'provincial'], 'B9': ['sentence', 'proximate'],
'BA': ['shadow', 'puberty'], 'BB': ['shamrock', 'publisher'],
'BC': ['showgirl', 'pyramid'], 'BD': ['skullcap', 'quantity'],
'BE': ['skydive', 'racketeer'], 'BF': ['slingshot', 'rebellion'],
'C0': ['slowdown', 'recipe'], 'C1': ['snapline', 'recover'],
'C2': ['snapshot', 'repellent'], 'C3': ['snowcap', 'replica'],
'C4': ['snowslide', 'reproduce'], 'C5': ['solo', 'resistor'],
'C6': ['southward', 'responsive'], 'C7': ['soybean', 'retraction'],
'C8': ['spaniel', 'retrieval'], 'C9': ['spearhead', 'retrospect'],
'CA': ['spellbind', 'revenue'], 'CB': ['spheroid', 'revival'],
'CC': ['spigot', 'revolver'], 'CD': ['spindle', 'sandalwood'],
'CE': ['spyglass', 'sardonic'], 'CF': ['stagehand', 'Saturday'],
'D0': ['stagnate', 'savagery'], 'D1': ['stairway', 'scavenger'],
'D2': ['standard', 'sensation'], 'D3': ['stapler', 'sociable'],
'D4': ['steamship', 'souvenir'], 'D5': ['sterling', 'specialist'],
'D6': ['stockman', 'speculate'], 'D7': ['stopwatch', 'stethoscope'],
'D8': ['stormy', 'stupendous'], 'D9': ['sugar', 'supportive'],
'DA': ['surmount', 'surrender'], 'DB': ['suspense', 'suspicious'],
'DC': ['sweatband', 'sympathy'], 'DD': ['swelter', 'tambourine'],
'DE': ['tactics', 'telephone'], 'DF': ['talon', 'therapist'],
'E0': ['tapeworm', 'tobacco'], 'E1': ['tempest', 'tolerance'],
'E2': ['tiger', 'tomorrow'], 'E3': ['tissue', 'torpedo'],
'E4': ['tonic', 'tradition'], 'E5': ['topmost', 'travesty'],
'E6': ['tracker', 'trombonist'], 'E7': ['transit', 'truncated'],
'E8': ['trauma', 'typewriter'], 'E9': ['treadmill', 'ultimate'],
'EA': ['Trojan', 'undaunted'], 'EB': ['trouble', 'underfoot'],
'EC': ['tumor', 'unicorn'], 'ED': ['tunnel', 'unify'],
'EE': ['tycoon', 'universe'], 'EF': ['uncut', 'unravel'],
'F0': ['unearth', 'upcoming'], 'F1': ['unwind', 'vacancy'],
'F2': ['uproot', 'vagabond'], 'F3': ['upset', 'vertigo'],
'F4': ['upshot', 'Virginia'], 'F5': ['vapor', 'visitor'],
'F6': ['village', 'vocalist'], 'F7': ['virus', 'voyager'],
'F8': ['Vulcan', 'warranty'], 'F9': ['waffle', 'Waterloo'],
'FA': ['wallet', 'whimsical'], 'FB': ['watchword', 'Wichita'],
'FC': ['wayside', 'Wilmington'], 'FD': ['willow', 'Wyoming'],
'FE': ['woodlark', 'yesteryear'], 'FF': ['Zulu', 'Yucatan']
};
'00': ['aardvark', 'adroitness'],
'01': ['absurd', 'adviser'],
'02': ['accrue', 'aftermath'],
'03': ['acme', 'aggregate'],
'04': ['adrift', 'alkali'],
'05': ['adult', 'almighty'],
'06': ['afflict', 'amulet'],
'07': ['ahead', 'amusement'],
'08': ['aimless', 'antenna'],
'09': ['Algol', 'applicant'],
'0A': ['allow', 'Apollo'],
'0B': ['alone', 'armistice'],
'0C': ['ammo', 'article'],
'0D': ['ancient', 'asteroid'],
'0E': ['apple', 'Atlantic'],
'0F': ['artist', 'atmosphere'],
'10': ['assume', 'autopsy'],
'11': ['Athens', 'Babylon'],
'12': ['atlas', 'backwater'],
'13': ['Aztec', 'barbecue'],
'14': ['baboon', 'belowground'],
'15': ['backfield', 'bifocals'],
'16': ['backward', 'bodyguard'],
'17': ['banjo', 'bookseller'],
'18': ['beaming', 'borderline'],
'19': ['bedlamp', 'bottomless'],
'1A': ['beehive', 'Bradbury'],
'1B': ['beeswax', 'bravado'],
'1C': ['befriend', 'Brazilian'],
'1D': ['Belfast', 'breakaway'],
'1E': ['berserk', 'Burlington'],
'1F': ['billiard', 'businessman'],
'20': ['bison', 'butterfat'],
'21': ['blackjack', 'Camelot'],
'22': ['blockade', 'candidate'],
'23': ['blowtorch', 'cannonball'],
'24': ['bluebird', 'Capricorn'],
'25': ['bombast', 'caravan'],
'26': ['bookshelf', 'caretaker'],
'27': ['brackish', 'celebrate'],
'28': ['breadline', 'cellulose'],
'29': ['breakup', 'certify'],
'2A': ['brickyard', 'chambermaid'],
'2B': ['briefcase', 'Cherokee'],
'2C': ['Burbank', 'Chicago'],
'2D': ['button', 'clergyman'],
'2E': ['buzzard', 'coherence'],
'2F': ['cement', 'combustion'],
'30': ['chairlift', 'commando'],
'31': ['chatter', 'company'],
'32': ['checkup', 'component'],
'33': ['chisel', 'concurrent'],
'34': ['choking', 'confidence'],
'35': ['chopper', 'conformist'],
'36': ['Christmas', 'congregate'],
'37': ['clamshell', 'consensus'],
'38': ['classic', 'consulting'],
'39': ['classroom', 'corporate'],
'3A': ['cleanup', 'corrosion'],
'3B': ['clockwork', 'councilman'],
'3C': ['cobra', 'crossover'],
'3D': ['commence', 'crucifix'],
'3E': ['concert', 'cumbersome'],
'3F': ['cowbell', 'customer'],
'40': ['crackdown', 'Dakota'],
'41': ['cranky', 'decadence'],
'42': ['crowfoot', 'December'],
'43': ['crucial', 'decimal'],
'44': ['crumpled', 'designing'],
'45': ['crusade', 'detector'],
'46': ['cubic', 'detergent'],
'47': ['dashboard', 'determine'],
'48': ['deadbolt', 'dictator'],
'49': ['deckhand', 'dinosaur'],
'4A': ['dogsled', 'direction'],
'4B': ['dragnet', 'disable'],
'4C': ['drainage', 'disbelief'],
'4D': ['dreadful', 'disruptive'],
'4E': ['drifter', 'distortion'],
'4F': ['dropper', 'document'],
'50': ['drumbeat', 'embezzle'],
'51': ['drunken', 'enchanting'],
'52': ['Dupont', 'enrollment'],
'53': ['dwelling', 'enterprise'],
'54': ['eating', 'equation'],
'55': ['edict', 'equipment'],
'56': ['egghead', 'escapade'],
'57': ['eightball', 'Eskimo'],
'58': ['endorse', 'everyday'],
'59': ['endow', 'examine'],
'5A': ['enlist', 'existence'],
'5B': ['erase', 'exodus'],
'5C': ['escape', 'fascinate'],
'5D': ['exceed', 'filament'],
'5E': ['eyeglass', 'finicky'],
'5F': ['eyetooth', 'forever'],
'60': ['facial', 'fortitude'],
'61': ['fallout', 'frequency'],
'62': ['flagpole', 'gadgetry'],
'63': ['flatfoot', 'Galveston'],
'64': ['flytrap', 'getaway'],
'65': ['fracture', 'glossary'],
'66': ['framework', 'gossamer'],
'67': ['freedom', 'graduate'],
'68': ['frighten', 'gravity'],
'69': ['gazelle', 'guitarist'],
'6A': ['Geiger', 'hamburger'],
'6B': ['glitter', 'Hamilton'],
'6C': ['glucose', 'handiwork'],
'6D': ['goggles', 'hazardous'],
'6E': ['goldfish', 'headwaters'],
'6F': ['gremlin', 'hemisphere'],
'70': ['guidance', 'hesitate'],
'71': ['hamlet', 'hideaway'],
'72': ['highchair', 'holiness'],
'73': ['hockey', 'hurricane'],
'74': ['indoors', 'hydraulic'],
'75': ['indulge', 'impartial'],
'76': ['inverse', 'impetus'],
'77': ['involve', 'inception'],
'78': ['island', 'indigo'],
'79': ['jawbone', 'inertia'],
'7A': ['keyboard', 'infancy'],
'7B': ['kickoff', 'inferno'],
'7C': ['kiwi', 'informant'],
'7D': ['klaxon', 'insincere'],
'7E': ['locale', 'insurgent'],
'7F': ['lockup', 'integrate'],
'80': ['merit', 'intention'],
'81': ['minnow', 'inventive'],
'82': ['miser', 'Istanbul'],
'83': ['Mohawk', 'Jamaica'],
'84': ['mural', 'Jupiter'],
'85': ['music', 'leprosy'],
'86': ['necklace', 'letterhead'],
'87': ['Neptune', 'liberty'],
'88': ['newborn', 'maritime'],
'89': ['nightbird', 'matchmaker'],
'8A': ['Oakland', 'maverick'],
'8B': ['obtuse', 'Medusa'],
'8C': ['offload', 'megaton'],
'8D': ['optic', 'microscope'],
'8E': ['orca', 'microwave'],
'8F': ['payday', 'midsummer'],
'90': ['peachy', 'millionaire'],
'91': ['pheasant', 'miracle'],
'92': ['physique', 'misnomer'],
'93': ['playhouse', 'molasses'],
'94': ['Pluto', 'molecule'],
'95': ['preclude', 'Montana'],
'96': ['prefer', 'monument'],
'97': ['preshrunk', 'mosquito'],
'98': ['printer', 'narrative'],
'99': ['prowler', 'nebula'],
'9A': ['pupil', 'newsletter'],
'9B': ['puppy', 'Norwegian'],
'9C': ['python', 'October'],
'9D': ['quadrant', 'Ohio'],
'9E': ['quiver', 'onlooker'],
'9F': ['quota', 'opulent'],
'A0': ['ragtime', 'Orlando'],
'A1': ['ratchet', 'outfielder'],
'A2': ['rebirth', 'Pacific'],
'A3': ['reform', 'pandemic'],
'A4': ['regain', 'Pandora'],
'A5': ['reindeer', 'paperweight'],
'A6': ['rematch', 'paragon'],
'A7': ['repay', 'paragraph'],
'A8': ['retouch', 'paramount'],
'A9': ['revenge', 'passenger'],
'AA': ['reward', 'pedigree'],
'AB': ['rhythm', 'Pegasus'],
'AC': ['ribcage', 'penetrate'],
'AD': ['ringbolt', 'perceptive'],
'AE': ['robust', 'performance'],
'AF': ['rocker', 'pharmacy'],
'B0': ['ruffled', 'phonetic'],
'B1': ['sailboat', 'photograph'],
'B2': ['sawdust', 'pioneer'],
'B3': ['scallion', 'pocketful'],
'B4': ['scenic', 'politeness'],
'B5': ['scorecard', 'positive'],
'B6': ['Scotland', 'potato'],
'B7': ['seabird', 'processor'],
'B8': ['select', 'provincial'],
'B9': ['sentence', 'proximate'],
'BA': ['shadow', 'puberty'],
'BB': ['shamrock', 'publisher'],
'BC': ['showgirl', 'pyramid'],
'BD': ['skullcap', 'quantity'],
'BE': ['skydive', 'racketeer'],
'BF': ['slingshot', 'rebellion'],
'C0': ['slowdown', 'recipe'],
'C1': ['snapline', 'recover'],
'C2': ['snapshot', 'repellent'],
'C3': ['snowcap', 'replica'],
'C4': ['snowslide', 'reproduce'],
'C5': ['solo', 'resistor'],
'C6': ['southward', 'responsive'],
'C7': ['soybean', 'retraction'],
'C8': ['spaniel', 'retrieval'],
'C9': ['spearhead', 'retrospect'],
'CA': ['spellbind', 'revenue'],
'CB': ['spheroid', 'revival'],
'CC': ['spigot', 'revolver'],
'CD': ['spindle', 'sandalwood'],
'CE': ['spyglass', 'sardonic'],
'CF': ['stagehand', 'Saturday'],
'D0': ['stagnate', 'savagery'],
'D1': ['stairway', 'scavenger'],
'D2': ['standard', 'sensation'],
'D3': ['stapler', 'sociable'],
'D4': ['steamship', 'souvenir'],
'D5': ['sterling', 'specialist'],
'D6': ['stockman', 'speculate'],
'D7': ['stopwatch', 'stethoscope'],
'D8': ['stormy', 'stupendous'],
'D9': ['sugar', 'supportive'],
'DA': ['surmount', 'surrender'],
'DB': ['suspense', 'suspicious'],
'DC': ['sweatband', 'sympathy'],
'DD': ['swelter', 'tambourine'],
'DE': ['tactics', 'telephone'],
'DF': ['talon', 'therapist'],
'E0': ['tapeworm', 'tobacco'],
'E1': ['tempest', 'tolerance'],
'E2': ['tiger', 'tomorrow'],
'E3': ['tissue', 'torpedo'],
'E4': ['tonic', 'tradition'],
'E5': ['topmost', 'travesty'],
'E6': ['tracker', 'trombonist'],
'E7': ['transit', 'truncated'],
'E8': ['trauma', 'typewriter'],
'E9': ['treadmill', 'ultimate'],
'EA': ['Trojan', 'undaunted'],
'EB': ['trouble', 'underfoot'],
'EC': ['tumor', 'unicorn'],
'ED': ['tunnel', 'unify'],
'EE': ['tycoon', 'universe'],
'EF': ['uncut', 'unravel'],
'F0': ['unearth', 'upcoming'],
'F1': ['unwind', 'vacancy'],
'F2': ['uproot', 'vagabond'],
'F3': ['upset', 'vertigo'],
'F4': ['upshot', 'Virginia'],
'F5': ['vapor', 'visitor'],
'F6': ['village', 'vocalist'],
'F7': ['virus', 'voyager'],
'F8': ['Vulcan', 'warranty'],
'F9': ['waffle', 'Waterloo'],
'FA': ['wallet', 'whimsical'],
'FB': ['watchword', 'Wichita'],
'FC': ['wayside', 'Wilmington'],
'FD': ['willow', 'Wyoming'],
'FE': ['woodlark', 'yesteryear'],
'FF': ['Zulu', 'Yucatan']
}
byte_to_even_word = dict([(unhexlify(k.encode("ascii")), both_words[0])
for k,both_words
in raw_words.items()])
for k, both_words in raw_words.items()])
byte_to_odd_word = dict([(unhexlify(k.encode("ascii")), both_words[1])
for k,both_words
in raw_words.items()])
for k, both_words in raw_words.items()])
even_words_lowercase, odd_words_lowercase = set(), set()
for k,both_words in raw_words.items():
for k, both_words in raw_words.items():
even_word, odd_word = both_words
even_words_lowercase.add(even_word.lower())
odd_words_lowercase.add(odd_word.lower())
@implementer(IWordlist)
class PGPWordList(object):
def get_completions(self, prefix, num_words=2):
@ -177,7 +307,7 @@ class PGPWordList(object):
else:
suffix = prefix[:-lp] + word
# append a hyphen if we expect more words
if count+1 < num_words:
if count + 1 < num_words:
suffix += "-"
completions.add(suffix)
return completions

View File

@ -2,21 +2,24 @@ from __future__ import print_function
import os
import time
start = time.time()
import six
from textwrap import fill, dedent
from sys import stdout, stderr
from . import public_relay
from .. import __version__
from ..timing import DebugTiming
from ..errors import (WrongPasswordError, WelcomeError, KeyFormatError,
TransferError, NoTorError, UnsendableFileError,
ServerConnectionError)
from twisted.internet.defer import inlineCallbacks, maybeDeferred
from twisted.python.failure import Failure
from twisted.internet.task import react
from sys import stderr, stdout
from textwrap import dedent, fill
import click
import six
from twisted.internet.defer import inlineCallbacks, maybeDeferred
from twisted.internet.task import react
from twisted.python.failure import Failure
from . import public_relay
from .. import __version__
from ..errors import (KeyFormatError, NoTorError, ServerConnectionError,
TransferError, UnsendableFileError, WelcomeError,
WrongPasswordError)
from ..timing import DebugTiming
start = time.time()
top_import_finish = time.time()
@ -24,6 +27,7 @@ class Config(object):
"""
Union of config options that we pass down to (sub) commands.
"""
def __init__(self):
# This only holds attributes which are *not* set by CLI arguments.
# Everything else comes from Click decorators, so we can be sure
@ -34,11 +38,13 @@ class Config(object):
self.stderr = stderr
self.tor = False # XXX?
def _compose(*decorators):
def decorate(f):
for d in reversed(decorators):
f = d(f)
return f
return decorate
@ -48,6 +54,8 @@ ALIASES = {
"recieve": "receive",
"recv": "receive",
}
class AliasedGroup(click.Group):
def get_command(self, ctx, cmd_name):
cmd_name = ALIASES.get(cmd_name, cmd_name)
@ -56,22 +64,24 @@ class AliasedGroup(click.Group):
# top-level command ("wormhole ...")
@click.group(cls=AliasedGroup)
@click.option("--appid", default=None, metavar="APPID", help="appid to use")
@click.option(
"--appid", default=None, metavar="APPID", help="appid to use")
@click.option(
"--relay-url", default=public_relay.RENDEZVOUS_RELAY,
"--relay-url",
default=public_relay.RENDEZVOUS_RELAY,
envvar='WORMHOLE_RELAY_URL',
metavar="URL",
help="rendezvous relay to use",
)
@click.option(
"--transit-helper", default=public_relay.TRANSIT_RELAY,
"--transit-helper",
default=public_relay.TRANSIT_RELAY,
envvar='WORMHOLE_TRANSIT_HELPER',
metavar="tcp:HOST:PORT",
help="transit relay to use",
)
@click.option(
"--dump-timing", type=type(u""), # TODO: hide from --help output
"--dump-timing",
type=type(u""), # TODO: hide from --help output
default=None,
metavar="FILE.json",
help="(debug) write timing data to file",
@ -104,7 +114,8 @@ def _dispatch_command(reactor, cfg, command):
errors for the user.
"""
cfg.timing.add("command dispatch")
cfg.timing.add("import", when=start, which="top").finish(when=top_import_finish)
cfg.timing.add(
"import", when=start, which="top").finish(when=top_import_finish)
try:
yield maybeDeferred(command)
@ -141,56 +152,89 @@ def _dispatch_command(reactor, cfg, command):
CommonArgs = _compose(
click.option("-0", "zeromode", default=False, is_flag=True,
help="enable no-code anything-goes mode",
),
click.option("-c", "--code-length", default=2, metavar="NUMWORDS",
help="length of code (in bytes/words)",
),
click.option("-v", "--verify", is_flag=True, default=False,
help="display verification string (and wait for approval)",
),
click.option("--hide-progress", is_flag=True, default=False,
help="supress progress-bar display",
),
click.option("--listen/--no-listen", default=True,
help="(debug) don't open a listening socket for Transit",
),
click.option(
"-0",
"zeromode",
default=False,
is_flag=True,
help="enable no-code anything-goes mode",
),
click.option(
"-c",
"--code-length",
default=2,
metavar="NUMWORDS",
help="length of code (in bytes/words)",
),
click.option(
"-v",
"--verify",
is_flag=True,
default=False,
help="display verification string (and wait for approval)",
),
click.option(
"--hide-progress",
is_flag=True,
default=False,
help="supress progress-bar display",
),
click.option(
"--listen/--no-listen",
default=True,
help="(debug) don't open a listening socket for Transit",
),
)
TorArgs = _compose(
click.option("--tor", is_flag=True, default=False,
help="use Tor when connecting",
),
click.option("--launch-tor", is_flag=True, default=False,
help="launch Tor, rather than use existing control/socks port",
),
click.option("--tor-control-port", default=None, metavar="ENDPOINT",
help="endpoint descriptor for Tor control port",
),
click.option(
"--tor",
is_flag=True,
default=False,
help="use Tor when connecting",
),
click.option(
"--launch-tor",
is_flag=True,
default=False,
help="launch Tor, rather than use existing control/socks port",
),
click.option(
"--tor-control-port",
default=None,
metavar="ENDPOINT",
help="endpoint descriptor for Tor control port",
),
)
@wormhole.command()
@click.pass_context
def help(context, **kwargs):
print(context.find_root().get_help())
# wormhole send (or "wormhole tx")
@wormhole.command()
@CommonArgs
@TorArgs
@click.option(
"--code", metavar="CODE",
"--code",
metavar="CODE",
help="human-generated code phrase",
)
@click.option(
"--text", default=None, metavar="MESSAGE",
help="text message to send, instead of a file. Use '-' to read from stdin.",
"--text",
default=None,
metavar="MESSAGE",
help=("text message to send, instead of a file."
" Use '-' to read from stdin."),
)
@click.option(
"--ignore-unsendable-files", default=False, is_flag=True,
help="Don't raise an error if a file can't be read."
)
"--ignore-unsendable-files",
default=False,
is_flag=True,
help="Don't raise an error if a file can't be read.")
@click.argument("what", required=False, type=click.Path(path_type=type(u"")))
@click.pass_obj
def send(cfg, **kwargs):
@ -202,6 +246,7 @@ def send(cfg, **kwargs):
return go(cmd_send.send, cfg)
# this intermediate function can be mocked by tests that need to build a
# Config object
def go(f, cfg):
@ -214,23 +259,29 @@ def go(f, cfg):
@CommonArgs
@TorArgs
@click.option(
"--only-text", "-t", is_flag=True,
"--only-text",
"-t",
is_flag=True,
help="refuse file transfers, only accept text transfers",
)
@click.option(
"--accept-file", is_flag=True,
"--accept-file",
is_flag=True,
help="accept file transfer without asking for confirmation",
)
@click.option(
"--output-file", "-o",
"--output-file",
"-o",
metavar="FILENAME|DIRNAME",
help=("The file or directory to create, overriding the name suggested"
" by the sender."),
)
@click.argument(
"code", nargs=-1, default=None,
# help=("The magic-wormhole code, from the sender. If omitted, the"
# " program will ask for it, using tab-completion."),
"code",
nargs=-1,
default=None,
# help=("The magic-wormhole code, from the sender. If omitted, the"
# " program will ask for it, using tab-completion."),
)
@click.pass_obj
def receive(cfg, code, **kwargs):
@ -244,10 +295,8 @@ def receive(cfg, code, **kwargs):
if len(code) == 1:
cfg.code = code[0]
elif len(code) > 1:
print(
"Pass either no code or just one code; you passed"
" {}: {}".format(len(code), ', '.join(code))
)
print("Pass either no code or just one code; you passed"
" {}: {}".format(len(code), ', '.join(code)))
raise SystemExit(1)
else:
cfg.code = None
@ -260,17 +309,19 @@ def ssh():
"""
Facilitate sending/receiving SSH public keys
"""
pass
@ssh.command(name="invite")
@click.option(
"-c", "--code-length", default=2,
"-c",
"--code-length",
default=2,
metavar="NUMWORDS",
help="length of code (in bytes/words)",
)
@click.option(
"--user", "-u",
"--user",
"-u",
default=None,
metavar="USER",
help="Add to USER's ~/.ssh/authorized_keys",
@ -291,15 +342,20 @@ def ssh_invite(ctx, code_length, user, **kwargs):
@ssh.command(name="accept")
@click.argument(
"code", nargs=1, required=True,
"code",
nargs=1,
required=True,
)
@click.option(
"--key-file", "-F",
"--key-file",
"-F",
default=None,
type=click.Path(exists=True),
)
@click.option(
"--yes", "-y", is_flag=True,
"--yes",
"-y",
is_flag=True,
help="Skip confirmation prompt to send key",
)
@TorArgs
@ -318,7 +374,8 @@ def ssh_accept(cfg, code, key_file, yes, **kwargs):
kind, keyid, pubkey = cmd_ssh.find_public_key(key_file)
print("Sending public key type='{}' keyid='{}'".format(kind, keyid))
if yes is not True:
click.confirm("Really send public key '{}' ?".format(keyid), abort=True)
click.confirm(
"Really send public key '{}' ?".format(keyid), abort=True)
cfg.public_key = (kind, keyid, pubkey)
cfg.code = code

View File

@ -1,14 +1,23 @@
from __future__ import print_function
import os, sys, six, tempfile, zipfile, hashlib, shutil
from tqdm import tqdm
import hashlib
import os
import shutil
import sys
import tempfile
import zipfile
import six
from humanize import naturalsize
from tqdm import tqdm
from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.python import log
from wormhole import create, input_with_completion, __version__
from ..transit import TransitReceiver
from wormhole import __version__, create, input_with_completion
from ..errors import TransferError
from ..util import (dict_to_bytes, bytes_to_dict, bytes_to_hexstr,
from ..transit import TransitReceiver
from ..util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes,
estimate_free_space)
from .welcome import handle_welcome
@ -17,14 +26,17 @@ APPID = u"lothar.com/wormhole/text-or-file-xfer"
KEY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_KEY_TIMER", 1.0))
VERIFY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_VERIFY_TIMER", 1.0))
class RespondError(Exception):
def __init__(self, response):
self.response = response
class TransferRejectedError(RespondError):
def __init__(self):
RespondError.__init__(self, "transfer rejected")
def receive(args, reactor=reactor, _debug_stash_wormhole=None):
"""I implement 'wormhole receive'. I return a Deferred that fires with
None (for success), or signals one of the following errors:
@ -60,16 +72,19 @@ class Receiver:
# tor in parallel with everything else, make sure the Tor object
# can lazy-provide an endpoint, and overlap the startup process
# with the user handing off the wormhole code
self._tor = yield get_tor(self._reactor,
self.args.launch_tor,
self.args.tor_control_port,
timing=self.args.timing)
self._tor = yield get_tor(
self._reactor,
self.args.launch_tor,
self.args.tor_control_port,
timing=self.args.timing)
w = create(self.args.appid or APPID, self.args.relay_url,
self._reactor,
tor=self._tor,
timing=self.args.timing)
self._w = w # so tests can wait on events too
w = create(
self.args.appid or APPID,
self.args.relay_url,
self._reactor,
tor=self._tor,
timing=self.args.timing)
self._w = w # so tests can wait on events too
# I wanted to do this instead:
#
@ -87,7 +102,7 @@ class Receiver:
# (which might be an error)
@inlineCallbacks
def _good(res):
yield w.close() # wait for ack
yield w.close() # wait for ack
returnValue(res)
# if we raise an error, we should close and then return the original
@ -96,8 +111,8 @@ class Receiver:
@inlineCallbacks
def _bad(f):
try:
yield w.close() # might be an error too
except:
yield w.close() # might be an error too
except Exception:
pass
returnValue(f)
@ -114,6 +129,7 @@ class Receiver:
def on_slow_key():
print(u"Waiting for sender...", file=self.args.stderr)
notify = self._reactor.callLater(KEY_TIMER, on_slow_key)
try:
# We wait here until we connect to the server and see the senders
@ -129,8 +145,10 @@ class Receiver:
notify.cancel()
def on_slow_verification():
print(u"Key established, waiting for confirmation...",
file=self.args.stderr)
print(
u"Key established, waiting for confirmation...",
file=self.args.stderr)
notify = self._reactor.callLater(VERIFY_TIMER, on_slow_verification)
try:
# We wait here until we've seen their VERSION message (which they
@ -155,7 +173,7 @@ class Receiver:
while True:
them_d = yield self._get_data(w)
#print("GOT", them_d)
# print("GOT", them_d)
recognized = False
if u"transit" in them_d:
recognized = True
@ -172,7 +190,7 @@ class Receiver:
raise TransferError(r.response)
returnValue(None)
if not recognized:
log.msg("unrecognized message %r" % (them_d,))
log.msg("unrecognized message %r" % (them_d, ))
def _send_data(self, data, w):
data_bytes = dict_to_bytes(data)
@ -197,12 +215,12 @@ class Receiver:
w.set_code(code)
else:
prompt = "Enter receive wormhole code: "
used_completion = yield input_with_completion(prompt,
w.input_code(),
self._reactor)
used_completion = yield input_with_completion(
prompt, w.input_code(), self._reactor)
if not used_completion:
print(" (note: you can use <Tab> to complete words)",
file=self.args.stderr)
print(
" (note: you can use <Tab> to complete words)",
file=self.args.stderr)
yield w.get_code()
def _show_verifier(self, verifier_bytes):
@ -220,21 +238,24 @@ class Receiver:
@inlineCallbacks
def _build_transit(self, w, sender_transit):
tr = TransitReceiver(self.args.transit_helper,
no_listen=(not self.args.listen),
tor=self._tor,
reactor=self._reactor,
timing=self.args.timing)
tr = TransitReceiver(
self.args.transit_helper,
no_listen=(not self.args.listen),
tor=self._tor,
reactor=self._reactor,
timing=self.args.timing)
self._transit_receiver = tr
transit_key = w.derive_key(APPID+u"/transit-key", tr.TRANSIT_KEY_LENGTH)
transit_key = w.derive_key(APPID + u"/transit-key",
tr.TRANSIT_KEY_LENGTH)
tr.set_transit_key(transit_key)
tr.add_connection_hints(sender_transit.get("hints-v1", []))
receiver_abilities = tr.get_connection_abilities()
receiver_hints = yield tr.get_connection_hints()
receiver_transit = {"abilities-v1": receiver_abilities,
"hints-v1": receiver_hints,
}
receiver_transit = {
"abilities-v1": receiver_abilities,
"hints-v1": receiver_hints,
}
self._send_data({u"transit": receiver_transit}, w)
# TODO: send more hints as the TransitReceiver produces them
@ -260,7 +281,7 @@ class Receiver:
yield self._close_transit(rp, datahash)
else:
self._msg(u"I don't know what they're offering\n")
self._msg(u"Offer details: %r" % (them_d,))
self._msg(u"Offer details: %r" % (them_d, ))
raise RespondError("unknown offer type")
def _handle_text(self, them_d, w):
@ -276,12 +297,13 @@ class Receiver:
self.xfersize = file_data["filesize"]
free = estimate_free_space(self.abs_destname)
if free is not None and free < self.xfersize:
self._msg(u"Error: insufficient free space (%sB) for file (%sB)"
% (free, self.xfersize))
self._msg(u"Error: insufficient free space (%sB) for file (%sB)" %
(free, self.xfersize))
raise TransferRejectedError()
self._msg(u"Receiving file (%s) into: %s" %
(naturalsize(self.xfersize), os.path.basename(self.abs_destname)))
(naturalsize(self.xfersize),
os.path.basename(self.abs_destname)))
self._ask_permission()
tmp_destname = self.abs_destname + ".tmp"
return open(tmp_destname, "wb")
@ -290,19 +312,22 @@ class Receiver:
file_data = them_d["directory"]
zipmode = file_data["mode"]
if zipmode != "zipfile/deflated":
self._msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,))
self._msg(u"Error: unknown directory-transfer mode '%s'" %
(zipmode, ))
raise RespondError("unknown mode")
self.abs_destname = self._decide_destname("directory",
file_data["dirname"])
self.xfersize = file_data["zipsize"]
free = estimate_free_space(self.abs_destname)
if free is not None and free < file_data["numbytes"]:
self._msg(u"Error: insufficient free space (%sB) for directory (%sB)"
% (free, file_data["numbytes"]))
self._msg(
u"Error: insufficient free space (%sB) for directory (%sB)" %
(free, file_data["numbytes"]))
raise TransferRejectedError()
self._msg(u"Receiving directory (%s) into: %s/" %
(naturalsize(self.xfersize), os.path.basename(self.abs_destname)))
(naturalsize(self.xfersize),
os.path.basename(self.abs_destname)))
self._msg(u"%d files, %s (uncompressed)" %
(file_data["numfiles"], naturalsize(file_data["numbytes"])))
self._ask_permission()
@ -313,23 +338,26 @@ class Receiver:
# "~/.ssh/authorized_keys" and other attacks
destname = os.path.basename(destname)
if self.args.output_file:
destname = self.args.output_file # override
abs_destname = os.path.abspath( os.path.join(self.args.cwd, destname) )
destname = self.args.output_file # override
abs_destname = os.path.abspath(os.path.join(self.args.cwd, destname))
# get confirmation from the user before writing to the local directory
if os.path.exists(abs_destname):
if self.args.output_file: # overwrite is intentional
if self.args.output_file: # overwrite is intentional
self._msg(u"Overwriting '%s'" % destname)
if self.args.accept_file:
self._remove_existing(abs_destname)
else:
self._msg(u"Error: refusing to overwrite existing '%s'" % destname)
self._msg(
u"Error: refusing to overwrite existing '%s'" % destname)
raise TransferRejectedError()
return abs_destname
def _remove_existing(self, path):
if os.path.isfile(path): os.remove(path)
if os.path.isdir(path): shutil.rmtree(path)
if os.path.isfile(path):
os.remove(path)
if os.path.isdir(path):
shutil.rmtree(path)
def _ask_permission(self):
with self.args.timing.add("permission", waiting="user") as t:
@ -345,7 +373,7 @@ class Receiver:
t.detail(answer="yes")
def _send_permission(self, w):
self._send_data({"answer": { "file_ack": "ok" }}, w)
self._send_data({"answer": {"file_ack": "ok"}}, w)
@inlineCallbacks
def _establish_transit(self):
@ -359,14 +387,16 @@ class Receiver:
self._msg(u"Receiving (%s).." % record_pipe.describe())
with self.args.timing.add("rx file"):
progress = tqdm(file=self.args.stderr,
disable=self.args.hide_progress,
unit="B", unit_scale=True, total=self.xfersize)
progress = tqdm(
file=self.args.stderr,
disable=self.args.hide_progress,
unit="B",
unit_scale=True,
total=self.xfersize)
hasher = hashlib.sha256()
with progress:
received = yield record_pipe.writeToFile(f, self.xfersize,
progress.update,
hasher.update)
received = yield record_pipe.writeToFile(
f, self.xfersize, progress.update, hasher.update)
datahash = hasher.digest()
# except TransitError
@ -382,25 +412,26 @@ class Receiver:
tmp_name = f.name
f.close()
os.rename(tmp_name, self.abs_destname)
self._msg(u"Received file written to %s" %
os.path.basename(self.abs_destname))
self._msg(u"Received file written to %s" % os.path.basename(
self.abs_destname))
def _extract_file(self, zf, info, extract_dir):
"""
the zipfile module does not restore file permissions
so we'll do it manually
"""
out_path = os.path.join( extract_dir, info.filename )
out_path = os.path.abspath( out_path )
if not out_path.startswith( extract_dir ):
raise ValueError( "malicious zipfile, %s outside of extract_dir %s"
% (info.filename, extract_dir) )
out_path = os.path.join(extract_dir, info.filename)
out_path = os.path.abspath(out_path)
if not out_path.startswith(extract_dir):
raise ValueError(
"malicious zipfile, %s outside of extract_dir %s" %
(info.filename, extract_dir))
zf.extract( info.filename, path=extract_dir )
zf.extract(info.filename, path=extract_dir)
# not sure why zipfiles store the perms 16 bits away but they do
perm = info.external_attr >> 16
os.chmod( out_path, perm )
os.chmod(out_path, perm)
def _write_directory(self, f):
@ -408,10 +439,10 @@ class Receiver:
with self.args.timing.add("unpack zip"):
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
for info in zf.infolist():
self._extract_file( zf, info, self.abs_destname )
self._extract_file(zf, info, self.abs_destname)
self._msg(u"Received files written to %s/" %
os.path.basename(self.abs_destname))
self._msg(u"Received files written to %s/" % os.path.basename(
self.abs_destname))
f.close()
@inlineCallbacks

View File

@ -1,20 +1,29 @@
from __future__ import print_function
import os, sys, six, tempfile, zipfile, hashlib
from tqdm import tqdm
import hashlib
import os
import sys
import tempfile
import zipfile
import six
from humanize import naturalsize
from twisted.python import log
from twisted.protocols import basic
from tqdm import tqdm
from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.protocols import basic
from twisted.python import log
from wormhole import __version__, create
from ..errors import TransferError, UnsendableFileError
from wormhole import create, __version__
from ..transit import TransitSender
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr
from ..util import bytes_to_dict, bytes_to_hexstr, dict_to_bytes
from .welcome import handle_welcome
APPID = u"lothar.com/wormhole/text-or-file-xfer"
VERIFY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_VERIFY_TIMER", 1.0))
def send(args, reactor=reactor):
"""I implement 'wormhole send'. I return a Deferred that fires with None
(for success), or signals one of the following errors:
@ -26,6 +35,7 @@ def send(args, reactor=reactor):
"""
return Sender(args, reactor).go()
class Sender:
def __init__(self, args, reactor):
self._args = args
@ -45,22 +55,25 @@ class Sender:
# tor in parallel with everything else, make sure the Tor object
# can lazy-provide an endpoint, and overlap the startup process
# with the user handing off the wormhole code
self._tor = yield get_tor(reactor,
self._args.launch_tor,
self._args.tor_control_port,
timing=self._timing)
self._tor = yield get_tor(
reactor,
self._args.launch_tor,
self._args.tor_control_port,
timing=self._timing)
w = create(self._args.appid or APPID, self._args.relay_url,
self._reactor,
tor=self._tor,
timing=self._timing)
w = create(
self._args.appid or APPID,
self._args.relay_url,
self._reactor,
tor=self._tor,
timing=self._timing)
d = self._go(w)
# if we succeed, we should close and return the w.close results
# (which might be an error)
@inlineCallbacks
def _good(res):
yield w.close() # wait for ack
yield w.close() # wait for ack
returnValue(res)
# if we raise an error, we should close and then return the original
@ -69,8 +82,8 @@ class Sender:
@inlineCallbacks
def _bad(f):
try:
yield w.close() # might be an error too
except:
yield w.close() # might be an error too
except Exception:
pass
returnValue(f)
@ -125,8 +138,10 @@ class Sender:
# TODO: don't stall on w.get_verifier() unless they want it
def on_slow_connection():
print(u"Key established, waiting for confirmation...",
file=args.stderr)
print(
u"Key established, waiting for confirmation...",
file=args.stderr)
notify = self._reactor.callLater(VERIFY_TIMER, on_slow_connection)
try:
# The usual sender-chooses-code sequence means the receiver's
@ -137,32 +152,35 @@ class Sender:
# sitting here for a while, so printing the "waiting" message
# seems like a good idea. It might even be appropriate to give up
# after a while.
verifier_bytes = yield w.get_verifier() # might WrongPasswordError
verifier_bytes = yield w.get_verifier() # might WrongPasswordError
finally:
if not notify.called:
notify.cancel()
if args.verify:
self._check_verifier(w, verifier_bytes) # blocks, can TransferError
self._check_verifier(w,
verifier_bytes) # blocks, can TransferError
if self._fd_to_send:
ts = TransitSender(args.transit_helper,
no_listen=(not args.listen),
tor=self._tor,
reactor=self._reactor,
timing=self._timing)
ts = TransitSender(
args.transit_helper,
no_listen=(not args.listen),
tor=self._tor,
reactor=self._reactor,
timing=self._timing)
self._transit_sender = ts
# for now, send this before the main offer
sender_abilities = ts.get_connection_abilities()
sender_hints = yield ts.get_connection_hints()
sender_transit = {"abilities-v1": sender_abilities,
"hints-v1": sender_hints,
}
sender_transit = {
"abilities-v1": sender_abilities,
"hints-v1": sender_hints,
}
self._send_data({u"transit": sender_transit}, w)
# TODO: move this down below w.get_message()
transit_key = w.derive_key(APPID+"/transit-key",
transit_key = w.derive_key(APPID + "/transit-key",
ts.TRANSIT_KEY_LENGTH)
ts.set_transit_key(transit_key)
@ -175,11 +193,11 @@ class Sender:
# TODO: get_message() fired, so get_verifier must have fired, so
# now it's safe to use w.derive_key()
them_d = bytes_to_dict(them_d_bytes)
#print("GOT", them_d)
# print("GOT", them_d)
recognized = False
if u"error" in them_d:
raise TransferError("remote error, transfer abandoned: %s"
% them_d["error"])
raise TransferError(
"remote error, transfer abandoned: %s" % them_d["error"])
if u"transit" in them_d:
recognized = True
yield self._handle_transit(them_d[u"transit"])
@ -191,7 +209,7 @@ class Sender:
yield self._handle_answer(them_d[u"answer"])
returnValue(None)
if not recognized:
log.msg("unrecognized message %r" % (them_d,))
log.msg("unrecognized message %r" % (them_d, ))
def _check_verifier(self, w, verifier_bytes):
verifier = bytes_to_hexstr(verifier_bytes)
@ -221,9 +239,10 @@ class Sender:
text = six.moves.input("Text to send: ")
if text is not None:
print(u"Sending text message (%s)" % naturalsize(len(text)),
file=args.stderr)
offer = { "message": text }
print(
u"Sending text message (%s)" % naturalsize(len(text)),
file=args.stderr)
offer = {"message": text}
fd_to_send = None
return offer, fd_to_send
@ -244,7 +263,7 @@ class Sender:
# is a symlink to something with a different name. The normpath() is
# there to remove trailing slashes.
basename = os.path.basename(os.path.normpath(what))
assert basename != "", what # normpath shouldn't allow this
assert basename != "", what # normpath shouldn't allow this
# We use realpath() instead of normpath() to locate the actual
# file/directory, because the path might contain symlinks, and
@ -273,8 +292,8 @@ class Sender:
what = os.path.realpath(what)
if not os.path.exists(what):
raise TransferError("Cannot send: no file/directory named '%s'" %
args.what)
raise TransferError(
"Cannot send: no file/directory named '%s'" % args.what)
if os.path.isfile(what):
# we're sending a file
@ -282,10 +301,11 @@ class Sender:
offer["file"] = {
"filename": basename,
"filesize": filesize,
}
print(u"Sending %s file named '%s'"
% (naturalsize(filesize), basename),
file=args.stderr)
}
print(
u"Sending %s file named '%s'" % (naturalsize(filesize),
basename),
file=args.stderr)
fd_to_send = open(what, "rb")
return offer, fd_to_send
@ -297,16 +317,18 @@ class Sender:
num_files = 0
num_bytes = 0
tostrip = len(what.split(os.sep))
with zipfile.ZipFile(fd_to_send, "w",
compression=zipfile.ZIP_DEFLATED,
allowZip64=True) as zf:
for path,dirs,files in os.walk(what):
with zipfile.ZipFile(
fd_to_send,
"w",
compression=zipfile.ZIP_DEFLATED,
allowZip64=True) as zf:
for path, dirs, files in os.walk(what):
# path always starts with args.what, then sometimes might
# have "/subdir" appended. We want the zipfile to contain
# "" or "subdir"
localpath = list(path.split(os.sep)[tostrip:])
for fn in files:
archivename = os.path.join(*tuple(localpath+[fn]))
archivename = os.path.join(*tuple(localpath + [fn]))
localfilename = os.path.join(path, fn)
try:
zf.write(localfilename, archivename)
@ -315,22 +337,25 @@ class Sender:
except OSError as e:
errmsg = u"{}: {}".format(fn, e.strerror)
if self._args.ignore_unsendable_files:
print(u"{} (ignoring error)".format(errmsg),
file=args.stderr)
print(
u"{} (ignoring error)".format(errmsg),
file=args.stderr)
else:
raise UnsendableFileError(errmsg)
fd_to_send.seek(0,2)
fd_to_send.seek(0, 2)
filesize = fd_to_send.tell()
fd_to_send.seek(0,0)
fd_to_send.seek(0, 0)
offer["directory"] = {
"mode": "zipfile/deflated",
"dirname": basename,
"zipsize": filesize,
"numbytes": num_bytes,
"numfiles": num_files,
}
print(u"Sending directory (%s compressed) named '%s'"
% (naturalsize(filesize), basename), file=args.stderr)
}
print(
u"Sending directory (%s compressed) named '%s'" %
(naturalsize(filesize), basename),
file=args.stderr)
return offer, fd_to_send
raise TypeError("'%s' is neither file nor directory" % args.what)
@ -340,23 +365,22 @@ class Sender:
if self._fd_to_send is None:
if them_answer["message_ack"] == "ok":
print(u"text message sent", file=self._args.stderr)
returnValue(None) # terminates this function
raise TransferError("error sending text: %r" % (them_answer,))
returnValue(None) # terminates this function
raise TransferError("error sending text: %r" % (them_answer, ))
if them_answer.get("file_ack") != "ok":
raise TransferError("ambiguous response from remote, "
"transfer abandoned: %s" % (them_answer,))
"transfer abandoned: %s" % (them_answer, ))
yield self._send_file()
@inlineCallbacks
def _send_file(self):
ts = self._transit_sender
self._fd_to_send.seek(0,2)
self._fd_to_send.seek(0, 2)
filesize = self._fd_to_send.tell()
self._fd_to_send.seek(0,0)
self._fd_to_send.seek(0, 0)
record_pipe = yield ts.connect()
self._timing.add("transit connected")
@ -365,21 +389,28 @@ class Sender:
print(u"Sending (%s).." % record_pipe.describe(), file=stderr)
hasher = hashlib.sha256()
progress = tqdm(file=stderr, disable=self._args.hide_progress,
unit="B", unit_scale=True,
total=filesize)
progress = tqdm(
file=stderr,
disable=self._args.hide_progress,
unit="B",
unit_scale=True,
total=filesize)
def _count_and_hash(data):
hasher.update(data)
progress.update(len(data))
return data
fs = basic.FileSender()
with self._timing.add("tx file"):
with progress:
if filesize:
# don't send zero-length files
yield fs.beginFileTransfer(self._fd_to_send, record_pipe,
transform=_count_and_hash)
yield fs.beginFileTransfer(
self._fd_to_send,
record_pipe,
transform=_count_and_hash)
expected_hash = hasher.digest()
expected_hex = bytes_to_hexstr(expected_hash)

View File

@ -1,16 +1,19 @@
from __future__ import print_function
import os
from os.path import expanduser, exists, join
from twisted.internet.defer import inlineCallbacks
from twisted.internet import reactor
from os.path import exists, expanduser, join
import click
from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks
from .. import xfer_util
class PubkeyError(Exception):
pass
def find_public_key(hint=None):
"""
This looks for an appropriate SSH key to send, possibly querying
@ -34,8 +37,9 @@ def find_public_key(hint=None):
got_key = False
while not got_key:
ans = click.prompt(
"Multiple public-keys found:\n" + \
"\n".join([" {}: {}".format(a, b) for a, b in enumerate(pubkeys)]) + \
"Multiple public-keys found:\n" +
"\n".join([" {}: {}".format(a, b)
for a, b in enumerate(pubkeys)]) +
"\nSend which one?"
)
try:
@ -76,7 +80,6 @@ def accept(cfg, reactor=reactor):
@inlineCallbacks
def invite(cfg, reactor=reactor):
def on_code_created(code):
print("Now tell the other user to run:")
print()

View File

@ -1,4 +1,3 @@
# This is a relay I run on a personal server. If it gets too expensive to
# run, I'll shut it down.
RENDEZVOUS_RELAY = u"ws://relay.magic-wormhole.io:4000/v1"

View File

@ -1,18 +1,23 @@
from __future__ import print_function, absolute_import, unicode_literals
from __future__ import absolute_import, print_function, unicode_literals
def handle_welcome(welcome, relay_url, my_version, stderr):
if "motd" in welcome:
motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" % (relay_url, motd_formatted),
file=stderr)
print(
"Server (at %s) says:\n %s" % (relay_url, motd_formatted),
file=stderr)
# Only warn if we're running a release version (e.g. 0.0.6, not
# 0.0.6+DISTANCE.gHASH). Only warn once.
if ("current_cli_version" in welcome
and "+" not in my_version
and welcome["current_cli_version"] != my_version):
print("Warning: errors may occur unless both sides are running the same version", file=stderr)
print("Server claims %s is current, but ours is %s"
% (welcome["current_cli_version"], my_version),
file=stderr)
if ("current_cli_version" in welcome and "+" not in my_version
and welcome["current_cli_version"] != my_version):
print(
("Warning: errors may occur unless both sides are running the"
" same version"),
file=stderr)
print(
"Server claims %s is current, but ours is %s" %
(welcome["current_cli_version"], my_version),
file=stderr)

View File

@ -1,8 +1,10 @@
from __future__ import unicode_literals
class WormholeError(Exception):
"""Parent class for all wormhole-related errors"""
class UnsendableFileError(Exception):
"""
A file you wanted to send couldn't be read, maybe because it's not
@ -13,30 +15,38 @@ class UnsendableFileError(Exception):
--ignore-unsendable-files flag.
"""
class ServerError(WormholeError):
"""The relay server complained about something we did."""
class ServerConnectionError(WormholeError):
"""We had a problem connecting to the relay server:"""
def __init__(self, url, reason):
self.url = url
self.reason = reason
def __str__(self):
return str(self.reason)
class Timeout(WormholeError):
pass
class WelcomeError(WormholeError):
"""
The relay server told us to signal an error, probably because our version
is too old to possibly work. The server said:"""
pass
class LonelyError(WormholeError):
"""wormhole.close() was called before the peer connection could be
established"""
class WrongPasswordError(WormholeError):
"""
Key confirmation failed. Either you or your correspondent typed the code
@ -47,6 +57,7 @@ class WrongPasswordError(WormholeError):
# or the data blob was corrupted, and that's why decrypt failed
pass
class KeyFormatError(WormholeError):
"""
The key you entered contains spaces or was missing a dash. Magic-wormhole
@ -55,43 +66,61 @@ class KeyFormatError(WormholeError):
dashes.
"""
class ReflectionAttack(WormholeError):
"""An attacker (or bug) reflected our outgoing message back to us."""
class InternalError(WormholeError):
"""The programmer did something wrong."""
class TransferError(WormholeError):
"""Something bad happened and the transfer failed."""
class NoTorError(WormholeError):
"""--tor was requested, but 'txtorcon' is not installed."""
class NoKeyError(WormholeError):
"""w.derive_key() was called before got_verifier() fired"""
class OnlyOneCodeError(WormholeError):
"""Only one w.generate_code/w.set_code/w.input_code may be called"""
class MustChooseNameplateFirstError(WormholeError):
"""The InputHelper was asked to do get_word_completions() or
choose_words() before the nameplate was chosen."""
class AlreadyChoseNameplateError(WormholeError):
"""The InputHelper was asked to do get_nameplate_completions() after
choose_nameplate() was called, or choose_nameplate() was called a second
time."""
class AlreadyChoseWordsError(WormholeError):
"""The InputHelper was asked to do get_word_completions() after
choose_words() was called, or choose_words() was called a second time."""
class AlreadyInputNameplateError(WormholeError):
"""The CodeInputter was asked to do completion on a nameplate, when we
had already committed to a different one."""
class WormholeClosed(Exception):
"""Deferred-returning API calls errback with WormholeClosed if the
wormhole was already closed, or if it closes before a real result can be
obtained."""
class _UnknownPhaseError(Exception):
"""internal exception type, for tests."""
class _UnknownMessageTypeError(Exception):
"""internal exception type, for tests."""

View File

@ -5,6 +5,7 @@ from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IReactorTime
from twisted.python import log
class EventualQueue(object):
def __init__(self, clock):
# pass clock=reactor unless you're testing
@ -14,7 +15,7 @@ class EventualQueue(object):
self._timer = None
def eventually(self, f, *args, **kwargs):
self._calls.append( (f, args, kwargs) )
self._calls.append((f, args, kwargs))
if not self._timer:
self._timer = self._clock.callLater(0, self._turn)
@ -28,7 +29,7 @@ class EventualQueue(object):
(f, args, kwargs) = self._calls.pop(0)
try:
f(*args, **kwargs)
except:
except Exception:
log.err()
self._timer = None
d, self._flush_d = self._flush_d, None

View File

@ -1,27 +1,37 @@
# no unicode_literals
# Find all of our ip addresses. From tahoe's src/allmydata/util/iputil.py
import os, re, subprocess, errno
import errno
import os
import re
import subprocess
from sys import platform
from twisted.python.procutils import which
# Wow, I'm really amazed at home much mileage we've gotten out of calling
# the external route.exe program on windows... It appears to work on all
# versions so far. Still, the real system calls would much be preferred...
# ... thus wrote Greg Smith in time immemorial...
_win32_re = re.compile(r'^\s*\d+\.\d+\.\d+\.\d+\s.+\s(?P<address>\d+\.\d+\.\d+\.\d+)\s+(?P<metric>\d+)\s*$', flags=re.M|re.I|re.S)
_win32_commands = (('route.exe', ('print',), _win32_re),)
_win32_re = re.compile(
(r'^\s*\d+\.\d+\.\d+\.\d+\s.+\s'
r'(?P<address>\d+\.\d+\.\d+\.\d+)\s+(?P<metric>\d+)\s*$'),
flags=re.M | re.I | re.S)
_win32_commands = (('route.exe', ('print', ), _win32_re), )
# These work in most Unices.
_addr_re = re.compile(r'^\s*inet [a-zA-Z]*:?(?P<address>\d+\.\d+\.\d+\.\d+)[\s/].+$', flags=re.M|re.I|re.S)
_unix_commands = (('/bin/ip', ('addr',), _addr_re),
('/sbin/ip', ('addr',), _addr_re),
('/sbin/ifconfig', ('-a',), _addr_re),
('/usr/sbin/ifconfig', ('-a',), _addr_re),
('/usr/etc/ifconfig', ('-a',), _addr_re),
('ifconfig', ('-a',), _addr_re),
('/sbin/ifconfig', (), _addr_re),
)
_addr_re = re.compile(
r'^\s*inet [a-zA-Z]*:?(?P<address>\d+\.\d+\.\d+\.\d+)[\s/].+$',
flags=re.M | re.I | re.S)
_unix_commands = (
('/bin/ip', ('addr', ), _addr_re),
('/sbin/ip', ('addr', ), _addr_re),
('/sbin/ifconfig', ('-a', ), _addr_re),
('/usr/sbin/ifconfig', ('-a', ), _addr_re),
('/usr/etc/ifconfig', ('-a', ), _addr_re),
('ifconfig', ('-a', ), _addr_re),
('/sbin/ifconfig', (), _addr_re),
)
def find_addresses():
@ -54,17 +64,19 @@ def find_addresses():
return ["127.0.0.1"]
def _query(path, args, regex):
env = {'LANG': 'en_US.UTF-8'}
trial = 0
while True:
trial += 1
try:
p = subprocess.Popen([path] + list(args),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
universal_newlines=True)
p = subprocess.Popen(
[path] + list(args),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
universal_newlines=True)
(output, err) = p.communicate()
break
except OSError as e:

View File

@ -1,8 +1,12 @@
from __future__ import print_function, absolute_import, unicode_literals
from zope.interface import implementer
from __future__ import absolute_import, print_function, unicode_literals
import contextlib
from zope.interface import implementer
from ._interfaces import IJournal
@implementer(IJournal)
class Journal(object):
def __init__(self, save_checkpoint):
@ -19,7 +23,7 @@ class Journal(object):
assert not self._processing
assert not self._outbound_queue
self._processing = True
yield # process inbound messages, change state, queue outbound
yield # process inbound messages, change state, queue outbound
self._save_checkpoint()
for (fn, args, kwargs) in self._outbound_queue:
fn(*args, **kwargs)
@ -31,8 +35,10 @@ class Journal(object):
class ImmediateJournal(object):
def __init__(self):
pass
def queue_outbound(self, fn, *args, **kwargs):
fn(*args, **kwargs)
@contextlib.contextmanager
def process(self):
yield

View File

@ -1,14 +1,16 @@
from __future__ import unicode_literals, print_function
from __future__ import print_function, unicode_literals
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
NoResult = object()
class OneShotObserver(object):
def __init__(self, eventual_queue):
self._eq = eventual_queue
self._result = NoResult
self._observers = [] # list of Deferreds
self._observers = [] # list of Deferreds
def when_fired(self):
d = Deferred()
@ -38,6 +40,7 @@ class OneShotObserver(object):
if self._result is NoResult:
self.fire(result)
class SequenceObserver(object):
def __init__(self, eventual_queue):
self._eq = eventual_queue

View File

@ -1,16 +1,19 @@
# no unicode_literals untill twisted update
from twisted.application import service, internet
from twisted.internet import defer, task, reactor, endpoints
from twisted.python import log
from click.testing import CliRunner
from twisted.application import internet, service
from twisted.internet import defer, endpoints, reactor, task
from twisted.python import log
import mock
from ..cli import cli
from ..transit import allocate_tcp_port
from wormhole_mailbox_server.database import create_channel_db, create_usage_db
from wormhole_mailbox_server.server import make_server
from wormhole_mailbox_server.web import make_web_server
from wormhole_mailbox_server.database import create_channel_db, create_usage_db
from wormhole_transit_relay.transit_server import Transit
from ..cli import cli
from ..transit import allocate_tcp_port
class MyInternetService(service.Service, object):
# like StreamServerEndpointService, but you can retrieve the port
def __init__(self, endpoint, factory):
@ -22,12 +25,15 @@ class MyInternetService(service.Service, object):
def startService(self):
super(MyInternetService, self).startService()
d = self.endpoint.listen(self.factory)
def good(lp):
self._lp = lp
self._port_d.callback(lp.getHost().port)
def bad(f):
log.err(f)
self._port_d.errback(f)
d.addCallbacks(good, bad)
@defer.inlineCallbacks
@ -35,9 +41,10 @@ class MyInternetService(service.Service, object):
if self._lp:
yield self._lp.stopListening()
def getPort(self): # only call once!
def getPort(self): # only call once!
return self._port_d
class ServerBase:
@defer.inlineCallbacks
def setUp(self):
@ -51,27 +58,27 @@ class ServerBase:
# endpoints.serverFromString
db = create_channel_db(":memory:")
self._usage_db = create_usage_db(":memory:")
self._rendezvous = make_server(db,
advertise_version=advertise_version,
signal_error=error,
usage_db=self._usage_db)
self._rendezvous = make_server(
db,
advertise_version=advertise_version,
signal_error=error,
usage_db=self._usage_db)
ep = endpoints.TCP4ServerEndpoint(reactor, 0, interface="127.0.0.1")
site = make_web_server(self._rendezvous, log_requests=False)
#self._lp = yield ep.listen(site)
# self._lp = yield ep.listen(site)
s = MyInternetService(ep, site)
s.setServiceParent(self.sp)
self.rdv_ws_port = yield s.getPort()
self._relay_server = s
#self._rendezvous = s._rendezvous
# self._rendezvous = s._rendezvous
self.relayurl = u"ws://127.0.0.1:%d/v1" % self.rdv_ws_port
# ws://127.0.0.1:%d/wormhole-relay/ws
self.transitport = allocate_tcp_port()
ep = endpoints.serverFromString(reactor,
"tcp:%d:interface=127.0.0.1" %
self.transitport)
self._transit_server = f = Transit(blur_usage=None, log_file=None,
usage_db=None)
ep = endpoints.serverFromString(
reactor, "tcp:%d:interface=127.0.0.1" % self.transitport)
self._transit_server = f = Transit(
blur_usage=None, log_file=None, usage_db=None)
internet.StreamServerEndpointService(ep, f).setServiceParent(self.sp)
self.transit = u"tcp:127.0.0.1:%d" % self.transitport
@ -109,6 +116,7 @@ class ServerBase:
" I convinced all threads to exit.")
yield d
def config(*argv):
r = CliRunner()
with mock.patch("wormhole.cli.cli.go") as go:
@ -121,6 +129,7 @@ def config(*argv):
cfg = go.call_args[0][1]
return cfg
@defer.inlineCallbacks
def poll_until(predicate):
# return a Deferred that won't fire until the predicate is True

View File

@ -1,4 +1,5 @@
from __future__ import unicode_literals
# This is a tiny helper module, to let "python -m wormhole.test.run_trial
# ARGS" does the same thing as running "trial ARGS" (unfortunately
# twisted/scripts/trial.py does not have a '__name__=="__main__"' clause).

View File

@ -1,15 +1,17 @@
import os
import sys
import mock
from twisted.trial import unittest
import mock
from ..cli.public_relay import RENDEZVOUS_RELAY, TRANSIT_RELAY
from .common import config
#from pprint import pprint
class Send(unittest.TestCase):
def test_baseline(self):
cfg = config("send", "--text", "hi")
#pprint(cfg.__dict__)
self.assertEqual(cfg.what, None)
self.assertEqual(cfg.code, None)
self.assertEqual(cfg.code_length, 2)
@ -32,7 +34,6 @@ class Send(unittest.TestCase):
def test_file(self):
cfg = config("send", "fn")
#pprint(cfg.__dict__)
self.assertEqual(cfg.what, u"fn")
self.assertEqual(cfg.text, None)
@ -101,7 +102,6 @@ class Send(unittest.TestCase):
class Receive(unittest.TestCase):
def test_baseline(self):
cfg = config("receive")
#pprint(cfg.__dict__)
self.assertEqual(cfg.accept_file, False)
self.assertEqual(cfg.code, None)
self.assertEqual(cfg.code_length, 2)
@ -191,10 +191,12 @@ class Receive(unittest.TestCase):
cfg = config("--transit-helper", transit_url_2, "receive")
self.assertEqual(cfg.transit_helper, transit_url_2)
class Config(unittest.TestCase):
def test_send(self):
cfg = config("send")
self.assertEqual(cfg.stdout, sys.stdout)
def test_receive(self):
cfg = config("receive")
self.assertEqual(cfg.stdout, sys.stdout)

View File

@ -1,22 +1,32 @@
from __future__ import print_function
import os, sys, re, io, zipfile, six, stat
from textwrap import fill, dedent
from humanize import naturalsize
import mock
import io
import os
import re
import stat
import sys
import zipfile
from textwrap import dedent, fill
import six
from click.testing import CliRunner
from zope.interface import implementer
from twisted.trial import unittest
from twisted.python import procutils, log
from humanize import naturalsize
from twisted.internet import endpoints, reactor
from twisted.internet.utils import getProcessOutputAndValue
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
from twisted.internet.error import ConnectionRefusedError
from twisted.internet.utils import getProcessOutputAndValue
from twisted.python import log, procutils
from twisted.trial import unittest
from zope.interface import implementer
import mock
from .. import __version__
from .common import ServerBase, config
from ..cli import cmd_send, cmd_receive, welcome, cli
from ..errors import (TransferError, WrongPasswordError, WelcomeError,
UnsendableFileError, ServerConnectionError)
from .._interfaces import ITorManager
from ..cli import cli, cmd_receive, cmd_send, welcome
from ..errors import (ServerConnectionError, TransferError,
UnsendableFileError, WelcomeError, WrongPasswordError)
from .common import ServerBase, config
def build_offer(args):
@ -108,8 +118,8 @@ class OfferData(unittest.TestCase):
self.cfg.cwd = send_dir
e = self.assertRaises(TransferError, build_offer, self.cfg)
self.assertEqual(str(e),
"Cannot send: no file/directory named '%s'" % filename)
self.assertEqual(
str(e), "Cannot send: no file/directory named '%s'" % filename)
def _do_test_directory(self, addslash):
parent_dir = self.mktemp()
@ -178,8 +188,8 @@ class OfferData(unittest.TestCase):
self.assertFalse(os.path.isdir(abs_filename))
e = self.assertRaises(TypeError, build_offer, self.cfg)
self.assertEqual(str(e),
"'%s' is neither file nor directory" % filename)
self.assertEqual(
str(e), "'%s' is neither file nor directory" % filename)
def test_symlink(self):
if not hasattr(os, 'symlink'):
@ -213,8 +223,9 @@ class OfferData(unittest.TestCase):
os.mkdir(os.path.join(parent_dir, "B2", "C2"))
with open(os.path.join(parent_dir, "B2", "D.txt"), "wb") as f:
f.write(b"success")
os.symlink(os.path.abspath(os.path.join(parent_dir, "B2", "C2")),
os.path.join(parent_dir, "B1", "C1"))
os.symlink(
os.path.abspath(os.path.join(parent_dir, "B2", "C2")),
os.path.join(parent_dir, "B1", "C1"))
# Now send "B1/C1/../D.txt" from A. The correct traversal will be:
# * start: A
# * B1: A/B1
@ -231,6 +242,7 @@ class OfferData(unittest.TestCase):
d, fd_to_send = build_offer(self.cfg)
self.assertEqual(d["file"]["filename"], "D.txt")
self.assertEqual(fd_to_send.read(), b"success")
if os.name == "nt":
test_symlink_collapse.todo = "host OS has broken os.path.realpath()"
# ntpath.py's realpath() is built out of normpath(), and does not
@ -241,6 +253,7 @@ class OfferData(unittest.TestCase):
# misbehavior (albeit in rare circumstances), 2: it probably used to
# work (sometimes, but not in #251). See cmd_send.py for more notes.
class LocaleFinder:
def __init__(self):
self._run_once = False
@ -266,10 +279,10 @@ class LocaleFinder:
# twisted.python.usage to avoid this problem in the future.
(out, err, rc) = yield getProcessOutputAndValue("locale", ["-a"])
if rc != 0:
log.msg("error running 'locale -a', rc=%s" % (rc,))
log.msg("stderr: %s" % (err,))
log.msg("error running 'locale -a', rc=%s" % (rc, ))
log.msg("stderr: %s" % (err, ))
returnValue(None)
out = out.decode("utf-8") # make sure we get a string
out = out.decode("utf-8") # make sure we get a string
utf8_locales = {}
for locale in out.splitlines():
locale = locale.strip()
@ -281,8 +294,11 @@ class LocaleFinder:
if utf8_locales:
returnValue(list(utf8_locales.values())[0])
returnValue(None)
locale_finder = LocaleFinder()
class ScriptsBase:
def find_executable(self):
# to make sure we're running the right executable (in a virtualenv),
@ -292,12 +308,13 @@ class ScriptsBase:
if not locations:
raise unittest.SkipTest("unable to find 'wormhole' in $PATH")
wormhole = locations[0]
if (os.path.dirname(os.path.abspath(wormhole)) !=
os.path.dirname(sys.executable)):
log.msg("locations: %s" % (locations,))
log.msg("sys.executable: %s" % (sys.executable,))
raise unittest.SkipTest("found the wrong 'wormhole' in $PATH: %s %s"
% (wormhole, sys.executable))
if (os.path.dirname(os.path.abspath(wormhole)) != os.path.dirname(
sys.executable)):
log.msg("locations: %s" % (locations, ))
log.msg("sys.executable: %s" % (sys.executable, ))
raise unittest.SkipTest(
"found the wrong 'wormhole' in $PATH: %s %s" %
(wormhole, sys.executable))
return wormhole
@inlineCallbacks
@ -323,8 +340,8 @@ class ScriptsBase:
raise unittest.SkipTest("unable to find UTF-8 locale")
locale_env = dict(LC_ALL=locale, LANG=locale)
wormhole = self.find_executable()
res = yield getProcessOutputAndValue(wormhole, ["--version"],
env=locale_env)
res = yield getProcessOutputAndValue(
wormhole, ["--version"], env=locale_env)
out, err, rc = res
if rc != 0:
log.msg("wormhole not runnable in this tree:")
@ -334,6 +351,7 @@ class ScriptsBase:
raise unittest.SkipTest("wormhole is not runnable in this tree")
returnValue(locale_env)
class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase):
# we need Twisted to run the server, but we run the sender and receiver
# with deferToThread()
@ -347,26 +365,30 @@ class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase):
wormhole = self.find_executable()
# we must pass on the environment so that "something" doesn't
# get sad about UTF8 vs. ascii encodings
out, err, rc = yield getProcessOutputAndValue(wormhole, ["--version"],
env=os.environ)
out, err, rc = yield getProcessOutputAndValue(
wormhole, ["--version"], env=os.environ)
err = err.decode("utf-8")
if "DistributionNotFound" in err:
log.msg("stderr was %s" % err)
last = err.strip().split("\n")[-1]
self.fail("wormhole not runnable: %s" % last)
ver = out.decode("utf-8") or err
self.failUnlessEqual(ver.strip(), "magic-wormhole {}".format(__version__))
self.failUnlessEqual(ver.strip(),
"magic-wormhole {}".format(__version__))
self.failUnlessEqual(rc, 0)
@implementer(ITorManager)
class FakeTor:
# use normal endpoints, but record the fact that we were asked
def __init__(self):
self.endpoints = []
def stream_via(self, host, port):
self.endpoints.append((host, port))
return endpoints.HostnameEndpoint(reactor, host, port)
class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
# we need Twisted to run the server, but we run the sender and receiver
# with deferToThread()
@ -377,11 +399,16 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
yield ServerBase.setUp(self)
@inlineCallbacks
def _do_test(self, as_subprocess=False,
mode="text", addslash=False, override_filename=False,
fake_tor=False, overwrite=False, mock_accept=False):
assert mode in ("text", "file", "empty-file", "directory",
"slow-text", "slow-sender-text")
def _do_test(self,
as_subprocess=False,
mode="text",
addslash=False,
override_filename=False,
fake_tor=False,
overwrite=False,
mock_accept=False):
assert mode in ("text", "file", "empty-file", "directory", "slow-text",
"slow-sender-text")
if fake_tor:
assert not as_subprocess
send_cfg = config("send")
@ -408,7 +435,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
elif mode in ("file", "empty-file"):
if mode == "empty-file":
message = ""
send_filename = u"testfil\u00EB" # e-with-diaeresis
send_filename = u"testfil\u00EB" # e-with-diaeresis
with open(os.path.join(send_dir, send_filename), "w") as f:
f.write(message)
send_cfg.what = send_filename
@ -433,8 +460,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
# expect: $receive_dir/$dirname/[12345]
send_dirname = u"testdir"
def message(i):
return "test message %d\n" % i
os.mkdir(os.path.join(send_dir, u"middle"))
source_dir = os.path.join(send_dir, u"middle", send_dirname)
os.mkdir(source_dir)
@ -476,21 +505,27 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
env["_MAGIC_WORMHOLE_TEST_KEY_TIMER"] = "999999"
env["_MAGIC_WORMHOLE_TEST_VERIFY_TIMER"] = "999999"
send_args = [
'--relay-url', self.relayurl,
'--transit-helper', '',
'send',
'--hide-progress',
'--code', send_cfg.code,
] + content_args
'--relay-url',
self.relayurl,
'--transit-helper',
'',
'send',
'--hide-progress',
'--code',
send_cfg.code,
] + content_args
send_d = getProcessOutputAndValue(
wormhole_bin, send_args,
wormhole_bin,
send_args,
path=send_dir,
env=env,
)
recv_args = [
'--relay-url', self.relayurl,
'--transit-helper', '',
'--relay-url',
self.relayurl,
'--transit-helper',
'',
'receive',
'--hide-progress',
'--accept-file',
@ -500,7 +535,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
recv_args.extend(['-o', receive_filename])
receive_d = getProcessOutputAndValue(
wormhole_bin, recv_args,
wormhole_bin,
recv_args,
path=receive_dir,
env=env,
)
@ -524,25 +560,27 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
send_cfg.tor = True
send_cfg.transit_helper = self.transit
tx_tm = FakeTor()
with mock.patch("wormhole.tor_manager.get_tor",
return_value=tx_tm,
) as mtx_tm:
with mock.patch(
"wormhole.tor_manager.get_tor",
return_value=tx_tm,
) as mtx_tm:
send_d = cmd_send.send(send_cfg)
recv_cfg.tor = True
recv_cfg.transit_helper = self.transit
rx_tm = FakeTor()
with mock.patch("wormhole.tor_manager.get_tor",
return_value=rx_tm,
) as mrx_tm:
with mock.patch(
"wormhole.tor_manager.get_tor",
return_value=rx_tm,
) as mrx_tm:
receive_d = cmd_receive.receive(recv_cfg)
else:
KEY_TIMER = 0 if mode == "slow-sender-text" else 99999
rxw = []
with mock.patch.object(cmd_receive, "KEY_TIMER", KEY_TIMER):
send_d = cmd_send.send(send_cfg)
receive_d = cmd_receive.receive(recv_cfg,
_debug_stash_wormhole=rxw)
receive_d = cmd_receive.receive(
recv_cfg, _debug_stash_wormhole=rxw)
# we need to keep KEY_TIMER patched until the receiver
# gets far enough to start the timer, which happens after
# the code is set
@ -555,8 +593,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
with mock.patch.object(cmd_receive, "VERIFY_TIMER", VERIFY_TIMER):
with mock.patch.object(cmd_send, "VERIFY_TIMER", VERIFY_TIMER):
if mock_accept:
with mock.patch.object(cmd_receive.six.moves,
'input', return_value='y'):
with mock.patch.object(
cmd_receive.six.moves, 'input',
return_value='y'):
yield gatherResults([send_d, receive_d], True)
else:
yield gatherResults([send_d, receive_d], True)
@ -567,14 +606,14 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
expected_endpoints.append(("127.0.0.1", self.transitport))
tx_timing = mtx_tm.call_args[1]["timing"]
self.assertEqual(tx_tm.endpoints, expected_endpoints)
self.assertEqual(mtx_tm.mock_calls,
[mock.call(reactor, False, None,
timing=tx_timing)])
self.assertEqual(
mtx_tm.mock_calls,
[mock.call(reactor, False, None, timing=tx_timing)])
rx_timing = mrx_tm.call_args[1]["timing"]
self.assertEqual(rx_tm.endpoints, expected_endpoints)
self.assertEqual(mrx_tm.mock_calls,
[mock.call(reactor, False, None,
timing=rx_timing)])
self.assertEqual(
mrx_tm.mock_calls,
[mock.call(reactor, False, None, timing=rx_timing)])
send_stdout = send_cfg.stdout.getvalue()
send_stderr = send_cfg.stderr.getvalue()
@ -585,7 +624,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
# newlines, even if we're on windows
NL = "\n"
self.maxDiff = None # show full output for assertion failures
self.maxDiff = None # show full output for assertion failures
key_established = ""
if mode == "slow-text":
@ -600,38 +639,41 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
"On the other computer, please run:{NL}{NL}"
"wormhole receive {code}{NL}{NL}"
"{KE}"
"text message sent{NL}").format(bytes=len(message),
code=send_cfg.code,
NL=NL,
KE=key_established)
"text message sent{NL}").format(
bytes=len(message),
code=send_cfg.code,
NL=NL,
KE=key_established)
self.failUnlessEqual(send_stderr, expected)
elif mode == "file":
self.failUnlessIn(u"Sending {size:s} file named '{name}'{NL}"
.format(size=naturalsize(len(message)),
name=send_filename,
NL=NL), send_stderr)
.format(
size=naturalsize(len(message)),
name=send_filename,
NL=NL), send_stderr)
self.failUnlessIn(u"Wormhole code is: {code}{NL}"
"On the other computer, please run:{NL}{NL}"
"wormhole receive {code}{NL}{NL}"
.format(code=send_cfg.code, NL=NL),
send_stderr)
self.failUnlessIn(u"File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}"
.format(NL=NL), send_stderr)
"wormhole receive {code}{NL}{NL}".format(
code=send_cfg.code, NL=NL), send_stderr)
self.failUnlessIn(
u"File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}".format(NL=NL),
send_stderr)
elif mode == "directory":
self.failUnlessIn(u"Sending directory", send_stderr)
self.failUnlessIn(u"named 'testdir'", send_stderr)
self.failUnlessIn(u"Wormhole code is: {code}{NL}"
"On the other computer, please run:{NL}{NL}"
"wormhole receive {code}{NL}{NL}"
.format(code=send_cfg.code, NL=NL), send_stderr)
self.failUnlessIn(u"File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}"
.format(NL=NL), send_stderr)
"wormhole receive {code}{NL}{NL}".format(
code=send_cfg.code, NL=NL), send_stderr)
self.failUnlessIn(
u"File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}".format(NL=NL),
send_stderr)
# check receiver
if mode in ("text", "slow-text", "slow-sender-text"):
self.assertEqual(receive_stdout, message+NL)
self.assertEqual(receive_stdout, message + NL)
if mode == "text":
self.assertEqual(receive_stderr, "")
elif mode == "slow-text":
@ -640,9 +682,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.assertEqual(receive_stderr, "Waiting for sender...\n")
elif mode == "file":
self.failUnlessEqual(receive_stdout, "")
self.failUnlessIn(u"Receiving file ({size:s}) into: {name}"
.format(size=naturalsize(len(message)),
name=receive_filename), receive_stderr)
self.failUnlessIn(u"Receiving file ({size:s}) into: {name}".format(
size=naturalsize(len(message)), name=receive_filename),
receive_stderr)
self.failUnlessIn(u"Received file written to ", receive_stderr)
fn = os.path.join(receive_dir, receive_filename)
self.failUnless(os.path.exists(fn))
@ -652,52 +694,67 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.failUnlessEqual(receive_stdout, "")
want = (r"Receiving directory \(\d+ \w+\) into: {name}/"
.format(name=receive_dirname))
self.failUnless(re.search(want, receive_stderr),
(want, receive_stderr))
self.failUnlessIn(u"Received files written to {name}"
.format(name=receive_dirname), receive_stderr)
self.failUnless(
re.search(want, receive_stderr), (want, receive_stderr))
self.failUnlessIn(
u"Received files written to {name}"
.format(name=receive_dirname),
receive_stderr)
fn = os.path.join(receive_dir, receive_dirname)
self.failUnless(os.path.exists(fn), fn)
for i in range(5):
fn = os.path.join(receive_dir, receive_dirname, str(i))
with open(fn, "r") as f:
self.failUnlessEqual(f.read(), message(i))
self.failUnlessEqual(modes[i],
stat.S_IMODE(os.stat(fn).st_mode))
self.failUnlessEqual(modes[i], stat.S_IMODE(
os.stat(fn).st_mode))
def test_text(self):
return self._do_test()
def test_text_subprocess(self):
return self._do_test(as_subprocess=True)
def test_text_tor(self):
return self._do_test(fake_tor=True)
def test_file(self):
return self._do_test(mode="file")
def test_file_override(self):
return self._do_test(mode="file", override_filename=True)
def test_file_overwrite(self):
return self._do_test(mode="file", overwrite=True)
def test_file_overwrite_mock_accept(self):
return self._do_test(mode="file", overwrite=True, mock_accept=True)
def test_file_tor(self):
return self._do_test(mode="file", fake_tor=True)
def test_empty_file(self):
return self._do_test(mode="empty-file")
def test_directory(self):
return self._do_test(mode="directory")
def test_directory_addslash(self):
return self._do_test(mode="directory", addslash=True)
def test_directory_override(self):
return self._do_test(mode="directory", override_filename=True)
def test_directory_overwrite(self):
return self._do_test(mode="directory", overwrite=True)
def test_directory_overwrite_mock_accept(self):
return self._do_test(mode="directory", overwrite=True, mock_accept=True)
return self._do_test(
mode="directory", overwrite=True, mock_accept=True)
def test_slow_text(self):
return self._do_test(mode="slow-text")
def test_slow_sender_text(self):
return self._do_test(mode="slow-sender-text")
@ -721,7 +778,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
os.mkdir(send_dir)
receive_dir = self.mktemp()
os.mkdir(receive_dir)
recv_cfg.accept_file = True # don't ask for permission
recv_cfg.accept_file = True # don't ask for permission
if mode == "file":
message = "test message\n"
@ -765,10 +822,12 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
free_space = 10000000
else:
free_space = 0
with mock.patch("wormhole.cli.cmd_receive.estimate_free_space",
return_value=free_space):
with mock.patch(
"wormhole.cli.cmd_receive.estimate_free_space",
return_value=free_space):
f = yield self.assertFailure(send_d, TransferError)
self.assertEqual(str(f), "remote error, transfer abandoned: transfer rejected")
self.assertEqual(
str(f), "remote error, transfer abandoned: transfer rejected")
f = yield self.assertFailure(receive_d, TransferError)
self.assertEqual(str(f), "transfer rejected")
@ -781,7 +840,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
# newlines, even if we're on windows
NL = "\n"
self.maxDiff = None # show full output for assertion failures
self.maxDiff = None # show full output for assertion failures
self.assertEqual(send_stdout, "")
self.assertEqual(receive_stdout, "")
@ -789,54 +848,63 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
# check sender
if mode == "file":
self.failUnlessIn("Sending {size:s} file named '{name}'{NL}"
.format(size=naturalsize(size),
name=send_filename,
NL=NL), send_stderr)
.format(
size=naturalsize(size),
name=send_filename,
NL=NL), send_stderr)
self.failUnlessIn("Wormhole code is: {code}{NL}"
"On the other computer, please run:{NL}{NL}"
"wormhole receive {code}{NL}"
.format(code=send_cfg.code, NL=NL),
send_stderr)
self.failIfIn("File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}"
.format(NL=NL), send_stderr)
"wormhole receive {code}{NL}".format(
code=send_cfg.code, NL=NL), send_stderr)
self.failIfIn(
"File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}".format(NL=NL),
send_stderr)
elif mode == "directory":
self.failUnlessIn("Sending directory", send_stderr)
self.failUnlessIn("named 'testdir'", send_stderr)
self.failUnlessIn("Wormhole code is: {code}{NL}"
"On the other computer, please run:{NL}{NL}"
"wormhole receive {code}{NL}"
.format(code=send_cfg.code, NL=NL), send_stderr)
self.failIfIn("File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}"
.format(NL=NL), send_stderr)
"wormhole receive {code}{NL}".format(
code=send_cfg.code, NL=NL), send_stderr)
self.failIfIn(
"File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}".format(NL=NL),
send_stderr)
# check receiver
if mode == "file":
self.failIfIn("Received file written to ", receive_stderr)
if failmode == "noclobber":
self.failUnlessIn("Error: "
"refusing to overwrite existing 'testfile'{NL}"
.format(NL=NL), receive_stderr)
self.failUnlessIn(
"Error: "
"refusing to overwrite existing 'testfile'{NL}"
.format(NL=NL),
receive_stderr)
else:
self.failUnlessIn("Error: "
"insufficient free space (0B) for file ({size:d}B){NL}"
.format(NL=NL, size=size), receive_stderr)
self.failUnlessIn(
"Error: "
"insufficient free space (0B) for file ({size:d}B){NL}"
.format(NL=NL, size=size), receive_stderr)
elif mode == "directory":
self.failIfIn("Received files written to {name}"
.format(name=receive_name), receive_stderr)
#want = (r"Receiving directory \(\d+ \w+\) into: {name}/"
self.failIfIn(
"Received files written to {name}".format(name=receive_name),
receive_stderr)
# want = (r"Receiving directory \(\d+ \w+\) into: {name}/"
# .format(name=receive_name))
#self.failUnless(re.search(want, receive_stderr),
# self.failUnless(re.search(want, receive_stderr),
# (want, receive_stderr))
if failmode == "noclobber":
self.failUnlessIn("Error: "
"refusing to overwrite existing 'testdir'{NL}"
.format(NL=NL), receive_stderr)
self.failUnlessIn(
"Error: "
"refusing to overwrite existing 'testdir'{NL}"
.format(NL=NL),
receive_stderr)
else:
self.failUnlessIn("Error: "
"insufficient free space (0B) for directory ({size:d}B){NL}"
.format(NL=NL, size=size), receive_stderr)
self.failUnlessIn(("Error: "
"insufficient free space (0B) for directory"
" ({size:d}B){NL}").format(
NL=NL, size=size), receive_stderr)
if failmode == "noclobber":
fn = os.path.join(receive_dir, receive_name)
@ -846,13 +914,17 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
def test_fail_file_noclobber(self):
return self._do_test_fail("file", "noclobber")
def test_fail_directory_noclobber(self):
return self._do_test_fail("directory", "noclobber")
def test_fail_file_toobig(self):
return self._do_test_fail("file", "toobig")
def test_fail_directory_toobig(self):
return self._do_test_fail("directory", "toobig")
class ZeroMode(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_text(self):
@ -871,8 +943,8 @@ class ZeroMode(ServerBase, unittest.TestCase):
send_cfg.text = message
#send_cfg.cwd = send_dir
#recv_cfg.cwd = receive_dir
# send_cfg.cwd = send_dir
# recv_cfg.cwd = receive_dir
send_d = cmd_send.send(send_cfg)
receive_d = cmd_receive.receive(recv_cfg)
@ -888,7 +960,7 @@ class ZeroMode(ServerBase, unittest.TestCase):
# newlines, even if we're on windows
NL = "\n"
self.maxDiff = None # show full output for assertion failures
self.maxDiff = None # show full output for assertion failures
self.assertEqual(send_stdout, "")
@ -898,15 +970,15 @@ class ZeroMode(ServerBase, unittest.TestCase):
"{NL}"
"wormhole receive -0{NL}"
"{NL}"
"text message sent{NL}").format(bytes=len(message),
code=send_cfg.code,
NL=NL)
"text message sent{NL}").format(
bytes=len(message), code=send_cfg.code, NL=NL)
self.failUnlessEqual(send_stderr, expected)
# check receiver
self.assertEqual(receive_stdout, message+NL)
self.assertEqual(receive_stdout, message + NL)
self.assertEqual(receive_stderr, "")
class NotWelcome(ServerBase, unittest.TestCase):
@inlineCallbacks
def setUp(self):
@ -936,6 +1008,7 @@ class NotWelcome(ServerBase, unittest.TestCase):
f = yield self.assertFailure(receive_d, WelcomeError)
self.assertEqual(str(f), "please upgrade XYZ")
class NoServer(ServerBase, unittest.TestCase):
@inlineCallbacks
def setUp(self):
@ -991,8 +1064,8 @@ class NoServer(ServerBase, unittest.TestCase):
e = yield self.assertFailure(receive_d, ServerConnectionError)
self.assertIsInstance(e.reason, ConnectionRefusedError)
class Cleanup(ServerBase, unittest.TestCase):
class Cleanup(ServerBase, unittest.TestCase):
def make_config(self):
cfg = config("send")
# common options for all tests in this suite
@ -1040,6 +1113,7 @@ class Cleanup(ServerBase, unittest.TestCase):
cids = self._rendezvous.get_app(cmd_send.APPID).get_nameplate_ids()
self.assertEqual(len(cids), 0)
class ExtractFile(unittest.TestCase):
def test_filenames(self):
args = mock.Mock()
@ -1066,8 +1140,8 @@ class ExtractFile(unittest.TestCase):
zf = mock.Mock()
zi = mock.Mock()
zi.filename = "haha//root" # abspath squashes this, hopefully zipfile
# does too
zi.filename = "haha//root" # abspath squashes this, hopefully zipfile
# does too
zi.external_attr = 5 << 16
expected = os.path.join(extract_dir, "haha", "root")
with mock.patch.object(cmd_receive.os, "chmod") as chmod:
@ -1082,6 +1156,7 @@ class ExtractFile(unittest.TestCase):
e = self.assertRaises(ValueError, ef, zf, zi, extract_dir)
self.assertIn("malicious zipfile", str(e))
class AppID(ServerBase, unittest.TestCase):
@inlineCallbacks
def setUp(self):
@ -1108,11 +1183,11 @@ class AppID(ServerBase, unittest.TestCase):
yield receive_d
used = self._usage_db.execute("SELECT DISTINCT `app_id`"
" FROM `nameplates`"
).fetchall()
" FROM `nameplates`").fetchall()
self.assertEqual(len(used), 1, used)
self.assertEqual(used[0]["app_id"], u"appid2")
class Welcome(unittest.TestCase):
def do(self, welcome_message, my_version="2.0"):
stderr = io.StringIO()
@ -1129,27 +1204,33 @@ class Welcome(unittest.TestCase):
def test_version_old(self):
stderr = self.do({"current_cli_version": "3.0"})
expected = ("Warning: errors may occur unless both sides are running the same version\n" +
expected = ("Warning: errors may occur unless both sides are"
" running the same version\n"
"Server claims 3.0 is current, but ours is 2.0\n")
self.assertEqual(stderr, expected)
def test_version_unreleased(self):
stderr = self.do({"current_cli_version": "3.0"},
my_version="2.5+middle.something")
stderr = self.do(
{
"current_cli_version": "3.0"
}, my_version="2.5+middle.something")
self.assertEqual(stderr, "")
def test_motd(self):
stderr = self.do({"motd": "hello"})
self.assertEqual(stderr, "Server (at url) says:\n hello\n")
class Dispatch(unittest.TestCase):
@inlineCallbacks
def test_success(self):
cfg = config("send")
cfg.stderr = io.StringIO()
called = []
def fake():
called.append(1)
yield cli._dispatch_command(reactor, cfg, fake)
self.assertEqual(called, [1])
self.assertEqual(cfg.stderr.getvalue(), "")
@ -1160,8 +1241,10 @@ class Dispatch(unittest.TestCase):
cfg.stderr = io.StringIO()
cfg.timing = mock.Mock()
cfg.dump_timing = "filename"
def fake():
pass
yield cli._dispatch_command(reactor, cfg, fake)
self.assertEqual(cfg.stderr.getvalue(), "")
self.assertEqual(cfg.timing.mock_calls[-1],
@ -1171,32 +1254,39 @@ class Dispatch(unittest.TestCase):
def test_wrong_password_error(self):
cfg = config("send")
cfg.stderr = io.StringIO()
def fake():
raise WrongPasswordError("abcd")
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
SystemExit)
expected = fill("ERROR: " + dedent(WrongPasswordError.__doc__))+"\n"
yield self.assertFailure(
cli._dispatch_command(reactor, cfg, fake), SystemExit)
expected = fill("ERROR: " + dedent(WrongPasswordError.__doc__)) + "\n"
self.assertEqual(cfg.stderr.getvalue(), expected)
@inlineCallbacks
def test_welcome_error(self):
cfg = config("send")
cfg.stderr = io.StringIO()
def fake():
raise WelcomeError("abcd")
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
SystemExit)
expected = fill("ERROR: " + dedent(WelcomeError.__doc__))+"\n\nabcd\n"
yield self.assertFailure(
cli._dispatch_command(reactor, cfg, fake), SystemExit)
expected = (
fill("ERROR: " + dedent(WelcomeError.__doc__)) + "\n\nabcd\n")
self.assertEqual(cfg.stderr.getvalue(), expected)
@inlineCallbacks
def test_transfer_error(self):
cfg = config("send")
cfg.stderr = io.StringIO()
def fake():
raise TransferError("abcd")
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
SystemExit)
yield self.assertFailure(
cli._dispatch_command(reactor, cfg, fake), SystemExit)
expected = "TransferError: abcd\n"
self.assertEqual(cfg.stderr.getvalue(), expected)
@ -1204,11 +1294,14 @@ class Dispatch(unittest.TestCase):
def test_server_connection_error(self):
cfg = config("send")
cfg.stderr = io.StringIO()
def fake():
raise ServerConnectionError("URL", ValueError("abcd"))
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
SystemExit)
expected = fill("ERROR: " + dedent(ServerConnectionError.__doc__))+"\n"
yield self.assertFailure(
cli._dispatch_command(reactor, cfg, fake), SystemExit)
expected = fill(
"ERROR: " + dedent(ServerConnectionError.__doc__)) + "\n"
expected += "(relay URL was URL)\n"
expected += "abcd\n"
self.assertEqual(cfg.stderr.getvalue(), expected)
@ -1217,21 +1310,26 @@ class Dispatch(unittest.TestCase):
def test_other_error(self):
cfg = config("send")
cfg.stderr = io.StringIO()
def fake():
raise ValueError("abcd")
# I'm seeing unicode problems with the Failure().printTraceback, and
# the output would be kind of unpredictable anyways, so we'll mock it
# out here.
f = mock.Mock()
def mock_print(file):
file.write(u"<TRACEBACK>\n")
f.printTraceback = mock_print
with mock.patch("wormhole.cli.cli.Failure", return_value=f):
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
SystemExit)
yield self.assertFailure(
cli._dispatch_command(reactor, cfg, fake), SystemExit)
expected = "<TRACEBACK>\nERROR: abcd\n"
self.assertEqual(cfg.stderr.getvalue(), expected)
class Help(unittest.TestCase):
def _check_top_level_help(self, got):
# the main wormhole.cli.cli.wormhole docstring should be in the

View File

@ -1,14 +1,19 @@
from __future__ import print_function, unicode_literals
import mock
from twisted.trial import unittest
from twisted.internet import reactor
from twisted.internet.task import Clock
from twisted.internet.defer import Deferred, inlineCallbacks
from twisted.internet.task import Clock
from twisted.trial import unittest
import mock
from ..eventual import EventualQueue
class IntentionalError(Exception):
pass
class Eventual(unittest.TestCase, object):
def test_eventually(self):
c = Clock()
@ -23,9 +28,10 @@ class Eventual(unittest.TestCase, object):
self.assertNoResult(d3)
eq.flush_sync()
self.assertEqual(c1.mock_calls,
[mock.call("arg1", "arg2", kwarg1="kw1"),
mock.call("arg3", "arg4", kwarg5="kw5")])
self.assertEqual(c1.mock_calls, [
mock.call("arg1", "arg2", kwarg1="kw1"),
mock.call("arg3", "arg4", kwarg5="kw5")
])
self.assertEqual(self.successResultOf(d2), None)
self.assertEqual(self.successResultOf(d3), "value")
@ -47,11 +53,12 @@ class Eventual(unittest.TestCase, object):
eq = EventualQueue(reactor)
d1 = eq.fire_eventually()
d2 = Deferred()
def _more(res):
eq.eventually(d2.callback, None)
d1.addCallback(_more)
yield eq.flush()
# d1 will fire, which will queue d2 to fire, and the flush() ought to
# wait for d2 too
self.successResultOf(d2)

View File

@ -1,44 +1,47 @@
from __future__ import print_function, unicode_literals
import unittest
from binascii import unhexlify #, hexlify
from binascii import unhexlify # , hexlify
from hkdf import Hkdf
#def generate_KAT():
# print("KAT = [")
# for salt in (b"", b"salt"):
# for context in (b"", b"context"):
# skm = b"secret"
# out = HKDF(skm, 64, XTS=salt, CTXinfo=context)
# hexout = " '%s' +\n '%s'" % (hexlify(out[:32]),
# hexlify(out[32:]))
# print(" (%r, %r, %r,\n%s)," % (salt, context, skm, hexout))
# print("]")
# def generate_KAT():
# print("KAT = [")
# for salt in (b"", b"salt"):
# for context in (b"", b"context"):
# skm = b"secret"
# out = HKDF(skm, 64, XTS=salt, CTXinfo=context)
# hexout = " '%s' +\n '%s'" % (hexlify(out[:32]),
# hexlify(out[32:]))
# print(" (%r, %r, %r,\n%s)," % (salt, context, skm, hexout))
# print("]")
KAT = [
('', '', 'secret',
'2f34e5ff91ec85d53ca9b543683174d0cf550b60d5f52b24c97b386cfcf6cbbf' +
'9cfd42fd37e1e5a214d15f03058d7fee63dc28f564b7b9fe3da514f80daad4bf'),
('', 'context', 'secret',
'c24c303a1adfb4c3e2b092e6254ed481c41d8955ba8ec3f6a1473493a60c957b' +
'31b723018ca75557214d3d5c61c0c7a5315b103b21ff00cb03ebe023dc347a47'),
('salt', '', 'secret',
'f1156507c39b0e326159e778696253122de430899a8df2484040a85a5f95ceb1' +
'dfca555d4cc603bdf7153ed1560de8cbc3234b27a6d2be8e8ca202d90649679a'),
('salt', 'context', 'secret',
'61a4f201a867bcc12381ddb180d27074408d03ee9d5750855e5a12d967fa060f' +
'10336ead9370927eaabb0d60b259346ee5f57eb7ceba8c72f1ed3f2932b1bf19'),
('', '', 'secret',
'2f34e5ff91ec85d53ca9b543683174d0cf550b60d5f52b24c97b386cfcf6cbbf' +
'9cfd42fd37e1e5a214d15f03058d7fee63dc28f564b7b9fe3da514f80daad4bf'),
('', 'context', 'secret',
'c24c303a1adfb4c3e2b092e6254ed481c41d8955ba8ec3f6a1473493a60c957b' +
'31b723018ca75557214d3d5c61c0c7a5315b103b21ff00cb03ebe023dc347a47'),
('salt', '', 'secret',
'f1156507c39b0e326159e778696253122de430899a8df2484040a85a5f95ceb1' +
'dfca555d4cc603bdf7153ed1560de8cbc3234b27a6d2be8e8ca202d90649679a'),
('salt', 'context', 'secret',
'61a4f201a867bcc12381ddb180d27074408d03ee9d5750855e5a12d967fa060f' +
'10336ead9370927eaabb0d60b259346ee5f57eb7ceba8c72f1ed3f2932b1bf19'),
]
class TestKAT(unittest.TestCase):
# note: this uses SHA256
def test_kat(self):
for (salt, context, skm, expected_hexout) in KAT:
expected_out = unhexlify(expected_hexout)
for outlen in range(0, len(expected_out)):
out = Hkdf(salt.encode("ascii"),
skm.encode("ascii")).expand(context.encode("ascii"),
outlen)
out = Hkdf(salt.encode("ascii"), skm.encode("ascii")).expand(
context.encode("ascii"), outlen)
self.assertEqual(out, expected_out[:outlen])
#if __name__ == '__main__':
# generate_KAT()
# if __name__ == '__main__':
# generate_KAT()

View File

@ -1,9 +1,13 @@
import errno
import os
import re
import subprocess
import re, errno, subprocess, os
from twisted.trial import unittest
from .. import ipaddrs
DOTTED_QUAD_RE=re.compile("^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$")
DOTTED_QUAD_RE = re.compile("^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$")
MOCK_IPADDR_OUTPUT = """\
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 16436 qdisc noqueue state UNKNOWN \n\
@ -11,12 +15,14 @@ MOCK_IPADDR_OUTPUT = """\
inet 127.0.0.1/8 scope host lo
inet6 ::1/128 scope host \n\
valid_lft forever preferred_lft forever
2: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc pfifo_fast state UP qlen 1000
2: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc pfifo_fast state UP \
qlen 1000
link/ether d4:3d:7e:01:b4:3e brd ff:ff:ff:ff:ff:ff
inet 192.168.0.6/24 brd 192.168.0.255 scope global eth1
inet6 fe80::d63d:7eff:fe01:b43e/64 scope link \n\
valid_lft forever preferred_lft forever
3: wlan0: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc mq state UP qlen 1000
3: wlan0: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc mq state UP qlen\
1000
link/ether 90:f6:52:27:15:0a brd ff:ff:ff:ff:ff:ff
inet 192.168.0.2/24 brd 192.168.0.255 scope global wlan0
inet6 fe80::92f6:52ff:fe27:150a/64 scope link \n\
@ -58,7 +64,8 @@ MOCK_ROUTE_OUTPUT = """\
===========================================================================
Interface List
0x1 ........................... MS TCP Loopback interface
0x2 ...08 00 27 c3 80 ad ...... AMD PCNET Family PCI Ethernet Adapter - Packet Scheduler Miniport
0x2 ...08 00 27 c3 80 ad ...... AMD PCNET Family PCI Ethernet Adapter - \
Packet Scheduler Miniport
===========================================================================
===========================================================================
Active Routes:
@ -85,6 +92,7 @@ class FakeProcess:
def __init__(self, output, err):
self.output = output
self.err = err
def communicate(self):
return (self.output, self.err)
@ -94,16 +102,28 @@ class ListAddresses(unittest.TestCase):
addresses = ipaddrs.find_addresses()
self.failUnlessIn("127.0.0.1", addresses)
self.failIfIn("0.0.0.0", addresses)
# David A.'s OpenSolaris box timed out on this test one time when it was at
# 2s.
test_list.timeout=4
test_list.timeout = 4
def _test_list_mock(self, command, output, expected):
self.first = True
def call_Popen(args, bufsize=0, executable=None, stdin=None, stdout=None, stderr=None,
preexec_fn=None, close_fds=False, shell=False, cwd=None, env=None,
universal_newlines=False, startupinfo=None, creationflags=0):
def call_Popen(args,
bufsize=0,
executable=None,
stdin=None,
stdout=None,
stderr=None,
preexec_fn=None,
close_fds=False,
shell=False,
cwd=None,
env=None,
universal_newlines=False,
startupinfo=None,
creationflags=0):
if self.first:
self.first = False
e = OSError("EINTR")
@ -115,11 +135,13 @@ class ListAddresses(unittest.TestCase):
e = OSError("[Errno 2] No such file or directory")
e.errno = errno.ENOENT
raise e
self.patch(subprocess, 'Popen', call_Popen)
self.patch(os.path, 'isfile', lambda x: True)
def call_which(name):
return [name]
self.patch(ipaddrs, 'which', call_which)
addresses = ipaddrs.find_addresses()
@ -131,11 +153,13 @@ class ListAddresses(unittest.TestCase):
def test_list_mock_ifconfig(self):
self.patch(ipaddrs, 'platform', "linux2")
self._test_list_mock("ifconfig", MOCK_IFCONFIG_OUTPUT, UNIX_TEST_ADDRESSES)
self._test_list_mock("ifconfig", MOCK_IFCONFIG_OUTPUT,
UNIX_TEST_ADDRESSES)
def test_list_mock_route(self):
self.patch(ipaddrs, 'platform', "win32")
self._test_list_mock("route.exe", MOCK_ROUTE_OUTPUT, WINDOWS_TEST_ADDRESSES)
self._test_list_mock("route.exe", MOCK_ROUTE_OUTPUT,
WINDOWS_TEST_ADDRESSES)
def test_list_mock_cygwin(self):
self.patch(ipaddrs, 'platform', "cygwin")

View File

@ -1,8 +1,11 @@
from __future__ import print_function, absolute_import, unicode_literals
from __future__ import absolute_import, print_function, unicode_literals
from twisted.trial import unittest
from .. import journal
from .._interfaces import IJournal
class Journal(unittest.TestCase):
def test_journal(self):
events = []

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,11 @@
from twisted.trial import unittest
from twisted.internet.task import Clock
from twisted.python.failure import Failure
from twisted.trial import unittest
from ..eventual import EventualQueue
from ..observer import OneShotObserver, SequenceObserver
class OneShot(unittest.TestCase):
def test_fire(self):
c = Clock()
@ -119,4 +121,3 @@ class Sequence(unittest.TestCase):
d2 = o.when_next_event()
eq.flush_sync()
self.assertIdentical(self.failureResultOf(d2), f)

View File

@ -1,39 +1,44 @@
from __future__ import print_function, absolute_import, unicode_literals
import mock
from __future__ import absolute_import, print_function, unicode_literals
from itertools import count
from twisted.trial import unittest
from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks
from twisted.internet.threads import deferToThread
from .._rlcompleter import (input_with_completion,
_input_code_with_completion,
CodeInputter, warn_readline)
from ..errors import KeyFormatError, AlreadyInputNameplateError
from twisted.trial import unittest
import mock
from .._rlcompleter import (CodeInputter, _input_code_with_completion,
input_with_completion, warn_readline)
from ..errors import AlreadyInputNameplateError, KeyFormatError
APPID = "appid"
class Input(unittest.TestCase):
@inlineCallbacks
def test_wrapper(self):
helper = object()
trueish = object()
with mock.patch("wormhole._rlcompleter._input_code_with_completion",
return_value=trueish) as m:
used_completion = yield input_with_completion("prompt:", helper,
reactor)
with mock.patch(
"wormhole._rlcompleter._input_code_with_completion",
return_value=trueish) as m:
used_completion = yield input_with_completion(
"prompt:", helper, reactor)
self.assertIs(used_completion, trueish)
self.assertEqual(m.mock_calls,
[mock.call("prompt:", helper, reactor)])
self.assertEqual(m.mock_calls, [mock.call("prompt:", helper, reactor)])
# note: if this test fails, the warn_readline() message will probably
# get written to stderr
class Sync(unittest.TestCase):
# exercise _input_code_with_completion, which uses the blocking builtin
# "input()" function, hence _input_code_with_completion is usually in a
# thread with deferToThread
@mock.patch("wormhole._rlcompleter.CodeInputter")
@mock.patch("wormhole._rlcompleter.readline",
__doc__="I am GNU readline")
@mock.patch("wormhole._rlcompleter.readline", __doc__="I am GNU readline")
@mock.patch("wormhole._rlcompleter.input", return_value="code")
def test_readline(self, input, readline, ci):
c = mock.Mock(name="inhibit parenting")
@ -49,17 +54,17 @@ class Sync(unittest.TestCase):
self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)])
self.assertEqual(c.mock_calls, [mock.call.finish("code")])
self.assertEqual(input.mock_calls, [mock.call(prompt)])
self.assertEqual(readline.mock_calls,
[mock.call.parse_and_bind("tab: complete"),
mock.call.set_completer(c.completer),
mock.call.set_completer_delims(""),
])
self.assertEqual(readline.mock_calls, [
mock.call.parse_and_bind("tab: complete"),
mock.call.set_completer(c.completer),
mock.call.set_completer_delims(""),
])
@mock.patch("wormhole._rlcompleter.CodeInputter")
@mock.patch("wormhole._rlcompleter.readline")
@mock.patch("wormhole._rlcompleter.input", return_value="code")
def test_readline_no_docstring(self, input, readline, ci):
del readline.__doc__ # when in doubt, it assumes GNU readline
del readline.__doc__ # when in doubt, it assumes GNU readline
c = mock.Mock(name="inhibit parenting")
c.completer = object()
trueish = object()
@ -73,15 +78,14 @@ class Sync(unittest.TestCase):
self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)])
self.assertEqual(c.mock_calls, [mock.call.finish("code")])
self.assertEqual(input.mock_calls, [mock.call(prompt)])
self.assertEqual(readline.mock_calls,
[mock.call.parse_and_bind("tab: complete"),
mock.call.set_completer(c.completer),
mock.call.set_completer_delims(""),
])
self.assertEqual(readline.mock_calls, [
mock.call.parse_and_bind("tab: complete"),
mock.call.set_completer(c.completer),
mock.call.set_completer_delims(""),
])
@mock.patch("wormhole._rlcompleter.CodeInputter")
@mock.patch("wormhole._rlcompleter.readline",
__doc__="I am libedit")
@mock.patch("wormhole._rlcompleter.readline", __doc__="I am libedit")
@mock.patch("wormhole._rlcompleter.input", return_value="code")
def test_libedit(self, input, readline, ci):
c = mock.Mock(name="inhibit parenting")
@ -97,11 +101,11 @@ class Sync(unittest.TestCase):
self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)])
self.assertEqual(c.mock_calls, [mock.call.finish("code")])
self.assertEqual(input.mock_calls, [mock.call(prompt)])
self.assertEqual(readline.mock_calls,
[mock.call.parse_and_bind("bind ^I rl_complete"),
mock.call.set_completer(c.completer),
mock.call.set_completer_delims(""),
])
self.assertEqual(readline.mock_calls, [
mock.call.parse_and_bind("bind ^I rl_complete"),
mock.call.set_completer(c.completer),
mock.call.set_completer_delims(""),
])
@mock.patch("wormhole._rlcompleter.CodeInputter")
@mock.patch("wormhole._rlcompleter.readline", None)
@ -139,6 +143,7 @@ class Sync(unittest.TestCase):
self.assertEqual(c.mock_calls, [mock.call.finish(u"code")])
self.assertEqual(input.mock_calls, [mock.call(prompt)])
def get_completions(c, prefix):
completions = []
for state in count(0):
@ -147,9 +152,11 @@ def get_completions(c, prefix):
return completions
completions.append(text)
def fake_blockingCallFromThread(f, *a, **kw):
return f(*a, **kw)
class Completion(unittest.TestCase):
def test_simple(self):
# no actual completion
@ -158,12 +165,14 @@ class Completion(unittest.TestCase):
c.bcft = fake_blockingCallFromThread
c.finish("1-code-ghost")
self.assertFalse(c.used_completion)
self.assertEqual(helper.mock_calls,
[mock.call.choose_nameplate("1"),
mock.call.choose_words("code-ghost")])
self.assertEqual(helper.mock_calls, [
mock.call.choose_nameplate("1"),
mock.call.choose_words("code-ghost")
])
@mock.patch("wormhole._rlcompleter.readline",
get_completion_type=mock.Mock(return_value=0))
@mock.patch(
"wormhole._rlcompleter.readline",
get_completion_type=mock.Mock(return_value=0))
def test_call(self, readline):
# check that it calls _commit_and_build_completions correctly
helper = mock.Mock()
@ -188,14 +197,14 @@ class Completion(unittest.TestCase):
# now we have three "a" words: "and", "ark", "aaah!zombies!!"
cabc.reset_mock()
cabc.configure_mock(return_value=["aargh", "ark", "aaah!zombies!!"])
self.assertEqual(get_completions(c, "12-a"),
["aargh", "ark", "aaah!zombies!!"])
self.assertEqual(
get_completions(c, "12-a"), ["aargh", "ark", "aaah!zombies!!"])
self.assertEqual(cabc.mock_calls, [mock.call("12-a")])
cabc.reset_mock()
cabc.configure_mock(return_value=["aargh", "aaah!zombies!!"])
self.assertEqual(get_completions(c, "12-aa"),
["aargh", "aaah!zombies!!"])
self.assertEqual(
get_completions(c, "12-aa"), ["aargh", "aaah!zombies!!"])
self.assertEqual(cabc.mock_calls, [mock.call("12-aa")])
cabc.reset_mock()
@ -223,16 +232,17 @@ class Completion(unittest.TestCase):
def test_build_completions(self):
rn = mock.Mock()
# InputHelper.get_nameplate_completions returns just the suffixes
gnc = mock.Mock() # get_nameplate_completions
cn = mock.Mock() # choose_nameplate
gwc = mock.Mock() # get_word_completions
cw = mock.Mock() # choose_words
helper = mock.Mock(refresh_nameplates=rn,
get_nameplate_completions=gnc,
choose_nameplate=cn,
get_word_completions=gwc,
choose_words=cw,
)
gnc = mock.Mock() # get_nameplate_completions
cn = mock.Mock() # choose_nameplate
gwc = mock.Mock() # get_word_completions
cw = mock.Mock() # choose_words
helper = mock.Mock(
refresh_nameplates=rn,
get_nameplate_completions=gnc,
choose_nameplate=cn,
get_word_completions=gwc,
choose_words=cw,
)
# this needs a real reactor, for blockingCallFromThread
c = CodeInputter(helper, reactor)
cabc = c._commit_and_build_completions
@ -327,17 +337,18 @@ class Completion(unittest.TestCase):
gwc.configure_mock(return_value=["code", "court"])
c = CodeInputter(helper, reactor)
cabc = c._commit_and_build_completions
matches = yield deferToThread(cabc, "1-co") # this commits us to 1-
self.assertEqual(helper.mock_calls,
[mock.call.choose_nameplate("1"),
mock.call.when_wordlist_is_available(),
mock.call.get_word_completions("co")])
matches = yield deferToThread(cabc, "1-co") # this commits us to 1-
self.assertEqual(helper.mock_calls, [
mock.call.choose_nameplate("1"),
mock.call.when_wordlist_is_available(),
mock.call.get_word_completions("co")
])
self.assertEqual(matches, ["1-code", "1-court"])
helper.reset_mock()
with self.assertRaises(AlreadyInputNameplateError) as e:
yield deferToThread(cabc, "2-co")
self.assertEqual(str(e.exception),
"nameplate (1-) already entered, cannot go back")
self.assertEqual(
str(e.exception), "nameplate (1-) already entered, cannot go back")
self.assertEqual(helper.mock_calls, [])
@inlineCallbacks
@ -347,17 +358,18 @@ class Completion(unittest.TestCase):
gwc.configure_mock(return_value=["code", "court"])
c = CodeInputter(helper, reactor)
cabc = c._commit_and_build_completions
matches = yield deferToThread(cabc, "1-co") # this commits us to 1-
self.assertEqual(helper.mock_calls,
[mock.call.choose_nameplate("1"),
mock.call.when_wordlist_is_available(),
mock.call.get_word_completions("co")])
matches = yield deferToThread(cabc, "1-co") # this commits us to 1-
self.assertEqual(helper.mock_calls, [
mock.call.choose_nameplate("1"),
mock.call.when_wordlist_is_available(),
mock.call.get_word_completions("co")
])
self.assertEqual(matches, ["1-code", "1-court"])
helper.reset_mock()
with self.assertRaises(AlreadyInputNameplateError) as e:
yield deferToThread(c.finish, "2-code")
self.assertEqual(str(e.exception),
"nameplate (1-) already entered, cannot go back")
self.assertEqual(
str(e.exception), "nameplate (1-) already entered, cannot go back")
self.assertEqual(helper.mock_calls, [])
@mock.patch("wormhole._rlcompleter.stderr")
@ -366,6 +378,7 @@ class Completion(unittest.TestCase):
# right time, since it involves a reactor and a "system event
# trigger", but let's at least make sure it's invocable
warn_readline()
expected ="\nCommand interrupted: please press Return to quit"
self.assertEqual(stderr.mock_calls, [mock.call.write(expected),
mock.call.write("\n")])
expected = "\nCommand interrupted: please press Return to quit"
self.assertEqual(stderr.mock_calls,
[mock.call.write(expected),
mock.call.write("\n")])

View File

@ -1,10 +1,15 @@
import os, io
import mock
import io
import os
from twisted.trial import unittest
import mock
from ..cli import cmd_ssh
OTHERS = ["config", "config~", "known_hosts", "known_hosts~"]
class FindPubkey(unittest.TestCase):
def test_find_one(self):
files = OTHERS + ["id_rsa.pub", "id_rsa"]
@ -12,8 +17,8 @@ class FindPubkey(unittest.TestCase):
pubkey_file = io.StringIO(pubkey_data)
with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True):
with mock.patch("os.listdir", return_value=files) as ld:
with mock.patch("wormhole.cli.cmd_ssh.open",
return_value=pubkey_file):
with mock.patch(
"wormhole.cli.cmd_ssh.open", return_value=pubkey_file):
res = cmd_ssh.find_public_key()
self.assertEqual(ld.mock_calls,
[mock.call(os.path.expanduser("~/.ssh/"))])
@ -24,7 +29,7 @@ class FindPubkey(unittest.TestCase):
self.assertEqual(pubkey, pubkey_data)
def test_find_none(self):
files = OTHERS # no pubkey
files = OTHERS # no pubkey
with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True):
with mock.patch("os.listdir", return_value=files):
e = self.assertRaises(cmd_ssh.PubkeyError,
@ -34,12 +39,12 @@ class FindPubkey(unittest.TestCase):
def test_bad_hint(self):
with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=False):
e = self.assertRaises(cmd_ssh.PubkeyError,
cmd_ssh.find_public_key,
hint="bogus/path")
e = self.assertRaises(
cmd_ssh.PubkeyError,
cmd_ssh.find_public_key,
hint="bogus/path")
self.assertEqual(str(e), "Can't find 'bogus/path'")
def test_find_multiple(self):
files = OTHERS + ["id_rsa.pub", "id_rsa", "id_dsa.pub", "id_dsa"]
pubkey_data = u"ssh-rsa AAAAkeystuff email@host\n"
@ -47,10 +52,11 @@ class FindPubkey(unittest.TestCase):
with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True):
with mock.patch("os.listdir", return_value=files):
responses = iter(["frog", "NaN", "-1", "0"])
with mock.patch("click.prompt",
side_effect=lambda p: next(responses)):
with mock.patch("wormhole.cli.cmd_ssh.open",
return_value=pubkey_file):
with mock.patch(
"click.prompt", side_effect=lambda p: next(responses)):
with mock.patch(
"wormhole.cli.cmd_ssh.open",
return_value=pubkey_file):
res = cmd_ssh.find_public_key()
self.assertEqual(len(res), 3, res)
kind, keyid, pubkey = res

View File

@ -1,42 +1,51 @@
from __future__ import print_function, unicode_literals
import mock, io
from twisted.trial import unittest
import io
from twisted.internet import defer
from twisted.internet.error import ConnectError
from twisted.trial import unittest
import mock
from ..tor_manager import get_tor, SocksOnlyTor
from ..errors import NoTorError
from .._interfaces import ITorManager
from ..errors import NoTorError
from ..tor_manager import SocksOnlyTor, get_tor
class X():
pass
class Tor(unittest.TestCase):
def test_no_txtorcon(self):
with mock.patch("wormhole.tor_manager.txtorcon", None):
self.failureResultOf(get_tor(None), NoTorError)
def test_bad_args(self):
f = self.failureResultOf(get_tor(None, launch_tor="not boolean"),
TypeError)
f = self.failureResultOf(
get_tor(None, launch_tor="not boolean"), TypeError)
self.assertEqual(str(f.value), "launch_tor= must be boolean")
f = self.failureResultOf(get_tor(None, tor_control_port=1234),
TypeError)
f = self.failureResultOf(
get_tor(None, tor_control_port=1234), TypeError)
self.assertEqual(str(f.value), "tor_control_port= must be str or None")
f = self.failureResultOf(get_tor(None, launch_tor=True,
tor_control_port="tcp:127.0.0.1:1234"),
ValueError)
self.assertEqual(str(f.value),
"cannot combine --launch-tor and --tor-control-port=")
f = self.failureResultOf(
get_tor(
None, launch_tor=True, tor_control_port="tcp:127.0.0.1:1234"),
ValueError)
self.assertEqual(
str(f.value),
"cannot combine --launch-tor and --tor-control-port=")
def test_launch(self):
reactor = object()
my_tor = X() # object() didn't like providedBy()
my_tor = X() # object() didn't like providedBy()
launch_d = defer.Deferred()
stderr = io.StringIO()
with mock.patch("wormhole.tor_manager.txtorcon.launch",
side_effect=launch_d) as launch:
with mock.patch(
"wormhole.tor_manager.txtorcon.launch",
side_effect=launch_d) as launch:
d = get_tor(reactor, launch_tor=True, stderr=stderr)
self.assertNoResult(d)
self.assertEqual(launch.mock_calls, [mock.call(reactor)])
@ -44,18 +53,21 @@ class Tor(unittest.TestCase):
tor = self.successResultOf(d)
self.assertIs(tor, my_tor)
self.assert_(ITorManager.providedBy(tor))
self.assertEqual(stderr.getvalue(),
" launching a new Tor process, this may take a while..\n")
self.assertEqual(
stderr.getvalue(),
" launching a new Tor process, this may take a while..\n")
def test_connect(self):
reactor = object()
my_tor = X() # object() didn't like providedBy()
my_tor = X() # object() didn't like providedBy()
connect_d = defer.Deferred()
stderr = io.StringIO()
with mock.patch("wormhole.tor_manager.txtorcon.connect",
side_effect=connect_d) as connect:
with mock.patch("wormhole.tor_manager.clientFromString",
side_effect=["foo"]) as sfs:
with mock.patch(
"wormhole.tor_manager.txtorcon.connect",
side_effect=connect_d) as connect:
with mock.patch(
"wormhole.tor_manager.clientFromString",
side_effect=["foo"]) as sfs:
d = get_tor(reactor, stderr=stderr)
self.assertEqual(sfs.mock_calls, [])
self.assertNoResult(d)
@ -71,10 +83,12 @@ class Tor(unittest.TestCase):
reactor = object()
connect_d = defer.Deferred()
stderr = io.StringIO()
with mock.patch("wormhole.tor_manager.txtorcon.connect",
side_effect=connect_d) as connect:
with mock.patch("wormhole.tor_manager.clientFromString",
side_effect=["foo"]) as sfs:
with mock.patch(
"wormhole.tor_manager.txtorcon.connect",
side_effect=connect_d) as connect:
with mock.patch(
"wormhole.tor_manager.clientFromString",
side_effect=["foo"]) as sfs:
d = get_tor(reactor, stderr=stderr)
self.assertEqual(sfs.mock_calls, [])
self.assertNoResult(d)
@ -85,20 +99,23 @@ class Tor(unittest.TestCase):
self.assertIsInstance(tor, SocksOnlyTor)
self.assert_(ITorManager.providedBy(tor))
self.assertEqual(tor._reactor, reactor)
self.assertEqual(stderr.getvalue(),
" unable to find default Tor control port, using SOCKS\n")
self.assertEqual(
stderr.getvalue(),
" unable to find default Tor control port, using SOCKS\n")
def test_connect_custom_control_port(self):
reactor = object()
my_tor = X() # object() didn't like providedBy()
my_tor = X() # object() didn't like providedBy()
tcp = "PORT"
ep = object()
connect_d = defer.Deferred()
stderr = io.StringIO()
with mock.patch("wormhole.tor_manager.txtorcon.connect",
side_effect=connect_d) as connect:
with mock.patch("wormhole.tor_manager.clientFromString",
side_effect=[ep]) as sfs:
with mock.patch(
"wormhole.tor_manager.txtorcon.connect",
side_effect=connect_d) as connect:
with mock.patch(
"wormhole.tor_manager.clientFromString",
side_effect=[ep]) as sfs:
d = get_tor(reactor, tor_control_port=tcp, stderr=stderr)
self.assertEqual(sfs.mock_calls, [mock.call(reactor, tcp)])
self.assertNoResult(d)
@ -116,10 +133,12 @@ class Tor(unittest.TestCase):
ep = object()
connect_d = defer.Deferred()
stderr = io.StringIO()
with mock.patch("wormhole.tor_manager.txtorcon.connect",
side_effect=connect_d) as connect:
with mock.patch("wormhole.tor_manager.clientFromString",
side_effect=[ep]) as sfs:
with mock.patch(
"wormhole.tor_manager.txtorcon.connect",
side_effect=connect_d) as connect:
with mock.patch(
"wormhole.tor_manager.clientFromString",
side_effect=[ep]) as sfs:
d = get_tor(reactor, tor_control_port=tcp, stderr=stderr)
self.assertEqual(sfs.mock_calls, [mock.call(reactor, tcp)])
self.assertNoResult(d)
@ -129,18 +148,22 @@ class Tor(unittest.TestCase):
self.failureResultOf(d, ConnectError)
self.assertEqual(stderr.getvalue(), "")
class SocksOnly(unittest.TestCase):
def test_tor(self):
reactor = object()
sot = SocksOnlyTor(reactor)
fake_ep = object()
with mock.patch("wormhole.tor_manager.txtorcon.TorClientEndpoint",
return_value=fake_ep) as tce:
with mock.patch(
"wormhole.tor_manager.txtorcon.TorClientEndpoint",
return_value=fake_ep) as tce:
ep = sot.stream_via("host", "port")
self.assertIs(ep, fake_ep)
self.assertEqual(tce.mock_calls, [mock.call("host", "port",
socks_endpoint=None,
tls=False,
reactor=reactor)])
self.assertEqual(tce.mock_calls, [
mock.call(
"host",
"port",
socks_endpoint=None,
tls=False,
reactor=reactor)
])

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,15 @@
from __future__ import unicode_literals
import six
import mock
import unicodedata
import six
from twisted.trial import unittest
import mock
from .. import util
class Utils(unittest.TestCase):
def test_to_bytes(self):
b = util.to_bytes("abc")
@ -41,11 +46,12 @@ class Utils(unittest.TestCase):
self.assertIsInstance(d, dict)
self.assertEqual(d, {"a": "b", "c": 2})
class Space(unittest.TestCase):
def test_free_space(self):
free = util.estimate_free_space(".")
self.assert_(isinstance(free, six.integer_types + (type(None),)),
repr(free))
self.assert_(
isinstance(free, six.integer_types + (type(None), )), repr(free))
# some platforms (I think the VMs used by travis are in this
# category) return 0, and windows will return None, so don't assert
# anything more specific about the return value
@ -56,5 +62,5 @@ class Space(unittest.TestCase):
try:
with mock.patch("os.statvfs", side_effect=AttributeError()):
self.assertEqual(util.estimate_free_space("."), None)
except AttributeError: # raised by mock.get_original()
except AttributeError: # raised by mock.get_original()
pass

View File

@ -1,8 +1,12 @@
from __future__ import print_function, unicode_literals
import mock
from twisted.trial import unittest
import mock
from .._wordlist import PGPWordList
class Completions(unittest.TestCase):
def test_completions(self):
wl = PGPWordList()
@ -14,16 +18,21 @@ class Completions(unittest.TestCase):
self.assertEqual(len(lots), 256, lots)
first = list(lots)[0]
self.assert_(first.startswith("armistice-"), first)
self.assertEqual(gc("armistice-ba", 2),
{"armistice-baboon", "armistice-backfield",
"armistice-backward", "armistice-banjo"})
self.assertEqual(gc("armistice-ba", 3),
{"armistice-baboon-", "armistice-backfield-",
"armistice-backward-", "armistice-banjo-"})
self.assertEqual(
gc("armistice-ba", 2), {
"armistice-baboon", "armistice-backfield",
"armistice-backward", "armistice-banjo"
})
self.assertEqual(
gc("armistice-ba", 3), {
"armistice-baboon-", "armistice-backfield-",
"armistice-backward-", "armistice-banjo-"
})
self.assertEqual(gc("armistice-baboon", 2), {"armistice-baboon"})
self.assertEqual(gc("armistice-baboon", 3), {"armistice-baboon-"})
self.assertEqual(gc("armistice-baboon", 4), {"armistice-baboon-"})
class Choose(unittest.TestCase):
def test_choose_words(self):
wl = PGPWordList()

View File

@ -1,17 +1,22 @@
from __future__ import print_function, unicode_literals
import io, re
import mock
from twisted.trial import unittest
import io
import re
from twisted.internet import reactor
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
from twisted.internet.error import ConnectionRefusedError
from .common import ServerBase, poll_until
from .. import wormhole, _rendezvous
from ..errors import (WrongPasswordError, ServerConnectionError,
KeyFormatError, WormholeClosed, LonelyError,
NoKeyError, OnlyOneCodeError)
from ..transit import allocate_tcp_port
from twisted.trial import unittest
import mock
from .. import _rendezvous, wormhole
from ..errors import (KeyFormatError, LonelyError, NoKeyError,
OnlyOneCodeError, ServerConnectionError, WormholeClosed,
WrongPasswordError)
from ..eventual import EventualQueue
from ..transit import allocate_tcp_port
from .common import ServerBase, poll_until
APPID = "appid"
@ -26,6 +31,7 @@ APPID = "appid"
# * set_code, then connected
# * connected, receive_pake, send_phase, set_code
class Delegate:
def __init__(self):
self.welcome = None
@ -35,28 +41,35 @@ class Delegate:
self.versions = None
self.messages = []
self.closed = None
def wormhole_got_welcome(self, welcome):
self.welcome = welcome
def wormhole_got_code(self, code):
self.code = code
def wormhole_got_unverified_key(self, key):
self.key = key
def wormhole_got_verifier(self, verifier):
self.verifier = verifier
def wormhole_got_versions(self, versions):
self.versions = versions
def wormhole_got_message(self, data):
self.messages.append(data)
def wormhole_closed(self, result):
self.closed = result
class Delegated(ServerBase, unittest.TestCase):
class Delegated(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_delegated(self):
dg = Delegate()
w1 = wormhole.create(APPID, self.relayurl, reactor, delegate=dg)
#w1.debug_set_trace("W1")
# w1.debug_set_trace("W1")
with self.assertRaises(NoKeyError):
w1.derive_key("purpose", 12)
w1.set_code("1-abc")
@ -103,6 +116,7 @@ class Delegated(ServerBase, unittest.TestCase):
yield poll_until(lambda: dg.code is not None)
w1.close()
class Wormholes(ServerBase, unittest.TestCase):
# integration test, with a real server
@ -135,12 +149,12 @@ class Wormholes(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_basic(self):
w1 = wormhole.create(APPID, self.relayurl, reactor)
#w1.debug_set_trace("W1")
# w1.debug_set_trace("W1")
with self.assertRaises(NoKeyError):
w1.derive_key("purpose", 12)
w2 = wormhole.create(APPID, self.relayurl, reactor)
#w2.debug_set_trace(" W2")
# w2.debug_set_trace(" W2")
w1.allocate_code()
code = yield w1.get_code()
w2.set_code(code)
@ -302,7 +316,6 @@ class Wormholes(ServerBase, unittest.TestCase):
yield w1.close()
yield w2.close()
@inlineCallbacks
def test_multiple_messages(self):
w1 = wormhole.create(APPID, self.relayurl, reactor)
@ -322,7 +335,6 @@ class Wormholes(ServerBase, unittest.TestCase):
yield w1.close()
yield w2.close()
@inlineCallbacks
def test_closed(self):
eq = EventualQueue(reactor)
@ -377,7 +389,7 @@ class Wormholes(ServerBase, unittest.TestCase):
w2 = wormhole.create(APPID, self.relayurl, reactor, _eventual_queue=eq)
w1.allocate_code()
code = yield w1.get_code()
w2.set_code(code+"not")
w2.set_code(code + "not")
code2 = yield w2.get_code()
self.assertNotEqual(code, code2)
# That's enough to allow both sides to discover the mismatch, but
@ -387,9 +399,9 @@ class Wormholes(ServerBase, unittest.TestCase):
w1.send_message(b"should still work")
w2.send_message(b"should still work")
key2 = yield w2.get_unverified_key() # should work
key2 = yield w2.get_unverified_key() # should work
# w2 has just received w1.PAKE, and is about to send w2.VERSION
key1 = yield w1.get_unverified_key() # should work
key1 = yield w1.get_unverified_key() # should work
# w1 has just received w2.PAKE, and is about to send w1.VERSION, and
# then will receive w2.VERSION. When it sees w2.VERSION, it will
# learn about the WrongPasswordError.
@ -451,7 +463,7 @@ class Wormholes(ServerBase, unittest.TestCase):
badcode = "4 oops spaces"
with self.assertRaises(KeyFormatError) as ex:
w.set_code(badcode)
expected_msg = "Code '%s' contains spaces." % (badcode,)
expected_msg = "Code '%s' contains spaces." % (badcode, )
self.assertEqual(expected_msg, str(ex.exception))
yield self.assertFailure(w.close(), LonelyError)
@ -461,7 +473,7 @@ class Wormholes(ServerBase, unittest.TestCase):
badcode = " 4-oops-space"
with self.assertRaises(KeyFormatError) as ex:
w.set_code(badcode)
expected_msg = "Code '%s' contains spaces." % (badcode,)
expected_msg = "Code '%s' contains spaces." % (badcode, )
self.assertEqual(expected_msg, str(ex.exception))
yield self.assertFailure(w.close(), LonelyError)
@ -478,8 +490,8 @@ class Wormholes(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_welcome(self):
w1 = wormhole.create(APPID, self.relayurl, reactor)
wel1 = yield w1.get_welcome() # early: before connection established
wel2 = yield w1.get_welcome() # late: already received welcome
wel1 = yield w1.get_welcome() # early: before connection established
wel2 = yield w1.get_welcome() # late: already received welcome
self.assertEqual(wel1, wel2)
self.assertIn("current_cli_version", wel1)
@ -489,7 +501,7 @@ class Wormholes(ServerBase, unittest.TestCase):
w2.set_code("123-NOT")
yield self.assertFailure(w1.get_verifier(), WrongPasswordError)
yield self.assertFailure(w1.get_welcome(), WrongPasswordError) # late
yield self.assertFailure(w1.get_welcome(), WrongPasswordError) # late
yield self.assertFailure(w1.close(), WrongPasswordError)
yield self.assertFailure(w2.close(), WrongPasswordError)
@ -502,7 +514,7 @@ class Wormholes(ServerBase, unittest.TestCase):
w1.allocate_code()
code = yield w1.get_code()
w2.set_code(code)
v1 = yield w1.get_verifier() # early
v1 = yield w1.get_verifier() # early
v2 = yield w2.get_verifier()
self.failUnlessEqual(type(v1), type(b""))
self.failUnlessEqual(v1, v2)
@ -525,10 +537,10 @@ class Wormholes(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_versions(self):
# there's no API for this yet, but make sure the internals work
w1 = wormhole.create(APPID, self.relayurl, reactor,
versions={"w1": 123})
w2 = wormhole.create(APPID, self.relayurl, reactor,
versions={"w2": 456})
w1 = wormhole.create(
APPID, self.relayurl, reactor, versions={"w1": 123})
w2 = wormhole.create(
APPID, self.relayurl, reactor, versions={"w2": 456})
w1.allocate_code()
code = yield w1.get_code()
w2.set_code(code)
@ -564,17 +576,19 @@ class Wormholes(ServerBase, unittest.TestCase):
yield w1.close()
yield w2.close()
class MessageDoubler(_rendezvous.RendezvousConnector):
# we could double messages on the sending side, but a future server will
# strip those duplicates, so to really exercise the receiver, we must
# double them on the inbound side instead
#def _msg_send(self, phase, body):
# wormhole._Wormhole._msg_send(self, phase, body)
# self._ws_send_command("add", phase=phase, body=bytes_to_hexstr(body))
# def _msg_send(self, phase, body):
# wormhole._Wormhole._msg_send(self, phase, body)
# self._ws_send_command("add", phase=phase, body=bytes_to_hexstr(body))
def _response_handle_message(self, msg):
_rendezvous.RendezvousConnector._response_handle_message(self, msg)
_rendezvous.RendezvousConnector._response_handle_message(self, msg)
class Errors(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_derive_key_early(self):
@ -602,16 +616,17 @@ class Errors(ServerBase, unittest.TestCase):
w.set_code("123-nope")
yield self.assertFailure(w.close(), LonelyError)
class Reconnection(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_basic(self):
w1 = wormhole.create(APPID, self.relayurl, reactor)
w1_in = []
w1._boss._RC._debug_record_inbound_f = w1_in.append
#w1.debug_set_trace("W1")
# w1.debug_set_trace("W1")
w1.allocate_code()
code = yield w1.get_code()
w1.send_message(b"data1") # queued until wormhole is established
w1.send_message(b"data1") # queued until wormhole is established
# now wait until we've deposited all our messages on the server
def seen_our_pake():
@ -619,6 +634,7 @@ class Reconnection(ServerBase, unittest.TestCase):
if m["type"] == "message" and m["phase"] == "pake":
return True
return False
yield poll_until(seen_our_pake)
w1_in[:] = []
@ -634,7 +650,7 @@ class Reconnection(ServerBase, unittest.TestCase):
# receiver has started
w2 = wormhole.create(APPID, self.relayurl, reactor)
#w2.debug_set_trace(" W2")
# w2.debug_set_trace(" W2")
w2.set_code(code)
dataY = yield w2.get_message()
@ -649,6 +665,7 @@ class Reconnection(ServerBase, unittest.TestCase):
c2 = yield w2.close()
self.assertEqual(c2, "happy")
class InitialFailure(unittest.TestCase):
@inlineCallbacks
def assertSCEFailure(self, eq, d, innerType):
@ -662,8 +679,8 @@ class InitialFailure(unittest.TestCase):
def test_bad_dns(self):
eq = EventualQueue(reactor)
# point at a URL that will never connect
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1",
reactor, _eventual_queue=eq)
w = wormhole.create(
APPID, "ws://%%%.example.org:4000/v1", reactor, _eventual_queue=eq)
# that should have already received an error, when it tried to
# resolve the bogus DNS name. All API calls will return an error.
@ -717,6 +734,7 @@ class InitialFailure(unittest.TestCase):
yield self.assertSCE(d4, ConnectionRefusedError)
yield self.assertSCE(d5, ConnectionRefusedError)
class Trace(unittest.TestCase):
def test_basic(self):
w1 = wormhole.create(APPID, "ws://localhost:1", reactor)
@ -734,13 +752,11 @@ class Trace(unittest.TestCase):
["C1.M1[OLD].IN -> [NEW]"])
out("OUT1")
self.assertEqual(stderr.getvalue().splitlines(),
["C1.M1[OLD].IN -> [NEW]",
" C1.M1.OUT1()"])
["C1.M1[OLD].IN -> [NEW]", " C1.M1.OUT1()"])
w1._boss._print_trace("", "R.connected", "", "C1", "RC1", stderr)
self.assertEqual(stderr.getvalue().splitlines(),
["C1.M1[OLD].IN -> [NEW]",
" C1.M1.OUT1()",
"C1.RC1.R.connected"])
self.assertEqual(
stderr.getvalue().splitlines(),
["C1.M1[OLD].IN -> [NEW]", " C1.M1.OUT1()", "C1.RC1.R.connected"])
def test_delegated(self):
dg = Delegate()

View File

@ -1,11 +1,13 @@
from twisted.trial import unittest
from twisted.internet import reactor, defer
from twisted.internet import defer, reactor
from twisted.internet.defer import inlineCallbacks
from twisted.trial import unittest
from .. import xfer_util
from .common import ServerBase
APPID = u"appid"
class Xfer(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_xfer(self):
@ -24,10 +26,15 @@ class Xfer(ServerBase, unittest.TestCase):
data = u"data"
send_code = []
receive_code = []
d1 = xfer_util.send(reactor, APPID, self.relayurl, data, code,
on_code=send_code.append)
d2 = xfer_util.receive(reactor, APPID, self.relayurl, code,
on_code=receive_code.append)
d1 = xfer_util.send(
reactor,
APPID,
self.relayurl,
data,
code,
on_code=send_code.append)
d2 = xfer_util.receive(
reactor, APPID, self.relayurl, code, on_code=receive_code.append)
send_result = yield d1
receive_result = yield d2
self.assertEqual(send_code, [code])
@ -39,8 +46,13 @@ class Xfer(ServerBase, unittest.TestCase):
def test_make_code(self):
data = u"data"
got_code = defer.Deferred()
d1 = xfer_util.send(reactor, APPID, self.relayurl, data, code=None,
on_code=got_code.callback)
d1 = xfer_util.send(
reactor,
APPID,
self.relayurl,
data,
code=None,
on_code=got_code.callback)
code = yield got_code
d2 = xfer_util.receive(reactor, APPID, self.relayurl, code)
send_result = yield d1

View File

@ -1,8 +1,13 @@
from __future__ import print_function, absolute_import, unicode_literals
import json, time
from __future__ import absolute_import, print_function, unicode_literals
import json
import time
from zope.interface import implementer
from ._interfaces import ITiming
class Event:
def __init__(self, name, when, **details):
# data fields that will be dumped to JSON later
@ -35,6 +40,7 @@ class Event:
else:
self.finish()
@implementer(ITiming)
class DebugTiming:
def __init__(self):
@ -47,11 +53,14 @@ class DebugTiming:
def write(self, fn, stderr):
with open(fn, "wt") as f:
data = [ dict(name=e._name,
start=e._start, stop=e._stop,
details=e._details,
)
for e in self._events ]
data = [
dict(
name=e._name,
start=e._start,
stop=e._stop,
details=e._details,
) for e in self._events
]
json.dump(data, f, indent=1)
f.write("\n")
print("Timing data written to %s" % fn, file=stderr)

View File

@ -1,15 +1,20 @@
from __future__ import print_function, unicode_literals
import sys
from attr import attrs, attrib
from zope.interface.declarations import directlyProvides
from attr import attrib, attrs
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.endpoints import clientFromString
from zope.interface.declarations import directlyProvides
from . import _interfaces, errors
from .timing import DebugTiming
try:
import txtorcon
except ImportError:
txtorcon = None
from . import _interfaces, errors
from .timing import DebugTiming
@attrs
class SocksOnlyTor(object):
@ -17,15 +22,20 @@ class SocksOnlyTor(object):
def stream_via(self, host, port, tls=False):
return txtorcon.TorClientEndpoint(
host, port,
socks_endpoint=None, # tries localhost:9050 and 9150
host,
port,
socks_endpoint=None, # tries localhost:9050 and 9150
tls=tls,
reactor=self._reactor,
)
@inlineCallbacks
def get_tor(reactor, launch_tor=False, tor_control_port=None,
timing=None, stderr=sys.stderr):
def get_tor(reactor,
launch_tor=False,
tor_control_port=None,
timing=None,
stderr=sys.stderr):
"""
If launch_tor=True, I will try to launch a new Tor process, ask it
for its SOCKS and control ports, and use those for outbound
@ -59,7 +69,7 @@ def get_tor(reactor, launch_tor=False, tor_control_port=None,
if not txtorcon:
raise errors.NoTorError()
if not isinstance(launch_tor, bool): # note: False is int
if not isinstance(launch_tor, bool): # note: False is int
raise TypeError("launch_tor= must be boolean")
if not isinstance(tor_control_port, (type(""), type(None))):
raise TypeError("tor_control_port= must be str or None")
@ -74,19 +84,21 @@ def get_tor(reactor, launch_tor=False, tor_control_port=None,
# need the control port.
if launch_tor:
print(" launching a new Tor process, this may take a while..",
file=stderr)
print(
" launching a new Tor process, this may take a while..",
file=stderr)
with timing.add("launch tor"):
tor = yield txtorcon.launch(reactor,
#data_directory=,
#tor_binary=,
# data_directory=,
# tor_binary=,
)
elif tor_control_port:
with timing.add("find tor"):
control_ep = clientFromString(reactor, tor_control_port)
tor = yield txtorcon.connect(reactor, control_ep) # might raise
print(" using Tor via control port at %s" % tor_control_port,
file=stderr)
tor = yield txtorcon.connect(reactor, control_ep) # might raise
print(
" using Tor via control port at %s" % tor_control_port,
file=stderr)
else:
# Let txtorcon look through a list of usual places. If that fails,
# we'll arrange to attempt the default SOCKS port
@ -98,8 +110,9 @@ def get_tor(reactor, launch_tor=False, tor_control_port=None,
# TODO: make this more specific. I think connect() is
# likely to throw a reactor.connectTCP -type error, like
# ConnectionFailed or ConnectionRefused or something
print(" unable to find default Tor control port, using SOCKS",
file=stderr)
print(
" unable to find default Tor control port, using SOCKS",
file=stderr)
tor = SocksOnlyTor(reactor)
directlyProvides(tor, _interfaces.ITorManager)
returnValue(tor)

View File

@ -1,38 +1,51 @@
# no unicode_literals, revisit after twisted patch
from __future__ import print_function, absolute_import
import os, re, sys, time, socket
from collections import namedtuple, deque
from __future__ import absolute_import, print_function
import os
import re
import socket
import sys
import time
from binascii import hexlify, unhexlify
from collections import deque, namedtuple
import six
from zope.interface import implementer
from twisted.python import log
from twisted.python.runtime import platformType
from twisted.internet import (reactor, interfaces, defer, protocol,
endpoints, task, address, error)
from hkdf import Hkdf
from nacl.secret import SecretBox
from twisted.internet import (address, defer, endpoints, error, interfaces,
protocol, reactor, task)
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.protocols import policies
from nacl.secret import SecretBox
from hkdf import Hkdf
from twisted.python import log
from twisted.python.runtime import platformType
from zope.interface import implementer
from . import ipaddrs
from .errors import InternalError
from .timing import DebugTiming
from .util import bytes_to_hexstr
from . import ipaddrs
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen)
class TransitError(Exception):
pass
class BadHandshake(Exception):
pass
class TransitClosed(TransitError):
pass
class BadNonce(TransitError):
pass
# The beginning of each TCP connection consists of the following handshake
# messages. The sender transmits the same text regardless of whether it is on
# the initiating/connecting end of the TCP connection, or on the
@ -63,19 +76,23 @@ class BadNonce(TransitError):
# RXID_HEX ready\n\n" and then makes a first/not-first decision about sending
# "go\n" or "nevermind\n"+close().
def build_receiver_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_receiver")
return b"transit receiver "+hexlify(hexid)+b" ready\n\n"
return b"transit receiver " + hexlify(hexid) + b" ready\n\n"
def build_sender_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
return b"transit sender "+hexlify(hexid)+b" ready\n\n"
return b"transit sender " + hexlify(hexid) + b" ready\n\n"
def build_sided_relay_handshake(key, side):
assert isinstance(side, type(u""))
assert len(side) == 8*2
assert len(side) == 8 * 2
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n"
return b"please relay " + hexlify(token) + b" for side " + side.encode(
"ascii") + b"\n"
# These namedtuples are "hint objects". The JSON-serializable dictionaries
@ -87,7 +104,8 @@ def build_sided_relay_handshake(key, side):
# * expect to see the receiver/sender handshake bytes from the other side
# * the sender writes "go\n", the receiver waits for "go\n"
# * the rest of the connection contains transit data
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"])
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint",
["hostname", "port", "priority"])
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
# use a tuple rather than a list so they'll be hashable into a set). For each
@ -95,6 +113,7 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# rest of the V1 protocol. Only one hint per relay is useful.
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
def describe_hint_obj(hint):
if isinstance(hint, DirectTCPV1Hint):
return u"tcp:%s:%d" % (hint.hostname, hint.port)
@ -103,27 +122,30 @@ def describe_hint_obj(hint):
else:
return str(hint)
def parse_hint_argv(hint, stderr=sys.stderr):
assert isinstance(hint, type(u""))
# return tuple or None for an unparseable hint
priority = 0.0
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
if not mo:
print("unparseable hint '%s'" % (hint,), file=stderr)
print("unparseable hint '%s'" % (hint, ), file=stderr)
return None
hint_type = mo.group(1)
if hint_type != "tcp":
print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr)
print(
"unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr)
return None
hint_value = mo.group(2)
pieces = hint_value.split(":")
if len(pieces) < 2:
print("unparseable TCP hint (need more colons) '%s'" % (hint,),
file=stderr)
print(
"unparseable TCP hint (need more colons) '%s'" % (hint, ),
file=stderr)
return None
mo = re.search(r'^(\d+)$', pieces[1])
if not mo:
print("non-numeric port in TCP hint '%s'" % (hint,), file=stderr)
print("non-numeric port in TCP hint '%s'" % (hint, ), file=stderr)
return None
hint_host = pieces[0]
hint_port = int(pieces[1])
@ -133,12 +155,15 @@ def parse_hint_argv(hint, stderr=sys.stderr):
try:
priority = float(more_pieces[1])
except ValueError:
print("non-float priority= in TCP hint '%s'" % (hint,),
file=stderr)
print(
"non-float priority= in TCP hint '%s'" % (hint, ),
file=stderr)
return None
return DirectTCPV1Hint(hint_host, hint_port, priority)
TIMEOUT = 60 # seconds
TIMEOUT = 60 # seconds
@implementer(interfaces.IProducer, interfaces.IConsumer)
class Connection(protocol.Protocol, policies.TimeoutMixin):
@ -159,7 +184,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
self._waiting_reads = deque()
def connectionMade(self):
self.setTimeout(TIMEOUT) # does timeoutConnection() when it expires
self.setTimeout(TIMEOUT) # does timeoutConnection() when it expires
self.factory.connectionWasMade(self)
def startNegotiation(self):
@ -168,11 +193,11 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
self.state = "relay"
else:
self.state = "start"
self.dataReceived(b"") # cycle the state machine
self.dataReceived(b"") # cycle the state machine
return self._negotiation_d
def _cancel(self, d):
self.state = "hung up" # stop reacting to anything further
self.state = "hung up" # stop reacting to anything further
self._error = defer.CancelledError()
self.transport.loseConnection()
# if connectionLost isn't called synchronously, then our
@ -181,7 +206,6 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
if self._negotiation_d:
self._negotiation_d = None
def dataReceived(self, data):
try:
self._dataReceived(data)
@ -198,7 +222,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
if not self.buf.startswith(expected[:len(self.buf)]):
raise BadHandshake("got %r want %r" % (self.buf, expected))
if len(self.buf) < len(expected):
return False # keep waiting
return False # keep waiting
self.buf = self.buf[len(expected):]
return True
@ -245,9 +269,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
return self.dataReceivedRECORDS()
if self.state == "hung up":
return
if isinstance(self.state, Exception): # for tests
if isinstance(self.state, Exception): # for tests
raise self.state
raise ValueError("internal error: unknown state %s" % (self.state,))
raise ValueError("internal error: unknown state %s" % (self.state, ))
def _negotiationSuccessful(self):
self.state = "records"
@ -266,19 +290,20 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
if len(self.buf) < 4:
return
length = int(hexlify(self.buf[:4]), 16)
if len(self.buf) < 4+length:
if len(self.buf) < 4 + length:
return
encrypted, self.buf = self.buf[4:4+length], self.buf[4+length:]
encrypted, self.buf = self.buf[4:4 + length], self.buf[4 + length:]
record = self._decrypt_record(encrypted)
self.recordReceived(record)
def _decrypt_record(self, encrypted):
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
nonce = int(hexlify(nonce_buf), 16)
if nonce != self.next_receive_nonce:
raise BadNonce("received out-of-order record: got %d, expected %d"
% (nonce, self.next_receive_nonce))
raise BadNonce(
"received out-of-order record: got %d, expected %d" %
(nonce, self.next_receive_nonce))
self.next_receive_nonce += 1
record = self.receive_box.decrypt(encrypted)
return record
@ -287,14 +312,15 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
return self._description
def send_record(self, record):
if not isinstance(record, type(b"")): raise InternalError
if not isinstance(record, type(b"")):
raise InternalError
assert SecretBox.NONCE_SIZE == 24
assert self.send_nonce < 2**(8*24)
assert len(record) < 2**(8*4)
nonce = unhexlify("%048x" % self.send_nonce) # big-endian
assert self.send_nonce < 2**(8 * 24)
assert len(record) < 2**(8 * 4)
nonce = unhexlify("%048x" % self.send_nonce) # big-endian
self.send_nonce += 1
encrypted = self.send_box.encrypt(record, nonce)
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
self.transport.write(length)
self.transport.write(encrypted)
@ -353,8 +379,10 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
def registerProducer(self, producer, streaming):
assert interfaces.IConsumer.providedBy(self.transport)
self.transport.registerProducer(producer, streaming)
def unregisterProducer(self):
self.transport.unregisterProducer()
def write(self, data):
self.send_record(data)
@ -362,8 +390,10 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
# the transport.
def stopProducing(self):
self.transport.stopProducing()
def pauseProducing(self):
self.transport.pauseProducing()
def resumeProducing(self):
self.transport.resumeProducing()
@ -384,8 +414,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
Deferred, and you must call disconnectConsumer() when you are done."""
if self._consumer:
raise RuntimeError("A consumer is already attached: %r" %
self._consumer)
raise RuntimeError(
"A consumer is already attached: %r" % self._consumer)
# be aware of an ordering hazard: when we call the consumer's
# .registerProducer method, they are likely to immediately call
@ -440,6 +470,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
fc = FileConsumer(f, progress, hasher)
return self.connectConsumer(fc, expected)
class OutboundConnectionFactory(protocol.ClientFactory):
protocol = Connection
@ -478,7 +509,7 @@ class InboundConnectionFactory(protocol.ClientFactory):
def _shutdown(self):
for d in list(self._pending_connections):
d.cancel() # that fires _remove and _proto_failed
d.cancel() # that fires _remove and _proto_failed
def _describePeer(self, addr):
if isinstance(addr, address.HostnameAddress):
@ -511,6 +542,7 @@ class InboundConnectionFactory(protocol.ClientFactory):
# ignore these two, let Twisted log everything else
f.trap(BadHandshake, defer.CancelledError)
def allocate_tcp_port():
"""Return an (integer) available TCP port on localhost. This briefly
listens on the port in question, then closes it right away."""
@ -527,6 +559,7 @@ def allocate_tcp_port():
s.close()
return port
class _ThereCanBeOnlyOne:
"""Accept a list of contender Deferreds, and return a summary Deferred.
When the first contender fires successfully, cancel the rest and fire the
@ -535,6 +568,7 @@ class _ThereCanBeOnlyOne:
status_cb=?
"""
def __init__(self, contenders):
self._remaining = set(contenders)
self._winner_d = defer.Deferred(self._cancel)
@ -581,26 +615,32 @@ class _ThereCanBeOnlyOne:
else:
self._winner_d.errback(self._first_failure)
def there_can_be_only_one(contenders):
return _ThereCanBeOnlyOne(contenders).run()
class Common:
RELAY_DELAY = 2.0
TRANSIT_KEY_LENGTH = SecretBox.KEY_SIZE
def __init__(self, transit_relay, no_listen=False, tor=None,
reactor=reactor, timing=None):
self._side = bytes_to_hexstr(os.urandom(8)) # unicode
def __init__(self,
transit_relay,
no_listen=False,
tor=None,
reactor=reactor,
timing=None):
self._side = bytes_to_hexstr(os.urandom(8)) # unicode
if transit_relay:
if not isinstance(transit_relay, type(u"")):
raise InternalError
# TODO: allow multiple hints for a single relay
relay_hint = parse_hint_argv(transit_relay)
relay = RelayV1Hint(hints=(relay_hint,))
relay = RelayV1Hint(hints=(relay_hint, ))
self._transit_relays = [relay]
else:
self._transit_relays = []
self._their_direct_hints = [] # hintobjs
self._their_direct_hints = [] # hintobjs
self._our_relay_hints = set(self._transit_relays)
self._tor = tor
self._transit_key = None
@ -622,33 +662,42 @@ class Common:
# some test hosts, including the appveyor VMs, *only* have
# 127.0.0.1, and the tests will hang badly if we remove it.
addresses = non_loopback_addresses
direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0)
for addr in addresses]
direct_hints = [
DirectTCPV1Hint(six.u(addr), portnum, 0.0) for addr in addresses
]
ep = endpoints.serverFromString(reactor, "tcp:%d" % portnum)
return direct_hints, ep
def get_connection_abilities(self):
return [{u"type": u"direct-tcp-v1"},
{u"type": u"relay-v1"},
]
return [
{
u"type": u"direct-tcp-v1"
},
{
u"type": u"relay-v1"
},
]
@inlineCallbacks
def get_connection_hints(self):
hints = []
direct_hints = yield self._get_direct_hints()
for dh in direct_hints:
hints.append({u"type": u"direct-tcp-v1",
u"priority": dh.priority,
u"hostname": dh.hostname,
u"port": dh.port, # integer
})
hints.append({
u"type": u"direct-tcp-v1",
u"priority": dh.priority,
u"hostname": dh.hostname,
u"port": dh.port, # integer
})
for relay in self._transit_relays:
rhint = {u"type": u"relay-v1", u"hints": []}
for rh in relay.hints:
rhint[u"hints"].append({u"type": u"direct-tcp-v1",
u"priority": rh.priority,
u"hostname": rh.hostname,
u"port": rh.port})
rhint[u"hints"].append({
u"type": u"direct-tcp-v1",
u"priority": rh.priority,
u"hostname": rh.hostname,
u"port": rh.port
})
hints.append(rhint)
returnValue(hints)
@ -665,24 +714,27 @@ class Common:
# listener will win.
self._my_direct_hints, self._listener = self._build_listener()
if self._listener is None: # don't listen
if self._listener is None: # don't listen
self._listener_d = None
return defer.succeed(self._my_direct_hints) # empty
return defer.succeed(self._my_direct_hints) # empty
# Start the server, so it will be running by the time anyone tries to
# connect to the direct hints we return.
f = InboundConnectionFactory(self)
self._listener_f = f # for tests # XX move to __init__ ?
self._listener_f = f # for tests # XX move to __init__ ?
self._listener_d = f.whenDone()
d = self._listener.listen(f)
def _listening(lp):
# lp is an IListeningPort
#self._listener_port = lp # for tests
# self._listener_port = lp # for tests
def _stop_listening(res):
lp.stopListening()
return res
self._listener_d.addBoth(_stop_listening)
return self._my_direct_hints
d.addCallback(_listening)
return d
@ -694,18 +746,18 @@ class Common:
self._listener_d.addErrback(lambda f: None)
self._listener_d.cancel()
def _parse_tcp_v1_hint(self, hint): # hint_struct -> hint_obj
def _parse_tcp_v1_hint(self, hint): # hint_struct -> hint_obj
hint_type = hint.get(u"type", u"")
if hint_type not in [u"direct-tcp-v1", u"tor-tcp-v1"]:
log.msg("unknown hint type: %r" % (hint,))
log.msg("unknown hint type: %r" % (hint, ))
return None
if not(u"hostname" in hint
and isinstance(hint[u"hostname"], type(u""))):
log.msg("invalid hostname in hint: %r" % (hint,))
if not (u"hostname" in hint
and isinstance(hint[u"hostname"], type(u""))):
log.msg("invalid hostname in hint: %r" % (hint, ))
return None
if not(u"port" in hint
and isinstance(hint[u"port"], six.integer_types)):
log.msg("invalid port in hint: %r" % (hint,))
if not (u"port" in hint
and isinstance(hint[u"port"], six.integer_types)):
log.msg("invalid port in hint: %r" % (hint, ))
return None
priority = hint.get(u"priority", 0.0)
if hint_type == u"direct-tcp-v1":
@ -714,12 +766,12 @@ class Common:
return TorTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
def add_connection_hints(self, hints):
for h in hints: # hint structs
for h in hints: # hint structs
hint_type = h.get(u"type", u"")
if hint_type in [u"direct-tcp-v1", u"tor-tcp-v1"]:
dh = self._parse_tcp_v1_hint(h)
if dh:
self._their_direct_hints.append(dh) # hint_obj
self._their_direct_hints.append(dh) # hint_obj
elif hint_type == u"relay-v1":
# TODO: each relay-v1 clause describes a different relay,
# with a set of equally-valid ways to connect to it. Treat
@ -734,7 +786,7 @@ class Common:
rh = RelayV1Hint(hints=tuple(sorted(relay_hints)))
self._our_relay_hints.add(rh)
else:
log.msg("unknown hint type: %r" % (h,))
log.msg("unknown hint type: %r" % (h, ))
def _send_this(self):
assert self._transit_key
@ -748,25 +800,33 @@ class Common:
if self.is_sender:
return build_receiver_handshake(self._transit_key)
else:
return build_sender_handshake(self._transit_key)# + b"go\n"
return build_sender_handshake(self._transit_key) # + b"go\n"
def _sender_record_key(self):
assert self._transit_key
if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key")
return HKDF(
self._transit_key,
SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key")
else:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key")
return HKDF(
self._transit_key,
SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key")
def _receiver_record_key(self):
assert self._transit_key
if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key")
return HKDF(
self._transit_key,
SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key")
else:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key")
return HKDF(
self._transit_key,
SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key")
def set_transit_key(self, key):
assert isinstance(key, type(b"")), type(key)
@ -848,9 +908,13 @@ class Common:
description = "->relay:%s" % describe_hint_obj(hint_obj)
if self._tor:
description = "tor" + description
d = task.deferLater(self._reactor, relay_delay,
self._start_connector, ep, description,
is_relay=True)
d = task.deferLater(
self._reactor,
relay_delay,
self._start_connector,
ep,
description,
is_relay=True)
contenders.append(d)
relay_delay += self.RELAY_DELAY
@ -858,16 +922,18 @@ class Common:
raise TransitError("No contenders for connection")
winner = there_can_be_only_one(contenders)
return self._not_forever(2*TIMEOUT, winner)
return self._not_forever(2 * TIMEOUT, winner)
def _not_forever(self, timeout, d):
"""If the timer fires first, cancel the deferred. If the deferred fires
first, cancel the timer."""
t = self._reactor.callLater(timeout, d.cancel)
def _done(res):
if t.active():
t.cancel()
return res
d.addBoth(_done)
return d
@ -896,8 +962,8 @@ class Common:
return None
return None
if isinstance(hint, DirectTCPV1Hint):
return endpoints.HostnameEndpoint(self._reactor,
hint.hostname, hint.port)
return endpoints.HostnameEndpoint(self._reactor, hint.hostname,
hint.port)
return None
def connection_ready(self, p):
@ -915,9 +981,11 @@ class Common:
self._winner = p
return "go"
class TransitSender(Common):
is_sender = True
class TransitReceiver(Common):
is_sender = False
@ -926,6 +994,7 @@ class TransitReceiver(Common):
# when done, and add a progress function that gets called with the length of
# each write, and a hasher function that gets called with the data.
@implementer(interfaces.IConsumer)
class FileConsumer:
def __init__(self, f, progress=None, hasher=None):
@ -950,6 +1019,7 @@ class FileConsumer:
assert self._producer
self._producer = None
# the TransitSender/Receiver.connect() yields a Connection, on which you can
# do send_record(), but what should the receive API be? set a callback for
# inbound records? get a Deferred for the next record? The producer/consumer

View File

@ -1,30 +1,42 @@
# No unicode_literals
import os, json, unicodedata
import json
import os
import unicodedata
from binascii import hexlify, unhexlify
def to_bytes(u):
return unicodedata.normalize("NFC", u).encode("utf-8")
def bytes_to_hexstr(b):
assert isinstance(b, type(b""))
hexstr = hexlify(b).decode("ascii")
assert isinstance(hexstr, type(u""))
return hexstr
def hexstr_to_bytes(hexstr):
assert isinstance(hexstr, type(u""))
b = unhexlify(hexstr.encode("ascii"))
assert isinstance(b, type(b""))
return b
def dict_to_bytes(d):
assert isinstance(d, dict)
b = json.dumps(d).encode("utf-8")
assert isinstance(b, type(b""))
return b
def bytes_to_dict(b):
assert isinstance(b, type(b""))
d = json.loads(b.decode("utf-8"))
assert isinstance(d, dict)
return d
def estimate_free_space(target):
# f_bfree is the blocks available to a root user. It might be more
# accurate to use f_bavail (blocks available to non-root user), but we

View File

@ -1,19 +1,25 @@
from __future__ import print_function, absolute_import, unicode_literals
import os, sys
from attr import attrs, attrib
from zope.interface import implementer
from __future__ import absolute_import, print_function, unicode_literals
import os
import sys
from attr import attrib, attrs
from twisted.python import failure
from . import __version__
from ._interfaces import IWormhole, IDeferredWormhole
from .util import bytes_to_hexstr
from .eventual import EventualQueue
from .observer import OneShotObserver, SequenceObserver
from .timing import DebugTiming
from .journal import ImmediateJournal
from zope.interface import implementer
from ._boss import Boss
from ._interfaces import IDeferredWormhole, IWormhole
from ._key import derive_key
from .errors import NoKeyError, WormholeClosed
from .util import to_bytes
from .eventual import EventualQueue
from .journal import ImmediateJournal
from .observer import OneShotObserver, SequenceObserver
from .timing import DebugTiming
from .util import bytes_to_hexstr, to_bytes
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions
# We can provide different APIs to different apps:
# * Deferreds
@ -36,6 +42,7 @@ from .util import to_bytes
# wormhole(delegate=app, delegate_prefix="wormhole_",
# delegate_args=(args, kwargs))
@attrs
@implementer(IWormhole)
class _DelegatedWormhole(object):
@ -51,16 +58,18 @@ class _DelegatedWormhole(object):
def allocate_code(self, code_length=2):
self._boss.allocate_code(code_length)
def input_code(self):
return self._boss.input_code()
def set_code(self, code):
self._boss.set_code(code)
## def serialize(self):
## s = {"serialized_wormhole_version": 1,
## "boss": self._boss.serialize(),
## }
## return s
# def serialize(self):
# s = {"serialized_wormhole_version": 1,
# "boss": self._boss.serialize(),
# }
# return s
def send_message(self, plaintext):
self._boss.send(plaintext)
@ -72,34 +81,45 @@ class _DelegatedWormhole(object):
cannot be called until when_verifier() has fired, nor after close()
was called.
"""
if not isinstance(purpose, type("")): raise TypeError(type(purpose))
if not self._key: raise NoKeyError()
if not isinstance(purpose, type("")):
raise TypeError(type(purpose))
if not self._key:
raise NoKeyError()
return derive_key(self._key, to_bytes(purpose), length)
def close(self):
self._boss.close()
def debug_set_trace(self, client_name, which="B N M S O K SK R RC L C T",
def debug_set_trace(self,
client_name,
which="B N M S O K SK R RC L C T",
file=sys.stderr):
self._boss._set_trace(client_name, which, file)
# from below
def got_welcome(self, welcome):
self._delegate.wormhole_got_welcome(welcome)
def got_code(self, code):
self._delegate.wormhole_got_code(code)
def got_key(self, key):
self._delegate.wormhole_got_unverified_key(key)
self._key = key # for derive_key()
self._key = key # for derive_key()
def got_verifier(self, verifier):
self._delegate.wormhole_got_verifier(verifier)
def got_versions(self, versions):
self._delegate.wormhole_got_versions(versions)
def received(self, plaintext):
self._delegate.wormhole_got_message(plaintext)
def closed(self, result):
self._delegate.wormhole_closed(result)
@implementer(IWormhole, IDeferredWormhole)
class _DeferredWormhole(object):
def __init__(self, eq):
@ -142,8 +162,10 @@ class _DeferredWormhole(object):
def allocate_code(self, code_length=2):
self._boss.allocate_code(code_length)
def input_code(self):
return self._boss.input_code()
def set_code(self, code):
self._boss.set_code(code)
@ -159,20 +181,23 @@ class _DeferredWormhole(object):
cannot be called until when_verified() has fired, nor after close()
was called.
"""
if not isinstance(purpose, type("")): raise TypeError(type(purpose))
if not self._key: raise NoKeyError()
if not isinstance(purpose, type("")):
raise TypeError(type(purpose))
if not self._key:
raise NoKeyError()
return derive_key(self._key, to_bytes(purpose), length)
def close(self):
# fails with WormholeError unless we established a connection
# (state=="happy"). Fails with WrongPasswordError (a subclass of
# WormholeError) if state=="scary".
d = self._closed_observer.when_fired() # maybe Failure
d = self._closed_observer.when_fired() # maybe Failure
if not self._closed:
self._boss.close() # only need to close if it wasn't already
self._boss.close() # only need to close if it wasn't already
return d
def debug_set_trace(self, client_name,
def debug_set_trace(self,
client_name,
which="B N M S O K SK R RC L A I C T",
file=sys.stderr):
self._boss._set_trace(client_name, which, file)
@ -180,14 +205,17 @@ class _DeferredWormhole(object):
# from below
def got_welcome(self, welcome):
self._welcome_observer.fire_if_not_fired(welcome)
def got_code(self, code):
self._code_observer.fire_if_not_fired(code)
def got_key(self, key):
self._key = key # for derive_key()
self._key = key # for derive_key()
self._key_observer.fire_if_not_fired(key)
def got_verifier(self, verifier):
self._verifier_observer.fire_if_not_fired(verifier)
def got_versions(self, versions):
self._version_observer.fire_if_not_fired(versions)
@ -196,7 +224,7 @@ class _DeferredWormhole(object):
def closed(self, result):
self._closed = True
#print("closed", result, type(result), file=sys.stderr)
# print("closed", result, type(result), file=sys.stderr)
if isinstance(result, Exception):
# everything pending gets an error, including close()
f = failure.Failure(result)
@ -215,12 +243,17 @@ class _DeferredWormhole(object):
self._received_observer.fire(f)
def create(appid, relay_url, reactor, # use keyword args for everything else
versions={},
delegate=None, journal=None, tor=None,
timing=None,
stderr=sys.stderr,
_eventual_queue=None):
def create(
appid,
relay_url,
reactor, # use keyword args for everything else
versions={},
delegate=None,
journal=None,
tor=None,
timing=None,
stderr=sys.stderr,
_eventual_queue=None):
timing = timing or DebugTiming()
side = bytes_to_hexstr(os.urandom(5))
journal = journal or ImmediateJournal()
@ -229,27 +262,28 @@ def create(appid, relay_url, reactor, # use keyword args for everything else
w = _DelegatedWormhole(delegate)
else:
w = _DeferredWormhole(eq)
wormhole_versions = {} # will be used to indicate Wormhole capabilities
wormhole_versions["app_versions"] = versions # app-specific capabilities
wormhole_versions = {} # will be used to indicate Wormhole capabilities
wormhole_versions["app_versions"] = versions # app-specific capabilities
v = __version__
if isinstance(v, type(b"")):
v = v.decode("utf-8", errors="replace")
client_version = ("python", v)
b = Boss(w, side, relay_url, appid, wormhole_versions, client_version,
reactor, journal, tor, timing)
reactor, journal, tor, timing)
w._set_boss(b)
b.start()
return w
## def from_serialized(serialized, reactor, delegate,
## journal=None, tor=None,
## timing=None, stderr=sys.stderr):
## assert serialized["serialized_wormhole_version"] == 1
## timing = timing or DebugTiming()
## w = _DelegatedWormhole(delegate)
## # now unpack state machines, including the SPAKE2 in Key
## b = Boss.from_serialized(w, serialized["boss"], reactor, journal, timing)
## w._set_boss(b)
## b.start() # ??
## raise NotImplemented
## # should the new Wormhole call got_code? only if it wasn't called before.
# def from_serialized(serialized, reactor, delegate,
# journal=None, tor=None,
# timing=None, stderr=sys.stderr):
# assert serialized["serialized_wormhole_version"] == 1
# timing = timing or DebugTiming()
# w = _DelegatedWormhole(delegate)
# # now unpack state machines, including the SPAKE2 in Key
# b = Boss.from_serialized(w, serialized["boss"], reactor, journal, timing)
# w._set_boss(b)
# b.start() # ??
# raise NotImplemented
# # should the new Wormhole call got_code? only if it wasn't called before.

View File

@ -1,12 +1,19 @@
import json
from twisted.internet.defer import inlineCallbacks, returnValue
from . import wormhole
from .tor_manager import get_tor
@inlineCallbacks
def receive(reactor, appid, relay_url, code,
use_tor=False, launch_tor=False, tor_control_port=None,
def receive(reactor,
appid,
relay_url,
code,
use_tor=False,
launch_tor=False,
tor_control_port=None,
on_code=None):
"""
This is a convenience API which returns a Deferred that callbacks
@ -21,9 +28,11 @@ def receive(reactor, appid, relay_url, code,
:param unicode code: a pre-existing code to use, or None
:param bool use_tor: True if we should use Tor, False to not use it (None for default)
:param bool use_tor: True if we should use Tor, False to not use it (None
for default)
:param on_code: if not None, this is called when we have a code (even if you passed in one explicitly)
:param on_code: if not None, this is called when we have a code (even if
you passed in one explicitly)
:type on_code: single-argument callable
"""
tor = None
@ -48,27 +57,33 @@ def receive(reactor, appid, relay_url, code,
data = json.loads(data.decode("utf-8"))
offer = data.get('offer', None)
if not offer:
raise Exception(
"Do not understand response: {}".format(data)
)
raise Exception("Do not understand response: {}".format(data))
msg = None
if 'message' in offer:
msg = offer['message']
wh.send_message(json.dumps({"answer":
{"message_ack": "ok"}}).encode("utf-8"))
wh.send_message(
json.dumps({
"answer": {
"message_ack": "ok"
}
}).encode("utf-8"))
else:
raise Exception(
"Unknown offer type: {}".format(offer.keys())
)
raise Exception("Unknown offer type: {}".format(offer.keys()))
yield wh.close()
returnValue(msg)
@inlineCallbacks
def send(reactor, appid, relay_url, data, code,
use_tor=False, launch_tor=False, tor_control_port=None,
def send(reactor,
appid,
relay_url,
data,
code,
use_tor=False,
launch_tor=False,
tor_control_port=None,
on_code=None):
"""
This is a convenience API which returns a Deferred that callbacks
@ -83,9 +98,12 @@ def send(reactor, appid, relay_url, data, code,
:param unicode code: a pre-existing code to use, or None
:param bool use_tor: True if we should use Tor, False to not use it (None for default)
:param bool use_tor: True if we should use Tor, False to not use it (None
for default)
:param on_code: if not None, this is called when we have a code (even if
you passed in one explicitly)
:param on_code: if not None, this is called when we have a code (even if you passed in one explicitly)
:type on_code: single-argument callable
"""
tor = None
@ -104,13 +122,7 @@ def send(reactor, appid, relay_url, data, code,
if on_code:
on_code(code)
wh.send_message(
json.dumps({
"offer": {
"message": data
}
}).encode("utf-8")
)
wh.send_message(json.dumps({"offer": {"message": data}}).encode("utf-8"))
data = yield wh.get_message()
data = json.loads(data.decode("utf-8"))
answer = data.get('answer', None)
@ -118,6 +130,4 @@ def send(reactor, appid, relay_url, data, code,
if answer:
returnValue(None)
else:
raise Exception(
"Unknown answer: {}".format(data)
)
raise Exception("Unknown answer: {}".format(data))