Merge PR296: require pep8 formatting

This commit is contained in:
Brian Warner 2018-06-16 16:46:44 -07:00
commit 7278925b11
54 changed files with 3271 additions and 1900 deletions

View File

@ -1,9 +1,4 @@
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions
from .wormhole import create
from ._rlcompleter import input_with_completion from ._rlcompleter import input_with_completion
from .wormhole import create, __version__
__all__ = ["create", "input_with_completion", "__version__"] __all__ = ["create", "input_with_completion", "__version__"]

View File

@ -1,8 +1,6 @@
from wormhole.cli import cli
if __name__ != "__main__": if __name__ != "__main__":
raise ImportError('this module should not be imported') raise ImportError('this module should not be imported')
from wormhole.cli import cli
cli.wormhole() cli.wormhole()

View File

@ -1,56 +1,78 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
from zope.interface import implementer
from attr import attrs, attrib from attr import attrib, attrs
from attr.validators import provides from attr.validators import provides
from automat import MethodicalMachine from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces from . import _interfaces
@attrs @attrs
@implementer(_interfaces.IAllocator) @implementer(_interfaces.IAllocator)
class Allocator(object): class Allocator(object):
_timing = attrib(validator=provides(_interfaces.ITiming)) _timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def wire(self, rendezvous_connector, code): def wire(self, rendezvous_connector, code):
self._RC = _interfaces.IRendezvousConnector(rendezvous_connector) self._RC = _interfaces.IRendezvousConnector(rendezvous_connector)
self._C = _interfaces.ICode(code) self._C = _interfaces.ICode(code)
@m.state(initial=True) @m.state(initial=True)
def S0A_idle(self): pass # pragma: no cover def S0A_idle(self):
pass # pragma: no cover
@m.state() @m.state()
def S0B_idle_connected(self): pass # pragma: no cover def S0B_idle_connected(self):
pass # pragma: no cover
@m.state() @m.state()
def S1A_allocating(self): pass # pragma: no cover def S1A_allocating(self):
pass # pragma: no cover
@m.state() @m.state()
def S1B_allocating_connected(self): pass # pragma: no cover def S1B_allocating_connected(self):
pass # pragma: no cover
@m.state() @m.state()
def S2_done(self): pass # pragma: no cover def S2_done(self):
pass # pragma: no cover
# from Code # from Code
@m.input() @m.input()
def allocate(self, length, wordlist): pass def allocate(self, length, wordlist):
pass
# from RendezvousConnector # from RendezvousConnector
@m.input() @m.input()
def connected(self): pass def connected(self):
pass
@m.input() @m.input()
def lost(self): pass def lost(self):
pass
@m.input() @m.input()
def rx_allocated(self, nameplate): pass def rx_allocated(self, nameplate):
pass
@m.output() @m.output()
def stash(self, length, wordlist): def stash(self, length, wordlist):
self._length = length self._length = length
self._wordlist = _interfaces.IWordlist(wordlist) self._wordlist = _interfaces.IWordlist(wordlist)
@m.output() @m.output()
def stash_and_RC_rx_allocate(self, length, wordlist): def stash_and_RC_rx_allocate(self, length, wordlist):
self._length = length self._length = length
self._wordlist = _interfaces.IWordlist(wordlist) self._wordlist = _interfaces.IWordlist(wordlist)
self._RC.tx_allocate() self._RC.tx_allocate()
@m.output() @m.output()
def RC_tx_allocate(self): def RC_tx_allocate(self):
self._RC.tx_allocate() self._RC.tx_allocate()
@m.output() @m.output()
def build_and_notify(self, nameplate): def build_and_notify(self, nameplate):
words = self._wordlist.choose_words(self._length) words = self._wordlist.choose_words(self._length)
@ -61,15 +83,17 @@ class Allocator(object):
S0B_idle_connected.upon(lost, enter=S0A_idle, outputs=[]) S0B_idle_connected.upon(lost, enter=S0A_idle, outputs=[])
S0A_idle.upon(allocate, enter=S1A_allocating, outputs=[stash]) S0A_idle.upon(allocate, enter=S1A_allocating, outputs=[stash])
S0B_idle_connected.upon(allocate, enter=S1B_allocating_connected, S0B_idle_connected.upon(
allocate,
enter=S1B_allocating_connected,
outputs=[stash_and_RC_rx_allocate]) outputs=[stash_and_RC_rx_allocate])
S1A_allocating.upon(connected, enter=S1B_allocating_connected, S1A_allocating.upon(
outputs=[RC_tx_allocate]) connected, enter=S1B_allocating_connected, outputs=[RC_tx_allocate])
S1B_allocating_connected.upon(lost, enter=S1A_allocating, outputs=[]) S1B_allocating_connected.upon(lost, enter=S1A_allocating, outputs=[])
S1B_allocating_connected.upon(rx_allocated, enter=S2_done, S1B_allocating_connected.upon(
outputs=[build_and_notify]) rx_allocated, enter=S2_done, outputs=[build_and_notify])
S2_done.upon(connected, enter=S2_done, outputs=[]) S2_done.upon(connected, enter=S2_done, outputs=[])
S2_done.upon(lost, enter=S2_done, outputs=[]) S2_done.upon(lost, enter=S2_done, outputs=[])

View File

@ -1,29 +1,33 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
import re import re
import six import six
from zope.interface import implementer from attr import attrib, attrs
from attr import attrs, attrib from attr.validators import instance_of, optional, provides
from attr.validators import provides, instance_of, optional
from twisted.python import log
from automat import MethodicalMachine from automat import MethodicalMachine
from twisted.python import log
from zope.interface import implementer
from . import _interfaces from . import _interfaces
from ._nameplate import Nameplate from ._allocator import Allocator
from ._mailbox import Mailbox from ._code import Code, validate_code
from ._send import Send from ._input import Input
from ._order import Order
from ._key import Key from ._key import Key
from ._lister import Lister
from ._mailbox import Mailbox
from ._nameplate import Nameplate
from ._order import Order
from ._receive import Receive from ._receive import Receive
from ._rendezvous import RendezvousConnector from ._rendezvous import RendezvousConnector
from ._lister import Lister from ._send import Send
from ._allocator import Allocator
from ._input import Input
from ._code import Code, validate_code
from ._terminator import Terminator from ._terminator import Terminator
from ._wordlist import PGPWordList from ._wordlist import PGPWordList
from .errors import (ServerError, LonelyError, WrongPasswordError, from .errors import (LonelyError, OnlyOneCodeError, ServerError, WelcomeError,
OnlyOneCodeError, _UnknownPhaseError, WelcomeError) WrongPasswordError, _UnknownPhaseError)
from .util import bytes_to_dict from .util import bytes_to_dict
@attrs @attrs
@implementer(_interfaces.IBoss) @implementer(_interfaces.IBoss)
class Boss(object): class Boss(object):
@ -38,7 +42,8 @@ class Boss(object):
_tor = attrib(validator=optional(provides(_interfaces.ITorManager))) _tor = attrib(validator=optional(provides(_interfaces.ITorManager)))
_timing = attrib(validator=provides(_interfaces.ITiming)) _timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self): def __attrs_post_init__(self):
self._build_workers() self._build_workers()
@ -52,9 +57,8 @@ class Boss(object):
self._K = Key(self._appid, self._versions, self._side, self._timing) self._K = Key(self._appid, self._versions, self._side, self._timing)
self._R = Receive(self._side, self._timing) self._R = Receive(self._side, self._timing)
self._RC = RendezvousConnector(self._url, self._appid, self._side, self._RC = RendezvousConnector(self._url, self._appid, self._side,
self._reactor, self._journal, self._reactor, self._journal, self._tor,
self._tor, self._timing, self._timing, self._client_version)
self._client_version)
self._L = Lister(self._timing) self._L = Lister(self._timing)
self._A = Allocator(self._timing) self._A = Allocator(self._timing)
self._I = Input(self._timing) self._I = Input(self._timing)
@ -86,32 +90,45 @@ class Boss(object):
def start(self): def start(self):
self._RC.start() self._RC.start()
def _print_trace(self, old_state, input, new_state, def _print_trace(self, old_state, input, new_state, client_name, machine,
client_name, machine, file): file):
if new_state: if new_state:
print("%s.%s[%s].%s -> [%s]" % print(
(client_name, machine, old_state, input, "%s.%s[%s].%s -> [%s]" % (client_name, machine, old_state,
new_state), file=file) input, new_state),
file=file)
else: else:
# the RendezvousConnector emits message events as if # the RendezvousConnector emits message events as if
# they were state transitions, except that old_state # they were state transitions, except that old_state
# and new_state are empty strings. "input" is one of # and new_state are empty strings. "input" is one of
# R.connected, R.rx(type phase+side), R.tx(type # R.connected, R.rx(type phase+side), R.tx(type
# phase), R.lost . # phase), R.lost .
print("%s.%s.%s" % (client_name, machine, input), print("%s.%s.%s" % (client_name, machine, input), file=file)
file=file)
file.flush() file.flush()
def output_tracer(output): def output_tracer(output):
print(" %s.%s.%s()" % (client_name, machine, output), print(" %s.%s.%s()" % (client_name, machine, output), file=file)
file=file)
file.flush() file.flush()
return output_tracer return output_tracer
def _set_trace(self, client_name, which, file): def _set_trace(self, client_name, which, file):
names = {"B": self, "N": self._N, "M": self._M, "S": self._S, names = {
"O": self._O, "K": self._K, "SK": self._K._SK, "R": self._R, "B": self,
"RC": self._RC, "L": self._L, "A": self._A, "I": self._I, "N": self._N,
"C": self._C, "T": self._T} "M": self._M,
"S": self._S,
"O": self._O,
"K": self._K,
"SK": self._K._SK,
"R": self._R,
"RC": self._RC,
"L": self._L,
"A": self._A,
"I": self._I,
"C": self._C,
"T": self._T
}
for machine in which.split(): for machine in which.split():
t = (lambda old_state, input, new_state, machine=machine: t = (lambda old_state, input, new_state, machine=machine:
self._print_trace(old_state, input, new_state, self._print_trace(old_state, input, new_state,
@ -121,21 +138,30 @@ class Boss(object):
if machine == "I": if machine == "I":
self._I.set_debug(t) self._I.set_debug(t)
## def serialize(self): # def serialize(self):
## raise NotImplemented # raise NotImplemented
# and these are the state-machine transition functions, which don't take # and these are the state-machine transition functions, which don't take
# args # args
@m.state(initial=True) @m.state(initial=True)
def S0_empty(self): pass # pragma: no cover def S0_empty(self):
pass # pragma: no cover
@m.state() @m.state()
def S1_lonely(self): pass # pragma: no cover def S1_lonely(self):
pass # pragma: no cover
@m.state() @m.state()
def S2_happy(self): pass # pragma: no cover def S2_happy(self):
pass # pragma: no cover
@m.state() @m.state()
def S3_closing(self): pass # pragma: no cover def S3_closing(self):
pass # pragma: no cover
@m.state(terminal=True) @m.state(terminal=True)
def S4_closed(self): pass # pragma: no cover def S4_closed(self):
pass # pragma: no cover
# from the Wormhole # from the Wormhole
@ -155,12 +181,14 @@ class Boss(object):
raise OnlyOneCodeError() raise OnlyOneCodeError()
self._did_start_code = True self._did_start_code = True
return self._C.input_code() return self._C.input_code()
def allocate_code(self, code_length): def allocate_code(self, code_length):
if self._did_start_code: if self._did_start_code:
raise OnlyOneCodeError() raise OnlyOneCodeError()
self._did_start_code = True self._did_start_code = True
wl = PGPWordList() wl = PGPWordList()
self._C.allocate_code(code_length, wl) self._C.allocate_code(code_length, wl)
def set_code(self, code): def set_code(self, code):
validate_code(code) # can raise KeyFormatError validate_code(code) # can raise KeyFormatError
if self._did_start_code: if self._did_start_code:
@ -169,9 +197,12 @@ class Boss(object):
self._C.set_code(code) self._C.set_code(code)
@m.input() @m.input()
def send(self, plaintext): pass def send(self, plaintext):
pass
@m.input() @m.input()
def close(self): pass def close(self):
pass
# from RendezvousConnector: # from RendezvousConnector:
# * "rx_welcome" is the Welcome message, which might signal an error, or # * "rx_welcome" is the Welcome message, which might signal an error, or
@ -193,23 +224,33 @@ class Boss(object):
self._W.got_welcome(welcome) # TODO: let this raise WelcomeError? self._W.got_welcome(welcome) # TODO: let this raise WelcomeError?
except WelcomeError as welcome_error: except WelcomeError as welcome_error:
self.rx_unwelcome(welcome_error) self.rx_unwelcome(welcome_error)
@m.input() @m.input()
def rx_unwelcome(self, welcome_error): pass def rx_unwelcome(self, welcome_error):
pass
@m.input() @m.input()
def rx_error(self, errmsg, orig): pass def rx_error(self, errmsg, orig):
pass
@m.input() @m.input()
def error(self, err): pass def error(self, err):
pass
# from Code (provoked by input/allocate/set_code) # from Code (provoked by input/allocate/set_code)
@m.input() @m.input()
def got_code(self, code): pass def got_code(self, code):
pass
# Key sends (got_key, scared) # Key sends (got_key, scared)
# Receive sends (got_message, happy, got_verifier, scared) # Receive sends (got_message, happy, got_verifier, scared)
@m.input() @m.input()
def happy(self): pass def happy(self):
pass
@m.input() @m.input()
def scared(self): pass def scared(self):
pass
def got_message(self, phase, plaintext): def got_message(self, phase, plaintext):
assert isinstance(phase, type("")), type(phase) assert isinstance(phase, type("")), type(phase)
@ -222,22 +263,32 @@ class Boss(object):
# Ignore unrecognized phases, for forwards-compatibility. Use # Ignore unrecognized phases, for forwards-compatibility. Use
# log.err so tests will catch surprises. # log.err so tests will catch surprises.
log.err(_UnknownPhaseError("received unknown phase '%s'" % phase)) log.err(_UnknownPhaseError("received unknown phase '%s'" % phase))
@m.input() @m.input()
def _got_version(self, plaintext): pass def _got_version(self, plaintext):
pass
@m.input() @m.input()
def _got_phase(self, phase, plaintext): pass def _got_phase(self, phase, plaintext):
pass
@m.input() @m.input()
def got_key(self, key): pass def got_key(self, key):
pass
@m.input() @m.input()
def got_verifier(self, verifier): pass def got_verifier(self, verifier):
pass
# Terminator sends closed # Terminator sends closed
@m.input() @m.input()
def closed(self): pass def closed(self):
pass
@m.output() @m.output()
def do_got_code(self, code): def do_got_code(self, code):
self._W.got_code(code) self._W.got_code(code)
@m.output() @m.output()
def process_version(self, plaintext): def process_version(self, plaintext):
# most of this is wormhole-to-wormhole, ignored for now # most of this is wormhole-to-wormhole, ignored for now
@ -259,18 +310,22 @@ class Boss(object):
# assert isinstance(err, WelcomeError) # assert isinstance(err, WelcomeError)
self._result = welcome_error self._result = welcome_error
self._T.close("unwelcome") self._T.close("unwelcome")
@m.output() @m.output()
def close_error(self, errmsg, orig): def close_error(self, errmsg, orig):
self._result = ServerError(errmsg) self._result = ServerError(errmsg)
self._T.close("errory") self._T.close("errory")
@m.output() @m.output()
def close_scared(self): def close_scared(self):
self._result = WrongPasswordError() self._result = WrongPasswordError()
self._T.close("scary") self._T.close("scary")
@m.output() @m.output()
def close_lonely(self): def close_lonely(self):
self._result = LonelyError() self._result = LonelyError()
self._T.close("lonely") self._T.close("lonely")
@m.output() @m.output()
def close_happy(self): def close_happy(self):
self._result = "happy" self._result = "happy"
@ -279,9 +334,11 @@ class Boss(object):
@m.output() @m.output()
def W_got_key(self, key): def W_got_key(self, key):
self._W.got_key(key) self._W.got_key(key)
@m.output() @m.output()
def W_got_verifier(self, verifier): def W_got_verifier(self, verifier):
self._W.got_verifier(verifier) self._W.got_verifier(verifier)
@m.output() @m.output()
def W_received(self, phase, plaintext): def W_received(self, phase, plaintext):
assert isinstance(phase, six.integer_types), type(phase) assert isinstance(phase, six.integer_types), type(phase)

View File

@ -7,21 +7,25 @@ from . import _interfaces
from ._nameplate import validate_nameplate from ._nameplate import validate_nameplate
from .errors import KeyFormatError from .errors import KeyFormatError
def validate_code(code): def validate_code(code):
if ' ' in code: if ' ' in code:
raise KeyFormatError("Code '%s' contains spaces." % code) raise KeyFormatError("Code '%s' contains spaces." % code)
nameplate = code.split("-", 2)[0] nameplate = code.split("-", 2)[0]
validate_nameplate(nameplate) # can raise KeyFormatError validate_nameplate(nameplate) # can raise KeyFormatError
def first(outputs): def first(outputs):
return list(outputs)[0] return list(outputs)[0]
@attrs @attrs
@implementer(_interfaces.ICode) @implementer(_interfaces.ICode)
class Code(object): class Code(object):
_timing = attrib(validator=provides(_interfaces.ITiming)) _timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def wire(self, boss, allocator, nameplate, key, input): def wire(self, boss, allocator, nameplate, key, input):
self._B = _interfaces.IBoss(boss) self._B = _interfaces.IBoss(boss)
@ -31,36 +35,55 @@ class Code(object):
self._I = _interfaces.IInput(input) self._I = _interfaces.IInput(input)
@m.state(initial=True) @m.state(initial=True)
def S0_idle(self): pass # pragma: no cover def S0_idle(self):
pass # pragma: no cover
@m.state() @m.state()
def S1_inputting_nameplate(self): pass # pragma: no cover def S1_inputting_nameplate(self):
pass # pragma: no cover
@m.state() @m.state()
def S2_inputting_words(self): pass # pragma: no cover def S2_inputting_words(self):
pass # pragma: no cover
@m.state() @m.state()
def S3_allocating(self): pass # pragma: no cover def S3_allocating(self):
pass # pragma: no cover
@m.state() @m.state()
def S4_known(self): pass # pragma: no cover def S4_known(self):
pass # pragma: no cover
# from App # from App
@m.input() @m.input()
def allocate_code(self, length, wordlist): pass def allocate_code(self, length, wordlist):
pass
@m.input() @m.input()
def input_code(self): pass def input_code(self):
pass
def set_code(self, code): def set_code(self, code):
validate_code(code) # can raise KeyFormatError validate_code(code) # can raise KeyFormatError
self._set_code(code) self._set_code(code)
@m.input() @m.input()
def _set_code(self, code): pass def _set_code(self, code):
pass
# from Allocator # from Allocator
@m.input() @m.input()
def allocated(self, nameplate, code): pass def allocated(self, nameplate, code):
pass
# from Input # from Input
@m.input() @m.input()
def got_nameplate(self, nameplate): pass def got_nameplate(self, nameplate):
pass
@m.input() @m.input()
def finished_input(self, code): pass def finished_input(self, code):
pass
@m.output() @m.output()
def do_set_code(self, code): def do_set_code(self, code):
@ -72,9 +95,11 @@ class Code(object):
@m.output() @m.output()
def do_start_input(self): def do_start_input(self):
return self._I.start() return self._I.start()
@m.output() @m.output()
def do_middle_input(self, nameplate): def do_middle_input(self, nameplate):
self._N.set_nameplate(nameplate) self._N.set_nameplate(nameplate)
@m.output() @m.output()
def do_finish_input(self, code): def do_finish_input(self, code):
self._B.got_code(code) self._B.got_code(code)
@ -83,6 +108,7 @@ class Code(object):
@m.output() @m.output()
def do_start_allocate(self, length, wordlist): def do_start_allocate(self, length, wordlist):
self._A.allocate(length, wordlist) self._A.allocate(length, wordlist)
@m.output() @m.output()
def do_finish_allocate(self, nameplate, code): def do_finish_allocate(self, nameplate, code):
assert code.startswith(nameplate + "-"), (nameplate, code) assert code.startswith(nameplate + "-"), (nameplate, code)
@ -91,11 +117,15 @@ class Code(object):
self._K.got_code(code) self._K.got_code(code)
S0_idle.upon(_set_code, enter=S4_known, outputs=[do_set_code]) S0_idle.upon(_set_code, enter=S4_known, outputs=[do_set_code])
S0_idle.upon(input_code, enter=S1_inputting_nameplate, S0_idle.upon(
outputs=[do_start_input], collector=first) input_code,
S1_inputting_nameplate.upon(got_nameplate, enter=S2_inputting_words, enter=S1_inputting_nameplate,
outputs=[do_middle_input]) outputs=[do_start_input],
S2_inputting_words.upon(finished_input, enter=S4_known, collector=first)
outputs=[do_finish_input]) S1_inputting_nameplate.upon(
S0_idle.upon(allocate_code, enter=S3_allocating, outputs=[do_start_allocate]) got_nameplate, enter=S2_inputting_words, outputs=[do_middle_input])
S2_inputting_words.upon(
finished_input, enter=S4_known, outputs=[do_finish_input])
S0_idle.upon(
allocate_code, enter=S3_allocating, outputs=[do_start_allocate])
S3_allocating.upon(allocated, enter=S4_known, outputs=[do_finish_allocate]) S3_allocating.upon(allocated, enter=S4_known, outputs=[do_finish_allocate])

View File

@ -1,25 +1,31 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
# We use 'threading' defensively here, to detect if we're being called from a # We use 'threading' defensively here, to detect if we're being called from a
# non-main thread. _rlcompleter.py is the only internal Wormhole code that # non-main thread. _rlcompleter.py is the only internal Wormhole code that
# deliberately creates a new thread. # deliberately creates a new thread.
import threading import threading
from zope.interface import implementer
from attr import attrs, attrib from attr import attrib, attrs
from attr.validators import provides from attr.validators import provides
from twisted.internet import defer
from automat import MethodicalMachine from automat import MethodicalMachine
from twisted.internet import defer
from zope.interface import implementer
from . import _interfaces, errors from . import _interfaces, errors
from ._nameplate import validate_nameplate from ._nameplate import validate_nameplate
def first(outputs): def first(outputs):
return list(outputs)[0] return list(outputs)[0]
@attrs @attrs
@implementer(_interfaces.IInput) @implementer(_interfaces.IInput)
class Input(object): class Input(object):
_timing = attrib(validator=provides(_interfaces.ITiming)) _timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self): def __attrs_post_init__(self):
self._all_nameplates = set() self._all_nameplates = set()
@ -30,6 +36,7 @@ class Input(object):
def set_debug(self, f): def set_debug(self, f):
self._trace = f self._trace = f
def _debug(self, what): # pragma: no cover def _debug(self, what): # pragma: no cover
if self._trace: if self._trace:
self._trace(old_state="", input=what, new_state="") self._trace(old_state="", input=what, new_state="")
@ -46,55 +53,80 @@ class Input(object):
return d return d
@m.state(initial=True) @m.state(initial=True)
def S0_idle(self): pass # pragma: no cover def S0_idle(self):
pass # pragma: no cover
@m.state() @m.state()
def S1_typing_nameplate(self): pass # pragma: no cover def S1_typing_nameplate(self):
pass # pragma: no cover
@m.state() @m.state()
def S2_typing_code_no_wordlist(self): pass # pragma: no cover def S2_typing_code_no_wordlist(self):
pass # pragma: no cover
@m.state() @m.state()
def S3_typing_code_yes_wordlist(self): pass # pragma: no cover def S3_typing_code_yes_wordlist(self):
pass # pragma: no cover
@m.state(terminal=True) @m.state(terminal=True)
def S4_done(self): pass # pragma: no cover def S4_done(self):
pass # pragma: no cover
# from Code # from Code
@m.input() @m.input()
def start(self): pass def start(self):
pass
# from Lister # from Lister
@m.input() @m.input()
def got_nameplates(self, all_nameplates): pass def got_nameplates(self, all_nameplates):
pass
# from Nameplate # from Nameplate
@m.input() @m.input()
def got_wordlist(self, wordlist): pass def got_wordlist(self, wordlist):
pass
# API provided to app as ICodeInputHelper # API provided to app as ICodeInputHelper
@m.input() @m.input()
def refresh_nameplates(self): pass def refresh_nameplates(self):
pass
@m.input() @m.input()
def get_nameplate_completions(self, prefix): pass def get_nameplate_completions(self, prefix):
pass
def choose_nameplate(self, nameplate): def choose_nameplate(self, nameplate):
validate_nameplate(nameplate) # can raise KeyFormatError validate_nameplate(nameplate) # can raise KeyFormatError
self._choose_nameplate(nameplate) self._choose_nameplate(nameplate)
@m.input() @m.input()
def _choose_nameplate(self, nameplate): pass def _choose_nameplate(self, nameplate):
pass
@m.input() @m.input()
def get_word_completions(self, prefix): pass def get_word_completions(self, prefix):
pass
@m.input() @m.input()
def choose_words(self, words): pass def choose_words(self, words):
pass
@m.output() @m.output()
def do_start(self): def do_start(self):
self._start_timing = self._timing.add("input code", waiting="user") self._start_timing = self._timing.add("input code", waiting="user")
self._L.refresh() self._L.refresh()
return Helper(self) return Helper(self)
@m.output() @m.output()
def do_refresh(self): def do_refresh(self):
self._L.refresh() self._L.refresh()
@m.output() @m.output()
def record_nameplates(self, all_nameplates): def record_nameplates(self, all_nameplates):
# we get a set of nameplate id strings # we get a set of nameplate id strings
self._all_nameplates = all_nameplates self._all_nameplates = all_nameplates
@m.output() @m.output()
def _get_nameplate_completions(self, prefix): def _get_nameplate_completions(self, prefix):
completions = set() completions = set()
@ -104,15 +136,18 @@ class Input(object):
# hyphen on nameplates, but WordList owns it for words # hyphen on nameplates, but WordList owns it for words
completions.add(nameplate + "-") completions.add(nameplate + "-")
return completions return completions
@m.output() @m.output()
def record_all_nameplates(self, nameplate): def record_all_nameplates(self, nameplate):
self._nameplate = nameplate self._nameplate = nameplate
self._C.got_nameplate(nameplate) self._C.got_nameplate(nameplate)
@m.output() @m.output()
def record_wordlist(self, wordlist): def record_wordlist(self, wordlist):
from ._rlcompleter import debug from ._rlcompleter import debug
debug(" -record_wordlist") debug(" -record_wordlist")
self._wordlist = wordlist self._wordlist = wordlist
@m.output() @m.output()
def notify_wordlist_waiters(self, wordlist): def notify_wordlist_waiters(self, wordlist):
while self._wordlist_waiters: while self._wordlist_waiters:
@ -122,6 +157,7 @@ class Input(object):
@m.output() @m.output()
def no_word_completions(self, prefix): def no_word_completions(self, prefix):
return set() return set()
@m.output() @m.output()
def _get_word_completions(self, prefix): def _get_word_completions(self, prefix):
assert self._wordlist assert self._wordlist
@ -130,21 +166,27 @@ class Input(object):
@m.output() @m.output()
def raise_must_choose_nameplate1(self, prefix): def raise_must_choose_nameplate1(self, prefix):
raise errors.MustChooseNameplateFirstError() raise errors.MustChooseNameplateFirstError()
@m.output() @m.output()
def raise_must_choose_nameplate2(self, words): def raise_must_choose_nameplate2(self, words):
raise errors.MustChooseNameplateFirstError() raise errors.MustChooseNameplateFirstError()
@m.output() @m.output()
def raise_already_chose_nameplate1(self): def raise_already_chose_nameplate1(self):
raise errors.AlreadyChoseNameplateError() raise errors.AlreadyChoseNameplateError()
@m.output() @m.output()
def raise_already_chose_nameplate2(self, prefix): def raise_already_chose_nameplate2(self, prefix):
raise errors.AlreadyChoseNameplateError() raise errors.AlreadyChoseNameplateError()
@m.output() @m.output()
def raise_already_chose_nameplate3(self, nameplate): def raise_already_chose_nameplate3(self, nameplate):
raise errors.AlreadyChoseNameplateError() raise errors.AlreadyChoseNameplateError()
@m.output() @m.output()
def raise_already_chose_words1(self, prefix): def raise_already_chose_words1(self, prefix):
raise errors.AlreadyChoseWordsError() raise errors.AlreadyChoseWordsError()
@m.output() @m.output()
def raise_already_chose_words2(self, words): def raise_already_chose_words2(self, words):
raise errors.AlreadyChoseWordsError() raise errors.AlreadyChoseWordsError()
@ -155,88 +197,110 @@ class Input(object):
self._start_timing.finish() self._start_timing.finish()
self._C.finished_input(code) self._C.finished_input(code)
S0_idle.upon(start, enter=S1_typing_nameplate, S0_idle.upon(
outputs=[do_start], collector=first) start, enter=S1_typing_nameplate, outputs=[do_start], collector=first)
# wormholes that don't use input_code (i.e. they use allocate_code or # wormholes that don't use input_code (i.e. they use allocate_code or
# generate_code) will never start() us, but Nameplate will give us a # generate_code) will never start() us, but Nameplate will give us a
# wordlist anyways (as soon as the nameplate is claimed), so handle it. # wordlist anyways (as soon as the nameplate is claimed), so handle it.
S0_idle.upon(got_wordlist, enter=S0_idle, outputs=[record_wordlist, S0_idle.upon(
notify_wordlist_waiters]) got_wordlist,
S1_typing_nameplate.upon(got_nameplates, enter=S1_typing_nameplate, enter=S0_idle,
outputs=[record_nameplates]) outputs=[record_wordlist, notify_wordlist_waiters])
S1_typing_nameplate.upon(
got_nameplates, enter=S1_typing_nameplate, outputs=[record_nameplates])
# but wormholes that *do* use input_code should not get got_wordlist # but wormholes that *do* use input_code should not get got_wordlist
# until after we tell Code that we got_nameplate, which is the earliest # until after we tell Code that we got_nameplate, which is the earliest
# it can be claimed # it can be claimed
S1_typing_nameplate.upon(refresh_nameplates, enter=S1_typing_nameplate, S1_typing_nameplate.upon(
outputs=[do_refresh]) refresh_nameplates, enter=S1_typing_nameplate, outputs=[do_refresh])
S1_typing_nameplate.upon(get_nameplate_completions, S1_typing_nameplate.upon(
get_nameplate_completions,
enter=S1_typing_nameplate, enter=S1_typing_nameplate,
outputs=[_get_nameplate_completions], outputs=[_get_nameplate_completions],
collector=first) collector=first)
S1_typing_nameplate.upon(_choose_nameplate, enter=S2_typing_code_no_wordlist, S1_typing_nameplate.upon(
_choose_nameplate,
enter=S2_typing_code_no_wordlist,
outputs=[record_all_nameplates]) outputs=[record_all_nameplates])
S1_typing_nameplate.upon(get_word_completions, S1_typing_nameplate.upon(
get_word_completions,
enter=S1_typing_nameplate, enter=S1_typing_nameplate,
outputs=[raise_must_choose_nameplate1]) outputs=[raise_must_choose_nameplate1])
S1_typing_nameplate.upon(choose_words, enter=S1_typing_nameplate, S1_typing_nameplate.upon(
choose_words,
enter=S1_typing_nameplate,
outputs=[raise_must_choose_nameplate2]) outputs=[raise_must_choose_nameplate2])
S2_typing_code_no_wordlist.upon(got_nameplates, S2_typing_code_no_wordlist.upon(
enter=S2_typing_code_no_wordlist, outputs=[]) got_nameplates, enter=S2_typing_code_no_wordlist, outputs=[])
S2_typing_code_no_wordlist.upon(got_wordlist, S2_typing_code_no_wordlist.upon(
got_wordlist,
enter=S3_typing_code_yes_wordlist, enter=S3_typing_code_yes_wordlist,
outputs=[record_wordlist, outputs=[record_wordlist, notify_wordlist_waiters])
notify_wordlist_waiters]) S2_typing_code_no_wordlist.upon(
S2_typing_code_no_wordlist.upon(refresh_nameplates, refresh_nameplates,
enter=S2_typing_code_no_wordlist, enter=S2_typing_code_no_wordlist,
outputs=[raise_already_chose_nameplate1]) outputs=[raise_already_chose_nameplate1])
S2_typing_code_no_wordlist.upon(get_nameplate_completions, S2_typing_code_no_wordlist.upon(
get_nameplate_completions,
enter=S2_typing_code_no_wordlist, enter=S2_typing_code_no_wordlist,
outputs=[raise_already_chose_nameplate2]) outputs=[raise_already_chose_nameplate2])
S2_typing_code_no_wordlist.upon(_choose_nameplate, S2_typing_code_no_wordlist.upon(
_choose_nameplate,
enter=S2_typing_code_no_wordlist, enter=S2_typing_code_no_wordlist,
outputs=[raise_already_chose_nameplate3]) outputs=[raise_already_chose_nameplate3])
S2_typing_code_no_wordlist.upon(get_word_completions, S2_typing_code_no_wordlist.upon(
get_word_completions,
enter=S2_typing_code_no_wordlist, enter=S2_typing_code_no_wordlist,
outputs=[no_word_completions], outputs=[no_word_completions],
collector=first) collector=first)
S2_typing_code_no_wordlist.upon(choose_words, enter=S4_done, S2_typing_code_no_wordlist.upon(
outputs=[do_words]) choose_words, enter=S4_done, outputs=[do_words])
S3_typing_code_yes_wordlist.upon(got_nameplates, S3_typing_code_yes_wordlist.upon(
enter=S3_typing_code_yes_wordlist, got_nameplates, enter=S3_typing_code_yes_wordlist, outputs=[])
outputs=[])
# got_wordlist: should never happen # got_wordlist: should never happen
S3_typing_code_yes_wordlist.upon(refresh_nameplates, S3_typing_code_yes_wordlist.upon(
refresh_nameplates,
enter=S3_typing_code_yes_wordlist, enter=S3_typing_code_yes_wordlist,
outputs=[raise_already_chose_nameplate1]) outputs=[raise_already_chose_nameplate1])
S3_typing_code_yes_wordlist.upon(get_nameplate_completions, S3_typing_code_yes_wordlist.upon(
get_nameplate_completions,
enter=S3_typing_code_yes_wordlist, enter=S3_typing_code_yes_wordlist,
outputs=[raise_already_chose_nameplate2]) outputs=[raise_already_chose_nameplate2])
S3_typing_code_yes_wordlist.upon(_choose_nameplate, S3_typing_code_yes_wordlist.upon(
_choose_nameplate,
enter=S3_typing_code_yes_wordlist, enter=S3_typing_code_yes_wordlist,
outputs=[raise_already_chose_nameplate3]) outputs=[raise_already_chose_nameplate3])
S3_typing_code_yes_wordlist.upon(get_word_completions, S3_typing_code_yes_wordlist.upon(
get_word_completions,
enter=S3_typing_code_yes_wordlist, enter=S3_typing_code_yes_wordlist,
outputs=[_get_word_completions], outputs=[_get_word_completions],
collector=first) collector=first)
S3_typing_code_yes_wordlist.upon(choose_words, enter=S4_done, S3_typing_code_yes_wordlist.upon(
outputs=[do_words]) choose_words, enter=S4_done, outputs=[do_words])
S4_done.upon(got_nameplates, enter=S4_done, outputs=[]) S4_done.upon(got_nameplates, enter=S4_done, outputs=[])
S4_done.upon(got_wordlist, enter=S4_done, outputs=[]) S4_done.upon(got_wordlist, enter=S4_done, outputs=[])
S4_done.upon(refresh_nameplates, S4_done.upon(
refresh_nameplates,
enter=S4_done, enter=S4_done,
outputs=[raise_already_chose_nameplate1]) outputs=[raise_already_chose_nameplate1])
S4_done.upon(get_nameplate_completions, S4_done.upon(
get_nameplate_completions,
enter=S4_done, enter=S4_done,
outputs=[raise_already_chose_nameplate2]) outputs=[raise_already_chose_nameplate2])
S4_done.upon(_choose_nameplate, enter=S4_done, S4_done.upon(
_choose_nameplate,
enter=S4_done,
outputs=[raise_already_chose_nameplate3]) outputs=[raise_already_chose_nameplate3])
S4_done.upon(get_word_completions, enter=S4_done, S4_done.upon(
get_word_completions,
enter=S4_done,
outputs=[raise_already_chose_words1]) outputs=[raise_already_chose_words1])
S4_done.upon(choose_words, enter=S4_done, S4_done.upon(
outputs=[raise_already_chose_words2]) choose_words, enter=S4_done, outputs=[raise_already_chose_words2])
# we only expose the Helper to application code, not _Input # we only expose the Helper to application code, not _Input
@attrs @attrs
@ -250,20 +314,25 @@ class Helper(object):
def refresh_nameplates(self): def refresh_nameplates(self):
assert threading.current_thread().ident == self._main_thread assert threading.current_thread().ident == self._main_thread
self._input.refresh_nameplates() self._input.refresh_nameplates()
def get_nameplate_completions(self, prefix): def get_nameplate_completions(self, prefix):
assert threading.current_thread().ident == self._main_thread assert threading.current_thread().ident == self._main_thread
return self._input.get_nameplate_completions(prefix) return self._input.get_nameplate_completions(prefix)
def choose_nameplate(self, nameplate): def choose_nameplate(self, nameplate):
assert threading.current_thread().ident == self._main_thread assert threading.current_thread().ident == self._main_thread
self._input._debug("I.choose_nameplate") self._input._debug("I.choose_nameplate")
self._input.choose_nameplate(nameplate) self._input.choose_nameplate(nameplate)
self._input._debug("I.choose_nameplate finished") self._input._debug("I.choose_nameplate finished")
def when_wordlist_is_available(self): def when_wordlist_is_available(self):
assert threading.current_thread().ident == self._main_thread assert threading.current_thread().ident == self._main_thread
return self._input.when_wordlist_is_available() return self._input.when_wordlist_is_available()
def get_word_completions(self, prefix): def get_word_completions(self, prefix):
assert threading.current_thread().ident == self._main_thread assert threading.current_thread().ident == self._main_thread
return self._input.get_word_completions(prefix) return self._input.get_word_completions(prefix)
def choose_words(self, words): def choose_words(self, words):
assert threading.current_thread().ident == self._main_thread assert threading.current_thread().ident == self._main_thread
self._input._debug("I.choose_words") self._input._debug("I.choose_words")

View File

@ -3,64 +3,105 @@ from zope.interface import Interface
# These interfaces are private: we use them as markers to detect # These interfaces are private: we use them as markers to detect
# swapped argument bugs in the various .wire() calls # swapped argument bugs in the various .wire() calls
class IWormhole(Interface): class IWormhole(Interface):
"""Internal: this contains the methods invoked 'from below'.""" """Internal: this contains the methods invoked 'from below'."""
def got_welcome(welcome): def got_welcome(welcome):
pass pass
def got_code(code): def got_code(code):
pass pass
def got_key(key): def got_key(key):
pass pass
def got_verifier(verifier): def got_verifier(verifier):
pass pass
def got_versions(versions): def got_versions(versions):
pass pass
def received(plaintext): def received(plaintext):
pass pass
def closed(result): def closed(result):
pass pass
class IBoss(Interface): class IBoss(Interface):
pass pass
class INameplate(Interface): class INameplate(Interface):
pass pass
class IMailbox(Interface): class IMailbox(Interface):
pass pass
class ISend(Interface): class ISend(Interface):
pass pass
class IOrder(Interface): class IOrder(Interface):
pass pass
class IKey(Interface): class IKey(Interface):
pass pass
class IReceive(Interface): class IReceive(Interface):
pass pass
class IRendezvousConnector(Interface): class IRendezvousConnector(Interface):
pass pass
class ILister(Interface): class ILister(Interface):
pass pass
class ICode(Interface): class ICode(Interface):
pass pass
class IInput(Interface): class IInput(Interface):
pass pass
class IAllocator(Interface): class IAllocator(Interface):
pass pass
class ITerminator(Interface): class ITerminator(Interface):
pass pass
class ITiming(Interface): class ITiming(Interface):
pass pass
class ITorManager(Interface): class ITorManager(Interface):
pass pass
class IWordlist(Interface): class IWordlist(Interface):
def choose_words(length): def choose_words(length):
"""Randomly select LENGTH words, join them with hyphens, return the """Randomly select LENGTH words, join them with hyphens, return the
result.""" result."""
def get_completions(prefix): def get_completions(prefix):
"""Return a list of all suffixes that could complete the given """Return a list of all suffixes that could complete the given
prefix.""" prefix."""
# These interfaces are public, and are re-exported by __init__.py # These interfaces are public, and are re-exported by __init__.py
class IDeferredWormhole(Interface): class IDeferredWormhole(Interface):
def get_welcome(): def get_welcome():
""" """
@ -277,6 +318,7 @@ class IDeferredWormhole(Interface):
:rtype: ``Deferred`` :rtype: ``Deferred``
""" """
class IInputHelper(Interface): class IInputHelper(Interface):
def refresh_nameplates(): def refresh_nameplates():
""" """

View File

@ -1,41 +1,50 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
from hashlib import sha256 from hashlib import sha256
import six import six
from zope.interface import implementer from attr import attrib, attrs
from attr import attrs, attrib from attr.validators import instance_of, provides
from attr.validators import provides, instance_of
from spake2 import SPAKE2_Symmetric
from hkdf import Hkdf
from nacl.secret import SecretBox
from nacl.exceptions import CryptoError
from nacl import utils
from automat import MethodicalMachine from automat import MethodicalMachine
from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes, from hkdf import Hkdf
bytes_to_dict, dict_to_bytes) from nacl import utils
from nacl.exceptions import CryptoError
from nacl.secret import SecretBox
from spake2 import SPAKE2_Symmetric
from zope.interface import implementer
from . import _interfaces from . import _interfaces
from .util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes,
hexstr_to_bytes, to_bytes)
CryptoError CryptoError
__all__ = ["derive_key", "derive_phase_key", "CryptoError", __all__ = ["derive_key", "derive_phase_key", "CryptoError", "Key"]
"Key"]
def HKDF(skm, outlen, salt=None, CTXinfo=b""): def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen) return Hkdf(salt, skm).expand(CTXinfo, outlen)
def derive_key(key, purpose, length=SecretBox.KEY_SIZE): def derive_key(key, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(key, type(b"")): raise TypeError(type(key)) if not isinstance(key, type(b"")):
if not isinstance(purpose, type(b"")): raise TypeError(type(purpose)) raise TypeError(type(key))
if not isinstance(length, six.integer_types): raise TypeError(type(length)) if not isinstance(purpose, type(b"")):
raise TypeError(type(purpose))
if not isinstance(length, six.integer_types):
raise TypeError(type(length))
return HKDF(key, length, CTXinfo=purpose) return HKDF(key, length, CTXinfo=purpose)
def derive_phase_key(key, side, phase): def derive_phase_key(key, side, phase):
assert isinstance(side, type("")), type(side) assert isinstance(side, type("")), type(side)
assert isinstance(phase, type("")), type(phase) assert isinstance(phase, type("")), type(phase)
side_bytes = side.encode("ascii") side_bytes = side.encode("ascii")
phase_bytes = phase.encode("ascii") phase_bytes = phase.encode("ascii")
purpose = (b"wormhole:phase:" purpose = (b"wormhole:phase:" + sha256(side_bytes).digest() +
+ sha256(side_bytes).digest() sha256(phase_bytes).digest())
+ sha256(phase_bytes).digest())
return derive_key(key, purpose) return derive_key(key, purpose)
def decrypt_data(key, encrypted): def decrypt_data(key, encrypted):
assert isinstance(key, type(b"")), type(key) assert isinstance(key, type(b"")), type(key)
assert isinstance(encrypted, type(b"")), type(encrypted) assert isinstance(encrypted, type(b"")), type(encrypted)
@ -44,6 +53,7 @@ def decrypt_data(key, encrypted):
data = box.decrypt(encrypted) data = box.decrypt(encrypted)
return data return data
def encrypt_data(key, plaintext): def encrypt_data(key, plaintext):
assert isinstance(key, type(b"")), type(key) assert isinstance(key, type(b"")), type(key)
assert isinstance(plaintext, type(b"")), type(plaintext) assert isinstance(plaintext, type(b"")), type(plaintext)
@ -52,10 +62,12 @@ def encrypt_data(key, plaintext):
nonce = utils.random(SecretBox.NONCE_SIZE) nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(plaintext, nonce) return box.encrypt(plaintext, nonce)
# the Key we expose to callers (Boss, Ordering) is responsible for sorting # the Key we expose to callers (Boss, Ordering) is responsible for sorting
# the two messages (got_code and got_pake), then delivering them to # the two messages (got_code and got_pake), then delivering them to
# _SortedKey in the right order. # _SortedKey in the right order.
@attrs @attrs
@implementer(_interfaces.IKey) @implementer(_interfaces.IKey)
class Key(object): class Key(object):
@ -64,7 +76,8 @@ class Key(object):
_side = attrib(validator=instance_of(type(u""))) _side = attrib(validator=instance_of(type(u"")))
_timing = attrib(validator=provides(_interfaces.ITiming)) _timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self): def __attrs_post_init__(self):
self._SK = _SortedKey(self._appid, self._versions, self._side, self._SK = _SortedKey(self._appid, self._versions, self._side,
@ -75,29 +88,42 @@ class Key(object):
self._SK.wire(boss, mailbox, receive) self._SK.wire(boss, mailbox, receive)
@m.state(initial=True) @m.state(initial=True)
def S00(self): pass # pragma: no cover def S00(self):
pass # pragma: no cover
@m.state() @m.state()
def S01(self): pass # pragma: no cover def S01(self):
pass # pragma: no cover
@m.state() @m.state()
def S10(self): pass # pragma: no cover def S10(self):
pass # pragma: no cover
@m.state() @m.state()
def S11(self): pass # pragma: no cover def S11(self):
pass # pragma: no cover
@m.input() @m.input()
def got_code(self, code): pass def got_code(self, code):
pass
@m.input() @m.input()
def got_pake(self, body): pass def got_pake(self, body):
pass
@m.output() @m.output()
def stash_pake(self, body): def stash_pake(self, body):
self._pake = body self._pake = body
self._debug_pake_stashed = True self._debug_pake_stashed = True
@m.output() @m.output()
def deliver_code(self, code): def deliver_code(self, code):
self._SK.got_code(code) self._SK.got_code(code)
@m.output() @m.output()
def deliver_pake(self, body): def deliver_pake(self, body):
self._SK.got_pake(body) self._SK.got_pake(body)
@m.output() @m.output()
def deliver_code_and_stashed_pake(self, code): def deliver_code_and_stashed_pake(self, code):
self._SK.got_code(code) self._SK.got_code(code)
@ -108,6 +134,7 @@ class Key(object):
S00.upon(got_pake, enter=S01, outputs=[stash_pake]) S00.upon(got_pake, enter=S01, outputs=[stash_pake])
S01.upon(got_code, enter=S11, outputs=[deliver_code_and_stashed_pake]) S01.upon(got_code, enter=S11, outputs=[deliver_code_and_stashed_pake])
@attrs @attrs
class _SortedKey(object): class _SortedKey(object):
_appid = attrib(validator=instance_of(type(u""))) _appid = attrib(validator=instance_of(type(u"")))
@ -115,7 +142,8 @@ class _SortedKey(object):
_side = attrib(validator=instance_of(type(u""))) _side = attrib(validator=instance_of(type(u"")))
_timing = attrib(validator=provides(_interfaces.ITiming)) _timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def wire(self, boss, mailbox, receive): def wire(self, boss, mailbox, receive):
self._B = _interfaces.IBoss(boss) self._B = _interfaces.IBoss(boss)
@ -123,17 +151,25 @@ class _SortedKey(object):
self._R = _interfaces.IReceive(receive) self._R = _interfaces.IReceive(receive)
@m.state(initial=True) @m.state(initial=True)
def S0_know_nothing(self): pass # pragma: no cover def S0_know_nothing(self):
pass # pragma: no cover
@m.state() @m.state()
def S1_know_code(self): pass # pragma: no cover def S1_know_code(self):
pass # pragma: no cover
@m.state() @m.state()
def S2_know_key(self): pass # pragma: no cover def S2_know_key(self):
pass # pragma: no cover
@m.state(terminal=True) @m.state(terminal=True)
def S3_scared(self): pass # pragma: no cover def S3_scared(self):
pass # pragma: no cover
# from Boss # from Boss
@m.input() @m.input()
def got_code(self, code): pass def got_code(self, code):
pass
# from Ordering # from Ordering
def got_pake(self, body): def got_pake(self, body):
@ -143,16 +179,20 @@ class _SortedKey(object):
self.got_pake_good(hexstr_to_bytes(payload["pake_v1"])) self.got_pake_good(hexstr_to_bytes(payload["pake_v1"]))
else: else:
self.got_pake_bad() self.got_pake_bad()
@m.input() @m.input()
def got_pake_good(self, msg2): pass def got_pake_good(self, msg2):
pass
@m.input() @m.input()
def got_pake_bad(self): pass def got_pake_bad(self):
pass
@m.output() @m.output()
def build_pake(self, code): def build_pake(self, code):
with self._timing.add("pake1", waiting="crypto"): with self._timing.add("pake1", waiting="crypto"):
self._sp = SPAKE2_Symmetric(to_bytes(code), self._sp = SPAKE2_Symmetric(
idSymmetric=to_bytes(self._appid)) to_bytes(code), idSymmetric=to_bytes(self._appid))
msg1 = self._sp.start() msg1 = self._sp.start()
body = dict_to_bytes({"pake_v1": bytes_to_hexstr(msg1)}) body = dict_to_bytes({"pake_v1": bytes_to_hexstr(msg1)})
self._M.add_message("pake", body) self._M.add_message("pake", body)
@ -160,6 +200,7 @@ class _SortedKey(object):
@m.output() @m.output()
def scared(self): def scared(self):
self._B.scared() self._B.scared()
@m.output() @m.output()
def compute_key(self, msg2): def compute_key(self, msg2):
assert isinstance(msg2, type(b"")) assert isinstance(msg2, type(b""))

View File

@ -1,16 +1,20 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
from zope.interface import implementer
from attr import attrs, attrib from attr import attrib, attrs
from attr.validators import provides from attr.validators import provides
from automat import MethodicalMachine from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces from . import _interfaces
@attrs @attrs
@implementer(_interfaces.ILister) @implementer(_interfaces.ILister)
class Lister(object): class Lister(object):
_timing = attrib(validator=provides(_interfaces.ITiming)) _timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def wire(self, rendezvous_connector, input): def wire(self, rendezvous_connector, input):
self._RC = _interfaces.IRendezvousConnector(rendezvous_connector) self._RC = _interfaces.IRendezvousConnector(rendezvous_connector)
@ -26,26 +30,41 @@ class Lister(object):
# request arrives, both requests will be satisfied by the same response. # request arrives, both requests will be satisfied by the same response.
@m.state(initial=True) @m.state(initial=True)
def S0A_idle_disconnected(self): pass # pragma: no cover def S0A_idle_disconnected(self):
pass # pragma: no cover
@m.state() @m.state()
def S1A_wanting_disconnected(self): pass # pragma: no cover def S1A_wanting_disconnected(self):
pass # pragma: no cover
@m.state() @m.state()
def S0B_idle_connected(self): pass # pragma: no cover def S0B_idle_connected(self):
pass # pragma: no cover
@m.state() @m.state()
def S1B_wanting_connected(self): pass # pragma: no cover def S1B_wanting_connected(self):
pass # pragma: no cover
@m.input() @m.input()
def connected(self): pass def connected(self):
pass
@m.input() @m.input()
def lost(self): pass def lost(self):
pass
@m.input() @m.input()
def refresh(self): pass def refresh(self):
pass
@m.input() @m.input()
def rx_nameplates(self, all_nameplates): pass def rx_nameplates(self, all_nameplates):
pass
@m.output() @m.output()
def RC_tx_list(self): def RC_tx_list(self):
self._RC.tx_list() self._RC.tx_list()
@m.output() @m.output()
def I_got_nameplates(self, all_nameplates): def I_got_nameplates(self, all_nameplates):
# We get a set of nameplate ids. There may be more attributes in the # We get a set of nameplate ids. There may be more attributes in the
@ -56,18 +75,19 @@ class Lister(object):
S0A_idle_disconnected.upon(connected, enter=S0B_idle_connected, outputs=[]) S0A_idle_disconnected.upon(connected, enter=S0B_idle_connected, outputs=[])
S0B_idle_connected.upon(lost, enter=S0A_idle_disconnected, outputs=[]) S0B_idle_connected.upon(lost, enter=S0A_idle_disconnected, outputs=[])
S0A_idle_disconnected.upon(refresh, S0A_idle_disconnected.upon(
enter=S1A_wanting_disconnected, outputs=[]) refresh, enter=S1A_wanting_disconnected, outputs=[])
S1A_wanting_disconnected.upon(refresh, S1A_wanting_disconnected.upon(
enter=S1A_wanting_disconnected, outputs=[]) refresh, enter=S1A_wanting_disconnected, outputs=[])
S1A_wanting_disconnected.upon(connected, enter=S1B_wanting_connected, S1A_wanting_disconnected.upon(
outputs=[RC_tx_list]) connected, enter=S1B_wanting_connected, outputs=[RC_tx_list])
S0B_idle_connected.upon(refresh, enter=S1B_wanting_connected, S0B_idle_connected.upon(
outputs=[RC_tx_list]) refresh, enter=S1B_wanting_connected, outputs=[RC_tx_list])
S0B_idle_connected.upon(rx_nameplates, enter=S0B_idle_connected, S0B_idle_connected.upon(
outputs=[I_got_nameplates]) rx_nameplates, enter=S0B_idle_connected, outputs=[I_got_nameplates])
S1B_wanting_connected.upon(lost, enter=S1A_wanting_disconnected, outputs=[]) S1B_wanting_connected.upon(
S1B_wanting_connected.upon(refresh, enter=S1B_wanting_connected, lost, enter=S1A_wanting_disconnected, outputs=[])
outputs=[RC_tx_list]) S1B_wanting_connected.upon(
S1B_wanting_connected.upon(rx_nameplates, enter=S0B_idle_connected, refresh, enter=S1B_wanting_connected, outputs=[RC_tx_list])
outputs=[I_got_nameplates]) S1B_wanting_connected.upon(
rx_nameplates, enter=S0B_idle_connected, outputs=[I_got_nameplates])

View File

@ -1,16 +1,20 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
from zope.interface import implementer
from attr import attrs, attrib from attr import attrib, attrs
from attr.validators import instance_of from attr.validators import instance_of
from automat import MethodicalMachine from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces from . import _interfaces
@attrs @attrs
@implementer(_interfaces.IMailbox) @implementer(_interfaces.IMailbox)
class Mailbox(object): class Mailbox(object):
_side = attrib(validator=instance_of(type(u""))) _side = attrib(validator=instance_of(type(u"")))
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self): def __attrs_post_init__(self):
self._mailbox = None self._mailbox = None
@ -29,27 +33,37 @@ class Mailbox(object):
# S0: know nothing # S0: know nothing
@m.state(initial=True) @m.state(initial=True)
def S0A(self): pass # pragma: no cover def S0A(self):
pass # pragma: no cover
@m.state() @m.state()
def S0B(self): pass # pragma: no cover def S0B(self):
pass # pragma: no cover
# S1: mailbox known, not opened # S1: mailbox known, not opened
@m.state() @m.state()
def S1A(self): pass # pragma: no cover def S1A(self):
pass # pragma: no cover
# S2: mailbox known, opened # S2: mailbox known, opened
# We've definitely tried to open the mailbox at least once, but it must # We've definitely tried to open the mailbox at least once, but it must
# be re-opened with each connection, because open() is also subscribe() # be re-opened with each connection, because open() is also subscribe()
@m.state() @m.state()
def S2A(self): pass # pragma: no cover def S2A(self):
pass # pragma: no cover
@m.state() @m.state()
def S2B(self): pass # pragma: no cover def S2B(self):
pass # pragma: no cover
# S3: closing # S3: closing
@m.state() @m.state()
def S3A(self): pass # pragma: no cover def S3A(self):
pass # pragma: no cover
@m.state() @m.state()
def S3B(self): pass # pragma: no cover def S3B(self):
pass # pragma: no cover
# S4: closed. We no longer care whether we're connected or not # S4: closed. We no longer care whether we're connected or not
# @m.state() # @m.state()
@ -57,24 +71,30 @@ class Mailbox(object):
# @m.state() # @m.state()
# def S4B(self): pass # def S4B(self): pass
@m.state(terminal=True) @m.state(terminal=True)
def S4(self): pass # pragma: no cover def S4(self):
pass # pragma: no cover
S4A = S4 S4A = S4
S4B = S4 S4B = S4
# from Terminator # from Terminator
@m.input() @m.input()
def close(self, mood): pass def close(self, mood):
pass
# from Nameplate # from Nameplate
@m.input() @m.input()
def got_mailbox(self, mailbox): pass def got_mailbox(self, mailbox):
pass
# from RendezvousConnector # from RendezvousConnector
@m.input() @m.input()
def connected(self): pass def connected(self):
pass
@m.input() @m.input()
def lost(self): pass def lost(self):
pass
def rx_message(self, side, phase, body): def rx_message(self, side, phase, body):
assert isinstance(side, type("")), type(side) assert isinstance(side, type("")), type(side)
@ -84,73 +104,91 @@ class Mailbox(object):
self.rx_message_ours(phase, body) self.rx_message_ours(phase, body)
else: else:
self.rx_message_theirs(side, phase, body) self.rx_message_theirs(side, phase, body)
@m.input() @m.input()
def rx_message_ours(self, phase, body): pass def rx_message_ours(self, phase, body):
pass
@m.input() @m.input()
def rx_message_theirs(self, side, phase, body): pass def rx_message_theirs(self, side, phase, body):
pass
@m.input() @m.input()
def rx_closed(self): pass def rx_closed(self):
pass
# from Send or Key # from Send or Key
@m.input() @m.input()
def add_message(self, phase, body): def add_message(self, phase, body):
pass pass
@m.output() @m.output()
def record_mailbox(self, mailbox): def record_mailbox(self, mailbox):
self._mailbox = mailbox self._mailbox = mailbox
@m.output() @m.output()
def RC_tx_open(self): def RC_tx_open(self):
assert self._mailbox assert self._mailbox
self._RC.tx_open(self._mailbox) self._RC.tx_open(self._mailbox)
@m.output() @m.output()
def queue(self, phase, body): def queue(self, phase, body):
assert isinstance(phase, type("")), type(phase) assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), (type(body), phase, body) assert isinstance(body, type(b"")), (type(body), phase, body)
self._pending_outbound[phase] = body self._pending_outbound[phase] = body
@m.output() @m.output()
def record_mailbox_and_RC_tx_open_and_drain(self, mailbox): def record_mailbox_and_RC_tx_open_and_drain(self, mailbox):
self._mailbox = mailbox self._mailbox = mailbox
self._RC.tx_open(mailbox) self._RC.tx_open(mailbox)
self._drain() self._drain()
@m.output() @m.output()
def drain(self): def drain(self):
self._drain() self._drain()
def _drain(self): def _drain(self):
for phase, body in self._pending_outbound.items(): for phase, body in self._pending_outbound.items():
self._RC.tx_add(phase, body) self._RC.tx_add(phase, body)
@m.output() @m.output()
def RC_tx_add(self, phase, body): def RC_tx_add(self, phase, body):
assert isinstance(phase, type("")), type(phase) assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body) assert isinstance(body, type(b"")), type(body)
self._RC.tx_add(phase, body) self._RC.tx_add(phase, body)
@m.output() @m.output()
def N_release_and_accept(self, side, phase, body): def N_release_and_accept(self, side, phase, body):
self._N.release() self._N.release()
if phase not in self._processed: if phase not in self._processed:
self._processed.add(phase) self._processed.add(phase)
self._O.got_message(side, phase, body) self._O.got_message(side, phase, body)
@m.output() @m.output()
def RC_tx_close(self): def RC_tx_close(self):
assert self._mood assert self._mood
self._RC_tx_close() self._RC_tx_close()
def _RC_tx_close(self): def _RC_tx_close(self):
self._RC.tx_close(self._mailbox, self._mood) self._RC.tx_close(self._mailbox, self._mood)
@m.output() @m.output()
def dequeue(self, phase, body): def dequeue(self, phase, body):
self._pending_outbound.pop(phase, None) self._pending_outbound.pop(phase, None)
@m.output() @m.output()
def record_mood(self, mood): def record_mood(self, mood):
self._mood = mood self._mood = mood
@m.output() @m.output()
def record_mood_and_RC_tx_close(self, mood): def record_mood_and_RC_tx_close(self, mood):
self._mood = mood self._mood = mood
self._RC_tx_close() self._RC_tx_close()
@m.output() @m.output()
def ignore_mood_and_T_mailbox_done(self, mood): def ignore_mood_and_T_mailbox_done(self, mood):
self._T.mailbox_done() self._T.mailbox_done()
@m.output() @m.output()
def T_mailbox_done(self): def T_mailbox_done(self):
self._T.mailbox_done() self._T.mailbox_done()
@ -162,7 +200,9 @@ class Mailbox(object):
S0B.upon(lost, enter=S0A, outputs=[]) S0B.upon(lost, enter=S0A, outputs=[])
S0B.upon(add_message, enter=S0B, outputs=[queue]) S0B.upon(add_message, enter=S0B, outputs=[queue])
S0B.upon(close, enter=S4B, outputs=[ignore_mood_and_T_mailbox_done]) S0B.upon(close, enter=S4B, outputs=[ignore_mood_and_T_mailbox_done])
S0B.upon(got_mailbox, enter=S2B, S0B.upon(
got_mailbox,
enter=S2B,
outputs=[record_mailbox_and_RC_tx_open_and_drain]) outputs=[record_mailbox_and_RC_tx_open_and_drain])
S1A.upon(connected, enter=S2B, outputs=[RC_tx_open, drain]) S1A.upon(connected, enter=S2B, outputs=[RC_tx_open, drain])
@ -192,4 +232,3 @@ class Mailbox(object):
S4.upon(rx_message_theirs, enter=S4, outputs=[]) S4.upon(rx_message_theirs, enter=S4, outputs=[])
S4.upon(rx_message_ours, enter=S4, outputs=[]) S4.upon(rx_message_ours, enter=S4, outputs=[])
S4.upon(close, enter=S4, outputs=[]) S4.upon(close, enter=S4, outputs=[])

View File

@ -1,7 +1,10 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
import re import re
from zope.interface import implementer
from automat import MethodicalMachine from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces from . import _interfaces
from ._wordlist import PGPWordList from ._wordlist import PGPWordList
from .errors import KeyFormatError from .errors import KeyFormatError
@ -9,13 +12,15 @@ from .errors import KeyFormatError
def validate_nameplate(nameplate): def validate_nameplate(nameplate):
if not re.search(r'^\d+$', nameplate): if not re.search(r'^\d+$', nameplate):
raise KeyFormatError("Nameplate '%s' must be numeric, with no spaces." raise KeyFormatError(
% nameplate) "Nameplate '%s' must be numeric, with no spaces." % nameplate)
@implementer(_interfaces.INameplate) @implementer(_interfaces.INameplate)
class Nameplate(object): class Nameplate(object):
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __init__(self): def __init__(self):
self._nameplate = None self._nameplate = None
@ -32,31 +37,44 @@ class Nameplate(object):
# S0: know nothing # S0: know nothing
@m.state(initial=True) @m.state(initial=True)
def S0A(self): pass # pragma: no cover def S0A(self):
pass # pragma: no cover
@m.state() @m.state()
def S0B(self): pass # pragma: no cover def S0B(self):
pass # pragma: no cover
# S1: nameplate known, never claimed # S1: nameplate known, never claimed
@m.state() @m.state()
def S1A(self): pass # pragma: no cover def S1A(self):
pass # pragma: no cover
# S2: nameplate known, maybe claimed # S2: nameplate known, maybe claimed
@m.state() @m.state()
def S2A(self): pass # pragma: no cover def S2A(self):
pass # pragma: no cover
@m.state() @m.state()
def S2B(self): pass # pragma: no cover def S2B(self):
pass # pragma: no cover
# S3: nameplate claimed # S3: nameplate claimed
@m.state() @m.state()
def S3A(self): pass # pragma: no cover def S3A(self):
pass # pragma: no cover
@m.state() @m.state()
def S3B(self): pass # pragma: no cover def S3B(self):
pass # pragma: no cover
# S4: maybe released # S4: maybe released
@m.state() @m.state()
def S4A(self): pass # pragma: no cover def S4A(self):
pass # pragma: no cover
@m.state() @m.state()
def S4B(self): pass # pragma: no cover def S4B(self):
pass # pragma: no cover
# S5: released # S5: released
# we no longer care whether we're connected or not # we no longer care whether we're connected or not
@ -65,7 +83,9 @@ class Nameplate(object):
# @m.state() # @m.state()
# def S5B(self): pass # def S5B(self): pass
@m.state() @m.state()
def S5(self): pass # pragma: no cover def S5(self):
pass # pragma: no cover
S5A = S5 S5A = S5
S5B = S5 S5B = S5
@ -73,53 +93,69 @@ class Nameplate(object):
def set_nameplate(self, nameplate): def set_nameplate(self, nameplate):
validate_nameplate(nameplate) # can raise KeyFormatError validate_nameplate(nameplate) # can raise KeyFormatError
self._set_nameplate(nameplate) self._set_nameplate(nameplate)
@m.input() @m.input()
def _set_nameplate(self, nameplate): pass def _set_nameplate(self, nameplate):
pass
# from Mailbox # from Mailbox
@m.input() @m.input()
def release(self): pass def release(self):
pass
# from Terminator # from Terminator
@m.input() @m.input()
def close(self): pass def close(self):
pass
# from RendezvousConnector # from RendezvousConnector
@m.input() @m.input()
def connected(self): pass def connected(self):
@m.input() pass
def lost(self): pass
@m.input() @m.input()
def rx_claimed(self, mailbox): pass def lost(self):
pass
@m.input() @m.input()
def rx_released(self): pass def rx_claimed(self, mailbox):
pass
@m.input()
def rx_released(self):
pass
@m.output() @m.output()
def record_nameplate(self, nameplate): def record_nameplate(self, nameplate):
validate_nameplate(nameplate) validate_nameplate(nameplate)
self._nameplate = nameplate self._nameplate = nameplate
@m.output() @m.output()
def record_nameplate_and_RC_tx_claim(self, nameplate): def record_nameplate_and_RC_tx_claim(self, nameplate):
validate_nameplate(nameplate) validate_nameplate(nameplate)
self._nameplate = nameplate self._nameplate = nameplate
self._RC.tx_claim(self._nameplate) self._RC.tx_claim(self._nameplate)
@m.output() @m.output()
def RC_tx_claim(self): def RC_tx_claim(self):
# when invoked via M.connected(), we must use the stored nameplate # when invoked via M.connected(), we must use the stored nameplate
self._RC.tx_claim(self._nameplate) self._RC.tx_claim(self._nameplate)
@m.output() @m.output()
def I_got_wordlist(self, mailbox): def I_got_wordlist(self, mailbox):
# TODO select wordlist based on nameplate properties, in rx_claimed # TODO select wordlist based on nameplate properties, in rx_claimed
wordlist = PGPWordList() wordlist = PGPWordList()
self._I.got_wordlist(wordlist) self._I.got_wordlist(wordlist)
@m.output() @m.output()
def M_got_mailbox(self, mailbox): def M_got_mailbox(self, mailbox):
self._M.got_mailbox(mailbox) self._M.got_mailbox(mailbox)
@m.output() @m.output()
def RC_tx_release(self): def RC_tx_release(self):
assert self._nameplate assert self._nameplate
self._RC.tx_release(self._nameplate) self._RC.tx_release(self._nameplate)
@m.output() @m.output()
def T_nameplate_done(self): def T_nameplate_done(self):
self._T.nameplate_done() self._T.nameplate_done()
@ -127,8 +163,8 @@ class Nameplate(object):
S0A.upon(_set_nameplate, enter=S1A, outputs=[record_nameplate]) S0A.upon(_set_nameplate, enter=S1A, outputs=[record_nameplate])
S0A.upon(connected, enter=S0B, outputs=[]) S0A.upon(connected, enter=S0B, outputs=[])
S0A.upon(close, enter=S5A, outputs=[T_nameplate_done]) S0A.upon(close, enter=S5A, outputs=[T_nameplate_done])
S0B.upon(_set_nameplate, enter=S2B, S0B.upon(
outputs=[record_nameplate_and_RC_tx_claim]) _set_nameplate, enter=S2B, outputs=[record_nameplate_and_RC_tx_claim])
S0B.upon(lost, enter=S0A, outputs=[]) S0B.upon(lost, enter=S0A, outputs=[])
S0B.upon(close, enter=S5A, outputs=[T_nameplate_done]) S0B.upon(close, enter=S5A, outputs=[T_nameplate_done])

View File

@ -1,29 +1,37 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
from zope.interface import implementer
from attr import attrs, attrib from attr import attrib, attrs
from attr.validators import provides, instance_of from attr.validators import instance_of, provides
from automat import MethodicalMachine from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces from . import _interfaces
@attrs @attrs
@implementer(_interfaces.IOrder) @implementer(_interfaces.IOrder)
class Order(object): class Order(object):
_side = attrib(validator=instance_of(type(u""))) _side = attrib(validator=instance_of(type(u"")))
_timing = attrib(validator=provides(_interfaces.ITiming)) _timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self): def __attrs_post_init__(self):
self._key = None self._key = None
self._queue = [] self._queue = []
def wire(self, key, receive): def wire(self, key, receive):
self._K = _interfaces.IKey(key) self._K = _interfaces.IKey(key)
self._R = _interfaces.IReceive(receive) self._R = _interfaces.IReceive(receive)
@m.state(initial=True) @m.state(initial=True)
def S0_no_pake(self): pass # pragma: no cover def S0_no_pake(self):
pass # pragma: no cover
@m.state(terminal=True) @m.state(terminal=True)
def S1_yes_pake(self): pass # pragma: no cover def S1_yes_pake(self):
pass # pragma: no cover
def got_message(self, side, phase, body): def got_message(self, side, phase, body):
# print("ORDER[%s].got_message(%s)" % (self._side, phase)) # print("ORDER[%s].got_message(%s)" % (self._side, phase))
@ -36,9 +44,12 @@ class Order(object):
self.got_non_pake(side, phase, body) self.got_non_pake(side, phase, body)
@m.input() @m.input()
def got_pake(self, side, phase, body): pass def got_pake(self, side, phase, body):
pass
@m.input() @m.input()
def got_non_pake(self, side, phase, body): pass def got_non_pake(self, side, phase, body):
pass
@m.output() @m.output()
def queue(self, side, phase, body): def queue(self, side, phase, body):
@ -46,9 +57,11 @@ class Order(object):
assert isinstance(phase, type("")), type(phase) assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body) assert isinstance(body, type(b"")), type(body)
self._queue.append((side, phase, body)) self._queue.append((side, phase, body))
@m.output() @m.output()
def notify_key(self, side, phase, body): def notify_key(self, side, phase, body):
self._K.got_pake(body) self._K.got_pake(body)
@m.output() @m.output()
def drain(self, side, phase, body): def drain(self, side, phase, body):
del phase del phase
@ -56,6 +69,7 @@ class Order(object):
for (side, phase, body) in self._queue: for (side, phase, body) in self._queue:
self._deliver(side, phase, body) self._deliver(side, phase, body)
self._queue[:] = [] self._queue[:] = []
@m.output() @m.output()
def deliver(self, side, phase, body): def deliver(self, side, phase, body):
self._deliver(side, phase, body) self._deliver(side, phase, body)

View File

@ -1,10 +1,13 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
from zope.interface import implementer
from attr import attrs, attrib from attr import attrib, attrs
from attr.validators import provides, instance_of from attr.validators import instance_of, provides
from automat import MethodicalMachine from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces from . import _interfaces
from ._key import derive_key, derive_phase_key, decrypt_data, CryptoError from ._key import CryptoError, decrypt_data, derive_key, derive_phase_key
@attrs @attrs
@implementer(_interfaces.IReceive) @implementer(_interfaces.IReceive)
@ -12,7 +15,8 @@ class Receive(object):
_side = attrib(validator=instance_of(type(u""))) _side = attrib(validator=instance_of(type(u"")))
_timing = attrib(validator=provides(_interfaces.ITiming)) _timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self): def __attrs_post_init__(self):
self._key = None self._key = None
@ -22,13 +26,20 @@ class Receive(object):
self._S = _interfaces.ISend(send) self._S = _interfaces.ISend(send)
@m.state(initial=True) @m.state(initial=True)
def S0_unknown_key(self): pass # pragma: no cover def S0_unknown_key(self):
pass # pragma: no cover
@m.state() @m.state()
def S1_unverified_key(self): pass # pragma: no cover def S1_unverified_key(self):
pass # pragma: no cover
@m.state() @m.state()
def S2_verified_key(self): pass # pragma: no cover def S2_verified_key(self):
pass # pragma: no cover
@m.state(terminal=True) @m.state(terminal=True)
def S3_scared(self): pass # pragma: no cover def S3_scared(self):
pass # pragma: no cover
# from Ordering # from Ordering
def got_message(self, side, phase, body): def got_message(self, side, phase, body):
@ -43,47 +54,56 @@ class Receive(object):
self.got_message_bad() self.got_message_bad()
return return
self.got_message_good(phase, plaintext) self.got_message_good(phase, plaintext)
@m.input() @m.input()
def got_message_good(self, phase, plaintext): pass def got_message_good(self, phase, plaintext):
pass
@m.input() @m.input()
def got_message_bad(self): pass def got_message_bad(self):
pass
# from Key # from Key
@m.input() @m.input()
def got_key(self, key): pass def got_key(self, key):
pass
@m.output() @m.output()
def record_key(self, key): def record_key(self, key):
self._key = key self._key = key
@m.output() @m.output()
def S_got_verified_key(self, phase, plaintext): def S_got_verified_key(self, phase, plaintext):
assert self._key assert self._key
self._S.got_verified_key(self._key) self._S.got_verified_key(self._key)
@m.output() @m.output()
def W_happy(self, phase, plaintext): def W_happy(self, phase, plaintext):
self._B.happy() self._B.happy()
@m.output() @m.output()
def W_got_verifier(self, phase, plaintext): def W_got_verifier(self, phase, plaintext):
self._B.got_verifier(derive_key(self._key, b"wormhole:verifier")) self._B.got_verifier(derive_key(self._key, b"wormhole:verifier"))
@m.output() @m.output()
def W_got_message(self, phase, plaintext): def W_got_message(self, phase, plaintext):
assert isinstance(phase, type("")), type(phase) assert isinstance(phase, type("")), type(phase)
assert isinstance(plaintext, type(b"")), type(plaintext) assert isinstance(plaintext, type(b"")), type(plaintext)
self._B.got_message(phase, plaintext) self._B.got_message(phase, plaintext)
@m.output() @m.output()
def W_scared(self): def W_scared(self):
self._B.scared() self._B.scared()
S0_unknown_key.upon(got_key, enter=S1_unverified_key, outputs=[record_key]) S0_unknown_key.upon(got_key, enter=S1_unverified_key, outputs=[record_key])
S1_unverified_key.upon(got_message_good, enter=S2_verified_key, S1_unverified_key.upon(
outputs=[S_got_verified_key, got_message_good,
W_happy, W_got_verifier, W_got_message]) enter=S2_verified_key,
S1_unverified_key.upon(got_message_bad, enter=S3_scared, outputs=[S_got_verified_key, W_happy, W_got_verifier, W_got_message])
outputs=[W_scared]) S1_unverified_key.upon(
S2_verified_key.upon(got_message_bad, enter=S3_scared, got_message_bad, enter=S3_scared, outputs=[W_scared])
outputs=[W_scared]) S2_verified_key.upon(got_message_bad, enter=S3_scared, outputs=[W_scared])
S2_verified_key.upon(got_message_good, enter=S2_verified_key, S2_verified_key.upon(
outputs=[W_got_message]) got_message_good, enter=S2_verified_key, outputs=[W_got_message])
S3_scared.upon(got_message_good, enter=S3_scared, outputs=[]) S3_scared.upon(got_message_good, enter=S3_scared, outputs=[])
S3_scared.upon(got_message_bad, enter=S3_scared, outputs=[]) S3_scared.upon(got_message_bad, enter=S3_scared, outputs=[])

View File

@ -9,8 +9,9 @@ from twisted.internet import defer, endpoints, task
from twisted.application import internet from twisted.application import internet
from autobahn.twisted import websocket from autobahn.twisted import websocket
from . import _interfaces, errors from . import _interfaces, errors
from .util import (bytes_to_hexstr, hexstr_to_bytes, from .util import (bytes_to_hexstr, hexstr_to_bytes, bytes_to_dict,
bytes_to_dict, dict_to_bytes) dict_to_bytes)
class WSClient(websocket.WebSocketClientProtocol): class WSClient(websocket.WebSocketClientProtocol):
def onConnect(self, response): def onConnect(self, response):
@ -29,7 +30,7 @@ class WSClient(websocket.WebSocketClientProtocol):
assert not isBinary assert not isBinary
try: try:
self._RC.ws_message(payload) self._RC.ws_message(payload)
except: except Exception:
from twisted.python.failure import Failure from twisted.python.failure import Failure
print("LOGGING", Failure()) print("LOGGING", Failure())
log.err() log.err()
@ -45,8 +46,10 @@ class WSClient(websocket.WebSocketClientProtocol):
# # finishing WebSocket negotiation (onOpen): errback # # finishing WebSocket negotiation (onOpen): errback
# self.factory.d.errback(error.ConnectError(reason)) # self.factory.d.errback(error.ConnectError(reason))
class WSFactory(websocket.WebSocketClientFactory): class WSFactory(websocket.WebSocketClientFactory):
protocol = WSClient protocol = WSClient
def __init__(self, RC, *args, **kwargs): def __init__(self, RC, *args, **kwargs):
websocket.WebSocketClientFactory.__init__(self, *args, **kwargs) websocket.WebSocketClientFactory.__init__(self, *args, **kwargs)
self._RC = RC self._RC = RC
@ -57,6 +60,7 @@ class WSFactory(websocket.WebSocketClientFactory):
# proto.wormhole_open = False # proto.wormhole_open = False
return proto return proto
@attrs @attrs
@implementer(_interfaces.IRendezvousConnector) @implementer(_interfaces.IRendezvousConnector)
class RendezvousConnector(object): class RendezvousConnector(object):
@ -90,6 +94,7 @@ class RendezvousConnector(object):
def set_trace(self, f): def set_trace(self, f):
self._trace = f self._trace = f
def _debug(self, what): def _debug(self, what):
if self._trace: if self._trace:
self._trace(old_state="", input=what, new_state="") self._trace(old_state="", input=what, new_state="")
@ -140,7 +145,6 @@ class RendezvousConnector(object):
d.addErrback(log.err) d.addErrback(log.err)
d.addBoth(self._stopped) d.addBoth(self._stopped)
# from Lister # from Lister
def tx_list(self): def tx_list(self):
self._tx("list") self._tx("list")
@ -166,7 +170,10 @@ class RendezvousConnector(object):
self._have_made_a_successful_connection = True self._have_made_a_successful_connection = True
self._ws = proto self._ws = proto
try: try:
self._tx("bind", appid=self._appid, side=self._side, self._tx(
"bind",
appid=self._appid,
side=self._side,
client_version=self._client_version) client_version=self._client_version)
self._N.connected() self._N.connected()
self._M.connected() self._M.connected()
@ -180,8 +187,9 @@ class RendezvousConnector(object):
def ws_message(self, payload): def ws_message(self, payload):
msg = bytes_to_dict(payload) msg = bytes_to_dict(payload)
if msg["type"] != "ack": if msg["type"] != "ack":
self._debug("R.rx(%s %s%s)" % self._debug("R.rx(%s %s%s)" % (
(msg["type"], msg.get("phase",""), msg["type"],
msg.get("phase", ""),
"[mine]" if msg.get("side", "") == self._side else "", "[mine]" if msg.get("side", "") == self._side else "",
)) ))
@ -192,7 +200,9 @@ class RendezvousConnector(object):
meth = getattr(self, "_response_handle_" + mtype, None) meth = getattr(self, "_response_handle_" + mtype, None)
if not meth: if not meth:
# make tests fail, but real application will ignore it # make tests fail, but real application will ignore it
log.err(errors._UnknownMessageTypeError("Unknown inbound message type %r" % (msg,))) log.err(
errors._UnknownMessageTypeError(
"Unknown inbound message type %r" % (msg, )))
return return
try: try:
return meth(msg) return meth(msg)
@ -301,5 +311,4 @@ class RendezvousConnector(object):
def _response_handle_closed(self, msg): def _response_handle_closed(self, msg):
self._M.rx_closed() self._M.rx_closed()
# record, message, payload, packet, bundle, ciphertext, plaintext # record, message, payload, packet, bundle, ciphertext, plaintext

View File

@ -1,17 +1,23 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import traceback import traceback
from sys import stderr from sys import stderr
from attr import attrib, attrs
from six.moves import input
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.threads import blockingCallFromThread, deferToThread
from .errors import AlreadyInputNameplateError, KeyFormatError
try: try:
import readline import readline
except ImportError: except ImportError:
readline = None readline = None
from six.moves import input
from attr import attrs, attrib
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.threads import deferToThread, blockingCallFromThread
from .errors import KeyFormatError, AlreadyInputNameplateError
errf = None errf = None
# uncomment this to enable tab-completion debugging # uncomment this to enable tab-completion debugging
# import os ; errf = open("err", "w") if os.path.exists("err") else None # import os ; errf = open("err", "w") if os.path.exists("err") else None
def debug(*args, **kwargs): # pragma: no cover def debug(*args, **kwargs): # pragma: no cover
@ -19,10 +25,12 @@ def debug(*args, **kwargs): # pragma: no cover
print(*args, file=errf, **kwargs) print(*args, file=errf, **kwargs)
errf.flush() errf.flush()
@attrs @attrs
class CodeInputter(object): class CodeInputter(object):
_input_helper = attrib() _input_helper = attrib()
_reactor = attrib() _reactor = attrib()
def __attrs_post_init__(self): def __attrs_post_init__(self):
self.used_completion = False self.used_completion = False
self._matches = None self._matches = None
@ -83,7 +91,9 @@ class CodeInputter(object):
# they deleted past the committment point: we can't use # they deleted past the committment point: we can't use
# this. For now, bail, but in the future let's find a # this. For now, bail, but in the future let's find a
# gentler way to encourage them to not do that. # gentler way to encourage them to not do that.
raise AlreadyInputNameplateError("nameplate (%s-) already entered, cannot go back" % self._committed_nameplate) raise AlreadyInputNameplateError(
"nameplate (%s-) already entered, cannot go back" %
self._committed_nameplate)
if not got_nameplate: if not got_nameplate:
# we're completing on nameplates: "" or "12" or "123" # we're completing on nameplates: "" or "12" or "123"
self.bcft(ih.refresh_nameplates) # results arrive later self.bcft(ih.refresh_nameplates) # results arrive later
@ -115,8 +125,10 @@ class CodeInputter(object):
self.bcft(ih.when_wordlist_is_available) # blocks on CLAIM self.bcft(ih.when_wordlist_is_available) # blocks on CLAIM
# and we're completing on words now # and we're completing on words now
debug(" getting words (%s)" % (words, )) debug(" getting words (%s)" % (words, ))
completions = [nameplate+"-"+c completions = [
for c in self.bcft(ih.get_word_completions, words)] nameplate + "-" + c
for c in self.bcft(ih.get_word_completions, words)
]
# rlcompleter wants full strings # rlcompleter wants full strings
return sorted(completions) return sorted(completions)
@ -131,13 +143,16 @@ class CodeInputter(object):
# they deleted past the committment point: we can't use # they deleted past the committment point: we can't use
# this. For now, bail, but in the future let's find a # this. For now, bail, but in the future let's find a
# gentler way to encourage them to not do that. # gentler way to encourage them to not do that.
raise AlreadyInputNameplateError("nameplate (%s-) already entered, cannot go back" % self._committed_nameplate) raise AlreadyInputNameplateError(
"nameplate (%s-) already entered, cannot go back" %
self._committed_nameplate)
else: else:
debug(" choose_nameplate(%s)" % nameplate) debug(" choose_nameplate(%s)" % nameplate)
self.bcft(self._input_helper.choose_nameplate, nameplate) self.bcft(self._input_helper.choose_nameplate, nameplate)
debug(" choose_words(%s)" % words) debug(" choose_words(%s)" % words)
self.bcft(self._input_helper.choose_words, words) self.bcft(self._input_helper.choose_words, words)
def _input_code_with_completion(prompt, input_helper, reactor): def _input_code_with_completion(prompt, input_helper, reactor):
# reminder: this all occurs in a separate thread. All calls to input_helper # reminder: this all occurs in a separate thread. All calls to input_helper
# must go through blockingCallFromThread() # must go through blockingCallFromThread()
@ -159,6 +174,7 @@ def _input_code_with_completion(prompt, input_helper, reactor):
c.finish(code) c.finish(code)
return c.used_completion return c.used_completion
def warn_readline(): def warn_readline():
# When our process receives a SIGINT, Twisted's SIGINT handler will # When our process receives a SIGINT, Twisted's SIGINT handler will
# stop the reactor and wait for all threads to terminate before the # stop the reactor and wait for all threads to terminate before the
@ -192,11 +208,12 @@ def warn_readline():
# doesn't see the signal, and we must still wait for stdin to make # doesn't see the signal, and we must still wait for stdin to make
# readline finish. # readline finish.
@inlineCallbacks @inlineCallbacks
def input_with_completion(prompt, input_helper, reactor): def input_with_completion(prompt, input_helper, reactor):
t = reactor.addSystemEventTrigger("before", "shutdown", warn_readline) t = reactor.addSystemEventTrigger("before", "shutdown", warn_readline)
# input_helper.refresh_nameplates() # input_helper.refresh_nameplates()
used_completion = yield deferToThread(_input_code_with_completion, used_completion = yield deferToThread(_input_code_with_completion, prompt,
prompt, input_helper, reactor) input_helper, reactor)
reactor.removeSystemEventTrigger(t) reactor.removeSystemEventTrigger(t)
returnValue(used_completion) returnValue(used_completion)

View File

@ -1,18 +1,22 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
from attr import attrs, attrib
from attr.validators import provides, instance_of from attr import attrib, attrs
from zope.interface import implementer from attr.validators import instance_of, provides
from automat import MethodicalMachine from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces from . import _interfaces
from ._key import derive_phase_key, encrypt_data from ._key import derive_phase_key, encrypt_data
@attrs @attrs
@implementer(_interfaces.ISend) @implementer(_interfaces.ISend)
class Send(object): class Send(object):
_side = attrib(validator=instance_of(type(u""))) _side = attrib(validator=instance_of(type(u"")))
_timing = attrib(validator=provides(_interfaces.ITiming)) _timing = attrib(validator=provides(_interfaces.ITiming))
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self): def __attrs_post_init__(self):
self._queue = [] self._queue = []
@ -21,31 +25,40 @@ class Send(object):
self._M = _interfaces.IMailbox(mailbox) self._M = _interfaces.IMailbox(mailbox)
@m.state(initial=True) @m.state(initial=True)
def S0_no_key(self): pass # pragma: no cover def S0_no_key(self):
pass # pragma: no cover
@m.state(terminal=True) @m.state(terminal=True)
def S1_verified_key(self): pass # pragma: no cover def S1_verified_key(self):
pass # pragma: no cover
# from Receive # from Receive
@m.input() @m.input()
def got_verified_key(self, key): pass def got_verified_key(self, key):
pass
# from Boss # from Boss
@m.input() @m.input()
def send(self, phase, plaintext): pass def send(self, phase, plaintext):
pass
@m.output() @m.output()
def queue(self, phase, plaintext): def queue(self, phase, plaintext):
assert isinstance(phase, type("")), type(phase) assert isinstance(phase, type("")), type(phase)
assert isinstance(plaintext, type(b"")), type(plaintext) assert isinstance(plaintext, type(b"")), type(plaintext)
self._queue.append((phase, plaintext)) self._queue.append((phase, plaintext))
@m.output() @m.output()
def record_key(self, key): def record_key(self, key):
self._key = key self._key = key
@m.output() @m.output()
def drain(self, key): def drain(self, key):
del key del key
for (phase, plaintext) in self._queue: for (phase, plaintext) in self._queue:
self._encrypt_and_send(phase, plaintext) self._encrypt_and_send(phase, plaintext)
self._queue[:] = [] self._queue[:] = []
@m.output() @m.output()
def deliver(self, phase, plaintext): def deliver(self, phase, plaintext):
assert isinstance(phase, type("")), type(phase) assert isinstance(phase, type("")), type(phase)
@ -59,6 +72,6 @@ class Send(object):
self._M.add_message(phase, encrypted) self._M.add_message(phase, encrypted)
S0_no_key.upon(send, enter=S0_no_key, outputs=[queue]) S0_no_key.upon(send, enter=S0_no_key, outputs=[queue])
S0_no_key.upon(got_verified_key, enter=S1_verified_key, S0_no_key.upon(
outputs=[record_key, drain]) got_verified_key, enter=S1_verified_key, outputs=[record_key, drain])
S1_verified_key.upon(send, enter=S1_verified_key, outputs=[deliver]) S1_verified_key.upon(send, enter=S1_verified_key, outputs=[deliver])

View File

@ -1,12 +1,16 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
from zope.interface import implementer
from automat import MethodicalMachine from automat import MethodicalMachine
from zope.interface import implementer
from . import _interfaces from . import _interfaces
@implementer(_interfaces.ITerminator) @implementer(_interfaces.ITerminator)
class Terminator(object): class Terminator(object):
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace",
lambda self, f: None) # pragma: no cover
def __init__(self): def __init__(self):
self._mood = None self._mood = None
@ -29,48 +33,68 @@ class Terminator(object):
# done, and we're closing, then we stop the RendezvousConnector # done, and we're closing, then we stop the RendezvousConnector
@m.state(initial=True) @m.state(initial=True)
def Snmo(self): pass # pragma: no cover def Snmo(self):
@m.state() pass # pragma: no cover
def Smo(self): pass # pragma: no cover
@m.state()
def Sno(self): pass # pragma: no cover
@m.state()
def S0o(self): pass # pragma: no cover
@m.state() @m.state()
def Snm(self): pass # pragma: no cover def Smo(self):
pass # pragma: no cover
@m.state() @m.state()
def Sm(self): pass # pragma: no cover def Sno(self):
pass # pragma: no cover
@m.state() @m.state()
def Sn(self): pass # pragma: no cover def S0o(self):
pass # pragma: no cover
@m.state()
def Snm(self):
pass # pragma: no cover
@m.state()
def Sm(self):
pass # pragma: no cover
@m.state()
def Sn(self):
pass # pragma: no cover
# @m.state() # @m.state()
# def S0(self): pass # unused # def S0(self): pass # unused
@m.state() @m.state()
def S_stopping(self): pass # pragma: no cover def S_stopping(self):
pass # pragma: no cover
@m.state() @m.state()
def S_stopped(self, terminal=True): pass # pragma: no cover def S_stopped(self, terminal=True):
pass # pragma: no cover
# from Boss # from Boss
@m.input() @m.input()
def close(self, mood): pass def close(self, mood):
pass
# from Nameplate # from Nameplate
@m.input() @m.input()
def nameplate_done(self): pass def nameplate_done(self):
pass
# from Mailbox # from Mailbox
@m.input() @m.input()
def mailbox_done(self): pass def mailbox_done(self):
pass
# from RendezvousConnector # from RendezvousConnector
@m.input() @m.input()
def stopped(self): pass def stopped(self):
pass
@m.output() @m.output()
def close_nameplate(self, mood): def close_nameplate(self, mood):
self._N.close() # ignores mood self._N.close() # ignores mood
@m.output() @m.output()
def close_mailbox(self, mood): def close_mailbox(self, mood):
self._M.close(mood) self._M.close(mood)
@ -78,9 +102,11 @@ class Terminator(object):
@m.output() @m.output()
def ignore_mood_and_RC_stop(self, mood): def ignore_mood_and_RC_stop(self, mood):
self._RC.stop() self._RC.stop()
@m.output() @m.output()
def RC_stop(self): def RC_stop(self):
self._RC.stop() self._RC.stop()
@m.output() @m.output()
def B_closed(self): def B_closed(self):
self._B.closed() self._B.closed()
@ -99,7 +125,9 @@ class Terminator(object):
Snm.upon(nameplate_done, enter=Sm, outputs=[]) Snm.upon(nameplate_done, enter=Sm, outputs=[])
Sn.upon(nameplate_done, enter=S_stopping, outputs=[RC_stop]) Sn.upon(nameplate_done, enter=S_stopping, outputs=[RC_stop])
S0o.upon(close, enter=S_stopping, S0o.upon(
close,
enter=S_stopping,
outputs=[close_nameplate, close_mailbox, ignore_mood_and_RC_stop]) outputs=[close_nameplate, close_mailbox, ignore_mood_and_RC_stop])
Sm.upon(mailbox_done, enter=S_stopping, outputs=[RC_stop]) Sm.upon(mailbox_done, enter=S_stopping, outputs=[RC_stop])

View File

@ -1,6 +1,10 @@
from __future__ import unicode_literals, print_function from __future__ import print_function, unicode_literals
import os import os
from binascii import unhexlify
from zope.interface import implementer from zope.interface import implementer
from ._interfaces import IWordlist from ._interfaces import IWordlist
# The PGP Word List, which maps bytes to phonetically-distinct words. There # The PGP Word List, which maps bytes to phonetically-distinct words. There
@ -10,146 +14,271 @@ from ._interfaces import IWordlist
# Thanks to Warren Guy for transcribing them: # Thanks to Warren Guy for transcribing them:
# https://github.com/warrenguy/javascript-pgp-word-list # https://github.com/warrenguy/javascript-pgp-word-list
from binascii import unhexlify
raw_words = { raw_words = {
'00': ['aardvark', 'adroitness'], '01': ['absurd', 'adviser'], '00': ['aardvark', 'adroitness'],
'02': ['accrue', 'aftermath'], '03': ['acme', 'aggregate'], '01': ['absurd', 'adviser'],
'04': ['adrift', 'alkali'], '05': ['adult', 'almighty'], '02': ['accrue', 'aftermath'],
'06': ['afflict', 'amulet'], '07': ['ahead', 'amusement'], '03': ['acme', 'aggregate'],
'08': ['aimless', 'antenna'], '09': ['Algol', 'applicant'], '04': ['adrift', 'alkali'],
'0A': ['allow', 'Apollo'], '0B': ['alone', 'armistice'], '05': ['adult', 'almighty'],
'0C': ['ammo', 'article'], '0D': ['ancient', 'asteroid'], '06': ['afflict', 'amulet'],
'0E': ['apple', 'Atlantic'], '0F': ['artist', 'atmosphere'], '07': ['ahead', 'amusement'],
'10': ['assume', 'autopsy'], '11': ['Athens', 'Babylon'], '08': ['aimless', 'antenna'],
'12': ['atlas', 'backwater'], '13': ['Aztec', 'barbecue'], '09': ['Algol', 'applicant'],
'14': ['baboon', 'belowground'], '15': ['backfield', 'bifocals'], '0A': ['allow', 'Apollo'],
'16': ['backward', 'bodyguard'], '17': ['banjo', 'bookseller'], '0B': ['alone', 'armistice'],
'18': ['beaming', 'borderline'], '19': ['bedlamp', 'bottomless'], '0C': ['ammo', 'article'],
'1A': ['beehive', 'Bradbury'], '1B': ['beeswax', 'bravado'], '0D': ['ancient', 'asteroid'],
'1C': ['befriend', 'Brazilian'], '1D': ['Belfast', 'breakaway'], '0E': ['apple', 'Atlantic'],
'1E': ['berserk', 'Burlington'], '1F': ['billiard', 'businessman'], '0F': ['artist', 'atmosphere'],
'20': ['bison', 'butterfat'], '21': ['blackjack', 'Camelot'], '10': ['assume', 'autopsy'],
'22': ['blockade', 'candidate'], '23': ['blowtorch', 'cannonball'], '11': ['Athens', 'Babylon'],
'24': ['bluebird', 'Capricorn'], '25': ['bombast', 'caravan'], '12': ['atlas', 'backwater'],
'26': ['bookshelf', 'caretaker'], '27': ['brackish', 'celebrate'], '13': ['Aztec', 'barbecue'],
'28': ['breadline', 'cellulose'], '29': ['breakup', 'certify'], '14': ['baboon', 'belowground'],
'2A': ['brickyard', 'chambermaid'], '2B': ['briefcase', 'Cherokee'], '15': ['backfield', 'bifocals'],
'2C': ['Burbank', 'Chicago'], '2D': ['button', 'clergyman'], '16': ['backward', 'bodyguard'],
'2E': ['buzzard', 'coherence'], '2F': ['cement', 'combustion'], '17': ['banjo', 'bookseller'],
'30': ['chairlift', 'commando'], '31': ['chatter', 'company'], '18': ['beaming', 'borderline'],
'32': ['checkup', 'component'], '33': ['chisel', 'concurrent'], '19': ['bedlamp', 'bottomless'],
'34': ['choking', 'confidence'], '35': ['chopper', 'conformist'], '1A': ['beehive', 'Bradbury'],
'36': ['Christmas', 'congregate'], '37': ['clamshell', 'consensus'], '1B': ['beeswax', 'bravado'],
'38': ['classic', 'consulting'], '39': ['classroom', 'corporate'], '1C': ['befriend', 'Brazilian'],
'3A': ['cleanup', 'corrosion'], '3B': ['clockwork', 'councilman'], '1D': ['Belfast', 'breakaway'],
'3C': ['cobra', 'crossover'], '3D': ['commence', 'crucifix'], '1E': ['berserk', 'Burlington'],
'3E': ['concert', 'cumbersome'], '3F': ['cowbell', 'customer'], '1F': ['billiard', 'businessman'],
'40': ['crackdown', 'Dakota'], '41': ['cranky', 'decadence'], '20': ['bison', 'butterfat'],
'42': ['crowfoot', 'December'], '43': ['crucial', 'decimal'], '21': ['blackjack', 'Camelot'],
'44': ['crumpled', 'designing'], '45': ['crusade', 'detector'], '22': ['blockade', 'candidate'],
'46': ['cubic', 'detergent'], '47': ['dashboard', 'determine'], '23': ['blowtorch', 'cannonball'],
'48': ['deadbolt', 'dictator'], '49': ['deckhand', 'dinosaur'], '24': ['bluebird', 'Capricorn'],
'4A': ['dogsled', 'direction'], '4B': ['dragnet', 'disable'], '25': ['bombast', 'caravan'],
'4C': ['drainage', 'disbelief'], '4D': ['dreadful', 'disruptive'], '26': ['bookshelf', 'caretaker'],
'4E': ['drifter', 'distortion'], '4F': ['dropper', 'document'], '27': ['brackish', 'celebrate'],
'50': ['drumbeat', 'embezzle'], '51': ['drunken', 'enchanting'], '28': ['breadline', 'cellulose'],
'52': ['Dupont', 'enrollment'], '53': ['dwelling', 'enterprise'], '29': ['breakup', 'certify'],
'54': ['eating', 'equation'], '55': ['edict', 'equipment'], '2A': ['brickyard', 'chambermaid'],
'56': ['egghead', 'escapade'], '57': ['eightball', 'Eskimo'], '2B': ['briefcase', 'Cherokee'],
'58': ['endorse', 'everyday'], '59': ['endow', 'examine'], '2C': ['Burbank', 'Chicago'],
'5A': ['enlist', 'existence'], '5B': ['erase', 'exodus'], '2D': ['button', 'clergyman'],
'5C': ['escape', 'fascinate'], '5D': ['exceed', 'filament'], '2E': ['buzzard', 'coherence'],
'5E': ['eyeglass', 'finicky'], '5F': ['eyetooth', 'forever'], '2F': ['cement', 'combustion'],
'60': ['facial', 'fortitude'], '61': ['fallout', 'frequency'], '30': ['chairlift', 'commando'],
'62': ['flagpole', 'gadgetry'], '63': ['flatfoot', 'Galveston'], '31': ['chatter', 'company'],
'64': ['flytrap', 'getaway'], '65': ['fracture', 'glossary'], '32': ['checkup', 'component'],
'66': ['framework', 'gossamer'], '67': ['freedom', 'graduate'], '33': ['chisel', 'concurrent'],
'68': ['frighten', 'gravity'], '69': ['gazelle', 'guitarist'], '34': ['choking', 'confidence'],
'6A': ['Geiger', 'hamburger'], '6B': ['glitter', 'Hamilton'], '35': ['chopper', 'conformist'],
'6C': ['glucose', 'handiwork'], '6D': ['goggles', 'hazardous'], '36': ['Christmas', 'congregate'],
'6E': ['goldfish', 'headwaters'], '6F': ['gremlin', 'hemisphere'], '37': ['clamshell', 'consensus'],
'70': ['guidance', 'hesitate'], '71': ['hamlet', 'hideaway'], '38': ['classic', 'consulting'],
'72': ['highchair', 'holiness'], '73': ['hockey', 'hurricane'], '39': ['classroom', 'corporate'],
'74': ['indoors', 'hydraulic'], '75': ['indulge', 'impartial'], '3A': ['cleanup', 'corrosion'],
'76': ['inverse', 'impetus'], '77': ['involve', 'inception'], '3B': ['clockwork', 'councilman'],
'78': ['island', 'indigo'], '79': ['jawbone', 'inertia'], '3C': ['cobra', 'crossover'],
'7A': ['keyboard', 'infancy'], '7B': ['kickoff', 'inferno'], '3D': ['commence', 'crucifix'],
'7C': ['kiwi', 'informant'], '7D': ['klaxon', 'insincere'], '3E': ['concert', 'cumbersome'],
'7E': ['locale', 'insurgent'], '7F': ['lockup', 'integrate'], '3F': ['cowbell', 'customer'],
'80': ['merit', 'intention'], '81': ['minnow', 'inventive'], '40': ['crackdown', 'Dakota'],
'82': ['miser', 'Istanbul'], '83': ['Mohawk', 'Jamaica'], '41': ['cranky', 'decadence'],
'84': ['mural', 'Jupiter'], '85': ['music', 'leprosy'], '42': ['crowfoot', 'December'],
'86': ['necklace', 'letterhead'], '87': ['Neptune', 'liberty'], '43': ['crucial', 'decimal'],
'88': ['newborn', 'maritime'], '89': ['nightbird', 'matchmaker'], '44': ['crumpled', 'designing'],
'8A': ['Oakland', 'maverick'], '8B': ['obtuse', 'Medusa'], '45': ['crusade', 'detector'],
'8C': ['offload', 'megaton'], '8D': ['optic', 'microscope'], '46': ['cubic', 'detergent'],
'8E': ['orca', 'microwave'], '8F': ['payday', 'midsummer'], '47': ['dashboard', 'determine'],
'90': ['peachy', 'millionaire'], '91': ['pheasant', 'miracle'], '48': ['deadbolt', 'dictator'],
'92': ['physique', 'misnomer'], '93': ['playhouse', 'molasses'], '49': ['deckhand', 'dinosaur'],
'94': ['Pluto', 'molecule'], '95': ['preclude', 'Montana'], '4A': ['dogsled', 'direction'],
'96': ['prefer', 'monument'], '97': ['preshrunk', 'mosquito'], '4B': ['dragnet', 'disable'],
'98': ['printer', 'narrative'], '99': ['prowler', 'nebula'], '4C': ['drainage', 'disbelief'],
'9A': ['pupil', 'newsletter'], '9B': ['puppy', 'Norwegian'], '4D': ['dreadful', 'disruptive'],
'9C': ['python', 'October'], '9D': ['quadrant', 'Ohio'], '4E': ['drifter', 'distortion'],
'9E': ['quiver', 'onlooker'], '9F': ['quota', 'opulent'], '4F': ['dropper', 'document'],
'A0': ['ragtime', 'Orlando'], 'A1': ['ratchet', 'outfielder'], '50': ['drumbeat', 'embezzle'],
'A2': ['rebirth', 'Pacific'], 'A3': ['reform', 'pandemic'], '51': ['drunken', 'enchanting'],
'A4': ['regain', 'Pandora'], 'A5': ['reindeer', 'paperweight'], '52': ['Dupont', 'enrollment'],
'A6': ['rematch', 'paragon'], 'A7': ['repay', 'paragraph'], '53': ['dwelling', 'enterprise'],
'A8': ['retouch', 'paramount'], 'A9': ['revenge', 'passenger'], '54': ['eating', 'equation'],
'AA': ['reward', 'pedigree'], 'AB': ['rhythm', 'Pegasus'], '55': ['edict', 'equipment'],
'AC': ['ribcage', 'penetrate'], 'AD': ['ringbolt', 'perceptive'], '56': ['egghead', 'escapade'],
'AE': ['robust', 'performance'], 'AF': ['rocker', 'pharmacy'], '57': ['eightball', 'Eskimo'],
'B0': ['ruffled', 'phonetic'], 'B1': ['sailboat', 'photograph'], '58': ['endorse', 'everyday'],
'B2': ['sawdust', 'pioneer'], 'B3': ['scallion', 'pocketful'], '59': ['endow', 'examine'],
'B4': ['scenic', 'politeness'], 'B5': ['scorecard', 'positive'], '5A': ['enlist', 'existence'],
'B6': ['Scotland', 'potato'], 'B7': ['seabird', 'processor'], '5B': ['erase', 'exodus'],
'B8': ['select', 'provincial'], 'B9': ['sentence', 'proximate'], '5C': ['escape', 'fascinate'],
'BA': ['shadow', 'puberty'], 'BB': ['shamrock', 'publisher'], '5D': ['exceed', 'filament'],
'BC': ['showgirl', 'pyramid'], 'BD': ['skullcap', 'quantity'], '5E': ['eyeglass', 'finicky'],
'BE': ['skydive', 'racketeer'], 'BF': ['slingshot', 'rebellion'], '5F': ['eyetooth', 'forever'],
'C0': ['slowdown', 'recipe'], 'C1': ['snapline', 'recover'], '60': ['facial', 'fortitude'],
'C2': ['snapshot', 'repellent'], 'C3': ['snowcap', 'replica'], '61': ['fallout', 'frequency'],
'C4': ['snowslide', 'reproduce'], 'C5': ['solo', 'resistor'], '62': ['flagpole', 'gadgetry'],
'C6': ['southward', 'responsive'], 'C7': ['soybean', 'retraction'], '63': ['flatfoot', 'Galveston'],
'C8': ['spaniel', 'retrieval'], 'C9': ['spearhead', 'retrospect'], '64': ['flytrap', 'getaway'],
'CA': ['spellbind', 'revenue'], 'CB': ['spheroid', 'revival'], '65': ['fracture', 'glossary'],
'CC': ['spigot', 'revolver'], 'CD': ['spindle', 'sandalwood'], '66': ['framework', 'gossamer'],
'CE': ['spyglass', 'sardonic'], 'CF': ['stagehand', 'Saturday'], '67': ['freedom', 'graduate'],
'D0': ['stagnate', 'savagery'], 'D1': ['stairway', 'scavenger'], '68': ['frighten', 'gravity'],
'D2': ['standard', 'sensation'], 'D3': ['stapler', 'sociable'], '69': ['gazelle', 'guitarist'],
'D4': ['steamship', 'souvenir'], 'D5': ['sterling', 'specialist'], '6A': ['Geiger', 'hamburger'],
'D6': ['stockman', 'speculate'], 'D7': ['stopwatch', 'stethoscope'], '6B': ['glitter', 'Hamilton'],
'D8': ['stormy', 'stupendous'], 'D9': ['sugar', 'supportive'], '6C': ['glucose', 'handiwork'],
'DA': ['surmount', 'surrender'], 'DB': ['suspense', 'suspicious'], '6D': ['goggles', 'hazardous'],
'DC': ['sweatband', 'sympathy'], 'DD': ['swelter', 'tambourine'], '6E': ['goldfish', 'headwaters'],
'DE': ['tactics', 'telephone'], 'DF': ['talon', 'therapist'], '6F': ['gremlin', 'hemisphere'],
'E0': ['tapeworm', 'tobacco'], 'E1': ['tempest', 'tolerance'], '70': ['guidance', 'hesitate'],
'E2': ['tiger', 'tomorrow'], 'E3': ['tissue', 'torpedo'], '71': ['hamlet', 'hideaway'],
'E4': ['tonic', 'tradition'], 'E5': ['topmost', 'travesty'], '72': ['highchair', 'holiness'],
'E6': ['tracker', 'trombonist'], 'E7': ['transit', 'truncated'], '73': ['hockey', 'hurricane'],
'E8': ['trauma', 'typewriter'], 'E9': ['treadmill', 'ultimate'], '74': ['indoors', 'hydraulic'],
'EA': ['Trojan', 'undaunted'], 'EB': ['trouble', 'underfoot'], '75': ['indulge', 'impartial'],
'EC': ['tumor', 'unicorn'], 'ED': ['tunnel', 'unify'], '76': ['inverse', 'impetus'],
'EE': ['tycoon', 'universe'], 'EF': ['uncut', 'unravel'], '77': ['involve', 'inception'],
'F0': ['unearth', 'upcoming'], 'F1': ['unwind', 'vacancy'], '78': ['island', 'indigo'],
'F2': ['uproot', 'vagabond'], 'F3': ['upset', 'vertigo'], '79': ['jawbone', 'inertia'],
'F4': ['upshot', 'Virginia'], 'F5': ['vapor', 'visitor'], '7A': ['keyboard', 'infancy'],
'F6': ['village', 'vocalist'], 'F7': ['virus', 'voyager'], '7B': ['kickoff', 'inferno'],
'F8': ['Vulcan', 'warranty'], 'F9': ['waffle', 'Waterloo'], '7C': ['kiwi', 'informant'],
'FA': ['wallet', 'whimsical'], 'FB': ['watchword', 'Wichita'], '7D': ['klaxon', 'insincere'],
'FC': ['wayside', 'Wilmington'], 'FD': ['willow', 'Wyoming'], '7E': ['locale', 'insurgent'],
'FE': ['woodlark', 'yesteryear'], 'FF': ['Zulu', 'Yucatan'] '7F': ['lockup', 'integrate'],
}; '80': ['merit', 'intention'],
'81': ['minnow', 'inventive'],
'82': ['miser', 'Istanbul'],
'83': ['Mohawk', 'Jamaica'],
'84': ['mural', 'Jupiter'],
'85': ['music', 'leprosy'],
'86': ['necklace', 'letterhead'],
'87': ['Neptune', 'liberty'],
'88': ['newborn', 'maritime'],
'89': ['nightbird', 'matchmaker'],
'8A': ['Oakland', 'maverick'],
'8B': ['obtuse', 'Medusa'],
'8C': ['offload', 'megaton'],
'8D': ['optic', 'microscope'],
'8E': ['orca', 'microwave'],
'8F': ['payday', 'midsummer'],
'90': ['peachy', 'millionaire'],
'91': ['pheasant', 'miracle'],
'92': ['physique', 'misnomer'],
'93': ['playhouse', 'molasses'],
'94': ['Pluto', 'molecule'],
'95': ['preclude', 'Montana'],
'96': ['prefer', 'monument'],
'97': ['preshrunk', 'mosquito'],
'98': ['printer', 'narrative'],
'99': ['prowler', 'nebula'],
'9A': ['pupil', 'newsletter'],
'9B': ['puppy', 'Norwegian'],
'9C': ['python', 'October'],
'9D': ['quadrant', 'Ohio'],
'9E': ['quiver', 'onlooker'],
'9F': ['quota', 'opulent'],
'A0': ['ragtime', 'Orlando'],
'A1': ['ratchet', 'outfielder'],
'A2': ['rebirth', 'Pacific'],
'A3': ['reform', 'pandemic'],
'A4': ['regain', 'Pandora'],
'A5': ['reindeer', 'paperweight'],
'A6': ['rematch', 'paragon'],
'A7': ['repay', 'paragraph'],
'A8': ['retouch', 'paramount'],
'A9': ['revenge', 'passenger'],
'AA': ['reward', 'pedigree'],
'AB': ['rhythm', 'Pegasus'],
'AC': ['ribcage', 'penetrate'],
'AD': ['ringbolt', 'perceptive'],
'AE': ['robust', 'performance'],
'AF': ['rocker', 'pharmacy'],
'B0': ['ruffled', 'phonetic'],
'B1': ['sailboat', 'photograph'],
'B2': ['sawdust', 'pioneer'],
'B3': ['scallion', 'pocketful'],
'B4': ['scenic', 'politeness'],
'B5': ['scorecard', 'positive'],
'B6': ['Scotland', 'potato'],
'B7': ['seabird', 'processor'],
'B8': ['select', 'provincial'],
'B9': ['sentence', 'proximate'],
'BA': ['shadow', 'puberty'],
'BB': ['shamrock', 'publisher'],
'BC': ['showgirl', 'pyramid'],
'BD': ['skullcap', 'quantity'],
'BE': ['skydive', 'racketeer'],
'BF': ['slingshot', 'rebellion'],
'C0': ['slowdown', 'recipe'],
'C1': ['snapline', 'recover'],
'C2': ['snapshot', 'repellent'],
'C3': ['snowcap', 'replica'],
'C4': ['snowslide', 'reproduce'],
'C5': ['solo', 'resistor'],
'C6': ['southward', 'responsive'],
'C7': ['soybean', 'retraction'],
'C8': ['spaniel', 'retrieval'],
'C9': ['spearhead', 'retrospect'],
'CA': ['spellbind', 'revenue'],
'CB': ['spheroid', 'revival'],
'CC': ['spigot', 'revolver'],
'CD': ['spindle', 'sandalwood'],
'CE': ['spyglass', 'sardonic'],
'CF': ['stagehand', 'Saturday'],
'D0': ['stagnate', 'savagery'],
'D1': ['stairway', 'scavenger'],
'D2': ['standard', 'sensation'],
'D3': ['stapler', 'sociable'],
'D4': ['steamship', 'souvenir'],
'D5': ['sterling', 'specialist'],
'D6': ['stockman', 'speculate'],
'D7': ['stopwatch', 'stethoscope'],
'D8': ['stormy', 'stupendous'],
'D9': ['sugar', 'supportive'],
'DA': ['surmount', 'surrender'],
'DB': ['suspense', 'suspicious'],
'DC': ['sweatband', 'sympathy'],
'DD': ['swelter', 'tambourine'],
'DE': ['tactics', 'telephone'],
'DF': ['talon', 'therapist'],
'E0': ['tapeworm', 'tobacco'],
'E1': ['tempest', 'tolerance'],
'E2': ['tiger', 'tomorrow'],
'E3': ['tissue', 'torpedo'],
'E4': ['tonic', 'tradition'],
'E5': ['topmost', 'travesty'],
'E6': ['tracker', 'trombonist'],
'E7': ['transit', 'truncated'],
'E8': ['trauma', 'typewriter'],
'E9': ['treadmill', 'ultimate'],
'EA': ['Trojan', 'undaunted'],
'EB': ['trouble', 'underfoot'],
'EC': ['tumor', 'unicorn'],
'ED': ['tunnel', 'unify'],
'EE': ['tycoon', 'universe'],
'EF': ['uncut', 'unravel'],
'F0': ['unearth', 'upcoming'],
'F1': ['unwind', 'vacancy'],
'F2': ['uproot', 'vagabond'],
'F3': ['upset', 'vertigo'],
'F4': ['upshot', 'Virginia'],
'F5': ['vapor', 'visitor'],
'F6': ['village', 'vocalist'],
'F7': ['virus', 'voyager'],
'F8': ['Vulcan', 'warranty'],
'F9': ['waffle', 'Waterloo'],
'FA': ['wallet', 'whimsical'],
'FB': ['watchword', 'Wichita'],
'FC': ['wayside', 'Wilmington'],
'FD': ['willow', 'Wyoming'],
'FE': ['woodlark', 'yesteryear'],
'FF': ['Zulu', 'Yucatan']
}
byte_to_even_word = dict([(unhexlify(k.encode("ascii")), both_words[0]) byte_to_even_word = dict([(unhexlify(k.encode("ascii")), both_words[0])
for k,both_words for k, both_words in raw_words.items()])
in raw_words.items()])
byte_to_odd_word = dict([(unhexlify(k.encode("ascii")), both_words[1]) byte_to_odd_word = dict([(unhexlify(k.encode("ascii")), both_words[1])
for k,both_words for k, both_words in raw_words.items()])
in raw_words.items()])
even_words_lowercase, odd_words_lowercase = set(), set() even_words_lowercase, odd_words_lowercase = set(), set()
@ -158,6 +287,7 @@ for k,both_words in raw_words.items():
even_words_lowercase.add(even_word.lower()) even_words_lowercase.add(even_word.lower())
odd_words_lowercase.add(odd_word.lower()) odd_words_lowercase.add(odd_word.lower())
@implementer(IWordlist) @implementer(IWordlist)
class PGPWordList(object): class PGPWordList(object):
def get_completions(self, prefix, num_words=2): def get_completions(self, prefix, num_words=2):

View File

@ -3,20 +3,24 @@ from __future__ import print_function
import os import os
import time import time
start = time.time() start = time.time()
import six
from textwrap import fill, dedent
from sys import stdout, stderr
from . import public_relay
from .. import __version__
from ..timing import DebugTiming
from ..errors import (WrongPasswordError, WelcomeError, KeyFormatError,
TransferError, NoTorError, UnsendableFileError,
ServerConnectionError)
from twisted.internet.defer import inlineCallbacks, maybeDeferred
from twisted.python.failure import Failure
from twisted.internet.task import react
import click from sys import stderr, stdout # noqa: E402
from textwrap import dedent, fill # noqa: E402
import click # noqa: E402
import six # noqa: E402
from twisted.internet.defer import inlineCallbacks, maybeDeferred # noqa: E402
from twisted.internet.task import react # noqa: E402
from twisted.python.failure import Failure # noqa: E402
from . import public_relay # noqa: E402
from .. import __version__ # noqa: E402
from ..errors import (KeyFormatError, NoTorError, # noqa: E402
ServerConnectionError,
TransferError, UnsendableFileError, WelcomeError,
WrongPasswordError)
from ..timing import DebugTiming # noqa: E402
top_import_finish = time.time() top_import_finish = time.time()
@ -24,6 +28,7 @@ class Config(object):
""" """
Union of config options that we pass down to (sub) commands. Union of config options that we pass down to (sub) commands.
""" """
def __init__(self): def __init__(self):
# This only holds attributes which are *not* set by CLI arguments. # This only holds attributes which are *not* set by CLI arguments.
# Everything else comes from Click decorators, so we can be sure # Everything else comes from Click decorators, so we can be sure
@ -34,11 +39,13 @@ class Config(object):
self.stderr = stderr self.stderr = stderr
self.tor = False # XXX? self.tor = False # XXX?
def _compose(*decorators): def _compose(*decorators):
def decorate(f): def decorate(f):
for d in reversed(decorators): for d in reversed(decorators):
f = d(f) f = d(f)
return f return f
return decorate return decorate
@ -48,6 +55,8 @@ ALIASES = {
"recieve": "receive", "recieve": "receive",
"recv": "receive", "recv": "receive",
} }
class AliasedGroup(click.Group): class AliasedGroup(click.Group):
def get_command(self, ctx, cmd_name): def get_command(self, ctx, cmd_name):
cmd_name = ALIASES.get(cmd_name, cmd_name) cmd_name = ALIASES.get(cmd_name, cmd_name)
@ -56,22 +65,24 @@ class AliasedGroup(click.Group):
# top-level command ("wormhole ...") # top-level command ("wormhole ...")
@click.group(cls=AliasedGroup) @click.group(cls=AliasedGroup)
@click.option("--appid", default=None, metavar="APPID", help="appid to use")
@click.option( @click.option(
"--appid", default=None, metavar="APPID", help="appid to use") "--relay-url",
@click.option( default=public_relay.RENDEZVOUS_RELAY,
"--relay-url", default=public_relay.RENDEZVOUS_RELAY,
envvar='WORMHOLE_RELAY_URL', envvar='WORMHOLE_RELAY_URL',
metavar="URL", metavar="URL",
help="rendezvous relay to use", help="rendezvous relay to use",
) )
@click.option( @click.option(
"--transit-helper", default=public_relay.TRANSIT_RELAY, "--transit-helper",
default=public_relay.TRANSIT_RELAY,
envvar='WORMHOLE_TRANSIT_HELPER', envvar='WORMHOLE_TRANSIT_HELPER',
metavar="tcp:HOST:PORT", metavar="tcp:HOST:PORT",
help="transit relay to use", help="transit relay to use",
) )
@click.option( @click.option(
"--dump-timing", type=type(u""), # TODO: hide from --help output "--dump-timing",
type=type(u""), # TODO: hide from --help output
default=None, default=None,
metavar="FILE.json", metavar="FILE.json",
help="(debug) write timing data to file", help="(debug) write timing data to file",
@ -104,7 +115,8 @@ def _dispatch_command(reactor, cfg, command):
errors for the user. errors for the user.
""" """
cfg.timing.add("command dispatch") cfg.timing.add("command dispatch")
cfg.timing.add("import", when=start, which="top").finish(when=top_import_finish) cfg.timing.add(
"import", when=start, which="top").finish(when=top_import_finish)
try: try:
yield maybeDeferred(command) yield maybeDeferred(command)
@ -141,56 +153,89 @@ def _dispatch_command(reactor, cfg, command):
CommonArgs = _compose( CommonArgs = _compose(
click.option("-0", "zeromode", default=False, is_flag=True, click.option(
"-0",
"zeromode",
default=False,
is_flag=True,
help="enable no-code anything-goes mode", help="enable no-code anything-goes mode",
), ),
click.option("-c", "--code-length", default=2, metavar="NUMWORDS", click.option(
"-c",
"--code-length",
default=2,
metavar="NUMWORDS",
help="length of code (in bytes/words)", help="length of code (in bytes/words)",
), ),
click.option("-v", "--verify", is_flag=True, default=False, click.option(
"-v",
"--verify",
is_flag=True,
default=False,
help="display verification string (and wait for approval)", help="display verification string (and wait for approval)",
), ),
click.option("--hide-progress", is_flag=True, default=False, click.option(
"--hide-progress",
is_flag=True,
default=False,
help="supress progress-bar display", help="supress progress-bar display",
), ),
click.option("--listen/--no-listen", default=True, click.option(
"--listen/--no-listen",
default=True,
help="(debug) don't open a listening socket for Transit", help="(debug) don't open a listening socket for Transit",
), ),
) )
TorArgs = _compose( TorArgs = _compose(
click.option("--tor", is_flag=True, default=False, click.option(
"--tor",
is_flag=True,
default=False,
help="use Tor when connecting", help="use Tor when connecting",
), ),
click.option("--launch-tor", is_flag=True, default=False, click.option(
"--launch-tor",
is_flag=True,
default=False,
help="launch Tor, rather than use existing control/socks port", help="launch Tor, rather than use existing control/socks port",
), ),
click.option("--tor-control-port", default=None, metavar="ENDPOINT", click.option(
"--tor-control-port",
default=None,
metavar="ENDPOINT",
help="endpoint descriptor for Tor control port", help="endpoint descriptor for Tor control port",
), ),
) )
@wormhole.command() @wormhole.command()
@click.pass_context @click.pass_context
def help(context, **kwargs): def help(context, **kwargs):
print(context.find_root().get_help()) print(context.find_root().get_help())
# wormhole send (or "wormhole tx") # wormhole send (or "wormhole tx")
@wormhole.command() @wormhole.command()
@CommonArgs @CommonArgs
@TorArgs @TorArgs
@click.option( @click.option(
"--code", metavar="CODE", "--code",
metavar="CODE",
help="human-generated code phrase", help="human-generated code phrase",
) )
@click.option( @click.option(
"--text", default=None, metavar="MESSAGE", "--text",
help="text message to send, instead of a file. Use '-' to read from stdin.", default=None,
metavar="MESSAGE",
help=("text message to send, instead of a file."
" Use '-' to read from stdin."),
) )
@click.option( @click.option(
"--ignore-unsendable-files", default=False, is_flag=True, "--ignore-unsendable-files",
help="Don't raise an error if a file can't be read." default=False,
) is_flag=True,
help="Don't raise an error if a file can't be read.")
@click.argument("what", required=False, type=click.Path(path_type=type(u""))) @click.argument("what", required=False, type=click.Path(path_type=type(u"")))
@click.pass_obj @click.pass_obj
def send(cfg, **kwargs): def send(cfg, **kwargs):
@ -202,6 +247,7 @@ def send(cfg, **kwargs):
return go(cmd_send.send, cfg) return go(cmd_send.send, cfg)
# this intermediate function can be mocked by tests that need to build a # this intermediate function can be mocked by tests that need to build a
# Config object # Config object
def go(f, cfg): def go(f, cfg):
@ -214,21 +260,27 @@ def go(f, cfg):
@CommonArgs @CommonArgs
@TorArgs @TorArgs
@click.option( @click.option(
"--only-text", "-t", is_flag=True, "--only-text",
"-t",
is_flag=True,
help="refuse file transfers, only accept text transfers", help="refuse file transfers, only accept text transfers",
) )
@click.option( @click.option(
"--accept-file", is_flag=True, "--accept-file",
is_flag=True,
help="accept file transfer without asking for confirmation", help="accept file transfer without asking for confirmation",
) )
@click.option( @click.option(
"--output-file", "-o", "--output-file",
"-o",
metavar="FILENAME|DIRNAME", metavar="FILENAME|DIRNAME",
help=("The file or directory to create, overriding the name suggested" help=("The file or directory to create, overriding the name suggested"
" by the sender."), " by the sender."),
) )
@click.argument( @click.argument(
"code", nargs=-1, default=None, "code",
nargs=-1,
default=None,
# help=("The magic-wormhole code, from the sender. If omitted, the" # help=("The magic-wormhole code, from the sender. If omitted, the"
# " program will ask for it, using tab-completion."), # " program will ask for it, using tab-completion."),
) )
@ -244,10 +296,8 @@ def receive(cfg, code, **kwargs):
if len(code) == 1: if len(code) == 1:
cfg.code = code[0] cfg.code = code[0]
elif len(code) > 1: elif len(code) > 1:
print( print("Pass either no code or just one code; you passed"
"Pass either no code or just one code; you passed" " {}: {}".format(len(code), ', '.join(code)))
" {}: {}".format(len(code), ', '.join(code))
)
raise SystemExit(1) raise SystemExit(1)
else: else:
cfg.code = None cfg.code = None
@ -260,17 +310,19 @@ def ssh():
""" """
Facilitate sending/receiving SSH public keys Facilitate sending/receiving SSH public keys
""" """
pass
@ssh.command(name="invite") @ssh.command(name="invite")
@click.option( @click.option(
"-c", "--code-length", default=2, "-c",
"--code-length",
default=2,
metavar="NUMWORDS", metavar="NUMWORDS",
help="length of code (in bytes/words)", help="length of code (in bytes/words)",
) )
@click.option( @click.option(
"--user", "-u", "--user",
"-u",
default=None, default=None,
metavar="USER", metavar="USER",
help="Add to USER's ~/.ssh/authorized_keys", help="Add to USER's ~/.ssh/authorized_keys",
@ -291,15 +343,20 @@ def ssh_invite(ctx, code_length, user, **kwargs):
@ssh.command(name="accept") @ssh.command(name="accept")
@click.argument( @click.argument(
"code", nargs=1, required=True, "code",
nargs=1,
required=True,
) )
@click.option( @click.option(
"--key-file", "-F", "--key-file",
"-F",
default=None, default=None,
type=click.Path(exists=True), type=click.Path(exists=True),
) )
@click.option( @click.option(
"--yes", "-y", is_flag=True, "--yes",
"-y",
is_flag=True,
help="Skip confirmation prompt to send key", help="Skip confirmation prompt to send key",
) )
@TorArgs @TorArgs
@ -318,7 +375,8 @@ def ssh_accept(cfg, code, key_file, yes, **kwargs):
kind, keyid, pubkey = cmd_ssh.find_public_key(key_file) kind, keyid, pubkey = cmd_ssh.find_public_key(key_file)
print("Sending public key type='{}' keyid='{}'".format(kind, keyid)) print("Sending public key type='{}' keyid='{}'".format(kind, keyid))
if yes is not True: if yes is not True:
click.confirm("Really send public key '{}' ?".format(keyid), abort=True) click.confirm(
"Really send public key '{}' ?".format(keyid), abort=True)
cfg.public_key = (kind, keyid, pubkey) cfg.public_key = (kind, keyid, pubkey)
cfg.code = code cfg.code = code

View File

@ -1,14 +1,23 @@
from __future__ import print_function from __future__ import print_function
import os, sys, six, tempfile, zipfile, hashlib, shutil
from tqdm import tqdm import hashlib
import os
import shutil
import sys
import tempfile
import zipfile
import six
from humanize import naturalsize from humanize import naturalsize
from tqdm import tqdm
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.python import log from twisted.python import log
from wormhole import create, input_with_completion, __version__ from wormhole import __version__, create, input_with_completion
from ..transit import TransitReceiver
from ..errors import TransferError from ..errors import TransferError
from ..util import (dict_to_bytes, bytes_to_dict, bytes_to_hexstr, from ..transit import TransitReceiver
from ..util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes,
estimate_free_space) estimate_free_space)
from .welcome import handle_welcome from .welcome import handle_welcome
@ -17,14 +26,17 @@ APPID = u"lothar.com/wormhole/text-or-file-xfer"
KEY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_KEY_TIMER", 1.0)) KEY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_KEY_TIMER", 1.0))
VERIFY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_VERIFY_TIMER", 1.0)) VERIFY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_VERIFY_TIMER", 1.0))
class RespondError(Exception): class RespondError(Exception):
def __init__(self, response): def __init__(self, response):
self.response = response self.response = response
class TransferRejectedError(RespondError): class TransferRejectedError(RespondError):
def __init__(self): def __init__(self):
RespondError.__init__(self, "transfer rejected") RespondError.__init__(self, "transfer rejected")
def receive(args, reactor=reactor, _debug_stash_wormhole=None): def receive(args, reactor=reactor, _debug_stash_wormhole=None):
"""I implement 'wormhole receive'. I return a Deferred that fires with """I implement 'wormhole receive'. I return a Deferred that fires with
None (for success), or signals one of the following errors: None (for success), or signals one of the following errors:
@ -60,12 +72,15 @@ class Receiver:
# tor in parallel with everything else, make sure the Tor object # tor in parallel with everything else, make sure the Tor object
# can lazy-provide an endpoint, and overlap the startup process # can lazy-provide an endpoint, and overlap the startup process
# with the user handing off the wormhole code # with the user handing off the wormhole code
self._tor = yield get_tor(self._reactor, self._tor = yield get_tor(
self._reactor,
self.args.launch_tor, self.args.launch_tor,
self.args.tor_control_port, self.args.tor_control_port,
timing=self.args.timing) timing=self.args.timing)
w = create(self.args.appid or APPID, self.args.relay_url, w = create(
self.args.appid or APPID,
self.args.relay_url,
self._reactor, self._reactor,
tor=self._tor, tor=self._tor,
timing=self.args.timing) timing=self.args.timing)
@ -97,7 +112,7 @@ class Receiver:
def _bad(f): def _bad(f):
try: try:
yield w.close() # might be an error too yield w.close() # might be an error too
except: except Exception:
pass pass
returnValue(f) returnValue(f)
@ -114,6 +129,7 @@ class Receiver:
def on_slow_key(): def on_slow_key():
print(u"Waiting for sender...", file=self.args.stderr) print(u"Waiting for sender...", file=self.args.stderr)
notify = self._reactor.callLater(KEY_TIMER, on_slow_key) notify = self._reactor.callLater(KEY_TIMER, on_slow_key)
try: try:
# We wait here until we connect to the server and see the senders # We wait here until we connect to the server and see the senders
@ -129,8 +145,10 @@ class Receiver:
notify.cancel() notify.cancel()
def on_slow_verification(): def on_slow_verification():
print(u"Key established, waiting for confirmation...", print(
u"Key established, waiting for confirmation...",
file=self.args.stderr) file=self.args.stderr)
notify = self._reactor.callLater(VERIFY_TIMER, on_slow_verification) notify = self._reactor.callLater(VERIFY_TIMER, on_slow_verification)
try: try:
# We wait here until we've seen their VERSION message (which they # We wait here until we've seen their VERSION message (which they
@ -197,11 +215,11 @@ class Receiver:
w.set_code(code) w.set_code(code)
else: else:
prompt = "Enter receive wormhole code: " prompt = "Enter receive wormhole code: "
used_completion = yield input_with_completion(prompt, used_completion = yield input_with_completion(
w.input_code(), prompt, w.input_code(), self._reactor)
self._reactor)
if not used_completion: if not used_completion:
print(" (note: you can use <Tab> to complete words)", print(
" (note: you can use <Tab> to complete words)",
file=self.args.stderr) file=self.args.stderr)
yield w.get_code() yield w.get_code()
@ -220,19 +238,22 @@ class Receiver:
@inlineCallbacks @inlineCallbacks
def _build_transit(self, w, sender_transit): def _build_transit(self, w, sender_transit):
tr = TransitReceiver(self.args.transit_helper, tr = TransitReceiver(
self.args.transit_helper,
no_listen=(not self.args.listen), no_listen=(not self.args.listen),
tor=self._tor, tor=self._tor,
reactor=self._reactor, reactor=self._reactor,
timing=self.args.timing) timing=self.args.timing)
self._transit_receiver = tr self._transit_receiver = tr
transit_key = w.derive_key(APPID+u"/transit-key", tr.TRANSIT_KEY_LENGTH) transit_key = w.derive_key(APPID + u"/transit-key",
tr.TRANSIT_KEY_LENGTH)
tr.set_transit_key(transit_key) tr.set_transit_key(transit_key)
tr.add_connection_hints(sender_transit.get("hints-v1", [])) tr.add_connection_hints(sender_transit.get("hints-v1", []))
receiver_abilities = tr.get_connection_abilities() receiver_abilities = tr.get_connection_abilities()
receiver_hints = yield tr.get_connection_hints() receiver_hints = yield tr.get_connection_hints()
receiver_transit = {"abilities-v1": receiver_abilities, receiver_transit = {
"abilities-v1": receiver_abilities,
"hints-v1": receiver_hints, "hints-v1": receiver_hints,
} }
self._send_data({u"transit": receiver_transit}, w) self._send_data({u"transit": receiver_transit}, w)
@ -276,12 +297,13 @@ class Receiver:
self.xfersize = file_data["filesize"] self.xfersize = file_data["filesize"]
free = estimate_free_space(self.abs_destname) free = estimate_free_space(self.abs_destname)
if free is not None and free < self.xfersize: if free is not None and free < self.xfersize:
self._msg(u"Error: insufficient free space (%sB) for file (%sB)" self._msg(u"Error: insufficient free space (%sB) for file (%sB)" %
% (free, self.xfersize)) (free, self.xfersize))
raise TransferRejectedError() raise TransferRejectedError()
self._msg(u"Receiving file (%s) into: %s" % self._msg(u"Receiving file (%s) into: %s" %
(naturalsize(self.xfersize), os.path.basename(self.abs_destname))) (naturalsize(self.xfersize),
os.path.basename(self.abs_destname)))
self._ask_permission() self._ask_permission()
tmp_destname = self.abs_destname + ".tmp" tmp_destname = self.abs_destname + ".tmp"
return open(tmp_destname, "wb") return open(tmp_destname, "wb")
@ -290,19 +312,22 @@ class Receiver:
file_data = them_d["directory"] file_data = them_d["directory"]
zipmode = file_data["mode"] zipmode = file_data["mode"]
if zipmode != "zipfile/deflated": if zipmode != "zipfile/deflated":
self._msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,)) self._msg(u"Error: unknown directory-transfer mode '%s'" %
(zipmode, ))
raise RespondError("unknown mode") raise RespondError("unknown mode")
self.abs_destname = self._decide_destname("directory", self.abs_destname = self._decide_destname("directory",
file_data["dirname"]) file_data["dirname"])
self.xfersize = file_data["zipsize"] self.xfersize = file_data["zipsize"]
free = estimate_free_space(self.abs_destname) free = estimate_free_space(self.abs_destname)
if free is not None and free < file_data["numbytes"]: if free is not None and free < file_data["numbytes"]:
self._msg(u"Error: insufficient free space (%sB) for directory (%sB)" self._msg(
% (free, file_data["numbytes"])) u"Error: insufficient free space (%sB) for directory (%sB)" %
(free, file_data["numbytes"]))
raise TransferRejectedError() raise TransferRejectedError()
self._msg(u"Receiving directory (%s) into: %s/" % self._msg(u"Receiving directory (%s) into: %s/" %
(naturalsize(self.xfersize), os.path.basename(self.abs_destname))) (naturalsize(self.xfersize),
os.path.basename(self.abs_destname)))
self._msg(u"%d files, %s (uncompressed)" % self._msg(u"%d files, %s (uncompressed)" %
(file_data["numfiles"], naturalsize(file_data["numbytes"]))) (file_data["numfiles"], naturalsize(file_data["numbytes"])))
self._ask_permission() self._ask_permission()
@ -323,13 +348,16 @@ class Receiver:
if self.args.accept_file: if self.args.accept_file:
self._remove_existing(abs_destname) self._remove_existing(abs_destname)
else: else:
self._msg(u"Error: refusing to overwrite existing '%s'" % destname) self._msg(
u"Error: refusing to overwrite existing '%s'" % destname)
raise TransferRejectedError() raise TransferRejectedError()
return abs_destname return abs_destname
def _remove_existing(self, path): def _remove_existing(self, path):
if os.path.isfile(path): os.remove(path) if os.path.isfile(path):
if os.path.isdir(path): shutil.rmtree(path) os.remove(path)
if os.path.isdir(path):
shutil.rmtree(path)
def _ask_permission(self): def _ask_permission(self):
with self.args.timing.add("permission", waiting="user") as t: with self.args.timing.add("permission", waiting="user") as t:
@ -359,14 +387,16 @@ class Receiver:
self._msg(u"Receiving (%s).." % record_pipe.describe()) self._msg(u"Receiving (%s).." % record_pipe.describe())
with self.args.timing.add("rx file"): with self.args.timing.add("rx file"):
progress = tqdm(file=self.args.stderr, progress = tqdm(
file=self.args.stderr,
disable=self.args.hide_progress, disable=self.args.hide_progress,
unit="B", unit_scale=True, total=self.xfersize) unit="B",
unit_scale=True,
total=self.xfersize)
hasher = hashlib.sha256() hasher = hashlib.sha256()
with progress: with progress:
received = yield record_pipe.writeToFile(f, self.xfersize, received = yield record_pipe.writeToFile(
progress.update, f, self.xfersize, progress.update, hasher.update)
hasher.update)
datahash = hasher.digest() datahash = hasher.digest()
# except TransitError # except TransitError
@ -382,8 +412,8 @@ class Receiver:
tmp_name = f.name tmp_name = f.name
f.close() f.close()
os.rename(tmp_name, self.abs_destname) os.rename(tmp_name, self.abs_destname)
self._msg(u"Received file written to %s" % self._msg(u"Received file written to %s" % os.path.basename(
os.path.basename(self.abs_destname)) self.abs_destname))
def _extract_file(self, zf, info, extract_dir): def _extract_file(self, zf, info, extract_dir):
""" """
@ -393,8 +423,9 @@ class Receiver:
out_path = os.path.join(extract_dir, info.filename) out_path = os.path.join(extract_dir, info.filename)
out_path = os.path.abspath(out_path) out_path = os.path.abspath(out_path)
if not out_path.startswith(extract_dir): if not out_path.startswith(extract_dir):
raise ValueError( "malicious zipfile, %s outside of extract_dir %s" raise ValueError(
% (info.filename, extract_dir) ) "malicious zipfile, %s outside of extract_dir %s" %
(info.filename, extract_dir))
zf.extract(info.filename, path=extract_dir) zf.extract(info.filename, path=extract_dir)
@ -410,8 +441,8 @@ class Receiver:
for info in zf.infolist(): for info in zf.infolist():
self._extract_file(zf, info, self.abs_destname) self._extract_file(zf, info, self.abs_destname)
self._msg(u"Received files written to %s/" % self._msg(u"Received files written to %s/" % os.path.basename(
os.path.basename(self.abs_destname)) self.abs_destname))
f.close() f.close()
@inlineCallbacks @inlineCallbacks

View File

@ -1,20 +1,29 @@
from __future__ import print_function from __future__ import print_function
import os, sys, six, tempfile, zipfile, hashlib
from tqdm import tqdm import hashlib
import os
import sys
import tempfile
import zipfile
import six
from humanize import naturalsize from humanize import naturalsize
from twisted.python import log from tqdm import tqdm
from twisted.protocols import basic
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.protocols import basic
from twisted.python import log
from wormhole import __version__, create
from ..errors import TransferError, UnsendableFileError from ..errors import TransferError, UnsendableFileError
from wormhole import create, __version__
from ..transit import TransitSender from ..transit import TransitSender
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr from ..util import bytes_to_dict, bytes_to_hexstr, dict_to_bytes
from .welcome import handle_welcome from .welcome import handle_welcome
APPID = u"lothar.com/wormhole/text-or-file-xfer" APPID = u"lothar.com/wormhole/text-or-file-xfer"
VERIFY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_VERIFY_TIMER", 1.0)) VERIFY_TIMER = float(os.environ.get("_MAGIC_WORMHOLE_TEST_VERIFY_TIMER", 1.0))
def send(args, reactor=reactor): def send(args, reactor=reactor):
"""I implement 'wormhole send'. I return a Deferred that fires with None """I implement 'wormhole send'. I return a Deferred that fires with None
(for success), or signals one of the following errors: (for success), or signals one of the following errors:
@ -26,6 +35,7 @@ def send(args, reactor=reactor):
""" """
return Sender(args, reactor).go() return Sender(args, reactor).go()
class Sender: class Sender:
def __init__(self, args, reactor): def __init__(self, args, reactor):
self._args = args self._args = args
@ -45,12 +55,15 @@ class Sender:
# tor in parallel with everything else, make sure the Tor object # tor in parallel with everything else, make sure the Tor object
# can lazy-provide an endpoint, and overlap the startup process # can lazy-provide an endpoint, and overlap the startup process
# with the user handing off the wormhole code # with the user handing off the wormhole code
self._tor = yield get_tor(reactor, self._tor = yield get_tor(
reactor,
self._args.launch_tor, self._args.launch_tor,
self._args.tor_control_port, self._args.tor_control_port,
timing=self._timing) timing=self._timing)
w = create(self._args.appid or APPID, self._args.relay_url, w = create(
self._args.appid or APPID,
self._args.relay_url,
self._reactor, self._reactor,
tor=self._tor, tor=self._tor,
timing=self._timing) timing=self._timing)
@ -70,7 +83,7 @@ class Sender:
def _bad(f): def _bad(f):
try: try:
yield w.close() # might be an error too yield w.close() # might be an error too
except: except Exception:
pass pass
returnValue(f) returnValue(f)
@ -125,8 +138,10 @@ class Sender:
# TODO: don't stall on w.get_verifier() unless they want it # TODO: don't stall on w.get_verifier() unless they want it
def on_slow_connection(): def on_slow_connection():
print(u"Key established, waiting for confirmation...", print(
u"Key established, waiting for confirmation...",
file=args.stderr) file=args.stderr)
notify = self._reactor.callLater(VERIFY_TIMER, on_slow_connection) notify = self._reactor.callLater(VERIFY_TIMER, on_slow_connection)
try: try:
# The usual sender-chooses-code sequence means the receiver's # The usual sender-chooses-code sequence means the receiver's
@ -143,10 +158,12 @@ class Sender:
notify.cancel() notify.cancel()
if args.verify: if args.verify:
self._check_verifier(w, verifier_bytes) # blocks, can TransferError self._check_verifier(w,
verifier_bytes) # blocks, can TransferError
if self._fd_to_send: if self._fd_to_send:
ts = TransitSender(args.transit_helper, ts = TransitSender(
args.transit_helper,
no_listen=(not args.listen), no_listen=(not args.listen),
tor=self._tor, tor=self._tor,
reactor=self._reactor, reactor=self._reactor,
@ -156,7 +173,8 @@ class Sender:
# for now, send this before the main offer # for now, send this before the main offer
sender_abilities = ts.get_connection_abilities() sender_abilities = ts.get_connection_abilities()
sender_hints = yield ts.get_connection_hints() sender_hints = yield ts.get_connection_hints()
sender_transit = {"abilities-v1": sender_abilities, sender_transit = {
"abilities-v1": sender_abilities,
"hints-v1": sender_hints, "hints-v1": sender_hints,
} }
self._send_data({u"transit": sender_transit}, w) self._send_data({u"transit": sender_transit}, w)
@ -178,8 +196,8 @@ class Sender:
# print("GOT", them_d) # print("GOT", them_d)
recognized = False recognized = False
if u"error" in them_d: if u"error" in them_d:
raise TransferError("remote error, transfer abandoned: %s" raise TransferError(
% them_d["error"]) "remote error, transfer abandoned: %s" % them_d["error"])
if u"transit" in them_d: if u"transit" in them_d:
recognized = True recognized = True
yield self._handle_transit(them_d[u"transit"]) yield self._handle_transit(them_d[u"transit"])
@ -221,7 +239,8 @@ class Sender:
text = six.moves.input("Text to send: ") text = six.moves.input("Text to send: ")
if text is not None: if text is not None:
print(u"Sending text message (%s)" % naturalsize(len(text)), print(
u"Sending text message (%s)" % naturalsize(len(text)),
file=args.stderr) file=args.stderr)
offer = {"message": text} offer = {"message": text}
fd_to_send = None fd_to_send = None
@ -273,8 +292,8 @@ class Sender:
what = os.path.realpath(what) what = os.path.realpath(what)
if not os.path.exists(what): if not os.path.exists(what):
raise TransferError("Cannot send: no file/directory named '%s'" % raise TransferError(
args.what) "Cannot send: no file/directory named '%s'" % args.what)
if os.path.isfile(what): if os.path.isfile(what):
# we're sending a file # we're sending a file
@ -283,8 +302,9 @@ class Sender:
"filename": basename, "filename": basename,
"filesize": filesize, "filesize": filesize,
} }
print(u"Sending %s file named '%s'" print(
% (naturalsize(filesize), basename), u"Sending %s file named '%s'" % (naturalsize(filesize),
basename),
file=args.stderr) file=args.stderr)
fd_to_send = open(what, "rb") fd_to_send = open(what, "rb")
return offer, fd_to_send return offer, fd_to_send
@ -297,7 +317,9 @@ class Sender:
num_files = 0 num_files = 0
num_bytes = 0 num_bytes = 0
tostrip = len(what.split(os.sep)) tostrip = len(what.split(os.sep))
with zipfile.ZipFile(fd_to_send, "w", with zipfile.ZipFile(
fd_to_send,
"w",
compression=zipfile.ZIP_DEFLATED, compression=zipfile.ZIP_DEFLATED,
allowZip64=True) as zf: allowZip64=True) as zf:
for path, dirs, files in os.walk(what): for path, dirs, files in os.walk(what):
@ -315,7 +337,8 @@ class Sender:
except OSError as e: except OSError as e:
errmsg = u"{}: {}".format(fn, e.strerror) errmsg = u"{}: {}".format(fn, e.strerror)
if self._args.ignore_unsendable_files: if self._args.ignore_unsendable_files:
print(u"{} (ignoring error)".format(errmsg), print(
u"{} (ignoring error)".format(errmsg),
file=args.stderr) file=args.stderr)
else: else:
raise UnsendableFileError(errmsg) raise UnsendableFileError(errmsg)
@ -329,8 +352,10 @@ class Sender:
"numbytes": num_bytes, "numbytes": num_bytes,
"numfiles": num_files, "numfiles": num_files,
} }
print(u"Sending directory (%s compressed) named '%s'" print(
% (naturalsize(filesize), basename), file=args.stderr) u"Sending directory (%s compressed) named '%s'" %
(naturalsize(filesize), basename),
file=args.stderr)
return offer, fd_to_send return offer, fd_to_send
raise TypeError("'%s' is neither file nor directory" % args.what) raise TypeError("'%s' is neither file nor directory" % args.what)
@ -349,7 +374,6 @@ class Sender:
yield self._send_file() yield self._send_file()
@inlineCallbacks @inlineCallbacks
def _send_file(self): def _send_file(self):
ts = self._transit_sender ts = self._transit_sender
@ -365,20 +389,27 @@ class Sender:
print(u"Sending (%s).." % record_pipe.describe(), file=stderr) print(u"Sending (%s).." % record_pipe.describe(), file=stderr)
hasher = hashlib.sha256() hasher = hashlib.sha256()
progress = tqdm(file=stderr, disable=self._args.hide_progress, progress = tqdm(
unit="B", unit_scale=True, file=stderr,
disable=self._args.hide_progress,
unit="B",
unit_scale=True,
total=filesize) total=filesize)
def _count_and_hash(data): def _count_and_hash(data):
hasher.update(data) hasher.update(data)
progress.update(len(data)) progress.update(len(data))
return data return data
fs = basic.FileSender() fs = basic.FileSender()
with self._timing.add("tx file"): with self._timing.add("tx file"):
with progress: with progress:
if filesize: if filesize:
# don't send zero-length files # don't send zero-length files
yield fs.beginFileTransfer(self._fd_to_send, record_pipe, yield fs.beginFileTransfer(
self._fd_to_send,
record_pipe,
transform=_count_and_hash) transform=_count_and_hash)
expected_hash = hasher.digest() expected_hash = hasher.digest()

View File

@ -1,16 +1,19 @@
from __future__ import print_function from __future__ import print_function
import os import os
from os.path import expanduser, exists, join from os.path import exists, expanduser, join
from twisted.internet.defer import inlineCallbacks
from twisted.internet import reactor
import click import click
from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks
from .. import xfer_util from .. import xfer_util
class PubkeyError(Exception): class PubkeyError(Exception):
pass pass
def find_public_key(hint=None): def find_public_key(hint=None):
""" """
This looks for an appropriate SSH key to send, possibly querying This looks for an appropriate SSH key to send, possibly querying
@ -34,8 +37,9 @@ def find_public_key(hint=None):
got_key = False got_key = False
while not got_key: while not got_key:
ans = click.prompt( ans = click.prompt(
"Multiple public-keys found:\n" + \ "Multiple public-keys found:\n" +
"\n".join([" {}: {}".format(a, b) for a, b in enumerate(pubkeys)]) + \ "\n".join([" {}: {}".format(a, b)
for a, b in enumerate(pubkeys)]) +
"\nSend which one?" "\nSend which one?"
) )
try: try:
@ -76,7 +80,6 @@ def accept(cfg, reactor=reactor):
@inlineCallbacks @inlineCallbacks
def invite(cfg, reactor=reactor): def invite(cfg, reactor=reactor):
def on_code_created(code): def on_code_created(code):
print("Now tell the other user to run:") print("Now tell the other user to run:")
print() print()

View File

@ -1,4 +1,3 @@
# This is a relay I run on a personal server. If it gets too expensive to # This is a relay I run on a personal server. If it gets too expensive to
# run, I'll shut it down. # run, I'll shut it down.
RENDEZVOUS_RELAY = u"ws://relay.magic-wormhole.io:4000/v1" RENDEZVOUS_RELAY = u"ws://relay.magic-wormhole.io:4000/v1"

View File

@ -1,18 +1,24 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
def handle_welcome(welcome, relay_url, my_version, stderr): def handle_welcome(welcome, relay_url, my_version, stderr):
if "motd" in welcome: if "motd" in welcome:
motd_lines = welcome["motd"].splitlines() motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines) motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" % (relay_url, motd_formatted), print(
"Server (at %s) says:\n %s" % (relay_url, motd_formatted),
file=stderr) file=stderr)
# Only warn if we're running a release version (e.g. 0.0.6, not # Only warn if we're running a release version (e.g. 0.0.6, not
# 0.0.6+DISTANCE.gHASH). Only warn once. # 0.0.6+DISTANCE.gHASH). Only warn once.
if ("current_cli_version" in welcome if (("current_cli_version" in welcome and
and "+" not in my_version "+" not in my_version and
and welcome["current_cli_version"] != my_version): welcome["current_cli_version"] != my_version)):
print("Warning: errors may occur unless both sides are running the same version", file=stderr) print(
print("Server claims %s is current, but ours is %s" ("Warning: errors may occur unless both sides are running the"
% (welcome["current_cli_version"], my_version), " same version"),
file=stderr)
print(
"Server claims %s is current, but ours is %s" %
(welcome["current_cli_version"], my_version),
file=stderr) file=stderr)

View File

@ -1,8 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
class WormholeError(Exception): class WormholeError(Exception):
"""Parent class for all wormhole-related errors""" """Parent class for all wormhole-related errors"""
class UnsendableFileError(Exception): class UnsendableFileError(Exception):
""" """
A file you wanted to send couldn't be read, maybe because it's not A file you wanted to send couldn't be read, maybe because it's not
@ -13,30 +15,38 @@ class UnsendableFileError(Exception):
--ignore-unsendable-files flag. --ignore-unsendable-files flag.
""" """
class ServerError(WormholeError): class ServerError(WormholeError):
"""The relay server complained about something we did.""" """The relay server complained about something we did."""
class ServerConnectionError(WormholeError): class ServerConnectionError(WormholeError):
"""We had a problem connecting to the relay server:""" """We had a problem connecting to the relay server:"""
def __init__(self, url, reason): def __init__(self, url, reason):
self.url = url self.url = url
self.reason = reason self.reason = reason
def __str__(self): def __str__(self):
return str(self.reason) return str(self.reason)
class Timeout(WormholeError): class Timeout(WormholeError):
pass pass
class WelcomeError(WormholeError): class WelcomeError(WormholeError):
""" """
The relay server told us to signal an error, probably because our version The relay server told us to signal an error, probably because our version
is too old to possibly work. The server said:""" is too old to possibly work. The server said:"""
pass pass
class LonelyError(WormholeError): class LonelyError(WormholeError):
"""wormhole.close() was called before the peer connection could be """wormhole.close() was called before the peer connection could be
established""" established"""
class WrongPasswordError(WormholeError): class WrongPasswordError(WormholeError):
""" """
Key confirmation failed. Either you or your correspondent typed the code Key confirmation failed. Either you or your correspondent typed the code
@ -47,6 +57,7 @@ class WrongPasswordError(WormholeError):
# or the data blob was corrupted, and that's why decrypt failed # or the data blob was corrupted, and that's why decrypt failed
pass pass
class KeyFormatError(WormholeError): class KeyFormatError(WormholeError):
""" """
The key you entered contains spaces or was missing a dash. Magic-wormhole The key you entered contains spaces or was missing a dash. Magic-wormhole
@ -55,43 +66,61 @@ class KeyFormatError(WormholeError):
dashes. dashes.
""" """
class ReflectionAttack(WormholeError): class ReflectionAttack(WormholeError):
"""An attacker (or bug) reflected our outgoing message back to us.""" """An attacker (or bug) reflected our outgoing message back to us."""
class InternalError(WormholeError): class InternalError(WormholeError):
"""The programmer did something wrong.""" """The programmer did something wrong."""
class TransferError(WormholeError): class TransferError(WormholeError):
"""Something bad happened and the transfer failed.""" """Something bad happened and the transfer failed."""
class NoTorError(WormholeError): class NoTorError(WormholeError):
"""--tor was requested, but 'txtorcon' is not installed.""" """--tor was requested, but 'txtorcon' is not installed."""
class NoKeyError(WormholeError): class NoKeyError(WormholeError):
"""w.derive_key() was called before got_verifier() fired""" """w.derive_key() was called before got_verifier() fired"""
class OnlyOneCodeError(WormholeError): class OnlyOneCodeError(WormholeError):
"""Only one w.generate_code/w.set_code/w.input_code may be called""" """Only one w.generate_code/w.set_code/w.input_code may be called"""
class MustChooseNameplateFirstError(WormholeError): class MustChooseNameplateFirstError(WormholeError):
"""The InputHelper was asked to do get_word_completions() or """The InputHelper was asked to do get_word_completions() or
choose_words() before the nameplate was chosen.""" choose_words() before the nameplate was chosen."""
class AlreadyChoseNameplateError(WormholeError): class AlreadyChoseNameplateError(WormholeError):
"""The InputHelper was asked to do get_nameplate_completions() after """The InputHelper was asked to do get_nameplate_completions() after
choose_nameplate() was called, or choose_nameplate() was called a second choose_nameplate() was called, or choose_nameplate() was called a second
time.""" time."""
class AlreadyChoseWordsError(WormholeError): class AlreadyChoseWordsError(WormholeError):
"""The InputHelper was asked to do get_word_completions() after """The InputHelper was asked to do get_word_completions() after
choose_words() was called, or choose_words() was called a second time.""" choose_words() was called, or choose_words() was called a second time."""
class AlreadyInputNameplateError(WormholeError): class AlreadyInputNameplateError(WormholeError):
"""The CodeInputter was asked to do completion on a nameplate, when we """The CodeInputter was asked to do completion on a nameplate, when we
had already committed to a different one.""" had already committed to a different one."""
class WormholeClosed(Exception): class WormholeClosed(Exception):
"""Deferred-returning API calls errback with WormholeClosed if the """Deferred-returning API calls errback with WormholeClosed if the
wormhole was already closed, or if it closes before a real result can be wormhole was already closed, or if it closes before a real result can be
obtained.""" obtained."""
class _UnknownPhaseError(Exception): class _UnknownPhaseError(Exception):
"""internal exception type, for tests.""" """internal exception type, for tests."""
class _UnknownMessageTypeError(Exception): class _UnknownMessageTypeError(Exception):
"""internal exception type, for tests.""" """internal exception type, for tests."""

View File

@ -5,6 +5,7 @@ from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IReactorTime from twisted.internet.interfaces import IReactorTime
from twisted.python import log from twisted.python import log
class EventualQueue(object): class EventualQueue(object):
def __init__(self, clock): def __init__(self, clock):
# pass clock=reactor unless you're testing # pass clock=reactor unless you're testing
@ -28,7 +29,7 @@ class EventualQueue(object):
(f, args, kwargs) = self._calls.pop(0) (f, args, kwargs) = self._calls.pop(0)
try: try:
f(*args, **kwargs) f(*args, **kwargs)
except: except Exception:
log.err() log.err()
self._timer = None self._timer = None
d, self._flush_d = self._flush_d, None d, self._flush_d = self._flush_d, None

View File

@ -1,20 +1,30 @@
# no unicode_literals # no unicode_literals
# Find all of our ip addresses. From tahoe's src/allmydata/util/iputil.py # Find all of our ip addresses. From tahoe's src/allmydata/util/iputil.py
import os, re, subprocess, errno import errno
import os
import re
import subprocess
from sys import platform from sys import platform
from twisted.python.procutils import which from twisted.python.procutils import which
# Wow, I'm really amazed at home much mileage we've gotten out of calling # Wow, I'm really amazed at home much mileage we've gotten out of calling
# the external route.exe program on windows... It appears to work on all # the external route.exe program on windows... It appears to work on all
# versions so far. Still, the real system calls would much be preferred... # versions so far. Still, the real system calls would much be preferred...
# ... thus wrote Greg Smith in time immemorial... # ... thus wrote Greg Smith in time immemorial...
_win32_re = re.compile(r'^\s*\d+\.\d+\.\d+\.\d+\s.+\s(?P<address>\d+\.\d+\.\d+\.\d+)\s+(?P<metric>\d+)\s*$', flags=re.M|re.I|re.S) _win32_re = re.compile(
(r'^\s*\d+\.\d+\.\d+\.\d+\s.+\s'
r'(?P<address>\d+\.\d+\.\d+\.\d+)\s+(?P<metric>\d+)\s*$'),
flags=re.M | re.I | re.S)
_win32_commands = (('route.exe', ('print', ), _win32_re), ) _win32_commands = (('route.exe', ('print', ), _win32_re), )
# These work in most Unices. # These work in most Unices.
_addr_re = re.compile(r'^\s*inet [a-zA-Z]*:?(?P<address>\d+\.\d+\.\d+\.\d+)[\s/].+$', flags=re.M|re.I|re.S) _addr_re = re.compile(
_unix_commands = (('/bin/ip', ('addr',), _addr_re), r'^\s*inet [a-zA-Z]*:?(?P<address>\d+\.\d+\.\d+\.\d+)[\s/].+$',
flags=re.M | re.I | re.S)
_unix_commands = (
('/bin/ip', ('addr', ), _addr_re),
('/sbin/ip', ('addr', ), _addr_re), ('/sbin/ip', ('addr', ), _addr_re),
('/sbin/ifconfig', ('-a', ), _addr_re), ('/sbin/ifconfig', ('-a', ), _addr_re),
('/usr/sbin/ifconfig', ('-a', ), _addr_re), ('/usr/sbin/ifconfig', ('-a', ), _addr_re),
@ -54,13 +64,15 @@ def find_addresses():
return ["127.0.0.1"] return ["127.0.0.1"]
def _query(path, args, regex): def _query(path, args, regex):
env = {'LANG': 'en_US.UTF-8'} env = {'LANG': 'en_US.UTF-8'}
trial = 0 trial = 0
while True: while True:
trial += 1 trial += 1
try: try:
p = subprocess.Popen([path] + list(args), p = subprocess.Popen(
[path] + list(args),
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
env=env, env=env,

View File

@ -1,8 +1,12 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
from zope.interface import implementer
import contextlib import contextlib
from zope.interface import implementer
from ._interfaces import IJournal from ._interfaces import IJournal
@implementer(IJournal) @implementer(IJournal)
class Journal(object): class Journal(object):
def __init__(self, save_checkpoint): def __init__(self, save_checkpoint):
@ -31,8 +35,10 @@ class Journal(object):
class ImmediateJournal(object): class ImmediateJournal(object):
def __init__(self): def __init__(self):
pass pass
def queue_outbound(self, fn, *args, **kwargs): def queue_outbound(self, fn, *args, **kwargs):
fn(*args, **kwargs) fn(*args, **kwargs)
@contextlib.contextmanager @contextlib.contextmanager
def process(self): def process(self):
yield yield

View File

@ -1,9 +1,11 @@
from __future__ import unicode_literals, print_function from __future__ import print_function, unicode_literals
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.python.failure import Failure from twisted.python.failure import Failure
NoResult = object() NoResult = object()
class OneShotObserver(object): class OneShotObserver(object):
def __init__(self, eventual_queue): def __init__(self, eventual_queue):
self._eq = eventual_queue self._eq = eventual_queue
@ -38,6 +40,7 @@ class OneShotObserver(object):
if self._result is NoResult: if self._result is NoResult:
self.fire(result) self.fire(result)
class SequenceObserver(object): class SequenceObserver(object):
def __init__(self, eventual_queue): def __init__(self, eventual_queue):
self._eq = eventual_queue self._eq = eventual_queue

View File

@ -1,16 +1,19 @@
# no unicode_literals untill twisted update # no unicode_literals untill twisted update
from twisted.application import service, internet
from twisted.internet import defer, task, reactor, endpoints
from twisted.python import log
from click.testing import CliRunner from click.testing import CliRunner
from twisted.application import internet, service
from twisted.internet import defer, endpoints, reactor, task
from twisted.python import log
import mock import mock
from ..cli import cli from wormhole_mailbox_server.database import create_channel_db, create_usage_db
from ..transit import allocate_tcp_port
from wormhole_mailbox_server.server import make_server from wormhole_mailbox_server.server import make_server
from wormhole_mailbox_server.web import make_web_server from wormhole_mailbox_server.web import make_web_server
from wormhole_mailbox_server.database import create_channel_db, create_usage_db
from wormhole_transit_relay.transit_server import Transit from wormhole_transit_relay.transit_server import Transit
from ..cli import cli
from ..transit import allocate_tcp_port
class MyInternetService(service.Service, object): class MyInternetService(service.Service, object):
# like StreamServerEndpointService, but you can retrieve the port # like StreamServerEndpointService, but you can retrieve the port
def __init__(self, endpoint, factory): def __init__(self, endpoint, factory):
@ -22,12 +25,15 @@ class MyInternetService(service.Service, object):
def startService(self): def startService(self):
super(MyInternetService, self).startService() super(MyInternetService, self).startService()
d = self.endpoint.listen(self.factory) d = self.endpoint.listen(self.factory)
def good(lp): def good(lp):
self._lp = lp self._lp = lp
self._port_d.callback(lp.getHost().port) self._port_d.callback(lp.getHost().port)
def bad(f): def bad(f):
log.err(f) log.err(f)
self._port_d.errback(f) self._port_d.errback(f)
d.addCallbacks(good, bad) d.addCallbacks(good, bad)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -38,6 +44,7 @@ class MyInternetService(service.Service, object):
def getPort(self): # only call once! def getPort(self): # only call once!
return self._port_d return self._port_d
class ServerBase: class ServerBase:
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
@ -51,7 +58,8 @@ class ServerBase:
# endpoints.serverFromString # endpoints.serverFromString
db = create_channel_db(":memory:") db = create_channel_db(":memory:")
self._usage_db = create_usage_db(":memory:") self._usage_db = create_usage_db(":memory:")
self._rendezvous = make_server(db, self._rendezvous = make_server(
db,
advertise_version=advertise_version, advertise_version=advertise_version,
signal_error=error, signal_error=error,
usage_db=self._usage_db) usage_db=self._usage_db)
@ -67,11 +75,10 @@ class ServerBase:
# ws://127.0.0.1:%d/wormhole-relay/ws # ws://127.0.0.1:%d/wormhole-relay/ws
self.transitport = allocate_tcp_port() self.transitport = allocate_tcp_port()
ep = endpoints.serverFromString(reactor, ep = endpoints.serverFromString(
"tcp:%d:interface=127.0.0.1" % reactor, "tcp:%d:interface=127.0.0.1" % self.transitport)
self.transitport) self._transit_server = f = Transit(
self._transit_server = f = Transit(blur_usage=None, log_file=None, blur_usage=None, log_file=None, usage_db=None)
usage_db=None)
internet.StreamServerEndpointService(ep, f).setServiceParent(self.sp) internet.StreamServerEndpointService(ep, f).setServiceParent(self.sp)
self.transit = u"tcp:127.0.0.1:%d" % self.transitport self.transit = u"tcp:127.0.0.1:%d" % self.transitport
@ -109,6 +116,7 @@ class ServerBase:
" I convinced all threads to exit.") " I convinced all threads to exit.")
yield d yield d
def config(*argv): def config(*argv):
r = CliRunner() r = CliRunner()
with mock.patch("wormhole.cli.cli.go") as go: with mock.patch("wormhole.cli.cli.go") as go:
@ -121,6 +129,7 @@ def config(*argv):
cfg = go.call_args[0][1] cfg = go.call_args[0][1]
return cfg return cfg
@defer.inlineCallbacks @defer.inlineCallbacks
def poll_until(predicate): def poll_until(predicate):
# return a Deferred that won't fire until the predicate is True # return a Deferred that won't fire until the predicate is True

View File

@ -1,4 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
# This is a tiny helper module, to let "python -m wormhole.test.run_trial # This is a tiny helper module, to let "python -m wormhole.test.run_trial
# ARGS" does the same thing as running "trial ARGS" (unfortunately # ARGS" does the same thing as running "trial ARGS" (unfortunately
# twisted/scripts/trial.py does not have a '__name__=="__main__"' clause). # twisted/scripts/trial.py does not have a '__name__=="__main__"' clause).

View File

@ -1,15 +1,17 @@
import os import os
import sys import sys
import mock
from twisted.trial import unittest from twisted.trial import unittest
import mock
from ..cli.public_relay import RENDEZVOUS_RELAY, TRANSIT_RELAY from ..cli.public_relay import RENDEZVOUS_RELAY, TRANSIT_RELAY
from .common import config from .common import config
#from pprint import pprint
class Send(unittest.TestCase): class Send(unittest.TestCase):
def test_baseline(self): def test_baseline(self):
cfg = config("send", "--text", "hi") cfg = config("send", "--text", "hi")
#pprint(cfg.__dict__)
self.assertEqual(cfg.what, None) self.assertEqual(cfg.what, None)
self.assertEqual(cfg.code, None) self.assertEqual(cfg.code, None)
self.assertEqual(cfg.code_length, 2) self.assertEqual(cfg.code_length, 2)
@ -32,7 +34,6 @@ class Send(unittest.TestCase):
def test_file(self): def test_file(self):
cfg = config("send", "fn") cfg = config("send", "fn")
#pprint(cfg.__dict__)
self.assertEqual(cfg.what, u"fn") self.assertEqual(cfg.what, u"fn")
self.assertEqual(cfg.text, None) self.assertEqual(cfg.text, None)
@ -101,7 +102,6 @@ class Send(unittest.TestCase):
class Receive(unittest.TestCase): class Receive(unittest.TestCase):
def test_baseline(self): def test_baseline(self):
cfg = config("receive") cfg = config("receive")
#pprint(cfg.__dict__)
self.assertEqual(cfg.accept_file, False) self.assertEqual(cfg.accept_file, False)
self.assertEqual(cfg.code, None) self.assertEqual(cfg.code, None)
self.assertEqual(cfg.code_length, 2) self.assertEqual(cfg.code_length, 2)
@ -191,10 +191,12 @@ class Receive(unittest.TestCase):
cfg = config("--transit-helper", transit_url_2, "receive") cfg = config("--transit-helper", transit_url_2, "receive")
self.assertEqual(cfg.transit_helper, transit_url_2) self.assertEqual(cfg.transit_helper, transit_url_2)
class Config(unittest.TestCase): class Config(unittest.TestCase):
def test_send(self): def test_send(self):
cfg = config("send") cfg = config("send")
self.assertEqual(cfg.stdout, sys.stdout) self.assertEqual(cfg.stdout, sys.stdout)
def test_receive(self): def test_receive(self):
cfg = config("receive") cfg = config("receive")
self.assertEqual(cfg.stdout, sys.stdout) self.assertEqual(cfg.stdout, sys.stdout)

View File

@ -1,22 +1,32 @@
from __future__ import print_function from __future__ import print_function
import os, sys, re, io, zipfile, six, stat
from textwrap import fill, dedent import io
from humanize import naturalsize import os
import mock import re
import stat
import sys
import zipfile
from textwrap import dedent, fill
import six
from click.testing import CliRunner from click.testing import CliRunner
from zope.interface import implementer from humanize import naturalsize
from twisted.trial import unittest
from twisted.python import procutils, log
from twisted.internet import endpoints, reactor from twisted.internet import endpoints, reactor
from twisted.internet.utils import getProcessOutputAndValue
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
from twisted.internet.error import ConnectionRefusedError from twisted.internet.error import ConnectionRefusedError
from twisted.internet.utils import getProcessOutputAndValue
from twisted.python import log, procutils
from twisted.trial import unittest
from zope.interface import implementer
import mock
from .. import __version__ from .. import __version__
from .common import ServerBase, config
from ..cli import cmd_send, cmd_receive, welcome, cli
from ..errors import (TransferError, WrongPasswordError, WelcomeError,
UnsendableFileError, ServerConnectionError)
from .._interfaces import ITorManager from .._interfaces import ITorManager
from ..cli import cli, cmd_receive, cmd_send, welcome
from ..errors import (ServerConnectionError, TransferError,
UnsendableFileError, WelcomeError, WrongPasswordError)
from .common import ServerBase, config
def build_offer(args): def build_offer(args):
@ -108,8 +118,8 @@ class OfferData(unittest.TestCase):
self.cfg.cwd = send_dir self.cfg.cwd = send_dir
e = self.assertRaises(TransferError, build_offer, self.cfg) e = self.assertRaises(TransferError, build_offer, self.cfg)
self.assertEqual(str(e), self.assertEqual(
"Cannot send: no file/directory named '%s'" % filename) str(e), "Cannot send: no file/directory named '%s'" % filename)
def _do_test_directory(self, addslash): def _do_test_directory(self, addslash):
parent_dir = self.mktemp() parent_dir = self.mktemp()
@ -178,8 +188,8 @@ class OfferData(unittest.TestCase):
self.assertFalse(os.path.isdir(abs_filename)) self.assertFalse(os.path.isdir(abs_filename))
e = self.assertRaises(TypeError, build_offer, self.cfg) e = self.assertRaises(TypeError, build_offer, self.cfg)
self.assertEqual(str(e), self.assertEqual(
"'%s' is neither file nor directory" % filename) str(e), "'%s' is neither file nor directory" % filename)
def test_symlink(self): def test_symlink(self):
if not hasattr(os, 'symlink'): if not hasattr(os, 'symlink'):
@ -213,7 +223,8 @@ class OfferData(unittest.TestCase):
os.mkdir(os.path.join(parent_dir, "B2", "C2")) os.mkdir(os.path.join(parent_dir, "B2", "C2"))
with open(os.path.join(parent_dir, "B2", "D.txt"), "wb") as f: with open(os.path.join(parent_dir, "B2", "D.txt"), "wb") as f:
f.write(b"success") f.write(b"success")
os.symlink(os.path.abspath(os.path.join(parent_dir, "B2", "C2")), os.symlink(
os.path.abspath(os.path.join(parent_dir, "B2", "C2")),
os.path.join(parent_dir, "B1", "C1")) os.path.join(parent_dir, "B1", "C1"))
# Now send "B1/C1/../D.txt" from A. The correct traversal will be: # Now send "B1/C1/../D.txt" from A. The correct traversal will be:
# * start: A # * start: A
@ -231,6 +242,7 @@ class OfferData(unittest.TestCase):
d, fd_to_send = build_offer(self.cfg) d, fd_to_send = build_offer(self.cfg)
self.assertEqual(d["file"]["filename"], "D.txt") self.assertEqual(d["file"]["filename"], "D.txt")
self.assertEqual(fd_to_send.read(), b"success") self.assertEqual(fd_to_send.read(), b"success")
if os.name == "nt": if os.name == "nt":
test_symlink_collapse.todo = "host OS has broken os.path.realpath()" test_symlink_collapse.todo = "host OS has broken os.path.realpath()"
# ntpath.py's realpath() is built out of normpath(), and does not # ntpath.py's realpath() is built out of normpath(), and does not
@ -241,6 +253,7 @@ class OfferData(unittest.TestCase):
# misbehavior (albeit in rare circumstances), 2: it probably used to # misbehavior (albeit in rare circumstances), 2: it probably used to
# work (sometimes, but not in #251). See cmd_send.py for more notes. # work (sometimes, but not in #251). See cmd_send.py for more notes.
class LocaleFinder: class LocaleFinder:
def __init__(self): def __init__(self):
self._run_once = False self._run_once = False
@ -281,8 +294,11 @@ class LocaleFinder:
if utf8_locales: if utf8_locales:
returnValue(list(utf8_locales.values())[0]) returnValue(list(utf8_locales.values())[0])
returnValue(None) returnValue(None)
locale_finder = LocaleFinder() locale_finder = LocaleFinder()
class ScriptsBase: class ScriptsBase:
def find_executable(self): def find_executable(self):
# to make sure we're running the right executable (in a virtualenv), # to make sure we're running the right executable (in a virtualenv),
@ -292,12 +308,13 @@ class ScriptsBase:
if not locations: if not locations:
raise unittest.SkipTest("unable to find 'wormhole' in $PATH") raise unittest.SkipTest("unable to find 'wormhole' in $PATH")
wormhole = locations[0] wormhole = locations[0]
if (os.path.dirname(os.path.abspath(wormhole)) != if (os.path.dirname(os.path.abspath(wormhole)) != os.path.dirname(
os.path.dirname(sys.executable)): sys.executable)):
log.msg("locations: %s" % (locations, )) log.msg("locations: %s" % (locations, ))
log.msg("sys.executable: %s" % (sys.executable, )) log.msg("sys.executable: %s" % (sys.executable, ))
raise unittest.SkipTest("found the wrong 'wormhole' in $PATH: %s %s" raise unittest.SkipTest(
% (wormhole, sys.executable)) "found the wrong 'wormhole' in $PATH: %s %s" %
(wormhole, sys.executable))
return wormhole return wormhole
@inlineCallbacks @inlineCallbacks
@ -323,8 +340,8 @@ class ScriptsBase:
raise unittest.SkipTest("unable to find UTF-8 locale") raise unittest.SkipTest("unable to find UTF-8 locale")
locale_env = dict(LC_ALL=locale, LANG=locale) locale_env = dict(LC_ALL=locale, LANG=locale)
wormhole = self.find_executable() wormhole = self.find_executable()
res = yield getProcessOutputAndValue(wormhole, ["--version"], res = yield getProcessOutputAndValue(
env=locale_env) wormhole, ["--version"], env=locale_env)
out, err, rc = res out, err, rc = res
if rc != 0: if rc != 0:
log.msg("wormhole not runnable in this tree:") log.msg("wormhole not runnable in this tree:")
@ -334,6 +351,7 @@ class ScriptsBase:
raise unittest.SkipTest("wormhole is not runnable in this tree") raise unittest.SkipTest("wormhole is not runnable in this tree")
returnValue(locale_env) returnValue(locale_env)
class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase): class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase):
# we need Twisted to run the server, but we run the sender and receiver # we need Twisted to run the server, but we run the sender and receiver
# with deferToThread() # with deferToThread()
@ -347,26 +365,30 @@ class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase):
wormhole = self.find_executable() wormhole = self.find_executable()
# we must pass on the environment so that "something" doesn't # we must pass on the environment so that "something" doesn't
# get sad about UTF8 vs. ascii encodings # get sad about UTF8 vs. ascii encodings
out, err, rc = yield getProcessOutputAndValue(wormhole, ["--version"], out, err, rc = yield getProcessOutputAndValue(
env=os.environ) wormhole, ["--version"], env=os.environ)
err = err.decode("utf-8") err = err.decode("utf-8")
if "DistributionNotFound" in err: if "DistributionNotFound" in err:
log.msg("stderr was %s" % err) log.msg("stderr was %s" % err)
last = err.strip().split("\n")[-1] last = err.strip().split("\n")[-1]
self.fail("wormhole not runnable: %s" % last) self.fail("wormhole not runnable: %s" % last)
ver = out.decode("utf-8") or err ver = out.decode("utf-8") or err
self.failUnlessEqual(ver.strip(), "magic-wormhole {}".format(__version__)) self.failUnlessEqual(ver.strip(),
"magic-wormhole {}".format(__version__))
self.failUnlessEqual(rc, 0) self.failUnlessEqual(rc, 0)
@implementer(ITorManager) @implementer(ITorManager)
class FakeTor: class FakeTor:
# use normal endpoints, but record the fact that we were asked # use normal endpoints, but record the fact that we were asked
def __init__(self): def __init__(self):
self.endpoints = [] self.endpoints = []
def stream_via(self, host, port): def stream_via(self, host, port):
self.endpoints.append((host, port)) self.endpoints.append((host, port))
return endpoints.HostnameEndpoint(reactor, host, port) return endpoints.HostnameEndpoint(reactor, host, port)
class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
# we need Twisted to run the server, but we run the sender and receiver # we need Twisted to run the server, but we run the sender and receiver
# with deferToThread() # with deferToThread()
@ -377,11 +399,16 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
yield ServerBase.setUp(self) yield ServerBase.setUp(self)
@inlineCallbacks @inlineCallbacks
def _do_test(self, as_subprocess=False, def _do_test(self,
mode="text", addslash=False, override_filename=False, as_subprocess=False,
fake_tor=False, overwrite=False, mock_accept=False): mode="text",
assert mode in ("text", "file", "empty-file", "directory", addslash=False,
"slow-text", "slow-sender-text") override_filename=False,
fake_tor=False,
overwrite=False,
mock_accept=False):
assert mode in ("text", "file", "empty-file", "directory", "slow-text",
"slow-sender-text")
if fake_tor: if fake_tor:
assert not as_subprocess assert not as_subprocess
send_cfg = config("send") send_cfg = config("send")
@ -433,8 +460,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
# expect: $receive_dir/$dirname/[12345] # expect: $receive_dir/$dirname/[12345]
send_dirname = u"testdir" send_dirname = u"testdir"
def message(i): def message(i):
return "test message %d\n" % i return "test message %d\n" % i
os.mkdir(os.path.join(send_dir, u"middle")) os.mkdir(os.path.join(send_dir, u"middle"))
source_dir = os.path.join(send_dir, u"middle", send_dirname) source_dir = os.path.join(send_dir, u"middle", send_dirname)
os.mkdir(source_dir) os.mkdir(source_dir)
@ -476,21 +505,27 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
env["_MAGIC_WORMHOLE_TEST_KEY_TIMER"] = "999999" env["_MAGIC_WORMHOLE_TEST_KEY_TIMER"] = "999999"
env["_MAGIC_WORMHOLE_TEST_VERIFY_TIMER"] = "999999" env["_MAGIC_WORMHOLE_TEST_VERIFY_TIMER"] = "999999"
send_args = [ send_args = [
'--relay-url', self.relayurl, '--relay-url',
'--transit-helper', '', self.relayurl,
'--transit-helper',
'',
'send', 'send',
'--hide-progress', '--hide-progress',
'--code', send_cfg.code, '--code',
send_cfg.code,
] + content_args ] + content_args
send_d = getProcessOutputAndValue( send_d = getProcessOutputAndValue(
wormhole_bin, send_args, wormhole_bin,
send_args,
path=send_dir, path=send_dir,
env=env, env=env,
) )
recv_args = [ recv_args = [
'--relay-url', self.relayurl, '--relay-url',
'--transit-helper', '', self.relayurl,
'--transit-helper',
'',
'receive', 'receive',
'--hide-progress', '--hide-progress',
'--accept-file', '--accept-file',
@ -500,7 +535,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
recv_args.extend(['-o', receive_filename]) recv_args.extend(['-o', receive_filename])
receive_d = getProcessOutputAndValue( receive_d = getProcessOutputAndValue(
wormhole_bin, recv_args, wormhole_bin,
recv_args,
path=receive_dir, path=receive_dir,
env=env, env=env,
) )
@ -524,7 +560,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
send_cfg.tor = True send_cfg.tor = True
send_cfg.transit_helper = self.transit send_cfg.transit_helper = self.transit
tx_tm = FakeTor() tx_tm = FakeTor()
with mock.patch("wormhole.tor_manager.get_tor", with mock.patch(
"wormhole.tor_manager.get_tor",
return_value=tx_tm, return_value=tx_tm,
) as mtx_tm: ) as mtx_tm:
send_d = cmd_send.send(send_cfg) send_d = cmd_send.send(send_cfg)
@ -532,7 +569,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
recv_cfg.tor = True recv_cfg.tor = True
recv_cfg.transit_helper = self.transit recv_cfg.transit_helper = self.transit
rx_tm = FakeTor() rx_tm = FakeTor()
with mock.patch("wormhole.tor_manager.get_tor", with mock.patch(
"wormhole.tor_manager.get_tor",
return_value=rx_tm, return_value=rx_tm,
) as mrx_tm: ) as mrx_tm:
receive_d = cmd_receive.receive(recv_cfg) receive_d = cmd_receive.receive(recv_cfg)
@ -541,8 +579,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
rxw = [] rxw = []
with mock.patch.object(cmd_receive, "KEY_TIMER", KEY_TIMER): with mock.patch.object(cmd_receive, "KEY_TIMER", KEY_TIMER):
send_d = cmd_send.send(send_cfg) send_d = cmd_send.send(send_cfg)
receive_d = cmd_receive.receive(recv_cfg, receive_d = cmd_receive.receive(
_debug_stash_wormhole=rxw) recv_cfg, _debug_stash_wormhole=rxw)
# we need to keep KEY_TIMER patched until the receiver # we need to keep KEY_TIMER patched until the receiver
# gets far enough to start the timer, which happens after # gets far enough to start the timer, which happens after
# the code is set # the code is set
@ -555,8 +593,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
with mock.patch.object(cmd_receive, "VERIFY_TIMER", VERIFY_TIMER): with mock.patch.object(cmd_receive, "VERIFY_TIMER", VERIFY_TIMER):
with mock.patch.object(cmd_send, "VERIFY_TIMER", VERIFY_TIMER): with mock.patch.object(cmd_send, "VERIFY_TIMER", VERIFY_TIMER):
if mock_accept: if mock_accept:
with mock.patch.object(cmd_receive.six.moves, with mock.patch.object(
'input', return_value='y'): cmd_receive.six.moves, 'input',
return_value='y'):
yield gatherResults([send_d, receive_d], True) yield gatherResults([send_d, receive_d], True)
else: else:
yield gatherResults([send_d, receive_d], True) yield gatherResults([send_d, receive_d], True)
@ -567,14 +606,14 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
expected_endpoints.append(("127.0.0.1", self.transitport)) expected_endpoints.append(("127.0.0.1", self.transitport))
tx_timing = mtx_tm.call_args[1]["timing"] tx_timing = mtx_tm.call_args[1]["timing"]
self.assertEqual(tx_tm.endpoints, expected_endpoints) self.assertEqual(tx_tm.endpoints, expected_endpoints)
self.assertEqual(mtx_tm.mock_calls, self.assertEqual(
[mock.call(reactor, False, None, mtx_tm.mock_calls,
timing=tx_timing)]) [mock.call(reactor, False, None, timing=tx_timing)])
rx_timing = mrx_tm.call_args[1]["timing"] rx_timing = mrx_tm.call_args[1]["timing"]
self.assertEqual(rx_tm.endpoints, expected_endpoints) self.assertEqual(rx_tm.endpoints, expected_endpoints)
self.assertEqual(mrx_tm.mock_calls, self.assertEqual(
[mock.call(reactor, False, None, mrx_tm.mock_calls,
timing=rx_timing)]) [mock.call(reactor, False, None, timing=rx_timing)])
send_stdout = send_cfg.stdout.getvalue() send_stdout = send_cfg.stdout.getvalue()
send_stderr = send_cfg.stderr.getvalue() send_stderr = send_cfg.stderr.getvalue()
@ -600,34 +639,37 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
"On the other computer, please run:{NL}{NL}" "On the other computer, please run:{NL}{NL}"
"wormhole receive {code}{NL}{NL}" "wormhole receive {code}{NL}{NL}"
"{KE}" "{KE}"
"text message sent{NL}").format(bytes=len(message), "text message sent{NL}").format(
bytes=len(message),
code=send_cfg.code, code=send_cfg.code,
NL=NL, NL=NL,
KE=key_established) KE=key_established)
self.failUnlessEqual(send_stderr, expected) self.failUnlessEqual(send_stderr, expected)
elif mode == "file": elif mode == "file":
self.failUnlessIn(u"Sending {size:s} file named '{name}'{NL}" self.failUnlessIn(u"Sending {size:s} file named '{name}'{NL}"
.format(size=naturalsize(len(message)), .format(
size=naturalsize(len(message)),
name=send_filename, name=send_filename,
NL=NL), send_stderr) NL=NL), send_stderr)
self.failUnlessIn(u"Wormhole code is: {code}{NL}" self.failUnlessIn(u"Wormhole code is: {code}{NL}"
"On the other computer, please run:{NL}{NL}" "On the other computer, please run:{NL}{NL}"
"wormhole receive {code}{NL}{NL}" "wormhole receive {code}{NL}{NL}".format(
.format(code=send_cfg.code, NL=NL), code=send_cfg.code, NL=NL), send_stderr)
self.failUnlessIn(
u"File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}".format(NL=NL),
send_stderr) send_stderr)
self.failUnlessIn(u"File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}"
.format(NL=NL), send_stderr)
elif mode == "directory": elif mode == "directory":
self.failUnlessIn(u"Sending directory", send_stderr) self.failUnlessIn(u"Sending directory", send_stderr)
self.failUnlessIn(u"named 'testdir'", send_stderr) self.failUnlessIn(u"named 'testdir'", send_stderr)
self.failUnlessIn(u"Wormhole code is: {code}{NL}" self.failUnlessIn(u"Wormhole code is: {code}{NL}"
"On the other computer, please run:{NL}{NL}" "On the other computer, please run:{NL}{NL}"
"wormhole receive {code}{NL}{NL}" "wormhole receive {code}{NL}{NL}".format(
.format(code=send_cfg.code, NL=NL), send_stderr) code=send_cfg.code, NL=NL), send_stderr)
self.failUnlessIn(u"File sent.. waiting for confirmation{NL}" self.failUnlessIn(
"Confirmation received. Transfer complete.{NL}" u"File sent.. waiting for confirmation{NL}"
.format(NL=NL), send_stderr) "Confirmation received. Transfer complete.{NL}".format(NL=NL),
send_stderr)
# check receiver # check receiver
if mode in ("text", "slow-text", "slow-sender-text"): if mode in ("text", "slow-text", "slow-sender-text"):
@ -640,9 +682,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.assertEqual(receive_stderr, "Waiting for sender...\n") self.assertEqual(receive_stderr, "Waiting for sender...\n")
elif mode == "file": elif mode == "file":
self.failUnlessEqual(receive_stdout, "") self.failUnlessEqual(receive_stdout, "")
self.failUnlessIn(u"Receiving file ({size:s}) into: {name}" self.failUnlessIn(u"Receiving file ({size:s}) into: {name}".format(
.format(size=naturalsize(len(message)), size=naturalsize(len(message)), name=receive_filename),
name=receive_filename), receive_stderr) receive_stderr)
self.failUnlessIn(u"Received file written to ", receive_stderr) self.failUnlessIn(u"Received file written to ", receive_stderr)
fn = os.path.join(receive_dir, receive_filename) fn = os.path.join(receive_dir, receive_filename)
self.failUnless(os.path.exists(fn)) self.failUnless(os.path.exists(fn))
@ -652,52 +694,67 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.failUnlessEqual(receive_stdout, "") self.failUnlessEqual(receive_stdout, "")
want = (r"Receiving directory \(\d+ \w+\) into: {name}/" want = (r"Receiving directory \(\d+ \w+\) into: {name}/"
.format(name=receive_dirname)) .format(name=receive_dirname))
self.failUnless(re.search(want, receive_stderr), self.failUnless(
(want, receive_stderr)) re.search(want, receive_stderr), (want, receive_stderr))
self.failUnlessIn(u"Received files written to {name}" self.failUnlessIn(
.format(name=receive_dirname), receive_stderr) u"Received files written to {name}"
.format(name=receive_dirname),
receive_stderr)
fn = os.path.join(receive_dir, receive_dirname) fn = os.path.join(receive_dir, receive_dirname)
self.failUnless(os.path.exists(fn), fn) self.failUnless(os.path.exists(fn), fn)
for i in range(5): for i in range(5):
fn = os.path.join(receive_dir, receive_dirname, str(i)) fn = os.path.join(receive_dir, receive_dirname, str(i))
with open(fn, "r") as f: with open(fn, "r") as f:
self.failUnlessEqual(f.read(), message(i)) self.failUnlessEqual(f.read(), message(i))
self.failUnlessEqual(modes[i], self.failUnlessEqual(modes[i], stat.S_IMODE(
stat.S_IMODE(os.stat(fn).st_mode)) os.stat(fn).st_mode))
def test_text(self): def test_text(self):
return self._do_test() return self._do_test()
def test_text_subprocess(self): def test_text_subprocess(self):
return self._do_test(as_subprocess=True) return self._do_test(as_subprocess=True)
def test_text_tor(self): def test_text_tor(self):
return self._do_test(fake_tor=True) return self._do_test(fake_tor=True)
def test_file(self): def test_file(self):
return self._do_test(mode="file") return self._do_test(mode="file")
def test_file_override(self): def test_file_override(self):
return self._do_test(mode="file", override_filename=True) return self._do_test(mode="file", override_filename=True)
def test_file_overwrite(self): def test_file_overwrite(self):
return self._do_test(mode="file", overwrite=True) return self._do_test(mode="file", overwrite=True)
def test_file_overwrite_mock_accept(self): def test_file_overwrite_mock_accept(self):
return self._do_test(mode="file", overwrite=True, mock_accept=True) return self._do_test(mode="file", overwrite=True, mock_accept=True)
def test_file_tor(self): def test_file_tor(self):
return self._do_test(mode="file", fake_tor=True) return self._do_test(mode="file", fake_tor=True)
def test_empty_file(self): def test_empty_file(self):
return self._do_test(mode="empty-file") return self._do_test(mode="empty-file")
def test_directory(self): def test_directory(self):
return self._do_test(mode="directory") return self._do_test(mode="directory")
def test_directory_addslash(self): def test_directory_addslash(self):
return self._do_test(mode="directory", addslash=True) return self._do_test(mode="directory", addslash=True)
def test_directory_override(self): def test_directory_override(self):
return self._do_test(mode="directory", override_filename=True) return self._do_test(mode="directory", override_filename=True)
def test_directory_overwrite(self): def test_directory_overwrite(self):
return self._do_test(mode="directory", overwrite=True) return self._do_test(mode="directory", overwrite=True)
def test_directory_overwrite_mock_accept(self): def test_directory_overwrite_mock_accept(self):
return self._do_test(mode="directory", overwrite=True, mock_accept=True) return self._do_test(
mode="directory", overwrite=True, mock_accept=True)
def test_slow_text(self): def test_slow_text(self):
return self._do_test(mode="slow-text") return self._do_test(mode="slow-text")
def test_slow_sender_text(self): def test_slow_sender_text(self):
return self._do_test(mode="slow-sender-text") return self._do_test(mode="slow-sender-text")
@ -765,10 +822,12 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
free_space = 10000000 free_space = 10000000
else: else:
free_space = 0 free_space = 0
with mock.patch("wormhole.cli.cmd_receive.estimate_free_space", with mock.patch(
"wormhole.cli.cmd_receive.estimate_free_space",
return_value=free_space): return_value=free_space):
f = yield self.assertFailure(send_d, TransferError) f = yield self.assertFailure(send_d, TransferError)
self.assertEqual(str(f), "remote error, transfer abandoned: transfer rejected") self.assertEqual(
str(f), "remote error, transfer abandoned: transfer rejected")
f = yield self.assertFailure(receive_d, TransferError) f = yield self.assertFailure(receive_d, TransferError)
self.assertEqual(str(f), "transfer rejected") self.assertEqual(str(f), "transfer rejected")
@ -789,54 +848,63 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
# check sender # check sender
if mode == "file": if mode == "file":
self.failUnlessIn("Sending {size:s} file named '{name}'{NL}" self.failUnlessIn("Sending {size:s} file named '{name}'{NL}"
.format(size=naturalsize(size), .format(
size=naturalsize(size),
name=send_filename, name=send_filename,
NL=NL), send_stderr) NL=NL), send_stderr)
self.failUnlessIn("Wormhole code is: {code}{NL}" self.failUnlessIn("Wormhole code is: {code}{NL}"
"On the other computer, please run:{NL}{NL}" "On the other computer, please run:{NL}{NL}"
"wormhole receive {code}{NL}" "wormhole receive {code}{NL}".format(
.format(code=send_cfg.code, NL=NL), code=send_cfg.code, NL=NL), send_stderr)
self.failIfIn(
"File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}".format(NL=NL),
send_stderr) send_stderr)
self.failIfIn("File sent.. waiting for confirmation{NL}"
"Confirmation received. Transfer complete.{NL}"
.format(NL=NL), send_stderr)
elif mode == "directory": elif mode == "directory":
self.failUnlessIn("Sending directory", send_stderr) self.failUnlessIn("Sending directory", send_stderr)
self.failUnlessIn("named 'testdir'", send_stderr) self.failUnlessIn("named 'testdir'", send_stderr)
self.failUnlessIn("Wormhole code is: {code}{NL}" self.failUnlessIn("Wormhole code is: {code}{NL}"
"On the other computer, please run:{NL}{NL}" "On the other computer, please run:{NL}{NL}"
"wormhole receive {code}{NL}" "wormhole receive {code}{NL}".format(
.format(code=send_cfg.code, NL=NL), send_stderr) code=send_cfg.code, NL=NL), send_stderr)
self.failIfIn("File sent.. waiting for confirmation{NL}" self.failIfIn(
"Confirmation received. Transfer complete.{NL}" "File sent.. waiting for confirmation{NL}"
.format(NL=NL), send_stderr) "Confirmation received. Transfer complete.{NL}".format(NL=NL),
send_stderr)
# check receiver # check receiver
if mode == "file": if mode == "file":
self.failIfIn("Received file written to ", receive_stderr) self.failIfIn("Received file written to ", receive_stderr)
if failmode == "noclobber": if failmode == "noclobber":
self.failUnlessIn("Error: " self.failUnlessIn(
"Error: "
"refusing to overwrite existing 'testfile'{NL}" "refusing to overwrite existing 'testfile'{NL}"
.format(NL=NL), receive_stderr) .format(NL=NL),
receive_stderr)
else: else:
self.failUnlessIn("Error: " self.failUnlessIn(
"Error: "
"insufficient free space (0B) for file ({size:d}B){NL}" "insufficient free space (0B) for file ({size:d}B){NL}"
.format(NL=NL, size=size), receive_stderr) .format(NL=NL, size=size), receive_stderr)
elif mode == "directory": elif mode == "directory":
self.failIfIn("Received files written to {name}" self.failIfIn(
.format(name=receive_name), receive_stderr) "Received files written to {name}".format(name=receive_name),
receive_stderr)
# want = (r"Receiving directory \(\d+ \w+\) into: {name}/" # want = (r"Receiving directory \(\d+ \w+\) into: {name}/"
# .format(name=receive_name)) # .format(name=receive_name))
# self.failUnless(re.search(want, receive_stderr), # self.failUnless(re.search(want, receive_stderr),
# (want, receive_stderr)) # (want, receive_stderr))
if failmode == "noclobber": if failmode == "noclobber":
self.failUnlessIn("Error: " self.failUnlessIn(
"Error: "
"refusing to overwrite existing 'testdir'{NL}" "refusing to overwrite existing 'testdir'{NL}"
.format(NL=NL), receive_stderr) .format(NL=NL),
receive_stderr)
else: else:
self.failUnlessIn("Error: " self.failUnlessIn(("Error: "
"insufficient free space (0B) for directory ({size:d}B){NL}" "insufficient free space (0B) for directory"
.format(NL=NL, size=size), receive_stderr) " ({size:d}B){NL}").format(
NL=NL, size=size), receive_stderr)
if failmode == "noclobber": if failmode == "noclobber":
fn = os.path.join(receive_dir, receive_name) fn = os.path.join(receive_dir, receive_name)
@ -846,13 +914,17 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
def test_fail_file_noclobber(self): def test_fail_file_noclobber(self):
return self._do_test_fail("file", "noclobber") return self._do_test_fail("file", "noclobber")
def test_fail_directory_noclobber(self): def test_fail_directory_noclobber(self):
return self._do_test_fail("directory", "noclobber") return self._do_test_fail("directory", "noclobber")
def test_fail_file_toobig(self): def test_fail_file_toobig(self):
return self._do_test_fail("file", "toobig") return self._do_test_fail("file", "toobig")
def test_fail_directory_toobig(self): def test_fail_directory_toobig(self):
return self._do_test_fail("directory", "toobig") return self._do_test_fail("directory", "toobig")
class ZeroMode(ServerBase, unittest.TestCase): class ZeroMode(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_text(self): def test_text(self):
@ -898,15 +970,15 @@ class ZeroMode(ServerBase, unittest.TestCase):
"{NL}" "{NL}"
"wormhole receive -0{NL}" "wormhole receive -0{NL}"
"{NL}" "{NL}"
"text message sent{NL}").format(bytes=len(message), "text message sent{NL}").format(
code=send_cfg.code, bytes=len(message), code=send_cfg.code, NL=NL)
NL=NL)
self.failUnlessEqual(send_stderr, expected) self.failUnlessEqual(send_stderr, expected)
# check receiver # check receiver
self.assertEqual(receive_stdout, message + NL) self.assertEqual(receive_stdout, message + NL)
self.assertEqual(receive_stderr, "") self.assertEqual(receive_stderr, "")
class NotWelcome(ServerBase, unittest.TestCase): class NotWelcome(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def setUp(self): def setUp(self):
@ -936,6 +1008,7 @@ class NotWelcome(ServerBase, unittest.TestCase):
f = yield self.assertFailure(receive_d, WelcomeError) f = yield self.assertFailure(receive_d, WelcomeError)
self.assertEqual(str(f), "please upgrade XYZ") self.assertEqual(str(f), "please upgrade XYZ")
class NoServer(ServerBase, unittest.TestCase): class NoServer(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def setUp(self): def setUp(self):
@ -991,8 +1064,8 @@ class NoServer(ServerBase, unittest.TestCase):
e = yield self.assertFailure(receive_d, ServerConnectionError) e = yield self.assertFailure(receive_d, ServerConnectionError)
self.assertIsInstance(e.reason, ConnectionRefusedError) self.assertIsInstance(e.reason, ConnectionRefusedError)
class Cleanup(ServerBase, unittest.TestCase):
class Cleanup(ServerBase, unittest.TestCase):
def make_config(self): def make_config(self):
cfg = config("send") cfg = config("send")
# common options for all tests in this suite # common options for all tests in this suite
@ -1040,6 +1113,7 @@ class Cleanup(ServerBase, unittest.TestCase):
cids = self._rendezvous.get_app(cmd_send.APPID).get_nameplate_ids() cids = self._rendezvous.get_app(cmd_send.APPID).get_nameplate_ids()
self.assertEqual(len(cids), 0) self.assertEqual(len(cids), 0)
class ExtractFile(unittest.TestCase): class ExtractFile(unittest.TestCase):
def test_filenames(self): def test_filenames(self):
args = mock.Mock() args = mock.Mock()
@ -1082,6 +1156,7 @@ class ExtractFile(unittest.TestCase):
e = self.assertRaises(ValueError, ef, zf, zi, extract_dir) e = self.assertRaises(ValueError, ef, zf, zi, extract_dir)
self.assertIn("malicious zipfile", str(e)) self.assertIn("malicious zipfile", str(e))
class AppID(ServerBase, unittest.TestCase): class AppID(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def setUp(self): def setUp(self):
@ -1108,11 +1183,11 @@ class AppID(ServerBase, unittest.TestCase):
yield receive_d yield receive_d
used = self._usage_db.execute("SELECT DISTINCT `app_id`" used = self._usage_db.execute("SELECT DISTINCT `app_id`"
" FROM `nameplates`" " FROM `nameplates`").fetchall()
).fetchall()
self.assertEqual(len(used), 1, used) self.assertEqual(len(used), 1, used)
self.assertEqual(used[0]["app_id"], u"appid2") self.assertEqual(used[0]["app_id"], u"appid2")
class Welcome(unittest.TestCase): class Welcome(unittest.TestCase):
def do(self, welcome_message, my_version="2.0"): def do(self, welcome_message, my_version="2.0"):
stderr = io.StringIO() stderr = io.StringIO()
@ -1129,27 +1204,33 @@ class Welcome(unittest.TestCase):
def test_version_old(self): def test_version_old(self):
stderr = self.do({"current_cli_version": "3.0"}) stderr = self.do({"current_cli_version": "3.0"})
expected = ("Warning: errors may occur unless both sides are running the same version\n" + expected = ("Warning: errors may occur unless both sides are"
" running the same version\n"
"Server claims 3.0 is current, but ours is 2.0\n") "Server claims 3.0 is current, but ours is 2.0\n")
self.assertEqual(stderr, expected) self.assertEqual(stderr, expected)
def test_version_unreleased(self): def test_version_unreleased(self):
stderr = self.do({"current_cli_version": "3.0"}, stderr = self.do(
my_version="2.5+middle.something") {
"current_cli_version": "3.0"
}, my_version="2.5+middle.something")
self.assertEqual(stderr, "") self.assertEqual(stderr, "")
def test_motd(self): def test_motd(self):
stderr = self.do({"motd": "hello"}) stderr = self.do({"motd": "hello"})
self.assertEqual(stderr, "Server (at url) says:\n hello\n") self.assertEqual(stderr, "Server (at url) says:\n hello\n")
class Dispatch(unittest.TestCase): class Dispatch(unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_success(self): def test_success(self):
cfg = config("send") cfg = config("send")
cfg.stderr = io.StringIO() cfg.stderr = io.StringIO()
called = [] called = []
def fake(): def fake():
called.append(1) called.append(1)
yield cli._dispatch_command(reactor, cfg, fake) yield cli._dispatch_command(reactor, cfg, fake)
self.assertEqual(called, [1]) self.assertEqual(called, [1])
self.assertEqual(cfg.stderr.getvalue(), "") self.assertEqual(cfg.stderr.getvalue(), "")
@ -1160,8 +1241,10 @@ class Dispatch(unittest.TestCase):
cfg.stderr = io.StringIO() cfg.stderr = io.StringIO()
cfg.timing = mock.Mock() cfg.timing = mock.Mock()
cfg.dump_timing = "filename" cfg.dump_timing = "filename"
def fake(): def fake():
pass pass
yield cli._dispatch_command(reactor, cfg, fake) yield cli._dispatch_command(reactor, cfg, fake)
self.assertEqual(cfg.stderr.getvalue(), "") self.assertEqual(cfg.stderr.getvalue(), "")
self.assertEqual(cfg.timing.mock_calls[-1], self.assertEqual(cfg.timing.mock_calls[-1],
@ -1171,10 +1254,12 @@ class Dispatch(unittest.TestCase):
def test_wrong_password_error(self): def test_wrong_password_error(self):
cfg = config("send") cfg = config("send")
cfg.stderr = io.StringIO() cfg.stderr = io.StringIO()
def fake(): def fake():
raise WrongPasswordError("abcd") raise WrongPasswordError("abcd")
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
SystemExit) yield self.assertFailure(
cli._dispatch_command(reactor, cfg, fake), SystemExit)
expected = fill("ERROR: " + dedent(WrongPasswordError.__doc__)) + "\n" expected = fill("ERROR: " + dedent(WrongPasswordError.__doc__)) + "\n"
self.assertEqual(cfg.stderr.getvalue(), expected) self.assertEqual(cfg.stderr.getvalue(), expected)
@ -1182,21 +1267,26 @@ class Dispatch(unittest.TestCase):
def test_welcome_error(self): def test_welcome_error(self):
cfg = config("send") cfg = config("send")
cfg.stderr = io.StringIO() cfg.stderr = io.StringIO()
def fake(): def fake():
raise WelcomeError("abcd") raise WelcomeError("abcd")
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
SystemExit) yield self.assertFailure(
expected = fill("ERROR: " + dedent(WelcomeError.__doc__))+"\n\nabcd\n" cli._dispatch_command(reactor, cfg, fake), SystemExit)
expected = (
fill("ERROR: " + dedent(WelcomeError.__doc__)) + "\n\nabcd\n")
self.assertEqual(cfg.stderr.getvalue(), expected) self.assertEqual(cfg.stderr.getvalue(), expected)
@inlineCallbacks @inlineCallbacks
def test_transfer_error(self): def test_transfer_error(self):
cfg = config("send") cfg = config("send")
cfg.stderr = io.StringIO() cfg.stderr = io.StringIO()
def fake(): def fake():
raise TransferError("abcd") raise TransferError("abcd")
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
SystemExit) yield self.assertFailure(
cli._dispatch_command(reactor, cfg, fake), SystemExit)
expected = "TransferError: abcd\n" expected = "TransferError: abcd\n"
self.assertEqual(cfg.stderr.getvalue(), expected) self.assertEqual(cfg.stderr.getvalue(), expected)
@ -1204,11 +1294,14 @@ class Dispatch(unittest.TestCase):
def test_server_connection_error(self): def test_server_connection_error(self):
cfg = config("send") cfg = config("send")
cfg.stderr = io.StringIO() cfg.stderr = io.StringIO()
def fake(): def fake():
raise ServerConnectionError("URL", ValueError("abcd")) raise ServerConnectionError("URL", ValueError("abcd"))
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake),
SystemExit) yield self.assertFailure(
expected = fill("ERROR: " + dedent(ServerConnectionError.__doc__))+"\n" cli._dispatch_command(reactor, cfg, fake), SystemExit)
expected = fill(
"ERROR: " + dedent(ServerConnectionError.__doc__)) + "\n"
expected += "(relay URL was URL)\n" expected += "(relay URL was URL)\n"
expected += "abcd\n" expected += "abcd\n"
self.assertEqual(cfg.stderr.getvalue(), expected) self.assertEqual(cfg.stderr.getvalue(), expected)
@ -1217,21 +1310,26 @@ class Dispatch(unittest.TestCase):
def test_other_error(self): def test_other_error(self):
cfg = config("send") cfg = config("send")
cfg.stderr = io.StringIO() cfg.stderr = io.StringIO()
def fake(): def fake():
raise ValueError("abcd") raise ValueError("abcd")
# I'm seeing unicode problems with the Failure().printTraceback, and # I'm seeing unicode problems with the Failure().printTraceback, and
# the output would be kind of unpredictable anyways, so we'll mock it # the output would be kind of unpredictable anyways, so we'll mock it
# out here. # out here.
f = mock.Mock() f = mock.Mock()
def mock_print(file): def mock_print(file):
file.write(u"<TRACEBACK>\n") file.write(u"<TRACEBACK>\n")
f.printTraceback = mock_print f.printTraceback = mock_print
with mock.patch("wormhole.cli.cli.Failure", return_value=f): with mock.patch("wormhole.cli.cli.Failure", return_value=f):
yield self.assertFailure(cli._dispatch_command(reactor, cfg, fake), yield self.assertFailure(
SystemExit) cli._dispatch_command(reactor, cfg, fake), SystemExit)
expected = "<TRACEBACK>\nERROR: abcd\n" expected = "<TRACEBACK>\nERROR: abcd\n"
self.assertEqual(cfg.stderr.getvalue(), expected) self.assertEqual(cfg.stderr.getvalue(), expected)
class Help(unittest.TestCase): class Help(unittest.TestCase):
def _check_top_level_help(self, got): def _check_top_level_help(self, got):
# the main wormhole.cli.cli.wormhole docstring should be in the # the main wormhole.cli.cli.wormhole docstring should be in the

View File

@ -1,14 +1,19 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import mock
from twisted.trial import unittest
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.task import Clock
from twisted.internet.defer import Deferred, inlineCallbacks from twisted.internet.defer import Deferred, inlineCallbacks
from twisted.internet.task import Clock
from twisted.trial import unittest
import mock
from ..eventual import EventualQueue from ..eventual import EventualQueue
class IntentionalError(Exception): class IntentionalError(Exception):
pass pass
class Eventual(unittest.TestCase, object): class Eventual(unittest.TestCase, object):
def test_eventually(self): def test_eventually(self):
c = Clock() c = Clock()
@ -23,9 +28,10 @@ class Eventual(unittest.TestCase, object):
self.assertNoResult(d3) self.assertNoResult(d3)
eq.flush_sync() eq.flush_sync()
self.assertEqual(c1.mock_calls, self.assertEqual(c1.mock_calls, [
[mock.call("arg1", "arg2", kwarg1="kw1"), mock.call("arg1", "arg2", kwarg1="kw1"),
mock.call("arg3", "arg4", kwarg5="kw5")]) mock.call("arg3", "arg4", kwarg5="kw5")
])
self.assertEqual(self.successResultOf(d2), None) self.assertEqual(self.successResultOf(d2), None)
self.assertEqual(self.successResultOf(d3), "value") self.assertEqual(self.successResultOf(d3), "value")
@ -47,11 +53,12 @@ class Eventual(unittest.TestCase, object):
eq = EventualQueue(reactor) eq = EventualQueue(reactor)
d1 = eq.fire_eventually() d1 = eq.fire_eventually()
d2 = Deferred() d2 = Deferred()
def _more(res): def _more(res):
eq.eventually(d2.callback, None) eq.eventually(d2.callback, None)
d1.addCallback(_more) d1.addCallback(_more)
yield eq.flush() yield eq.flush()
# d1 will fire, which will queue d2 to fire, and the flush() ought to # d1 will fire, which will queue d2 to fire, and the flush() ought to
# wait for d2 too # wait for d2 too
self.successResultOf(d2) self.successResultOf(d2)

View File

@ -1,6 +1,8 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import unittest import unittest
from binascii import unhexlify # , hexlify from binascii import unhexlify # , hexlify
from hkdf import Hkdf from hkdf import Hkdf
# def generate_KAT(): # def generate_KAT():
@ -29,16 +31,17 @@ KAT = [
'10336ead9370927eaabb0d60b259346ee5f57eb7ceba8c72f1ed3f2932b1bf19'), '10336ead9370927eaabb0d60b259346ee5f57eb7ceba8c72f1ed3f2932b1bf19'),
] ]
class TestKAT(unittest.TestCase): class TestKAT(unittest.TestCase):
# note: this uses SHA256 # note: this uses SHA256
def test_kat(self): def test_kat(self):
for (salt, context, skm, expected_hexout) in KAT: for (salt, context, skm, expected_hexout) in KAT:
expected_out = unhexlify(expected_hexout) expected_out = unhexlify(expected_hexout)
for outlen in range(0, len(expected_out)): for outlen in range(0, len(expected_out)):
out = Hkdf(salt.encode("ascii"), out = Hkdf(salt.encode("ascii"), skm.encode("ascii")).expand(
skm.encode("ascii")).expand(context.encode("ascii"), context.encode("ascii"), outlen)
outlen)
self.assertEqual(out, expected_out[:outlen]) self.assertEqual(out, expected_out[:outlen])
# if __name__ == '__main__': # if __name__ == '__main__':
# generate_KAT() # generate_KAT()

View File

@ -1,6 +1,10 @@
import errno
import os
import re
import subprocess
import re, errno, subprocess, os
from twisted.trial import unittest from twisted.trial import unittest
from .. import ipaddrs from .. import ipaddrs
DOTTED_QUAD_RE = re.compile("^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$") DOTTED_QUAD_RE = re.compile("^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$")
@ -11,12 +15,14 @@ MOCK_IPADDR_OUTPUT = """\
inet 127.0.0.1/8 scope host lo inet 127.0.0.1/8 scope host lo
inet6 ::1/128 scope host \n\ inet6 ::1/128 scope host \n\
valid_lft forever preferred_lft forever valid_lft forever preferred_lft forever
2: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc pfifo_fast state UP qlen 1000 2: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc pfifo_fast state UP \
qlen 1000
link/ether d4:3d:7e:01:b4:3e brd ff:ff:ff:ff:ff:ff link/ether d4:3d:7e:01:b4:3e brd ff:ff:ff:ff:ff:ff
inet 192.168.0.6/24 brd 192.168.0.255 scope global eth1 inet 192.168.0.6/24 brd 192.168.0.255 scope global eth1
inet6 fe80::d63d:7eff:fe01:b43e/64 scope link \n\ inet6 fe80::d63d:7eff:fe01:b43e/64 scope link \n\
valid_lft forever preferred_lft forever valid_lft forever preferred_lft forever
3: wlan0: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc mq state UP qlen 1000 3: wlan0: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc mq state UP qlen\
1000
link/ether 90:f6:52:27:15:0a brd ff:ff:ff:ff:ff:ff link/ether 90:f6:52:27:15:0a brd ff:ff:ff:ff:ff:ff
inet 192.168.0.2/24 brd 192.168.0.255 scope global wlan0 inet 192.168.0.2/24 brd 192.168.0.255 scope global wlan0
inet6 fe80::92f6:52ff:fe27:150a/64 scope link \n\ inet6 fe80::92f6:52ff:fe27:150a/64 scope link \n\
@ -58,7 +64,8 @@ MOCK_ROUTE_OUTPUT = """\
=========================================================================== ===========================================================================
Interface List Interface List
0x1 ........................... MS TCP Loopback interface 0x1 ........................... MS TCP Loopback interface
0x2 ...08 00 27 c3 80 ad ...... AMD PCNET Family PCI Ethernet Adapter - Packet Scheduler Miniport 0x2 ...08 00 27 c3 80 ad ...... AMD PCNET Family PCI Ethernet Adapter - \
Packet Scheduler Miniport
=========================================================================== ===========================================================================
=========================================================================== ===========================================================================
Active Routes: Active Routes:
@ -85,6 +92,7 @@ class FakeProcess:
def __init__(self, output, err): def __init__(self, output, err):
self.output = output self.output = output
self.err = err self.err = err
def communicate(self): def communicate(self):
return (self.output, self.err) return (self.output, self.err)
@ -94,6 +102,7 @@ class ListAddresses(unittest.TestCase):
addresses = ipaddrs.find_addresses() addresses = ipaddrs.find_addresses()
self.failUnlessIn("127.0.0.1", addresses) self.failUnlessIn("127.0.0.1", addresses)
self.failIfIn("0.0.0.0", addresses) self.failIfIn("0.0.0.0", addresses)
# David A.'s OpenSolaris box timed out on this test one time when it was at # David A.'s OpenSolaris box timed out on this test one time when it was at
# 2s. # 2s.
test_list.timeout = 4 test_list.timeout = 4
@ -101,9 +110,20 @@ class ListAddresses(unittest.TestCase):
def _test_list_mock(self, command, output, expected): def _test_list_mock(self, command, output, expected):
self.first = True self.first = True
def call_Popen(args, bufsize=0, executable=None, stdin=None, stdout=None, stderr=None, def call_Popen(args,
preexec_fn=None, close_fds=False, shell=False, cwd=None, env=None, bufsize=0,
universal_newlines=False, startupinfo=None, creationflags=0): executable=None,
stdin=None,
stdout=None,
stderr=None,
preexec_fn=None,
close_fds=False,
shell=False,
cwd=None,
env=None,
universal_newlines=False,
startupinfo=None,
creationflags=0):
if self.first: if self.first:
self.first = False self.first = False
e = OSError("EINTR") e = OSError("EINTR")
@ -115,11 +135,13 @@ class ListAddresses(unittest.TestCase):
e = OSError("[Errno 2] No such file or directory") e = OSError("[Errno 2] No such file or directory")
e.errno = errno.ENOENT e.errno = errno.ENOENT
raise e raise e
self.patch(subprocess, 'Popen', call_Popen) self.patch(subprocess, 'Popen', call_Popen)
self.patch(os.path, 'isfile', lambda x: True) self.patch(os.path, 'isfile', lambda x: True)
def call_which(name): def call_which(name):
return [name] return [name]
self.patch(ipaddrs, 'which', call_which) self.patch(ipaddrs, 'which', call_which)
addresses = ipaddrs.find_addresses() addresses = ipaddrs.find_addresses()
@ -131,11 +153,13 @@ class ListAddresses(unittest.TestCase):
def test_list_mock_ifconfig(self): def test_list_mock_ifconfig(self):
self.patch(ipaddrs, 'platform', "linux2") self.patch(ipaddrs, 'platform', "linux2")
self._test_list_mock("ifconfig", MOCK_IFCONFIG_OUTPUT, UNIX_TEST_ADDRESSES) self._test_list_mock("ifconfig", MOCK_IFCONFIG_OUTPUT,
UNIX_TEST_ADDRESSES)
def test_list_mock_route(self): def test_list_mock_route(self):
self.patch(ipaddrs, 'platform', "win32") self.patch(ipaddrs, 'platform', "win32")
self._test_list_mock("route.exe", MOCK_ROUTE_OUTPUT, WINDOWS_TEST_ADDRESSES) self._test_list_mock("route.exe", MOCK_ROUTE_OUTPUT,
WINDOWS_TEST_ADDRESSES)
def test_list_mock_cygwin(self): def test_list_mock_cygwin(self):
self.patch(ipaddrs, 'platform', "cygwin") self.patch(ipaddrs, 'platform', "cygwin")

View File

@ -1,8 +1,11 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
from twisted.trial import unittest from twisted.trial import unittest
from .. import journal from .. import journal
from .._interfaces import IJournal from .._interfaces import IJournal
class Journal(unittest.TestCase): class Journal(unittest.TestCase):
def test_journal(self): def test_journal(self):
events = [] events = []

View File

@ -1,29 +1,36 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import json import json
import mock
from zope.interface import directlyProvides, implementer from nacl.secret import SecretBox
from spake2 import SPAKE2_Symmetric
from twisted.trial import unittest from twisted.trial import unittest
from .. import (errors, timing, _order, _receive, _key, _code, _lister, _boss, from zope.interface import directlyProvides, implementer
_input, _allocator, _send, _terminator, _nameplate, _mailbox,
_rendezvous, __version__) import mock
from .._interfaces import (IKey, IReceive, IBoss, ISend, IMailbox, IOrder,
IRendezvousConnector, ILister, IInput, IAllocator, from .. import (__version__, _allocator, _boss, _code, _input, _key, _lister,
INameplate, ICode, IWordlist, ITerminator) _mailbox, _nameplate, _order, _receive, _rendezvous, _send,
_terminator, errors, timing)
from .._interfaces import (IAllocator, IBoss, ICode, IInput, IKey, ILister,
IMailbox, INameplate, IOrder, IReceive,
IRendezvousConnector, ISend, ITerminator, IWordlist)
from .._key import derive_key, derive_phase_key, encrypt_data from .._key import derive_key, derive_phase_key, encrypt_data
from ..journal import ImmediateJournal from ..journal import ImmediateJournal
from ..util import (dict_to_bytes, bytes_to_dict, from ..util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes,
hexstr_to_bytes, bytes_to_hexstr, to_bytes) hexstr_to_bytes, to_bytes)
from spake2 import SPAKE2_Symmetric
from nacl.secret import SecretBox
@implementer(IWordlist) @implementer(IWordlist)
class FakeWordList(object): class FakeWordList(object):
def choose_words(self, length): def choose_words(self, length):
return "-".join(["word"] * length) return "-".join(["word"] * length)
def get_completions(self, prefix): def get_completions(self, prefix):
self._get_completions_prefix = prefix self._get_completions_prefix = prefix
return self._completions return self._completions
class Dummy: class Dummy:
def __init__(self, name, events, iface, *meths): def __init__(self, name, events, iface, *meths):
self.name = name self.name = name
@ -33,12 +40,15 @@ class Dummy:
for meth in meths: for meth in meths:
self.mock(meth) self.mock(meth)
self.retval = None self.retval = None
def mock(self, meth): def mock(self, meth):
def log(*args): def log(*args):
self.events.append(("%s.%s" % (self.name, meth), ) + args) self.events.append(("%s.%s" % (self.name, meth), ) + args)
return self.retval return self.retval
setattr(self, meth, log) setattr(self, meth, log)
class Send(unittest.TestCase): class Send(unittest.TestCase):
def build(self): def build(self):
events = [] events = []
@ -57,7 +67,9 @@ class Send(unittest.TestCase):
s.got_verified_key(key) s.got_verified_key(key)
self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)]) self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)])
# print(bytes_to_hexstr(events[0][2])) # print(bytes_to_hexstr(events[0][2]))
enc1 = hexstr_to_bytes("00000000000000000000000000000000000000000000000022f1a46c3c3496423c394621a2a5a8cf275b08") enc1 = hexstr_to_bytes(
("000000000000000000000000000000000000000000000000"
"22f1a46c3c3496423c394621a2a5a8cf275b08"))
self.assertEqual(events, [("m.add_message", "phase1", enc1)]) self.assertEqual(events, [("m.add_message", "phase1", enc1)])
events[:] = [] events[:] = []
@ -65,7 +77,9 @@ class Send(unittest.TestCase):
with mock.patch("nacl.utils.random", side_effect=[nonce2]) as r: with mock.patch("nacl.utils.random", side_effect=[nonce2]) as r:
s.send("phase2", b"msg") s.send("phase2", b"msg")
self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)]) self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)])
enc2 = hexstr_to_bytes("0202020202020202020202020202020202020202020202026660337c3eac6513c0dac9818b62ef16d9cd7e") enc2 = hexstr_to_bytes(
("0202020202020202020202020202020202020202"
"020202026660337c3eac6513c0dac9818b62ef16d9cd7e"))
self.assertEqual(events, [("m.add_message", "phase2", enc2)]) self.assertEqual(events, [("m.add_message", "phase2", enc2)])
def test_key_first(self): def test_key_first(self):
@ -78,7 +92,8 @@ class Send(unittest.TestCase):
with mock.patch("nacl.utils.random", side_effect=[nonce1]) as r: with mock.patch("nacl.utils.random", side_effect=[nonce1]) as r:
s.send("phase1", b"msg") s.send("phase1", b"msg")
self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)]) self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)])
enc1 = hexstr_to_bytes("00000000000000000000000000000000000000000000000022f1a46c3c3496423c394621a2a5a8cf275b08") enc1 = hexstr_to_bytes(("00000000000000000000000000000000000000000000"
"000022f1a46c3c3496423c394621a2a5a8cf275b08"))
self.assertEqual(events, [("m.add_message", "phase1", enc1)]) self.assertEqual(events, [("m.add_message", "phase1", enc1)])
events[:] = [] events[:] = []
@ -86,11 +101,12 @@ class Send(unittest.TestCase):
with mock.patch("nacl.utils.random", side_effect=[nonce2]) as r: with mock.patch("nacl.utils.random", side_effect=[nonce2]) as r:
s.send("phase2", b"msg") s.send("phase2", b"msg")
self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)]) self.assertEqual(r.mock_calls, [mock.call(SecretBox.NONCE_SIZE)])
enc2 = hexstr_to_bytes("0202020202020202020202020202020202020202020202026660337c3eac6513c0dac9818b62ef16d9cd7e") enc2 = hexstr_to_bytes(
("0202020202020202020202020202020202020"
"202020202026660337c3eac6513c0dac9818b62ef16d9cd7e"))
self.assertEqual(events, [("m.add_message", "phase2", enc2)]) self.assertEqual(events, [("m.add_message", "phase2", enc2)])
class Order(unittest.TestCase): class Order(unittest.TestCase):
def build(self): def build(self):
events = [] events = []
@ -106,8 +122,8 @@ class Order(unittest.TestCase):
self.assertEqual(events, [("k.got_pake", b"body")]) # right away self.assertEqual(events, [("k.got_pake", b"body")]) # right away
o.got_message(u"side", u"version", b"body") o.got_message(u"side", u"version", b"body")
o.got_message(u"side", u"1", b"body") o.got_message(u"side", u"1", b"body")
self.assertEqual(events, self.assertEqual(events, [
[("k.got_pake", b"body"), ("k.got_pake", b"body"),
("r.got_message", u"side", u"version", b"body"), ("r.got_message", u"side", u"version", b"body"),
("r.got_message", u"side", u"1", b"body"), ("r.got_message", u"side", u"1", b"body"),
]) ])
@ -120,18 +136,19 @@ class Order(unittest.TestCase):
self.assertEqual(events, []) # nothing yet self.assertEqual(events, []) # nothing yet
o.got_message(u"side", u"pake", b"body") o.got_message(u"side", u"pake", b"body")
# got_pake is delivered first # got_pake is delivered first
self.assertEqual(events, self.assertEqual(events, [
[("k.got_pake", b"body"), ("k.got_pake", b"body"),
("r.got_message", u"side", u"version", b"body"), ("r.got_message", u"side", u"version", b"body"),
("r.got_message", u"side", u"1", b"body"), ("r.got_message", u"side", u"1", b"body"),
]) ])
class Receive(unittest.TestCase): class Receive(unittest.TestCase):
def build(self): def build(self):
events = [] events = []
r = _receive.Receive(u"side", timing.DebugTiming()) r = _receive.Receive(u"side", timing.DebugTiming())
b = Dummy("b", events, IBoss, b = Dummy("b", events, IBoss, "happy", "scared", "got_verifier",
"happy", "scared", "got_verifier", "got_message") "got_message")
s = Dummy("s", events, ISend, "got_verified_key") s = Dummy("s", events, ISend, "got_verified_key")
r.wire(b, s) r.wire(b, s)
return r, b, s, events return r, b, s, events
@ -146,7 +163,8 @@ class Receive(unittest.TestCase):
data1 = b"data1" data1 = b"data1"
good_body = encrypt_data(phase1_key, data1) good_body = encrypt_data(phase1_key, data1)
r.got_message(u"side", u"phase1", good_body) r.got_message(u"side", u"phase1", good_body)
self.assertEqual(events, [("s.got_verified_key", key), self.assertEqual(events, [
("s.got_verified_key", key),
("b.happy", ), ("b.happy", ),
("b.got_verifier", verifier), ("b.got_verifier", verifier),
("b.got_message", u"phase1", data1), ("b.got_message", u"phase1", data1),
@ -156,7 +174,8 @@ class Receive(unittest.TestCase):
data2 = b"data2" data2 = b"data2"
good_body = encrypt_data(phase2_key, data2) good_body = encrypt_data(phase2_key, data2)
r.got_message(u"side", u"phase2", good_body) r.got_message(u"side", u"phase2", good_body)
self.assertEqual(events, [("s.got_verified_key", key), self.assertEqual(events, [
("s.got_verified_key", key),
("b.happy", ), ("b.happy", ),
("b.got_verifier", verifier), ("b.got_verifier", verifier),
("b.got_message", u"phase1", data1), ("b.got_message", u"phase1", data1),
@ -172,14 +191,16 @@ class Receive(unittest.TestCase):
data1 = b"data1" data1 = b"data1"
bad_body = encrypt_data(phase1_key, data1) bad_body = encrypt_data(phase1_key, data1)
r.got_message(u"side", u"phase1", bad_body) r.got_message(u"side", u"phase1", bad_body)
self.assertEqual(events, [("b.scared",), self.assertEqual(events, [
("b.scared", ),
]) ])
phase2_key = derive_phase_key(key, u"side", u"phase2") phase2_key = derive_phase_key(key, u"side", u"phase2")
data2 = b"data2" data2 = b"data2"
good_body = encrypt_data(phase2_key, data2) good_body = encrypt_data(phase2_key, data2)
r.got_message(u"side", u"phase2", good_body) r.got_message(u"side", u"phase2", good_body)
self.assertEqual(events, [("b.scared",), self.assertEqual(events, [
("b.scared", ),
]) ])
def test_late_bad(self): def test_late_bad(self):
@ -192,7 +213,8 @@ class Receive(unittest.TestCase):
data1 = b"data1" data1 = b"data1"
good_body = encrypt_data(phase1_key, data1) good_body = encrypt_data(phase1_key, data1)
r.got_message(u"side", u"phase1", good_body) r.got_message(u"side", u"phase1", good_body)
self.assertEqual(events, [("s.got_verified_key", key), self.assertEqual(events, [
("s.got_verified_key", key),
("b.happy", ), ("b.happy", ),
("b.got_verifier", verifier), ("b.got_verifier", verifier),
("b.got_message", u"phase1", data1), ("b.got_message", u"phase1", data1),
@ -202,7 +224,8 @@ class Receive(unittest.TestCase):
data2 = b"data2" data2 = b"data2"
bad_body = encrypt_data(phase2_key, data2) bad_body = encrypt_data(phase2_key, data2)
r.got_message(u"side", u"phase2", bad_body) r.got_message(u"side", u"phase2", bad_body)
self.assertEqual(events, [("s.got_verified_key", key), self.assertEqual(events, [
("s.got_verified_key", key),
("b.happy", ), ("b.happy", ),
("b.got_verifier", verifier), ("b.got_verifier", verifier),
("b.got_message", u"phase1", data1), ("b.got_message", u"phase1", data1),
@ -210,13 +233,15 @@ class Receive(unittest.TestCase):
]) ])
r.got_message(u"side", u"phase1", good_body) r.got_message(u"side", u"phase1", good_body)
r.got_message(u"side", u"phase2", bad_body) r.got_message(u"side", u"phase2", bad_body)
self.assertEqual(events, [("s.got_verified_key", key), self.assertEqual(events, [
("s.got_verified_key", key),
("b.happy", ), ("b.happy", ),
("b.got_verifier", verifier), ("b.got_verifier", verifier),
("b.got_message", u"phase1", data1), ("b.got_message", u"phase1", data1),
("b.scared", ), ("b.scared", ),
]) ])
class Key(unittest.TestCase): class Key(unittest.TestCase):
def test_derive_errors(self): def test_derive_errors(self):
self.assertRaises(TypeError, derive_key, 123, b"purpose") self.assertRaises(TypeError, derive_key, 123, b"purpose")
@ -260,7 +285,8 @@ class Key(unittest.TestCase):
self.assertEqual(events[0][:2], ("m.add_message", "pake")) self.assertEqual(events[0][:2], ("m.add_message", "pake"))
pake_1_json = events[0][2].decode("utf-8") pake_1_json = events[0][2].decode("utf-8")
pake_1 = json.loads(pake_1_json) pake_1 = json.loads(pake_1_json)
self.assertEqual(list(pake_1.keys()), ["pake_v1"]) # value is PAKE stuff self.assertEqual(list(pake_1.keys()),
["pake_v1"]) # value is PAKE stuff
events[:] = [] events[:] = []
bad_pake_d = {"not_pake_v1": "stuff"} bad_pake_d = {"not_pake_v1": "stuff"}
k.got_pake(dict_to_bytes(bad_pake_d)) k.got_pake(dict_to_bytes(bad_pake_d))
@ -293,6 +319,7 @@ class Key(unittest.TestCase):
self.assertEqual(events[2][:2], ("m.add_message", "version")) self.assertEqual(events[2][:2], ("m.add_message", "version"))
self.assertEqual(events[3], ("r.got_key", key2)) self.assertEqual(events[3], ("r.got_key", key2))
class Code(unittest.TestCase): class Code(unittest.TestCase):
def build(self): def build(self):
events = [] events = []
@ -308,7 +335,8 @@ class Code(unittest.TestCase):
def test_set_code(self): def test_set_code(self):
c, b, a, n, k, i, events = self.build() c, b, a, n, k, i, events = self.build()
c.set_code(u"1-code") c.set_code(u"1-code")
self.assertEqual(events, [("n.set_nameplate", u"1"), self.assertEqual(events, [
("n.set_nameplate", u"1"),
("b.got_code", u"1-code"), ("b.got_code", u"1-code"),
("k.got_code", u"1-code"), ("k.got_code", u"1-code"),
]) ])
@ -323,12 +351,14 @@ class Code(unittest.TestCase):
self.assertEqual(str(e.exception), "Code ' 1-code' contains spaces.") self.assertEqual(str(e.exception), "Code ' 1-code' contains spaces.")
with self.assertRaises(errors.KeyFormatError) as e: with self.assertRaises(errors.KeyFormatError) as e:
c.set_code(u"code-code") c.set_code(u"code-code")
self.assertEqual(str(e.exception), self.assertEqual(
str(e.exception),
"Nameplate 'code' must be numeric, with no spaces.") "Nameplate 'code' must be numeric, with no spaces.")
# it should still be possible to use the wormhole at this point # it should still be possible to use the wormhole at this point
c.set_code(u"1-code") c.set_code(u"1-code")
self.assertEqual(events, [("n.set_nameplate", u"1"), self.assertEqual(events, [
("n.set_nameplate", u"1"),
("b.got_code", u"1-code"), ("b.got_code", u"1-code"),
("k.got_code", u"1-code"), ("k.got_code", u"1-code"),
]) ])
@ -340,7 +370,8 @@ class Code(unittest.TestCase):
self.assertEqual(events, [("a.allocate", 2, wl)]) self.assertEqual(events, [("a.allocate", 2, wl)])
events[:] = [] events[:] = []
c.allocated("1", "1-code") c.allocated("1", "1-code")
self.assertEqual(events, [("n.set_nameplate", u"1"), self.assertEqual(events, [
("n.set_nameplate", u"1"),
("b.got_code", u"1-code"), ("b.got_code", u"1-code"),
("k.got_code", u"1-code"), ("k.got_code", u"1-code"),
]) ])
@ -351,14 +382,17 @@ class Code(unittest.TestCase):
self.assertEqual(events, [("i.start", )]) self.assertEqual(events, [("i.start", )])
events[:] = [] events[:] = []
c.got_nameplate("1") c.got_nameplate("1")
self.assertEqual(events, [("n.set_nameplate", u"1"), self.assertEqual(events, [
("n.set_nameplate", u"1"),
]) ])
events[:] = [] events[:] = []
c.finished_input("1-code") c.finished_input("1-code")
self.assertEqual(events, [("b.got_code", u"1-code"), self.assertEqual(events, [
("b.got_code", u"1-code"),
("k.got_code", u"1-code"), ("k.got_code", u"1-code"),
]) ])
class Input(unittest.TestCase): class Input(unittest.TestCase):
def build(self): def build(self):
events = [] events = []
@ -422,13 +456,13 @@ class Input(unittest.TestCase):
helper.get_word_completions("prefix") helper.get_word_completions("prefix")
i.got_nameplates({"1", "12", "34", "35", "367"}) i.got_nameplates({"1", "12", "34", "35", "367"})
self.assertNoResult(d) self.assertNoResult(d)
self.assertEqual(helper.get_nameplate_completions(""), self.assertEqual(
helper.get_nameplate_completions(""),
{"1-", "12-", "34-", "35-", "367-"}) {"1-", "12-", "34-", "35-", "367-"})
self.assertEqual(helper.get_nameplate_completions("1"), self.assertEqual(helper.get_nameplate_completions("1"), {"1-", "12-"})
{"1-", "12-"})
self.assertEqual(helper.get_nameplate_completions("2"), set()) self.assertEqual(helper.get_nameplate_completions("2"), set())
self.assertEqual(helper.get_nameplate_completions("3"), self.assertEqual(
{"34-", "35-", "367-"}) helper.get_nameplate_completions("3"), {"34-", "35-", "367-"})
helper.choose_nameplate("34") helper.choose_nameplate("34")
with self.assertRaises(errors.AlreadyChoseNameplateError): with self.assertRaises(errors.AlreadyChoseNameplateError):
helper.refresh_nameplates() helper.refresh_nameplates()
@ -461,15 +495,14 @@ class Input(unittest.TestCase):
self.assertEqual(events, [("c.finished_input", "34-word-word")]) self.assertEqual(events, [("c.finished_input", "34-word-word")])
class Lister(unittest.TestCase): class Lister(unittest.TestCase):
def build(self): def build(self):
events = [] events = []
l = _lister.Lister(timing.DebugTiming()) lister = _lister.Lister(timing.DebugTiming())
rc = Dummy("rc", events, IRendezvousConnector, "tx_list") rc = Dummy("rc", events, IRendezvousConnector, "tx_list")
i = Dummy("i", events, IInput, "got_nameplates") i = Dummy("i", events, IInput, "got_nameplates")
l.wire(rc, i) lister.wire(rc, i)
return l, rc, i, events return lister, rc, i, events
def test_connect_first(self): def test_connect_first(self):
l, rc, i, events = self.build() l, rc, i, events = self.build()
@ -478,11 +511,13 @@ class Lister(unittest.TestCase):
l.connected() l.connected()
self.assertEqual(events, []) self.assertEqual(events, [])
l.refresh() l.refresh()
self.assertEqual(events, [("rc.tx_list",), self.assertEqual(events, [
("rc.tx_list", ),
]) ])
events[:] = [] events[:] = []
l.rx_nameplates({"1", "2", "3"}) l.rx_nameplates({"1", "2", "3"})
self.assertEqual(events, [("i.got_nameplates", {"1", "2", "3"}), self.assertEqual(events, [
("i.got_nameplates", {"1", "2", "3"}),
]) ])
events[:] = [] events[:] = []
# now we're satisfied: disconnecting and reconnecting won't ask again # now we're satisfied: disconnecting and reconnecting won't ask again
@ -492,7 +527,8 @@ class Lister(unittest.TestCase):
# but if we're told to refresh, we'll do so # but if we're told to refresh, we'll do so
l.refresh() l.refresh()
self.assertEqual(events, [("rc.tx_list",), self.assertEqual(events, [
("rc.tx_list", ),
]) ])
def test_connect_first_ask_twice(self): def test_connect_first_ask_twice(self):
@ -501,16 +537,19 @@ class Lister(unittest.TestCase):
self.assertEqual(events, []) self.assertEqual(events, [])
l.refresh() l.refresh()
l.refresh() l.refresh()
self.assertEqual(events, [("rc.tx_list",), self.assertEqual(events, [
("rc.tx_list", ),
("rc.tx_list", ), ("rc.tx_list", ),
]) ])
l.rx_nameplates({"1", "2", "3"}) l.rx_nameplates({"1", "2", "3"})
self.assertEqual(events, [("rc.tx_list",), self.assertEqual(events, [
("rc.tx_list", ),
("rc.tx_list", ), ("rc.tx_list", ),
("i.got_nameplates", {"1", "2", "3"}), ("i.got_nameplates", {"1", "2", "3"}),
]) ])
l.rx_nameplates({"1", "2", "3", "4"}) l.rx_nameplates({"1", "2", "3", "4"})
self.assertEqual(events, [("rc.tx_list",), self.assertEqual(events, [
("rc.tx_list", ),
("rc.tx_list", ), ("rc.tx_list", ),
("i.got_nameplates", {"1", "2", "3"}), ("i.got_nameplates", {"1", "2", "3"}),
("i.got_nameplates", {"1", "2", "3", "4"}), ("i.got_nameplates", {"1", "2", "3", "4"}),
@ -520,12 +559,14 @@ class Lister(unittest.TestCase):
l, rc, i, events = self.build() l, rc, i, events = self.build()
l.refresh() l.refresh()
l.connected() l.connected()
self.assertEqual(events, [("rc.tx_list",), self.assertEqual(events, [
("rc.tx_list", ),
]) ])
events[:] = [] events[:] = []
l.lost() l.lost()
l.connected() l.connected()
self.assertEqual(events, [("rc.tx_list",), self.assertEqual(events, [
("rc.tx_list", ),
]) ])
def test_refresh_first(self): def test_refresh_first(self):
@ -533,10 +574,12 @@ class Lister(unittest.TestCase):
l.refresh() l.refresh()
self.assertEqual(events, []) self.assertEqual(events, [])
l.connected() l.connected()
self.assertEqual(events, [("rc.tx_list",), self.assertEqual(events, [
("rc.tx_list", ),
]) ])
l.rx_nameplates({"1", "2", "3"}) l.rx_nameplates({"1", "2", "3"})
self.assertEqual(events, [("rc.tx_list",), self.assertEqual(events, [
("rc.tx_list", ),
("i.got_nameplates", {"1", "2", "3"}), ("i.got_nameplates", {"1", "2", "3"}),
]) ])
@ -547,9 +590,11 @@ class Lister(unittest.TestCase):
l.connected() l.connected()
self.assertEqual(events, []) self.assertEqual(events, [])
l.rx_nameplates({"1", "2", "3"}) l.rx_nameplates({"1", "2", "3"})
self.assertEqual(events, [("i.got_nameplates", {"1", "2", "3"}), self.assertEqual(events, [
("i.got_nameplates", {"1", "2", "3"}),
]) ])
class Allocator(unittest.TestCase): class Allocator(unittest.TestCase):
def build(self): def build(self):
events = [] events = []
@ -573,11 +618,13 @@ class Allocator(unittest.TestCase):
events[:] = [] events[:] = []
a.lost() a.lost()
a.connected() a.connected()
self.assertEqual(events, [("rc.tx_allocate",), self.assertEqual(events, [
("rc.tx_allocate", ),
]) ])
events[:] = [] events[:] = []
a.rx_allocated("1") a.rx_allocated("1")
self.assertEqual(events, [("c.allocated", "1", "1-word-word"), self.assertEqual(events, [
("c.allocated", "1", "1-word-word"),
]) ])
def test_connect_first(self): def test_connect_first(self):
@ -589,20 +636,24 @@ class Allocator(unittest.TestCase):
events[:] = [] events[:] = []
a.lost() a.lost()
a.connected() a.connected()
self.assertEqual(events, [("rc.tx_allocate",), self.assertEqual(events, [
("rc.tx_allocate", ),
]) ])
events[:] = [] events[:] = []
a.rx_allocated("1") a.rx_allocated("1")
self.assertEqual(events, [("c.allocated", "1", "1-word-word"), self.assertEqual(events, [
("c.allocated", "1", "1-word-word"),
]) ])
class Nameplate(unittest.TestCase): class Nameplate(unittest.TestCase):
def build(self): def build(self):
events = [] events = []
n = _nameplate.Nameplate() n = _nameplate.Nameplate()
m = Dummy("m", events, IMailbox, "got_mailbox") m = Dummy("m", events, IMailbox, "got_mailbox")
i = Dummy("i", events, IInput, "got_wordlist") i = Dummy("i", events, IInput, "got_wordlist")
rc = Dummy("rc", events, IRendezvousConnector, "tx_claim", "tx_release") rc = Dummy("rc", events, IRendezvousConnector, "tx_claim",
"tx_release")
t = Dummy("t", events, ITerminator, "nameplate_done") t = Dummy("t", events, ITerminator, "nameplate_done")
n.wire(m, i, rc, t) n.wire(m, i, rc, t)
return n, m, i, rc, t, events return n, m, i, rc, t, events
@ -611,11 +662,13 @@ class Nameplate(unittest.TestCase):
n, m, i, rc, t, events = self.build() n, m, i, rc, t, events = self.build()
with self.assertRaises(errors.KeyFormatError) as e: with self.assertRaises(errors.KeyFormatError) as e:
n.set_nameplate(" 1") n.set_nameplate(" 1")
self.assertEqual(str(e.exception), self.assertEqual(
str(e.exception),
"Nameplate ' 1' must be numeric, with no spaces.") "Nameplate ' 1' must be numeric, with no spaces.")
with self.assertRaises(errors.KeyFormatError) as e: with self.assertRaises(errors.KeyFormatError) as e:
n.set_nameplate("one") n.set_nameplate("one")
self.assertEqual(str(e.exception), self.assertEqual(
str(e.exception),
"Nameplate 'one' must be numeric, with no spaces.") "Nameplate 'one' must be numeric, with no spaces.")
# wormhole should still be usable # wormhole should still be usable
@ -636,7 +689,8 @@ class Nameplate(unittest.TestCase):
wl = object() wl = object()
with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl):
n.rx_claimed("mbox1") n.rx_claimed("mbox1")
self.assertEqual(events, [("i.got_wordlist", wl), self.assertEqual(events, [
("i.got_wordlist", wl),
("m.got_mailbox", "mbox1"), ("m.got_mailbox", "mbox1"),
]) ])
events[:] = [] events[:] = []
@ -661,7 +715,8 @@ class Nameplate(unittest.TestCase):
wl = object() wl = object()
with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl):
n.rx_claimed("mbox1") n.rx_claimed("mbox1")
self.assertEqual(events, [("i.got_wordlist", wl), self.assertEqual(events, [
("i.got_wordlist", wl),
("m.got_mailbox", "mbox1"), ("m.got_mailbox", "mbox1"),
]) ])
events[:] = [] events[:] = []
@ -700,7 +755,8 @@ class Nameplate(unittest.TestCase):
wl = object() wl = object()
with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl):
n.rx_claimed("mbox1") n.rx_claimed("mbox1")
self.assertEqual(events, [("i.got_wordlist", wl), self.assertEqual(events, [
("i.got_wordlist", wl),
("m.got_mailbox", "mbox1"), ("m.got_mailbox", "mbox1"),
]) ])
events[:] = [] events[:] = []
@ -722,7 +778,8 @@ class Nameplate(unittest.TestCase):
wl = object() wl = object()
with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl):
n.rx_claimed("mbox1") n.rx_claimed("mbox1")
self.assertEqual(events, [("i.got_wordlist", wl), self.assertEqual(events, [
("i.got_wordlist", wl),
("m.got_mailbox", "mbox1"), ("m.got_mailbox", "mbox1"),
]) ])
events[:] = [] events[:] = []
@ -748,7 +805,8 @@ class Nameplate(unittest.TestCase):
wl = object() wl = object()
with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl):
n.rx_claimed("mbox1") n.rx_claimed("mbox1")
self.assertEqual(events, [("i.got_wordlist", wl), self.assertEqual(events, [
("i.got_wordlist", wl),
("m.got_mailbox", "mbox1"), ("m.got_mailbox", "mbox1"),
]) ])
events[:] = [] events[:] = []
@ -828,7 +886,8 @@ class Nameplate(unittest.TestCase):
wl = object() wl = object()
with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl):
n.rx_claimed("mbox1") n.rx_claimed("mbox1")
self.assertEqual(events, [("i.got_wordlist", wl), self.assertEqual(events, [
("i.got_wordlist", wl),
("m.got_mailbox", "mbox1"), ("m.got_mailbox", "mbox1"),
]) ])
events[:] = [] events[:] = []
@ -852,7 +911,8 @@ class Nameplate(unittest.TestCase):
wl = object() wl = object()
with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl):
n.rx_claimed("mbox1") n.rx_claimed("mbox1")
self.assertEqual(events, [("i.got_wordlist", wl), self.assertEqual(events, [
("i.got_wordlist", wl),
("m.got_mailbox", "mbox1"), ("m.got_mailbox", "mbox1"),
]) ])
events[:] = [] events[:] = []
@ -878,7 +938,8 @@ class Nameplate(unittest.TestCase):
wl = object() wl = object()
with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl):
n.rx_claimed("mbox1") n.rx_claimed("mbox1")
self.assertEqual(events, [("i.got_wordlist", wl), self.assertEqual(events, [
("i.got_wordlist", wl),
("m.got_mailbox", "mbox1"), ("m.got_mailbox", "mbox1"),
]) ])
events[:] = [] events[:] = []
@ -903,7 +964,8 @@ class Nameplate(unittest.TestCase):
wl = object() wl = object()
with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl):
n.rx_claimed("mbox1") n.rx_claimed("mbox1")
self.assertEqual(events, [("i.got_wordlist", wl), self.assertEqual(events, [
("i.got_wordlist", wl),
("m.got_mailbox", "mbox1"), ("m.got_mailbox", "mbox1"),
]) ])
events[:] = [] events[:] = []
@ -937,7 +999,8 @@ class Nameplate(unittest.TestCase):
wl = object() wl = object()
with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl):
n.rx_claimed("mbox1") n.rx_claimed("mbox1")
self.assertEqual(events, [("i.got_wordlist", wl), self.assertEqual(events, [
("i.got_wordlist", wl),
("m.got_mailbox", "mbox1"), ("m.got_mailbox", "mbox1"),
]) ])
events[:] = [] events[:] = []
@ -966,7 +1029,8 @@ class Nameplate(unittest.TestCase):
wl = object() wl = object()
with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl): with mock.patch("wormhole._nameplate.PGPWordList", return_value=wl):
n.rx_claimed("mbox1") n.rx_claimed("mbox1")
self.assertEqual(events, [("i.got_wordlist", wl), self.assertEqual(events, [
("i.got_wordlist", wl),
("m.got_mailbox", "mbox1"), ("m.got_mailbox", "mbox1"),
]) ])
events[:] = [] events[:] = []
@ -983,13 +1047,14 @@ class Nameplate(unittest.TestCase):
n.close() # NOP n.close() # NOP
self.assertEqual(events, []) self.assertEqual(events, [])
class Mailbox(unittest.TestCase): class Mailbox(unittest.TestCase):
def build(self): def build(self):
events = [] events = []
m = _mailbox.Mailbox("side1") m = _mailbox.Mailbox("side1")
n = Dummy("n", events, INameplate, "release") n = Dummy("n", events, INameplate, "release")
rc = Dummy("rc", events, IRendezvousConnector, rc = Dummy("rc", events, IRendezvousConnector, "tx_add", "tx_open",
"tx_add", "tx_open", "tx_close") "tx_close")
o = Dummy("o", events, IOrder, "got_message") o = Dummy("o", events, IOrder, "got_message")
t = Dummy("t", events, ITerminator, "mailbox_done") t = Dummy("t", events, ITerminator, "mailbox_done")
m.wire(n, rc, o, t) m.wire(n, rc, o, t)
@ -998,8 +1063,9 @@ class Mailbox(unittest.TestCase):
# TODO: test moods # TODO: test moods
def assert_events(self, events, initial_events, tx_add_events): def assert_events(self, events, initial_events, tx_add_events):
self.assertEqual(len(events), len(initial_events)+len(tx_add_events), self.assertEqual(
events) len(events),
len(initial_events) + len(tx_add_events), events)
self.assertEqual(events[:len(initial_events)], initial_events) self.assertEqual(events[:len(initial_events)], initial_events)
self.assertEqual(set(events[len(initial_events):]), tx_add_events) self.assertEqual(set(events[len(initial_events):]), tx_add_events)
@ -1029,20 +1095,22 @@ class Mailbox(unittest.TestCase):
m.connected() m.connected()
# the other messages are allowed to be sent in any order # the other messages are allowed to be sent in any order
self.assert_events(events, [("rc.tx_open", "mbox1")], self.assert_events(
{ ("rc.tx_add", "phase1", b"msg1"), events, [("rc.tx_open", "mbox1")], {
("rc.tx_add", "phase1", b"msg1"),
("rc.tx_add", "phase2", b"msg2"), ("rc.tx_add", "phase2", b"msg2"),
("rc.tx_add", "phase3", b"msg3"), ("rc.tx_add", "phase3", b"msg3"),
}) })
events[:] = [] events[:] = []
m.rx_message("side1", "phase1", b"msg1") # echo of our message, dequeue m.rx_message("side1", "phase1",
b"msg1") # echo of our message, dequeue
self.assertEqual(events, []) self.assertEqual(events, [])
m.lost() m.lost()
m.connected() m.connected()
self.assert_events(events, [("rc.tx_open", "mbox1")], self.assert_events(events, [("rc.tx_open", "mbox1")], {
{("rc.tx_add", "phase2", b"msg2"), ("rc.tx_add", "phase2", b"msg2"),
("rc.tx_add", "phase3", b"msg3"), ("rc.tx_add", "phase3", b"msg3"),
}) })
events[:] = [] events[:] = []
@ -1051,7 +1119,8 @@ class Mailbox(unittest.TestCase):
# released since the message proves that our peer opened the Mailbox # released since the message proves that our peer opened the Mailbox
# and therefore no longer needs the Nameplate # and therefore no longer needs the Nameplate
m.rx_message("side2", "phase1", b"msg1them") # new message from peer m.rx_message("side2", "phase1", b"msg1them") # new message from peer
self.assertEqual(events, [("n.release",), self.assertEqual(events, [
("n.release", ),
("o.got_message", "side2", "phase1", b"msg1them"), ("o.got_message", "side2", "phase1", b"msg1them"),
]) ])
events[:] = [] events[:] = []
@ -1059,7 +1128,8 @@ class Mailbox(unittest.TestCase):
# we de-duplicate peer messages, but still re-release the nameplate # we de-duplicate peer messages, but still re-release the nameplate
# since Nameplate is smart enough to ignore that # since Nameplate is smart enough to ignore that
m.rx_message("side2", "phase1", b"msg1them") m.rx_message("side2", "phase1", b"msg1them")
self.assertEqual(events, [("n.release",), self.assertEqual(events, [
("n.release", ),
]) ])
events[:] = [] events[:] = []
@ -1103,8 +1173,8 @@ class Mailbox(unittest.TestCase):
m.connected() m.connected()
self.assert_events(events, [("rc.tx_open", "mbox1")], self.assert_events(events, [("rc.tx_open", "mbox1")], {
{ ("rc.tx_add", "phase1", b"msg1"), ("rc.tx_add", "phase1", b"msg1"),
("rc.tx_add", "phase2", b"msg2"), ("rc.tx_add", "phase2", b"msg2"),
}) })
@ -1146,6 +1216,7 @@ class Mailbox(unittest.TestCase):
self.assertEqual(events, [("t.mailbox_done", )]) self.assertEqual(events, [("t.mailbox_done", )])
events[:] = [] events[:] = []
class Terminator(unittest.TestCase): class Terminator(unittest.TestCase):
def build(self): def build(self):
events = [] events = []
@ -1160,11 +1231,13 @@ class Terminator(unittest.TestCase):
# there are three events, and we need to test all orderings of them # there are three events, and we need to test all orderings of them
def _do_test(self, ev1, ev2, ev3): def _do_test(self, ev1, ev2, ev3):
t, b, rc, n, m, events = self.build() t, b, rc, n, m, events = self.build()
input_events = {"mailbox": lambda: t.mailbox_done(), input_events = {
"mailbox": lambda: t.mailbox_done(),
"nameplate": lambda: t.nameplate_done(), "nameplate": lambda: t.nameplate_done(),
"close": lambda: t.close("happy"), "close": lambda: t.close("happy"),
} }
close_events = [("n.close",), close_events = [
("n.close", ),
("m.close", "happy"), ("m.close", "happy"),
] ]
@ -1203,18 +1276,19 @@ class Terminator(unittest.TestCase):
# TODO: test moods # TODO: test moods
class MockBoss(_boss.Boss): class MockBoss(_boss.Boss):
def __attrs_post_init__(self): def __attrs_post_init__(self):
# self._build_workers() # self._build_workers()
self._init_other_state() self._init_other_state()
class Boss(unittest.TestCase): class Boss(unittest.TestCase):
def build(self): def build(self):
events = [] events = []
wormhole = Dummy("w", events, None, wormhole = Dummy("w", events, None, "got_welcome", "got_code",
"got_welcome", "got_key", "got_verifier", "got_versions", "received",
"got_code", "got_key", "got_verifier", "got_versions", "closed")
"received", "closed")
versions = {"app": "version1"} versions = {"app": "version1"}
reactor = None reactor = None
journal = ImmediateJournal() journal = ImmediateJournal()
@ -1226,8 +1300,8 @@ class Boss(unittest.TestCase):
b._T = Dummy("t", events, ITerminator, "close") b._T = Dummy("t", events, ITerminator, "close")
b._S = Dummy("s", events, ISend, "send") b._S = Dummy("s", events, ISend, "send")
b._RC = Dummy("rc", events, IRendezvousConnector, "start") b._RC = Dummy("rc", events, IRendezvousConnector, "start")
b._C = Dummy("c", events, ICode, b._C = Dummy("c", events, ICode, "allocate_code", "input_code",
"allocate_code", "input_code", "set_code") "set_code")
return b, events return b, events
def test_basic(self): def test_basic(self):
@ -1242,7 +1316,8 @@ class Boss(unittest.TestCase):
welcome = {"howdy": "how are ya"} welcome = {"howdy": "how are ya"}
b.rx_welcome(welcome) b.rx_welcome(welcome)
self.assertEqual(events, [("w.got_welcome", welcome), self.assertEqual(events, [
("w.got_welcome", welcome),
]) ])
events[:] = [] events[:] = []
@ -1252,7 +1327,8 @@ class Boss(unittest.TestCase):
b.got_verifier(b"verifier") b.got_verifier(b"verifier")
b.got_message("version", b"{}") b.got_message("version", b"{}")
b.got_message("0", b"msg1") b.got_message("0", b"msg1")
self.assertEqual(events, [("w.got_key", b"key"), self.assertEqual(events, [
("w.got_key", b"key"),
("w.got_verifier", b"verifier"), ("w.got_verifier", b"verifier"),
("w.got_versions", {}), ("w.got_versions", {}),
("w.received", b"msg1"), ("w.received", b"msg1"),
@ -1451,11 +1527,9 @@ class Rendezvous(unittest.TestCase):
journal = ImmediateJournal() journal = ImmediateJournal()
tor_manager = None tor_manager = None
client_version = ("python", __version__) client_version = ("python", __version__)
rc = _rendezvous.RendezvousConnector("ws://host:4000/v1", "appid", rc = _rendezvous.RendezvousConnector(
"side", reactor, "ws://host:4000/v1", "appid", "side", reactor, journal,
journal, tor_manager, tor_manager, timing.DebugTiming(), client_version)
timing.DebugTiming(),
client_version)
b = Dummy("b", events, IBoss, "error") b = Dummy("b", events, IBoss, "error")
n = Dummy("n", events, INameplate, "connected", "lost") n = Dummy("n", events, INameplate, "connected", "lost")
m = Dummy("m", events, IMailbox, "connected", "lost") m = Dummy("m", events, IMailbox, "connected", "lost")
@ -1488,36 +1562,45 @@ class Rendezvous(unittest.TestCase):
rc, events = self.build() rc, events = self.build()
ws = mock.Mock() ws = mock.Mock()
def notrandom(length): def notrandom(length):
return b"\x00" * length return b"\x00" * length
with mock.patch("os.urandom", notrandom): with mock.patch("os.urandom", notrandom):
rc.ws_open(ws) rc.ws_open(ws)
self.assertEqual(events, [("n.connected", ), self.assertEqual(events, [
("n.connected", ),
("m.connected", ), ("m.connected", ),
("l.connected", ), ("l.connected", ),
("a.connected", ), ("a.connected", ),
]) ])
events[:] = [] events[:] = []
def sent_messages(ws): def sent_messages(ws):
for c in ws.mock_calls: for c in ws.mock_calls:
self.assertEqual(c[0], "sendMessage", ws.mock_calls) self.assertEqual(c[0], "sendMessage", ws.mock_calls)
self.assertEqual(c[1][1], False, ws.mock_calls) self.assertEqual(c[1][1], False, ws.mock_calls)
yield bytes_to_dict(c[1][0]) yield bytes_to_dict(c[1][0])
self.assertEqual(list(sent_messages(ws)),
[dict(appid="appid", side="side", self.assertEqual(
list(sent_messages(ws)), [
dict(
appid="appid",
side="side",
client_version=["python", __version__], client_version=["python", __version__],
id="0000", type="bind"), id="0000",
type="bind"),
]) ])
rc.ws_close(True, None, None) rc.ws_close(True, None, None)
self.assertEqual(events, [("n.lost", ), self.assertEqual(events, [
("n.lost", ),
("m.lost", ), ("m.lost", ),
("l.lost", ), ("l.lost", ),
("a.lost", ), ("a.lost", ),
]) ])
# TODO # TODO
# #Send # #Send
# #Mailbox # #Mailbox

View File

@ -1,9 +1,11 @@
from twisted.trial import unittest
from twisted.internet.task import Clock from twisted.internet.task import Clock
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.trial import unittest
from ..eventual import EventualQueue from ..eventual import EventualQueue
from ..observer import OneShotObserver, SequenceObserver from ..observer import OneShotObserver, SequenceObserver
class OneShot(unittest.TestCase): class OneShot(unittest.TestCase):
def test_fire(self): def test_fire(self):
c = Clock() c = Clock()
@ -119,4 +121,3 @@ class Sequence(unittest.TestCase):
d2 = o.when_next_event() d2 = o.when_next_event()
eq.flush_sync() eq.flush_sync()
self.assertIdentical(self.failureResultOf(d2), f) self.assertIdentical(self.failureResultOf(d2), f)

View File

@ -1,39 +1,44 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
import mock
from itertools import count from itertools import count
from twisted.trial import unittest
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks from twisted.internet.defer import inlineCallbacks
from twisted.internet.threads import deferToThread from twisted.internet.threads import deferToThread
from .._rlcompleter import (input_with_completion, from twisted.trial import unittest
_input_code_with_completion,
CodeInputter, warn_readline) import mock
from ..errors import KeyFormatError, AlreadyInputNameplateError
from .._rlcompleter import (CodeInputter, _input_code_with_completion,
input_with_completion, warn_readline)
from ..errors import AlreadyInputNameplateError, KeyFormatError
APPID = "appid" APPID = "appid"
class Input(unittest.TestCase): class Input(unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_wrapper(self): def test_wrapper(self):
helper = object() helper = object()
trueish = object() trueish = object()
with mock.patch("wormhole._rlcompleter._input_code_with_completion", with mock.patch(
"wormhole._rlcompleter._input_code_with_completion",
return_value=trueish) as m: return_value=trueish) as m:
used_completion = yield input_with_completion("prompt:", helper, used_completion = yield input_with_completion(
reactor) "prompt:", helper, reactor)
self.assertIs(used_completion, trueish) self.assertIs(used_completion, trueish)
self.assertEqual(m.mock_calls, self.assertEqual(m.mock_calls, [mock.call("prompt:", helper, reactor)])
[mock.call("prompt:", helper, reactor)])
# note: if this test fails, the warn_readline() message will probably # note: if this test fails, the warn_readline() message will probably
# get written to stderr # get written to stderr
class Sync(unittest.TestCase): class Sync(unittest.TestCase):
# exercise _input_code_with_completion, which uses the blocking builtin # exercise _input_code_with_completion, which uses the blocking builtin
# "input()" function, hence _input_code_with_completion is usually in a # "input()" function, hence _input_code_with_completion is usually in a
# thread with deferToThread # thread with deferToThread
@mock.patch("wormhole._rlcompleter.CodeInputter") @mock.patch("wormhole._rlcompleter.CodeInputter")
@mock.patch("wormhole._rlcompleter.readline", @mock.patch("wormhole._rlcompleter.readline", __doc__="I am GNU readline")
__doc__="I am GNU readline")
@mock.patch("wormhole._rlcompleter.input", return_value="code") @mock.patch("wormhole._rlcompleter.input", return_value="code")
def test_readline(self, input, readline, ci): def test_readline(self, input, readline, ci):
c = mock.Mock(name="inhibit parenting") c = mock.Mock(name="inhibit parenting")
@ -49,8 +54,8 @@ class Sync(unittest.TestCase):
self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)]) self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)])
self.assertEqual(c.mock_calls, [mock.call.finish("code")]) self.assertEqual(c.mock_calls, [mock.call.finish("code")])
self.assertEqual(input.mock_calls, [mock.call(prompt)]) self.assertEqual(input.mock_calls, [mock.call(prompt)])
self.assertEqual(readline.mock_calls, self.assertEqual(readline.mock_calls, [
[mock.call.parse_and_bind("tab: complete"), mock.call.parse_and_bind("tab: complete"),
mock.call.set_completer(c.completer), mock.call.set_completer(c.completer),
mock.call.set_completer_delims(""), mock.call.set_completer_delims(""),
]) ])
@ -73,15 +78,14 @@ class Sync(unittest.TestCase):
self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)]) self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)])
self.assertEqual(c.mock_calls, [mock.call.finish("code")]) self.assertEqual(c.mock_calls, [mock.call.finish("code")])
self.assertEqual(input.mock_calls, [mock.call(prompt)]) self.assertEqual(input.mock_calls, [mock.call(prompt)])
self.assertEqual(readline.mock_calls, self.assertEqual(readline.mock_calls, [
[mock.call.parse_and_bind("tab: complete"), mock.call.parse_and_bind("tab: complete"),
mock.call.set_completer(c.completer), mock.call.set_completer(c.completer),
mock.call.set_completer_delims(""), mock.call.set_completer_delims(""),
]) ])
@mock.patch("wormhole._rlcompleter.CodeInputter") @mock.patch("wormhole._rlcompleter.CodeInputter")
@mock.patch("wormhole._rlcompleter.readline", @mock.patch("wormhole._rlcompleter.readline", __doc__="I am libedit")
__doc__="I am libedit")
@mock.patch("wormhole._rlcompleter.input", return_value="code") @mock.patch("wormhole._rlcompleter.input", return_value="code")
def test_libedit(self, input, readline, ci): def test_libedit(self, input, readline, ci):
c = mock.Mock(name="inhibit parenting") c = mock.Mock(name="inhibit parenting")
@ -97,8 +101,8 @@ class Sync(unittest.TestCase):
self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)]) self.assertEqual(ci.mock_calls, [mock.call(input_helper, reactor)])
self.assertEqual(c.mock_calls, [mock.call.finish("code")]) self.assertEqual(c.mock_calls, [mock.call.finish("code")])
self.assertEqual(input.mock_calls, [mock.call(prompt)]) self.assertEqual(input.mock_calls, [mock.call(prompt)])
self.assertEqual(readline.mock_calls, self.assertEqual(readline.mock_calls, [
[mock.call.parse_and_bind("bind ^I rl_complete"), mock.call.parse_and_bind("bind ^I rl_complete"),
mock.call.set_completer(c.completer), mock.call.set_completer(c.completer),
mock.call.set_completer_delims(""), mock.call.set_completer_delims(""),
]) ])
@ -139,6 +143,7 @@ class Sync(unittest.TestCase):
self.assertEqual(c.mock_calls, [mock.call.finish(u"code")]) self.assertEqual(c.mock_calls, [mock.call.finish(u"code")])
self.assertEqual(input.mock_calls, [mock.call(prompt)]) self.assertEqual(input.mock_calls, [mock.call(prompt)])
def get_completions(c, prefix): def get_completions(c, prefix):
completions = [] completions = []
for state in count(0): for state in count(0):
@ -147,9 +152,11 @@ def get_completions(c, prefix):
return completions return completions
completions.append(text) completions.append(text)
def fake_blockingCallFromThread(f, *a, **kw): def fake_blockingCallFromThread(f, *a, **kw):
return f(*a, **kw) return f(*a, **kw)
class Completion(unittest.TestCase): class Completion(unittest.TestCase):
def test_simple(self): def test_simple(self):
# no actual completion # no actual completion
@ -158,11 +165,13 @@ class Completion(unittest.TestCase):
c.bcft = fake_blockingCallFromThread c.bcft = fake_blockingCallFromThread
c.finish("1-code-ghost") c.finish("1-code-ghost")
self.assertFalse(c.used_completion) self.assertFalse(c.used_completion)
self.assertEqual(helper.mock_calls, self.assertEqual(helper.mock_calls, [
[mock.call.choose_nameplate("1"), mock.call.choose_nameplate("1"),
mock.call.choose_words("code-ghost")]) mock.call.choose_words("code-ghost")
])
@mock.patch("wormhole._rlcompleter.readline", @mock.patch(
"wormhole._rlcompleter.readline",
get_completion_type=mock.Mock(return_value=0)) get_completion_type=mock.Mock(return_value=0))
def test_call(self, readline): def test_call(self, readline):
# check that it calls _commit_and_build_completions correctly # check that it calls _commit_and_build_completions correctly
@ -188,14 +197,14 @@ class Completion(unittest.TestCase):
# now we have three "a" words: "and", "ark", "aaah!zombies!!" # now we have three "a" words: "and", "ark", "aaah!zombies!!"
cabc.reset_mock() cabc.reset_mock()
cabc.configure_mock(return_value=["aargh", "ark", "aaah!zombies!!"]) cabc.configure_mock(return_value=["aargh", "ark", "aaah!zombies!!"])
self.assertEqual(get_completions(c, "12-a"), self.assertEqual(
["aargh", "ark", "aaah!zombies!!"]) get_completions(c, "12-a"), ["aargh", "ark", "aaah!zombies!!"])
self.assertEqual(cabc.mock_calls, [mock.call("12-a")]) self.assertEqual(cabc.mock_calls, [mock.call("12-a")])
cabc.reset_mock() cabc.reset_mock()
cabc.configure_mock(return_value=["aargh", "aaah!zombies!!"]) cabc.configure_mock(return_value=["aargh", "aaah!zombies!!"])
self.assertEqual(get_completions(c, "12-aa"), self.assertEqual(
["aargh", "aaah!zombies!!"]) get_completions(c, "12-aa"), ["aargh", "aaah!zombies!!"])
self.assertEqual(cabc.mock_calls, [mock.call("12-aa")]) self.assertEqual(cabc.mock_calls, [mock.call("12-aa")])
cabc.reset_mock() cabc.reset_mock()
@ -227,7 +236,8 @@ class Completion(unittest.TestCase):
cn = mock.Mock() # choose_nameplate cn = mock.Mock() # choose_nameplate
gwc = mock.Mock() # get_word_completions gwc = mock.Mock() # get_word_completions
cw = mock.Mock() # choose_words cw = mock.Mock() # choose_words
helper = mock.Mock(refresh_nameplates=rn, helper = mock.Mock(
refresh_nameplates=rn,
get_nameplate_completions=gnc, get_nameplate_completions=gnc,
choose_nameplate=cn, choose_nameplate=cn,
get_word_completions=gwc, get_word_completions=gwc,
@ -328,16 +338,17 @@ class Completion(unittest.TestCase):
c = CodeInputter(helper, reactor) c = CodeInputter(helper, reactor)
cabc = c._commit_and_build_completions cabc = c._commit_and_build_completions
matches = yield deferToThread(cabc, "1-co") # this commits us to 1- matches = yield deferToThread(cabc, "1-co") # this commits us to 1-
self.assertEqual(helper.mock_calls, self.assertEqual(helper.mock_calls, [
[mock.call.choose_nameplate("1"), mock.call.choose_nameplate("1"),
mock.call.when_wordlist_is_available(), mock.call.when_wordlist_is_available(),
mock.call.get_word_completions("co")]) mock.call.get_word_completions("co")
])
self.assertEqual(matches, ["1-code", "1-court"]) self.assertEqual(matches, ["1-code", "1-court"])
helper.reset_mock() helper.reset_mock()
with self.assertRaises(AlreadyInputNameplateError) as e: with self.assertRaises(AlreadyInputNameplateError) as e:
yield deferToThread(cabc, "2-co") yield deferToThread(cabc, "2-co")
self.assertEqual(str(e.exception), self.assertEqual(
"nameplate (1-) already entered, cannot go back") str(e.exception), "nameplate (1-) already entered, cannot go back")
self.assertEqual(helper.mock_calls, []) self.assertEqual(helper.mock_calls, [])
@inlineCallbacks @inlineCallbacks
@ -348,16 +359,17 @@ class Completion(unittest.TestCase):
c = CodeInputter(helper, reactor) c = CodeInputter(helper, reactor)
cabc = c._commit_and_build_completions cabc = c._commit_and_build_completions
matches = yield deferToThread(cabc, "1-co") # this commits us to 1- matches = yield deferToThread(cabc, "1-co") # this commits us to 1-
self.assertEqual(helper.mock_calls, self.assertEqual(helper.mock_calls, [
[mock.call.choose_nameplate("1"), mock.call.choose_nameplate("1"),
mock.call.when_wordlist_is_available(), mock.call.when_wordlist_is_available(),
mock.call.get_word_completions("co")]) mock.call.get_word_completions("co")
])
self.assertEqual(matches, ["1-code", "1-court"]) self.assertEqual(matches, ["1-code", "1-court"])
helper.reset_mock() helper.reset_mock()
with self.assertRaises(AlreadyInputNameplateError) as e: with self.assertRaises(AlreadyInputNameplateError) as e:
yield deferToThread(c.finish, "2-code") yield deferToThread(c.finish, "2-code")
self.assertEqual(str(e.exception), self.assertEqual(
"nameplate (1-) already entered, cannot go back") str(e.exception), "nameplate (1-) already entered, cannot go back")
self.assertEqual(helper.mock_calls, []) self.assertEqual(helper.mock_calls, [])
@mock.patch("wormhole._rlcompleter.stderr") @mock.patch("wormhole._rlcompleter.stderr")
@ -367,5 +379,6 @@ class Completion(unittest.TestCase):
# trigger", but let's at least make sure it's invocable # trigger", but let's at least make sure it's invocable
warn_readline() warn_readline()
expected = "\nCommand interrupted: please press Return to quit" expected = "\nCommand interrupted: please press Return to quit"
self.assertEqual(stderr.mock_calls, [mock.call.write(expected), self.assertEqual(stderr.mock_calls,
[mock.call.write(expected),
mock.call.write("\n")]) mock.call.write("\n")])

View File

@ -1,10 +1,15 @@
import os, io import io
import mock import os
from twisted.trial import unittest from twisted.trial import unittest
import mock
from ..cli import cmd_ssh from ..cli import cmd_ssh
OTHERS = ["config", "config~", "known_hosts", "known_hosts~"] OTHERS = ["config", "config~", "known_hosts", "known_hosts~"]
class FindPubkey(unittest.TestCase): class FindPubkey(unittest.TestCase):
def test_find_one(self): def test_find_one(self):
files = OTHERS + ["id_rsa.pub", "id_rsa"] files = OTHERS + ["id_rsa.pub", "id_rsa"]
@ -12,8 +17,8 @@ class FindPubkey(unittest.TestCase):
pubkey_file = io.StringIO(pubkey_data) pubkey_file = io.StringIO(pubkey_data)
with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True): with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True):
with mock.patch("os.listdir", return_value=files) as ld: with mock.patch("os.listdir", return_value=files) as ld:
with mock.patch("wormhole.cli.cmd_ssh.open", with mock.patch(
return_value=pubkey_file): "wormhole.cli.cmd_ssh.open", return_value=pubkey_file):
res = cmd_ssh.find_public_key() res = cmd_ssh.find_public_key()
self.assertEqual(ld.mock_calls, self.assertEqual(ld.mock_calls,
[mock.call(os.path.expanduser("~/.ssh/"))]) [mock.call(os.path.expanduser("~/.ssh/"))])
@ -34,12 +39,12 @@ class FindPubkey(unittest.TestCase):
def test_bad_hint(self): def test_bad_hint(self):
with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=False): with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=False):
e = self.assertRaises(cmd_ssh.PubkeyError, e = self.assertRaises(
cmd_ssh.PubkeyError,
cmd_ssh.find_public_key, cmd_ssh.find_public_key,
hint="bogus/path") hint="bogus/path")
self.assertEqual(str(e), "Can't find 'bogus/path'") self.assertEqual(str(e), "Can't find 'bogus/path'")
def test_find_multiple(self): def test_find_multiple(self):
files = OTHERS + ["id_rsa.pub", "id_rsa", "id_dsa.pub", "id_dsa"] files = OTHERS + ["id_rsa.pub", "id_rsa", "id_dsa.pub", "id_dsa"]
pubkey_data = u"ssh-rsa AAAAkeystuff email@host\n" pubkey_data = u"ssh-rsa AAAAkeystuff email@host\n"
@ -47,9 +52,10 @@ class FindPubkey(unittest.TestCase):
with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True): with mock.patch("wormhole.cli.cmd_ssh.exists", return_value=True):
with mock.patch("os.listdir", return_value=files): with mock.patch("os.listdir", return_value=files):
responses = iter(["frog", "NaN", "-1", "0"]) responses = iter(["frog", "NaN", "-1", "0"])
with mock.patch("click.prompt", with mock.patch(
side_effect=lambda p: next(responses)): "click.prompt", side_effect=lambda p: next(responses)):
with mock.patch("wormhole.cli.cmd_ssh.open", with mock.patch(
"wormhole.cli.cmd_ssh.open",
return_value=pubkey_file): return_value=pubkey_file):
res = cmd_ssh.find_public_key() res = cmd_ssh.find_public_key()
self.assertEqual(len(res), 3, res) self.assertEqual(len(res), 3, res)

View File

@ -1,33 +1,41 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import mock, io
from twisted.trial import unittest import io
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.error import ConnectError from twisted.internet.error import ConnectError
from twisted.trial import unittest
import mock
from ..tor_manager import get_tor, SocksOnlyTor
from ..errors import NoTorError
from .._interfaces import ITorManager from .._interfaces import ITorManager
from ..errors import NoTorError
from ..tor_manager import SocksOnlyTor, get_tor
class X(): class X():
pass pass
class Tor(unittest.TestCase): class Tor(unittest.TestCase):
def test_no_txtorcon(self): def test_no_txtorcon(self):
with mock.patch("wormhole.tor_manager.txtorcon", None): with mock.patch("wormhole.tor_manager.txtorcon", None):
self.failureResultOf(get_tor(None), NoTorError) self.failureResultOf(get_tor(None), NoTorError)
def test_bad_args(self): def test_bad_args(self):
f = self.failureResultOf(get_tor(None, launch_tor="not boolean"), f = self.failureResultOf(
TypeError) get_tor(None, launch_tor="not boolean"), TypeError)
self.assertEqual(str(f.value), "launch_tor= must be boolean") self.assertEqual(str(f.value), "launch_tor= must be boolean")
f = self.failureResultOf(get_tor(None, tor_control_port=1234), f = self.failureResultOf(
TypeError) get_tor(None, tor_control_port=1234), TypeError)
self.assertEqual(str(f.value), "tor_control_port= must be str or None") self.assertEqual(str(f.value), "tor_control_port= must be str or None")
f = self.failureResultOf(get_tor(None, launch_tor=True, f = self.failureResultOf(
tor_control_port="tcp:127.0.0.1:1234"), get_tor(
None, launch_tor=True, tor_control_port="tcp:127.0.0.1:1234"),
ValueError) ValueError)
self.assertEqual(str(f.value), self.assertEqual(
str(f.value),
"cannot combine --launch-tor and --tor-control-port=") "cannot combine --launch-tor and --tor-control-port=")
def test_launch(self): def test_launch(self):
@ -35,7 +43,8 @@ class Tor(unittest.TestCase):
my_tor = X() # object() didn't like providedBy() my_tor = X() # object() didn't like providedBy()
launch_d = defer.Deferred() launch_d = defer.Deferred()
stderr = io.StringIO() stderr = io.StringIO()
with mock.patch("wormhole.tor_manager.txtorcon.launch", with mock.patch(
"wormhole.tor_manager.txtorcon.launch",
side_effect=launch_d) as launch: side_effect=launch_d) as launch:
d = get_tor(reactor, launch_tor=True, stderr=stderr) d = get_tor(reactor, launch_tor=True, stderr=stderr)
self.assertNoResult(d) self.assertNoResult(d)
@ -44,7 +53,8 @@ class Tor(unittest.TestCase):
tor = self.successResultOf(d) tor = self.successResultOf(d)
self.assertIs(tor, my_tor) self.assertIs(tor, my_tor)
self.assert_(ITorManager.providedBy(tor)) self.assert_(ITorManager.providedBy(tor))
self.assertEqual(stderr.getvalue(), self.assertEqual(
stderr.getvalue(),
" launching a new Tor process, this may take a while..\n") " launching a new Tor process, this may take a while..\n")
def test_connect(self): def test_connect(self):
@ -52,9 +62,11 @@ class Tor(unittest.TestCase):
my_tor = X() # object() didn't like providedBy() my_tor = X() # object() didn't like providedBy()
connect_d = defer.Deferred() connect_d = defer.Deferred()
stderr = io.StringIO() stderr = io.StringIO()
with mock.patch("wormhole.tor_manager.txtorcon.connect", with mock.patch(
"wormhole.tor_manager.txtorcon.connect",
side_effect=connect_d) as connect: side_effect=connect_d) as connect:
with mock.patch("wormhole.tor_manager.clientFromString", with mock.patch(
"wormhole.tor_manager.clientFromString",
side_effect=["foo"]) as sfs: side_effect=["foo"]) as sfs:
d = get_tor(reactor, stderr=stderr) d = get_tor(reactor, stderr=stderr)
self.assertEqual(sfs.mock_calls, []) self.assertEqual(sfs.mock_calls, [])
@ -71,9 +83,11 @@ class Tor(unittest.TestCase):
reactor = object() reactor = object()
connect_d = defer.Deferred() connect_d = defer.Deferred()
stderr = io.StringIO() stderr = io.StringIO()
with mock.patch("wormhole.tor_manager.txtorcon.connect", with mock.patch(
"wormhole.tor_manager.txtorcon.connect",
side_effect=connect_d) as connect: side_effect=connect_d) as connect:
with mock.patch("wormhole.tor_manager.clientFromString", with mock.patch(
"wormhole.tor_manager.clientFromString",
side_effect=["foo"]) as sfs: side_effect=["foo"]) as sfs:
d = get_tor(reactor, stderr=stderr) d = get_tor(reactor, stderr=stderr)
self.assertEqual(sfs.mock_calls, []) self.assertEqual(sfs.mock_calls, [])
@ -85,7 +99,8 @@ class Tor(unittest.TestCase):
self.assertIsInstance(tor, SocksOnlyTor) self.assertIsInstance(tor, SocksOnlyTor)
self.assert_(ITorManager.providedBy(tor)) self.assert_(ITorManager.providedBy(tor))
self.assertEqual(tor._reactor, reactor) self.assertEqual(tor._reactor, reactor)
self.assertEqual(stderr.getvalue(), self.assertEqual(
stderr.getvalue(),
" unable to find default Tor control port, using SOCKS\n") " unable to find default Tor control port, using SOCKS\n")
def test_connect_custom_control_port(self): def test_connect_custom_control_port(self):
@ -95,9 +110,11 @@ class Tor(unittest.TestCase):
ep = object() ep = object()
connect_d = defer.Deferred() connect_d = defer.Deferred()
stderr = io.StringIO() stderr = io.StringIO()
with mock.patch("wormhole.tor_manager.txtorcon.connect", with mock.patch(
"wormhole.tor_manager.txtorcon.connect",
side_effect=connect_d) as connect: side_effect=connect_d) as connect:
with mock.patch("wormhole.tor_manager.clientFromString", with mock.patch(
"wormhole.tor_manager.clientFromString",
side_effect=[ep]) as sfs: side_effect=[ep]) as sfs:
d = get_tor(reactor, tor_control_port=tcp, stderr=stderr) d = get_tor(reactor, tor_control_port=tcp, stderr=stderr)
self.assertEqual(sfs.mock_calls, [mock.call(reactor, tcp)]) self.assertEqual(sfs.mock_calls, [mock.call(reactor, tcp)])
@ -116,9 +133,11 @@ class Tor(unittest.TestCase):
ep = object() ep = object()
connect_d = defer.Deferred() connect_d = defer.Deferred()
stderr = io.StringIO() stderr = io.StringIO()
with mock.patch("wormhole.tor_manager.txtorcon.connect", with mock.patch(
"wormhole.tor_manager.txtorcon.connect",
side_effect=connect_d) as connect: side_effect=connect_d) as connect:
with mock.patch("wormhole.tor_manager.clientFromString", with mock.patch(
"wormhole.tor_manager.clientFromString",
side_effect=[ep]) as sfs: side_effect=[ep]) as sfs:
d = get_tor(reactor, tor_control_port=tcp, stderr=stderr) d = get_tor(reactor, tor_control_port=tcp, stderr=stderr)
self.assertEqual(sfs.mock_calls, [mock.call(reactor, tcp)]) self.assertEqual(sfs.mock_calls, [mock.call(reactor, tcp)])
@ -129,18 +148,22 @@ class Tor(unittest.TestCase):
self.failureResultOf(d, ConnectError) self.failureResultOf(d, ConnectError)
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
class SocksOnly(unittest.TestCase): class SocksOnly(unittest.TestCase):
def test_tor(self): def test_tor(self):
reactor = object() reactor = object()
sot = SocksOnlyTor(reactor) sot = SocksOnlyTor(reactor)
fake_ep = object() fake_ep = object()
with mock.patch("wormhole.tor_manager.txtorcon.TorClientEndpoint", with mock.patch(
"wormhole.tor_manager.txtorcon.TorClientEndpoint",
return_value=fake_ep) as tce: return_value=fake_ep) as tce:
ep = sot.stream_via("host", "port") ep = sot.stream_via("host", "port")
self.assertIs(ep, fake_ep) self.assertIs(ep, fake_ep)
self.assertEqual(tce.mock_calls, [mock.call("host", "port", self.assertEqual(tce.mock_calls, [
mock.call(
"host",
"port",
socks_endpoint=None, socks_endpoint=None,
tls=False, tls=False,
reactor=reactor)]) reactor=reactor)
])

View File

@ -1,27 +1,33 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import six
import io
import gc import gc
import mock import io
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from collections import namedtuple from collections import namedtuple
from twisted.trial import unittest
from twisted.internet import defer, task, endpoints, protocol, address, error import six
from nacl.exceptions import CryptoError
from nacl.secret import SecretBox
from twisted.internet import address, defer, endpoints, error, protocol, task
from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.internet.defer import gatherResults, inlineCallbacks
from twisted.python import log from twisted.python import log
from twisted.test import proto_helpers from twisted.test import proto_helpers
from twisted.trial import unittest
import mock
from wormhole_transit_relay import transit_server from wormhole_transit_relay import transit_server
from ..errors import InternalError
from .. import transit from .. import transit
from ..errors import InternalError
from .common import ServerBase from .common import ServerBase
from nacl.secret import SecretBox
from nacl.exceptions import CryptoError
class Highlander(unittest.TestCase): class Highlander(unittest.TestCase):
def test_one_winner(self): def test_one_winner(self):
cancelled = set() cancelled = set()
contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) contenders = [
for i in range(5)] defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(5)
]
d = transit.there_can_be_only_one(contenders) d = transit.there_can_be_only_one(contenders)
self.assertNoResult(d) self.assertNoResult(d)
contenders[0].errback(ValueError()) contenders[0].errback(ValueError())
@ -34,8 +40,9 @@ class Highlander(unittest.TestCase):
def test_there_might_also_be_none(self): def test_there_might_also_be_none(self):
cancelled = set() cancelled = set()
contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) contenders = [
for i in range(4)] defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4)
]
d = transit.there_can_be_only_one(contenders) d = transit.there_can_be_only_one(contenders)
self.assertNoResult(d) self.assertNoResult(d)
contenders[0].errback(ValueError()) contenders[0].errback(ValueError())
@ -50,8 +57,9 @@ class Highlander(unittest.TestCase):
def test_cancel_early(self): def test_cancel_early(self):
cancelled = set() cancelled = set()
contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) contenders = [
for i in range(4)] defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4)
]
d = transit.there_can_be_only_one(contenders) d = transit.there_can_be_only_one(contenders)
self.assertNoResult(d) self.assertNoResult(d)
self.assertEqual(cancelled, set()) self.assertEqual(cancelled, set())
@ -61,8 +69,9 @@ class Highlander(unittest.TestCase):
def test_cancel_after_one_failure(self): def test_cancel_after_one_failure(self):
cancelled = set() cancelled = set()
contenders = [defer.Deferred(lambda d, i=i: cancelled.add(i)) contenders = [
for i in range(4)] defer.Deferred(lambda d, i=i: cancelled.add(i)) for i in range(4)
]
d = transit.there_can_be_only_one(contenders) d = transit.there_can_be_only_one(contenders)
self.assertNoResult(d) self.assertNoResult(d)
self.assertEqual(cancelled, set()) self.assertEqual(cancelled, set())
@ -71,6 +80,7 @@ class Highlander(unittest.TestCase):
self.failureResultOf(d, ValueError) self.failureResultOf(d, ValueError)
self.assertEqual(cancelled, set([1, 2, 3])) self.assertEqual(cancelled, set([1, 2, 3]))
class Forever(unittest.TestCase): class Forever(unittest.TestCase):
def _forever_setup(self): def _forever_setup(self):
clock = task.Clock() clock = task.Clock()
@ -116,6 +126,7 @@ class Forever(unittest.TestCase):
self.failureResultOf(d, defer.CancelledError) self.failureResultOf(d, defer.CancelledError)
self.assertNot(clock.getDelayedCalls()) self.assertNot(clock.getDelayedCalls())
class Misc(unittest.TestCase): class Misc(unittest.TestCase):
def test_allocate_port(self): def test_allocate_port(self):
portno = transit.allocate_tcp_port() portno = transit.allocate_tcp_port()
@ -128,29 +139,36 @@ class Misc(unittest.TestCase):
portno = transit.allocate_tcp_port() portno = transit.allocate_tcp_port()
self.assertIsInstance(portno, int) self.assertIsInstance(portno, int)
UnknownHint = namedtuple("UnknownHint", ["stuff"]) UnknownHint = namedtuple("UnknownHint", ["stuff"])
class Hints(unittest.TestCase): class Hints(unittest.TestCase):
def test_endpoint_from_hint_obj(self): def test_endpoint_from_hint_obj(self):
c = transit.Common("") c = transit.Common("")
efho = c._endpoint_from_hint_obj efho = c._endpoint_from_hint_obj
self.assertIsInstance(efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), self.assertIsInstance(
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
endpoints.HostnameEndpoint) endpoints.HostnameEndpoint)
self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None)
# c._tor is currently None # c._tor is currently None
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None) self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None)
c._tor = mock.Mock() c._tor = mock.Mock()
def tor_ep(hostname, port): def tor_ep(hostname, port):
if hostname == "non-public": if hostname == "non-public":
return None return None
return ("tor_ep", hostname, port) return ("tor_ep", hostname, port)
c._tor.stream_via = mock.Mock(side_effect=tor_ep) c._tor.stream_via = mock.Mock(side_effect=tor_ep)
self.assertEqual(efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), self.assertEqual(
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
("tor_ep", "host", 1234)) ("tor_ep", "host", 1234))
self.assertEqual(efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)), self.assertEqual(
efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)),
("tor_ep", "host2.onion", 1234)) ("tor_ep", "host2.onion", 1234))
self.assertEqual(efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), self.assertEqual(
None) efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), None)
self.assertEqual(efho(UnknownHint("foo")), None) self.assertEqual(efho(UnknownHint("foo")), None)
def test_comparable(self): def test_comparable(self):
@ -170,29 +188,47 @@ class Hints(unittest.TestCase):
self.assertEqual(p({"type": "unknown"}), None) self.assertEqual(p({"type": "unknown"}), None)
h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234}) h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234})
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0)) self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0))
h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234, h = p({
"priority": 2.5}) "type": "direct-tcp-v1",
"hostname": "foo",
"port": 1234,
"priority": 2.5
})
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5)) self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5))
h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234}) h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234})
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0)) self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0))
h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234, h = p({
"priority": 2.5}) "type": "tor-tcp-v1",
"hostname": "foo",
"port": 1234,
"priority": 2.5
})
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5)) self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5))
self.assertEqual(p({"type": "direct-tcp-v1"}), self.assertEqual(p({
None) # missing hostname "type": "direct-tcp-v1"
self.assertEqual(p({"type": "direct-tcp-v1", "hostname": 12}), }), None) # missing hostname
None) # invalid hostname self.assertEqual(p({
self.assertEqual(p({"type": "direct-tcp-v1", "hostname": "foo"}), "type": "direct-tcp-v1",
None) # missing port "hostname": 12
self.assertEqual(p({"type": "direct-tcp-v1", "hostname": "foo", }), None) # invalid hostname
"port": "not a number"}), self.assertEqual(
None) # invalid port p({
"type": "direct-tcp-v1",
"hostname": "foo"
}), None) # missing port
self.assertEqual(
p({
"type": "direct-tcp-v1",
"hostname": "foo",
"port": "not a number"
}), None) # invalid port
def test_parse_hint_argv(self): def test_parse_hint_argv(self):
def p(hint): def p(hint):
stderr = io.StringIO() stderr = io.StringIO()
value = transit.parse_hint_argv(hint, stderr=stderr) value = transit.parse_hint_argv(hint, stderr=stderr)
return value, stderr.getvalue() return value, stderr.getvalue()
h, stderr = p("tcp:host:1234") h, stderr = p("tcp:host:1234")
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0))
self.assertEqual(stderr, "") self.assertEqual(stderr, "")
@ -216,7 +252,8 @@ class Hints(unittest.TestCase):
h, stderr = p("tcp:just-a-hostname") h, stderr = p("tcp:just-a-hostname")
self.assertEqual(h, None) self.assertEqual(h, None)
self.assertEqual(stderr, self.assertEqual(
stderr,
"unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n") "unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n")
h, stderr = p("tcp:host:number") h, stderr = p("tcp:host:number")
@ -226,17 +263,19 @@ class Hints(unittest.TestCase):
h, stderr = p("tcp:host:1234:priority=bad") h, stderr = p("tcp:host:1234:priority=bad")
self.assertEqual(h, None) self.assertEqual(h, None)
self.assertEqual(stderr, self.assertEqual(
stderr,
"non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n") "non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n")
def test_describe_hint_obj(self): def test_describe_hint_obj(self):
d = transit.describe_hint_obj d = transit.describe_hint_obj
self.assertEqual(d(transit.DirectTCPV1Hint("host", 1234, 0.0)), self.assertEqual(
"tcp:host:1234") d(transit.DirectTCPV1Hint("host", 1234, 0.0)), "tcp:host:1234")
self.assertEqual(d(transit.TorTCPV1Hint("host", 1234, 0.0)), self.assertEqual(
"tor:host:1234") d(transit.TorTCPV1Hint("host", 1234, 0.0)), "tor:host:1234")
self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff"))) self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff")))
# ipaddrs.py currently uses native strings: bytes on py2, unicode on # ipaddrs.py currently uses native strings: bytes on py2, unicode on
# py3 # py3
if six.PY2: if six.PY2:
@ -246,17 +285,22 @@ else:
LOOPADDR = "127.0.0.1" # unicode_literals LOOPADDR = "127.0.0.1" # unicode_literals
OTHERADDR = "1.2.3.4" OTHERADDR = "1.2.3.4"
class Basic(unittest.TestCase): class Basic(unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_relay_hints(self): def test_relay_hints(self):
URL = "tcp:host:1234" URL = "tcp:host:1234"
c = transit.Common(URL, no_listen=True) c = transit.Common(URL, no_listen=True)
hints = yield c.get_connection_hints() hints = yield c.get_connection_hints()
self.assertEqual(hints, [{"type": "relay-v1", self.assertEqual(hints, [{
"hints": [{"type": "direct-tcp-v1", "type":
"relay-v1",
"hints": [{
"type": "direct-tcp-v1",
"hostname": "host", "hostname": "host",
"port": 1234, "port": 1234,
"priority": 0.0}], "priority": 0.0
}],
}]) }])
self.assertRaises(InternalError, transit.Common, 123) self.assertRaises(InternalError, transit.Common, 123)
@ -269,8 +313,12 @@ class Basic(unittest.TestCase):
def test_ignore_bad_hints(self): def test_ignore_bad_hints(self):
c = transit.Common("") c = transit.Common("")
c.add_connection_hints([{"type": "unknown"}]) c.add_connection_hints([{"type": "unknown"}])
c.add_connection_hints([{"type": "relay-v1", c.add_connection_hints([{
"hints": [{"type": "unknown"}]}]) "type": "relay-v1",
"hints": [{
"type": "unknown"
}]
}])
self.assertEqual(c._their_direct_hints, []) self.assertEqual(c._their_direct_hints, [])
self.assertEqual(c._our_relay_hints, set()) self.assertEqual(c._our_relay_hints, set())
@ -290,7 +338,8 @@ class Basic(unittest.TestCase):
def test_ignore_localhost_hint(self): def test_ignore_localhost_hint(self):
# this actually starts the listener # this actually starts the listener
c = transit.TransitSender("") c = transit.TransitSender("")
with mock.patch("wormhole.ipaddrs.find_addresses", with mock.patch(
"wormhole.ipaddrs.find_addresses",
return_value=[LOOPADDR, OTHERADDR]): return_value=[LOOPADDR, OTHERADDR]):
hints = self.successResultOf(c.get_connection_hints()) hints = self.successResultOf(c.get_connection_hints())
c._stop_listening() c._stop_listening()
@ -302,8 +351,8 @@ class Basic(unittest.TestCase):
def test_keep_only_localhost_hint(self): def test_keep_only_localhost_hint(self):
# this actually starts the listener # this actually starts the listener
c = transit.TransitSender("") c = transit.TransitSender("")
with mock.patch("wormhole.ipaddrs.find_addresses", with mock.patch(
return_value=[LOOPADDR]): "wormhole.ipaddrs.find_addresses", return_value=[LOOPADDR]):
hints = self.successResultOf(c.get_connection_hints()) hints = self.successResultOf(c.get_connection_hints())
c._stop_listening() c._stop_listening()
# If the only hint is localhost, it should stay. # If the only hint is localhost, it should stay.
@ -313,8 +362,13 @@ class Basic(unittest.TestCase):
def test_abilities(self): def test_abilities(self):
c = transit.Common(None, no_listen=True) c = transit.Common(None, no_listen=True)
abilities = c.get_connection_abilities() abilities = c.get_connection_abilities()
self.assertEqual(abilities, [{"type": "direct-tcp-v1"}, self.assertEqual(abilities, [
{"type": "relay-v1"}, {
"type": "direct-tcp-v1"
},
{
"type": "relay-v1"
},
]) ])
def test_transit_key_wait(self): def test_transit_key_wait(self):
@ -339,19 +393,31 @@ class Basic(unittest.TestCase):
r = transit.TransitReceiver("") r = transit.TransitReceiver("")
r.set_transit_key(KEY) r.set_transit_key(KEY)
self.assertEqual(s._send_this(), b"transit sender 559bdeae4b49fa6a23378d2b68f4c7e69378615d4af049c371c6a26e82391089 ready\n\n") self.assertEqual(s._send_this(), (
b"transit sender "
b"559bdeae4b49fa6a23378d2b68f4c7e69378615d4af049c371c6a26e82391089"
b" ready\n\n"))
self.assertEqual(s._send_this(), r._expect_this()) self.assertEqual(s._send_this(), r._expect_this())
self.assertEqual(r._send_this(), b"transit receiver ed447528194bac4c00d0c854b12a97ce51413d89aa74d6304475f516fdc23a1b ready\n\n") self.assertEqual(r._send_this(), (
b"transit receiver "
b"ed447528194bac4c00d0c854b12a97ce51413d89aa74d6304475f516fdc23a1b"
b" ready\n\n"))
self.assertEqual(r._send_this(), s._expect_this()) self.assertEqual(r._send_this(), s._expect_this())
self.assertEqual(hexlify(s._sender_record_key()), b"5a2fba3a9e524ab2e2823ff53b05f946896f6e4ce4e282ffd8e3ac0e5e9e0cda") self.assertEqual(
self.assertEqual(hexlify(s._sender_record_key()), hexlify(s._sender_record_key()),
hexlify(r._receiver_record_key())) b"5a2fba3a9e524ab2e2823ff53b05f946896f6e4ce4e282ffd8e3ac0e5e9e0cda"
)
self.assertEqual(
hexlify(s._sender_record_key()), hexlify(r._receiver_record_key()))
self.assertEqual(hexlify(r._sender_record_key()), b"eedb143117249f45b39da324decf6bd9aae33b7ccd58487436de611a3c6b871d") self.assertEqual(
self.assertEqual(hexlify(r._sender_record_key()), hexlify(r._sender_record_key()),
hexlify(s._receiver_record_key())) b"eedb143117249f45b39da324decf6bd9aae33b7ccd58487436de611a3c6b871d"
)
self.assertEqual(
hexlify(r._sender_record_key()), hexlify(s._receiver_record_key()))
def test_connection_ready(self): def test_connection_ready(self):
s = transit.TransitSender("") s = transit.TransitSender("")
@ -422,19 +488,24 @@ class DummyProtocol(protocol.Protocol):
if self._d2: if self._d2:
self._d2.callback(None) self._d2.callback(None)
class FakeTransport: class FakeTransport:
signalConnectionLost = True signalConnectionLost = True
def __init__(self, p, peeraddr): def __init__(self, p, peeraddr):
self.protocol = p self.protocol = p
self._peeraddr = peeraddr self._peeraddr = peeraddr
self._buf = b"" self._buf = b""
self._connected = True self._connected = True
def write(self, data): def write(self, data):
self._buf += data self._buf += data
def loseConnection(self): def loseConnection(self):
self._connected = False self._connected = False
if self.signalConnectionLost: if self.signalConnectionLost:
self.protocol.connectionLost() self.protocol.connectionLost()
def getPeer(self): def getPeer(self):
return self._peeraddr return self._peeraddr
@ -443,17 +514,21 @@ class FakeTransport:
self._buf = b"" self._buf = b""
return b return b
class RandomError(Exception): class RandomError(Exception):
pass pass
class MockConnection: class MockConnection:
def __init__(self, owner, relay_handshake, start, description): def __init__(self, owner, relay_handshake, start, description):
self.owner = owner self.owner = owner
self.relay_handshake = relay_handshake self.relay_handshake = relay_handshake
self.start = start self.start = start
self._description = description self._description = description
def cancel(d): def cancel(d):
self._cancelled = True self._cancelled = True
self._d = defer.Deferred(cancel) self._d = defer.Deferred(cancel)
self._start_negotiation_called = False self._start_negotiation_called = False
self._cancelled = False self._cancelled = False
@ -462,6 +537,7 @@ class MockConnection:
self._start_negotiation_called = True self._start_negotiation_called = True
return self._d return self._d
class InboundConnectionFactory(unittest.TestCase): class InboundConnectionFactory(unittest.TestCase):
def test_describe(self): def test_describe(self):
f = transit.InboundConnectionFactory(None) f = transit.InboundConnectionFactory(None)
@ -472,8 +548,8 @@ class InboundConnectionFactory(unittest.TestCase):
addr6 = address.IPv6Address("TCP", "::1", 1234) addr6 = address.IPv6Address("TCP", "::1", 1234)
self.assertEqual(f._describePeer(addr6), "<-::1:1234") self.assertEqual(f._describePeer(addr6), "<-::1:1234")
addrU = address.UNIXAddress("/dev/unlikely") addrU = address.UNIXAddress("/dev/unlikely")
self.assertEqual(f._describePeer(addrU), self.assertEqual(
"<-UNIXAddress('/dev/unlikely')") f._describePeer(addrU), "<-UNIXAddress('/dev/unlikely')")
def test_success(self): def test_success(self):
f = transit.InboundConnectionFactory("owner") f = transit.InboundConnectionFactory("owner")
@ -586,8 +662,10 @@ class InboundConnectionFactory(unittest.TestCase):
self.assertEqual(p1._cancelled, True) self.assertEqual(p1._cancelled, True)
self.assertEqual(p2._cancelled, True) self.assertEqual(p2._cancelled, True)
# XXX check descriptions # XXX check descriptions
class OutboundConnectionFactory(unittest.TestCase): class OutboundConnectionFactory(unittest.TestCase):
def test_success(self): def test_success(self):
f = transit.OutboundConnectionFactory("owner", "relay_handshake", f = transit.OutboundConnectionFactory("owner", "relay_handshake",
@ -609,25 +687,33 @@ class OutboundConnectionFactory(unittest.TestCase):
class MockOwner: class MockOwner:
_connection_ready_called = False _connection_ready_called = False
def connection_ready(self, connection): def connection_ready(self, connection):
self._connection_ready_called = True self._connection_ready_called = True
self._connection = connection self._connection = connection
return self._state return self._state
def _send_this(self): def _send_this(self):
return b"send_this" return b"send_this"
def _expect_this(self): def _expect_this(self):
return b"expect_this" return b"expect_this"
def _sender_record_key(self): def _sender_record_key(self):
return b"s" * 32 return b"s" * 32
def _receiver_record_key(self): def _receiver_record_key(self):
return b"r" * 32 return b"r" * 32
class MockFactory: class MockFactory:
_connectionWasMade_called = False _connectionWasMade_called = False
def connectionWasMade(self, p): def connectionWasMade(self, p):
self._connectionWasMade_called = True self._connectionWasMade_called = True
self._p = p self._p = p
class Connection(unittest.TestCase): class Connection(unittest.TestCase):
# exercise the Connection protocol class # exercise the Connection protocol class
@ -640,8 +726,8 @@ class Connection(unittest.TestCase):
c.buf = b"unexpected" c.buf = b"unexpected"
e = self.assertRaises(transit.BadHandshake, c._check_and_remove, EXP) e = self.assertRaises(transit.BadHandshake, c._check_and_remove, EXP)
self.assertEqual(str(e), self.assertEqual(
"got %r want %r" % (b'unexpected', b'expectation')) str(e), "got %r want %r" % (b'unexpected', b'expectation'))
self.assertEqual(c.buf, b"unexpected") self.assertEqual(c.buf, b"unexpected")
c.buf = b"expect" c.buf = b"expect"
@ -813,8 +899,8 @@ class Connection(unittest.TestCase):
self.assertEqual(c.state, "hung up") self.assertEqual(c.state, "hung up")
f = self.failureResultOf(d, transit.BadHandshake) f = self.failureResultOf(d, transit.BadHandshake)
self.assertEqual(str(f.value), self.assertEqual(
"got %r want %r" % (b"not ok\n", b"ok\n")) str(f.value), "got %r want %r" % (b"not ok\n", b"ok\n"))
def test_receiver_accepted(self): def test_receiver_accepted(self):
# we're on the receiving side, so we wait for the sender to decide # we're on the receiving side, so we wait for the sender to decide
@ -870,8 +956,8 @@ class Connection(unittest.TestCase):
self.assertEqual(t._connected, False) self.assertEqual(t._connected, False)
self.assertEqual(c.state, "hung up") self.assertEqual(c.state, "hung up")
f = self.failureResultOf(d, transit.BadHandshake) f = self.failureResultOf(d, transit.BadHandshake)
self.assertEqual(str(f.value), self.assertEqual(
"got %r want %r" % (b"nevermind\n", b"go\n")) str(f.value), "got %r want %r" % (b"nevermind\n", b"go\n"))
def test_receiver_rejected_rudely(self): def test_receiver_rejected_rudely(self):
# we're on the receiving side, so we wait for the sender to decide # we're on the receiving side, so we wait for the sender to decide
@ -901,7 +987,6 @@ class Connection(unittest.TestCase):
f = self.failureResultOf(d, transit.BadHandshake) f = self.failureResultOf(d, transit.BadHandshake)
self.assertEqual(str(f.value), "connection lost") self.assertEqual(str(f.value), "connection lost")
def test_cancel(self): def test_cancel(self):
owner = MockOwner() owner = MockOwner()
factory = MockFactory() factory = MockFactory()
@ -926,8 +1011,10 @@ class Connection(unittest.TestCase):
factory = MockFactory() factory = MockFactory()
addr = address.HostnameAddress("example.com", 1234) addr = address.HostnameAddress("example.com", 1234)
c = transit.Connection(owner, None, None, "description") c = transit.Connection(owner, None, None, "description")
def _callLater(period, func): def _callLater(period, func):
clock.callLater(period, func) clock.callLater(period, func)
c.callLater = _callLater c.callLater = _callLater
self.assertEqual(c.state, "too-early") self.assertEqual(c.state, "too-early")
t = c.transport = FakeTransport(c, addr) t = c.transport = FakeTransport(c, addr)
@ -1273,6 +1360,7 @@ class Connection(unittest.TestCase):
c.unregisterProducer() c.unregisterProducer()
self.assertEqual(c.transport.producer, None) self.assertEqual(c.transport.producer, None)
class FileConsumer(unittest.TestCase): class FileConsumer(unittest.TestCase):
def test_basic(self): def test_basic(self):
f = io.BytesIO() f = io.BytesIO()
@ -1305,22 +1393,43 @@ class FileConsumer(unittest.TestCase):
self.assertEqual(hashee, [b"." * 99, b"!"]) self.assertEqual(hashee, [b"." * 99, b"!"])
DIRECT_HINT_JSON = {"type": "direct-tcp-v1", DIRECT_HINT_JSON = {
"hostname": "direct", "port": 1234} "type": "direct-tcp-v1",
RELAY_HINT_JSON = {"type": "relay-v1", "hostname": "direct",
"hints": [{"type": "direct-tcp-v1", "port": 1234
"hostname": "relay", "port": 1234}]} }
UNRECOGNIZED_DIRECT_HINT_JSON = {"type": "direct-tcp-v1", RELAY_HINT_JSON = {
"hostname": ["cannot", "parse", "list"]} "type": "relay-v1",
"hints": [{
"type": "direct-tcp-v1",
"hostname": "relay",
"port": 1234
}]
}
UNRECOGNIZED_DIRECT_HINT_JSON = {
"type": "direct-tcp-v1",
"hostname": ["cannot", "parse", "list"]
}
UNRECOGNIZED_HINT_JSON = {"type": "unknown"} UNRECOGNIZED_HINT_JSON = {"type": "unknown"}
UNAVAILABLE_HINT_JSON = {"type": "direct-tcp-v1", # e.g. Tor without txtorcon UNAVAILABLE_HINT_JSON = {
"hostname": "unavailable", "port": 1234} "type": "direct-tcp-v1", # e.g. Tor without txtorcon
RELAY_HINT2_JSON = {"type": "relay-v1", "hostname": "unavailable",
"hints": [{"type": "direct-tcp-v1", "port": 1234
"hostname": "relay", "port": 1234}, }
UNRECOGNIZED_HINT_JSON]} RELAY_HINT2_JSON = {
UNAVAILABLE_RELAY_HINT_JSON = {"type": "relay-v1", "type":
"hints": [UNAVAILABLE_HINT_JSON]} "relay-v1",
"hints": [{
"type": "direct-tcp-v1",
"hostname": "relay",
"port": 1234
}, UNRECOGNIZED_HINT_JSON]
}
UNAVAILABLE_RELAY_HINT_JSON = {
"type": "relay-v1",
"hints": [UNAVAILABLE_HINT_JSON]
}
class Transit(unittest.TestCase): class Transit(unittest.TestCase):
def setUp(self): def setUp(self):
@ -1342,9 +1451,10 @@ class Transit(unittest.TestCase):
s.set_transit_key(b"key") s.set_transit_key(b"key")
hints = yield s.get_connection_hints() # start the listener hints = yield s.get_connection_hints() # start the listener
del hints del hints
s.add_connection_hints([DIRECT_HINT_JSON, s.add_connection_hints([
UNRECOGNIZED_DIRECT_HINT_JSON, DIRECT_HINT_JSON, UNRECOGNIZED_DIRECT_HINT_JSON,
UNRECOGNIZED_HINT_JSON]) UNRECOGNIZED_HINT_JSON
])
s._start_connector = self._start_connector s._start_connector = self._start_connector
d = s.connect() d = s.connect()
@ -1411,9 +1521,8 @@ class Transit(unittest.TestCase):
s.set_transit_key(b"key") s.set_transit_key(b"key")
hints = yield s.get_connection_hints() hints = yield s.get_connection_hints()
del hints del hints
s.add_connection_hints([DIRECT_HINT_JSON, s.add_connection_hints(
UNRECOGNIZED_HINT_JSON, [DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON, RELAY_HINT_JSON])
RELAY_HINT_JSON])
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector s._start_connector = self._start_connector
@ -1437,19 +1546,45 @@ class Transit(unittest.TestCase):
hints = yield s.get_connection_hints() hints = yield s.get_connection_hints()
del hints del hints
s.add_connection_hints([ s.add_connection_hints([
{"type": "relay-v1", {
"hints": [{"type": "direct-tcp-v1", "type":
"hostname": "relay", "port": 1234}]}, "relay-v1",
{"type": "direct-tcp-v1", "hints": [{
"hostname": "direct", "port": 1234}, "type": "direct-tcp-v1",
{"type": "relay-v1", "hostname": "relay",
"hints": [{"type": "direct-tcp-v1", "priority": 2.0, "port": 1234
"hostname": "relay2", "port": 1234}, }]
{"type": "direct-tcp-v1", "priority": 3.0, },
"hostname": "relay3", "port": 1234}]}, {
{"type": "relay-v1", "type": "direct-tcp-v1",
"hints": [{"type": "direct-tcp-v1", "priority": 2.0, "hostname": "direct",
"hostname": "relay4", "port": 1234}]}, "port": 1234
},
{
"type":
"relay-v1",
"hints": [{
"type": "direct-tcp-v1",
"priority": 2.0,
"hostname": "relay2",
"port": 1234
}, {
"type": "direct-tcp-v1",
"priority": 3.0,
"hostname": "relay3",
"port": 1234
}]
},
{
"type":
"relay-v1",
"hints": [{
"type": "direct-tcp-v1",
"priority": 2.0,
"hostname": "relay4",
"port": 1234
}]
},
]) ])
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector s._start_connector = self._start_connector
@ -1485,10 +1620,10 @@ class Transit(unittest.TestCase):
hints = yield s.get_connection_hints() # start the listener hints = yield s.get_connection_hints() # start the listener
del hints del hints
# include hints that can't be turned into an endpoint at runtime # include hints that can't be turned into an endpoint at runtime
s.add_connection_hints([UNRECOGNIZED_HINT_JSON, s.add_connection_hints([
UNAVAILABLE_HINT_JSON, UNRECOGNIZED_HINT_JSON, UNAVAILABLE_HINT_JSON, RELAY_HINT2_JSON,
RELAY_HINT2_JSON, UNAVAILABLE_RELAY_HINT_JSON
UNAVAILABLE_RELAY_HINT_JSON]) ])
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector s._start_connector = self._start_connector
@ -1519,6 +1654,7 @@ class Transit(unittest.TestCase):
f = self.failureResultOf(d, transit.TransitError) f = self.failureResultOf(d, transit.TransitError)
self.assertEqual(str(f.value), "No contenders for connection") self.assertEqual(str(f.value), "No contenders for connection")
class RelayHandshake(unittest.TestCase): class RelayHandshake(unittest.TestCase):
def old_build_relay_handshake(self, key): def old_build_relay_handshake(self, key):
token = transit.HKDF(key, 32, CTXinfo=b"transit_relay_token") token = transit.HKDF(key, 32, CTXinfo=b"transit_relay_token")
@ -1548,7 +1684,8 @@ class RelayHandshake(unittest.TestCase):
tc.dataReceived(new_handshake[:-1]) tc.dataReceived(new_handshake[:-1])
self.assertEqual(tc.factory.connection_got_token.mock_calls, []) self.assertEqual(tc.factory.connection_got_token.mock_calls, [])
tc.dataReceived(new_handshake[-1:]) tc.dataReceived(new_handshake[-1:])
self.assertEqual(tc.factory.connection_got_token.mock_calls, self.assertEqual(
tc.factory.connection_got_token.mock_calls,
[mock.call(hexlify(token), c._side.encode("ascii"), tc)]) [mock.call(hexlify(token), c._side.encode("ascii"), tc)])

View File

@ -1,10 +1,15 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import six
import mock
import unicodedata import unicodedata
import six
from twisted.trial import unittest from twisted.trial import unittest
import mock
from .. import util from .. import util
class Utils(unittest.TestCase): class Utils(unittest.TestCase):
def test_to_bytes(self): def test_to_bytes(self):
b = util.to_bytes("abc") b = util.to_bytes("abc")
@ -41,11 +46,12 @@ class Utils(unittest.TestCase):
self.assertIsInstance(d, dict) self.assertIsInstance(d, dict)
self.assertEqual(d, {"a": "b", "c": 2}) self.assertEqual(d, {"a": "b", "c": 2})
class Space(unittest.TestCase): class Space(unittest.TestCase):
def test_free_space(self): def test_free_space(self):
free = util.estimate_free_space(".") free = util.estimate_free_space(".")
self.assert_(isinstance(free, six.integer_types + (type(None),)), self.assert_(
repr(free)) isinstance(free, six.integer_types + (type(None), )), repr(free))
# some platforms (I think the VMs used by travis are in this # some platforms (I think the VMs used by travis are in this
# category) return 0, and windows will return None, so don't assert # category) return 0, and windows will return None, so don't assert
# anything more specific about the return value # anything more specific about the return value

View File

@ -1,8 +1,12 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import mock
from twisted.trial import unittest from twisted.trial import unittest
import mock
from .._wordlist import PGPWordList from .._wordlist import PGPWordList
class Completions(unittest.TestCase): class Completions(unittest.TestCase):
def test_completions(self): def test_completions(self):
wl = PGPWordList() wl = PGPWordList()
@ -14,16 +18,21 @@ class Completions(unittest.TestCase):
self.assertEqual(len(lots), 256, lots) self.assertEqual(len(lots), 256, lots)
first = list(lots)[0] first = list(lots)[0]
self.assert_(first.startswith("armistice-"), first) self.assert_(first.startswith("armistice-"), first)
self.assertEqual(gc("armistice-ba", 2), self.assertEqual(
{"armistice-baboon", "armistice-backfield", gc("armistice-ba", 2), {
"armistice-backward", "armistice-banjo"}) "armistice-baboon", "armistice-backfield",
self.assertEqual(gc("armistice-ba", 3), "armistice-backward", "armistice-banjo"
{"armistice-baboon-", "armistice-backfield-", })
"armistice-backward-", "armistice-banjo-"}) self.assertEqual(
gc("armistice-ba", 3), {
"armistice-baboon-", "armistice-backfield-",
"armistice-backward-", "armistice-banjo-"
})
self.assertEqual(gc("armistice-baboon", 2), {"armistice-baboon"}) self.assertEqual(gc("armistice-baboon", 2), {"armistice-baboon"})
self.assertEqual(gc("armistice-baboon", 3), {"armistice-baboon-"}) self.assertEqual(gc("armistice-baboon", 3), {"armistice-baboon-"})
self.assertEqual(gc("armistice-baboon", 4), {"armistice-baboon-"}) self.assertEqual(gc("armistice-baboon", 4), {"armistice-baboon-"})
class Choose(unittest.TestCase): class Choose(unittest.TestCase):
def test_choose_words(self): def test_choose_words(self):
wl = PGPWordList() wl = PGPWordList()

View File

@ -1,17 +1,22 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import io, re
import mock import io
from twisted.trial import unittest import re
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue from twisted.internet.defer import gatherResults, inlineCallbacks, returnValue
from twisted.internet.error import ConnectionRefusedError from twisted.internet.error import ConnectionRefusedError
from .common import ServerBase, poll_until from twisted.trial import unittest
from .. import wormhole, _rendezvous
from ..errors import (WrongPasswordError, ServerConnectionError, import mock
KeyFormatError, WormholeClosed, LonelyError,
NoKeyError, OnlyOneCodeError) from .. import _rendezvous, wormhole
from ..transit import allocate_tcp_port from ..errors import (KeyFormatError, LonelyError, NoKeyError,
OnlyOneCodeError, ServerConnectionError, WormholeClosed,
WrongPasswordError)
from ..eventual import EventualQueue from ..eventual import EventualQueue
from ..transit import allocate_tcp_port
from .common import ServerBase, poll_until
APPID = "appid" APPID = "appid"
@ -26,6 +31,7 @@ APPID = "appid"
# * set_code, then connected # * set_code, then connected
# * connected, receive_pake, send_phase, set_code # * connected, receive_pake, send_phase, set_code
class Delegate: class Delegate:
def __init__(self): def __init__(self):
self.welcome = None self.welcome = None
@ -35,23 +41,30 @@ class Delegate:
self.versions = None self.versions = None
self.messages = [] self.messages = []
self.closed = None self.closed = None
def wormhole_got_welcome(self, welcome): def wormhole_got_welcome(self, welcome):
self.welcome = welcome self.welcome = welcome
def wormhole_got_code(self, code): def wormhole_got_code(self, code):
self.code = code self.code = code
def wormhole_got_unverified_key(self, key): def wormhole_got_unverified_key(self, key):
self.key = key self.key = key
def wormhole_got_verifier(self, verifier): def wormhole_got_verifier(self, verifier):
self.verifier = verifier self.verifier = verifier
def wormhole_got_versions(self, versions): def wormhole_got_versions(self, versions):
self.versions = versions self.versions = versions
def wormhole_got_message(self, data): def wormhole_got_message(self, data):
self.messages.append(data) self.messages.append(data)
def wormhole_closed(self, result): def wormhole_closed(self, result):
self.closed = result self.closed = result
class Delegated(ServerBase, unittest.TestCase):
class Delegated(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_delegated(self): def test_delegated(self):
dg = Delegate() dg = Delegate()
@ -103,6 +116,7 @@ class Delegated(ServerBase, unittest.TestCase):
yield poll_until(lambda: dg.code is not None) yield poll_until(lambda: dg.code is not None)
w1.close() w1.close()
class Wormholes(ServerBase, unittest.TestCase): class Wormholes(ServerBase, unittest.TestCase):
# integration test, with a real server # integration test, with a real server
@ -302,7 +316,6 @@ class Wormholes(ServerBase, unittest.TestCase):
yield w1.close() yield w1.close()
yield w2.close() yield w2.close()
@inlineCallbacks @inlineCallbacks
def test_multiple_messages(self): def test_multiple_messages(self):
w1 = wormhole.create(APPID, self.relayurl, reactor) w1 = wormhole.create(APPID, self.relayurl, reactor)
@ -322,7 +335,6 @@ class Wormholes(ServerBase, unittest.TestCase):
yield w1.close() yield w1.close()
yield w2.close() yield w2.close()
@inlineCallbacks @inlineCallbacks
def test_closed(self): def test_closed(self):
eq = EventualQueue(reactor) eq = EventualQueue(reactor)
@ -525,10 +537,10 @@ class Wormholes(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_versions(self): def test_versions(self):
# there's no API for this yet, but make sure the internals work # there's no API for this yet, but make sure the internals work
w1 = wormhole.create(APPID, self.relayurl, reactor, w1 = wormhole.create(
versions={"w1": 123}) APPID, self.relayurl, reactor, versions={"w1": 123})
w2 = wormhole.create(APPID, self.relayurl, reactor, w2 = wormhole.create(
versions={"w2": 456}) APPID, self.relayurl, reactor, versions={"w2": 456})
w1.allocate_code() w1.allocate_code()
code = yield w1.get_code() code = yield w1.get_code()
w2.set_code(code) w2.set_code(code)
@ -564,6 +576,7 @@ class Wormholes(ServerBase, unittest.TestCase):
yield w1.close() yield w1.close()
yield w2.close() yield w2.close()
class MessageDoubler(_rendezvous.RendezvousConnector): class MessageDoubler(_rendezvous.RendezvousConnector):
# we could double messages on the sending side, but a future server will # we could double messages on the sending side, but a future server will
# strip those duplicates, so to really exercise the receiver, we must # strip those duplicates, so to really exercise the receiver, we must
@ -575,6 +588,7 @@ class MessageDoubler(_rendezvous.RendezvousConnector):
_rendezvous.RendezvousConnector._response_handle_message(self, msg) _rendezvous.RendezvousConnector._response_handle_message(self, msg)
_rendezvous.RendezvousConnector._response_handle_message(self, msg) _rendezvous.RendezvousConnector._response_handle_message(self, msg)
class Errors(ServerBase, unittest.TestCase): class Errors(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_derive_key_early(self): def test_derive_key_early(self):
@ -602,6 +616,7 @@ class Errors(ServerBase, unittest.TestCase):
w.set_code("123-nope") w.set_code("123-nope")
yield self.assertFailure(w.close(), LonelyError) yield self.assertFailure(w.close(), LonelyError)
class Reconnection(ServerBase, unittest.TestCase): class Reconnection(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_basic(self): def test_basic(self):
@ -619,6 +634,7 @@ class Reconnection(ServerBase, unittest.TestCase):
if m["type"] == "message" and m["phase"] == "pake": if m["type"] == "message" and m["phase"] == "pake":
return True return True
return False return False
yield poll_until(seen_our_pake) yield poll_until(seen_our_pake)
w1_in[:] = [] w1_in[:] = []
@ -649,6 +665,7 @@ class Reconnection(ServerBase, unittest.TestCase):
c2 = yield w2.close() c2 = yield w2.close()
self.assertEqual(c2, "happy") self.assertEqual(c2, "happy")
class InitialFailure(unittest.TestCase): class InitialFailure(unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def assertSCEFailure(self, eq, d, innerType): def assertSCEFailure(self, eq, d, innerType):
@ -662,8 +679,8 @@ class InitialFailure(unittest.TestCase):
def test_bad_dns(self): def test_bad_dns(self):
eq = EventualQueue(reactor) eq = EventualQueue(reactor)
# point at a URL that will never connect # point at a URL that will never connect
w = wormhole.create(APPID, "ws://%%%.example.org:4000/v1", w = wormhole.create(
reactor, _eventual_queue=eq) APPID, "ws://%%%.example.org:4000/v1", reactor, _eventual_queue=eq)
# that should have already received an error, when it tried to # that should have already received an error, when it tried to
# resolve the bogus DNS name. All API calls will return an error. # resolve the bogus DNS name. All API calls will return an error.
@ -717,6 +734,7 @@ class InitialFailure(unittest.TestCase):
yield self.assertSCE(d4, ConnectionRefusedError) yield self.assertSCE(d4, ConnectionRefusedError)
yield self.assertSCE(d5, ConnectionRefusedError) yield self.assertSCE(d5, ConnectionRefusedError)
class Trace(unittest.TestCase): class Trace(unittest.TestCase):
def test_basic(self): def test_basic(self):
w1 = wormhole.create(APPID, "ws://localhost:1", reactor) w1 = wormhole.create(APPID, "ws://localhost:1", reactor)
@ -734,13 +752,11 @@ class Trace(unittest.TestCase):
["C1.M1[OLD].IN -> [NEW]"]) ["C1.M1[OLD].IN -> [NEW]"])
out("OUT1") out("OUT1")
self.assertEqual(stderr.getvalue().splitlines(), self.assertEqual(stderr.getvalue().splitlines(),
["C1.M1[OLD].IN -> [NEW]", ["C1.M1[OLD].IN -> [NEW]", " C1.M1.OUT1()"])
" C1.M1.OUT1()"])
w1._boss._print_trace("", "R.connected", "", "C1", "RC1", stderr) w1._boss._print_trace("", "R.connected", "", "C1", "RC1", stderr)
self.assertEqual(stderr.getvalue().splitlines(), self.assertEqual(
["C1.M1[OLD].IN -> [NEW]", stderr.getvalue().splitlines(),
" C1.M1.OUT1()", ["C1.M1[OLD].IN -> [NEW]", " C1.M1.OUT1()", "C1.RC1.R.connected"])
"C1.RC1.R.connected"])
def test_delegated(self): def test_delegated(self):
dg = Delegate() dg = Delegate()

View File

@ -1,11 +1,13 @@
from twisted.trial import unittest from twisted.internet import defer, reactor
from twisted.internet import reactor, defer
from twisted.internet.defer import inlineCallbacks from twisted.internet.defer import inlineCallbacks
from twisted.trial import unittest
from .. import xfer_util from .. import xfer_util
from .common import ServerBase from .common import ServerBase
APPID = u"appid" APPID = u"appid"
class Xfer(ServerBase, unittest.TestCase): class Xfer(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_xfer(self): def test_xfer(self):
@ -24,10 +26,15 @@ class Xfer(ServerBase, unittest.TestCase):
data = u"data" data = u"data"
send_code = [] send_code = []
receive_code = [] receive_code = []
d1 = xfer_util.send(reactor, APPID, self.relayurl, data, code, d1 = xfer_util.send(
reactor,
APPID,
self.relayurl,
data,
code,
on_code=send_code.append) on_code=send_code.append)
d2 = xfer_util.receive(reactor, APPID, self.relayurl, code, d2 = xfer_util.receive(
on_code=receive_code.append) reactor, APPID, self.relayurl, code, on_code=receive_code.append)
send_result = yield d1 send_result = yield d1
receive_result = yield d2 receive_result = yield d2
self.assertEqual(send_code, [code]) self.assertEqual(send_code, [code])
@ -39,7 +46,12 @@ class Xfer(ServerBase, unittest.TestCase):
def test_make_code(self): def test_make_code(self):
data = u"data" data = u"data"
got_code = defer.Deferred() got_code = defer.Deferred()
d1 = xfer_util.send(reactor, APPID, self.relayurl, data, code=None, d1 = xfer_util.send(
reactor,
APPID,
self.relayurl,
data,
code=None,
on_code=got_code.callback) on_code=got_code.callback)
code = yield got_code code = yield got_code
d2 = xfer_util.receive(reactor, APPID, self.relayurl, code) d2 = xfer_util.receive(reactor, APPID, self.relayurl, code)

View File

@ -1,8 +1,13 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
import json, time
import json
import time
from zope.interface import implementer from zope.interface import implementer
from ._interfaces import ITiming from ._interfaces import ITiming
class Event: class Event:
def __init__(self, name, when, **details): def __init__(self, name, when, **details):
# data fields that will be dumped to JSON later # data fields that will be dumped to JSON later
@ -35,6 +40,7 @@ class Event:
else: else:
self.finish() self.finish()
@implementer(ITiming) @implementer(ITiming)
class DebugTiming: class DebugTiming:
def __init__(self): def __init__(self):
@ -47,11 +53,14 @@ class DebugTiming:
def write(self, fn, stderr): def write(self, fn, stderr):
with open(fn, "wt") as f: with open(fn, "wt") as f:
data = [ dict(name=e._name, data = [
start=e._start, stop=e._stop, dict(
name=e._name,
start=e._start,
stop=e._stop,
details=e._details, details=e._details,
) ) for e in self._events
for e in self._events ] ]
json.dump(data, f, indent=1) json.dump(data, f, indent=1)
f.write("\n") f.write("\n")
print("Timing data written to %s" % fn, file=stderr) print("Timing data written to %s" % fn, file=stderr)

View File

@ -1,15 +1,20 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import sys import sys
from attr import attrs, attrib
from zope.interface.declarations import directlyProvides from attr import attrib, attrs
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.endpoints import clientFromString from twisted.internet.endpoints import clientFromString
from zope.interface.declarations import directlyProvides
from . import _interfaces, errors
from .timing import DebugTiming
try: try:
import txtorcon import txtorcon
except ImportError: except ImportError:
txtorcon = None txtorcon = None
from . import _interfaces, errors
from .timing import DebugTiming
@attrs @attrs
class SocksOnlyTor(object): class SocksOnlyTor(object):
@ -17,15 +22,20 @@ class SocksOnlyTor(object):
def stream_via(self, host, port, tls=False): def stream_via(self, host, port, tls=False):
return txtorcon.TorClientEndpoint( return txtorcon.TorClientEndpoint(
host, port, host,
port,
socks_endpoint=None, # tries localhost:9050 and 9150 socks_endpoint=None, # tries localhost:9050 and 9150
tls=tls, tls=tls,
reactor=self._reactor, reactor=self._reactor,
) )
@inlineCallbacks @inlineCallbacks
def get_tor(reactor, launch_tor=False, tor_control_port=None, def get_tor(reactor,
timing=None, stderr=sys.stderr): launch_tor=False,
tor_control_port=None,
timing=None,
stderr=sys.stderr):
""" """
If launch_tor=True, I will try to launch a new Tor process, ask it If launch_tor=True, I will try to launch a new Tor process, ask it
for its SOCKS and control ports, and use those for outbound for its SOCKS and control ports, and use those for outbound
@ -74,7 +84,8 @@ def get_tor(reactor, launch_tor=False, tor_control_port=None,
# need the control port. # need the control port.
if launch_tor: if launch_tor:
print(" launching a new Tor process, this may take a while..", print(
" launching a new Tor process, this may take a while..",
file=stderr) file=stderr)
with timing.add("launch tor"): with timing.add("launch tor"):
tor = yield txtorcon.launch(reactor, tor = yield txtorcon.launch(reactor,
@ -85,7 +96,8 @@ def get_tor(reactor, launch_tor=False, tor_control_port=None,
with timing.add("find tor"): with timing.add("find tor"):
control_ep = clientFromString(reactor, tor_control_port) control_ep = clientFromString(reactor, tor_control_port)
tor = yield txtorcon.connect(reactor, control_ep) # might raise tor = yield txtorcon.connect(reactor, control_ep) # might raise
print(" using Tor via control port at %s" % tor_control_port, print(
" using Tor via control port at %s" % tor_control_port,
file=stderr) file=stderr)
else: else:
# Let txtorcon look through a list of usual places. If that fails, # Let txtorcon look through a list of usual places. If that fails,
@ -98,7 +110,8 @@ def get_tor(reactor, launch_tor=False, tor_control_port=None,
# TODO: make this more specific. I think connect() is # TODO: make this more specific. I think connect() is
# likely to throw a reactor.connectTCP -type error, like # likely to throw a reactor.connectTCP -type error, like
# ConnectionFailed or ConnectionRefused or something # ConnectionFailed or ConnectionRefused or something
print(" unable to find default Tor control port, using SOCKS", print(
" unable to find default Tor control port, using SOCKS",
file=stderr) file=stderr)
tor = SocksOnlyTor(reactor) tor = SocksOnlyTor(reactor)
directlyProvides(tor, _interfaces.ITorManager) directlyProvides(tor, _interfaces.ITorManager)

View File

@ -1,38 +1,51 @@
# no unicode_literals, revisit after twisted patch # no unicode_literals, revisit after twisted patch
from __future__ import print_function, absolute_import from __future__ import absolute_import, print_function
import os, re, sys, time, socket
from collections import namedtuple, deque import os
import re
import socket
import sys
import time
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from collections import deque, namedtuple
import six import six
from zope.interface import implementer from hkdf import Hkdf
from twisted.python import log from nacl.secret import SecretBox
from twisted.python.runtime import platformType from twisted.internet import (address, defer, endpoints, error, interfaces,
from twisted.internet import (reactor, interfaces, defer, protocol, protocol, reactor, task)
endpoints, task, address, error)
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.protocols import policies from twisted.protocols import policies
from nacl.secret import SecretBox from twisted.python import log
from hkdf import Hkdf from twisted.python.runtime import platformType
from zope.interface import implementer
from . import ipaddrs
from .errors import InternalError from .errors import InternalError
from .timing import DebugTiming from .timing import DebugTiming
from .util import bytes_to_hexstr from .util import bytes_to_hexstr
from . import ipaddrs
def HKDF(skm, outlen, salt=None, CTXinfo=b""): def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen) return Hkdf(salt, skm).expand(CTXinfo, outlen)
class TransitError(Exception): class TransitError(Exception):
pass pass
class BadHandshake(Exception): class BadHandshake(Exception):
pass pass
class TransitClosed(TransitError): class TransitClosed(TransitError):
pass pass
class BadNonce(TransitError): class BadNonce(TransitError):
pass pass
# The beginning of each TCP connection consists of the following handshake # The beginning of each TCP connection consists of the following handshake
# messages. The sender transmits the same text regardless of whether it is on # messages. The sender transmits the same text regardless of whether it is on
# the initiating/connecting end of the TCP connection, or on the # the initiating/connecting end of the TCP connection, or on the
@ -63,19 +76,23 @@ class BadNonce(TransitError):
# RXID_HEX ready\n\n" and then makes a first/not-first decision about sending # RXID_HEX ready\n\n" and then makes a first/not-first decision about sending
# "go\n" or "nevermind\n"+close(). # "go\n" or "nevermind\n"+close().
def build_receiver_handshake(key): def build_receiver_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_receiver") hexid = HKDF(key, 32, CTXinfo=b"transit_receiver")
return b"transit receiver " + hexlify(hexid) + b" ready\n\n" return b"transit receiver " + hexlify(hexid) + b" ready\n\n"
def build_sender_handshake(key): def build_sender_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_sender") hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
return b"transit sender " + hexlify(hexid) + b" ready\n\n" return b"transit sender " + hexlify(hexid) + b" ready\n\n"
def build_sided_relay_handshake(key, side): def build_sided_relay_handshake(key, side):
assert isinstance(side, type(u"")) assert isinstance(side, type(u""))
assert len(side) == 8 * 2 assert len(side) == 8 * 2
token = HKDF(key, 32, CTXinfo=b"transit_relay_token") token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n" return b"please relay " + hexlify(token) + b" for side " + side.encode(
"ascii") + b"\n"
# These namedtuples are "hint objects". The JSON-serializable dictionaries # These namedtuples are "hint objects". The JSON-serializable dictionaries
@ -87,7 +104,8 @@ def build_sided_relay_handshake(key, side):
# * expect to see the receiver/sender handshake bytes from the other side # * expect to see the receiver/sender handshake bytes from the other side
# * the sender writes "go\n", the receiver waits for "go\n" # * the sender writes "go\n", the receiver waits for "go\n"
# * the rest of the connection contains transit data # * the rest of the connection contains transit data
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"]) DirectTCPV1Hint = namedtuple("DirectTCPV1Hint",
["hostname", "port", "priority"])
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we # RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
# use a tuple rather than a list so they'll be hashable into a set). For each # use a tuple rather than a list so they'll be hashable into a set). For each
@ -95,6 +113,7 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# rest of the V1 protocol. Only one hint per relay is useful. # rest of the V1 protocol. Only one hint per relay is useful.
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
def describe_hint_obj(hint): def describe_hint_obj(hint):
if isinstance(hint, DirectTCPV1Hint): if isinstance(hint, DirectTCPV1Hint):
return u"tcp:%s:%d" % (hint.hostname, hint.port) return u"tcp:%s:%d" % (hint.hostname, hint.port)
@ -103,6 +122,7 @@ def describe_hint_obj(hint):
else: else:
return str(hint) return str(hint)
def parse_hint_argv(hint, stderr=sys.stderr): def parse_hint_argv(hint, stderr=sys.stderr):
assert isinstance(hint, type(u"")) assert isinstance(hint, type(u""))
# return tuple or None for an unparseable hint # return tuple or None for an unparseable hint
@ -113,12 +133,14 @@ def parse_hint_argv(hint, stderr=sys.stderr):
return None return None
hint_type = mo.group(1) hint_type = mo.group(1)
if hint_type != "tcp": if hint_type != "tcp":
print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) print(
"unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr)
return None return None
hint_value = mo.group(2) hint_value = mo.group(2)
pieces = hint_value.split(":") pieces = hint_value.split(":")
if len(pieces) < 2: if len(pieces) < 2:
print("unparseable TCP hint (need more colons) '%s'" % (hint,), print(
"unparseable TCP hint (need more colons) '%s'" % (hint, ),
file=stderr) file=stderr)
return None return None
mo = re.search(r'^(\d+)$', pieces[1]) mo = re.search(r'^(\d+)$', pieces[1])
@ -133,13 +155,16 @@ def parse_hint_argv(hint, stderr=sys.stderr):
try: try:
priority = float(more_pieces[1]) priority = float(more_pieces[1])
except ValueError: except ValueError:
print("non-float priority= in TCP hint '%s'" % (hint,), print(
"non-float priority= in TCP hint '%s'" % (hint, ),
file=stderr) file=stderr)
return None return None
return DirectTCPV1Hint(hint_host, hint_port, priority) return DirectTCPV1Hint(hint_host, hint_port, priority)
TIMEOUT = 60 # seconds TIMEOUT = 60 # seconds
@implementer(interfaces.IProducer, interfaces.IConsumer) @implementer(interfaces.IProducer, interfaces.IConsumer)
class Connection(protocol.Protocol, policies.TimeoutMixin): class Connection(protocol.Protocol, policies.TimeoutMixin):
def __init__(self, owner, relay_handshake, start, description): def __init__(self, owner, relay_handshake, start, description):
@ -181,7 +206,6 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
if self._negotiation_d: if self._negotiation_d:
self._negotiation_d = None self._negotiation_d = None
def dataReceived(self, data): def dataReceived(self, data):
try: try:
self._dataReceived(data) self._dataReceived(data)
@ -277,8 +301,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
nonce = int(hexlify(nonce_buf), 16) nonce = int(hexlify(nonce_buf), 16)
if nonce != self.next_receive_nonce: if nonce != self.next_receive_nonce:
raise BadNonce("received out-of-order record: got %d, expected %d" raise BadNonce(
% (nonce, self.next_receive_nonce)) "received out-of-order record: got %d, expected %d" %
(nonce, self.next_receive_nonce))
self.next_receive_nonce += 1 self.next_receive_nonce += 1
record = self.receive_box.decrypt(encrypted) record = self.receive_box.decrypt(encrypted)
return record return record
@ -287,7 +312,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
return self._description return self._description
def send_record(self, record): def send_record(self, record):
if not isinstance(record, type(b"")): raise InternalError if not isinstance(record, type(b"")):
raise InternalError
assert SecretBox.NONCE_SIZE == 24 assert SecretBox.NONCE_SIZE == 24
assert self.send_nonce < 2**(8 * 24) assert self.send_nonce < 2**(8 * 24)
assert len(record) < 2**(8 * 4) assert len(record) < 2**(8 * 4)
@ -353,8 +379,10 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
def registerProducer(self, producer, streaming): def registerProducer(self, producer, streaming):
assert interfaces.IConsumer.providedBy(self.transport) assert interfaces.IConsumer.providedBy(self.transport)
self.transport.registerProducer(producer, streaming) self.transport.registerProducer(producer, streaming)
def unregisterProducer(self): def unregisterProducer(self):
self.transport.unregisterProducer() self.transport.unregisterProducer()
def write(self, data): def write(self, data):
self.send_record(data) self.send_record(data)
@ -362,8 +390,10 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
# the transport. # the transport.
def stopProducing(self): def stopProducing(self):
self.transport.stopProducing() self.transport.stopProducing()
def pauseProducing(self): def pauseProducing(self):
self.transport.pauseProducing() self.transport.pauseProducing()
def resumeProducing(self): def resumeProducing(self):
self.transport.resumeProducing() self.transport.resumeProducing()
@ -384,8 +414,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
Deferred, and you must call disconnectConsumer() when you are done.""" Deferred, and you must call disconnectConsumer() when you are done."""
if self._consumer: if self._consumer:
raise RuntimeError("A consumer is already attached: %r" % raise RuntimeError(
self._consumer) "A consumer is already attached: %r" % self._consumer)
# be aware of an ordering hazard: when we call the consumer's # be aware of an ordering hazard: when we call the consumer's
# .registerProducer method, they are likely to immediately call # .registerProducer method, they are likely to immediately call
@ -440,6 +470,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
fc = FileConsumer(f, progress, hasher) fc = FileConsumer(f, progress, hasher)
return self.connectConsumer(fc, expected) return self.connectConsumer(fc, expected)
class OutboundConnectionFactory(protocol.ClientFactory): class OutboundConnectionFactory(protocol.ClientFactory):
protocol = Connection protocol = Connection
@ -511,6 +542,7 @@ class InboundConnectionFactory(protocol.ClientFactory):
# ignore these two, let Twisted log everything else # ignore these two, let Twisted log everything else
f.trap(BadHandshake, defer.CancelledError) f.trap(BadHandshake, defer.CancelledError)
def allocate_tcp_port(): def allocate_tcp_port():
"""Return an (integer) available TCP port on localhost. This briefly """Return an (integer) available TCP port on localhost. This briefly
listens on the port in question, then closes it right away.""" listens on the port in question, then closes it right away."""
@ -527,6 +559,7 @@ def allocate_tcp_port():
s.close() s.close()
return port return port
class _ThereCanBeOnlyOne: class _ThereCanBeOnlyOne:
"""Accept a list of contender Deferreds, and return a summary Deferred. """Accept a list of contender Deferreds, and return a summary Deferred.
When the first contender fires successfully, cancel the rest and fire the When the first contender fires successfully, cancel the rest and fire the
@ -535,6 +568,7 @@ class _ThereCanBeOnlyOne:
status_cb=? status_cb=?
""" """
def __init__(self, contenders): def __init__(self, contenders):
self._remaining = set(contenders) self._remaining = set(contenders)
self._winner_d = defer.Deferred(self._cancel) self._winner_d = defer.Deferred(self._cancel)
@ -581,15 +615,21 @@ class _ThereCanBeOnlyOne:
else: else:
self._winner_d.errback(self._first_failure) self._winner_d.errback(self._first_failure)
def there_can_be_only_one(contenders): def there_can_be_only_one(contenders):
return _ThereCanBeOnlyOne(contenders).run() return _ThereCanBeOnlyOne(contenders).run()
class Common: class Common:
RELAY_DELAY = 2.0 RELAY_DELAY = 2.0
TRANSIT_KEY_LENGTH = SecretBox.KEY_SIZE TRANSIT_KEY_LENGTH = SecretBox.KEY_SIZE
def __init__(self, transit_relay, no_listen=False, tor=None, def __init__(self,
reactor=reactor, timing=None): transit_relay,
no_listen=False,
tor=None,
reactor=reactor,
timing=None):
self._side = bytes_to_hexstr(os.urandom(8)) # unicode self._side = bytes_to_hexstr(os.urandom(8)) # unicode
if transit_relay: if transit_relay:
if not isinstance(transit_relay, type(u"")): if not isinstance(transit_relay, type(u"")):
@ -622,14 +662,20 @@ class Common:
# some test hosts, including the appveyor VMs, *only* have # some test hosts, including the appveyor VMs, *only* have
# 127.0.0.1, and the tests will hang badly if we remove it. # 127.0.0.1, and the tests will hang badly if we remove it.
addresses = non_loopback_addresses addresses = non_loopback_addresses
direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0) direct_hints = [
for addr in addresses] DirectTCPV1Hint(six.u(addr), portnum, 0.0) for addr in addresses
]
ep = endpoints.serverFromString(reactor, "tcp:%d" % portnum) ep = endpoints.serverFromString(reactor, "tcp:%d" % portnum)
return direct_hints, ep return direct_hints, ep
def get_connection_abilities(self): def get_connection_abilities(self):
return [{u"type": u"direct-tcp-v1"}, return [
{u"type": u"relay-v1"}, {
u"type": u"direct-tcp-v1"
},
{
u"type": u"relay-v1"
},
] ]
@inlineCallbacks @inlineCallbacks
@ -637,7 +683,8 @@ class Common:
hints = [] hints = []
direct_hints = yield self._get_direct_hints() direct_hints = yield self._get_direct_hints()
for dh in direct_hints: for dh in direct_hints:
hints.append({u"type": u"direct-tcp-v1", hints.append({
u"type": u"direct-tcp-v1",
u"priority": dh.priority, u"priority": dh.priority,
u"hostname": dh.hostname, u"hostname": dh.hostname,
u"port": dh.port, # integer u"port": dh.port, # integer
@ -645,10 +692,12 @@ class Common:
for relay in self._transit_relays: for relay in self._transit_relays:
rhint = {u"type": u"relay-v1", u"hints": []} rhint = {u"type": u"relay-v1", u"hints": []}
for rh in relay.hints: for rh in relay.hints:
rhint[u"hints"].append({u"type": u"direct-tcp-v1", rhint[u"hints"].append({
u"type": u"direct-tcp-v1",
u"priority": rh.priority, u"priority": rh.priority,
u"hostname": rh.hostname, u"hostname": rh.hostname,
u"port": rh.port}) u"port": rh.port
})
hints.append(rhint) hints.append(rhint)
returnValue(hints) returnValue(hints)
@ -675,14 +724,17 @@ class Common:
self._listener_f = f # for tests # XX move to __init__ ? self._listener_f = f # for tests # XX move to __init__ ?
self._listener_d = f.whenDone() self._listener_d = f.whenDone()
d = self._listener.listen(f) d = self._listener.listen(f)
def _listening(lp): def _listening(lp):
# lp is an IListeningPort # lp is an IListeningPort
# self._listener_port = lp # for tests # self._listener_port = lp # for tests
def _stop_listening(res): def _stop_listening(res):
lp.stopListening() lp.stopListening()
return res return res
self._listener_d.addBoth(_stop_listening) self._listener_d.addBoth(_stop_listening)
return self._my_direct_hints return self._my_direct_hints
d.addCallback(_listening) d.addCallback(_listening)
return d return d
@ -699,12 +751,12 @@ class Common:
if hint_type not in [u"direct-tcp-v1", u"tor-tcp-v1"]: if hint_type not in [u"direct-tcp-v1", u"tor-tcp-v1"]:
log.msg("unknown hint type: %r" % (hint, )) log.msg("unknown hint type: %r" % (hint, ))
return None return None
if not(u"hostname" in hint if not (u"hostname" in hint and
and isinstance(hint[u"hostname"], type(u""))): isinstance(hint[u"hostname"], type(u""))):
log.msg("invalid hostname in hint: %r" % (hint, )) log.msg("invalid hostname in hint: %r" % (hint, ))
return None return None
if not(u"port" in hint if not (u"port" in hint and
and isinstance(hint[u"port"], six.integer_types)): isinstance(hint[u"port"], six.integer_types)):
log.msg("invalid port in hint: %r" % (hint, )) log.msg("invalid port in hint: %r" % (hint, ))
return None return None
priority = hint.get(u"priority", 0.0) priority = hint.get(u"priority", 0.0)
@ -753,19 +805,27 @@ class Common:
def _sender_record_key(self): def _sender_record_key(self):
assert self._transit_key assert self._transit_key
if self.is_sender: if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE, return HKDF(
self._transit_key,
SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key") CTXinfo=b"transit_record_sender_key")
else: else:
return HKDF(self._transit_key, SecretBox.KEY_SIZE, return HKDF(
self._transit_key,
SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key") CTXinfo=b"transit_record_receiver_key")
def _receiver_record_key(self): def _receiver_record_key(self):
assert self._transit_key assert self._transit_key
if self.is_sender: if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE, return HKDF(
self._transit_key,
SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key") CTXinfo=b"transit_record_receiver_key")
else: else:
return HKDF(self._transit_key, SecretBox.KEY_SIZE, return HKDF(
self._transit_key,
SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key") CTXinfo=b"transit_record_sender_key")
def set_transit_key(self, key): def set_transit_key(self, key):
@ -848,8 +908,12 @@ class Common:
description = "->relay:%s" % describe_hint_obj(hint_obj) description = "->relay:%s" % describe_hint_obj(hint_obj)
if self._tor: if self._tor:
description = "tor" + description description = "tor" + description
d = task.deferLater(self._reactor, relay_delay, d = task.deferLater(
self._start_connector, ep, description, self._reactor,
relay_delay,
self._start_connector,
ep,
description,
is_relay=True) is_relay=True)
contenders.append(d) contenders.append(d)
relay_delay += self.RELAY_DELAY relay_delay += self.RELAY_DELAY
@ -864,10 +928,12 @@ class Common:
"""If the timer fires first, cancel the deferred. If the deferred fires """If the timer fires first, cancel the deferred. If the deferred fires
first, cancel the timer.""" first, cancel the timer."""
t = self._reactor.callLater(timeout, d.cancel) t = self._reactor.callLater(timeout, d.cancel)
def _done(res): def _done(res):
if t.active(): if t.active():
t.cancel() t.cancel()
return res return res
d.addBoth(_done) d.addBoth(_done)
return d return d
@ -896,8 +962,8 @@ class Common:
return None return None
return None return None
if isinstance(hint, DirectTCPV1Hint): if isinstance(hint, DirectTCPV1Hint):
return endpoints.HostnameEndpoint(self._reactor, return endpoints.HostnameEndpoint(self._reactor, hint.hostname,
hint.hostname, hint.port) hint.port)
return None return None
def connection_ready(self, p): def connection_ready(self, p):
@ -915,9 +981,11 @@ class Common:
self._winner = p self._winner = p
return "go" return "go"
class TransitSender(Common): class TransitSender(Common):
is_sender = True is_sender = True
class TransitReceiver(Common): class TransitReceiver(Common):
is_sender = False is_sender = False
@ -926,6 +994,7 @@ class TransitReceiver(Common):
# when done, and add a progress function that gets called with the length of # when done, and add a progress function that gets called with the length of
# each write, and a hasher function that gets called with the data. # each write, and a hasher function that gets called with the data.
@implementer(interfaces.IConsumer) @implementer(interfaces.IConsumer)
class FileConsumer: class FileConsumer:
def __init__(self, f, progress=None, hasher=None): def __init__(self, f, progress=None, hasher=None):
@ -950,6 +1019,7 @@ class FileConsumer:
assert self._producer assert self._producer
self._producer = None self._producer = None
# the TransitSender/Receiver.connect() yields a Connection, on which you can # the TransitSender/Receiver.connect() yields a Connection, on which you can
# do send_record(), but what should the receive API be? set a callback for # do send_record(), but what should the receive API be? set a callback for
# inbound records? get a Deferred for the next record? The producer/consumer # inbound records? get a Deferred for the next record? The producer/consumer

View File

@ -1,30 +1,42 @@
# No unicode_literals # No unicode_literals
import os, json, unicodedata import json
import os
import unicodedata
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
def to_bytes(u): def to_bytes(u):
return unicodedata.normalize("NFC", u).encode("utf-8") return unicodedata.normalize("NFC", u).encode("utf-8")
def bytes_to_hexstr(b): def bytes_to_hexstr(b):
assert isinstance(b, type(b"")) assert isinstance(b, type(b""))
hexstr = hexlify(b).decode("ascii") hexstr = hexlify(b).decode("ascii")
assert isinstance(hexstr, type(u"")) assert isinstance(hexstr, type(u""))
return hexstr return hexstr
def hexstr_to_bytes(hexstr): def hexstr_to_bytes(hexstr):
assert isinstance(hexstr, type(u"")) assert isinstance(hexstr, type(u""))
b = unhexlify(hexstr.encode("ascii")) b = unhexlify(hexstr.encode("ascii"))
assert isinstance(b, type(b"")) assert isinstance(b, type(b""))
return b return b
def dict_to_bytes(d): def dict_to_bytes(d):
assert isinstance(d, dict) assert isinstance(d, dict)
b = json.dumps(d).encode("utf-8") b = json.dumps(d).encode("utf-8")
assert isinstance(b, type(b"")) assert isinstance(b, type(b""))
return b return b
def bytes_to_dict(b): def bytes_to_dict(b):
assert isinstance(b, type(b"")) assert isinstance(b, type(b""))
d = json.loads(b.decode("utf-8")) d = json.loads(b.decode("utf-8"))
assert isinstance(d, dict) assert isinstance(d, dict)
return d return d
def estimate_free_space(target): def estimate_free_space(target):
# f_bfree is the blocks available to a root user. It might be more # f_bfree is the blocks available to a root user. It might be more
# accurate to use f_bavail (blocks available to non-root user), but we # accurate to use f_bavail (blocks available to non-root user), but we

View File

@ -1,19 +1,25 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
import os, sys
from attr import attrs, attrib import os
from zope.interface import implementer import sys
from attr import attrib, attrs
from twisted.python import failure from twisted.python import failure
from . import __version__ from zope.interface import implementer
from ._interfaces import IWormhole, IDeferredWormhole
from .util import bytes_to_hexstr
from .eventual import EventualQueue
from .observer import OneShotObserver, SequenceObserver
from .timing import DebugTiming
from .journal import ImmediateJournal
from ._boss import Boss from ._boss import Boss
from ._interfaces import IDeferredWormhole, IWormhole
from ._key import derive_key from ._key import derive_key
from .errors import NoKeyError, WormholeClosed from .errors import NoKeyError, WormholeClosed
from .util import to_bytes from .eventual import EventualQueue
from .journal import ImmediateJournal
from .observer import OneShotObserver, SequenceObserver
from .timing import DebugTiming
from .util import bytes_to_hexstr, to_bytes
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions
# We can provide different APIs to different apps: # We can provide different APIs to different apps:
# * Deferreds # * Deferreds
@ -36,6 +42,7 @@ from .util import to_bytes
# wormhole(delegate=app, delegate_prefix="wormhole_", # wormhole(delegate=app, delegate_prefix="wormhole_",
# delegate_args=(args, kwargs)) # delegate_args=(args, kwargs))
@attrs @attrs
@implementer(IWormhole) @implementer(IWormhole)
class _DelegatedWormhole(object): class _DelegatedWormhole(object):
@ -51,16 +58,18 @@ class _DelegatedWormhole(object):
def allocate_code(self, code_length=2): def allocate_code(self, code_length=2):
self._boss.allocate_code(code_length) self._boss.allocate_code(code_length)
def input_code(self): def input_code(self):
return self._boss.input_code() return self._boss.input_code()
def set_code(self, code): def set_code(self, code):
self._boss.set_code(code) self._boss.set_code(code)
## def serialize(self): # def serialize(self):
## s = {"serialized_wormhole_version": 1, # s = {"serialized_wormhole_version": 1,
## "boss": self._boss.serialize(), # "boss": self._boss.serialize(),
## } # }
## return s # return s
def send_message(self, plaintext): def send_message(self, plaintext):
self._boss.send(plaintext) self._boss.send(plaintext)
@ -72,34 +81,45 @@ class _DelegatedWormhole(object):
cannot be called until when_verifier() has fired, nor after close() cannot be called until when_verifier() has fired, nor after close()
was called. was called.
""" """
if not isinstance(purpose, type("")): raise TypeError(type(purpose)) if not isinstance(purpose, type("")):
if not self._key: raise NoKeyError() raise TypeError(type(purpose))
if not self._key:
raise NoKeyError()
return derive_key(self._key, to_bytes(purpose), length) return derive_key(self._key, to_bytes(purpose), length)
def close(self): def close(self):
self._boss.close() self._boss.close()
def debug_set_trace(self, client_name, which="B N M S O K SK R RC L C T", def debug_set_trace(self,
client_name,
which="B N M S O K SK R RC L C T",
file=sys.stderr): file=sys.stderr):
self._boss._set_trace(client_name, which, file) self._boss._set_trace(client_name, which, file)
# from below # from below
def got_welcome(self, welcome): def got_welcome(self, welcome):
self._delegate.wormhole_got_welcome(welcome) self._delegate.wormhole_got_welcome(welcome)
def got_code(self, code): def got_code(self, code):
self._delegate.wormhole_got_code(code) self._delegate.wormhole_got_code(code)
def got_key(self, key): def got_key(self, key):
self._delegate.wormhole_got_unverified_key(key) self._delegate.wormhole_got_unverified_key(key)
self._key = key # for derive_key() self._key = key # for derive_key()
def got_verifier(self, verifier): def got_verifier(self, verifier):
self._delegate.wormhole_got_verifier(verifier) self._delegate.wormhole_got_verifier(verifier)
def got_versions(self, versions): def got_versions(self, versions):
self._delegate.wormhole_got_versions(versions) self._delegate.wormhole_got_versions(versions)
def received(self, plaintext): def received(self, plaintext):
self._delegate.wormhole_got_message(plaintext) self._delegate.wormhole_got_message(plaintext)
def closed(self, result): def closed(self, result):
self._delegate.wormhole_closed(result) self._delegate.wormhole_closed(result)
@implementer(IWormhole, IDeferredWormhole) @implementer(IWormhole, IDeferredWormhole)
class _DeferredWormhole(object): class _DeferredWormhole(object):
def __init__(self, eq): def __init__(self, eq):
@ -142,8 +162,10 @@ class _DeferredWormhole(object):
def allocate_code(self, code_length=2): def allocate_code(self, code_length=2):
self._boss.allocate_code(code_length) self._boss.allocate_code(code_length)
def input_code(self): def input_code(self):
return self._boss.input_code() return self._boss.input_code()
def set_code(self, code): def set_code(self, code):
self._boss.set_code(code) self._boss.set_code(code)
@ -159,8 +181,10 @@ class _DeferredWormhole(object):
cannot be called until when_verified() has fired, nor after close() cannot be called until when_verified() has fired, nor after close()
was called. was called.
""" """
if not isinstance(purpose, type("")): raise TypeError(type(purpose)) if not isinstance(purpose, type("")):
if not self._key: raise NoKeyError() raise TypeError(type(purpose))
if not self._key:
raise NoKeyError()
return derive_key(self._key, to_bytes(purpose), length) return derive_key(self._key, to_bytes(purpose), length)
def close(self): def close(self):
@ -172,7 +196,8 @@ class _DeferredWormhole(object):
self._boss.close() # only need to close if it wasn't already self._boss.close() # only need to close if it wasn't already
return d return d
def debug_set_trace(self, client_name, def debug_set_trace(self,
client_name,
which="B N M S O K SK R RC L A I C T", which="B N M S O K SK R RC L A I C T",
file=sys.stderr): file=sys.stderr):
self._boss._set_trace(client_name, which, file) self._boss._set_trace(client_name, which, file)
@ -180,14 +205,17 @@ class _DeferredWormhole(object):
# from below # from below
def got_welcome(self, welcome): def got_welcome(self, welcome):
self._welcome_observer.fire_if_not_fired(welcome) self._welcome_observer.fire_if_not_fired(welcome)
def got_code(self, code): def got_code(self, code):
self._code_observer.fire_if_not_fired(code) self._code_observer.fire_if_not_fired(code)
def got_key(self, key): def got_key(self, key):
self._key = key # for derive_key() self._key = key # for derive_key()
self._key_observer.fire_if_not_fired(key) self._key_observer.fire_if_not_fired(key)
def got_verifier(self, verifier): def got_verifier(self, verifier):
self._verifier_observer.fire_if_not_fired(verifier) self._verifier_observer.fire_if_not_fired(verifier)
def got_versions(self, versions): def got_versions(self, versions):
self._version_observer.fire_if_not_fired(versions) self._version_observer.fire_if_not_fired(versions)
@ -215,9 +243,14 @@ class _DeferredWormhole(object):
self._received_observer.fire(f) self._received_observer.fire(f)
def create(appid, relay_url, reactor, # use keyword args for everything else def create(
appid,
relay_url,
reactor, # use keyword args for everything else
versions={}, versions={},
delegate=None, journal=None, tor=None, delegate=None,
journal=None,
tor=None,
timing=None, timing=None,
stderr=sys.stderr, stderr=sys.stderr,
_eventual_queue=None): _eventual_queue=None):
@ -241,15 +274,16 @@ def create(appid, relay_url, reactor, # use keyword args for everything else
b.start() b.start()
return w return w
## def from_serialized(serialized, reactor, delegate,
## journal=None, tor=None, # def from_serialized(serialized, reactor, delegate,
## timing=None, stderr=sys.stderr): # journal=None, tor=None,
## assert serialized["serialized_wormhole_version"] == 1 # timing=None, stderr=sys.stderr):
## timing = timing or DebugTiming() # assert serialized["serialized_wormhole_version"] == 1
## w = _DelegatedWormhole(delegate) # timing = timing or DebugTiming()
## # now unpack state machines, including the SPAKE2 in Key # w = _DelegatedWormhole(delegate)
## b = Boss.from_serialized(w, serialized["boss"], reactor, journal, timing) # # now unpack state machines, including the SPAKE2 in Key
## w._set_boss(b) # b = Boss.from_serialized(w, serialized["boss"], reactor, journal, timing)
## b.start() # ?? # w._set_boss(b)
## raise NotImplemented # b.start() # ??
## # should the new Wormhole call got_code? only if it wasn't called before. # raise NotImplemented
# # should the new Wormhole call got_code? only if it wasn't called before.

View File

@ -1,12 +1,19 @@
import json import json
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from . import wormhole from . import wormhole
from .tor_manager import get_tor from .tor_manager import get_tor
@inlineCallbacks @inlineCallbacks
def receive(reactor, appid, relay_url, code, def receive(reactor,
use_tor=False, launch_tor=False, tor_control_port=None, appid,
relay_url,
code,
use_tor=False,
launch_tor=False,
tor_control_port=None,
on_code=None): on_code=None):
""" """
This is a convenience API which returns a Deferred that callbacks This is a convenience API which returns a Deferred that callbacks
@ -21,9 +28,11 @@ def receive(reactor, appid, relay_url, code,
:param unicode code: a pre-existing code to use, or None :param unicode code: a pre-existing code to use, or None
:param bool use_tor: True if we should use Tor, False to not use it (None for default) :param bool use_tor: True if we should use Tor, False to not use it (None
for default)
:param on_code: if not None, this is called when we have a code (even if you passed in one explicitly) :param on_code: if not None, this is called when we have a code (even if
you passed in one explicitly)
:type on_code: single-argument callable :type on_code: single-argument callable
""" """
tor = None tor = None
@ -48,27 +57,33 @@ def receive(reactor, appid, relay_url, code,
data = json.loads(data.decode("utf-8")) data = json.loads(data.decode("utf-8"))
offer = data.get('offer', None) offer = data.get('offer', None)
if not offer: if not offer:
raise Exception( raise Exception("Do not understand response: {}".format(data))
"Do not understand response: {}".format(data)
)
msg = None msg = None
if 'message' in offer: if 'message' in offer:
msg = offer['message'] msg = offer['message']
wh.send_message(json.dumps({"answer": wh.send_message(
{"message_ack": "ok"}}).encode("utf-8")) json.dumps({
"answer": {
"message_ack": "ok"
}
}).encode("utf-8"))
else: else:
raise Exception( raise Exception("Unknown offer type: {}".format(offer.keys()))
"Unknown offer type: {}".format(offer.keys())
)
yield wh.close() yield wh.close()
returnValue(msg) returnValue(msg)
@inlineCallbacks @inlineCallbacks
def send(reactor, appid, relay_url, data, code, def send(reactor,
use_tor=False, launch_tor=False, tor_control_port=None, appid,
relay_url,
data,
code,
use_tor=False,
launch_tor=False,
tor_control_port=None,
on_code=None): on_code=None):
""" """
This is a convenience API which returns a Deferred that callbacks This is a convenience API which returns a Deferred that callbacks
@ -83,9 +98,12 @@ def send(reactor, appid, relay_url, data, code,
:param unicode code: a pre-existing code to use, or None :param unicode code: a pre-existing code to use, or None
:param bool use_tor: True if we should use Tor, False to not use it (None for default) :param bool use_tor: True if we should use Tor, False to not use it (None
for default)
:param on_code: if not None, this is called when we have a code (even if
you passed in one explicitly)
:param on_code: if not None, this is called when we have a code (even if you passed in one explicitly)
:type on_code: single-argument callable :type on_code: single-argument callable
""" """
tor = None tor = None
@ -104,13 +122,7 @@ def send(reactor, appid, relay_url, data, code,
if on_code: if on_code:
on_code(code) on_code(code)
wh.send_message( wh.send_message(json.dumps({"offer": {"message": data}}).encode("utf-8"))
json.dumps({
"offer": {
"message": data
}
}).encode("utf-8")
)
data = yield wh.get_message() data = yield wh.get_message()
data = json.loads(data.decode("utf-8")) data = json.loads(data.decode("utf-8"))
answer = data.get('answer', None) answer = data.get('answer', None)
@ -118,6 +130,4 @@ def send(reactor, appid, relay_url, data, code,
if answer: if answer:
returnValue(None) returnValue(None)
else: else:
raise Exception( raise Exception("Unknown answer: {}".format(data))
"Unknown answer: {}".format(data)
)

11
tox.ini
View File

@ -4,7 +4,7 @@
# and then run "tox" from this directory. # and then run "tox" from this directory.
[tox] [tox]
envlist = {py27,py34,py35,py36,pypy} envlist = {py27,py34,py35,py36,pypy,flake8}
skip_missing_interpreters = True skip_missing_interpreters = True
minversion = 2.4.0 minversion = 2.4.0
@ -36,3 +36,12 @@ commands =
wormhole --version wormhole --version
coverage run --branch -m wormhole.test.run_trial {posargs:wormhole} coverage run --branch -m wormhole.test.run_trial {posargs:wormhole}
coverage xml coverage xml
[testenv:flake8]
deps = flake8
commands = flake8 src/wormhole
[flake8]
ignore = E741
exclude = .git,__pycache__,docs/source/conf.py,old,build,dist
max-complexity = 40