twisted.transit: handle multiple records in one chunk

I made the classic dataReceived() mistake, and exited the function after
delivering the first record. Keep at it until there are no complete
records left.
This commit is contained in:
Brian Warner 2016-03-02 00:36:00 -08:00
parent 84749dd8b3
commit 7234e25897
2 changed files with 25 additions and 10 deletions

View File

@ -867,9 +867,9 @@ class Connection(unittest.TestCase):
# and that we can receive records properly # and that we can receive records properly
inbound_records = [] inbound_records = []
c.recordReceived = inbound_records.append c.recordReceived = inbound_records.append
send_box = SecretBox(owner._receiver_record_key())
RECORD3 = b"record3" RECORD3 = b"record3"
send_box = SecretBox(owner._receiver_record_key())
nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0 nonce_buf = unhexlify("%048x" % 0) # first nonce must be 0
encrypted = send_box.encrypt(RECORD3, nonce_buf) encrypted = send_box.encrypt(RECORD3, nonce_buf)
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
@ -881,7 +881,6 @@ class Connection(unittest.TestCase):
self.assertEqual(inbound_records, [RECORD3]) self.assertEqual(inbound_records, [RECORD3])
RECORD4 = b"record4" RECORD4 = b"record4"
send_box = SecretBox(owner._receiver_record_key())
nonce_buf = unhexlify("%048x" % 1) # nonces increment nonce_buf = unhexlify("%048x" % 1) # nonces increment
encrypted = send_box.encrypt(RECORD4, nonce_buf) encrypted = send_box.encrypt(RECORD4, nonce_buf)
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
@ -892,6 +891,21 @@ class Connection(unittest.TestCase):
c.dataReceived(encrypted[-2:]) c.dataReceived(encrypted[-2:])
self.assertEqual(inbound_records, [RECORD3, RECORD4]) self.assertEqual(inbound_records, [RECORD3, RECORD4])
# receiving two records at the same time: deliver both
inbound_records[:] = []
RECORD5 = b"record5"
nonce_buf = unhexlify("%048x" % 2) # nonces increment
encrypted = send_box.encrypt(RECORD5, nonce_buf)
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
r5 = length+encrypted
RECORD6 = b"record6"
nonce_buf = unhexlify("%048x" % 3) # nonces increment
encrypted = send_box.encrypt(RECORD6, nonce_buf)
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
r6 = length+encrypted
c.dataReceived(r5+r6)
self.assertEqual(inbound_records, [RECORD5, RECORD6])
def corrupt(self, orig): def corrupt(self, orig):
last_byte = orig[-1:] last_byte = orig[-1:]
num = int(hexlify(last_byte).decode("ascii"), 16) num = int(hexlify(last_byte).decode("ascii"), 16)

View File

@ -143,15 +143,16 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
d.callback(self) d.callback(self)
def dataReceivedRECORDS(self): def dataReceivedRECORDS(self):
if len(self.buf) < 4: while True:
return if len(self.buf) < 4:
length = int(hexlify(self.buf[:4]), 16) return
if len(self.buf) < 4+length: length = int(hexlify(self.buf[:4]), 16)
return if len(self.buf) < 4+length:
encrypted, self.buf = self.buf[4:4+length], self.buf[4+length:] return
encrypted, self.buf = self.buf[4:4+length], self.buf[4+length:]
record = self._decrypt_record(encrypted) record = self._decrypt_record(encrypted)
self.recordReceived(record) self.recordReceived(record)
def _decrypt_record(self, encrypted): def _decrypt_record(self, encrypted):
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended