From 763d72f582e6ac748406d2be78e9d75b2928cece Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 15 Feb 2016 11:40:21 -0800 Subject: [PATCH] twisted.transit: implement Deferred-based receive_record() --- src/wormhole/test/test_transit_twisted.py | 49 +++++++++++++++++++++-- src/wormhole/twisted/transit.py | 29 ++++++++++++-- 2 files changed, 71 insertions(+), 7 deletions(-) diff --git a/src/wormhole/test/test_transit_twisted.py b/src/wormhole/test/test_transit_twisted.py index 3f1da6b..310f55f 100644 --- a/src/wormhole/test/test_transit_twisted.py +++ b/src/wormhole/test/test_transit_twisted.py @@ -1,7 +1,7 @@ from __future__ import print_function from binascii import hexlify, unhexlify from twisted.trial import unittest -from twisted.internet import defer, task, endpoints, protocol, address +from twisted.internet import defer, task, endpoints, protocol, address, error from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.python import log, failure from ..twisted import transit @@ -280,6 +280,7 @@ class DummyProtocol(protocol.Protocol): self._d2.callback(None) class FakeTransport: + signalConnectionLost = True def __init__(self, p, peeraddr): self.protocol = p self._peeraddr = peeraddr @@ -289,7 +290,8 @@ class FakeTransport: self._buf += data def loseConnection(self): self._connected = False - self.protocol.connectionLost() + if self.signalConnectionLost: + self.protocol.connectionLost() def getPeer(self): return self._peeraddr @@ -950,6 +952,46 @@ class Connection(unittest.TestCase): # happens? We currently get a type-check assertion from HKDF because # the key is None. + def test_receive_queue(self): + c = transit.Connection(None, None, None) + c.transport = FakeTransport(c, None) + c.transport.signalConnectionLost = False + results = [[] for i in range(5)] + c.recordReceived(b"0") + c.recordReceived(b"1") + c.recordReceived(b"2") + c.receive_record().addBoth(results[0].append) + self.assertEqual(results[0], [b"0"]) + d1 = c.receive_record() + d2 = c.receive_record() + # they must fire in order of receipt, not order of addCallback + d2.addBoth(results[2].append) + self.assertEqual(results[2], [b"2"]) + d1.addBoth(results[1].append) + self.assertEqual(results[1], [b"1"]) + + c.receive_record().addBoth(results[3].append) + c.receive_record().addBoth(results[4].append) + self.assertEqual(results[3], []) + self.assertEqual(results[4], []) + + c.recordReceived(b"3") + self.assertEqual(results[3], [b"3"]) + self.assertEqual(results[4], []) + + c.recordReceived(b"4") + self.assertEqual(results[3], [b"3"]) + self.assertEqual(results[4], [b"4"]) + + closed = [] + c.receive_record().addBoth(closed.append) + c.close() + self.assertEqual(len(closed), 1) + f = closed[0] + self.assertIsInstance(f, failure.Failure) + self.assertIsInstance(f.value, error.ConnectionClosed) + + DIRECT_HINT = u"tcp:direct:1234" RELAY_HINT = u"tcp:relay:1234" UNUSABLE_HINT = u"unusable:foo:bar" @@ -1098,8 +1140,7 @@ class Full(unittest.TestCase): self.assertIsInstance(x, transit.Connection) self.assertIsInstance(y, transit.Connection) - d = defer.Deferred() - y.recordReceived = d.callback + d = y.receive_record() x.send_record(b"record1") r = yield d diff --git a/src/wormhole/twisted/transit.py b/src/wormhole/twisted/transit.py index fbb5d88..6544a71 100644 --- a/src/wormhole/twisted/transit.py +++ b/src/wormhole/twisted/transit.py @@ -1,10 +1,10 @@ from __future__ import print_function -import sys, time, socket +import sys, time, socket, collections from binascii import hexlify, unhexlify from zope.interface import implementer from twisted.python.runtime import platformType from twisted.internet import (reactor, interfaces, defer, protocol, - endpoints, task, address) + endpoints, task, address, error) from twisted.protocols import policies from nacl.secret import SecretBox from ..util import ipaddrs @@ -34,6 +34,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): self.start = start self._negotiation_d = defer.Deferred(self._cancel) self._error = None + self._inbound_records = collections.deque() + self._waiting_reads = collections.deque() def connectionMade(self): debug("handle %r" % (self.transport,)) @@ -171,10 +173,31 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): self.transport.write(encrypted) def recordReceived(self, record): - pass + self._inbound_records.append(record) + self._deliverRecords() + self._checkPause() + + def _checkPause(self): + return # TODO + + def receive_record(self): + d = defer.Deferred() + self._waiting_reads.append(d) + self._deliverRecords() + return d + + def _deliverRecords(self): + while self._inbound_records and self._waiting_reads: + r = self._inbound_records.popleft() + d = self._waiting_reads.popleft() + d.callback(r) + self._checkPause() def close(self): self.transport.loseConnection() + while self._waiting_reads: + d = self._waiting_reads.popleft() + d.errback(error.ConnectionClosed()) def timeoutConnection(self): self._error = BadHandshake("timeout")