reject invalid codes with KeyFormatError

refs #212
This commit is contained in:
Brian Warner 2017-07-04 10:50:21 -07:00
parent de0af837cc
commit 8b0a245e19
8 changed files with 131 additions and 20 deletions

View File

@ -248,6 +248,11 @@ it doesn't matter which one goes first, and both use the same Wormhole
constructor function. However if `w.allocate_code()` is used, only one side constructor function. However if `w.allocate_code()` is used, only one side
should use it. should use it.
Providing an invalid nameplate (which is easily caused by cut-and-paste
errors that include an extra space at the beginning, or which copy the words
but not the number) will raise a `KeyFormatError`, either in
`w.set_code(code)` or in `h.choose_nameplate()`.
## Offline Codes ## Offline Codes
In most situations, the "sending" or "initiating" side will call In most situations, the "sending" or "initiating" side will call

View File

@ -17,12 +17,11 @@ from ._rendezvous import RendezvousConnector
from ._lister import Lister from ._lister import Lister
from ._allocator import Allocator from ._allocator import Allocator
from ._input import Input from ._input import Input
from ._code import Code from ._code import Code, validate_code
from ._terminator import Terminator from ._terminator import Terminator
from ._wordlist import PGPWordList from ._wordlist import PGPWordList
from .errors import (ServerError, LonelyError, WrongPasswordError, from .errors import (ServerError, LonelyError, WrongPasswordError,
KeyFormatError, OnlyOneCodeError, _UnknownPhaseError, OnlyOneCodeError, _UnknownPhaseError, WelcomeError)
WelcomeError)
from .util import bytes_to_dict from .util import bytes_to_dict
@attrs @attrs
@ -159,8 +158,7 @@ class Boss(object):
wl = PGPWordList() wl = PGPWordList()
self._C.allocate_code(code_length, wl) self._C.allocate_code(code_length, wl)
def set_code(self, code): def set_code(self, code):
if ' ' in code: validate_code(code) # can raise KeyFormatError
raise KeyFormatError("code (%s) contains spaces." % code)
if self._did_start_code: if self._did_start_code:
raise OnlyOneCodeError() raise OnlyOneCodeError()
self._did_start_code = True self._did_start_code = True

View File

