Merge branch 'pr34'

Closes #12.
This commit is contained in:
Brian Warner 2016-06-02 16:13:48 -07:00
commit 7b0ca28589
4 changed files with 38 additions and 5 deletions

View File

@ -4,7 +4,8 @@ start = time.time()
import os, sys, textwrap import os, sys, textwrap
from twisted.internet.defer import maybeDeferred from twisted.internet.defer import maybeDeferred
from twisted.internet.task import react from twisted.internet.task import react
from ..errors import TransferError, WrongPasswordError, WelcomeError, Timeout from ..errors import (TransferError, WrongPasswordError, WelcomeError, Timeout,
KeyFormatError)
from ..timing import DebugTiming from ..timing import DebugTiming
from .cli_args import parser from .cli_args import parser
top_import_finish = time.time() top_import_finish = time.time()
@ -49,7 +50,8 @@ def run(reactor, argv, cwd, stdout, stderr, executable=None):
d.addBoth(_maybe_dump_timing) d.addBoth(_maybe_dump_timing)
def _explain_error(f): def _explain_error(f):
# these errors don't print a traceback, just an explanation # these errors don't print a traceback, just an explanation
f.trap(TransferError, WrongPasswordError, WelcomeError, Timeout) f.trap(TransferError, WrongPasswordError, WelcomeError, Timeout,
KeyFormatError)
if f.check(WrongPasswordError): if f.check(WrongPasswordError):
msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__)) msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__))
print(msg, file=stderr) print(msg, file=stderr)
@ -58,6 +60,9 @@ def run(reactor, argv, cwd, stdout, stderr, executable=None):
print(msg, file=stderr) print(msg, file=stderr)
print(file=stderr) print(file=stderr)
print(str(f.value), file=stderr) print(str(f.value), file=stderr)
elif f.check(KeyFormatError):
msg = textwrap.fill("ERROR: " + textwrap.dedent(f.value.__doc__))
print(msg, file=stderr)
else: else:
print("ERROR:", f.value, file=stderr) print("ERROR:", f.value, file=stderr)
raise SystemExit(1) raise SystemExit(1)

View File

@ -36,6 +36,13 @@ class WrongPasswordError(Exception):
# or the data blob was corrupted, and that's why decrypt failed # or the data blob was corrupted, and that's why decrypt failed
pass pass
class KeyFormatError(Exception):
"""
The key you entered contains spaces. Magic-wormhole expects keys to be
separated by dashes. Please reenter the key you were given separating the
words with dashes.
"""
class ReflectionAttack(Exception): class ReflectionAttack(Exception):
"""An attacker (or bug) reflected our outgoing message back to us.""" """An attacker (or bug) reflected our outgoing message back to us."""

View File

@ -7,7 +7,8 @@ from twisted.internet import reactor
from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks
from .common import ServerBase from .common import ServerBase
from .. import wormhole from .. import wormhole
from ..errors import WrongPasswordError, WelcomeError, UsageError from ..errors import (WrongPasswordError, WelcomeError, UsageError,
KeyFormatError)
from spake2 import SPAKE2_Symmetric from spake2 import SPAKE2_Symmetric
from ..timing import DebugTiming from ..timing import DebugTiming
from ..util import (bytes_to_dict, dict_to_bytes, from ..util import (bytes_to_dict, dict_to_bytes,
@ -818,6 +819,23 @@ class Wormholes(ServerBase, unittest.TestCase):
yield w2.close() yield w2.close()
self.flushLoggedErrors(WrongPasswordError) self.flushLoggedErrors(WrongPasswordError)
@inlineCallbacks
def test_wrong_password_with_spaces(self):
w1 = wormhole.wormhole(APPID, self.relayurl, reactor)
w2 = wormhole.wormhole(APPID, self.relayurl, reactor)
code = yield w1.get_code()
code_no_dashes = code.replace('-', ' ')
with self.assertRaises(KeyFormatError) as ex:
w2.set_code(code_no_dashes)
expected_msg = "code (%s) contains spaces." % (code_no_dashes,)
self.assertEqual(expected_msg, str(ex.exception))
yield w1.close()
yield w2.close()
self.flushLoggedErrors(KeyFormatError)
@inlineCallbacks @inlineCallbacks
def test_verifier(self): def test_verifier(self):
w1 = wormhole.wormhole(APPID, self.relayurl, reactor) w1 = wormhole.wormhole(APPID, self.relayurl, reactor)
@ -875,4 +893,3 @@ class Errors(ServerBase, unittest.TestCase):
yield self.assertFailure(w.get_code(), UsageError) yield self.assertFailure(w.get_code(), UsageError)
yield self.assertFailure(w.input_code(), UsageError) yield self.assertFailure(w.input_code(), UsageError)
yield w.close() yield w.close()

View File

@ -15,7 +15,7 @@ from . import __version__
from . import codes from . import codes
#from .errors import ServerError, Timeout #from .errors import ServerError, Timeout
from .errors import (WrongPasswordError, UsageError, WelcomeError, from .errors import (WrongPasswordError, UsageError, WelcomeError,
WormholeClosedError) WormholeClosedError, KeyFormatError)
from .timing import DebugTiming from .timing import DebugTiming
from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes, from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes,
dict_to_bytes, bytes_to_dict) dict_to_bytes, bytes_to_dict)
@ -476,6 +476,10 @@ class _Wormhole:
def _event_learned_code(self, code): def _event_learned_code(self, code):
self._timing.add("code established") self._timing.add("code established")
# bail out early if the password contains spaces...
# this should raise a useful error
if ' ' in code:
raise KeyFormatError("code (%s) contains spaces." % code)
self._code = code self._code = code
mo = re.search(r'^(\d+)-', code) mo = re.search(r'^(\d+)-', code)
if not mo: if not mo: