diff --git a/src/wormhole/twisted/transit.py b/src/wormhole/twisted/transit.py index fc3fea5..e5c684f 100644 --- a/src/wormhole/twisted/transit.py +++ b/src/wormhole/twisted/transit.py @@ -478,18 +478,21 @@ class Common: self._their_relay_hints = set(hints) def _send_this(self): + assert self._transit_key if self.is_sender: return build_sender_handshake(self._transit_key) else: return build_receiver_handshake(self._transit_key) def _expect_this(self): + assert self._transit_key if self.is_sender: return build_receiver_handshake(self._transit_key) else: return build_sender_handshake(self._transit_key)# + b"go\n" def _sender_record_key(self): + assert self._transit_key if self.is_sender: return HKDF(self._transit_key, SecretBox.KEY_SIZE, CTXinfo=b"transit_record_sender_key") @@ -498,6 +501,7 @@ class Common: CTXinfo=b"transit_record_receiver_key") def _receiver_record_key(self): + assert self._transit_key if self.is_sender: return HKDF(self._transit_key, SecretBox.KEY_SIZE, CTXinfo=b"transit_record_receiver_key") @@ -506,6 +510,7 @@ class Common: CTXinfo=b"transit_record_sender_key") def set_transit_key(self, key): + assert isinstance(key, type(b"")), type(key) # We use pubsub to protect against the race where the sender knows # the hints and the key, and connects to the receiver's transit # socket before the receiver gets the relay message (and thus the @@ -589,6 +594,7 @@ class Common: def _start_connector(self, ep, description, is_relay=False): relay_handshake = None if is_relay: + assert self._transit_key relay_handshake = build_relay_handshake(self._transit_key) f = OutboundConnectionFactory(self, relay_handshake) d = ep.connect(f) diff --git a/src/wormhole/util/hkdf.py b/src/wormhole/util/hkdf.py index 3f6f9ad..c3a477d 100644 --- a/src/wormhole/util/hkdf.py +++ b/src/wormhole/util/hkdf.py @@ -4,9 +4,9 @@ import six def HKDF(SKM, dkLen, XTS=None, CTXinfo=b"", digest=sha256, _test_expected_PRK=None): - assert isinstance(SKM, six.binary_type) - assert isinstance(XTS, (six.binary_type,type(None))) - assert isinstance(CTXinfo, six.binary_type) + assert isinstance(SKM, six.binary_type), type(SKM) + assert isinstance(XTS, (six.binary_type,type(None))), type(XTS) + assert isinstance(CTXinfo, six.binary_type), type(CTXinfo) hlen = len(digest(b"").digest()) assert dkLen <= hlen*255 if XTS is None: