Merge pull request #20 from meejah/iosim-based-tests

iosim.IOPump based tests
This commit is contained in:
meejah 2021-04-12 21:00:50 -06:00 committed by GitHub
commit de8e0f0399
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 275 additions and 121 deletions

View File

@ -1,28 +1,129 @@
from twisted.test import proto_helpers
from ..transit_server import Transit
from twisted.internet.protocol import (
ClientFactory,
Protocol,
)
from twisted.test import iosim
from zope.interface import (
Interface,
Attribute,
implementer,
)
from ..transit_server import (
Transit,
)
class IRelayTestClient(Interface):
"""
The client interface used by tests.
"""
connected = Attribute("True if we are currently connected else False")
def send(data):
"""
Send some bytes.
:param bytes data: the data to send
"""
def disconnect():
"""
Terminate the connection.
"""
def get_received_data():
"""
:returns: all the bytes received from the server on this
connection.
"""
def reset_data():
"""
Erase any received data to this point.
"""
class ServerBase:
log_requests = False
def setUp(self):
self._pumps = []
self._lp = None
if self.log_requests:
blur_usage = None
else:
blur_usage = 60.0
self._setup_relay(blur_usage=blur_usage)
self._transit_server._debug_log = self.log_requests
def flush(self):
did_work = False
for pump in self._pumps:
did_work = pump.flush() or did_work
if did_work:
self.flush()
def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None):
self._transit_server = Transit(blur_usage=blur_usage,
log_file=log_file, usage_db=usage_db)
self._transit_server = Transit(
blur_usage=blur_usage,
log_file=log_file,
usage_db=usage_db,
)
self._transit_server._debug_log = self.log_requests
def new_protocol(self):
protocol = self._transit_server.buildProtocol(('127.0.0.1', 0))
transport = proto_helpers.StringTransportWithDisconnection()
protocol.makeConnection(transport)
transport.protocol = protocol
return protocol
"""
Create a new client protocol connected to the server.
:returns: a IRelayTestClient implementation
"""
server_protocol = self._transit_server.buildProtocol(('127.0.0.1', 0))
@implementer(IRelayTestClient)
class TransitClientProtocolTcp(Protocol):
"""
Speak the transit client protocol used by the tests over TCP
"""
_received = b""
connected = False
# override Protocol callbacks
def connectionMade(self):
self.connected = True
return Protocol.connectionMade(self)
def connectionLost(self, reason):
self.connected = False
return Protocol.connectionLost(self, reason)
def dataReceived(self, data):
self._received = self._received + data
# IRelayTestClient
def send(self, data):
self.transport.write(data)
def disconnect(self):
self.transport.loseConnection()
def reset_received_data(self):
self._received = b""
def get_received_data(self):
return self._received
client_factory = ClientFactory()
client_factory.protocol = TransitClientProtocolTcp
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
pump = iosim.connect(
server_protocol,
iosim.makeFakeServer(server_protocol),
client_protocol,
iosim.makeFakeClient(client_protocol),
)
pump.flush()
self._pumps.append(pump)
return client_protocol
def tearDown(self):
if self._lp:

View File

@ -1,5 +1,8 @@
from __future__ import print_function, unicode_literals
import mock
try:
from unittest import mock
except ImportError:
import mock
from twisted.trial import unittest
from ..increase_rlimits import increase_rlimits

View File

@ -1,6 +1,9 @@
from __future__ import unicode_literals, print_function
from twisted.trial import unittest
import mock
try:
from unittest import mock
except ImportError:
import mock
from twisted.application.service import MultiService
from .. import server_tap

View File

@ -1,6 +1,9 @@
from __future__ import print_function, unicode_literals
import os, io, json, sqlite3
import mock
try:
from unittest import mock
except ImportError:
import mock
from twisted.trial import unittest
from ..transit_server import Transit
from .. import database

View File

