twisted.transit: implement Deferred-based receive_record()
This commit is contained in:
parent
fb1461fa8c
commit
763d72f582
|
@ -1,7 +1,7 @@
|
|||
from __future__ import print_function
|
||||
from binascii import hexlify, unhexlify
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import defer, task, endpoints, protocol, address
|
||||
from twisted.internet import defer, task, endpoints, protocol, address, error
|
||||
from twisted.internet.defer import gatherResults, inlineCallbacks
|
||||
from twisted.python import log, failure
|
||||
from ..twisted import transit
|
||||
|
@ -280,6 +280,7 @@ class DummyProtocol(protocol.Protocol):
|
|||
self._d2.callback(None)
|
||||
|
||||
class FakeTransport:
|
||||
signalConnectionLost = True
|
||||
def __init__(self, p, peeraddr):
|
||||
self.protocol = p
|
||||
self._peeraddr = peeraddr
|
||||
|
@ -289,7 +290,8 @@ class FakeTransport:
|
|||
self._buf += data
|
||||
def loseConnection(self):
|
||||
self._connected = False
|
||||
self.protocol.connectionLost()
|
||||
if self.signalConnectionLost:
|
||||
self.protocol.connectionLost()
|
||||
def getPeer(self):
|
||||
return self._peeraddr
|
||||
|
||||
|
@ -950,6 +952,46 @@ class Connection(unittest.TestCase):
|
|||
# happens? We currently get a type-check assertion from HKDF because
|
||||
# the key is None.
|
||||
|
||||
def test_receive_queue(self):
|
||||
c = transit.Connection(None, None, None)
|
||||
c.transport = FakeTransport(c, None)
|
||||
c.transport.signalConnectionLost = False
|
||||
results = [[] for i in range(5)]
|
||||
c.recordReceived(b"0")
|
||||
c.recordReceived(b"1")
|
||||
c.recordReceived(b"2")
|
||||
c.receive_record().addBoth(results[0].append)
|
||||
self.assertEqual(results[0], [b"0"])
|
||||
d1 = c.receive_record()
|
||||
d2 = c.receive_record()
|
||||
# they must fire in order of receipt, not order of addCallback
|
||||
d2.addBoth(results[2].append)
|
||||
self.assertEqual(results[2], [b"2"])
|
||||
d1.addBoth(results[1].append)
|
||||
self.assertEqual(results[1], [b"1"])
|
||||
|
||||
c.receive_record().addBoth(results[3].append)
|
||||
c.receive_record().addBoth(results[4].append)
|
||||
self.assertEqual(results[3], [])
|
||||
self.assertEqual(results[4], [])
|
||||
|
||||
c.recordReceived(b"3")
|
||||
self.assertEqual(results[3], [b"3"])
|
||||
self.assertEqual(results[4], [])
|
||||
|
||||
c.recordReceived(b"4")
|
||||
self.assertEqual(results[3], [b"3"])
|
||||
self.assertEqual(results[4], [b"4"])
|
||||
|
||||
closed = []
|
||||
c.receive_record().addBoth(closed.append)
|
||||
c.close()
|
||||
self.assertEqual(len(closed), 1)
|
||||
f = closed[0]
|
||||
self.assertIsInstance(f, failure.Failure)
|
||||
self.assertIsInstance(f.value, error.ConnectionClosed)
|
||||
|
||||
|
||||
DIRECT_HINT = u"tcp:direct:1234"
|
||||
RELAY_HINT = u"tcp:relay:1234"
|
||||
UNUSABLE_HINT = u"unusable:foo:bar"
|
||||
|
@ -1098,8 +1140,7 @@ class Full(unittest.TestCase):
|
|||
self.assertIsInstance(x, transit.Connection)
|
||||
self.assertIsInstance(y, transit.Connection)
|
||||
|
||||
d = defer.Deferred()
|
||||
y.recordReceived = d.callback
|
||||
d = y.receive_record()
|
||||
|
||||
x.send_record(b"record1")
|
||||
r = yield d
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from __future__ import print_function
|
||||
import sys, time, socket
|
||||
import sys, time, socket, collections
|
||||
from binascii import hexlify, unhexlify
|
||||
from zope.interface import implementer
|
||||
from twisted.python.runtime import platformType
|
||||
from twisted.internet import (reactor, interfaces, defer, protocol,
|
||||
endpoints, task, address)
|
||||
endpoints, task, address, error)
|
||||
from twisted.protocols import policies
|
||||
from nacl.secret import SecretBox
|
||||
from ..util import ipaddrs
|
||||
|
@ -34,6 +34,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
|||
self.start = start
|
||||
self._negotiation_d = defer.Deferred(self._cancel)
|
||||
self._error = None
|
||||
self._inbound_records = collections.deque()
|
||||
self._waiting_reads = collections.deque()
|
||||
|
||||
def connectionMade(self):
|
||||
debug("handle %r" % (self.transport,))
|
||||
|
@ -171,10 +173,31 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
|||
self.transport.write(encrypted)
|
||||
|
||||
def recordReceived(self, record):
|
||||
pass
|
||||
self._inbound_records.append(record)
|
||||
self._deliverRecords()
|
||||
self._checkPause()
|
||||
|
||||
def _checkPause(self):
|
||||
return # TODO
|
||||
|
||||
def receive_record(self):
|
||||
d = defer.Deferred()
|
||||
self._waiting_reads.append(d)
|
||||
self._deliverRecords()
|
||||
return d
|
||||
|
||||
def _deliverRecords(self):
|
||||
while self._inbound_records and self._waiting_reads:
|
||||
r = self._inbound_records.popleft()
|
||||
d = self._waiting_reads.popleft()
|
||||
d.callback(r)
|
||||
self._checkPause()
|
||||
|
||||
def close(self):
|
||||
self.transport.loseConnection()
|
||||
while self._waiting_reads:
|
||||
d = self._waiting_reads.popleft()
|
||||
d.errback(error.ConnectionClosed())
|
||||
|
||||
def timeoutConnection(self):
|
||||
self._error = BadHandshake("timeout")
|
||||
|
|
Loading…
Reference in New Issue
Block a user