@ -4,6 +4,14 @@ from attr import attrs, attrib
from attr.validators import provides from attr.validators import provides
from automat import MethodicalMachine from automat import MethodicalMachine
from . import _interfaces from . import _interfaces
from ._nameplate import validate_nameplate
from .errors import KeyFormatError
def validate_code(code):
if ' ' in code:
raise KeyFormatError("Code '%s' contains spaces." % code)
nameplate = code.split("-", 2)[0]
validate_nameplate(nameplate) # can raise KeyFormatError
def first(outputs): def first(outputs):
return list(outputs)[0] return list(outputs)[0]
@ -38,8 +46,11 @@ class Code(object):
def allocate_code(self, length, wordlist): pass def allocate_code(self, length, wordlist): pass
@m.input() @m.input()
def input_code(self): pass def input_code(self): pass
def set_code(self, code):
validate_code(code) # can raise KeyFormatError
self._set_code(code)
@m.input() @m.input()
def set_code(self, code): pass def _set_code(self, code): pass
# from Allocator # from Allocator
@m.input() @m.input()
@ -79,7 +90,7 @@ class Code(object):
self._B.got_code(code) self._B.got_code(code)
self._K.got_code(code) self._K.got_code(code)
S0_idle.upon(set_code, enter=S4_known, outputs=[do_set_code]) S0_idle.upon(_set_code, enter=S4_known, outputs=[do_set_code])
S0_idle.upon(input_code, enter=S1_inputting_nameplate, S0_idle.upon(input_code, enter=S1_inputting_nameplate,
outputs=[do_start_input], collector=first) outputs=[do_start_input], collector=first)
S1_inputting_nameplate.upon(got_nameplate, enter=S2_inputting_words, S1_inputting_nameplate.upon(got_nameplate, enter=S2_inputting_words,

View File

@ -5,6 +5,7 @@ from attr.validators import provides
from twisted.internet import defer from twisted.internet import defer
from automat import MethodicalMachine from automat import MethodicalMachine
from . import _interfaces, errors from . import _interfaces, errors
from ._nameplate import validate_nameplate
def first(outputs): def first(outputs):
return list(outputs)[0] return list(outputs)[0]
@ -61,8 +62,11 @@ class Input(object):
def refresh_nameplates(self): pass def refresh_nameplates(self): pass
@m.input() @m.input()
def get_nameplate_completions(self, prefix): pass def get_nameplate_completions(self, prefix): pass
def choose_nameplate(self, nameplate):
validate_nameplate(nameplate) # can raise KeyFormatError
self._choose_nameplate(nameplate)
@m.input() @m.input()
def choose_nameplate(self, nameplate): pass def _choose_nameplate(self, nameplate): pass
@m.input() @m.input()
def get_word_completions(self, prefix): pass def get_word_completions(self, prefix): pass
@m.input() @m.input()
@ -158,7 +162,7 @@ class Input(object):
enter=S1_typing_nameplate, enter=S1_typing_nameplate,
outputs=[_get_nameplate_completions], outputs=[_get_nameplate_completions],
collector=first) collector=first)
S1_typing_nameplate.upon(choose_nameplate, enter=S2_typing_code_no_wordlist, S1_typing_nameplate.upon(_choose_nameplate, enter=S2_typing_code_no_wordlist,
outputs=[record_all_nameplates]) outputs=[record_all_nameplates])
S1_typing_nameplate.upon(get_word_completions, S1_typing_nameplate.upon(get_word_completions,
enter=S1_typing_nameplate, enter=S1_typing_nameplate,
@ -178,7 +182,7 @@ class Input(object):
S2_typing_code_no_wordlist.upon(get_nameplate_completions, S2_typing_code_no_wordlist.upon(get_nameplate_completions,
enter=S2_typing_code_no_wordlist, enter=S2_typing_code_no_wordlist,
outputs=[raise_already_chose_nameplate2]) outputs=[raise_already_chose_nameplate2])
S2_typing_code_no_wordlist.upon(choose_nameplate, S2_typing_code_no_wordlist.upon(_choose_nameplate,
enter=S2_typing_code_no_wordlist, enter=S2_typing_code_no_wordlist,
outputs=[raise_already_chose_nameplate3]) outputs=[raise_already_chose_nameplate3])
S2_typing_code_no_wordlist.upon(get_word_completions, S2_typing_code_no_wordlist.upon(get_word_completions,
@ -198,7 +202,7 @@ class Input(object):
S3_typing_code_yes_wordlist.upon(get_nameplate_completions, S3_typing_code_yes_wordlist.upon(get_nameplate_completions,
enter=S3_typing_code_yes_wordlist, enter=S3_typing_code_yes_wordlist,
outputs=[raise_already_chose_nameplate2]) outputs=[raise_already_chose_nameplate2])
S3_typing_code_yes_wordlist.upon(choose_nameplate, S3_typing_code_yes_wordlist.upon(_choose_nameplate,
enter=S3_typing_code_yes_wordlist, enter=S3_typing_code_yes_wordlist,
outputs=[raise_already_chose_nameplate3]) outputs=[raise_already_chose_nameplate3])
S3_typing_code_yes_wordlist.upon(get_word_completions, S3_typing_code_yes_wordlist.upon(get_word_completions,
@ -216,7 +220,7 @@ class Input(object):
S4_done.upon(get_nameplate_completions, S4_done.upon(get_nameplate_completions,
enter=S4_done, enter=S4_done,
outputs=[raise_already_chose_nameplate2]) outputs=[raise_already_chose_nameplate2])
S4_done.upon(choose_nameplate, enter=S4_done, S4_done.upon(_choose_nameplate, enter=S4_done,
outputs=[raise_already_chose_nameplate3]) outputs=[raise_already_chose_nameplate3])
S4_done.upon(get_word_completions, enter=S4_done, S4_done.upon(get_word_completions, enter=S4_done,
outputs=[raise_already_chose_words1]) outputs=[raise_already_chose_words1])

View File

@ -1,8 +1,16 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import print_function, absolute_import, unicode_literals
import re
from zope.interface import implementer from zope.interface import implementer
from automat import MethodicalMachine from automat import MethodicalMachine
from . import _interfaces from . import _interfaces
from ._wordlist import PGPWordList from ._wordlist import PGPWordList
from .errors import KeyFormatError
def validate_nameplate(nameplate):
if not re.search(r'^\d+$', nameplate):
raise KeyFormatError("Nameplate '%s' must be numeric, with no spaces."
% nameplate)
@implementer(_interfaces.INameplate) @implementer(_interfaces.INameplate)
class Nameplate(object): class Nameplate(object):
@ -62,8 +70,11 @@ class Nameplate(object):
S5B = S5 S5B = S5
# from Boss # from Boss
def set_nameplate(self, nameplate):
validate_nameplate(nameplate) # can raise KeyFormatError
self._set_nameplate(nameplate)
@m.input() @m.input()
def set_nameplate(self, nameplate): pass def _set_nameplate(self, nameplate): pass
# from Mailbox # from Mailbox
@m.input() @m.input()
@ -84,12 +95,13 @@ class Nameplate(object):
@m.input() @m.input()
def rx_released(self): pass def rx_released(self): pass
@m.output() @m.output()
def record_nameplate(self, nameplate): def record_nameplate(self, nameplate):
validate_nameplate(nameplate)
self._nameplate = nameplate self._nameplate = nameplate
@m.output() @m.output()
def record_nameplate_and_RC_tx_claim(self, nameplate): def record_nameplate_and_RC_tx_claim(self, nameplate):
validate_nameplate(nameplate)
self._nameplate = nameplate self._nameplate = nameplate
self._RC.tx_claim(self._nameplate) self._RC.tx_claim(self._nameplate)
@m.output() @m.output()
@ -112,10 +124,10 @@ class Nameplate(object):
def T_nameplate_done(self): def T_nameplate_done(self):
self._T.nameplate_done() self._T.nameplate_done()
S0A.upon(set_nameplate, enter=S1A, outputs=[record_nameplate]) S0A.upon(_set_nameplate, enter=S1A, outputs=[record_nameplate])
S0A.upon(connected, enter=S0B, outputs=[]) S0A.upon(connected, enter=S0B, outputs=[])
S0A.upon(close, enter=S5A, outputs=[T_nameplate_done]) S0A.upon(close, enter=S5A, outputs=[T_nameplate_done])
S0B.upon(set_nameplate, enter=S2B, S0B.upon(_set_nameplate, enter=S2B,
outputs=[record_nameplate_and_RC_tx_claim]) outputs=[record_nameplate_and_RC_tx_claim])
S0B.upon(lost, enter=S0A, outputs=[]) S0B.upon(lost, enter=S0A, outputs=[])
S0B.upon(close, enter=S5A, outputs=[T_nameplate_done]) S0B.upon(close, enter=S5A, outputs=[T_nameplate_done])

View File

@ -106,11 +106,11 @@ def _dispatch_command(reactor, cfg, command):
try: try:
yield maybeDeferred(command) yield maybeDeferred(command)
except (WrongPasswordError, KeyFormatError, NoTorError) as e: except (WrongPasswordError, NoTorError) as e:
msg = fill("ERROR: " + dedent(e.__doc__)) msg = fill("ERROR: " + dedent(e.__doc__))
print(msg, file=cfg.stderr) print(msg, file=cfg.stderr)
raise SystemExit(1) raise SystemExit(1)
except (WelcomeError, UnsendableFileError) as e: except (WelcomeError, UnsendableFileError, KeyFormatError) as e:
msg = fill("ERROR: " + dedent(e.__doc__)) msg = fill("ERROR: " + dedent(e.__doc__))
print(msg, file=cfg.stderr) print(msg, file=cfg.stderr)
print(six.u(""), file=cfg.stderr) print(six.u(""), file=cfg.stderr)

View File

@ -313,6 +313,26 @@ class Code(unittest.TestCase):
("k.got_code", u"1-code"), ("k.got_code", u"1-code"),
]) ])
def test_set_code_invalid(self):
c, b, a, n, k, i, events = self.build()
with self.assertRaises(errors.KeyFormatError) as e:
c.set_code(u"1-code ")
self.assertEqual(str(e.exception), "Code '1-code ' contains spaces.")
with self.assertRaises(errors.KeyFormatError) as e:
c.set_code(u" 1-code")
self.assertEqual(str(e.exception), "Code ' 1-code' contains spaces.")
with self.assertRaises(errors.KeyFormatError) as e:
c.set_code(u"code-code")
self.assertEqual(str(e.exception),
"Nameplate 'code' must be numeric, with no spaces.")
# it should still be possible to use the wormhole at this point
c.set_code(u"1-code")
self.assertEqual(events, [("n.set_nameplate", u"1"),
("b.got_code", u"1-code"),
("k.got_code", u"1-code"),
])
def test_allocate_code(self): def test_allocate_code(self):
c, b, a, n, k, i, events = self.build() c, b, a, n, k, i, events = self.build()
wl = FakeWordList() wl = FakeWordList()
@ -366,6 +386,27 @@ class Input(unittest.TestCase):
helper.choose_words("word-word") helper.choose_words("word-word")
self.assertEqual(events, [("c.finished_input", "1-word-word")]) self.assertEqual(events, [("c.finished_input", "1-word-word")])
def test_bad_nameplate(self):
i, c, l, events = self.build()
helper = i.start()
self.assertIsInstance(helper, _input.Helper)
self.assertEqual(events, [("l.refresh",)])
events[:] = []
with self.assertRaises(errors.MustChooseNameplateFirstError):
helper.choose_words("word-word")
with self.assertRaises(errors.KeyFormatError):
helper.choose_nameplate(" 1")
# should still work afterwards
helper.choose_nameplate("1")
self.assertEqual(events, [("c.got_nameplate", "1")])
events[:] = []
with self.assertRaises(errors.AlreadyChoseNameplateError):
helper.choose_nameplate("2")
helper.choose_words("word-word")
with self.assertRaises(errors.AlreadyChoseWordsError):
helper.choose_words("word-word")
self.assertEqual(events, [("c.finished_input", "1-word-word")])
def test_with_completion(self): def test_with_completion(self):
i, c, l, events = self.build() i, c, l, events = self.build()
helper = i.start() helper = i.start()
@ -566,6 +607,23 @@ class Nameplate(unittest.TestCase):
n.wire(m, i, rc, t) n.wire(m, i, rc, t)
return n, m, i, rc, t, events return n, m, i, rc, t, events
def test_set_invalid(self):
n, m, i, rc, t, events = self.build()
with self.assertRaises(errors.KeyFormatError) as e:
n.set_nameplate(" 1")
self.assertEqual(str(e.exception),
"Nameplate ' 1' must be numeric, with no spaces.")
with self.assertRaises(errors.KeyFormatError) as e:
n.set_nameplate("one")
self.assertEqual(str(e.exception),
"Nameplate 'one' must be numeric, with no spaces.")
# wormhole should still be usable
n.set_nameplate("1")
self.assertEqual(events, [])
n.connected()
self.assertEqual(events, [("rc.tx_claim", "1")])
def test_set_first(self): def test_set_first(self):
# connection remains up throughout # connection remains up throughout
n, m, i, rc, t, events = self.build() n, m, i, rc, t, events = self.build()
@ -1356,8 +1414,11 @@ class Boss(unittest.TestCase):
b, events = self.build() b, events = self.build()
with self.assertRaises(errors.KeyFormatError): with self.assertRaises(errors.KeyFormatError):
b.set_code("1 code") b.set_code("1 code")
# wormhole should still be usable
b.set_code("1-code")
self.assertEqual(events, [("c.set_code", "1-code")])
def test_set_code_bad_twice(self): def test_set_code_twice(self):
b, events = self.build() b, events = self.build()
b.set_code("1-code") b.set_code("1-code")
with self.assertRaises(errors.OnlyOneCodeError): with self.assertRaises(errors.OnlyOneCodeError):

View File

@ -448,7 +448,27 @@ class Wormholes(ServerBase, unittest.TestCase):
badcode = "4 oops spaces" badcode = "4 oops spaces"
with self.assertRaises(KeyFormatError) as ex: with self.assertRaises(KeyFormatError) as ex:
w.set_code(badcode) w.set_code(badcode)
expected_msg = "code (%s) contains spaces." % (badcode,) expected_msg = "Code '%s' contains spaces." % (badcode,)
self.assertEqual(expected_msg, str(ex.exception))
yield self.assertFailure(w.close(), LonelyError)
@inlineCallbacks
def test_wrong_password_with_leading_space(self):
w = wormhole.create(APPID, self.relayurl, reactor)
badcode = " 4-oops-space"
with self.assertRaises(KeyFormatError) as ex:
w.set_code(badcode)
expected_msg = "Code '%s' contains spaces." % (badcode,)
self.assertEqual(expected_msg, str(ex.exception))
yield self.assertFailure(w.close(), LonelyError)
@inlineCallbacks
def test_wrong_password_with_non_numeric_nameplate(self):
w = wormhole.create(APPID, self.relayurl, reactor)
badcode = "four-oops-space"
with self.assertRaises(KeyFormatError) as ex:
w.set_code(badcode)
expected_msg = "Nameplate 'four' must be numeric, with no spaces."
self.assertEqual(expected_msg, str(ex.exception)) self.assertEqual(expected_msg, str(ex.exception))
yield self.assertFailure(w.close(), LonelyError) yield self.assertFailure(w.close(), LonelyError)