add more assertions around transit_key

This commit is contained in:
Brian Warner 2016-02-17 17:21:52 -08:00
parent 3ffceff9d5
commit 7ceffd783a
2 changed files with 9 additions and 3 deletions

View File

@ -478,18 +478,21 @@ class Common:
self._their_relay_hints = set(hints) self._their_relay_hints = set(hints)
def _send_this(self): def _send_this(self):
assert self._transit_key
if self.is_sender: if self.is_sender:
return build_sender_handshake(self._transit_key) return build_sender_handshake(self._transit_key)
else: else:
return build_receiver_handshake(self._transit_key) return build_receiver_handshake(self._transit_key)
def _expect_this(self): def _expect_this(self):
assert self._transit_key
if self.is_sender: if self.is_sender:
return build_receiver_handshake(self._transit_key) return build_receiver_handshake(self._transit_key)
else: else:
return build_sender_handshake(self._transit_key)# + b"go\n" return build_sender_handshake(self._transit_key)# + b"go\n"
def _sender_record_key(self): def _sender_record_key(self):
assert self._transit_key
if self.is_sender: if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE, return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key") CTXinfo=b"transit_record_sender_key")
@ -498,6 +501,7 @@ class Common:
CTXinfo=b"transit_record_receiver_key") CTXinfo=b"transit_record_receiver_key")
def _receiver_record_key(self): def _receiver_record_key(self):
assert self._transit_key
if self.is_sender: if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE, return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key") CTXinfo=b"transit_record_receiver_key")
@ -506,6 +510,7 @@ class Common:
CTXinfo=b"transit_record_sender_key") CTXinfo=b"transit_record_sender_key")
def set_transit_key(self, key): def set_transit_key(self, key):
assert isinstance(key, type(b"")), type(key)
# We use pubsub to protect against the race where the sender knows # We use pubsub to protect against the race where the sender knows
# the hints and the key, and connects to the receiver's transit # the hints and the key, and connects to the receiver's transit
# socket before the receiver gets the relay message (and thus the # socket before the receiver gets the relay message (and thus the
@ -589,6 +594,7 @@ class Common:
def _start_connector(self, ep, description, is_relay=False): def _start_connector(self, ep, description, is_relay=False):
relay_handshake = None relay_handshake = None
if is_relay: if is_relay:
assert self._transit_key
relay_handshake = build_relay_handshake(self._transit_key) relay_handshake = build_relay_handshake(self._transit_key)
f = OutboundConnectionFactory(self, relay_handshake) f = OutboundConnectionFactory(self, relay_handshake)
d = ep.connect(f) d = ep.connect(f)

View File

@ -4,9 +4,9 @@ import six
def HKDF(SKM, dkLen, XTS=None, CTXinfo=b"", digest=sha256, def HKDF(SKM, dkLen, XTS=None, CTXinfo=b"", digest=sha256,
_test_expected_PRK=None): _test_expected_PRK=None):
assert isinstance(SKM, six.binary_type) assert isinstance(SKM, six.binary_type), type(SKM)
assert isinstance(XTS, (six.binary_type,type(None))) assert isinstance(XTS, (six.binary_type,type(None))), type(XTS)
assert isinstance(CTXinfo, six.binary_type) assert isinstance(CTXinfo, six.binary_type), type(CTXinfo)
hlen = len(digest(b"").digest()) hlen = len(digest(b"").digest())
assert dkLen <= hlen*255 assert dkLen <= hlen*255
if XTS is None: if XTS is None: