Merge pull request #20 from meejah/iosim-based-tests

iosim.IOPump based tests
This commit is contained in:
meejah 2021-04-12 21:00:50 -06:00 committed by GitHub
commit de8e0f0399
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 275 additions and 121 deletions

View File

@ -1,28 +1,129 @@
from twisted.test import proto_helpers from twisted.internet.protocol import (
from ..transit_server import Transit ClientFactory,
Protocol,
)
from twisted.test import iosim
from zope.interface import (
Interface,
Attribute,
implementer,
)
from ..transit_server import (
Transit,
)
class IRelayTestClient(Interface):
"""
The client interface used by tests.
"""
connected = Attribute("True if we are currently connected else False")
def send(data):
"""
Send some bytes.
:param bytes data: the data to send
"""
def disconnect():
"""
Terminate the connection.
"""
def get_received_data():
"""
:returns: all the bytes received from the server on this
connection.
"""
def reset_data():
"""
Erase any received data to this point.
"""
class ServerBase: class ServerBase:
log_requests = False log_requests = False
def setUp(self): def setUp(self):
self._pumps = []
self._lp = None self._lp = None
if self.log_requests: if self.log_requests:
blur_usage = None blur_usage = None
else: else:
blur_usage = 60.0 blur_usage = 60.0
self._setup_relay(blur_usage=blur_usage) self._setup_relay(blur_usage=blur_usage)
self._transit_server._debug_log = self.log_requests
def flush(self):
did_work = False
for pump in self._pumps:
did_work = pump.flush() or did_work
if did_work:
self.flush()
def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None):
self._transit_server = Transit(blur_usage=blur_usage, self._transit_server = Transit(
log_file=log_file, usage_db=usage_db) blur_usage=blur_usage,
log_file=log_file,
usage_db=usage_db,
)
self._transit_server._debug_log = self.log_requests
def new_protocol(self): def new_protocol(self):
protocol = self._transit_server.buildProtocol(('127.0.0.1', 0)) """
transport = proto_helpers.StringTransportWithDisconnection() Create a new client protocol connected to the server.
protocol.makeConnection(transport) :returns: a IRelayTestClient implementation
transport.protocol = protocol """
return protocol server_protocol = self._transit_server.buildProtocol(('127.0.0.1', 0))
@implementer(IRelayTestClient)
class TransitClientProtocolTcp(Protocol):
"""
Speak the transit client protocol used by the tests over TCP
"""
_received = b""
connected = False
# override Protocol callbacks
def connectionMade(self):
self.connected = True
return Protocol.connectionMade(self)
def connectionLost(self, reason):
self.connected = False
return Protocol.connectionLost(self, reason)
def dataReceived(self, data):
self._received = self._received + data
# IRelayTestClient
def send(self, data):
self.transport.write(data)
def disconnect(self):
self.transport.loseConnection()
def reset_received_data(self):
self._received = b""
def get_received_data(self):
return self._received
client_factory = ClientFactory()
client_factory.protocol = TransitClientProtocolTcp
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
pump = iosim.connect(
server_protocol,
iosim.makeFakeServer(server_protocol),
client_protocol,
iosim.makeFakeClient(client_protocol),
)
pump.flush()
self._pumps.append(pump)
return client_protocol
def tearDown(self): def tearDown(self):
if self._lp: if self._lp:

View File

@ -1,4 +1,7 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
try:
from unittest import mock
except ImportError:
import mock import mock
from twisted.trial import unittest from twisted.trial import unittest
from ..increase_rlimits import increase_rlimits from ..increase_rlimits import increase_rlimits

View File

@ -1,5 +1,8 @@
from __future__ import unicode_literals, print_function from __future__ import unicode_literals, print_function
from twisted.trial import unittest from twisted.trial import unittest
try:
from unittest import mock
except ImportError:
import mock import mock
from twisted.application.service import MultiService from twisted.application.service import MultiService
from .. import server_tap from .. import server_tap

View File

@ -1,5 +1,8 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import os, io, json, sqlite3 import os, io, json, sqlite3
try:
from unittest import mock
except ImportError:
import mock import mock
from twisted.trial import unittest from twisted.trial import unittest
from ..transit_server import Transit from ..transit_server import Transit

View File

