twisted.transit: implement Deferred-based receive_record()

This commit is contained in:
Brian Warner 2016-02-15 11:40:21 -08:00
parent fb1461fa8c
commit 763d72f582
2 changed files with 71 additions and 7 deletions

View File

@ -1,7 +1,7 @@
from __future__ import print_function from __future__ import print_function
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer, task, endpoints, protocol, address from twisted.internet import defer, task, endpoints, protocol, address, error
from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.internet.defer import gatherResults, inlineCallbacks
from twisted.python import log, failure from twisted.python import log, failure
from ..twisted import transit from ..twisted import transit
@ -280,6 +280,7 @@ class DummyProtocol(protocol.Protocol):
self._d2.callback(None) self._d2.callback(None)
class FakeTransport: class FakeTransport:
signalConnectionLost = True
def __init__(self, p, peeraddr): def __init__(self, p, peeraddr):
self.protocol = p self.protocol = p
self._peeraddr = peeraddr self._peeraddr = peeraddr
@ -289,6 +290,7 @@ class FakeTransport:
self._buf += data self._buf += data
def loseConnection(self): def loseConnection(self):
self._connected = False self._connected = False
if self.signalConnectionLost:
self.protocol.connectionLost() self.protocol.connectionLost()
def getPeer(self): def getPeer(self):
return self._peeraddr return self._peeraddr
@ -950,6 +952,46 @@ class Connection(unittest.TestCase):
# happens? We currently get a type-check assertion from HKDF because # happens? We currently get a type-check assertion from HKDF because
# the key is None. # the key is None.
def test_receive_queue(self):
c = transit.Connection(None, None, None)
c.transport = FakeTransport(c, None)
c.transport.signalConnectionLost = False
results = [[] for i in range(5)]
c.recordReceived(b"0")
c.recordReceived(b"1")
c.recordReceived(b"2")
c.receive_record().addBoth(results[0].append)
self.assertEqual(results[0], [b"0"])
d1 = c.receive_record()
d2 = c.receive_record()
# they must fire in order of receipt, not order of addCallback
d2.addBoth(results[2].append)
self.assertEqual(results[2], [b"2"])
d1.addBoth(results[1].append)
self.assertEqual(results[1], [b"1"])
c.receive_record().addBoth(results[3].append)
c.receive_record().addBoth(results[4].append)
self.assertEqual(results[3], [])
self.assertEqual(results[4], [])
c.recordReceived(b"3")
self.assertEqual(results[3], [b"3"])
self.assertEqual(results[4], [])
c.recordReceived(b"4")
self.assertEqual(results[3], [b"3"])
self.assertEqual(results[4], [b"4"])
closed = []
c.receive_record().addBoth(closed.append)
c.close()
self.assertEqual(len(closed), 1)
f = closed[0]
self.assertIsInstance(f, failure.Failure)
self.assertIsInstance(f.value, error.ConnectionClosed)
DIRECT_HINT = u"tcp:direct:1234" DIRECT_HINT = u"tcp:direct:1234"
RELAY_HINT = u"tcp:relay:1234" RELAY_HINT = u"tcp:relay:1234"
UNUSABLE_HINT = u"unusable:foo:bar" UNUSABLE_HINT = u"unusable:foo:bar"
@ -1098,8 +1140,7 @@ class Full(unittest.TestCase):
self.assertIsInstance(x, transit.Connection) self.assertIsInstance(x, transit.Connection)
self.assertIsInstance(y, transit.Connection) self.assertIsInstance(y, transit.Connection)
d = defer.Deferred() d = y.receive_record()
y.recordReceived = d.callback
x.send_record(b"record1") x.send_record(b"record1")
r = yield d r = yield d

View File

@ -1,10 +1,10 @@
from __future__ import print_function from __future__ import print_function
import sys, time, socket import sys, time, socket, collections
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from zope.interface import implementer from zope.interface import implementer
from twisted.python.runtime import platformType from twisted.python.runtime import platformType
from twisted.internet import (reactor, interfaces, defer, protocol, from twisted.internet import (reactor, interfaces, defer, protocol,
endpoints, task, address) endpoints, task, address, error)
from twisted.protocols import policies from twisted.protocols import policies
from nacl.secret import SecretBox from nacl.secret import SecretBox
from ..util import ipaddrs from ..util import ipaddrs
@ -34,6 +34,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
self.start = start self.start = start
self._negotiation_d = defer.Deferred(self._cancel) self._negotiation_d = defer.Deferred(self._cancel)
self._error = None self._error = None
self._inbound_records = collections.deque()
self._waiting_reads = collections.deque()
def connectionMade(self): def connectionMade(self):
debug("handle %r" % (self.transport,)) debug("handle %r" % (self.transport,))
@ -171,10 +173,31 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
self.transport.write(encrypted) self.transport.write(encrypted)
def recordReceived(self, record): def recordReceived(self, record):
pass self._inbound_records.append(record)
self._deliverRecords()
self._checkPause()
def _checkPause(self):
return # TODO
def receive_record(self):
d = defer.Deferred()
self._waiting_reads.append(d)
self._deliverRecords()
return d
def _deliverRecords(self):
while self._inbound_records and self._waiting_reads:
r = self._inbound_records.popleft()
d = self._waiting_reads.popleft()
d.callback(r)
self._checkPause()
def close(self): def close(self):
self.transport.loseConnection() self.transport.loseConnection()
while self._waiting_reads:
d = self._waiting_reads.popleft()
d.errback(error.ConnectionClosed())
def timeoutConnection(self): def timeoutConnection(self):
self._error = BadHandshake("timeout") self._error = BadHandshake("timeout")