twisted.transit: implement Deferred-based receive_record()

This commit is contained in:
Brian Warner 2016-02-15 11:40:21 -08:00
parent fb1461fa8c
commit 763d72f582
2 changed files with 71 additions and 7 deletions

View File

@ -1,7 +1,7 @@
from __future__ import print_function
from binascii import hexlify, unhexlify
from twisted.trial import unittest
from twisted.internet import defer, task, endpoints, protocol, address
from twisted.internet import defer, task, endpoints, protocol, address, error
from twisted.internet.defer import gatherResults, inlineCallbacks
from twisted.python import log, failure
from ..twisted import transit
@ -280,6 +280,7 @@ class DummyProtocol(protocol.Protocol):
self._d2.callback(None)
class FakeTransport:
signalConnectionLost = True
def __init__(self, p, peeraddr):
self.protocol = p
self._peeraddr = peeraddr
@ -289,6 +290,7 @@ class FakeTransport:
self._buf += data
def loseConnection(self):
self._connected = False
if self.signalConnectionLost:
self.protocol.connectionLost()
def getPeer(self):
return self._peeraddr
@ -950,6 +952,46 @@ class Connection(unittest.TestCase):
# happens? We currently get a type-check assertion from HKDF because
# the key is None.
def test_receive_queue(self):
c = transit.Connection(None, None, None)
c.transport = FakeTransport(c, None)
c.transport.signalConnectionLost = False
results = [[] for i in range(5)]
c.recordReceived(b"0")
c.recordReceived(b"1")
c.recordReceived(b"2")
c.receive_record().addBoth(results[0].append)
self.assertEqual(results[0], [b"0"])
d1 = c.receive_record()
d2 = c.receive_record()
# they must fire in order of receipt, not order of addCallback
d2.addBoth(results[2].append)
self.assertEqual(results[2], [b"2"])
d1.addBoth(results[1].append)
self.assertEqual(results[1], [b"1"])
c.receive_record().addBoth(results[3].append)
c.receive_record().addBoth(results[4].append)
self.assertEqual(results[3], [])
self.assertEqual(results[4], [])
c.recordReceived(b"3")
self.assertEqual(results[3], [b"3"])
self.assertEqual(results[4], [])
c.recordReceived(b"4")
self.assertEqual(results[3], [b"3"])
self.assertEqual(results[4], [b"4"])
closed = []
c.receive_record().addBoth(closed.append)
c.close()
self.assertEqual(len(closed), 1)
f = closed[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, error.ConnectionClosed)
DIRECT_HINT = u"tcp:direct:1234"
RELAY_HINT = u"tcp:relay:1234"
UNUSABLE_HINT = u"unusable:foo:bar"
@ -1098,8 +1140,7 @@ class Full(unittest.TestCase):
self.assertIsInstance(x, transit.Connection)
self.assertIsInstance(y, transit.Connection)
d = defer.Deferred()
y.recordReceived = d.callback
d = y.receive_record()
x.send_record(b"record1")
r = yield d

View File

@ -1,10 +1,10 @@
from __future__ import print_function
import sys, time, socket
import sys, time, socket, collections
from binascii import hexlify, unhexlify
from zope.interface import implementer
from twisted.python.runtime import platformType
from twisted.internet import (reactor, interfaces, defer, protocol,
endpoints, task, address)
endpoints, task, address, error)
from twisted.protocols import policies
from nacl.secret import SecretBox
from ..util import ipaddrs
@ -34,6 +34,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
self.start = start
self._negotiation_d = defer.Deferred(self._cancel)
self._error = None
self._inbound_records = collections.deque()
self._waiting_reads = collections.deque()
def connectionMade(self):
debug("handle %r" % (self.transport,))
@ -171,10 +173,31 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
self.transport.write(encrypted)
def recordReceived(self, record):
pass
self._inbound_records.append(record)
self._deliverRecords()
self._checkPause()
def _checkPause(self):
return # TODO
def receive_record(self):
d = defer.Deferred()
self._waiting_reads.append(d)
self._deliverRecords()
return d
def _deliverRecords(self):
while self._inbound_records and self._waiting_reads:
r = self._inbound_records.popleft()
d = self._waiting_reads.popleft()
d.callback(r)
self._checkPause()
def close(self):
self.transport.loseConnection()
while self._waiting_reads:
d = self._waiting_reads.popleft()
d.errback(error.ConnectionClosed())
def timeoutConnection(self):
self._error = BadHandshake("timeout")