Merge pull request #23 from meejah/websocket-support-on-iosim-tests-master
WebSocket support
This commit is contained in:
		
						commit
						80e02d4a77
					
				
							
								
								
									
										54
									
								
								client.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								client.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,54 @@
 | 
			
		|||
"""
 | 
			
		||||
This is a test-client for the transit-relay that uses TCP. It
 | 
			
		||||
doesn't send any data, only prints out data that is received. Uses a
 | 
			
		||||
fixed token of 64 'a' characters. Always connects on localhost:4001
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from twisted.internet import endpoints
 | 
			
		||||
from twisted.internet.defer import (
 | 
			
		||||
    Deferred,
 | 
			
		||||
)
 | 
			
		||||
from twisted.internet.task import react
 | 
			
		||||
from twisted.internet.error import (
 | 
			
		||||
    ConnectionDone,
 | 
			
		||||
)
 | 
			
		||||
from twisted.internet.protocol import (
 | 
			
		||||
    Protocol,
 | 
			
		||||
    Factory,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RelayEchoClient(Protocol):
 | 
			
		||||
    """
 | 
			
		||||
    Speaks the version1 magic wormhole transit relay protocol (as a client)
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def connectionMade(self):
 | 
			
		||||
        print(">CONNECT")
 | 
			
		||||
        self.data = b""
 | 
			
		||||
        self.transport.write(u"please relay {}\n".format(self.factory.token).encode("ascii"))
 | 
			
		||||
 | 
			
		||||
    def dataReceived(self, data):
 | 
			
		||||
        print(">RECV {} bytes".format(len(data)))
 | 
			
		||||
        print(data.decode("ascii"))
 | 
			
		||||
        self.data += data
 | 
			
		||||
        if data == "ok\n":
 | 
			
		||||
            self.transport.write("ding\n")
 | 
			
		||||
 | 
			
		||||
    def connectionLost(self, reason):
 | 
			
		||||
        if isinstance(reason.value, ConnectionDone):
 | 
			
		||||
            self.factory.done.callback(None)
 | 
			
		||||
        else:
 | 
			
		||||
            print(">DISCONNCT: {}".format(reason))
 | 
			
		||||
            self.factory.done.callback(reason)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@react
 | 
			
		||||
def main(reactor):
 | 
			
		||||
    ep = endpoints.clientFromString(reactor, "tcp:localhost:4001")
 | 
			
		||||
    f = Factory.forProtocol(RelayEchoClient)
 | 
			
		||||
    f.token = "a" * 64
 | 
			
		||||
    f.done = Deferred()
 | 
			
		||||
    ep.connect(f)
 | 
			
		||||
    return f.done
 | 
			
		||||
| 
						 | 
				
			
			@ -50,6 +50,15 @@ The relevant arguments are:
 | 
			
		|||
* ``--usage-db=``: maintains a SQLite database with current and historical usage data
 | 
			
		||||
* ``--blur-usage=``: round logged timestamps and data sizes
 | 
			
		||||
 | 
			
		||||
For WebSockets support, two additional arguments:
 | 
			
		||||
 | 
			
		||||
* ``--websocket``: the endpoint to listen for websocket connections
 | 
			
		||||
  on, like ``tcp:4002``
 | 
			
		||||
* ``--websocket-url``: the URL of the WebSocket connection. This may
 | 
			
		||||
  be different from the listening endpoint because of port-forwarding
 | 
			
		||||
  and so forth. By default it will be ``ws://localhost:<port>`` if not
 | 
			
		||||
  provided
 | 
			
		||||
 | 
			
		||||
When you use ``twist``, the relay runs in the foreground, so it will
 | 
			
		||||
generally exit as soon as the controlling terminal exits. For persistent
 | 
			
		||||
