Merge PR49: emit pacifier message when key-verification is slow

Closes #29
Closes #49
This commit is contained in:
Brian Warner 2016-12-16 01:47:56 -08:00
commit ec3599ff09
5 changed files with 122 additions and 11 deletions

View File

@ -12,6 +12,7 @@ from ..util import (dict_to_bytes, bytes_to_dict, bytes_to_hexstr,
estimate_free_space) estimate_free_space)
APPID = u"lothar.com/wormhole/text-or-file-xfer" APPID = u"lothar.com/wormhole/text-or-file-xfer"
VERIFY_TIMER = 1
class RespondError(Exception): class RespondError(Exception):
def __init__(self, response): def __init__(self, response):
@ -75,7 +76,16 @@ class TwistedReceiver:
@inlineCallbacks @inlineCallbacks
def _go(self, w): def _go(self, w):
yield self._handle_code(w) yield self._handle_code(w)
verifier = yield w.verify() yield w.establish_key()
def on_slow_connection():
print(u"Key established, waiting for confirmation...",
file=self.args.stderr)
notify = self._reactor.callLater(VERIFY_TIMER, on_slow_connection)
try:
verifier = yield w.verify()
finally:
if not notify.called:
notify.cancel()
self._show_verifier(verifier) self._show_verifier(verifier)
want_offer = True want_offer = True

View File

@ -12,6 +12,7 @@ from ..transit import TransitSender
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr
APPID = u"lothar.com/wormhole/text-or-file-xfer" APPID = u"lothar.com/wormhole/text-or-file-xfer"
VERIFY_TIMER = 1
def send(args, reactor=reactor): def send(args, reactor=reactor):
"""I implement 'wormhole send'. I return a Deferred that fires with None """I implement 'wormhole send'. I return a Deferred that fires with None
@ -85,8 +86,19 @@ class Sender:
print(u"Wormhole code is: %s" % code, file=args.stdout) print(u"Wormhole code is: %s" % code, file=args.stdout)
print(u"", file=args.stdout) print(u"", file=args.stdout)
yield w.establish_key()
def on_slow_connection():
print(u"Key established, waiting for confirmation...",
file=args.stderr)
notify = self._reactor.callLater(VERIFY_TIMER, on_slow_connection)
# TODO: don't stall on w.verify() unless they want it # TODO: don't stall on w.verify() unless they want it
verifier_bytes = yield w.verify() # this may raise WrongPasswordError try:
verifier_bytes = yield w.verify() # this may raise WrongPasswordError
finally:
if not notify.called:
notify.cancel()
if args.verify: if args.verify:
verifier = bytes_to_hexstr(verifier_bytes) verifier = bytes_to_hexstr(verifier_bytes)
while True: while True:

View File

