refactor to use IOPump: one test passes

This commit is contained in:
meejah 2021-04-02 15:28:08 -06:00
parent 4f818bb7e0
commit b9c2bbc524
2 changed files with 52 additions and 11 deletions

View File

@ -1,28 +1,67 @@
from twisted.test import proto_helpers
from ..transit_server import Transit
from twisted.internet.protocol import (
ServerFactory,
ClientFactory,
Protocol,
)
from twisted.test import iosim
from ..transit_server import (
Transit,
TransitConnection,
)
class ServerBase:
log_requests = False
def setUp(self):
self._pumps = []
self._lp = None
if self.log_requests:
blur_usage = None
else:
blur_usage = 60.0
self._setup_relay(blur_usage=blur_usage)
self._transit_server._debug_log = self.log_requests
def flush(self):
for pump in self._pumps:
pump.flush()
def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None):
self._transit_server = Transit(blur_usage=blur_usage,
log_file=log_file, usage_db=usage_db)
self._transit_server = Transit(
blur_usage=blur_usage,
log_file=log_file,
usage_db=usage_db,
)
self._transit_server._debug_log = self.log_requests
def new_protocol(self):
protocol = self._transit_server.buildProtocol(('127.0.0.1', 0))
transport = proto_helpers.StringTransportWithDisconnection()
protocol.makeConnection(transport)
transport.protocol = protocol
return protocol
server_protocol = self._transit_server.buildProtocol(('127.0.0.1', 0))
# XXX interface?
class TransitClientProtocolTcp(Protocol):
"""
Speak the transit client protocol used by the tests over TCP
"""
def send(self, data):
self.transport.write(data)
def disconnect(self):
self.transport.loseConnection()
client_factory = ClientFactory()
client_factory.protocol = TransitClientProtocolTcp
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
pump = iosim.connect(
server_protocol,
iosim.makeFakeServer(server_protocol),
client_protocol,
iosim.makeFakeClient(client_protocol),
)
pump.flush()
self._pumps.append(pump)
return client_protocol
def tearDown(self):
if self._lp:

View File

@ -41,10 +41,12 @@ class _Transit:
token1 = b"\x00"*32
side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side1))
p1.send(handshake(token1, side1))
self.flush()
self.assertEqual(self.count(), 1)
p1.transport.loseConnection()
p1.disconnect()
self.flush()
self.assertEqual(self.count(), 0)
# the token should be removed too