twisted.transit: implement Deferred-based receive_record()
This commit is contained in:
parent
fb1461fa8c
commit
763d72f582
|
@ -1,7 +1,7 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.internet import defer, task, endpoints, protocol, address
|
from twisted.internet import defer, task, endpoints, protocol, address, error
|
||||||
from twisted.internet.defer import gatherResults, inlineCallbacks
|
from twisted.internet.defer import gatherResults, inlineCallbacks
|
||||||
from twisted.python import log, failure
|
from twisted.python import log, failure
|
||||||
from ..twisted import transit
|
from ..twisted import transit
|
||||||
|
@ -280,6 +280,7 @@ class DummyProtocol(protocol.Protocol):
|
||||||
self._d2.callback(None)
|
self._d2.callback(None)
|
||||||
|
|
||||||
class FakeTransport:
|
class FakeTransport:
|
||||||
|
signalConnectionLost = True
|
||||||
def __init__(self, p, peeraddr):
|
def __init__(self, p, peeraddr):
|
||||||
self.protocol = p
|
self.protocol = p
|
||||||
self._peeraddr = peeraddr
|
self._peeraddr = peeraddr
|
||||||
|
@ -289,7 +290,8 @@ class FakeTransport:
|
||||||
self._buf += data
|
self._buf += data
|
||||||
def loseConnection(self):
|
def loseConnection(self):
|
||||||
self._connected = False
|
self._connected = False
|
||||||
self.protocol.connectionLost()
|
if self.signalConnectionLost:
|
||||||
|
self.protocol.connectionLost()
|
||||||
def getPeer(self):
|
def getPeer(self):
|
||||||
return self._peeraddr
|
return self._peeraddr
|
||||||
|
|
||||||
|
@ -950,6 +952,46 @@ class Connection(unittest.TestCase):
|
||||||
# happens? We currently get a type-check assertion from HKDF because
|
# happens? We currently get a type-check assertion from HKDF because
|
||||||
# the key is None.
|
# the key is None.
|
||||||
|
|
||||||
|
def test_receive_queue(self):
|
||||||
|
c = transit.Connection(None, None, None)
|
||||||
|
c.transport = FakeTransport(c, None)
|
||||||
|
c.transport.signalConnectionLost = False
|
||||||
|
results = [[] for i in range(5)]
|
||||||
|
c.recordReceived(b"0")
|
||||||
|
c.recordReceived(b"1")
|
||||||
|
c.recordReceived(b"2")
|
||||||
|
c.receive_record().addBoth(results[0].append)
|
||||||
|
self.assertEqual(results[0], [b"0"])
|
||||||
|
d1 = c.receive_record()
|
||||||
|
d2 = c.receive_record()
|
||||||
|
# they must fire in order of receipt, not order of addCallback
|
||||||
|
d2.addBoth(results[2].append)
|
||||||
|
self.assertEqual(results[2], [b"2"])
|
||||||
|
d1.addBoth(results[1].append)
|
||||||
|
self.assertEqual(results[1], [b"1"])
|
||||||
|
|
||||||
|
c.receive_record().addBoth(results[3].append)
|
||||||
|
c.receive_record().addBoth(results[4].append)
|
||||||
|
self.assertEqual(results[3], [])
|
||||||
|
self.assertEqual(results[4], [])
|
||||||
|
|
||||||
|
c.recordReceived(b"3")
|
||||||
|
self.assertEqual(results[3], [b"3"])
|
||||||
|
self.assertEqual(results[4], [])
|
||||||
|
|
||||||
|
c.recordReceived(b"4")
|
||||||
|
self.assertEqual(results[3], [b"3"])
|
||||||
|
self.assertEqual(results[4], [b"4"])
|
||||||
|
|
||||||
|
closed = []
|
||||||
|
c.receive_record().addBoth(closed.append)
|
||||||
|
c.close()
|
||||||
|
self.assertEqual(len(closed), 1)
|
||||||
|
f = closed[0]
|
||||||
|
self.assertIsInstance(f, failure.Failure)
|
||||||
|
self.assertIsInstance(f.value, error.ConnectionClosed)
|
||||||
|
|
||||||
|
|
||||||
DIRECT_HINT = u"tcp:direct:1234"
|
DIRECT_HINT = u"tcp:direct:1234"
|
||||||
RELAY_HINT = u"tcp:relay:1234"
|
RELAY_HINT = u"tcp:relay:1234"
|
||||||
UNUSABLE_HINT = u"unusable:foo:bar"
|
UNUSABLE_HINT = u"unusable:foo:bar"
|
||||||
|
@ -1098,8 +1140,7 @@ class Full(unittest.TestCase):
|
||||||
self.assertIsInstance(x, transit.Connection)
|
self.assertIsInstance(x, transit.Connection)
|
||||||
self.assertIsInstance(y, transit.Connection)
|
self.assertIsInstance(y, transit.Connection)
|
||||||
|
|
||||||
d = defer.Deferred()
|
d = y.receive_record()
|
||||||
y.recordReceived = d.callback
|
|
||||||
|
|
||||||
x.send_record(b"record1")
|
x.send_record(b"record1")
|
||||||
r = yield d
|
r = yield d
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import sys, time, socket
|
import sys, time, socket, collections
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
from twisted.python.runtime import platformType
|
from twisted.python.runtime import platformType
|
||||||
from twisted.internet import (reactor, interfaces, defer, protocol,
|
from twisted.internet import (reactor, interfaces, defer, protocol,
|
||||||
endpoints, task, address)
|
endpoints, task, address, error)
|
||||||
from twisted.protocols import policies
|
from twisted.protocols import policies
|
||||||
from nacl.secret import SecretBox
|
from nacl.secret import SecretBox
|
||||||
from ..util import ipaddrs
|
from ..util import ipaddrs
|
||||||
|
@ -34,6 +34,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
self.start = start
|
self.start = start
|
||||||
self._negotiation_d = defer.Deferred(self._cancel)
|
self._negotiation_d = defer.Deferred(self._cancel)
|
||||||
self._error = None
|
self._error = None
|
||||||
|
self._inbound_records = collections.deque()
|
||||||
|
self._waiting_reads = collections.deque()
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self):
|
||||||
debug("handle %r" % (self.transport,))
|
debug("handle %r" % (self.transport,))
|
||||||
|
@ -171,10 +173,31 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
|
||||||
self.transport.write(encrypted)
|
self.transport.write(encrypted)
|
||||||
|
|
||||||
def recordReceived(self, record):
|
def recordReceived(self, record):
|
||||||
pass
|
self._inbound_records.append(record)
|
||||||
|
self._deliverRecords()
|
||||||
|
self._checkPause()
|
||||||
|
|
||||||
|
def _checkPause(self):
|
||||||
|
return # TODO
|
||||||
|
|
||||||
|
def receive_record(self):
|
||||||
|
d = defer.Deferred()
|
||||||
|
self._waiting_reads.append(d)
|
||||||
|
self._deliverRecords()
|
||||||
|
return d
|
||||||
|
|
||||||
|
def _deliverRecords(self):
|
||||||
|
while self._inbound_records and self._waiting_reads:
|
||||||
|
r = self._inbound_records.popleft()
|
||||||
|
d = self._waiting_reads.popleft()
|
||||||
|
d.callback(r)
|
||||||
|
self._checkPause()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.transport.loseConnection()
|
self.transport.loseConnection()
|
||||||
|
while self._waiting_reads:
|
||||||
|
d = self._waiting_reads.popleft()
|
||||||
|
d.errback(error.ConnectionClosed())
|
||||||
|
|
||||||
def timeoutConnection(self):
|
def timeoutConnection(self):
|
||||||
self._error = BadHandshake("timeout")
|
self._error = BadHandshake("timeout")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user