@ -225,7 +225,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def _do_test(self, as_subprocess=False, def _do_test(self, as_subprocess=False,
mode="text", addslash=False, override_filename=False): mode="text", addslash=False, override_filename=False):
assert mode in ("text", "file", "directory") assert mode in ("text", "file", "directory", "slow-text")
send_cfg = config("send") send_cfg = config("send")
recv_cfg = config("receive") recv_cfg = config("receive")
message = "blah blah blah ponies" message = "blah blah blah ponies"
@ -244,7 +244,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
receive_dir = self.mktemp() receive_dir = self.mktemp()
os.mkdir(receive_dir) os.mkdir(receive_dir)
if mode == "text": if mode == "text" or mode == "slow-text":
send_cfg.text = message send_cfg.text = message
elif mode == "file": elif mode == "file":
@ -348,8 +348,14 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
# The sender might fail, leaving the receiver hanging, or vice # The sender might fail, leaving the receiver hanging, or vice
# versa. Make sure we don't wait on one side exclusively # versa. Make sure we don't wait on one side exclusively
if mode == "slow-text":
with mock.patch.object(cmd_send, "VERIFY_TIMER", 0), \
mock.patch.object(cmd_receive, "VERIFY_TIMER", 0):
yield gatherResults([send_d, receive_d], True)
else:
yield gatherResults([send_d, receive_d], True)
yield gatherResults([send_d, receive_d], True)
send_stdout = send_cfg.stdout.getvalue() send_stdout = send_cfg.stdout.getvalue()
send_stderr = send_cfg.stderr.getvalue() send_stderr = send_cfg.stderr.getvalue()
receive_stdout = recv_cfg.stdout.getvalue() receive_stdout = recv_cfg.stdout.getvalue()
@ -361,13 +367,21 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.maxDiff = None # show full output for assertion failures self.maxDiff = None # show full output for assertion failures
self.failUnlessEqual(send_stderr, "", if mode != "slow-text":
(send_stdout, send_stderr)) self.failUnlessEqual(send_stderr, "",
self.failUnlessEqual(receive_stderr, "", (send_stdout, send_stderr))
(receive_stdout, receive_stderr)) self.failUnlessEqual(receive_stderr, "",
(receive_stdout, receive_stderr))
else:
self.assertEqual(send_stderr,
"Key established, waiting for confirmation...\n",
(send_stdout, send_stderr))
self.assertEqual(receive_stderr,
"Key established, waiting for confirmation...\n",
(receive_stdout, receive_stderr))
# check sender # check sender
if mode == "text": if mode == "text" or mode == "slow-text":
expected = ("Sending text message ({bytes:d} Bytes){NL}" expected = ("Sending text message ({bytes:d} Bytes){NL}"
"On the other computer, please run: " "On the other computer, please run: "
"wormhole receive{NL}" "wormhole receive{NL}"
@ -401,7 +415,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
.format(NL=NL), send_stdout) .format(NL=NL), send_stdout)
# check receiver # check receiver
if mode == "text": if mode == "text" or mode == "slow-text":
self.failUnlessEqual(receive_stdout, message+NL) self.failUnlessEqual(receive_stdout, message+NL)
elif mode == "file": elif mode == "file":
self.failUnlessIn("Receiving file ({size:s}) into: {name}" self.failUnlessIn("Receiving file ({size:s}) into: {name}"
@ -448,6 +462,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
def test_directory_override(self): def test_directory_override(self):
return self._do_test(mode="directory", override_filename=True) return self._do_test(mode="directory", override_filename=True)
def test_slow_text(self):
return self._do_test(mode="slow-text")
@inlineCallbacks @inlineCallbacks
def _do_test_fail(self, mode, failmode): def _do_test_fail(self, mode, failmode):
assert mode in ("file", "directory") assert mode in ("file", "directory")

View File

@ -522,6 +522,45 @@ class Basic(unittest.TestCase):
self.assertEqual(len(pieces), 3) # nameplate plus two words self.assertEqual(len(pieces), 3) # nameplate plus two words
self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) self.assert_(re.search(r'^\d+-\w+-\w+$', code), code)
def _test_establish_key_hook(self, established, before):
timing = DebugTiming()
w = wormhole._Wormhole(APPID, "relay_url", reactor, None, timing)
if before:
d = w.establish_key()
if established is True:
w._key = b"key"
elif established is False:
w._key = None
else:
w._key = b"key"
w._error = WelcomeError()
if not before:
d = w.establish_key()
else:
w._maybe_notify_key()
if w._key is not None and established is True:
self.successResultOf(d)
elif established is False:
self.assertNot(d.called)
else:
self.failureResultOf(d)
def test_establish_key_hook(self):
for established in (True, False, "error"):
for before in (True, False):
self._test_establish_key_hook(established, before)
def test_establish_key_twice(self):
timing = DebugTiming()
w = wormhole._Wormhole(APPID, "relay_url", reactor, None, timing)
d = w.establish_key()
self.assertRaises(InternalError, w.establish_key)
del d
# make sure verify() can be called both before and after the verifier is # make sure verify() can be called both before and after the verifier is
# computed # computed

View File

@ -237,6 +237,8 @@ class _Wormhole:
self._flag_need_to_see_mailbox_used = True self._flag_need_to_see_mailbox_used = True
self._flag_need_to_build_msg1 = True self._flag_need_to_build_msg1 = True
self._flag_need_to_send_PAKE = True self._flag_need_to_send_PAKE = True
self._establish_key_called = False
self._key_waiter = None
self._key = None self._key = None
self._version_message = None self._version_message = None
@ -283,6 +285,14 @@ class _Wormhole:
# todo: restore-saved-state entry points # todo: restore-saved-state entry points
def establish_key(self):
"""
returns a Deferred that fires when we've established the shared key.
When successful, the Deferred fires with a simple `True`, otherwise
it fails.
"""
return self._API_establish_key()
def verify(self): def verify(self):
"""Returns a Deferred that fires when we've heard back from the other """Returns a Deferred that fires when we've heard back from the other
side, and have confirmed that they used the right wormhole code. When side, and have confirmed that they used the right wormhole code. When
@ -559,6 +569,7 @@ class _Wormhole:
def _event_established_key(self): def _event_established_key(self):
self._timing.add("key established") self._timing.add("key established")
self._maybe_notify_key()
# both sides send different (random) version messages # both sides send different (random) version messages
self._send_version_message() self._send_version_message()
@ -569,6 +580,25 @@ class _Wormhole:
self._maybe_check_version() self._maybe_check_version()
self._maybe_send_phase_messages() self._maybe_send_phase_messages()
def _API_establish_key(self):
if self._error: return defer.fail(self._error)
if self._establish_key_called: raise InternalError
self._establish_key_called = True
if self._key is not None:
return defer.succeed(True)
self._key_waiter = defer.Deferred()
return self._key_waiter
def _maybe_notify_key(self):
if self._key is None:
return
if self._error:
result = failure.Failure(self._error)
else:
result = True
if self._key_waiter and not self._key_waiter.called:
self._key_waiter.callback(result)
def _send_version_message(self): def _send_version_message(self):
# this is encrypted like a normal phase message, and includes a # this is encrypted like a normal phase message, and includes a
# dictionary of version flags to let the other Wormhole know what # dictionary of version flags to let the other Wormhole know what
@ -824,6 +854,9 @@ class _Wormhole:
if self._verifier_waiter and not self._verifier_waiter.called: if self._verifier_waiter and not self._verifier_waiter.called:
if self.DEBUG: print("EB VW") if self.DEBUG: print("EB VW")
self._verifier_waiter.errback(error) self._verifier_waiter.errback(error)
if self._key_waiter and not self._key_waiter.called:
if self.DEBUG: print("EB KW")
self._key_waiter.errback(error)
for d in self._receive_waiters.values(): for d in self._receive_waiters.values():
if self.DEBUG: print("EB RW") if self.DEBUG: print("EB RW")
d.errback(error) d.errback(error)