From 30ab9400342088118b50034650a123afbf28626e Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 24 May 2016 13:47:15 -0700 Subject: [PATCH] INCOMPATIBLE: change derivation of phase keys to include side Previously the encryption key used for "phase messages" (anything sent from one side to the other, protected by the shared PAKE-generated session key) was derived just from the session key and the phase name. The two sides would use the same key for their first message (but with random, thus different, nonces). This uses the sending side's string (a random 5-byte/10-character hex string) in the derivation process too, so the two sides use different keys. This gives us an easy way to reject reflected messages. We already ignore messages that claim to use a "side" which matches our own (to ignore server echoes of our own outbound messages). With this change, an attacker (or the server) can't swap in the payload of an outbound message, change the "side" to make it look like a peer message, and then let us decrypt it correctly. It also changes the derivation function to combine the phase and side values safely. This didn't matter much when we only had one externally-provided string, but with two, there's an opportunity for format confusion if they were combined with a simple delimiter. Now we hash both values before concatenating them. This breaks interoperability with clients from before this change. They will always get WrongPasswordErrors. --- src/wormhole/test/test_wormhole.py | 11 ++++++----- src/wormhole/wormhole.py | 30 +++++++++++++++++------------- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index c9e9a0f..8948535 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -251,7 +251,7 @@ class Basic(unittest.TestCase): self.check_out(out[0], type=u"add", phase=u"0") # decrypt+check the outbound message p0_outbound = unhexlify(out[0][u"body"].encode("ascii")) - msgkey0 = w.derive_key(u"wormhole:phase:0", SecretBox.KEY_SIZE) + msgkey0 = w._derive_phase_key(w._side, u"0") p0_plaintext = w._decrypt_data(msgkey0, p0_outbound) self.assertEqual(p0_plaintext, b"phase0-outbound") @@ -260,7 +260,8 @@ class Basic(unittest.TestCase): self.assertNoResult(md) self.assertIn(u"0", w._receive_waiters) self.assertNotIn(u"0", w._received_messages) - p0_inbound = w._encrypt_data(msgkey0, b"phase0-inbound") + msgkey1 = w._derive_phase_key(side2, u"0") + p0_inbound = w._encrypt_data(msgkey1, b"phase0-inbound") p0_inbound_hex = hexlify(p0_inbound).decode("ascii") response(w, type=u"message", phase=u"0", body=p0_inbound_hex, side=side2) @@ -270,8 +271,8 @@ class Basic(unittest.TestCase): self.assertIn(u"0", w._received_messages) # receiving an inbound message will queue it until get() is called - msgkey1 = w.derive_key(u"wormhole:phase:1", SecretBox.KEY_SIZE) - p1_inbound = w._encrypt_data(msgkey1, b"phase1-inbound") + msgkey2 = w._derive_phase_key(side2, u"1") + p1_inbound = w._encrypt_data(msgkey2, b"phase1-inbound") p1_inbound_hex = hexlify(p1_inbound).decode("ascii") response(w, type=u"message", phase=u"1", body=p1_inbound_hex, side=side2) @@ -433,7 +434,7 @@ class Basic(unittest.TestCase): response(w, type=u"claimed", mailbox=u"mb456") w._key = b"" - msgkey = w.derive_key(u"wormhole:phase:misc", SecretBox.KEY_SIZE) + msgkey = w._derive_phase_key(u"side2", u"misc") p1_inbound = w._encrypt_data(msgkey, b"") p1_inbound_hex = hexlify(p1_inbound).decode("ascii") response(w, type=u"message", phase=u"misc", side=u"side2", diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index c56148d..d84af3d 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -11,6 +11,7 @@ from nacl.secret import SecretBox from nacl.exceptions import CryptoError from nacl import utils from spake2 import SPAKE2_Symmetric +from hashlib import sha256 from . import __version__ from . import codes #from .errors import ServerError, Timeout @@ -592,10 +593,14 @@ class _Wormhole: with self._timing.add("API send", phase=phase): self._maybe_send_phase_messages() - #def _derive_phase_key(self, side, phase): - def _derive_phase_key(self, phase): - assert isinstance(phase, type(b"")), type(phase) - purpose = b"wormhole:phase:" + phase + def _derive_phase_key(self, side, phase): + assert isinstance(side, type(u"")), type(side) + assert isinstance(phase, type(u"")), type(phase) + side_bytes = side.encode("ascii") + phase_bytes = phase.encode("ascii") + purpose = (b"wormhole:phase:" + + sha256(side_bytes).digest() + + sha256(phase_bytes).digest()) return self._derive_key(purpose) def _maybe_send_phase_messages(self): @@ -607,12 +612,12 @@ class _Wormhole: plaintexts = self._plaintext_to_send self._plaintext_to_send = [] for pm in plaintexts: - (phase, plaintext) = pm - assert isinstance(phase, int), type(phase) - phase_bytes = (u"%d" % phase).encode("ascii") - data_key = self._derive_phase_key(phase_bytes) + (phase_int, plaintext) = pm + assert isinstance(phase_int, int), type(phase_int) + phase = u"%d" % phase_int + data_key = self._derive_phase_key(self._side, phase) encrypted = self._encrypt_data(data_key, plaintext) - self._msg_send(u"%d" % phase, encrypted) + self._msg_send(phase, encrypted) def _encrypt_data(self, key, data): # Without predefined roles, we can't derive predictably unique keys @@ -663,9 +668,9 @@ class _Wormhole: body = unhexlify(msg["body"].encode("ascii")) if side == self._side: return - self._event_received_peer_message(phase, body) + self._event_received_peer_message(side, phase, body) - def _event_received_peer_message(self, phase, body): + def _event_received_peer_message(self, side, phase, body): # any message in the mailbox means we no longer need the nameplate self._event_mailbox_used() #if phase in self._received_messages: @@ -682,9 +687,8 @@ class _Wormhole: # It's a phase message, aimed at the application above us. Decrypt # and deliver upstairs, notifying anyone waiting on it - phase_bytes = phase.encode("ascii") try: - data_key = self._derive_phase_key(phase_bytes) + data_key = self._derive_phase_key(side, phase) plaintext = self._decrypt_data(data_key, body) except CryptoError: e = WrongPasswordError()