@ -41,10 +41,12 @@ class _Transit:
token1 = b"\x00"*32
side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side1))
p1.send(handshake(token1, side1))
self.flush()
self.assertEqual(self.count(), 1)
p1.transport.loseConnection()
p1.disconnect()
self.flush()
self.assertEqual(self.count(), 0)
# the token should be removed too
@ -55,23 +57,27 @@ class _Transit:
p2 = self.new_protocol()
token1 = b"\x00"*32
p1.dataReceived(handshake(token1, side=None))
p2.dataReceived(handshake(token1, side=None))
p1.send(handshake(token1, side=None))
self.flush()
p2.send(handshake(token1, side=None))
self.flush()
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
self.assertEqual(p1.transport.value(), exp)
self.assertEqual(p2.transport.value(), exp)
self.assertEqual(p1.get_received_data(), exp)
self.assertEqual(p2.get_received_data(), exp)
p1.transport.clear()
p2.transport.clear()
p1.reset_received_data()
p2.reset_received_data()
s1 = b"data1"
p1.dataReceived(s1)
self.assertEqual(p2.transport.value(), s1)
p1.send(s1)
self.flush()
self.assertEqual(p2.get_received_data(), s1)
p1.transport.loseConnection()
p2.transport.loseConnection()
p1.disconnect()
p2.disconnect()
self.flush()
def test_sided_unsided(self):
p1 = self.new_protocol()
@ -79,24 +85,28 @@ class _Transit:
token1 = b"\x00"*32
side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side=side1))
p2.dataReceived(handshake(token1, side=None))
p1.send(handshake(token1, side=side1))
self.flush()
p2.send(handshake(token1, side=None))
self.flush()
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
self.assertEqual(p1.transport.value(), exp)
self.assertEqual(p2.transport.value(), exp)
self.assertEqual(p1.get_received_data(), exp)
self.assertEqual(p2.get_received_data(), exp)
p1.transport.clear()
p2.transport.clear()
p1.reset_received_data()
p2.reset_received_data()
# all data they sent after the handshake should be given to us
s1 = b"data1"
p1.dataReceived(s1)
self.assertEqual(p2.transport.value(), s1)
p1.send(s1)
self.flush()
self.assertEqual(p2.get_received_data(), s1)
p1.transport.loseConnection()
p2.transport.loseConnection()
p1.disconnect()
p2.disconnect()
self.flush()
def test_unsided_sided(self):
p1 = self.new_protocol()
@ -104,24 +114,26 @@ class _Transit:
token1 = b"\x00"*32
side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side=None))
p2.dataReceived(handshake(token1, side=side1))
p1.send(handshake(token1, side=None))
p2.send(handshake(token1, side=side1))
self.flush()
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
self.assertEqual(p1.transport.value(), exp)
self.assertEqual(p2.transport.value(), exp)
self.assertEqual(p1.get_received_data(), exp)
self.assertEqual(p2.get_received_data(), exp)
p1.transport.clear()
p2.transport.clear()
p1.reset_received_data()
p2.reset_received_data()
# all data they sent after the handshake should be given to us
s1 = b"data1"
p1.dataReceived(s1)
self.assertEqual(p2.transport.value(), s1)
p1.send(s1)
self.flush()
self.assertEqual(p2.get_received_data(), s1)
p1.transport.loseConnection()
p2.transport.loseConnection()
p1.disconnect()
p2.disconnect()
def test_both_sided(self):
p1 = self.new_protocol()
@ -130,24 +142,27 @@ class _Transit:
token1 = b"\x00"*32
side1 = b"\x01"*8
side2 = b"\x02"*8
p1.dataReceived(handshake(token1, side=side1))
p2.dataReceived(handshake(token1, side=side2))
p1.send(handshake(token1, side=side1))
self.flush()
p2.send(handshake(token1, side=side2))
self.flush()
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
self.assertEqual(p1.transport.value(), exp)
self.assertEqual(p2.transport.value(), exp)
self.assertEqual(p1.get_received_data(), exp)
self.assertEqual(p2.get_received_data(), exp)
p1.transport.clear()
p2.transport.clear()
p1.reset_received_data()
p2.reset_received_data()
# all data they sent after the handshake should be given to us
s1 = b"data1"
p1.dataReceived(s1)
self.assertEqual(p2.transport.value(), s1)
p1.send(s1)
self.flush()
self.assertEqual(p2.get_received_data(), s1)
p1.transport.loseConnection()
p2.transport.loseConnection()
p1.disconnect()
p2.disconnect()
def test_ignore_same_side(self):
p1 = self.new_protocol()
@ -157,41 +172,46 @@ class _Transit:
token1 = b"\x00"*32
side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side=side1))
p1.send(handshake(token1, side=side1))
self.flush()
self.assertEqual(self.count(), 1)
p2.dataReceived(handshake(token1, side=side1))
p2.send(handshake(token1, side=side1))
self.flush()
self.assertEqual(self.count(), 2) # same-side connections don't match
# when the second side arrives, the spare first connection should be
# closed
side2 = b"\x02"*8
p3.dataReceived(handshake(token1, side=side2))
p3.send(handshake(token1, side=side2))
self.flush()
self.assertEqual(self.count(), 0)
self.assertEqual(len(self._transit_server._pending_requests), 0)
self.assertEqual(len(self._transit_server._active_connections), 2)
# That will trigger a disconnect on exactly one of (p1 or p2).
# The other connection should still be connected
self.assertEqual(sum([int(t.transport.connected) for t in [p1, p2]]), 1)
self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1)
p1.transport.loseConnection()
p2.transport.loseConnection()
p3.transport.loseConnection()
p1.disconnect()
p2.disconnect()
p3.disconnect()
def test_bad_handshake_old(self):
p1 = self.new_protocol()
token1 = b"\x00"*32
p1.dataReceived(b"please DELAY " + hexlify(token1) + b"\n")
p1.send(b"please DELAY " + hexlify(token1) + b"\n")
self.flush()
exp = b"bad handshake\n"
self.assertEqual(p1.transport.value(), exp)
p1.transport.loseConnection()
self.assertEqual(p1.get_received_data(), exp)
p1.disconnect()
def test_bad_handshake_old_slow(self):
p1 = self.new_protocol()
p1.dataReceived(b"please DELAY ")
p1.send(b"please DELAY ")
self.flush()
# As in test_impatience_new_slow, the current state machine has code
# that can only be reached if we insert a stall here, so dataReceived
# gets called twice. Hopefully we can delete this test once
@ -200,12 +220,13 @@ class _Transit:
token1 = b"\x00"*32
# the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match.
p1.dataReceived(hexlify(token1) + b"\n")
p1.send(hexlify(token1) + b"\n")
self.flush()
exp = b"bad handshake\n"
self.assertEqual(p1.transport.value(), exp)
self.assertEqual(p1.get_received_data(), exp)
p1.transport.loseConnection()
p1.disconnect()
def test_bad_handshake_new(self):
p1 = self.new_protocol()
@ -214,13 +235,14 @@ class _Transit:
side1 = b"\x01"*8
# the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match.
p1.dataReceived(b"please DELAY " + hexlify(token1) +
p1.send(b"please DELAY " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
self.flush()
exp = b"bad handshake\n"
self.assertEqual(p1.transport.value(), exp)
self.assertEqual(p1.get_received_data(), exp)
p1.transport.loseConnection()
p1.disconnect()
def test_binary_handshake(self):
p1 = self.new_protocol()
@ -232,24 +254,26 @@ class _Transit:
# UnicodeDecodeError when it tried to coerce the incoming handshake
# to unicode, due to the ("\n" in buf) check. This was fixed to use
# (b"\n" in buf). This exercises the old failure.
p1.dataReceived(binary_bad_handshake)
p1.send(binary_bad_handshake)
self.flush()
exp = b"bad handshake\n"
self.assertEqual(p1.transport.value(), exp)
self.assertEqual(p1.get_received_data(), exp)
p1.transport.loseConnection()
p1.disconnect()
def test_impatience_old(self):
p1 = self.new_protocol()
token1 = b"\x00"*32
# sending too many bytes is impatience.
p1.dataReceived(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
p1.send(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
self.flush()
exp = b"impatient\n"
self.assertEqual(p1.transport.value(), exp)
self.assertEqual(p1.get_received_data(), exp)
p1.transport.loseConnection()
p1.disconnect()
def test_impatience_new(self):
p1 = self.new_protocol()
@ -257,13 +281,14 @@ class _Transit:
token1 = b"\x00"*32
side1 = b"\x01"*8
# sending too many bytes is impatience.
p1.dataReceived(b"please relay " + hexlify(token1) +
p1.send(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
self.flush()
exp = b"impatient\n"
self.assertEqual(p1.transport.value(), exp)
self.assertEqual(p1.get_received_data(), exp)
p1.transport.loseConnection()
p1.disconnect()
def test_impatience_new_slow(self):
p1 = self.new_protocol()
@ -279,27 +304,29 @@ class _Transit:
token1 = b"\x00"*32
side1 = b"\x01"*8
# sending too many bytes is impatience.
p1.dataReceived(b"please relay " + hexlify(token1) +
p1.send(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
self.flush()
p1.dataReceived(b"NOWNOWNOW")
p1.send(b"NOWNOWNOW")
self.flush()
exp = b"impatient\n"
self.assertEqual(p1.transport.value(), exp)
self.assertEqual(p1.get_received_data(), exp)
p1.transport.loseConnection()
p1.disconnect()
def test_short_handshake(self):
p1 = self.new_protocol()
# hang up before sending a complete handshake
p1.dataReceived(b"short")
p1.transport.loseConnection()
p1.send(b"short")
self.flush()
p1.disconnect()
def test_empty_handshake(self):
p1 = self.new_protocol()
# hang up before sending anything
p1.transport.loseConnection()
p1.disconnect()
class TransitWithLogs(_Transit, ServerBase, unittest.TestCase):
log_requests = True
@ -308,6 +335,8 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
log_requests = False
class Usage(ServerBase, unittest.TestCase):
log_requests = True
def setUp(self):
super(Usage, self).setUp()
self._usage = []
@ -319,7 +348,8 @@ class Usage(ServerBase, unittest.TestCase):
def test_empty(self):
p1 = self.new_protocol()
# hang up before sending anything
p1.transport.loseConnection()
p1.disconnect()
self.flush()
# that will log the "empty" usage event
self.assertEqual(len(self._usage), 1, self._usage)
@ -329,8 +359,9 @@ class Usage(ServerBase, unittest.TestCase):
def test_short(self):
p1 = self.new_protocol()
# hang up before sending a complete handshake
p1.transport.write(b"short")
p1.transport.loseConnection()
p1.send(b"short")
p1.disconnect()
self.flush()
# that will log the "empty" usage event
self.assertEqual(len(self._usage), 1, self._usage)
@ -340,9 +371,10 @@ class Usage(ServerBase, unittest.TestCase):
def test_errory(self):
p1 = self.new_protocol()
p1.dataReceived(b"this is a very bad handshake\n")
p1.send(b"this is a very bad handshake\n")
self.flush()
# that will log the "errory" usage event, then drop the connection
p1.transport.loseConnection()
p1.disconnect()
self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "errory", self._usage)
@ -352,9 +384,11 @@ class Usage(ServerBase, unittest.TestCase):
token1 = b"\x00"*32
side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side=side1))
p1.send(handshake(token1, side=side1))
self.flush()
# now we disconnect before the peer connects
p1.transport.loseConnection()
p1.disconnect()
self.flush()
self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
@ -368,15 +402,20 @@ class Usage(ServerBase, unittest.TestCase):
token1 = b"\x00"*32
side1 = b"\x01"*8
side2 = b"\x02"*8
p1.dataReceived(handshake(token1, side=side1))
p2.dataReceived(handshake(token1, side=side2))
p1.send(handshake(token1, side=side1))
self.flush()
p2.send(handshake(token1, side=side2))
self.flush()
self.assertEqual(self._usage, []) # no events yet
p1.dataReceived(b"\x00" * 13)
p2.dataReceived(b"\xff" * 7)
p1.send(b"\x00" * 13)
self.flush()
p2.send(b"\xff" * 7)
self.flush()
p1.transport.loseConnection()
p1.disconnect()
self.flush()
self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
@ -393,28 +432,33 @@ class Usage(ServerBase, unittest.TestCase):
token1 = b"\x00"*32
side1 = b"\x01"*8
side2 = b"\x02"*8
p1a.dataReceived(handshake(token1, side=side1))
p1b.dataReceived(handshake(token1, side=side1))
p1a.send(handshake(token1, side=side1))
self.flush()
p1b.send(handshake(token1, side=side1))
self.flush()
# connect and disconnect a third client (for side1) to exercise the
# code that removes a pending connection without removing the entire
# token
p1c.dataReceived(handshake(token1, side=side1))
p1c.transport.loseConnection()
p1c.send(handshake(token1, side=side1))
p1c.disconnect()
self.flush()
self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "lonely", self._usage)
p2.dataReceived(handshake(token1, side=side2))
p2.send(handshake(token1, side=side2))
self.flush()
self.assertEqual(len(self._transit_server._pending_requests), 0)
self.assertEqual(len(self._usage), 2, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[1]
self.assertEqual(result, "redundant", self._usage)
# one of the these is unecessary, but probably harmless
p1a.transport.loseConnection()
p1b.transport.loseConnection()
p1a.disconnect()
p1b.disconnect()
self.flush()
self.assertEqual(len(self._usage), 3, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[2]
self.assertEqual(result, "happy", self._usage)

View File

@ -84,14 +84,14 @@ class TransitConnection(LineReceiver):
# point the sender will only transmit data as fast as the
# receiver can handle it.
if self._sent_ok:
if not self._buddy:
# Our buddy disconnected (we're "jilted"), so we hung up too,
# but our incoming data hasn't stopped yet (it will in a
# moment, after our disconnect makes a roundtrip through the
# kernel). This probably means the file receiver hung up, and
# this connection is the file sender. In may-2020 this
# happened 11 times in 40 days.
return
# if self._buddy is None then our buddy disconnected
# (we're "jilted"), so we hung up too, but our incoming
# data hasn't stopped yet (it will in a moment, after our
# disconnect makes a roundtrip through the kernel). This
# probably means the file receiver hung up, and this
# connection is the file sender. In may-2020 this happened
# 11 times in 40 days.
if self._buddy:
self._total_sent += len(data)
self._buddy.transport.write(data)
return