diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index 63ef206..01fed57 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -214,6 +214,7 @@ class Wormhole: self.verifier = None self._sent_data = set() # phases self._got_data = set() + self._got_confirmation = False self._closed = False def __enter__(self): @@ -355,10 +356,20 @@ class Wormhole: if self._channel is None: raise UsageError self._got_data.add(phase) self._get_key() - data_key = self.derive_key(u"wormhole:phase:%s" % phase) - inbound_encrypted = self._channel.get(phase) + phases = [] + if not self._got_confirmation: + phases.append(u"_confirm") + phases.append(phase) + (got_phase, body) = self._channel.get_first_of(phases) + if got_phase == u"_confirm": + if body != self.derive_key(u"wormhole:confirmation"): + raise WrongPasswordError + self._got_confirmation = True + (got_phase, body) = self._channel.get_first_of([phase]) + assert got_phase == phase try: - inbound_data = self._decrypt_data(data_key, inbound_encrypted) + data_key = self.derive_key(u"wormhole:phase:%s" % phase) + inbound_data = self._decrypt_data(data_key, body) return inbound_data except CryptoError: raise WrongPasswordError diff --git a/src/wormhole/test/test_blocking.py b/src/wormhole/test/test_blocking.py index 5801e99..8556df5 100644 --- a/src/wormhole/test/test_blocking.py +++ b/src/wormhole/test/test_blocking.py @@ -3,7 +3,8 @@ import json from twisted.trial import unittest from twisted.internet.defer import gatherResults, succeed from twisted.internet.threads import deferToThread -from ..blocking.transcribe import Wormhole, UsageError, ChannelManager +from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager, + WrongPasswordError) from ..blocking.eventsource import EventSourceFollower from .common import ServerBase @@ -246,6 +247,35 @@ class Blocking(ServerBase, unittest.TestCase): d.addCallback(_got_2) return d + def test_wrong_password(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + + # make sure we can detect WrongPasswordError even if one side only + # does get_data() and not send_data(), like "wormhole receive" does + d = deferToThread(w1.get_code) + d.addCallback(lambda code: w2.set_code(code+"not")) + + # w2 can't throw WrongPasswordError until it sees a CONFIRM message, + # and w1 won't send CONFIRM until it sees a PAKE message, which w2 + # won't send until we call get_data. So we need both sides to be + # running at the same time for this test. + def _w1_sends(): + w1.send_data(b"data1") + def _w2_gets(): + self.assertRaises(WrongPasswordError, w2.get_data) + d.addCallback(lambda _: self.doBoth([_w1_sends], [_w2_gets])) + + # and now w1 should have enough information to throw too + d.addCallback(lambda _: deferToThread(self.assertRaises, + WrongPasswordError, w1.get_data)) + def _done(_): + # both sides are closed automatically upon error, but it's still + # legal to call .close(), and should be idempotent + return self.doBoth([w1.close], [w2.close]) + d.addCallback(_done) + return d + def test_verifier(self): w1 = Wormhole(APPID, self.relayurl) w2 = Wormhole(APPID, self.relayurl) diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py index fd20498..f522b06 100644 --- a/src/wormhole/test/test_twisted.py +++ b/src/wormhole/test/test_twisted.py @@ -2,7 +2,8 @@ from __future__ import print_function import sys, json from twisted.trial import unittest from twisted.internet.defer import gatherResults, succeed -from ..twisted.transcribe import Wormhole, UsageError, ChannelManager +from ..twisted.transcribe import (Wormhole, UsageError, ChannelManager, + WrongPasswordError) from ..twisted.eventsource_twisted import EventSourceParser from .common import ServerBase @@ -229,6 +230,32 @@ class Basic(ServerBase, unittest.TestCase): d.addCallback(_got_2) return d + def test_wrong_password(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + d = w1.get_code() + d.addCallback(lambda code: w2.set_code(code+"not")) + + # w2 can't throw WrongPasswordError until it sees a CONFIRM message, + # and w1 won't send CONFIRM until it sees a PAKE message, which w2 + # won't send until we call get_data. So we need both sides to be + # running at the same time for this test. + def _w1_sends(): + return w1.send_data(b"data1") + def _w2_gets(): + return self.assertFailure(w2.get_data(), WrongPasswordError) + d.addCallback(lambda _: self.doBoth(_w1_sends(), _w2_gets())) + + # and now w1 should have enough information to throw too + d.addCallback(lambda _: self.assertFailure(w1.get_data(), + WrongPasswordError)) + def _done(_): + # both sides are closed automatically upon error, but it's still + # legal to call .close(), and should be idempotent + return self.doBoth(w1.close(), w2.close()) + d.addCallback(_done) + return d + def test_verifier(self): w1 = Wormhole(APPID, self.relayurl) w2 = Wormhole(APPID, self.relayurl) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 5a36d90..d92123b 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -206,6 +206,7 @@ class Wormhole: self._started_get_code = False self._sent_data = set() # phases self._got_data = set() + self._got_confirmation = False def _set_side(self, side): self._side = side @@ -375,16 +376,30 @@ class Wormhole: self._got_data.add(phase) d = self._get_key() def _get(key): - data_key = self.derive_key(u"wormhole:phase:%s" % phase) - d1 = self._channel.get(phase) - def _decrypt(inbound_encrypted): + phases = [] + if not self._got_confirmation: + phases.append(u"_confirm") + phases.append(phase) + d1 = self._channel.get_first_of(phases) + def _maybe_got_confirm(phase_and_body): + (got_phase, body) = phase_and_body + if got_phase == u"_confirm": + if body != self.derive_key(u"wormhole:confirmation"): + raise WrongPasswordError + self._got_confirmation = True + return self._channel.get_first_of([phase]) + return phase_and_body + d1.addCallback(_maybe_got_confirm) + def _got(phase_and_body): + (got_phase, body) = phase_and_body + assert got_phase == phase try: - inbound_data = self._decrypt_data(data_key, - inbound_encrypted) + data_key = self.derive_key(u"wormhole:phase:%s" % phase) + inbound_data = self._decrypt_data(data_key, body) return inbound_data except CryptoError: raise WrongPasswordError - d1.addCallback(_decrypt) + d1.addCallback(_got) return d1 d.addCallback(_get) return d