commit
						db48e91531
					
				
							
								
								
									
										208
									
								
								src/wormhole_transit_relay/test/test_backpressure.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										208
									
								
								src/wormhole_transit_relay/test/test_backpressure.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,208 @@
 | 
			
		|||
from io import (
 | 
			
		||||
    StringIO,
 | 
			
		||||
)
 | 
			
		||||
import sys
 | 
			
		||||
import shutil
 | 
			
		||||
 | 
			
		||||
from twisted.trial import unittest
 | 
			
		||||
from twisted.internet.interfaces import (
 | 
			
		||||
    IPullProducer,
 | 
			
		||||
)
 | 
			
		||||
from twisted.internet.protocol import (
 | 
			
		||||
    ProcessProtocol,
 | 
			
		||||
)
 | 
			
		||||
from twisted.internet.defer import (
 | 
			
		||||
    inlineCallbacks,
 | 
			
		||||
    Deferred,
 | 
			
		||||
)
 | 
			
		||||
from autobahn.twisted.websocket import (
 | 
			
		||||
    WebSocketClientProtocol,
 | 
			
		||||
    create_client_agent,
 | 
			
		||||
)
 | 
			
		||||
from zope.interface import implementer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _CollectOutputProtocol(ProcessProtocol):
 | 
			
		||||
    """
 | 
			
		||||
    Internal helper. Collects all output (stdout + stderr) into
 | 
			
		||||
    self.output, and callback's on done with all of it after the
 | 
			
		||||
    process exits (for any reason).
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.done = Deferred()
 | 
			
		||||
        self.running = Deferred()
 | 
			
		||||
        self.output = StringIO()
 | 
			
		||||
 | 
			
		||||
    def processEnded(self, reason):
 | 
			
		||||
        if not self.done.called:
 | 
			
		||||
            self.done.callback(self.output.getvalue())
 | 
			
		||||
 | 
			
		||||
    def outReceived(self, data):
 | 
			
		||||
        print(data.decode(), end="", flush=True)
 | 
			
		||||
        self.output.write(data.decode(sys.getfilesystemencoding()))
 | 
			
		||||
        if not self.running.called:
 | 
			
		||||
            if "on 8088" in self.output.getvalue():
 | 
			
		||||
                self.running.callback(None)
 | 
			
		||||
 | 
			
		||||
    def errReceived(self, data):
 | 
			
		||||
        print("ERR: {}".format(data.decode(sys.getfilesystemencoding())))
 | 
			
		||||
        self.output.write(data.decode(sys.getfilesystemencoding()))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_transit(reactor, proto, tcp_port=None, websocket_port=None):
 | 
			
		||||
    exe = shutil.which("twistd")
 | 
			
		||||
    args = [
 | 
			
		||||
        exe, "-n", "transitrelay",
 | 
			
		||||
    ]
 | 
			
		||||
    if tcp_port is not None:
 | 
			
		||||
        args.append("--port")
 | 
			
		||||
        args.append(tcp_port)
 | 
			
		||||
    if websocket_port is not None:
 | 
			
		||||
        args.append("--websocket")
 | 
			
		||||
        args.append(websocket_port)
 | 
			
		||||
    proc = reactor.spawnProcess(proto, exe, args)
 | 
			
		||||
    return proc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Sender(WebSocketClientProtocol):
 | 
			
		||||
    """
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kw):
 | 
			
		||||
        WebSocketClientProtocol.__init__(self, *args, **kw)
 | 
			
		||||
        self.done = Deferred()
 | 
			
		||||
        self.got_ok = Deferred()
 | 
			
		||||
 | 
			
		||||
    def onMessage(self, payload, is_binary):
 | 
			
		||||
        print("onMessage")
 | 
			
		||||
        if not self.got_ok.called:
 | 
			
		||||
            if payload == b"ok\n":
 | 
			
		||||
                self.got_ok.callback(None)
 | 
			
		||||
        print("send: {}".format(payload.decode("utf8")))
 | 
			
		||||
 | 
			
		||||
    def onClose(self, clean, code, reason):
 | 
			
		||||
        print(f"close: {clean} {code} {reason}")
 | 
			
		||||
        self.done.callback(None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Receiver(WebSocketClientProtocol):
 | 
			
		||||
    """
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kw):
 | 
			
		||||
        WebSocketClientProtocol.__init__(self, *args, **kw)
 | 
			
		||||
        self.done = Deferred()
 | 
			
		||||
        self.first_message = Deferred()
 | 
			
		||||
        self.received = 0
 | 
			
		||||
 | 
			
		||||
    def onMessage(self, payload, is_binary):
 | 
			
		||||
        print("recv: {}".format(len(payload)))
 | 
			
		||||
        self.received += len(payload)
 | 
			
		||||
        if not self.first_message.called:
 | 
			
		||||
            self.first_message.callback(None)
 | 
			
		||||
 | 
			
		||||
    def onClose(self, clean, code, reason):
 | 
			
		||||
        print(f"close: {clean} {code} {reason}")
 | 
			
		||||
        self.done.callback(None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransitWebSockets(unittest.TestCase):
 | 
			
		||||
    """
 | 
			
		||||
    Integration-style tests of the transit WebSocket relay, using the
 | 
			
		||||
    real reactor (and running transit as a subprocess).
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @inlineCallbacks
 | 
			
		||||
    def test_buffer_fills(self):
 | 
			
		||||
        """
 | 
			
		||||
        A running transit relay stops accepting incoming data at a
 | 
			
		||||
        reasonable amount if the peer isn't reading. This test defines
 | 
			
		||||
        that as 'less than 100MiB' although in practice Twisted seems
 | 
			
		||||
        to stop before 10MiB.
 | 
			
		||||
        """
 | 
			
		||||
        from twisted.internet import reactor
 | 
			
		||||
        transit_proto = _CollectOutputProtocol()
 | 
			
		||||
        transit_proc = run_transit(reactor, transit_proto, websocket_port="tcp:8088")
 | 
			
		||||
 | 
			
		||||
        def cleanup_process():
 | 
			
		||||
            transit_proc.signalProcess("HUP")
 | 
			
		||||
            return transit_proto.done
 | 
			
		||||
        self.addCleanup(cleanup_process)
 | 
			
		||||
 | 
			
		||||
        yield transit_proto.running
 | 
			
		||||
        print("Transit running")
 | 
			
		||||
 | 
			
		||||
        agent = create_client_agent(reactor)
 | 
			
		||||
        side_a = yield agent.open("ws://localhost:8088", {}, lambda: Sender())
 | 
			
		||||
        side_b = yield agent.open("ws://localhost:8088", {}, lambda: Receiver())
 | 
			
		||||
 | 
			
		||||
        side_a.sendMessage(b"please relay aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa for side aaaaaaaaaaaaaaaa", True)
 | 
			
		||||
        side_b.sendMessage(b"please relay aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa for side bbbbbbbbbbbbbbbb", True)
 | 
			
		||||
 | 
			
		||||
        yield side_a.got_ok
 | 
			
		||||
        yield side_b.first_message
 | 
			
		||||
 | 
			
		||||
        # remove side_b's filedescriptor from the reactor .. this
 | 
			
		||||
        # means it will not read any more data
 | 
			
		||||
        reactor.removeReader(side_b.transport)
 | 
			
		||||
 | 
			
		||||
        # attempt to send up to 100MiB through side_a .. we should get
 | 
			
		||||
        # backpressure before that works which only manifests itself
 | 
			
		||||
        # as this producer not being asked to produce more
 | 
			
		||||
        max_data = 1024*1024*100 # 100MiB
 | 
			
		||||
 | 
			
		||||
        @implementer(IPullProducer)
 | 
			
		||||
        class ProduceMessages:
 | 
			
		||||
            def __init__(self, ws, on_produce):
 | 
			
		||||
                self._ws = ws
 | 
			
		||||
                self._sent = 0
 | 
			
		||||
                self._max = max_data
 | 
			
		||||
                self._on_produce = on_produce
 | 
			
		||||
 | 
			
		||||
            def resumeProducing(self):
 | 
			
		||||
                self._on_produce()
 | 
			
		||||
                if self._sent >= self._max:
 | 
			
		||||
                    self._ws.sendClose()
 | 
			
		||||
                    return
 | 
			
		||||
                data = b"a" * 1024*1024
 | 
			
		||||
                self._ws.sendMessage(data, True)
 | 
			
		||||
                self._sent += len(data)
 | 
			
		||||
                print("sent {}, total {}".format(len(data), self._sent))
 | 
			
		||||
 | 
			
		||||
        # our only signal is, "did our producer get asked to produce
 | 
			
		||||
        # more data" which it should do periodically. We want to stop
 | 
			
		||||
        # if we haven't seen a new data request for a while -- defined
 | 
			
		||||
        # as "more than 5 seconds".
 | 
			
		||||
 | 
			
		||||
        done = Deferred()
 | 
			
		||||
        last_produce = None
 | 
			
		||||
        timeout = 2  # seconds
 | 
			
		||||
 | 
			
		||||
        def asked_for_data():
 | 
			
		||||
            nonlocal last_produce
 | 
			
		||||
            last_produce = reactor.seconds()
 | 
			
		||||
 | 
			
		||||
        data = ProduceMessages(side_a, asked_for_data)
 | 
			
		||||
        side_a.transport.registerProducer(data, False)
 | 
			
		||||
        data.resumeProducing()
 | 
			
		||||
 | 
			
		||||
        def check_if_done():
 | 
			
		||||
            if last_produce is not None:
 | 
			
		||||
                if reactor.seconds() - last_produce > timeout:
 | 
			
		||||
                    done.callback(None)
 | 
			
		||||
                    return
 | 
			
		||||
            # recursive call to ourselves to check again soon
 | 
			
		||||
            reactor.callLater(.1, check_if_done)
 | 
			
		||||
        check_if_done()
 | 
			
		||||
 | 
			
		||||
        yield done
 | 
			
		||||
 | 
			
		||||
        mib = 1024*1024.0
 | 
			
		||||
        print("Sent {}MiB of {}MiB before backpressure".format(data._sent / mib, max_data / mib))
 | 
			
		||||
        self.assertTrue(data._sent < max_data, "Too much data sent")
 | 
			
		||||
 | 
			
		||||
        side_a.sendClose()
 | 
			
		||||
        side_b.sendClose()
 | 
			
		||||
        yield side_a.done
 | 
			
		||||
        yield side_b.done
 | 
			
		||||
| 
						 | 
				
			
			@ -47,6 +47,7 @@ class TransitConnection(LineReceiver):
 | 
			
		|||
        ITransitClient API
 | 
			
		||||
        """
 | 
			
		||||
        self._buddy = other
 | 
			
		||||
        self._buddy._client.transport.registerProducer(self.transport, True)
 | 
			
		||||
 | 
			
		||||
    def disconnect_partner(self):
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -198,6 +199,7 @@ class WebSocketTransitConnection(WebSocketServerProtocol):
 | 
			
		|||
        ITransitClient API
 | 
			
		||||
        """
 | 
			
		||||
        self._buddy = other
 | 
			
		||||
        self._buddy._client.transport.registerProducer(self.transport, True)
 | 
			
		||||
 | 
			
		||||
    def disconnect_partner(self):
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user