diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index e458de3..c23b44a 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -12,6 +12,7 @@ from ..util import (dict_to_bytes, bytes_to_dict, bytes_to_hexstr, estimate_free_space) APPID = u"lothar.com/wormhole/text-or-file-xfer" +VERIFY_TIMER = 1 class RespondError(Exception): def __init__(self, response): @@ -75,7 +76,16 @@ class TwistedReceiver: @inlineCallbacks def _go(self, w): yield self._handle_code(w) - verifier = yield w.verify() + yield w.establish_key() + def on_slow_connection(): + print(u"Key established, waiting for confirmation...", + file=self.args.stderr) + notify = self._reactor.callLater(VERIFY_TIMER, on_slow_connection) + try: + verifier = yield w.verify() + finally: + if not notify.called: + notify.cancel() self._show_verifier(verifier) want_offer = True diff --git a/src/wormhole/cli/cmd_send.py b/src/wormhole/cli/cmd_send.py index 6fa7446..2c46af8 100644 --- a/src/wormhole/cli/cmd_send.py +++ b/src/wormhole/cli/cmd_send.py @@ -12,6 +12,7 @@ from ..transit import TransitSender from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr APPID = u"lothar.com/wormhole/text-or-file-xfer" +VERIFY_TIMER = 1 def send(args, reactor=reactor): """I implement 'wormhole send'. I return a Deferred that fires with None @@ -85,8 +86,19 @@ class Sender: print(u"Wormhole code is: %s" % code, file=args.stdout) print(u"", file=args.stdout) + yield w.establish_key() + def on_slow_connection(): + print(u"Key established, waiting for confirmation...", + file=args.stderr) + notify = self._reactor.callLater(VERIFY_TIMER, on_slow_connection) + # TODO: don't stall on w.verify() unless they want it - verifier_bytes = yield w.verify() # this may raise WrongPasswordError + try: + verifier_bytes = yield w.verify() # this may raise WrongPasswordError + finally: + if not notify.called: + notify.cancel() + if args.verify: verifier = bytes_to_hexstr(verifier_bytes) while True: diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index e57f2fb..69f8a42 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -225,7 +225,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): @inlineCallbacks def _do_test(self, as_subprocess=False, mode="text", addslash=False, override_filename=False): - assert mode in ("text", "file", "directory") + assert mode in ("text", "file", "directory", "slow-text") send_cfg = config("send") recv_cfg = config("receive") message = "blah blah blah ponies" @@ -244,7 +244,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): receive_dir = self.mktemp() os.mkdir(receive_dir) - if mode == "text": + if mode == "text" or mode == "slow-text": send_cfg.text = message elif mode == "file": @@ -348,8 +348,14 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): # The sender might fail, leaving the receiver hanging, or vice # versa. Make sure we don't wait on one side exclusively + if mode == "slow-text": + with mock.patch.object(cmd_send, "VERIFY_TIMER", 0), \ + mock.patch.object(cmd_receive, "VERIFY_TIMER", 0): + yield gatherResults([send_d, receive_d], True) + else: + yield gatherResults([send_d, receive_d], True) + - yield gatherResults([send_d, receive_d], True) send_stdout = send_cfg.stdout.getvalue() send_stderr = send_cfg.stderr.getvalue() receive_stdout = recv_cfg.stdout.getvalue() @@ -361,13 +367,21 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.maxDiff = None # show full output for assertion failures - self.failUnlessEqual(send_stderr, "", - (send_stdout, send_stderr)) - self.failUnlessEqual(receive_stderr, "", - (receive_stdout, receive_stderr)) + if mode != "slow-text": + self.failUnlessEqual(send_stderr, "", + (send_stdout, send_stderr)) + self.failUnlessEqual(receive_stderr, "", + (receive_stdout, receive_stderr)) + else: + self.assertEqual(send_stderr, + "Key established, waiting for confirmation...\n", + (send_stdout, send_stderr)) + self.assertEqual(receive_stderr, + "Key established, waiting for confirmation...\n", + (receive_stdout, receive_stderr)) # check sender - if mode == "text": + if mode == "text" or mode == "slow-text": expected = ("Sending text message ({bytes:d} Bytes){NL}" "On the other computer, please run: " "wormhole receive{NL}" @@ -401,7 +415,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): .format(NL=NL), send_stdout) # check receiver - if mode == "text": + if mode == "text" or mode == "slow-text": self.failUnlessEqual(receive_stdout, message+NL) elif mode == "file": self.failUnlessIn("Receiving file ({size:s}) into: {name}" @@ -448,6 +462,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): def test_directory_override(self): return self._do_test(mode="directory", override_filename=True) + def test_slow_text(self): + return self._do_test(mode="slow-text") + @inlineCallbacks def _do_test_fail(self, mode, failmode): assert mode in ("file", "directory") diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index 67ad14f..a59280a 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -522,6 +522,45 @@ class Basic(unittest.TestCase): self.assertEqual(len(pieces), 3) # nameplate plus two words self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) + def _test_establish_key_hook(self, established, before): + timing = DebugTiming() + w = wormhole._Wormhole(APPID, "relay_url", reactor, None, timing) + + if before: + d = w.establish_key() + + if established is True: + w._key = b"key" + elif established is False: + w._key = None + else: + w._key = b"key" + w._error = WelcomeError() + + if not before: + d = w.establish_key() + else: + w._maybe_notify_key() + + if w._key is not None and established is True: + self.successResultOf(d) + elif established is False: + self.assertNot(d.called) + else: + self.failureResultOf(d) + + def test_establish_key_hook(self): + for established in (True, False, "error"): + for before in (True, False): + self._test_establish_key_hook(established, before) + + def test_establish_key_twice(self): + timing = DebugTiming() + w = wormhole._Wormhole(APPID, "relay_url", reactor, None, timing) + d = w.establish_key() + self.assertRaises(InternalError, w.establish_key) + del d + # make sure verify() can be called both before and after the verifier is # computed diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index c3b0c6f..93d95b8 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -237,6 +237,8 @@ class _Wormhole: self._flag_need_to_see_mailbox_used = True self._flag_need_to_build_msg1 = True self._flag_need_to_send_PAKE = True + self._establish_key_called = False + self._key_waiter = None self._key = None self._version_message = None @@ -283,6 +285,14 @@ class _Wormhole: # todo: restore-saved-state entry points + def establish_key(self): + """ + returns a Deferred that fires when we've established the shared key. + When successful, the Deferred fires with a simple `True`, otherwise + it fails. + """ + return self._API_establish_key() + def verify(self): """Returns a Deferred that fires when we've heard back from the other side, and have confirmed that they used the right wormhole code. When @@ -559,6 +569,7 @@ class _Wormhole: def _event_established_key(self): self._timing.add("key established") + self._maybe_notify_key() # both sides send different (random) version messages self._send_version_message() @@ -569,6 +580,25 @@ class _Wormhole: self._maybe_check_version() self._maybe_send_phase_messages() + def _API_establish_key(self): + if self._error: return defer.fail(self._error) + if self._establish_key_called: raise InternalError + self._establish_key_called = True + if self._key is not None: + return defer.succeed(True) + self._key_waiter = defer.Deferred() + return self._key_waiter + + def _maybe_notify_key(self): + if self._key is None: + return + if self._error: + result = failure.Failure(self._error) + else: + result = True + if self._key_waiter and not self._key_waiter.called: + self._key_waiter.callback(result) + def _send_version_message(self): # this is encrypted like a normal phase message, and includes a # dictionary of version flags to let the other Wormhole know what @@ -824,6 +854,9 @@ class _Wormhole: if self._verifier_waiter and not self._verifier_waiter.called: if self.DEBUG: print("EB VW") self._verifier_waiter.errback(error) + if self._key_waiter and not self._key_waiter.called: + if self.DEBUG: print("EB KW") + self._key_waiter.errback(error) for d in self._receive_waiters.values(): if self.DEBUG: print("EB RW") d.errback(error)