explicit interface, different naming

This commit is contained in:
meejah 2021-04-02 20:10:01 -06:00
parent 0434296415
commit 45c09fdd05
2 changed files with 73 additions and 30 deletions

View File

@ -3,11 +3,45 @@ from twisted.internet.protocol import (
Protocol,
)
from twisted.test import iosim
from zope.interface import (
Interface,
Attribute,
implementer,
)
from ..transit_server import (
Transit,
)
class ITransitClient(Interface):
"""
The client interface used by tests.
"""
connected = Attribute("True if we are currently connected else False")
def send(data):
"""
Send some bytes.
:param bytes data: the data to send
"""
def disconnect():
"""
Terminate the connection.
"""
def get_received_data():
"""
:returns: all the bytes received from the server on this
connection.
"""
def reset_data():
"""
Erase any received data to this point.
"""
class ServerBase:
log_requests = False
@ -33,14 +67,18 @@ class ServerBase:
self._transit_server._debug_log = self.log_requests
def new_protocol(self):
"""
Create a new client protocol connected to the server.
:returns: a ITransitClient implementation
"""
server_protocol = self._transit_server.buildProtocol(('127.0.0.1', 0))
# XXX interface?
@implementer(ITransitClient)
class TransitClientProtocolTcp(Protocol):
"""
Speak the transit client protocol used by the tests over TCP
"""
received = b""
_received = b""
connected = False
def connectionMade(self):
@ -56,8 +94,13 @@ class ServerBase:
self.transport.loseConnection()
def dataReceived(self, data):
self.received = self.received + data
self._received = self._received + data
def reset_received_data(self):
self._received = b""
def get_received_data(self):
return self._received
client_factory = ClientFactory()
client_factory.protocol = TransitClientProtocolTcp

View File

@ -65,16 +65,16 @@ class _Transit:
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
self.assertEqual(p1.received, exp)
self.assertEqual(p2.received, exp)
self.assertEqual(p1.get_received_data(), exp)
self.assertEqual(p2.get_received_data(), exp)
p1.received = b""
p2.received = b""
p1.reset_received_data()
p2.reset_received_data()
s1 = b"data1"
p1.send(s1)
self.flush()
self.assertEqual(p2.received, s1)
self.assertEqual(p2.get_received_data(), s1)
p1.disconnect()
p2.disconnect()
@ -93,17 +93,17 @@ class _Transit:
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
self.assertEqual(p1.received, exp)
self.assertEqual(p2.received, exp)
self.assertEqual(p1.get_received_data(), exp)
self.assertEqual(p2.get_received_data(), exp)
p1.received = b""
p2.received = b""
p1.reset_received_data()
p2.reset_received_data()
# all data they sent after the handshake should be given to us
s1 = b"data1"
p1.send(s1)
self.flush()
self.assertEqual(p2.received, s1)
self.assertEqual(p2.get_received_data(), s1)
p1.disconnect()
p2.disconnect()
@ -121,17 +121,17 @@ class _Transit:
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
self.assertEqual(p1.received, exp)
self.assertEqual(p2.received, exp)
self.assertEqual(p1.get_received_data(), exp)
self.assertEqual(p2.get_received_data(), exp)
p1.received = b""
p2.received = b""
p1.reset_received_data()
p2.reset_received_data()
# all data they sent after the handshake should be given to us
s1 = b"data1"
p1.send(s1)
self.flush()
self.assertEqual(p2.received, s1)
self.assertEqual(p2.get_received_data(), s1)
p1.disconnect()
p2.disconnect()
@ -151,17 +151,17 @@ class _Transit:
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
self.assertEqual(p1.received, exp)
self.assertEqual(p2.received, exp)
self.assertEqual(p1.get_received_data(), exp)
self.assertEqual(p2.get_received_data(), exp)
p1.received = b""
p2.received = b""
p1.reset_received_data()
p2.reset_received_data()
# all data they sent after the handshake should be given to us
s1 = b"data1"
p1.send(s1)
self.flush()
self.assertEqual(p2.received, s1)
self.assertEqual(p2.get_received_data(), s1)
p1.disconnect()
p2.disconnect()
@ -207,7 +207,7 @@ class _Transit:
self.flush()
exp = b"bad handshake\n"
self.assertEqual(p1.received, exp)
self.assertEqual(p1.get_received_data(), exp)
p1.disconnect()
def test_bad_handshake_old_slow(self):
@ -227,7 +227,7 @@ class _Transit:
self.flush()
exp = b"bad handshake\n"
self.assertEqual(p1.received, exp)
self.assertEqual(p1.get_received_data(), exp)
p1.disconnect()
@ -243,7 +243,7 @@ class _Transit:
self.flush()
exp = b"bad handshake\n"
self.assertEqual(p1.received, exp)
self.assertEqual(p1.get_received_data(), exp)
p1.disconnect()
@ -261,7 +261,7 @@ class _Transit:
self.flush()
exp = b"bad handshake\n"
self.assertEqual(p1.received, exp)
self.assertEqual(p1.get_received_data(), exp)
p1.disconnect()
@ -274,7 +274,7 @@ class _Transit:
self.flush()
exp = b"impatient\n"
self.assertEqual(p1.received, exp)
self.assertEqual(p1.get_received_data(), exp)
p1.disconnect()
@ -289,7 +289,7 @@ class _Transit:
self.flush()
exp = b"impatient\n"
self.assertEqual(p1.received, exp)
self.assertEqual(p1.get_received_data(), exp)
p1.disconnect()
@ -315,7 +315,7 @@ class _Transit:
self.flush()
exp = b"impatient\n"
self.assertEqual(p1.received, exp)
self.assertEqual(p1.get_received_data(), exp)
p1.disconnect()