@ -41,10 +41,12 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side1)) p1.send(handshake(token1, side1))
self.flush()
self.assertEqual(self.count(), 1) self.assertEqual(self.count(), 1)
p1.transport.loseConnection() p1.disconnect()
self.flush()
self.assertEqual(self.count(), 0) self.assertEqual(self.count(), 0)
# the token should be removed too # the token should be removed too
@ -55,23 +57,27 @@ class _Transit:
p2 = self.new_protocol() p2 = self.new_protocol()
token1 = b"\x00"*32 token1 = b"\x00"*32
p1.dataReceived(handshake(token1, side=None)) p1.send(handshake(token1, side=None))
p2.dataReceived(handshake(token1, side=None)) self.flush()
p2.send(handshake(token1, side=None))
self.flush()
# a correct handshake yields an ack, after which we can send # a correct handshake yields an ack, after which we can send
exp = b"ok\n" exp = b"ok\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.get_received_data(), exp)
self.assertEqual(p2.transport.value(), exp) self.assertEqual(p2.get_received_data(), exp)
p1.transport.clear() p1.reset_received_data()
p2.transport.clear() p2.reset_received_data()
s1 = b"data1" s1 = b"data1"
p1.dataReceived(s1) p1.send(s1)
self.assertEqual(p2.transport.value(), s1) self.flush()
self.assertEqual(p2.get_received_data(), s1)
p1.transport.loseConnection() p1.disconnect()
p2.transport.loseConnection() p2.disconnect()
self.flush()
def test_sided_unsided(self): def test_sided_unsided(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -79,24 +85,28 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side=side1)) p1.send(handshake(token1, side=side1))
p2.dataReceived(handshake(token1, side=None)) self.flush()
p2.send(handshake(token1, side=None))
self.flush()
# a correct handshake yields an ack, after which we can send # a correct handshake yields an ack, after which we can send
exp = b"ok\n" exp = b"ok\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.get_received_data(), exp)
self.assertEqual(p2.transport.value(), exp) self.assertEqual(p2.get_received_data(), exp)
p1.transport.clear() p1.reset_received_data()
p2.transport.clear() p2.reset_received_data()
# all data they sent after the handshake should be given to us # all data they sent after the handshake should be given to us
s1 = b"data1" s1 = b"data1"
p1.dataReceived(s1) p1.send(s1)
self.assertEqual(p2.transport.value(), s1) self.flush()
self.assertEqual(p2.get_received_data(), s1)
p1.transport.loseConnection() p1.disconnect()
p2.transport.loseConnection() p2.disconnect()
self.flush()
def test_unsided_sided(self): def test_unsided_sided(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -104,24 +114,26 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side=None)) p1.send(handshake(token1, side=None))
p2.dataReceived(handshake(token1, side=side1)) p2.send(handshake(token1, side=side1))
self.flush()
# a correct handshake yields an ack, after which we can send # a correct handshake yields an ack, after which we can send
exp = b"ok\n" exp = b"ok\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.get_received_data(), exp)
self.assertEqual(p2.transport.value(), exp) self.assertEqual(p2.get_received_data(), exp)
p1.transport.clear() p1.reset_received_data()
p2.transport.clear() p2.reset_received_data()
# all data they sent after the handshake should be given to us # all data they sent after the handshake should be given to us
s1 = b"data1" s1 = b"data1"
p1.dataReceived(s1) p1.send(s1)
self.assertEqual(p2.transport.value(), s1) self.flush()
self.assertEqual(p2.get_received_data(), s1)
p1.transport.loseConnection() p1.disconnect()
p2.transport.loseConnection() p2.disconnect()
def test_both_sided(self): def test_both_sided(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -130,24 +142,27 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
side2 = b"\x02"*8 side2 = b"\x02"*8
p1.dataReceived(handshake(token1, side=side1)) p1.send(handshake(token1, side=side1))
p2.dataReceived(handshake(token1, side=side2)) self.flush()
p2.send(handshake(token1, side=side2))
self.flush()
# a correct handshake yields an ack, after which we can send # a correct handshake yields an ack, after which we can send
exp = b"ok\n" exp = b"ok\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.get_received_data(), exp)
self.assertEqual(p2.transport.value(), exp) self.assertEqual(p2.get_received_data(), exp)
p1.transport.clear() p1.reset_received_data()
p2.transport.clear() p2.reset_received_data()
# all data they sent after the handshake should be given to us # all data they sent after the handshake should be given to us
s1 = b"data1" s1 = b"data1"
p1.dataReceived(s1) p1.send(s1)
self.assertEqual(p2.transport.value(), s1) self.flush()
self.assertEqual(p2.get_received_data(), s1)
p1.transport.loseConnection() p1.disconnect()
p2.transport.loseConnection() p2.disconnect()
def test_ignore_same_side(self): def test_ignore_same_side(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -157,41 +172,46 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side=side1)) p1.send(handshake(token1, side=side1))
self.flush()
self.assertEqual(self.count(), 1) self.assertEqual(self.count(), 1)
p2.dataReceived(handshake(token1, side=side1)) p2.send(handshake(token1, side=side1))
self.flush()
self.assertEqual(self.count(), 2) # same-side connections don't match self.assertEqual(self.count(), 2) # same-side connections don't match
# when the second side arrives, the spare first connection should be # when the second side arrives, the spare first connection should be
# closed # closed
side2 = b"\x02"*8 side2 = b"\x02"*8
p3.dataReceived(handshake(token1, side=side2)) p3.send(handshake(token1, side=side2))
self.flush()
self.assertEqual(self.count(), 0) self.assertEqual(self.count(), 0)
self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server._pending_requests), 0)
self.assertEqual(len(self._transit_server._active_connections), 2) self.assertEqual(len(self._transit_server._active_connections), 2)
# That will trigger a disconnect on exactly one of (p1 or p2). # That will trigger a disconnect on exactly one of (p1 or p2).
# The other connection should still be connected # The other connection should still be connected
self.assertEqual(sum([int(t.transport.connected) for t in [p1, p2]]), 1) self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1)
p1.transport.loseConnection() p1.disconnect()
p2.transport.loseConnection() p2.disconnect()
p3.transport.loseConnection() p3.disconnect()
def test_bad_handshake_old(self): def test_bad_handshake_old(self):
p1 = self.new_protocol() p1 = self.new_protocol()
token1 = b"\x00"*32 token1 = b"\x00"*32
p1.dataReceived(b"please DELAY " + hexlify(token1) + b"\n") p1.send(b"please DELAY " + hexlify(token1) + b"\n")
self.flush()
exp = b"bad handshake\n" exp = b"bad handshake\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.get_received_data(), exp)
p1.transport.loseConnection() p1.disconnect()
def test_bad_handshake_old_slow(self): def test_bad_handshake_old_slow(self):
p1 = self.new_protocol() p1 = self.new_protocol()
p1.dataReceived(b"please DELAY ") p1.send(b"please DELAY ")
self.flush()
# As in test_impatience_new_slow, the current state machine has code # As in test_impatience_new_slow, the current state machine has code
# that can only be reached if we insert a stall here, so dataReceived # that can only be reached if we insert a stall here, so dataReceived
# gets called twice. Hopefully we can delete this test once # gets called twice. Hopefully we can delete this test once
@ -200,12 +220,13 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
# the server waits for the exact number of bytes in the expected # the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match. # handshake message. to trigger "bad handshake", we must match.
p1.dataReceived(hexlify(token1) + b"\n") p1.send(hexlify(token1) + b"\n")
self.flush()
exp = b"bad handshake\n" exp = b"bad handshake\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.get_received_data(), exp)
p1.transport.loseConnection() p1.disconnect()
def test_bad_handshake_new(self): def test_bad_handshake_new(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -214,13 +235,14 @@ class _Transit:
side1 = b"\x01"*8 side1 = b"\x01"*8
# the server waits for the exact number of bytes in the expected # the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match. # handshake message. to trigger "bad handshake", we must match.
p1.dataReceived(b"please DELAY " + hexlify(token1) + p1.send(b"please DELAY " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n") b" for side " + hexlify(side1) + b"\n")
self.flush()
exp = b"bad handshake\n" exp = b"bad handshake\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.get_received_data(), exp)
p1.transport.loseConnection() p1.disconnect()
def test_binary_handshake(self): def test_binary_handshake(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -232,24 +254,26 @@ class _Transit:
# UnicodeDecodeError when it tried to coerce the incoming handshake # UnicodeDecodeError when it tried to coerce the incoming handshake
# to unicode, due to the ("\n" in buf) check. This was fixed to use # to unicode, due to the ("\n" in buf) check. This was fixed to use
# (b"\n" in buf). This exercises the old failure. # (b"\n" in buf). This exercises the old failure.
p1.dataReceived(binary_bad_handshake) p1.send(binary_bad_handshake)
self.flush()
exp = b"bad handshake\n" exp = b"bad handshake\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.get_received_data(), exp)
p1.transport.loseConnection() p1.disconnect()
def test_impatience_old(self): def test_impatience_old(self):
p1 = self.new_protocol() p1 = self.new_protocol()
token1 = b"\x00"*32 token1 = b"\x00"*32
# sending too many bytes is impatience. # sending too many bytes is impatience.
p1.dataReceived(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") p1.send(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
self.flush()
exp = b"impatient\n" exp = b"impatient\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.get_received_data(), exp)
p1.transport.loseConnection() p1.disconnect()
def test_impatience_new(self): def test_impatience_new(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -257,13 +281,14 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
# sending too many bytes is impatience. # sending too many bytes is impatience.
p1.dataReceived(b"please relay " + hexlify(token1) + p1.send(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\nNOWNOWNOW") b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
self.flush()
exp = b"impatient\n" exp = b"impatient\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.get_received_data(), exp)
p1.transport.loseConnection() p1.disconnect()
def test_impatience_new_slow(self): def test_impatience_new_slow(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -279,27 +304,29 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
# sending too many bytes is impatience. # sending too many bytes is impatience.
p1.dataReceived(b"please relay " + hexlify(token1) + p1.send(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n") b" for side " + hexlify(side1) + b"\n")
self.flush()
p1.send(b"NOWNOWNOW")
p1.dataReceived(b"NOWNOWNOW") self.flush()
exp = b"impatient\n" exp = b"impatient\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.get_received_data(), exp)
p1.transport.loseConnection() p1.disconnect()
def test_short_handshake(self): def test_short_handshake(self):
p1 = self.new_protocol() p1 = self.new_protocol()
# hang up before sending a complete handshake # hang up before sending a complete handshake
p1.dataReceived(b"short") p1.send(b"short")
p1.transport.loseConnection() self.flush()
p1.disconnect()
def test_empty_handshake(self): def test_empty_handshake(self):
p1 = self.new_protocol() p1 = self.new_protocol()
# hang up before sending anything # hang up before sending anything
p1.transport.loseConnection() p1.disconnect()
class TransitWithLogs(_Transit, ServerBase, unittest.TestCase): class TransitWithLogs(_Transit, ServerBase, unittest.TestCase):
log_requests = True log_requests = True
@ -308,6 +335,8 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
log_requests = False log_requests = False
class Usage(ServerBase, unittest.TestCase): class Usage(ServerBase, unittest.TestCase):
log_requests = True
def setUp(self): def setUp(self):
super(Usage, self).setUp() super(Usage, self).setUp()
self._usage = [] self._usage = []
@ -319,7 +348,8 @@ class Usage(ServerBase, unittest.TestCase):
def test_empty(self): def test_empty(self):
p1 = self.new_protocol() p1 = self.new_protocol()
# hang up before sending anything # hang up before sending anything
p1.transport.loseConnection() p1.disconnect()
self.flush()
# that will log the "empty" usage event # that will log the "empty" usage event
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
@ -329,8 +359,9 @@ class Usage(ServerBase, unittest.TestCase):
def test_short(self): def test_short(self):
p1 = self.new_protocol() p1 = self.new_protocol()
# hang up before sending a complete handshake # hang up before sending a complete handshake
p1.transport.write(b"short") p1.send(b"short")
p1.transport.loseConnection() p1.disconnect()
self.flush()
# that will log the "empty" usage event # that will log the "empty" usage event
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
@ -340,9 +371,10 @@ class Usage(ServerBase, unittest.TestCase):
def test_errory(self): def test_errory(self):
p1 = self.new_protocol() p1 = self.new_protocol()
p1.dataReceived(b"this is a very bad handshake\n") p1.send(b"this is a very bad handshake\n")
self.flush()
# that will log the "errory" usage event, then drop the connection # that will log the "errory" usage event, then drop the connection
p1.transport.loseConnection() p1.disconnect()
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "errory", self._usage) self.assertEqual(result, "errory", self._usage)
@ -352,9 +384,11 @@ class Usage(ServerBase, unittest.TestCase):
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side=side1)) p1.send(handshake(token1, side=side1))
self.flush()
# now we disconnect before the peer connects # now we disconnect before the peer connects
p1.transport.loseConnection() p1.disconnect()
self.flush()
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
@ -368,15 +402,20 @@ class Usage(ServerBase, unittest.TestCase):
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
side2 = b"\x02"*8 side2 = b"\x02"*8
p1.dataReceived(handshake(token1, side=side1)) p1.send(handshake(token1, side=side1))
p2.dataReceived(handshake(token1, side=side2)) self.flush()
p2.send(handshake(token1, side=side2))
self.flush()
self.assertEqual(self._usage, []) # no events yet self.assertEqual(self._usage, []) # no events yet
p1.dataReceived(b"\x00" * 13) p1.send(b"\x00" * 13)
p2.dataReceived(b"\xff" * 7) self.flush()
p2.send(b"\xff" * 7)
self.flush()
p1.transport.loseConnection() p1.disconnect()
self.flush()
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
@ -393,28 +432,33 @@ class Usage(ServerBase, unittest.TestCase):
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
side2 = b"\x02"*8 side2 = b"\x02"*8
p1a.dataReceived(handshake(token1, side=side1)) p1a.send(handshake(token1, side=side1))
p1b.dataReceived(handshake(token1, side=side1)) self.flush()
p1b.send(handshake(token1, side=side1))
self.flush()
# connect and disconnect a third client (for side1) to exercise the # connect and disconnect a third client (for side1) to exercise the
# code that removes a pending connection without removing the entire # code that removes a pending connection without removing the entire
# token # token
p1c.dataReceived(handshake(token1, side=side1)) p1c.send(handshake(token1, side=side1))
p1c.transport.loseConnection() p1c.disconnect()
self.flush()
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "lonely", self._usage) self.assertEqual(result, "lonely", self._usage)
p2.dataReceived(handshake(token1, side=side2)) p2.send(handshake(token1, side=side2))
self.flush()
self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server._pending_requests), 0)
self.assertEqual(len(self._usage), 2, self._usage) self.assertEqual(len(self._usage), 2, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[1] (started, result, total_bytes, total_time, waiting_time) = self._usage[1]
self.assertEqual(result, "redundant", self._usage) self.assertEqual(result, "redundant", self._usage)
# one of the these is unecessary, but probably harmless # one of the these is unecessary, but probably harmless
p1a.transport.loseConnection() p1a.disconnect()
p1b.transport.loseConnection() p1b.disconnect()
self.flush()
self.assertEqual(len(self._usage), 3, self._usage) self.assertEqual(len(self._usage), 3, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[2] (started, result, total_bytes, total_time, waiting_time) = self._usage[2]
self.assertEqual(result, "happy", self._usage) self.assertEqual(result, "happy", self._usage)

View File

@ -84,14 +84,14 @@ class TransitConnection(LineReceiver):
# point the sender will only transmit data as fast as the # point the sender will only transmit data as fast as the
# receiver can handle it. # receiver can handle it.
if self._sent_ok: if self._sent_ok:
if not self._buddy: # if self._buddy is None then our buddy disconnected
# Our buddy disconnected (we're "jilted"), so we hung up too, # (we're "jilted"), so we hung up too, but our incoming
# but our incoming data hasn't stopped yet (it will in a # data hasn't stopped yet (it will in a moment, after our
# moment, after our disconnect makes a roundtrip through the # disconnect makes a roundtrip through the kernel). This
# kernel). This probably means the file receiver hung up, and # probably means the file receiver hung up, and this
# this connection is the file sender. In may-2020 this # connection is the file sender. In may-2020 this happened
# happened 11 times in 40 days. # 11 times in 40 days.
return if self._buddy:
self._total_sent += len(data) self._total_sent += len(data)
self._buddy.transport.write(data) self._buddy.transport.write(data)
return return