integration-style test of backpressure
This commit is contained in:
parent
01510de60e
commit
aeae7c2bdd
202
src/wormhole_transit_relay/test/test_backpressure.py
Normal file
202
src/wormhole_transit_relay/test/test_backpressure.py
Normal file
|
@ -0,0 +1,202 @@
|
||||||
|
from io import (
|
||||||
|
StringIO,
|
||||||
|
)
|
||||||
|
import sys
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from twisted.trial import unittest
|
||||||
|
from twisted.internet.interfaces import (
|
||||||
|
IPullProducer,
|
||||||
|
)
|
||||||
|
from twisted.internet.protocol import (
|
||||||
|
ProcessProtocol,
|
||||||
|
)
|
||||||
|
from twisted.internet.defer import (
|
||||||
|
inlineCallbacks,
|
||||||
|
Deferred,
|
||||||
|
)
|
||||||
|
from autobahn.twisted.websocket import (
|
||||||
|
WebSocketClientProtocol,
|
||||||
|
create_client_agent,
|
||||||
|
)
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
|
||||||
|
class _CollectOutputProtocol(ProcessProtocol):
|
||||||
|
"""
|
||||||
|
Internal helper. Collects all output (stdout + stderr) into
|
||||||
|
self.output, and callback's on done with all of it after the
|
||||||
|
process exits (for any reason).
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.done = Deferred()
|
||||||
|
self.running = Deferred()
|
||||||
|
self.output = StringIO()
|
||||||
|
|
||||||
|
def processEnded(self, reason):
|
||||||
|
if not self.done.called:
|
||||||
|
self.done.callback(self.output.getvalue())
|
||||||
|
|
||||||
|
def outReceived(self, data):
|
||||||
|
print(data.decode(), end="", flush=True)
|
||||||
|
self.output.write(data.decode(sys.getfilesystemencoding()))
|
||||||
|
if not self.running.called:
|
||||||
|
if "on 8088" in self.output.getvalue():
|
||||||
|
self.running.callback(None)
|
||||||
|
|
||||||
|
def errReceived(self, data):
|
||||||
|
print("ERR: {}".format(data.decode(sys.getfilesystemencoding())))
|
||||||
|
self.output.write(data.decode(sys.getfilesystemencoding()))
|
||||||
|
|
||||||
|
|
||||||
|
def run_transit(reactor, proto, tcp_port=None, websocket_port=None):
|
||||||
|
exe = shutil.which("twistd")
|
||||||
|
args = [
|
||||||
|
exe, "-n", "transitrelay",
|
||||||
|
]
|
||||||
|
if tcp_port is not None:
|
||||||
|
args.append("--port")
|
||||||
|
args.append(tcp_port)
|
||||||
|
if websocket_port is not None:
|
||||||
|
args.append("--websocket")
|
||||||
|
args.append(websocket_port)
|
||||||
|
proc = reactor.spawnProcess(proto, exe, args)
|
||||||
|
return proc
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Sender(WebSocketClientProtocol):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kw):
|
||||||
|
WebSocketClientProtocol.__init__(self, *args, **kw)
|
||||||
|
self.done = Deferred()
|
||||||
|
self.got_ok = Deferred()
|
||||||
|
|
||||||
|
def onMessage(self, payload, is_binary):
|
||||||
|
print("onMessage")
|
||||||
|
if not self.got_ok.called:
|
||||||
|
if payload == b"ok\n":
|
||||||
|
self.got_ok.callback(None)
|
||||||
|
print("send: {}".format(payload.decode("utf8")))
|
||||||
|
|
||||||
|
def onClose(self, clean, code, reason):
|
||||||
|
print(f"close: {clean} {code} {reason}")
|
||||||
|
self.done.callback(None)
|
||||||
|
|
||||||
|
|
||||||
|
class Receiver(WebSocketClientProtocol):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kw):
|
||||||
|
WebSocketClientProtocol.__init__(self, *args, **kw)
|
||||||
|
self.done = Deferred()
|
||||||
|
self.first_message = Deferred()
|
||||||
|
self.received = 0
|
||||||
|
|
||||||
|
def onMessage(self, payload, is_binary):
|
||||||
|
print("recv: {}".format(len(payload)))
|
||||||
|
self.received += len(payload)
|
||||||
|
if not self.first_message.called:
|
||||||
|
self.first_message.callback(None)
|
||||||
|
|
||||||
|
def onClose(self, clean, code, reason):
|
||||||
|
print(f"close: {clean} {code} {reason}")
|
||||||
|
self.done.callback(None)
|
||||||
|
|
||||||
|
|
||||||
|
class TransitWebSockets(unittest.TestCase):
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_buffer_fills(self):
|
||||||
|
"""
|
||||||
|
A running transit relay stops accepting incoming data if the peer
|
||||||
|
isn't reading.
|
||||||
|
"""
|
||||||
|
from twisted.internet import reactor
|
||||||
|
transit_proto = _CollectOutputProtocol()
|
||||||
|
transit_proc = run_transit(reactor, transit_proto, websocket_port="tcp:8088")
|
||||||
|
|
||||||
|
def cleanup_process():
|
||||||
|
transit_proc.signalProcess("HUP")
|
||||||
|
return transit_proto.done
|
||||||
|
self.addCleanup(cleanup_process)
|
||||||
|
|
||||||
|
yield transit_proto.running
|
||||||
|
print("Transit running")
|
||||||
|
|
||||||
|
agent = create_client_agent(reactor)
|
||||||
|
side_a = yield agent.open("ws://localhost:8088", {}, lambda: Sender())
|
||||||
|
side_b = yield agent.open("ws://localhost:8088", {}, lambda: Receiver())
|
||||||
|
|
||||||
|
side_a.sendMessage(b"please relay aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa for side aaaaaaaaaaaaaaaa", True)
|
||||||
|
side_b.sendMessage(b"please relay aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa for side bbbbbbbbbbbbbbbb", True)
|
||||||
|
|
||||||
|
yield side_a.got_ok
|
||||||
|
yield side_b.first_message
|
||||||
|
|
||||||
|
# remove side_b's filedescriptor from the reactor .. this
|
||||||
|
# means it will not read any more data
|
||||||
|
reactor.removeReader(side_b.transport)
|
||||||
|
|
||||||
|
# attempt to send up to 100MiB through side_a .. we should get
|
||||||
|
# backpressure before that works which only manifests itself
|
||||||
|
# as this producer not being asked to produce more
|
||||||
|
max_data = 1024*1024*100 # 100MiB
|
||||||
|
|
||||||
|
@implementer(IPullProducer)
|
||||||
|
class ProduceMessages:
|
||||||
|
def __init__(self, ws, on_produce):
|
||||||
|
self._ws = ws
|
||||||
|
self._sent = 0
|
||||||
|
self._max = max_data
|
||||||
|
self._on_produce = on_produce
|
||||||
|
|
||||||
|
def resumeProducing(self):
|
||||||
|
self._on_produce()
|
||||||
|
if self._sent >= self._max:
|
||||||
|
self._ws.sendClose()
|
||||||
|
return
|
||||||
|
data = b"a" * 1024*1024
|
||||||
|
self._ws.sendMessage(data, True)
|
||||||
|
self._sent += len(data)
|
||||||
|
print("sent {}, total {}".format(len(data), self._sent))
|
||||||
|
|
||||||
|
# our only signal is, "did our producer get asked to produce
|
||||||
|
# more data" which it should do periodically. We want to stop
|
||||||
|
# if we haven't seen a new data request for a while -- defined
|
||||||
|
# as "more than 5 seconds".
|
||||||
|
|
||||||
|
done = Deferred()
|
||||||
|
last_produce = None
|
||||||
|
timeout = 3 # seconds
|
||||||
|
|
||||||
|
def asked_for_data():
|
||||||
|
nonlocal last_produce
|
||||||
|
last_produce = reactor.seconds()
|
||||||
|
|
||||||
|
data = ProduceMessages(side_a, asked_for_data)
|
||||||
|
side_a.transport.registerProducer(data, False)
|
||||||
|
data.resumeProducing()
|
||||||
|
|
||||||
|
def check_if_done():
|
||||||
|
if last_produce is not None:
|
||||||
|
if reactor.seconds() - last_produce > timeout:
|
||||||
|
done.callback(None)
|
||||||
|
return
|
||||||
|
# recursive call to ourselves to check again in a second
|
||||||
|
reactor.callLater(1, check_if_done)
|
||||||
|
check_if_done()
|
||||||
|
|
||||||
|
yield done
|
||||||
|
|
||||||
|
mib = 1024*1024.0
|
||||||
|
print("Sent {}MiB of {}MiB before backpressure".format(data._sent / mib, max_data / mib))
|
||||||
|
self.assertTrue(data._sent < max_data, "Too much data sent")
|
||||||
|
|
||||||
|
side_a.sendClose()
|
||||||
|
side_b.sendClose()
|
||||||
|
yield side_a.done
|
||||||
|
yield side_b.done
|
Loading…
Reference in New Issue
Block a user