diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 95575bc..c56148d 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -534,7 +534,7 @@ class _Wormhole: self._event_established_key() def _derive_confirmation_key(self): - return self._derive_key(u"wormhole:confirmation") + return self._derive_key(b"wormhole:confirmation") def _event_established_key(self): self._timing.add("key established") @@ -545,7 +545,7 @@ class _Wormhole: confmsg = make_confmsg(confkey, nonce) self._msg_send(u"confirm", confmsg) - verifier = self._derive_key(u"wormhole:verifier") + verifier = self._derive_key(b"wormhole:verifier") self._event_computed_verifier(verifier) self._maybe_send_phase_messages() @@ -594,7 +594,9 @@ class _Wormhole: #def _derive_phase_key(self, side, phase): def _derive_phase_key(self, phase): - return self._derive_key(u"wormhole:phase:%s" % phase) + assert isinstance(phase, type(b"")), type(phase) + purpose = b"wormhole:phase:" + phase + return self._derive_key(purpose) def _maybe_send_phase_messages(self): # TODO: deal with reentrant call @@ -607,7 +609,8 @@ class _Wormhole: for pm in plaintexts: (phase, plaintext) = pm assert isinstance(phase, int), type(phase) - data_key = self._derive_phase_key(u"%d" % phase) + phase_bytes = (u"%d" % phase).encode("ascii") + data_key = self._derive_phase_key(phase_bytes) encrypted = self._encrypt_data(data_key, plaintext) self._msg_send(u"%d" % phase, encrypted) @@ -644,13 +647,14 @@ class _Wormhole: def _API_derive_key(self, purpose, length): if self._error: raise self._error - return self._derive_key(purpose, length) + if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) + return self._derive_key(to_bytes(purpose), length) def _derive_key(self, purpose, length=SecretBox.KEY_SIZE): - if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) + if not isinstance(purpose, type(b"")): raise TypeError(type(purpose)) if self._key is None: raise UsageError # call derive_key after get_verifier() or get() - return HKDF(self._key, length, CTXinfo=to_bytes(purpose)) + return HKDF(self._key, length, CTXinfo=purpose) def _response_handle_message(self, msg): side = msg["side"] @@ -678,8 +682,9 @@ class _Wormhole: # It's a phase message, aimed at the application above us. Decrypt # and deliver upstairs, notifying anyone waiting on it + phase_bytes = phase.encode("ascii") try: - data_key = self._derive_phase_key(phase) + data_key = self._derive_phase_key(phase_bytes) plaintext = self._decrypt_data(data_key, body) except CryptoError: e = WrongPasswordError()