environments, you should daemonize the server.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										3
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								setup.py
									
									
									
									
									
								
							| 
						 | 
				
			
			@ -18,7 +18,8 @@ setup(name="magic-wormhole-transit-relay",
 | 
			
		|||
                ],
 | 
			
		||||
      package_data={"wormhole_transit_relay": ["db-schemas/*.sql"]},
 | 
			
		||||
      install_requires=[
 | 
			
		||||
          "twisted >= 17.5.0",
 | 
			
		||||
          "twisted >= 21.2.0",
 | 
			
		||||
          "autobahn >= 21.3.1",
 | 
			
		||||
      ],
 | 
			
		||||
      extras_require={
 | 
			
		||||
          ':sys_platform=="win32"': ["pypiwin32"],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										477
									
								
								src/wormhole_transit_relay/server_state.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										477
									
								
								src/wormhole_transit_relay/server_state.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,477 @@
 | 
			
		|||
from collections import defaultdict
 | 
			
		||||
 | 
			
		||||
import automat
 | 
			
		||||
from twisted.python import log
 | 
			
		||||
from zope.interface import (
 | 
			
		||||
    Interface,
 | 
			
		||||
    Attribute,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ITransitClient(Interface):
 | 
			
		||||
    """
 | 
			
		||||
    Represents the client side of a connection to this transit
 | 
			
		||||
    relay. This is used by TransitServerState instances.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    started_time = Attribute("timestamp when the connection was established")
 | 
			
		||||
 | 
			
		||||
    def send(data):
 | 
			
		||||
        """
 | 
			
		||||
        Send some byets to the client
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    def disconnect():
 | 
			
		||||
        """
 | 
			
		||||
        Disconnect the client transport
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    def connect_partner(other):
 | 
			
		||||
        """
 | 
			
		||||
        Hook up to our partner.
 | 
			
		||||
        :param ITransitClient other: our partner
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    def disconnect_partner():
 | 
			
		||||
        """
 | 
			
		||||
        Disconnect our partner's transport
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ActiveConnections(object):
 | 
			
		||||
    """
 | 
			
		||||
    Tracks active connections.
 | 
			
		||||
 | 
			
		||||
    A connection is 'active' when both sides have shown up and they
 | 
			
		||||
    are glued together (and thus could be passing data back and forth
 | 
			
		||||
    if any is flowing).
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._connections = set()
 | 
			
		||||
 | 
			
		||||
    def register(self, side0, side1):
 | 
			
		||||
        """
 | 
			
		||||
        A connection has become active so register both its sides
 | 
			
		||||
 | 
			
		||||
        :param TransitConnection side0: one side of the connection
 | 
			
		||||
        :param TransitConnection side1: one side of the connection
 | 
			
		||||
        """
 | 
			
		||||
        self._connections.add(side0)
 | 
			
		||||
        self._connections.add(side1)
 | 
			
		||||
 | 
			
		||||
    def unregister(self, side):
 | 
			
		||||
        """
 | 
			
		||||
        One side of a connection has become inactive.
 | 
			
		||||
 | 
			
		||||
        :param TransitConnection side: an inactive side of a connection
 | 
			
		||||
        """
 | 
			
		||||
        self._connections.discard(side)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PendingRequests(object):
 | 
			
		||||
    """
 | 
			
		||||
    Tracks outstanding (non-"active") requests.
 | 
			
		||||
 | 
			
		||||
    We register client connections against the tokens we have
 | 
			
		||||
    received. When the other side shows up we can thus match it to the
 | 
			
		||||
    correct partner connection. At this point, the connection becomes
 | 
			
		||||
    "active" is and is thus no longer "pending" and so will no longer
 | 
			
		||||
    be in this collection.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, active_connections):
 | 
			
		||||
        """
 | 
			
		||||
        :param active_connections: an instance of ActiveConnections where
 | 
			
		||||
            connections are put when both sides arrive.
 | 
			
		||||
        """
 | 
			
		||||
        self._requests = defaultdict(set) # token -> set((side, TransitConnection))
 | 
			
		||||
        self._active = active_connections
 | 
			
		||||
 | 
			
		||||
    def unregister(self, token, side, tc):
 | 
			
		||||
        """
 | 
			
		||||
        We no longer care about a particular client (e.g. it has
 | 
			
		||||
        disconnected).
 | 
			
		||||
        """
 | 
			
		||||
        if token in self._requests:
 | 
			
		||||
            self._requests[token].discard((side, tc))
 | 
			
		||||
            if not self._requests[token]:
 | 
			
		||||
                # no more sides; token is dead
 | 
			
		||||
                del self._requests[token]
 | 
			
		||||
        self._active.unregister(tc)
 | 
			
		||||
 | 
			
		||||
    def register(self, token, new_side, new_tc):
 | 
			
		||||
        """
 | 
			
		||||
        A client has connected and successfully offered a token (and
 | 
			
		||||
        optional 'side' token). If this is the first one for this
 | 
			
		||||
        token, we merely remember it. If it is the second side for
 | 
			
		||||
        this token we connect them together.
 | 
			
		||||
 | 
			
		||||
        :param bytes token: the token for this connection.
 | 
			
		||||
 | 
			
		||||
        :param bytes new_side: None or the side token for this connection
 | 
			
		||||
 | 
			
		||||
        :param TransitServerState new_tc: the state-machine of the connection
 | 
			
		||||
 | 
			
		||||
        :returns bool: True if we are the first side to register this
 | 
			
		||||
            token
 | 
			
		||||
        """
 | 
			
		||||
        potentials = self._requests[token]
 | 
			
		||||
        for old in potentials:
 | 
			
		||||
            (old_side, old_tc) = old
 | 
			
		||||
            if ((old_side is None)
 | 
			
		||||
                or (new_side is None)
 | 
			
		||||
                or (old_side != new_side)):
 | 
			
		||||
                # we found a match
 | 
			
		||||
 | 
			
		||||
                # drop and stop tracking the rest
 | 
			
		||||
                potentials.remove(old)
 | 
			
		||||
                for (_, leftover_tc) in potentials.copy():
 | 
			
		||||
                    # Don't record this as errory. It's just a spare connection
 | 
			
		||||
                    # from the same side as a connection that got used. This
 | 
			
		||||
                    # can happen if the connection hint contains multiple
 | 
			
		||||
                    # addresses (we don't currently support those, but it'd
 | 
			
		||||
                    # probably be useful in the future).
 | 
			
		||||
                    leftover_tc.partner_connection_lost()
 | 
			
		||||
                self._requests.pop(token, None)
 | 
			
		||||
 | 
			
		||||
                # glue the two ends together
 | 
			
		||||
                self._active.register(new_tc, old_tc)
 | 
			
		||||
                new_tc.got_partner(old_tc)
 | 
			
		||||
                old_tc.got_partner(new_tc)
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
        potentials.add((new_side, new_tc))
 | 
			
		||||
        return True
 | 
			
		||||
        # TODO: timer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransitServerState(object):
 | 
			
		||||
    """
 | 
			
		||||
    Encapsulates the state-machine of the server side of a transit
 | 
			
		||||
    relay connection.
 | 
			
		||||
 | 
			
		||||
    Once the protocol has been told to relay (or to relay for a side)
 | 
			
		||||
    it starts passing all received bytes to the other side until it
 | 
			
		||||
    closes.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    _machine = automat.MethodicalMachine()
 | 
			
		||||
    _client = None
 | 
			
		||||
    _buddy = None
 | 
			
		||||
    _token = None
 | 
			
		||||
    _side = None
 | 
			
		||||
    _first = None
 | 
			
		||||
    _mood = "empty"
 | 
			
		||||
    _total_sent = 0
 | 
			
		||||
 | 
			
		||||
    def __init__(self, pending_requests, usage_recorder):
 | 
			
		||||
        self._pending_requests = pending_requests
 | 
			
		||||
        self._usage = usage_recorder
 | 
			
		||||
 | 
			
		||||
    def get_token(self):
 | 
			
		||||
        """
 | 
			
		||||
        :returns str: a string describing our token. This will be "-" if
 | 
			
		||||
            we have no token yet, or "{16 chars}-<unsided>" if we have
 | 
			
		||||
            just a token or "{16 chars}-{16 chars}" if we have a token and
 | 
			
		||||
            a side.
 | 
			
		||||
        """
 | 
			
		||||
        d = "-"
 | 
			
		||||
        if self._token is not None:
 | 
			
		||||
            d = self._token[:16].decode("ascii")
 | 
			
		||||
 | 
			
		||||
            if self._side is not None:
 | 
			
		||||
                d += "-" + self._side.decode("ascii")
 | 
			
		||||
            else:
 | 
			
		||||
                d += "-<unsided>"
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    @_machine.input()
 | 
			
		||||
    def connection_made(self, client):
 | 
			
		||||
        """
 | 
			
		||||
        A client has connected. May only be called once.
 | 
			
		||||
 | 
			
		||||
        :param ITransitClient client: our client.
 | 
			
		||||
        """
 | 
			
		||||
        # NB: the "only called once" is enforced by the state-machine;
 | 
			
		||||
        # this input is only valid for the "listening" state, to which
 | 
			
		||||
        # we never return.
 | 
			
		||||
 | 
			
		||||
    @_machine.input()
 | 
			
		||||
    def please_relay(self, token):
 | 
			
		||||
        """
 | 
			
		||||
        A 'please relay X' message has been received (the original version
 | 
			
		||||
        of the protocol).
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    @_machine.input()
 | 
			
		||||
    def please_relay_for_side(self, token, side):
 | 
			
		||||
        """
 | 
			
		||||
        A 'please relay X for side Y' message has been received (the
 | 
			
		||||
        second version of the protocol).
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    @_machine.input()
 | 
			
		||||
    def bad_token(self):
 | 
			
		||||
        """
 | 
			
		||||
        A bad token / relay line was received (e.g. couldn't be parsed)
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    @_machine.input()
 | 
			
		||||
    def got_partner(self, client):
 | 
			
		||||
        """
 | 
			
		||||
        The partner for this relay session has been found
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    @_machine.input()
 | 
			
		||||
    def connection_lost(self):
 | 
			
		||||
        """
 | 
			
		||||
        Our transport has failed.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    @_machine.input()
 | 
			
		||||
    def partner_connection_lost(self):
 | 
			
		||||
        """
 | 
			
		||||
        Our partner's transport has failed.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    @_machine.input()
 | 
			
		||||
    def got_bytes(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        Some bytes have arrived (that aren't part of the handshake)
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _remember_client(self, client):
 | 
			
		||||
        self._client = client
 | 
			
		||||
 | 
			
		||||
    # note that there is no corresponding "_forget_client" because we
 | 
			
		||||
    # may still want to access it after it is gone .. for example, to
 | 
			
		||||
    # get the .started_time for logging purposes
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _register_token(self, token):
 | 
			
		||||
        return self._real_register_token_for_side(token, None)
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _register_token_for_side(self, token, side):
 | 
			
		||||
        return self._real_register_token_for_side(token, side)
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _unregister(self):
 | 
			
		||||
        """
 | 
			
		||||
        remove us from the thing that remembers tokens and sides
 | 
			
		||||
        """
 | 
			
		||||
        return self._pending_requests.unregister(self._token, self._side, self)
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _send_bad(self):
 | 
			
		||||
        self._mood = "errory"
 | 
			
		||||
        self._client.send(b"bad handshake\n")
 | 
			
		||||
        if self._client.factory.log_requests:
 | 
			
		||||
            log.msg("transit handshake failure")
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _send_ok(self):
 | 
			
		||||
        self._client.send(b"ok\n")
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _send_impatient(self):
 | 
			
		||||
        self._client.send(b"impatient\n")
 | 
			
		||||
        if self._client.factory.log_requests:
 | 
			
		||||
            log.msg("transit impatience failure")
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _count_bytes(self, data):
 | 
			
		||||
        self._total_sent += len(data)
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _send_to_partner(self, data):
 | 
			
		||||
        self._buddy._client.send(data)
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _connect_partner(self, client):
 | 
			
		||||
        self._buddy = client
 | 
			
		||||
        self._client.connect_partner(client)
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _disconnect(self):
 | 
			
		||||
        self._client.disconnect()
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _disconnect_partner(self):
 | 
			
		||||
        self._client.disconnect_partner()
 | 
			
		||||
 | 
			
		||||
    # some outputs to record "usage" information ..
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _record_usage(self):
 | 
			
		||||
        if self._mood == "jilted":
 | 
			
		||||
            if self._buddy and self._buddy._mood == "happy":
 | 
			
		||||
                return
 | 
			
		||||
        self._usage.record(
 | 
			
		||||
            started=self._client.started_time,
 | 
			
		||||
            buddy_started=self._buddy._client.started_time if self._buddy is not None else None,
 | 
			
		||||
            result=self._mood,
 | 
			
		||||
            bytes_sent=self._total_sent,
 | 
			
		||||
            buddy_bytes=self._buddy._total_sent if self._buddy is not None else None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # some outputs to record the "mood" ..
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _mood_happy(self):
 | 
			
		||||
        self._mood = "happy"
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _mood_lonely(self):
 | 
			
		||||
        self._mood = "lonely"
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _mood_redundant(self):
 | 
			
		||||
        self._mood = "redundant"
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _mood_impatient(self):
 | 
			
		||||
        self._mood = "impatient"
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _mood_errory(self):
 | 
			
		||||
        self._mood = "errory"
 | 
			
		||||
 | 
			
		||||
    @_machine.output()
 | 
			
		||||
    def _mood_happy_if_first(self):
 | 
			
		||||
        """
 | 
			
		||||
        We disconnected first so we're only happy if we also connected
 | 
			
		||||
        first.
 | 
			
		||||
        """
 | 
			
		||||
        if self._first:
 | 
			
		||||
            self._mood = "happy"
 | 
			
		||||
        else:
 | 
			
		||||
            self._mood = "jilted"
 | 
			
		||||
 | 
			
		||||
    def _real_register_token_for_side(self, token, side):
 | 
			
		||||
        """
 | 
			
		||||
        A client has connected and sent a valid version 1 or version 2
 | 
			
		||||
        handshake. If the former, `side` will be None.
 | 
			
		||||
 | 
			
		||||
        In either case, we remember the tokens and register
 | 
			
		||||
        ourselves. This might result in 'got_partner' notifications to
 | 
			
		||||
        two state-machines if this is the second side for a given token.
 | 
			
		||||
 | 
			
		||||
        :param bytes token: the token
 | 
			
		||||
        :param bytes side: The side token (or None)
 | 
			
		||||
        """
 | 
			
		||||
        self._token = token
 | 
			
		||||
        self._side = side
 | 
			
		||||
        self._first = self._pending_requests.register(token, side, self)
 | 
			
		||||
 | 
			
		||||
    @_machine.state(initial=True)
 | 
			
		||||
    def listening(self):
 | 
			
		||||
        """
 | 
			
		||||
        Initial state, awaiting connection.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    @_machine.state()
 | 
			
		||||
    def wait_relay(self):
 | 
			
		||||
        """
 | 
			
		||||
        Waiting for a 'relay' message
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    @_machine.state()
 | 
			
		||||
    def wait_partner(self):
 | 
			
		||||
        """
 | 
			
		||||
        Waiting for our partner to connect
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    @_machine.state()
 | 
			
		||||
    def relaying(self):
 | 
			
		||||
        """
 | 
			
		||||
        Relaying bytes to our partner
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    @_machine.state()
 | 
			
		||||
    def done(self):
 | 
			
		||||
        """
 | 
			
		||||
        Terminal state
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    listening.upon(
 | 
			
		||||
        connection_made,
 | 
			
		||||
        enter=wait_relay,
 | 
			
		||||
        outputs=[_remember_client],
 | 
			
		||||
    )
 | 
			
		||||
    listening.upon(
 | 
			
		||||
        connection_lost,
 | 
			
		||||
        enter=done,
 | 
			
		||||
        outputs=[_mood_errory],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    wait_relay.upon(
 | 
			
		||||
        please_relay,
 | 
			
		||||
        enter=wait_partner,
 | 
			
		||||
        outputs=[_mood_lonely, _register_token],
 | 
			
		||||
    )
 | 
			
		||||
    wait_relay.upon(
 | 
			
		||||
        please_relay_for_side,
 | 
			
		||||
        enter=wait_partner,
 | 
			
		||||
        outputs=[_mood_lonely, _register_token_for_side],
 | 
			
		||||
    )
 | 
			
		||||
    wait_relay.upon(
 | 
			
		||||
        bad_token,
 | 
			
		||||
        enter=done,
 | 
			
		||||
        outputs=[_mood_errory, _send_bad, _disconnect, _record_usage],
 | 
			
		||||
    )
 | 
			
		||||
    wait_relay.upon(
 | 
			
		||||
        got_bytes,
 | 
			
		||||
        enter=done,
 | 
			
		||||
        outputs=[_count_bytes, _mood_errory, _disconnect, _record_usage],
 | 
			
		||||
    )
 | 
			
		||||
    wait_relay.upon(
 | 
			
		||||
        connection_lost,
 | 
			
		||||
        enter=done,
 | 
			
		||||
        outputs=[_disconnect, _record_usage],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    wait_partner.upon(
 | 
			
		||||
        got_partner,
 | 
			
		||||
        enter=relaying,
 | 
			
		||||
        outputs=[_mood_happy, _send_ok, _connect_partner],
 | 
			
		||||
    )
 | 
			
		||||
    wait_partner.upon(
 | 
			
		||||
        connection_lost,
 | 
			
		||||
        enter=done,
 | 
			
		||||
        outputs=[_mood_lonely, _unregister, _record_usage],
 | 
			
		||||
    )
 | 
			
		||||
    wait_partner.upon(
 | 
			
		||||
        got_bytes,
 | 
			
		||||
        enter=done,
 | 
			
		||||
        outputs=[_mood_impatient, _send_impatient, _disconnect, _unregister, _record_usage],
 | 
			
		||||
    )
 | 
			
		||||
    wait_partner.upon(
 | 
			
		||||
        partner_connection_lost,
 | 
			
		||||
        enter=done,
 | 
			
		||||
        outputs=[_mood_redundant, _disconnect, _record_usage],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    relaying.upon(
 | 
			
		||||
        got_bytes,
 | 
			
		||||
        enter=relaying,
 | 
			
		||||
        outputs=[_count_bytes, _send_to_partner],
 | 
			
		||||
    )
 | 
			
		||||
    relaying.upon(
 | 
			
		||||
        connection_lost,
 | 
			
		||||
        enter=done,
 | 
			
		||||
        outputs=[_mood_happy_if_first, _disconnect_partner, _unregister, _record_usage],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    done.upon(
 | 
			
		||||
        connection_lost,
 | 
			
		||||
        enter=done,
 | 
			
		||||
        outputs=[],
 | 
			
		||||
    )
 | 
			
		||||
    done.upon(
 | 
			
		||||
        partner_connection_lost,
 | 
			
		||||
        enter=done,
 | 
			
		||||
        outputs=[],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # uncomment to turn on state-machine tracing
 | 
			
		||||
    # set_trace_function = _machine._setTrace
 | 
			
		||||
| 
						 | 
				
			
			@ -5,8 +5,14 @@ from twisted.application.service import MultiService
 | 
			
		|||
from twisted.application.internet import (TimerService,
 | 
			
		||||
                                          StreamServerEndpointService)
 | 
			
		||||
from twisted.internet import endpoints
 | 
			
		||||
from twisted.internet import protocol
 | 
			
		||||
 | 
			
		||||
from autobahn.twisted.websocket import WebSocketServerFactory
 | 
			
		||||
 | 
			
		||||
from . import transit_server
 | 
			
		||||
from .usage import create_usage_tracker
 | 
			
		||||
from .increase_rlimits import increase_rlimits
 | 
			
		||||
from .database import get_db
 | 
			
		||||
 | 
			
		||||
LONGDESC = """\
 | 
			
		||||
This plugin sets up a 'Transit Relay' server for magic-wormhole. This service
 | 
			
		||||
| 
						 | 
				
			
			@ -20,6 +26,8 @@ class Options(usage.Options):
 | 
			
		|||
 | 
			
		||||
    optParameters = [
 | 
			
		||||
        ("port", "p", "tcp:4001:interface=\:\:", "endpoint to listen on"),
 | 
			
		||||
        ("websocket", "w", None, "endpoint to listen for WebSocket connections"),
 | 
			
		||||
        ("websocket-url", "u", None, "WebSocket URL (derived from endpoint if not provided)"),
 | 
			
		||||
        ("blur-usage", None, None, "blur timestamps and data sizes in logs"),
 | 
			
		||||
        ("log-fd", None, None, "write JSON usage logs to this file descriptor"),
 | 
			
		||||
        ("usage-db", None, None, "record usage data (SQLite)"),
 | 
			
		||||
| 
						 | 
				
			
			@ -31,14 +39,45 @@ class Options(usage.Options):
 | 
			
		|||
 | 
			
		||||
def makeService(config, reactor=reactor):
 | 
			
		||||
    increase_rlimits()
 | 
			
		||||
    ep = endpoints.serverFromString(reactor, config["port"]) # to listen
 | 
			
		||||
    log_file = (os.fdopen(int(config["log-fd"]), "w")
 | 
			
		||||
                if config["log-fd"] is not None
 | 
			
		||||
                else None)
 | 
			
		||||
    f = transit_server.Transit(blur_usage=config["blur-usage"],
 | 
			
		||||
                               log_file=log_file,
 | 
			
		||||
                               usage_db=config["usage-db"])
 | 
			
		||||
    tcp_ep = endpoints.serverFromString(reactor, config["port"]) # to listen
 | 
			
		||||
    ws_ep = (
 | 
			
		||||
        endpoints.serverFromString(reactor, config["websocket"])
 | 
			
		||||
        if config["websocket"] is not None
 | 
			
		||||
        else None
 | 
			
		||||
    )
 | 
			
		||||
    log_file = (
 | 
			
		||||
        os.fdopen(int(config["log-fd"]), "w")
 | 
			
		||||
        if config["log-fd"] is not None
 | 
			
		||||
        else None
 | 
			
		||||
    )
 | 
			
		||||
    db = None if config["usage-db"] is None else get_db(config["usage-db"])
 | 
			
		||||
    usage = create_usage_tracker(
 | 
			
		||||
        blur_usage=config["blur-usage"],
 | 
			
		||||
        log_file=log_file,
 | 
			
		||||
        usage_db=db,
 | 
			
		||||
    )
 | 
			
		||||
    transit = transit_server.Transit(usage, reactor.seconds)
 | 
			
		||||
    tcp_factory = protocol.ServerFactory()
 | 
			
		||||
    tcp_factory.protocol = transit_server.TransitConnection
 | 
			
		||||
    tcp_factory.log_requests = False
 | 
			
		||||
 | 
			
		||||
    if ws_ep is not None:
 | 
			
		||||
        ws_url = config["websocket-url"]
 | 
			
		||||
        if ws_url is None:
 | 
			
		||||
            # we're using a "private" attribute here but I don't see
 | 
			
		||||
            # any useful alternative unless we also want to parse
 | 
			
		||||
            # Twisted endpoint-strings.
 | 
			
		||||
            ws_url = "ws://localhost:{}/".format(ws_ep._port)
 | 
			
		||||
            print("Using WebSocket URL '{}'".format(ws_url))
 | 
			
		||||
        ws_factory = WebSocketServerFactory(ws_url)
 | 
			
		||||
        ws_factory.protocol = transit_server.WebSocketTransitConnection
 | 
			
		||||
        ws_factory.transit = transit
 | 
			
		||||
        ws_factory.log_requests = False
 | 
			
		||||
 | 
			
		||||
    tcp_factory.transit = transit
 | 
			
		||||
    parent = MultiService()
 | 
			
		||||
    StreamServerEndpointService(ep, f).setServiceParent(parent)
 | 
			
		||||
    TimerService(5*60.0, f.timerUpdateStats).setServiceParent(parent)
 | 
			
		||||
    StreamServerEndpointService(tcp_ep, tcp_factory).setServiceParent(parent)
 | 
			
		||||
    if ws_ep is not None:
 | 
			
		||||
        StreamServerEndpointService(ws_ep, ws_factory).setServiceParent(parent)
 | 
			
		||||
    TimerService(5*60.0, transit.update_stats).setServiceParent(parent)
 | 
			
		||||
    return parent
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,7 +10,10 @@ from zope.interface import (
 | 
			
		|||
)
 | 
			
		||||
from ..transit_server import (
 | 
			
		||||
    Transit,
 | 
			
		||||
    TransitConnection,
 | 
			
		||||
)
 | 
			
		||||
from twisted.internet.protocol import ServerFactory
 | 
			
		||||
from ..usage import create_usage_tracker
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IRelayTestClient(Interface):
 | 
			
		||||
| 
						 | 
				
			
			@ -42,6 +45,7 @@ class IRelayTestClient(Interface):
 | 
			
		|||
        Erase any received data to this point.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ServerBase:
 | 
			
		||||
    log_requests = False
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -62,19 +66,30 @@ class ServerBase:
 | 
			
		|||
            self.flush()
 | 
			
		||||
 | 
			
		||||
    def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None):
 | 
			
		||||
        self._transit_server = Transit(
 | 
			
		||||
        usage = create_usage_tracker(
 | 
			
		||||
            blur_usage=blur_usage,
 | 
			
		||||
            log_file=log_file,
 | 
			
		||||
            usage_db=usage_db,
 | 
			
		||||
        )
 | 
			
		||||
        self._transit_server._debug_log = self.log_requests
 | 
			
		||||
        self._transit_server = Transit(usage, lambda: 123456789.0)
 | 
			
		||||
 | 
			
		||||
    def new_protocol(self):
 | 
			
		||||
        """
 | 
			
		||||
        This should be overridden by derived test-case classes to decide
 | 
			
		||||
        if they want a TCP or WebSockets protocol.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def new_protocol_tcp(self):
 | 
			
		||||
        """
 | 
			
		||||
        Create a new client protocol connected to the server.
 | 
			
		||||
        :returns: a IRelayTestClient implementation
 | 
			
		||||
        """
 | 
			
		||||
        server_protocol = self._transit_server.buildProtocol(('127.0.0.1', 0))
 | 
			
		||||
        server_factory = ServerFactory()
 | 
			
		||||
        server_factory.protocol = TransitConnection
 | 
			
		||||
        server_factory.transit = self._transit_server
 | 
			
		||||
        server_factory.log_requests = self.log_requests
 | 
			
		||||
        server_protocol = server_factory.buildProtocol(('127.0.0.1', 0))
 | 
			
		||||
 | 
			
		||||
        @implementer(IRelayTestClient)
 | 
			
		||||
        class TransitClientProtocolTcp(Protocol):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -8,12 +8,29 @@ class Config(unittest.TestCase):
 | 
			
		|||
        o = server_tap.Options()
 | 
			
		||||
        o.parseOptions([])
 | 
			
		||||
        self.assertEqual(o, {"blur-usage": None, "log-fd": None,
 | 
			
		||||
                             "usage-db": None, "port": PORT})
 | 
			
		||||
                             "usage-db": None, "port": PORT,
 | 
			
		||||
                             "websocket": None, "websocket-url": None})
 | 
			
		||||
    def test_blur(self):
 | 
			
		||||
        o = server_tap.Options()
 | 
			
		||||
        o.parseOptions(["--blur-usage=60"])
 | 
			
		||||
        self.assertEqual(o, {"blur-usage": 60, "log-fd": None,
 | 
			
		||||
                             "usage-db": None, "port": PORT})
 | 
			
		||||
                             "usage-db": None, "port": PORT,
 | 
			
		||||
                             "websocket": None, "websocket-url": None})
 | 
			
		||||
 | 
			
		||||
    def test_websocket(self):
 | 
			
		||||
        o = server_tap.Options()
 | 
			
		||||
        o.parseOptions(["--websocket=tcp:4004"])
 | 
			
		||||
        self.assertEqual(o, {"blur-usage": None, "log-fd": None,
 | 
			
		||||
                             "usage-db": None, "port": PORT,
 | 
			
		||||
                             "websocket": "tcp:4004", "websocket-url": None})
 | 
			
		||||
 | 
			
		||||
    def test_websocket_url(self):
 | 
			
		||||
        o = server_tap.Options()
 | 
			
		||||
        o.parseOptions(["--websocket=tcp:4004", "--websocket-url=ws://example.com/"])
 | 
			
		||||
        self.assertEqual(o, {"blur-usage": None, "log-fd": None,
 | 
			
		||||
                             "usage-db": None, "port": PORT,
 | 
			
		||||
                             "websocket": "tcp:4004",
 | 
			
		||||
                             "websocket-url": "ws://example.com/"})
 | 
			
		||||
 | 
			
		||||
    def test_string(self):
 | 
			
		||||
        o = server_tap.Options()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,13 +1,14 @@
 | 
			
		|||
from twisted.trial import unittest
 | 
			
		||||
from unittest import mock
 | 
			
		||||
from twisted.application.service import MultiService
 | 
			
		||||
from autobahn.twisted.websocket import WebSocketServerFactory
 | 
			
		||||
from .. import server_tap
 | 
			
		||||
 | 
			
		||||
class Service(unittest.TestCase):
 | 
			
		||||
    def test_defaults(self):
 | 
			
		||||
        o = server_tap.Options()
 | 
			
		||||
        o.parseOptions([])
 | 
			
		||||
        with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t:
 | 
			
		||||
        with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t:
 | 
			
		||||
            s = server_tap.makeService(o)
 | 
			
		||||
        self.assertEqual(t.mock_calls,
 | 
			
		||||
                         [mock.call(blur_usage=None,
 | 
			
		||||
| 
						 | 
				
			
			@ -17,7 +18,7 @@ class Service(unittest.TestCase):
 | 
			
		|||
    def test_blur(self):
 | 
			
		||||
        o = server_tap.Options()
 | 
			
		||||
        o.parseOptions(["--blur-usage=60"])
 | 
			
		||||
        with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t:
 | 
			
		||||
        with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t:
 | 
			
		||||
            server_tap.makeService(o)
 | 
			
		||||
        self.assertEqual(t.mock_calls,
 | 
			
		||||
                         [mock.call(blur_usage=60,
 | 
			
		||||
| 
						 | 
				
			
			@ -27,7 +28,7 @@ class Service(unittest.TestCase):
 | 
			
		|||
        o = server_tap.Options()
 | 
			
		||||
        o.parseOptions(["--log-fd=99"])
 | 
			
		||||
        fd = object()
 | 
			
		||||
        with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t:
 | 
			
		||||
        with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t:
 | 
			
		||||
            with mock.patch("wormhole_transit_relay.server_tap.os.fdopen",
 | 
			
		||||
                            return_value=fd) as f:
 | 
			
		||||
                server_tap.makeService(o)
 | 
			
		||||
| 
						 | 
				
			
			@ -36,3 +37,34 @@ class Service(unittest.TestCase):
 | 
			
		|||
                         [mock.call(blur_usage=None,
 | 
			
		||||
                                    log_file=fd, usage_db=None)])
 | 
			
		||||
 | 
			
		||||
    def test_websocket(self):
 | 
			
		||||
        """
 | 
			
		||||
        A websocket factory is created when passing --websocket
 | 
			
		||||
        """
 | 
			
		||||
        o = server_tap.Options()
 | 
			
		||||
        o.parseOptions(["--websocket=tcp:4004"])
 | 
			
		||||
        services = server_tap.makeService(o)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            any(
 | 
			
		||||
                isinstance(s.factory, WebSocketServerFactory)
 | 
			
		||||
                for s in services.services
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_websocket_explicit_url(self):
 | 
			
		||||
        """
 | 
			
		||||
        A websocket factory is created with --websocket and
 | 
			
		||||
        --websocket-url
 | 
			
		||||
        """
 | 
			
		||||
        o = server_tap.Options()
 | 
			
		||||
        o.parseOptions([
 | 
			
		||||
            "--websocket=tcp:4004",
 | 
			
		||||
            "--websocket-url=ws://example.com:4004",
 | 
			
		||||
        ])
 | 
			
		||||
        services = server_tap.makeService(o)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            any(
 | 
			
		||||
                isinstance(s.factory, WebSocketServerFactory)
 | 
			
		||||
                for s in services.services
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,27 +1,38 @@
 | 
			
		|||
import os, io, json, sqlite3
 | 
			
		||||
import os, io, json
 | 
			
		||||
from unittest import mock
 | 
			
		||||
from twisted.trial import unittest
 | 
			
		||||
from ..transit_server import Transit
 | 
			
		||||
from ..usage import create_usage_tracker
 | 
			
		||||
from .. import database
 | 
			
		||||
 | 
			
		||||
class DB(unittest.TestCase):
 | 
			
		||||
    def open_db(self, dbfile):
 | 
			
		||||
        db = sqlite3.connect(dbfile)
 | 
			
		||||
        database._initialize_db_connection(db)
 | 
			
		||||
        return db
 | 
			
		||||
 | 
			
		||||
    def test_db(self):
 | 
			
		||||
 | 
			
		||||
        T = 1519075308.0
 | 
			
		||||
 | 
			
		||||
        class Timer:
 | 
			
		||||
            t = T
 | 
			
		||||
            def __call__(self):
 | 
			
		||||
                return self.t
 | 
			
		||||
        get_time = Timer()
 | 
			
		||||
 | 
			
		||||
        d = self.mktemp()
 | 
			
		||||
        os.mkdir(d)
 | 
			
		||||
        usage_db = os.path.join(d, "usage.sqlite")
 | 
			
		||||
        with mock.patch("time.time", return_value=T+0):
 | 
			
		||||
            t = Transit(blur_usage=None, log_file=None, usage_db=usage_db)
 | 
			
		||||
        db = self.open_db(usage_db)
 | 
			
		||||
        db = database.get_db(usage_db)
 | 
			
		||||
        t = Transit(
 | 
			
		||||
            create_usage_tracker(blur_usage=None, log_file=None, usage_db=db),
 | 
			
		||||
            get_time,
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(t.usage._backends), 1)
 | 
			
		||||
        usage = list(t.usage._backends)[0]
 | 
			
		||||
 | 
			
		||||
        get_time.t = T + 1
 | 
			
		||||
        usage.record_usage(started=123, mood="happy", total_bytes=100,
 | 
			
		||||
                           total_time=10, waiting_time=2)
 | 
			
		||||
        t.update_stats()
 | 
			
		||||
 | 
			
		||||
        with mock.patch("time.time", return_value=T+1):
 | 
			
		||||
            t.recordUsage(started=123, result="happy", total_bytes=100,
 | 
			
		||||
                          total_time=10, waiting_time=2)
 | 
			
		||||
        self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(),
 | 
			
		||||
                         [dict(result="happy", started=123,
 | 
			
		||||
                               total_bytes=100, total_time=10, waiting_time=2),
 | 
			
		||||
| 
						 | 
				
			
			@ -31,9 +42,10 @@ class DB(unittest.TestCase):
 | 
			
		|||
                              incomplete_bytes=0,
 | 
			
		||||
                              waiting=0, connected=0))
 | 
			
		||||
 | 
			
		||||
        with mock.patch("time.time", return_value=T+2):
 | 
			
		||||
            t.recordUsage(started=150, result="errory", total_bytes=200,
 | 
			
		||||
                          total_time=11, waiting_time=3)
 | 
			
		||||
        get_time.t = T + 2
 | 
			
		||||
        usage.record_usage(started=150, mood="errory", total_bytes=200,
 | 
			
		||||
                           total_time=11, waiting_time=3)
 | 
			
		||||
        t.update_stats()
 | 
			
		||||
        self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(),
 | 
			
		||||
                         [dict(result="happy", started=123,
 | 
			
		||||
                               total_bytes=100, total_time=10, waiting_time=2),
 | 
			
		||||
| 
						 | 
				
			
			@ -45,27 +57,37 @@ class DB(unittest.TestCase):
 | 
			
		|||
                              incomplete_bytes=0,
 | 
			
		||||
                              waiting=0, connected=0))
 | 
			
		||||
 | 
			
		||||
        with mock.patch("time.time", return_value=T+3):
 | 
			
		||||
            t.timerUpdateStats()
 | 
			
		||||
        get_time.t = T + 3
 | 
			
		||||
        t.update_stats()
 | 
			
		||||
        self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(),
 | 
			
		||||
                         dict(rebooted=T+0, updated=T+3,
 | 
			
		||||
                              incomplete_bytes=0,
 | 
			
		||||
                              waiting=0, connected=0))
 | 
			
		||||
 | 
			
		||||
    def test_no_db(self):
 | 
			
		||||
        t = Transit(blur_usage=None, log_file=None, usage_db=None)
 | 
			
		||||
        t = Transit(
 | 
			
		||||
            create_usage_tracker(blur_usage=None, log_file=None, usage_db=None),
 | 
			
		||||
            lambda: 0,
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(0, len(t.usage._backends))
 | 
			
		||||
 | 
			
		||||
        t.recordUsage(started=123, result="happy", total_bytes=100,
 | 
			
		||||
                      total_time=10, waiting_time=2)
 | 
			
		||||
        t.timerUpdateStats()
 | 
			
		||||
 | 
			
		||||
class LogToStdout(unittest.TestCase):
 | 
			
		||||
    def test_log(self):
 | 
			
		||||
        # emit lines of JSON to log_file, if set
 | 
			
		||||
        log_file = io.StringIO()
 | 
			
		||||
        t = Transit(blur_usage=None, log_file=log_file, usage_db=None)
 | 
			
		||||
        t.recordUsage(started=123, result="happy", total_bytes=100,
 | 
			
		||||
                      total_time=10, waiting_time=2)
 | 
			
		||||
        t = Transit(
 | 
			
		||||
            create_usage_tracker(blur_usage=None, log_file=log_file, usage_db=None),
 | 
			
		||||
            lambda: 0,
 | 
			
		||||
        )
 | 
			
		||||
        with mock.patch("time.time", return_value=133):
 | 
			
		||||
            t.usage.record(
 | 
			
		||||
                started=123,
 | 
			
		||||
                buddy_started=125,
 | 
			
		||||
                result="happy",
 | 
			
		||||
                bytes_sent=100,
 | 
			
		||||
                buddy_bytes=0,
 | 
			
		||||
            )
 | 
			
		||||
        self.assertEqual(json.loads(log_file.getvalue()),
 | 
			
		||||
                         {"started": 123, "total_time": 10,
 | 
			
		||||
                          "waiting_time": 2, "total_bytes": 100,
 | 
			
		||||
| 
						 | 
				
			
			@ -75,15 +97,34 @@ class LogToStdout(unittest.TestCase):
 | 
			
		|||
        # if blurring is enabled, timestamps should be rounded to the
 | 
			
		||||
        # requested amount, and sizes should be rounded up too
 | 
			
		||||
        log_file = io.StringIO()
 | 
			
		||||
        t = Transit(blur_usage=60, log_file=log_file, usage_db=None)
 | 
			
		||||
        t.recordUsage(started=123, result="happy", total_bytes=11999,
 | 
			
		||||
                      total_time=10, waiting_time=2)
 | 
			
		||||
        t = Transit(
 | 
			
		||||
            create_usage_tracker(blur_usage=60, log_file=log_file, usage_db=None),
 | 
			
		||||
            lambda: 0,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        with mock.patch("time.time", return_value=123 + 10):
 | 
			
		||||
            t.usage.record(
 | 
			
		||||
                started=123,
 | 
			
		||||
                buddy_started=125,
 | 
			
		||||
                result="happy",
 | 
			
		||||
                bytes_sent=11999,
 | 
			
		||||
                buddy_bytes=0,
 | 
			
		||||
            )
 | 
			
		||||
        print(log_file.getvalue())
 | 
			
		||||
        self.assertEqual(json.loads(log_file.getvalue()),
 | 
			
		||||
                         {"started": 120, "total_time": 10,
 | 
			
		||||
                          "waiting_time": 2, "total_bytes": 20000,
 | 
			
		||||
                          "mood": "happy"})
 | 
			
		||||
 | 
			
		||||
    def test_do_not_log(self):
 | 
			
		||||
        t = Transit(blur_usage=60, log_file=None, usage_db=None)
 | 
			
		||||
        t.recordUsage(started=123, result="happy", total_bytes=11999,
 | 
			
		||||
                      total_time=10, waiting_time=2)
 | 
			
		||||
        t = Transit(
 | 
			
		||||
            create_usage_tracker(blur_usage=60, log_file=None, usage_db=None),
 | 
			
		||||
            lambda: 0,
 | 
			
		||||
        )
 | 
			
		||||
        t.usage.record(
 | 
			
		||||
            started=123,
 | 
			
		||||
            buddy_started=124,
 | 
			
		||||
            result="happy",
 | 
			
		||||
            bytes_sent=11999,
 | 
			
		||||
            buddy_bytes=12,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,7 +1,30 @@
 | 
			
		|||
from binascii import hexlify
 | 
			
		||||
from twisted.trial import unittest
 | 
			
		||||
from .common import ServerBase
 | 
			
		||||
from .. import transit_server
 | 
			
		||||
from twisted.test import iosim
 | 
			
		||||
from autobahn.twisted.websocket import (
 | 
			
		||||
    WebSocketServerFactory,
 | 
			
		||||
    WebSocketClientFactory,
 | 
			
		||||
    WebSocketClientProtocol,
 | 
			
		||||
)
 | 
			
		||||
from autobahn.twisted.testing import (
 | 
			
		||||
    create_pumper,
 | 
			
		||||
    MemoryReactorClockResolver,
 | 
			
		||||
)
 | 
			
		||||
from autobahn.exception import Disconnected
 | 
			
		||||
from zope.interface import implementer
 | 
			
		||||
from .common import (
 | 
			
		||||
    ServerBase,
 | 
			
		||||
    IRelayTestClient,
 | 
			
		||||
)
 | 
			
		||||
from ..usage import (
 | 
			
		||||
    MemoryUsageRecorder,
 | 
			
		||||
    blur_size,
 | 
			
		||||
)
 | 
			
		||||
from ..transit_server import (
 | 
			
		||||
    WebSocketTransitConnection,
 | 
			
		||||
    TransitServerState,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def handshake(token, side=None):
 | 
			
		||||
    hs = b"please relay " + hexlify(token)
 | 
			
		||||
| 
						 | 
				
			
			@ -12,27 +35,28 @@ def handshake(token, side=None):
 | 
			
		|||
 | 
			
		||||
class _Transit:
 | 
			
		||||
    def count(self):
 | 
			
		||||
        return sum([len(potentials)
 | 
			
		||||
                    for potentials
 | 
			
		||||
                    in self._transit_server._pending_requests.values()])
 | 
			
		||||
        return sum([
 | 
			
		||||
            len(potentials)
 | 
			
		||||
            for potentials
 | 
			
		||||
            in self._transit_server.pending_requests._requests.values()
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
    def test_blur_size(self):
 | 
			
		||||
        blur = transit_server.blur_size
 | 
			
		||||
        self.failUnlessEqual(blur(0), 0)
 | 
			
		||||
        self.failUnlessEqual(blur(1), 10e3)
 | 
			
		||||
        self.failUnlessEqual(blur(10e3), 10e3)
 | 
			
		||||
        self.failUnlessEqual(blur(10e3+1), 20e3)
 | 
			
		||||
        self.failUnlessEqual(blur(15e3), 20e3)
 | 
			
		||||
        self.failUnlessEqual(blur(20e3), 20e3)
 | 
			
		||||
        self.failUnlessEqual(blur(1e6), 1e6)
 | 
			
		||||
        self.failUnlessEqual(blur(1e6+1), 2e6)
 | 
			
		||||
        self.failUnlessEqual(blur(1.5e6), 2e6)
 | 
			
		||||
        self.failUnlessEqual(blur(2e6), 2e6)
 | 
			
		||||
        self.failUnlessEqual(blur(900e6), 900e6)
 | 
			
		||||
        self.failUnlessEqual(blur(1000e6), 1000e6)
 | 
			
		||||
        self.failUnlessEqual(blur(1050e6), 1100e6)
 | 
			
		||||
        self.failUnlessEqual(blur(1100e6), 1100e6)
 | 
			
		||||
        self.failUnlessEqual(blur(1150e6), 1200e6)
 | 
			
		||||
        self.failUnlessEqual(blur_size(0), 0)
 | 
			
		||||
        self.failUnlessEqual(blur_size(1), 10e3)
 | 
			
		||||
        self.failUnlessEqual(blur_size(10e3), 10e3)
 | 
			
		||||
        self.failUnlessEqual(blur_size(10e3+1), 20e3)
 | 
			
		||||
        self.failUnlessEqual(blur_size(15e3), 20e3)
 | 
			
		||||
        self.failUnlessEqual(blur_size(20e3), 20e3)
 | 
			
		||||
        self.failUnlessEqual(blur_size(1e6), 1e6)
 | 
			
		||||
        self.failUnlessEqual(blur_size(1e6+1), 2e6)
 | 
			
		||||
        self.failUnlessEqual(blur_size(1.5e6), 2e6)
 | 
			
		||||
        self.failUnlessEqual(blur_size(2e6), 2e6)
 | 
			
		||||
        self.failUnlessEqual(blur_size(900e6), 900e6)
 | 
			
		||||
        self.failUnlessEqual(blur_size(1000e6), 1000e6)
 | 
			
		||||
        self.failUnlessEqual(blur_size(1050e6), 1100e6)
 | 
			
		||||
        self.failUnlessEqual(blur_size(1100e6), 1100e6)
 | 
			
		||||
        self.failUnlessEqual(blur_size(1150e6), 1200e6)
 | 
			
		||||
 | 
			
		||||
    def test_register(self):
 | 
			
		||||
        p1 = self.new_protocol()
 | 
			
		||||
| 
						 | 
				
			
			@ -49,7 +73,7 @@ class _Transit:
 | 
			
		|||
        self.assertEqual(self.count(), 0)
 | 
			
		||||
 | 
			
		||||
        # the token should be removed too
 | 
			
		||||
        self.assertEqual(len(self._transit_server._pending_requests), 0)
 | 
			
		||||
        self.assertEqual(len(self._transit_server.pending_requests._requests), 0)
 | 
			
		||||
 | 
			
		||||
    def test_both_unsided(self):
 | 
			
		||||
        p1 = self.new_protocol()
 | 
			
		||||
| 
						 | 
				
			
			@ -75,7 +99,6 @@ class _Transit:
 | 
			
		|||
        self.assertEqual(p2.get_received_data(), s1)
 | 
			
		||||
 | 
			
		||||
        p1.disconnect()
 | 
			
		||||
        p2.disconnect()
 | 
			
		||||
        self.flush()
 | 
			
		||||
 | 
			
		||||
    def test_sided_unsided(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -104,7 +127,6 @@ class _Transit:
 | 
			
		|||
        self.assertEqual(p2.get_received_data(), s1)
 | 
			
		||||
 | 
			
		||||
        p1.disconnect()
 | 
			
		||||
        p2.disconnect()
 | 
			
		||||
        self.flush()
 | 
			
		||||
 | 
			
		||||
    def test_unsided_sided(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -177,6 +199,7 @@ class _Transit:
 | 
			
		|||
 | 
			
		||||
        p2.send(handshake(token1, side=side1))
 | 
			
		||||
        self.flush()
 | 
			
		||||
        self.flush()
 | 
			
		||||
        self.assertEqual(self.count(), 2) # same-side connections don't match
 | 
			
		||||
 | 
			
		||||
        # when the second side arrives, the spare first connection should be
 | 
			
		||||
| 
						 | 
				
			
			@ -185,8 +208,8 @@ class _Transit:
 | 
			
		|||
        p3.send(handshake(token1, side=side2))
 | 
			
		||||
        self.flush()
 | 
			
		||||
        self.assertEqual(self.count(), 0)
 | 
			
		||||
        self.assertEqual(len(self._transit_server._pending_requests), 0)
 | 
			
		||||
        self.assertEqual(len(self._transit_server._active_connections), 2)
 | 
			
		||||
        self.assertEqual(len(self._transit_server.pending_requests._requests), 0)
 | 
			
		||||
        self.assertEqual(len(self._transit_server.active_connections._connections), 2)
 | 
			
		||||
        # That will trigger a disconnect on exactly one of (p1 or p2).
 | 
			
		||||
        # The other connection should still be connected
 | 
			
		||||
        self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1)
 | 
			
		||||
| 
						 | 
				
			
			@ -266,7 +289,8 @@ class _Transit:
 | 
			
		|||
 | 
			
		||||
        token1 = b"\x00"*32
 | 
			
		||||
        # sending too many bytes is impatience.
 | 
			
		||||
        p1.send(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
 | 
			
		||||
        p1.send(b"please relay " + hexlify(token1))
 | 
			
		||||
        p1.send(b"\nNOWNOWNOW")
 | 
			
		||||
        self.flush()
 | 
			
		||||
 | 
			
		||||
        exp = b"impatient\n"
 | 
			
		||||
| 
						 | 
				
			
			@ -281,7 +305,8 @@ class _Transit:
 | 
			
		|||
        side1 = b"\x01"*8
 | 
			
		||||
        # sending too many bytes is impatience.
 | 
			
		||||
        p1.send(b"please relay " + hexlify(token1) +
 | 
			
		||||
                b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
 | 
			
		||||
                b" for side " + hexlify(side1))
 | 
			
		||||
        p1.send(b"\nNOWNOWNOW")
 | 
			
		||||
        self.flush()
 | 
			
		||||
 | 
			
		||||
        exp = b"impatient\n"
 | 
			
		||||
| 
						 | 
				
			
			@ -327,22 +352,163 @@ class _Transit:
 | 
			
		|||
        # hang up before sending anything
 | 
			
		||||
        p1.disconnect()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransitWithLogs(_Transit, ServerBase, unittest.TestCase):
 | 
			
		||||
    log_requests = True
 | 
			
		||||
 | 
			
		||||
    def new_protocol(self):
 | 
			
		||||
        return self.new_protocol_tcp()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
 | 
			
		||||
    log_requests = False
 | 
			
		||||
 | 
			
		||||
    def new_protocol(self):
 | 
			
		||||
        return self.new_protocol_tcp()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _new_protocol_ws(transit_server, log_requests):
 | 
			
		||||
    """
 | 
			
		||||
    Internal helper for test-suites that need to provide WebSocket
 | 
			
		||||
    client/server pairs.
 | 
			
		||||
 | 
			
		||||
    :returns: a 2-tuple: (iosim.IOPump, protocol)
 | 
			
		||||
    """
 | 
			
		||||
    ws_factory = WebSocketServerFactory("ws://localhost:4002")
 | 
			
		||||
    ws_factory.protocol = WebSocketTransitConnection
 | 
			
		||||
    ws_factory.transit = transit_server
 | 
			
		||||
    ws_factory.log_requests = log_requests
 | 
			
		||||
    ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
 | 
			
		||||
 | 
			
		||||
    @implementer(IRelayTestClient)
 | 
			
		||||
    class TransitWebSocketClientProtocol(WebSocketClientProtocol):
 | 
			
		||||
        _received = b""
 | 
			
		||||
        connected = False
 | 
			
		||||
 | 
			
		||||
        def connectionMade(self):
 | 
			
		||||
            self.connected = True
 | 
			
		||||
            return super(TransitWebSocketClientProtocol, self).connectionMade()
 | 
			
		||||
 | 
			
		||||
        def connectionLost(self, reason):
 | 
			
		||||
            self.connected = False
 | 
			
		||||
            return super(TransitWebSocketClientProtocol, self).connectionLost(reason)
 | 
			
		||||
 | 
			
		||||
        def onMessage(self, data, isBinary):
 | 
			
		||||
            self._received = self._received + data
 | 
			
		||||
 | 
			
		||||
        def send(self, data):
 | 
			
		||||
            self.sendMessage(data, True)
 | 
			
		||||
 | 
			
		||||
        def get_received_data(self):
 | 
			
		||||
            return self._received
 | 
			
		||||
 | 
			
		||||
        def reset_received_data(self):
 | 
			
		||||
            self._received = b""
 | 
			
		||||
 | 
			
		||||
        def disconnect(self):
 | 
			
		||||
            self.sendClose(1000, True)
 | 
			
		||||
 | 
			
		||||
    client_factory = WebSocketClientFactory()
 | 
			
		||||
    client_factory.protocol = TransitWebSocketClientProtocol
 | 
			
		||||
    client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
 | 
			
		||||
    client_protocol.disconnect = client_protocol.dropConnection
 | 
			
		||||
 | 
			
		||||
    pump = iosim.connect(
 | 
			
		||||
        ws_protocol,
 | 
			
		||||
        iosim.makeFakeServer(ws_protocol),
 | 
			
		||||
        client_protocol,
 | 
			
		||||
        iosim.makeFakeClient(client_protocol),
 | 
			
		||||
    )
 | 
			
		||||
    return pump, client_protocol
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransitWebSockets(_Transit, ServerBase, unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def new_protocol(self):
 | 
			
		||||
        return self.new_protocol_ws()
 | 
			
		||||
 | 
			
		||||
    def new_protocol_ws(self):
 | 
			
		||||
        pump, proto = _new_protocol_ws(self._transit_server, self.log_requests)
 | 
			
		||||
        self._pumps.append(pump)
 | 
			
		||||
        return proto
 | 
			
		||||
 | 
			
		||||
    def test_websocket_to_tcp(self):
 | 
			
		||||
        """
 | 
			
		||||
        One client is WebSocket and one is TCP
 | 
			
		||||
        """
 | 
			
		||||
        p1 = self.new_protocol_ws()
 | 
			
		||||
        p2 = self.new_protocol_tcp()
 | 
			
		||||
 | 
			
		||||
        token1 = b"\x00"*32
 | 
			
		||||
        side1 = b"\x01"*8
 | 
			
		||||
        side2 = b"\x02"*8
 | 
			
		||||
        p1.send(handshake(token1, side=side1))
 | 
			
		||||
        self.flush()
 | 
			
		||||
        p2.send(handshake(token1, side=side2))
 | 
			
		||||
        self.flush()
 | 
			
		||||
 | 
			
		||||
        # a correct handshake yields an ack, after which we can send
 | 
			
		||||
        exp = b"ok\n"
 | 
			
		||||
        self.assertEqual(p1.get_received_data(), exp)
 | 
			
		||||
        self.assertEqual(p2.get_received_data(), exp)
 | 
			
		||||
 | 
			
		||||
        p1.reset_received_data()
 | 
			
		||||
        p2.reset_received_data()
 | 
			
		||||
 | 
			
		||||
        # all data they sent after the handshake should be given to us
 | 
			
		||||
        s1 = b"data1"
 | 
			
		||||
        p1.send(s1)
 | 
			
		||||
        self.flush()
 | 
			
		||||
        self.assertEqual(p2.get_received_data(), s1)
 | 
			
		||||
 | 
			
		||||
        p1.disconnect()
 | 
			
		||||
        p2.disconnect()
 | 
			
		||||
        self.flush()
 | 
			
		||||
 | 
			
		||||
    def test_bad_handshake_old_slow(self):
 | 
			
		||||
        """
 | 
			
		||||
        This test only makes sense for TCP
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    def test_send_closed_partner(self):
 | 
			
		||||
        """
 | 
			
		||||
        Sending data to a closed partner causes an error that propogates
 | 
			
		||||
        to the sender.
 | 
			
		||||
        """
 | 
			
		||||
        p1 = self.new_protocol()
 | 
			
		||||
        p2 = self.new_protocol()
 | 
			
		||||
 | 
			
		||||
        # set up a successful connection
 | 
			
		||||
        token = b"a" * 32
 | 
			
		||||
        p1.send(handshake(token))
 | 
			
		||||
        p2.send(handshake(token))
 | 
			
		||||
        self.flush()
 | 
			
		||||
 | 
			
		||||
        # p2 loses connection, then p1 sends a message
 | 
			
		||||
        p2.transport.loseConnection()
 | 
			
		||||
        self.flush()
 | 
			
		||||
 | 
			
		||||
        # at this point, p1 learns that p2 is disconnected (because it
 | 
			
		||||
        # tried to relay "a message" but failed)
 | 
			
		||||
 | 
			
		||||
        # try to send more (our partner p2 is gone now though so it
 | 
			
		||||
        # should be an immediate error)
 | 
			
		||||
        with self.assertRaises(Disconnected):
 | 
			
		||||
            p1.send(b"more message")
 | 
			
		||||
            self.flush()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Usage(ServerBase, unittest.TestCase):
 | 
			
		||||
    log_requests = True
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super(Usage, self).setUp()
 | 
			
		||||
        self._usage = []
 | 
			
		||||
        def record(started, result, total_bytes, total_time, waiting_time):
 | 
			
		||||
            self._usage.append((started, result, total_bytes,
 | 
			
		||||
                                total_time, waiting_time))
 | 
			
		||||
        self._transit_server.recordUsage = record
 | 
			
		||||
        self._usage = MemoryUsageRecorder()
 | 
			
		||||
        self._transit_server.usage.add_backend(self._usage)
 | 
			
		||||
 | 
			
		||||
    def new_protocol(self):
 | 
			
		||||
        return self.new_protocol_tcp()
 | 
			
		||||
 | 
			
		||||
    def test_empty(self):
 | 
			
		||||
        p1 = self.new_protocol()
 | 
			
		||||
| 
						 | 
				
			
			@ -351,11 +517,14 @@ class Usage(ServerBase, unittest.TestCase):
 | 
			
		|||
        self.flush()
 | 
			
		||||
 | 
			
		||||
        # that will log the "empty" usage event
 | 
			
		||||
        self.assertEqual(len(self._usage), 1, self._usage)
 | 
			
		||||
        (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
 | 
			
		||||
        self.assertEqual(result, "empty", self._usage)
 | 
			
		||||
        self.assertEqual(len(self._usage.events), 1, self._usage)
 | 
			
		||||
        self.assertEqual(self._usage.events[0]["mood"], "empty", self._usage)
 | 
			
		||||
 | 
			
		||||
    def test_short(self):
 | 
			
		||||
        # Note: this test only runs on TCP clients because WebSockets
 | 
			
		||||
        # already does framing (so it's either "a bad handshake" or
 | 
			
		||||
        # there's no handshake at all yet .. you can't have a "short"
 | 
			
		||||
        # one).
 | 
			
		||||
        p1 = self.new_protocol()
 | 
			
		||||
        # hang up before sending a complete handshake
 | 
			
		||||
        p1.send(b"short")
 | 
			
		||||
| 
						 | 
				
			
			@ -363,9 +532,8 @@ class Usage(ServerBase, unittest.TestCase):
 | 
			
		|||
        self.flush()
 | 
			
		||||
 | 
			
		||||
        # that will log the "empty" usage event
 | 
			
		||||
        self.assertEqual(len(self._usage), 1, self._usage)
 | 
			
		||||
        (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
 | 
			
		||||
        self.assertEqual(result, "empty", self._usage)
 | 
			
		||||
        self.assertEqual(len(self._usage.events), 1, self._usage)
 | 
			
		||||
        self.assertEqual("empty", self._usage.events[0]["mood"])
 | 
			
		||||
 | 
			
		||||
    def test_errory(self):
 | 
			
		||||
        p1 = self.new_protocol()
 | 
			
		||||
| 
						 | 
				
			
			@ -374,9 +542,8 @@ class Usage(ServerBase, unittest.TestCase):
 | 
			
		|||
        self.flush()
 | 
			
		||||
        # that will log the "errory" usage event, then drop the connection
 | 
			
		||||
        p1.disconnect()
 | 
			
		||||
        self.assertEqual(len(self._usage), 1, self._usage)
 | 
			
		||||
        (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
 | 
			
		||||
        self.assertEqual(result, "errory", self._usage)
 | 
			
		||||
        self.assertEqual(len(self._usage.events), 1, self._usage)
 | 
			
		||||
        self.assertEqual(self._usage.events[0]["mood"], "errory", self._usage)
 | 
			
		||||
 | 
			
		||||
    def test_lonely(self):
 | 
			
		||||
        p1 = self.new_protocol()
 | 
			
		||||
| 
						 | 
				
			
			@ -389,10 +556,9 @@ class Usage(ServerBase, unittest.TestCase):
 | 
			
		|||
        p1.disconnect()
 | 
			
		||||
        self.flush()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(self._usage), 1, self._usage)
 | 
			
		||||
        (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
 | 
			
		||||
        self.assertEqual(result, "lonely", self._usage)
 | 
			
		||||
        self.assertIdentical(waiting_time, None)
 | 
			
		||||
        self.assertEqual(len(self._usage.events), 1, self._usage)
 | 
			
		||||
        self.assertEqual(self._usage.events[0]["mood"], "lonely", self._usage)
 | 
			
		||||
        self.assertIdentical(self._usage.events[0]["waiting_time"], None)
 | 
			
		||||
 | 
			
		||||
    def test_one_happy_one_jilted(self):
 | 
			
		||||
        p1 = self.new_protocol()
 | 
			
		||||
| 
						 | 
				
			
			@ -406,7 +572,7 @@ class Usage(ServerBase, unittest.TestCase):
 | 
			
		|||
        p2.send(handshake(token1, side=side2))
 | 
			
		||||
        self.flush()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(self._usage, []) # no events yet
 | 
			
		||||
        self.assertEqual(self._usage.events, []) # no events yet
 | 
			
		||||
 | 
			
		||||
        p1.send(b"\x00" * 13)
 | 
			
		||||
        self.flush()
 | 
			
		||||
| 
						 | 
				
			
			@ -416,11 +582,10 @@ class Usage(ServerBase, unittest.TestCase):
 | 
			
		|||
        p1.disconnect()
 | 
			
		||||
        self.flush()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(self._usage), 1, self._usage)
 | 
			
		||||
        (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
 | 
			
		||||
        self.assertEqual(result, "happy", self._usage)
 | 
			
		||||
        self.assertEqual(total_bytes, 20)
 | 
			
		||||
        self.assertNotIdentical(waiting_time, None)
 | 
			
		||||
        self.assertEqual(len(self._usage.events), 1, self._usage)
 | 
			
		||||
        self.assertEqual(self._usage.events[0]["mood"], "happy", self._usage)
 | 
			
		||||
        self.assertEqual(self._usage.events[0]["total_bytes"], 20)
 | 
			
		||||
        self.assertNotIdentical(self._usage.events[0]["waiting_time"], None)
 | 
			
		||||
 | 
			
		||||
    def test_redundant(self):
 | 
			
		||||
        p1a = self.new_protocol()
 | 
			
		||||
| 
						 | 
				
			
			@ -443,21 +608,80 @@ class Usage(ServerBase, unittest.TestCase):
 | 
			
		|||
        p1c.disconnect()
 | 
			
		||||
        self.flush()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(self._usage), 1, self._usage)
 | 
			
		||||
        (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
 | 
			
		||||
        self.assertEqual(result, "lonely", self._usage)
 | 
			
		||||
        self.assertEqual(len(self._usage.events), 1, self._usage)
 | 
			
		||||
        self.assertEqual(self._usage.events[0]["mood"], "lonely")
 | 
			
		||||
 | 
			
		||||
        p2.send(handshake(token1, side=side2))
 | 
			
		||||
        self.flush()
 | 
			
		||||
        self.assertEqual(len(self._transit_server._pending_requests), 0)
 | 
			
		||||
        self.assertEqual(len(self._usage), 2, self._usage)
 | 
			
		||||
        (started, result, total_bytes, total_time, waiting_time) = self._usage[1]
 | 
			
		||||
        self.assertEqual(result, "redundant", self._usage)
 | 
			
		||||
        self.assertEqual(len(self._transit_server.pending_requests._requests), 0)
 | 
			
		||||
        self.assertEqual(len(self._usage.events), 2, self._usage)
 | 
			
		||||
        self.assertEqual(self._usage.events[1]["mood"], "redundant")
 | 
			
		||||
 | 
			
		||||
        # one of the these is unecessary, but probably harmless
 | 
			
		||||
        p1a.disconnect()
 | 
			
		||||
        p1b.disconnect()
 | 
			
		||||
        self.flush()
 | 
			
		||||
        self.assertEqual(len(self._usage), 3, self._usage)
 | 
			
		||||
        (started, result, total_bytes, total_time, waiting_time) = self._usage[2]
 | 
			
		||||
        self.assertEqual(result, "happy", self._usage)
 | 
			
		||||
        self.assertEqual(len(self._usage.events), 3, self._usage)
 | 
			
		||||
        self.assertEqual(self._usage.events[2]["mood"], "happy")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UsageWebSockets(Usage):
 | 
			
		||||
    """
 | 
			
		||||
    All the tests of 'Usage' except with a WebSocket (instead of TCP)
 | 
			
		||||
    transport.
 | 
			
		||||
 | 
			
		||||
    This overrides ServerBase.new_protocol to achieve this. It might
 | 
			
		||||
    be nicer to parametrize these tests in a way that doesn't use
 | 
			
		||||
    inheritance .. but all the support etc classes are set up that way
 | 
			
		||||
    already.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super(UsageWebSockets, self).setUp()
 | 
			
		||||
        self._pump = create_pumper()
 | 
			
		||||
        self._reactor = MemoryReactorClockResolver()
 | 
			
		||||
        return self._pump.start()
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        return self._pump.stop()
 | 
			
		||||
 | 
			
		||||
    def new_protocol(self):
 | 
			
		||||
        return self.new_protocol_ws()
 | 
			
		||||
 | 
			
		||||
    def new_protocol_ws(self):
 | 
			
		||||
        pump, proto = _new_protocol_ws(self._transit_server, self.log_requests)
 | 
			
		||||
        self._pumps.append(pump)
 | 
			
		||||
        return proto
 | 
			
		||||
 | 
			
		||||
    def test_short(self):
 | 
			
		||||
        """
 | 
			
		||||
        This test essentially just tests the framing of the line-oriented
 | 
			
		||||
        TCP protocol; it doesnt' make sense for the WebSockets case
 | 
			
		||||
        because WS handles frameing: you either sent a 'bad handshake'
 | 
			
		||||
        because it is semantically invalid or no handshake (yet).
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    def test_send_non_binary_message(self):
 | 
			
		||||
        """
 | 
			
		||||
        A non-binary WebSocket message is an error
 | 
			
		||||
        """
 | 
			
		||||
        ws_factory = WebSocketServerFactory("ws://localhost:4002")
 | 
			
		||||
        ws_factory.protocol = WebSocketTransitConnection
 | 
			
		||||
        ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            ws_protocol.onMessage(u"foo", isBinary=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class State(unittest.TestCase):
 | 
			
		||||
    """
 | 
			
		||||
    Tests related to server_state.TransitServerState
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.state = TransitServerState(None, None)
 | 
			
		||||
 | 
			
		||||
    def test_empty_token(self):
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            "-",
 | 
			
		||||
            self.state.get_token(),
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,9 +1,9 @@
 | 
			
		|||
import re, time, json
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
import re
 | 
			
		||||
import time
 | 
			
		||||
from twisted.python import log
 | 
			
		||||
from twisted.internet import protocol
 | 
			
		||||
from twisted.protocols.basic import LineReceiver
 | 
			
		||||
from .database import get_db
 | 
			
		||||
from autobahn.twisted.websocket import WebSocketServerProtocol
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
SECONDS = 1.0
 | 
			
		||||
MINUTE = 60*SECONDS
 | 
			
		||||
| 
						 | 
				
			
			@ -11,340 +11,254 @@ HOUR = 60*MINUTE
 | 
			
		|||
DAY = 24*HOUR
 | 
			
		||||
MB = 1000*1000
 | 
			
		||||
 | 
			
		||||
def round_to(size, coarseness):
 | 
			
		||||
    return int(coarseness*(1+int((size-1)/coarseness)))
 | 
			
		||||
 | 
			
		||||
def blur_size(size):
 | 
			
		||||
    if size == 0:
 | 
			
		||||
        return 0
 | 
			
		||||
    if size < 1e6:
 | 
			
		||||
        return round_to(size, 10e3)
 | 
			
		||||
    if size < 1e9:
 | 
			
		||||
        return round_to(size, 1e6)
 | 
			
		||||
    return round_to(size, 100e6)
 | 
			
		||||
from wormhole_transit_relay.server_state import (
 | 
			
		||||
    TransitServerState,
 | 
			
		||||
    PendingRequests,
 | 
			
		||||
    ActiveConnections,
 | 
			
		||||
    ITransitClient,
 | 
			
		||||
)
 | 
			
		||||
from zope.interface import implementer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@implementer(ITransitClient)
 | 
			
		||||
class TransitConnection(LineReceiver):
 | 
			
		||||
    delimiter = b'\n'
 | 
			
		||||
    # maximum length of a line we will accept before the handshake is complete.
 | 
			
		||||
    # This must be >= to the longest possible handshake message.
 | 
			
		||||
 | 
			
		||||
    MAX_LENGTH = 1024
 | 
			
		||||
    started_time = None
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._got_token = False
 | 
			
		||||
        self._got_side = False
 | 
			
		||||
        self._sent_ok = False
 | 
			
		||||
        self._mood = "empty"
 | 
			
		||||
    def send(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        ITransitClient API
 | 
			
		||||
        """
 | 
			
		||||
        self.transport.write(data)
 | 
			
		||||
 | 
			
		||||
    def disconnect(self):
 | 
			
		||||
        """
 | 
			
		||||
        ITransitClient API
 | 
			
		||||
        """
 | 
			
		||||
        self.transport.loseConnection()
 | 
			
		||||
 | 
			
		||||
    def connect_partner(self, other):
 | 
			
		||||
        """
 | 
			
		||||
        ITransitClient API
 | 
			
		||||
        """
 | 
			
		||||
        self._buddy = other
 | 
			
		||||
 | 
			
		||||
    def disconnect_partner(self):
 | 
			
		||||
        """
 | 
			
		||||
        ITransitClient API
 | 
			
		||||
        """
 | 
			
		||||
        assert self._buddy is not None, "internal error: no buddy"
 | 
			
		||||
        if self.factory.log_requests:
 | 
			
		||||
            log.msg("buddy_disconnected {}".format(self._buddy.get_token()))
 | 
			
		||||
        self._buddy._client.disconnect()
 | 
			
		||||
        self._buddy = None
 | 
			
		||||
        self._total_sent = 0
 | 
			
		||||
 | 
			
		||||
    def describeToken(self):
 | 
			
		||||
        d = "-"
 | 
			
		||||
        if self._got_token:
 | 
			
		||||
            d = self._got_token[:16].decode("ascii")
 | 
			
		||||
        if self._got_side:
 | 
			
		||||
            d += "-" + self._got_side.decode("ascii")
 | 
			
		||||
        else:
 | 
			
		||||
            d += "-<unsided>"
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def connectionMade(self):
 | 
			
		||||
        self._started = time.time()
 | 
			
		||||
        self._log_requests = self.factory._log_requests
 | 
			
		||||
        # ideally more like self._reactor.seconds() ... but Twisted
 | 
			
		||||
        # doesn't have a good way to get the reactor for a protocol
 | 
			
		||||
        # (besides "use the global one")
 | 
			
		||||
        self.started_time = time.time()
 | 
			
		||||
        self._state = TransitServerState(
 | 
			
		||||
            self.factory.transit.pending_requests,
 | 
			
		||||
            self.factory.transit.usage,
 | 
			
		||||
        )
 | 
			
		||||
        self._state.connection_made(self)
 | 
			
		||||
        self.transport.setTcpKeepAlive(True)
 | 
			
		||||
 | 
			
		||||
        # uncomment to turn on state-machine tracing
 | 
			
		||||
        # def tracer(oldstate, theinput, newstate):
 | 
			
		||||
        #     print("TRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate))
 | 
			
		||||
        # self._state.set_trace_function(tracer)
 | 
			
		||||
 | 
			
		||||
    def lineReceived(self, line):
 | 
			
		||||
        """
 | 
			
		||||
        LineReceiver API
 | 
			
		||||
        """
 | 
			
		||||
        # old: "please relay {64}\n"
 | 
			
		||||
        token = None
 | 
			
		||||
        old = re.search(br"^please relay (\w{64})$", line)
 | 
			
		||||
        if old:
 | 
			
		||||
            token = old.group(1)
 | 
			
		||||
            return self._got_handshake(token, None)
 | 
			
		||||
            self._state.please_relay(token)
 | 
			
		||||
 | 
			
		||||
        # new: "please relay {64} for side {16}\n"
 | 
			
		||||
        new = re.search(br"^please relay (\w{64}) for side (\w{16})$", line)
 | 
			
		||||
        if new:
 | 
			
		||||
            token = new.group(1)
 | 
			
		||||
            side = new.group(2)
 | 
			
		||||
            return self._got_handshake(token, side)
 | 
			
		||||
            self._state.please_relay_for_side(token, side)
 | 
			
		||||
 | 
			
		||||
        self.sendLine(b"bad handshake")
 | 
			
		||||
        if self._log_requests:
 | 
			
		||||
            log.msg("transit handshake failure")
 | 
			
		||||
        return self.disconnect_error()
 | 
			
		||||
        if token is None:
 | 
			
		||||
            self._state.bad_token()
 | 
			
		||||
        else:
 | 
			
		||||
            self.setRawMode()
 | 
			
		||||
 | 
			
		||||
    def rawDataReceived(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        LineReceiver API
 | 
			
		||||
        """
 | 
			
		||||
        # We are an IPushProducer to our buddy's IConsumer, so they'll
 | 
			
		||||
        # throttle us (by calling pauseProducing()) when their outbound
 | 
			
		||||
        # buffer is full (e.g. when their downstream pipe is full). In
 | 
			
		||||
        # practice, this buffers about 10MB per connection, after which
 | 
			
		||||
        # point the sender will only transmit data as fast as the
 | 
			
		||||
        # receiver can handle it.
 | 
			
		||||
        if self._sent_ok:
 | 
			
		||||
            # if self._buddy is None then our buddy disconnected
 | 
			
		||||
            # (we're "jilted"), so we hung up too, but our incoming
 | 
			
		||||
            # data hasn't stopped yet (it will in a moment, after our
 | 
			
		||||
            # disconnect makes a roundtrip through the kernel). This
 | 
			
		||||
            # probably means the file receiver hung up, and this
 | 
			
		||||
            # connection is the file sender. In may-2020 this happened
 | 
			
		||||
            # 11 times in 40 days.
 | 
			
		||||
            if self._buddy:
 | 
			
		||||
                self._total_sent += len(data)
 | 
			
		||||
                self._buddy.transport.write(data)
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # handshake is complete but not yet sent_ok
 | 
			
		||||
        self.sendLine(b"impatient")
 | 
			
		||||
        if self._log_requests:
 | 
			
		||||
            log.msg("transit impatience failure")
 | 
			
		||||
        return self.disconnect_error() # impatience yields failure
 | 
			
		||||
 | 
			
		||||
    def _got_handshake(self, token, side):
 | 
			
		||||
        self._got_token = token
 | 
			
		||||
        self._got_side = side
 | 
			
		||||
        self._mood = "lonely" # until buddy connects
 | 
			
		||||
        self.setRawMode()
 | 
			
		||||
        self.factory.connection_got_token(token, side, self)
 | 
			
		||||
 | 
			
		||||
    def buddy_connected(self, them):
 | 
			
		||||
        self._buddy = them
 | 
			
		||||
        self._mood = "happy"
 | 
			
		||||
        self.sendLine(b"ok")
 | 
			
		||||
        self._sent_ok = True
 | 
			
		||||
        # Connect the two as a producer/consumer pair. We use streaming=True,
 | 
			
		||||
        # so this expects the IPushProducer interface, and uses
 | 
			
		||||
        # pauseProducing() to throttle, and resumeProducing() to unthrottle.
 | 
			
		||||
        self._buddy.transport.registerProducer(self.transport, True)
 | 
			
		||||
        # The Transit object calls buddy_connected() on both protocols, so
 | 
			
		||||
        # there will be two producer/consumer pairs.
 | 
			
		||||
 | 
			
		||||
    def buddy_disconnected(self):
 | 
			
		||||
        if self._log_requests:
 | 
			
		||||
            log.msg("buddy_disconnected %s" % self.describeToken())
 | 
			
		||||
        self._buddy = None
 | 
			
		||||
        self._mood = "jilted"
 | 
			
		||||
        self.transport.loseConnection()
 | 
			
		||||
 | 
			
		||||
    def disconnect_error(self):
 | 
			
		||||
        # we haven't finished the handshake, so there are no tokens tracking
 | 
			
		||||
        # us
 | 
			
		||||
        self._mood = "errory"
 | 
			
		||||
        self.transport.loseConnection()
 | 
			
		||||
        if self.factory._debug_log:
 | 
			
		||||
            log.msg("transitFailed %r" % self)
 | 
			
		||||
 | 
			
		||||
    def disconnect_redundant(self):
 | 
			
		||||
        # this is called if a buddy connected and we were found unnecessary.
 | 
			
		||||
        # Any token-tracking cleanup will have been done before we're called.
 | 
			
		||||
        self._mood = "redundant"
 | 
			
		||||
        self.transport.loseConnection()
 | 
			
		||||
        self._state.got_bytes(data)
 | 
			
		||||
 | 
			
		||||
    def connectionLost(self, reason):
 | 
			
		||||
        finished = time.time()
 | 
			
		||||
        total_time = finished - self._started
 | 
			
		||||
        self._state.connection_lost()
 | 
			
		||||
 | 
			
		||||
        # Record usage. There are eight cases:
 | 
			
		||||
        # * n0: we haven't gotten a full handshake yet (empty)
 | 
			
		||||
        # * n1: the handshake failed, not a real client (errory)
 | 
			
		||||
        # * n2: real client disconnected before any buddy appeared (lonely)
 | 
			
		||||
        # * n3: real client closed as redundant after buddy appears (redundant)
 | 
			
		||||
        # * n4: real client connected first, buddy closes first (jilted)
 | 
			
		||||
        # * n5: real client connected first, buddy close last (happy)
 | 
			
		||||
        # * n6: real client connected last, buddy closes first (jilted)
 | 
			
		||||
        # * n7: real client connected last, buddy closes last (happy)
 | 
			
		||||
 | 
			
		||||
        # * non-connected clients (0,1,2,3) always write a usage record
 | 
			
		||||
        # * for connected clients, whoever disconnects first gets to write the
 | 
			
		||||
        #   usage record (5, 7). The last disconnect doesn't write a record.
 | 
			
		||||
class Transit(object):
 | 
			
		||||
    """
 | 
			
		||||
    I manage pairs of simultaneous connections to a secondary TCP port,
 | 
			
		||||
    both forwarded to the other. Clients must begin each connection with
 | 
			
		||||
    "please relay TOKEN for SIDE\n" (or a legacy form without the "for
 | 
			
		||||
    SIDE"). Two connections match if they use the same TOKEN and have
 | 
			
		||||
    different SIDEs (the redundant connections are dropped when a match is
 | 
			
		||||
    made). Legacy connections match any with the same TOKEN, ignoring SIDE
 | 
			
		||||
    (so two legacy connections will match each other).
 | 
			
		||||
 | 
			
		||||
        if self._mood == "empty": # 0
 | 
			
		||||
            assert not self._buddy
 | 
			
		||||
            self.factory.recordUsage(self._started, "empty", 0,
 | 
			
		||||
                                     total_time, None)
 | 
			
		||||
        elif self._mood == "errory": # 1
 | 
			
		||||
            assert not self._buddy
 | 
			
		||||
            self.factory.recordUsage(self._started, "errory", 0,
 | 
			
		||||
                                     total_time, None)
 | 
			
		||||
        elif self._mood == "redundant": # 3
 | 
			
		||||
            assert not self._buddy
 | 
			
		||||
            self.factory.recordUsage(self._started, "redundant", 0,
 | 
			
		||||
                                     total_time, None)
 | 
			
		||||
        elif self._mood == "jilted": # 4 or 6
 | 
			
		||||
            # we were connected, but our buddy hung up on us. They record the
 | 
			
		||||
            # usage event, we do not
 | 
			
		||||
            pass
 | 
			
		||||
        elif self._mood == "lonely": # 2
 | 
			
		||||
            assert not self._buddy
 | 
			
		||||
            self.factory.recordUsage(self._started, "lonely", 0,
 | 
			
		||||
                                     total_time, None)
 | 
			
		||||
        else: # 5 or 7
 | 
			
		||||
            # we were connected, we hung up first. We record the event.
 | 
			
		||||
            assert self._mood == "happy", self._mood
 | 
			
		||||
            assert self._buddy
 | 
			
		||||
            starts = [self._started, self._buddy._started]
 | 
			
		||||
            total_time = finished - min(starts)
 | 
			
		||||
            waiting_time = max(starts) - min(starts)
 | 
			
		||||
            total_bytes = self._total_sent + self._buddy._total_sent
 | 
			
		||||
            self.factory.recordUsage(self._started, "happy", total_bytes,
 | 
			
		||||
                                     total_time, waiting_time)
 | 
			
		||||
    I will send "ok\n" when the matching connection is established, or
 | 
			
		||||
    disconnect if no matching connection is made within MAX_WAIT_TIME
 | 
			
		||||
    seconds. I will disconnect if you send data before the "ok\n". All data
 | 
			
		||||
    you get after the "ok\n" will be from the other side. You will not
 | 
			
		||||
    receive "ok\n" until the other side has also connected and submitted a
 | 
			
		||||
    matching token (and differing SIDE).
 | 
			
		||||
 | 
			
		||||
        if self._buddy:
 | 
			
		||||
            self._buddy.buddy_disconnected()
 | 
			
		||||
        self.factory.transitFinished(self, self._got_token, self._got_side,
 | 
			
		||||
                                     self.describeToken())
 | 
			
		||||
    In addition, the connections will be dropped after MAXLENGTH bytes have
 | 
			
		||||
    been sent by either side, or MAXTIME seconds have elapsed after the
 | 
			
		||||
    matching connections were established. A future API will reveal these
 | 
			
		||||
    limits to clients instead of causing mysterious spontaneous failures.
 | 
			
		||||
 | 
			
		||||
class Transit(protocol.ServerFactory):
 | 
			
		||||
    # I manage pairs of simultaneous connections to a secondary TCP port,
 | 
			
		||||
    # both forwarded to the other. Clients must begin each connection with
 | 
			
		||||
    # "please relay TOKEN for SIDE\n" (or a legacy form without the "for
 | 
			
		||||
    # SIDE"). Two connections match if they use the same TOKEN and have
 | 
			
		||||
    # different SIDEs (the redundant connections are dropped when a match is
 | 
			
		||||
    # made). Legacy connections match any with the same TOKEN, ignoring SIDE
 | 
			
		||||
    # (so two legacy connections will match each other).
 | 
			
		||||
 | 
			
		||||
    # I will send "ok\n" when the matching connection is established, or
 | 
			
		||||
    # disconnect if no matching connection is made within MAX_WAIT_TIME
 | 
			
		||||
    # seconds. I will disconnect if you send data before the "ok\n". All data
 | 
			
		||||
    # you get after the "ok\n" will be from the other side. You will not
 | 
			
		||||
    # receive "ok\n" until the other side has also connected and submitted a
 | 
			
		||||
    # matching token (and differing SIDE).
 | 
			
		||||
 | 
			
		||||
    # In addition, the connections will be dropped after MAXLENGTH bytes have
 | 
			
		||||
    # been sent by either side, or MAXTIME seconds have elapsed after the
 | 
			
		||||
    # matching connections were established. A future API will reveal these
 | 
			
		||||
    # limits to clients instead of causing mysterious spontaneous failures.
 | 
			
		||||
 | 
			
		||||
    # These relay connections are not half-closeable (unlike full TCP
 | 
			
		||||
    # connections, applications will not receive any data after half-closing
 | 
			
		||||
    # their outgoing side). Applications must negotiate shutdown with their
 | 
			
		||||
    # peer and not close the connection until all data has finished
 | 
			
		||||
    # transferring in both directions. Applications which only need to send
 | 
			
		||||
    # data in one direction can use close() as usual.
 | 
			
		||||
    These relay connections are not half-closeable (unlike full TCP
 | 
			
		||||
    connections, applications will not receive any data after half-closing
 | 
			
		||||
    their outgoing side). Applications must negotiate shutdown with their
 | 
			
		||||
    peer and not close the connection until all data has finished
 | 
			
		||||
    transferring in both directions. Applications which only need to send
 | 
			
		||||
    data in one direction can use close() as usual.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # TODO: unused
 | 
			
		||||
    MAX_WAIT_TIME = 30*SECONDS
 | 
			
		||||
    # TODO: unused
 | 
			
		||||
    MAXLENGTH = 10*MB
 | 
			
		||||
    # TODO: unused
 | 
			
		||||
    MAXTIME = 60*SECONDS
 | 
			
		||||
    protocol = TransitConnection
 | 
			
		||||
 | 
			
		||||
    def __init__(self, blur_usage, log_file, usage_db):
 | 
			
		||||
        self._blur_usage = blur_usage
 | 
			
		||||
        self._log_requests = blur_usage is None
 | 
			
		||||
        if self._blur_usage:
 | 
			
		||||
            log.msg("blurring access times to %d seconds" % self._blur_usage)
 | 
			
		||||
            log.msg("not logging Transit connections to Twisted log")
 | 
			
		||||
        else:
 | 
			
		||||
            log.msg("not blurring access times")
 | 
			
		||||
        self._debug_log = False
 | 
			
		||||
        self._log_file = log_file
 | 
			
		||||
        self._db = None
 | 
			
		||||
        if usage_db:
 | 
			
		||||
            self._db = get_db(usage_db)
 | 
			
		||||
        self._rebooted = time.time()
 | 
			
		||||
        # we don't track TransitConnections until they submit a token
 | 
			
		||||
        self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection))
 | 
			
		||||
        self._active_connections = set() # TransitConnection
 | 
			
		||||
    def __init__(self, usage, get_timestamp):
 | 
			
		||||
        self.active_connections = ActiveConnections()
 | 
			
		||||
        self.pending_requests = PendingRequests(self.active_connections)
 | 
			
		||||
        self.usage = usage
 | 
			
		||||
        self._timestamp = get_timestamp
 | 
			
		||||
        self._rebooted = self._timestamp()
 | 
			
		||||
 | 
			
		||||
    def connection_got_token(self, token, new_side, new_tc):
 | 
			
		||||
        potentials = self._pending_requests[token]
 | 
			
		||||
        for old in potentials:
 | 
			
		||||
            (old_side, old_tc) = old
 | 
			
		||||
            if ((old_side is None)
 | 
			
		||||
                or (new_side is None)
 | 
			
		||||
                or (old_side != new_side)):
 | 
			
		||||
                # we found a match
 | 
			
		||||
                if self._debug_log:
 | 
			
		||||
                    log.msg("transit relay 2: %s" % new_tc.describeToken())
 | 
			
		||||
 | 
			
		||||
                # drop and stop tracking the rest
 | 
			
		||||
                potentials.remove(old)
 | 
			
		||||
                for (_, leftover_tc) in potentials.copy():
 | 
			
		||||
                    # Don't record this as errory. It's just a spare connection
 | 
			
		||||
                    # from the same side as a connection that got used. This
 | 
			
		||||
                    # can happen if the connection hint contains multiple
 | 
			
		||||
                    # addresses (we don't currently support those, but it'd
 | 
			
		||||
                    # probably be useful in the future).
 | 
			
		||||
                    leftover_tc.disconnect_redundant()
 | 
			
		||||
                self._pending_requests.pop(token, None)
 | 
			
		||||
 | 
			
		||||
                # glue the two ends together
 | 
			
		||||
                self._active_connections.add(new_tc)
 | 
			
		||||
                self._active_connections.add(old_tc)
 | 
			
		||||
                new_tc.buddy_connected(old_tc)
 | 
			
		||||
                old_tc.buddy_connected(new_tc)
 | 
			
		||||
                return
 | 
			
		||||
        if self._debug_log:
 | 
			
		||||
            log.msg("transit relay 1: %s" % new_tc.describeToken())
 | 
			
		||||
        potentials.add((new_side, new_tc))
 | 
			
		||||
        # TODO: timer
 | 
			
		||||
 | 
			
		||||
    def transitFinished(self, tc, token, side, description):
 | 
			
		||||
        if token in self._pending_requests:
 | 
			
		||||
            side_tc = (side, tc)
 | 
			
		||||
            self._pending_requests[token].discard(side_tc)
 | 
			
		||||
            if not self._pending_requests[token]: # set is now empty
 | 
			
		||||
                del self._pending_requests[token]
 | 
			
		||||
        if self._debug_log:
 | 
			
		||||
            log.msg("transitFinished %s" % (description,))
 | 
			
		||||
        self._active_connections.discard(tc)
 | 
			
		||||
        # we could update the usage database "current" row immediately, or wait
 | 
			
		||||
        # until the 5-minute timer updates it. If we update it now, just after
 | 
			
		||||
        # losing a connection, we should probably also update it just after
 | 
			
		||||
        # establishing one (at the end of connection_got_token). For now I'm
 | 
			
		||||
        # going to omit these, but maybe someday we'll turn them both on. The
 | 
			
		||||
        # consequence is that a manual execution of the munin scripts ("munin
 | 
			
		||||
        # run wormhole_transit_active") will give the wrong value just after a
 | 
			
		||||
        # connect/disconnect event. Actual munin graphs should accurately
 | 
			
		||||
        # report connections that last longer than the 5-minute sampling
 | 
			
		||||
        # window, which is what we actually care about.
 | 
			
		||||
        #self.timerUpdateStats()
 | 
			
		||||
 | 
			
		||||
    def recordUsage(self, started, result, total_bytes,
 | 
			
		||||
                    total_time, waiting_time):
 | 
			
		||||
        if self._debug_log:
 | 
			
		||||
            log.msg(format="Transit.recordUsage {bytes}B", bytes=total_bytes)
 | 
			
		||||
        if self._blur_usage:
 | 
			
		||||
            started = self._blur_usage * (started // self._blur_usage)
 | 
			
		||||
            total_bytes = blur_size(total_bytes)
 | 
			
		||||
        if self._log_file is not None:
 | 
			
		||||
            data = {"started": started,
 | 
			
		||||
                    "total_time": total_time,
 | 
			
		||||
                    "waiting_time": waiting_time,
 | 
			
		||||
                    "total_bytes": total_bytes,
 | 
			
		||||
                    "mood": result,
 | 
			
		||||
                    }
 | 
			
		||||
            self._log_file.write(json.dumps(data)+"\n")
 | 
			
		||||
            self._log_file.flush()
 | 
			
		||||
        if self._db:
 | 
			
		||||
            self._db.execute("INSERT INTO `usage`"
 | 
			
		||||
                             " (`started`, `total_time`, `waiting_time`,"
 | 
			
		||||
                             "  `total_bytes`, `result`)"
 | 
			
		||||
                             " VALUES (?,?,?, ?,?)",
 | 
			
		||||
                             (started, total_time, waiting_time,
 | 
			
		||||
                              total_bytes, result))
 | 
			
		||||
            self._update_stats()
 | 
			
		||||
            self._db.commit()
 | 
			
		||||
 | 
			
		||||
    def timerUpdateStats(self):
 | 
			
		||||
        if self._db:
 | 
			
		||||
            self._update_stats()
 | 
			
		||||
            self._db.commit()
 | 
			
		||||
 | 
			
		||||
    def _update_stats(self):
 | 
			
		||||
        # current status: should be zero when idle
 | 
			
		||||
        rebooted = self._rebooted
 | 
			
		||||
        updated = time.time()
 | 
			
		||||
        connected = len(self._active_connections) / 2
 | 
			
		||||
    def update_stats(self):
 | 
			
		||||
        # TODO: when a connection is half-closed, len(active) will be odd. a
 | 
			
		||||
        # moment later (hopefully) the other side will disconnect, but
 | 
			
		||||
        # _update_stats isn't updated until later.
 | 
			
		||||
        waiting = len(self._pending_requests)
 | 
			
		||||
 | 
			
		||||
        # "waiting" doesn't count multiple parallel connections from the same
 | 
			
		||||
        # side
 | 
			
		||||
        incomplete_bytes = sum(tc._total_sent
 | 
			
		||||
                               for tc in self._active_connections)
 | 
			
		||||
        self._db.execute("DELETE FROM `current`")
 | 
			
		||||
        self._db.execute("INSERT INTO `current`"
 | 
			
		||||
                         " (`rebooted`, `updated`, `connected`, `waiting`,"
 | 
			
		||||
                         "  `incomplete_bytes`)"
 | 
			
		||||
                         " VALUES (?, ?, ?, ?, ?)",
 | 
			
		||||
                         (rebooted, updated, connected, waiting,
 | 
			
		||||
                          incomplete_bytes))
 | 
			
		||||
        self.usage.update_stats(
 | 
			
		||||
            rebooted=self._rebooted,
 | 
			
		||||
            updated=self._timestamp(),
 | 
			
		||||
            connected=len(self.active_connections._connections),
 | 
			
		||||
            waiting=len(self.pending_requests._requests),
 | 
			
		||||
            incomplete_bytes=sum(
 | 
			
		||||
                tc._total_sent
 | 
			
		||||
                for tc in self.active_connections._connections
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@implementer(ITransitClient)
 | 
			
		||||
class WebSocketTransitConnection(WebSocketServerProtocol):
 | 
			
		||||
    started_time = None
 | 
			
		||||
 | 
			
		||||
    def send(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        ITransitClient API
 | 
			
		||||
        """
 | 
			
		||||
        self.sendMessage(data, isBinary=True)
 | 
			
		||||
 | 
			
		||||
    def disconnect(self):
 | 
			
		||||
        """
 | 
			
		||||
        ITransitClient API
 | 
			
		||||
        """
 | 
			
		||||
        self.sendClose(1000, None)
 | 
			
		||||
 | 
			
		||||
    def connect_partner(self, other):
 | 
			
		||||
        """
 | 
			
		||||
        ITransitClient API
 | 
			
		||||
        """
 | 
			
		||||
        self._buddy = other
 | 
			
		||||
 | 
			
		||||
    def disconnect_partner(self):
 | 
			
		||||
        """
 | 
			
		||||
        ITransitClient API
 | 
			
		||||
        """
 | 
			
		||||
        assert self._buddy is not None, "internal error: no buddy"
 | 
			
		||||
        if self.factory.log_requests:
 | 
			
		||||
            log.msg("buddy_disconnected {}".format(self._buddy.get_token()))
 | 
			
		||||
        self._buddy._client.disconnect()
 | 
			
		||||
        self._buddy = None
 | 
			
		||||
 | 
			
		||||
    def connectionMade(self):
 | 
			
		||||
        """
 | 
			
		||||
        IProtocol API
 | 
			
		||||
        """
 | 
			
		||||
        super(WebSocketTransitConnection, self).connectionMade()
 | 
			
		||||
        self.started_time = time.time()
 | 
			
		||||
        self._first_message = True
 | 
			
		||||
        self._state = TransitServerState(
 | 
			
		||||
            self.factory.transit.pending_requests,
 | 
			
		||||
            self.factory.transit.usage,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # uncomment to turn on state-machine tracing
 | 
			
		||||
        # def tracer(oldstate, theinput, newstate):
 | 
			
		||||
        #    print("WSTRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate))
 | 
			
		||||
        # self._state.set_trace_function(tracer)
 | 
			
		||||
 | 
			
		||||
    def onOpen(self):
 | 
			
		||||
        self._state.connection_made(self)
 | 
			
		||||
 | 
			
		||||
    def onMessage(self, payload, isBinary):
 | 
			
		||||
        """
 | 
			
		||||
        We may have a 'handshake' on our hands or we may just have some bytes to relay
 | 
			
		||||
        """
 | 
			
		||||
        if not isBinary:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "All messages must be binary"
 | 
			
		||||
            )
 | 
			
		||||
        if self._first_message:
 | 
			
		||||
            self._first_message = False
 | 
			
		||||
            token = None
 | 
			
		||||
            old = re.search(br"^please relay (\w{64})$", payload)
 | 
			
		||||
            if old:
 | 
			
		||||
                token = old.group(1)
 | 
			
		||||
                self._state.please_relay(token)
 | 
			
		||||
 | 
			
		||||
            # new: "please relay {64} for side {16}\n"
 | 
			
		||||
            new = re.search(br"^please relay (\w{64}) for side (\w{16})$", payload)
 | 
			
		||||
            if new:
 | 
			
		||||
                token = new.group(1)
 | 
			
		||||
                side = new.group(2)
 | 
			
		||||
                self._state.please_relay_for_side(token, side)
 | 
			
		||||
 | 
			
		||||
            if token is None:
 | 
			
		||||
                self._state.bad_token()
 | 
			
		||||
        else:
 | 
			
		||||
            self._state.got_bytes(payload)
 | 
			
		||||
 | 
			
		||||
    def onClose(self, wasClean, code, reason):
 | 
			
		||||
        """
 | 
			
		||||
        IWebSocketChannel API
 | 
			
		||||
        """
 | 
			
		||||
        self._state.connection_lost()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										238
									
								
								src/wormhole_transit_relay/usage.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										238
									
								
								src/wormhole_transit_relay/usage.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,238 @@
 | 
			
		|||
import time
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
from twisted.python import log
 | 
			
		||||
from zope.interface import (
 | 
			
		||||
    implementer,
 | 
			
		||||
    Interface,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_usage_tracker(blur_usage, log_file, usage_db):
 | 
			
		||||
    """
 | 
			
		||||
    :param int blur_usage: see UsageTracker
 | 
			
		||||
 | 
			
		||||
    :param log_file: None or a file-like object to write JSON-encoded
 | 
			
		||||
        lines of usage information to.
 | 
			
		||||
 | 
			
		||||
    :param usage_db: None or an sqlite3 database connection
 | 
			
		||||
 | 
			
		||||
    :returns: a new UsageTracker instance configured with backends.
 | 
			
		||||
    """
 | 
			
		||||
    tracker = UsageTracker(blur_usage)
 | 
			
		||||
    if usage_db:
 | 
			
		||||
        tracker.add_backend(DatabaseUsageRecorder(usage_db))
 | 
			
		||||
    if log_file:
 | 
			
		||||
        tracker.add_backend(LogFileUsageRecorder(log_file))
 | 
			
		||||
    return tracker
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IUsageWriter(Interface):
 | 
			
		||||
    """
 | 
			
		||||
    Records actual usage statistics in some way
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def record_usage(started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None):
 | 
			
		||||
        """
 | 
			
		||||
        :param int started: timestemp when this connection began
 | 
			
		||||
 | 
			
		||||
        :param float total_time: total seconds this connection lasted
 | 
			
		||||
 | 
			
		||||
        :param float waiting_time: None or the total seconds one side
 | 
			
		||||
            waited for the other
 | 
			
		||||
 | 
			
		||||
        :param int total_bytes: the total bytes sent. In case the
 | 
			
		||||
            connection was concluded successfully, only one side will
 | 
			
		||||
            record the total bytes (but count both).
 | 
			
		||||
 | 
			
		||||
        :param str mood: the 'mood' of the connection
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@implementer(IUsageWriter)
 | 
			
		||||
class MemoryUsageRecorder:
 | 
			
		||||
    """
 | 
			
		||||
    Remebers usage records in memory.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.events = []
 | 
			
		||||
 | 
			
		||||
    def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None):
 | 
			
		||||
        """
 | 
			
		||||
        IUsageWriter.
 | 
			
		||||
        """
 | 
			
		||||
        data = {
 | 
			
		||||
            "started": started,
 | 
			
		||||
            "total_time": total_time,
 | 
			
		||||
            "waiting_time": waiting_time,
 | 
			
		||||
            "total_bytes": total_bytes,
 | 
			
		||||
            "mood": mood,
 | 
			
		||||
        }
 | 
			
		||||
        self.events.append(data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@implementer(IUsageWriter)
 | 
			
		||||
class LogFileUsageRecorder:
 | 
			
		||||
    """
 | 
			
		||||
    Writes usage records to a file. The records are written in JSON,
 | 
			
		||||
    one record per line.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, writable_file):
 | 
			
		||||
        self._file = writable_file
 | 
			
		||||
 | 
			
		||||
    def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None):
 | 
			
		||||
        """
 | 
			
		||||
        IUsageWriter.
 | 
			
		||||
        """
 | 
			
		||||
        data = {
 | 
			
		||||
            "started": started,
 | 
			
		||||
            "total_time": total_time,
 | 
			
		||||
            "waiting_time": waiting_time,
 | 
			
		||||
            "total_bytes": total_bytes,
 | 
			
		||||
            "mood": mood,
 | 
			
		||||
        }
 | 
			
		||||
        self._file.write(json.dumps(data) + "\n")
 | 
			
		||||
        self._file.flush()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@implementer(IUsageWriter)
 | 
			
		||||
class DatabaseUsageRecorder:
 | 
			
		||||
    """
 | 
			
		||||
    Write usage records into a database
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, db):
 | 
			
		||||
        self._db = db
 | 
			
		||||
 | 
			
		||||
    def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None):
 | 
			
		||||
        """
 | 
			
		||||
        IUsageWriter.
 | 
			
		||||
        """
 | 
			
		||||
        self._db.execute(
 | 
			
		||||
            "INSERT INTO `usage`"
 | 
			
		||||
            " (`started`, `total_time`, `waiting_time`,"
 | 
			
		||||
            "  `total_bytes`, `result`)"
 | 
			
		||||
            " VALUES (?,?,?,?,?)",
 | 
			
		||||
            (started, total_time, waiting_time, total_bytes, mood)
 | 
			
		||||
        )
 | 
			
		||||
        # original code did "self._update_stats()" here, thus causing
 | 
			
		||||
        # "global" stats update on every connection update .. should
 | 
			
		||||
        # we repeat this behavior, or really only record every
 | 
			
		||||
        # 60-seconds with the timer?
 | 
			
		||||
        self._db.commit()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UsageTracker(object):
 | 
			
		||||
    """
 | 
			
		||||
    Tracks usage statistics of connections
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, blur_usage):
 | 
			
		||||
        """
 | 
			
		||||
        :param int blur_usage: None or the number of seconds to use as a
 | 
			
		||||
            window around which to blur time statistics (e.g. "60" means times
 | 
			
		||||
            will be rounded to 1 minute intervals). When blur_usage is
 | 
			
		||||
            non-zero, sizes will also be rounded into buckets of "one
 | 
			
		||||
            megabyte", "one gigabyte" or "lots"
 | 
			
		||||
        """
 | 
			
		||||
        self._backends = set()
 | 
			
		||||
        self._blur_usage = blur_usage
 | 
			
		||||
        if blur_usage:
 | 
			
		||||
            log.msg("blurring access times to %d seconds" % self._blur_usage)
 | 
			
		||||
        else:
 | 
			
		||||
            log.msg("not blurring access times")
 | 
			
		||||
 | 
			
		||||
    def add_backend(self, backend):
 | 
			
		||||
        """
 | 
			
		||||
        Add a new backend.
 | 
			
		||||
 | 
			
		||||
        :param IUsageWriter backend: the backend to add
 | 
			
		||||
        """
 | 
			
		||||
        self._backends.add(backend)
 | 
			
		||||
 | 
			
		||||
    def record(self, started, buddy_started, result, bytes_sent, buddy_bytes):
 | 
			
		||||
        """
 | 
			
		||||
        :param int started: timestamp when our connection started
 | 
			
		||||
 | 
			
		||||
        :param int buddy_started: None, or the timestamp when our
 | 
			
		||||
            partner's connection started (will be None if we don't yet
 | 
			
		||||
            have a partner).
 | 
			
		||||
 | 
			
		||||
        :param str result: a label for the result of the connection
 | 
			
		||||
            (one of the "moods").
 | 
			
		||||
 | 
			
		||||
        :param int bytes_sent: number of bytes we sent
 | 
			
		||||
 | 
			
		||||
        :param int buddy_bytes: number of bytes our partner sent
 | 
			
		||||
        """
 | 
			
		||||
        # ideally self._reactor.seconds() or similar, but ..
 | 
			
		||||
        finished = time.time()
 | 
			
		||||
        if buddy_started is not None:
 | 
			
		||||
            starts = [started, buddy_started]
 | 
			
		||||
            total_time = finished - min(starts)
 | 
			
		||||
            waiting_time = max(starts) - min(starts)
 | 
			
		||||
            total_bytes = bytes_sent + buddy_bytes
 | 
			
		||||
        else:
 | 
			
		||||
            total_time = finished - started
 | 
			
		||||
            waiting_time = None
 | 
			
		||||
            total_bytes = bytes_sent
 | 
			
		||||
            # note that "bytes_sent" should always be 0 here, but
 | 
			
		||||
            # we're recording what the state-machine remembered in any
 | 
			
		||||
            # case
 | 
			
		||||
 | 
			
		||||
        if self._blur_usage:
 | 
			
		||||
            started = self._blur_usage * (started // self._blur_usage)
 | 
			
		||||
            total_bytes = blur_size(total_bytes)
 | 
			
		||||
 | 
			
		||||
        # This is "a dict" instead of "kwargs" because we have to make
 | 
			
		||||
        # it into a dict for the log use-case and in-memory/testing
 | 
			
		||||
        # use-case anyway so this is less repeats of the names.
 | 
			
		||||
        self._notify_backends({
 | 
			
		||||
            "started": started,
 | 
			
		||||
            "total_time": total_time,
 | 
			
		||||
            "waiting_time": waiting_time,
 | 
			
		||||
            "total_bytes": total_bytes,
 | 
			
		||||
            "mood": result,
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
    def update_stats(self, rebooted, updated, connected, waiting,
 | 
			
		||||
                     incomplete_bytes):
 | 
			
		||||
        """
 | 
			
		||||
        Update general statistics.
 | 
			
		||||
        """
 | 
			
		||||
        # in original code, this is only recorded in the database
 | 
			
		||||
        # .. perhaps a better way to do this, but ..
 | 
			
		||||
        for backend in self._backends:
 | 
			
		||||
            if isinstance(backend, DatabaseUsageRecorder):
 | 
			
		||||
                backend._db.execute("DELETE FROM `current`")
 | 
			
		||||
                backend._db.execute(
 | 
			
		||||
                    "INSERT INTO `current`"
 | 
			
		||||
                    " (`rebooted`, `updated`, `connected`, `waiting`,"
 | 
			
		||||
                    "  `incomplete_bytes`)"
 | 
			
		||||
                    " VALUES (?, ?, ?, ?, ?)",
 | 
			
		||||
                    (int(rebooted), int(updated), connected, waiting,
 | 
			
		||||
                     incomplete_bytes)
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    def _notify_backends(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        Internal helper. Tell every backend we have about a new usage record.
 | 
			
		||||
        """
 | 
			
		||||
        for backend in self._backends:
 | 
			
		||||
            backend.record_usage(**data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def round_to(size, coarseness):
 | 
			
		||||
    return int(coarseness*(1+int((size-1)/coarseness)))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def blur_size(size):
 | 
			
		||||
    if size == 0:
 | 
			
		||||
        return 0
 | 
			
		||||
    if size < 1e6:
 | 
			
		||||
        return round_to(size, 10e3)
 | 
			
		||||
    if size < 1e9:
 | 
			
		||||
        return round_to(size, 1e6)
 | 
			
		||||
    return round_to(size, 100e6)
 | 
			
		||||
							
								
								
									
										82
									
								
								ws_client.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								ws_client.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,82 @@
 | 
			
		|||
"""
 | 
			
		||||
This is a test-client for the transit-relay that uses WebSockets.
 | 
			
		||||
 | 
			
		||||
If an additional command-line argument (anything) is added, it will
 | 
			
		||||
send 5 messages upon connection. Otherwise, it just prints out what is
 | 
			
		||||
received. Uses a fixed token of 64 'a' characters. Always connects on
 | 
			
		||||
localhost:4002
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
from twisted.internet import endpoints
 | 
			
		||||
from twisted.internet.defer import (
 | 
			
		||||
    Deferred,
 | 
			
		||||
    inlineCallbacks,
 | 
			
		||||
)
 | 
			
		||||
from twisted.internet.task import react, deferLater
 | 
			
		||||
 | 
			
		||||
from autobahn.twisted.websocket import (
 | 
			
		||||
    WebSocketClientProtocol,
 | 
			
		||||
    WebSocketClientFactory,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RelayEchoClient(WebSocketClientProtocol):
 | 
			
		||||
 | 
			
		||||
    def onOpen(self):
 | 
			
		||||
        self._received = b""
 | 
			
		||||
        self.sendMessage(
 | 
			
		||||
            u"please relay {} for side {}".format(
 | 
			
		||||
                self.factory.token,
 | 
			
		||||
                self.factory.side,
 | 
			
		||||
            ).encode("ascii"),
 | 
			
		||||
            True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def onMessage(self, data, isBinary):
 | 
			
		||||
        print(">onMessage: {} bytes".format(len(data)))
 | 
			
		||||
        print(data, isBinary)
 | 
			
		||||
        if data == b"ok\n":
 | 
			
		||||
            self.factory.ready.callback(None)
 | 
			
		||||
        else:
 | 
			
		||||
            self._received += data
 | 
			
		||||
            if False:
 | 
			
		||||
                # test abrupt hangup from receiving side
 | 
			
		||||
                self.transport.loseConnection()
 | 
			
		||||
 | 
			
		||||
    def onClose(self, wasClean, code, reason):
 | 
			
		||||
        print(">onClose", wasClean, code, reason)
 | 
			
		||||
        self.factory.done.callback(reason)
 | 
			
		||||
        if not self.factory.ready.called:
 | 
			
		||||
            self.factory.ready.errback(RuntimeError(reason))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@react
 | 
			
		||||
@inlineCallbacks
 | 
			
		||||
def main(reactor):
 | 
			
		||||
    will_send_message = len(sys.argv) > 1
 | 
			
		||||
    ep = endpoints.clientFromString(reactor, "tcp:localhost:4002")
 | 
			
		||||
    f = WebSocketClientFactory("ws://127.0.0.1:4002/")
 | 
			
		||||
    f.reactor = reactor
 | 
			
		||||
    f.protocol = RelayEchoClient
 | 
			
		||||
    f.token = "a" * 64
 | 
			
		||||
    f.side = "0" * 16 if will_send_message else "1" * 16
 | 
			
		||||
    f.done = Deferred()
 | 
			
		||||
    f.ready = Deferred()
 | 
			
		||||
 | 
			
		||||
    proto = yield ep.connect(f)
 | 
			
		||||
    print("proto", proto)
 | 
			
		||||
    yield f.ready
 | 
			
		||||
 | 
			
		||||
    print("ready")
 | 
			
		||||
    if will_send_message:
 | 
			
		||||
        for _ in range(5):
 | 
			
		||||
            print("sending message")
 | 
			
		||||
            proto.sendMessage(b"it's a message", True)
 | 
			
		||||
            yield deferLater(reactor, 0.2)
 | 
			
		||||
        yield proto.sendClose()
 | 
			
		||||
        print("closing")
 | 
			
		||||
    yield f.done
 | 
			
		||||
    print("relayed {} bytes:".format(len(proto._received)))
 | 
			
		||||
    print(proto._received.decode("utf8"))
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user