diff --git a/client.py b/client.py new file mode 100644 index 0000000..5c7d235 --- /dev/null +++ b/client.py @@ -0,0 +1,54 @@ +""" +This is a test-client for the transit-relay that uses TCP. It +doesn't send any data, only prints out data that is received. Uses a +fixed token of 64 'a' characters. Always connects on localhost:4001 +""" + + +from twisted.internet import endpoints +from twisted.internet.defer import ( + Deferred, +) +from twisted.internet.task import react +from twisted.internet.error import ( + ConnectionDone, +) +from twisted.internet.protocol import ( + Protocol, + Factory, +) + + +class RelayEchoClient(Protocol): + """ + Speaks the version1 magic wormhole transit relay protocol (as a client) + """ + + def connectionMade(self): + print(">CONNECT") + self.data = b"" + self.transport.write(u"please relay {}\n".format(self.factory.token).encode("ascii")) + + def dataReceived(self, data): + print(">RECV {} bytes".format(len(data))) + print(data.decode("ascii")) + self.data += data + if data == "ok\n": + self.transport.write("ding\n") + + def connectionLost(self, reason): + if isinstance(reason.value, ConnectionDone): + self.factory.done.callback(None) + else: + print(">DISCONNCT: {}".format(reason)) + self.factory.done.callback(reason) + + +@react +def main(reactor): + ep = endpoints.clientFromString(reactor, "tcp:localhost:4001") + f = Factory.forProtocol(RelayEchoClient) + f.token = "a" * 64 + f.done = Deferred() + ep.connect(f) + return f.done diff --git a/docs/running.md b/docs/running.md index 5ea0601..5908584 100644 --- a/docs/running.md +++ b/docs/running.md @@ -50,6 +50,15 @@ The relevant arguments are: * ``--usage-db=``: maintains a SQLite database with current and historical usage data * ``--blur-usage=``: round logged timestamps and data sizes +For WebSockets support, two additional arguments: + +* ``--websocket``: the endpoint to listen for websocket connections + on, like ``tcp:4002`` +* ``--websocket-url``: the URL of the WebSocket connection. This may + be different from the listening endpoint because of port-forwarding + and so forth. By default it will be ``ws://localhost:`` if not + provided + When you use ``twist``, the relay runs in the foreground, so it will generally exit as soon as the controlling terminal exits. For persistent environments, you should daemonize the server. diff --git a/setup.py b/setup.py index 92c87c9..7119506 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,8 @@ setup(name="magic-wormhole-transit-relay", ], package_data={"wormhole_transit_relay": ["db-schemas/*.sql"]}, install_requires=[ - "twisted >= 17.5.0", + "twisted >= 21.2.0", + "autobahn >= 21.3.1", ], extras_require={ ':sys_platform=="win32"': ["pypiwin32"], diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py new file mode 100644 index 0000000..6018252 --- /dev/null +++ b/src/wormhole_transit_relay/server_state.py @@ -0,0 +1,477 @@ +from collections import defaultdict + +import automat +from twisted.python import log +from zope.interface import ( + Interface, + Attribute, +) + + +class ITransitClient(Interface): + """ + Represents the client side of a connection to this transit + relay. This is used by TransitServerState instances. + """ + + started_time = Attribute("timestamp when the connection was established") + + def send(data): + """ + Send some byets to the client + """ + + def disconnect(): + """ + Disconnect the client transport + """ + + def connect_partner(other): + """ + Hook up to our partner. + :param ITransitClient other: our partner + """ + + def disconnect_partner(): + """ + Disconnect our partner's transport + """ + + +class ActiveConnections(object): + """ + Tracks active connections. + + A connection is 'active' when both sides have shown up and they + are glued together (and thus could be passing data back and forth + if any is flowing). + """ + def __init__(self): + self._connections = set() + + def register(self, side0, side1): + """ + A connection has become active so register both its sides + + :param TransitConnection side0: one side of the connection + :param TransitConnection side1: one side of the connection + """ + self._connections.add(side0) + self._connections.add(side1) + + def unregister(self, side): + """ + One side of a connection has become inactive. + + :param TransitConnection side: an inactive side of a connection + """ + self._connections.discard(side) + + +class PendingRequests(object): + """ + Tracks outstanding (non-"active") requests. + + We register client connections against the tokens we have + received. When the other side shows up we can thus match it to the + correct partner connection. At this point, the connection becomes + "active" is and is thus no longer "pending" and so will no longer + be in this collection. + """ + + def __init__(self, active_connections): + """ + :param active_connections: an instance of ActiveConnections where + connections are put when both sides arrive. + """ + self._requests = defaultdict(set) # token -> set((side, TransitConnection)) + self._active = active_connections + + def unregister(self, token, side, tc): + """ + We no longer care about a particular client (e.g. it has + disconnected). + """ + if token in self._requests: + self._requests[token].discard((side, tc)) + if not self._requests[token]: + # no more sides; token is dead + del self._requests[token] + self._active.unregister(tc) + + def register(self, token, new_side, new_tc): + """ + A client has connected and successfully offered a token (and + optional 'side' token). If this is the first one for this + token, we merely remember it. If it is the second side for + this token we connect them together. + + :param bytes token: the token for this connection. + + :param bytes new_side: None or the side token for this connection + + :param TransitServerState new_tc: the state-machine of the connection + + :returns bool: True if we are the first side to register this + token + """ + potentials = self._requests[token] + for old in potentials: + (old_side, old_tc) = old + if ((old_side is None) + or (new_side is None) + or (old_side != new_side)): + # we found a match + + # drop and stop tracking the rest + potentials.remove(old) + for (_, leftover_tc) in potentials.copy(): + # Don't record this as errory. It's just a spare connection + # from the same side as a connection that got used. This + # can happen if the connection hint contains multiple + # addresses (we don't currently support those, but it'd + # probably be useful in the future). + leftover_tc.partner_connection_lost() + self._requests.pop(token, None) + + # glue the two ends together + self._active.register(new_tc, old_tc) + new_tc.got_partner(old_tc) + old_tc.got_partner(new_tc) + return False + + potentials.add((new_side, new_tc)) + return True + # TODO: timer + + +class TransitServerState(object): + """ + Encapsulates the state-machine of the server side of a transit + relay connection. + + Once the protocol has been told to relay (or to relay for a side) + it starts passing all received bytes to the other side until it + closes. + """ + + _machine = automat.MethodicalMachine() + _client = None + _buddy = None + _token = None + _side = None + _first = None + _mood = "empty" + _total_sent = 0 + + def __init__(self, pending_requests, usage_recorder): + self._pending_requests = pending_requests + self._usage = usage_recorder + + def get_token(self): + """ + :returns str: a string describing our token. This will be "-" if + we have no token yet, or "{16 chars}-" if we have + just a token or "{16 chars}-{16 chars}" if we have a token and + a side. + """ + d = "-" + if self._token is not None: + d = self._token[:16].decode("ascii") + + if self._side is not None: + d += "-" + self._side.decode("ascii") + else: + d += "-" + return d + + @_machine.input() + def connection_made(self, client): + """ + A client has connected. May only be called once. + + :param ITransitClient client: our client. + """ + # NB: the "only called once" is enforced by the state-machine; + # this input is only valid for the "listening" state, to which + # we never return. + + @_machine.input() + def please_relay(self, token): + """ + A 'please relay X' message has been received (the original version + of the protocol). + """ + + @_machine.input() + def please_relay_for_side(self, token, side): + """ + A 'please relay X for side Y' message has been received (the + second version of the protocol). + """ + + @_machine.input() + def bad_token(self): + """ + A bad token / relay line was received (e.g. couldn't be parsed) + """ + + @_machine.input() + def got_partner(self, client): + """ + The partner for this relay session has been found + """ + + @_machine.input() + def connection_lost(self): + """ + Our transport has failed. + """ + + @_machine.input() + def partner_connection_lost(self): + """ + Our partner's transport has failed. + """ + + @_machine.input() + def got_bytes(self, data): + """ + Some bytes have arrived (that aren't part of the handshake) + """ + + @_machine.output() + def _remember_client(self, client): + self._client = client + + # note that there is no corresponding "_forget_client" because we + # may still want to access it after it is gone .. for example, to + # get the .started_time for logging purposes + + @_machine.output() + def _register_token(self, token): + return self._real_register_token_for_side(token, None) + + @_machine.output() + def _register_token_for_side(self, token, side): + return self._real_register_token_for_side(token, side) + + @_machine.output() + def _unregister(self): + """ + remove us from the thing that remembers tokens and sides + """ + return self._pending_requests.unregister(self._token, self._side, self) + + @_machine.output() + def _send_bad(self): + self._mood = "errory" + self._client.send(b"bad handshake\n") + if self._client.factory.log_requests: + log.msg("transit handshake failure") + + @_machine.output() + def _send_ok(self): + self._client.send(b"ok\n") + + @_machine.output() + def _send_impatient(self): + self._client.send(b"impatient\n") + if self._client.factory.log_requests: + log.msg("transit impatience failure") + + @_machine.output() + def _count_bytes(self, data): + self._total_sent += len(data) + + @_machine.output() + def _send_to_partner(self, data): + self._buddy._client.send(data) + + @_machine.output() + def _connect_partner(self, client): + self._buddy = client + self._client.connect_partner(client) + + @_machine.output() + def _disconnect(self): + self._client.disconnect() + + @_machine.output() + def _disconnect_partner(self): + self._client.disconnect_partner() + + # some outputs to record "usage" information .. + @_machine.output() + def _record_usage(self): + if self._mood == "jilted": + if self._buddy and self._buddy._mood == "happy": + return + self._usage.record( + started=self._client.started_time, + buddy_started=self._buddy._client.started_time if self._buddy is not None else None, + result=self._mood, + bytes_sent=self._total_sent, + buddy_bytes=self._buddy._total_sent if self._buddy is not None else None + ) + + # some outputs to record the "mood" .. + @_machine.output() + def _mood_happy(self): + self._mood = "happy" + + @_machine.output() + def _mood_lonely(self): + self._mood = "lonely" + + @_machine.output() + def _mood_redundant(self): + self._mood = "redundant" + + @_machine.output() + def _mood_impatient(self): + self._mood = "impatient" + + @_machine.output() + def _mood_errory(self): + self._mood = "errory" + + @_machine.output() + def _mood_happy_if_first(self): + """ + We disconnected first so we're only happy if we also connected + first. + """ + if self._first: + self._mood = "happy" + else: + self._mood = "jilted" + + def _real_register_token_for_side(self, token, side): + """ + A client has connected and sent a valid version 1 or version 2 + handshake. If the former, `side` will be None. + + In either case, we remember the tokens and register + ourselves. This might result in 'got_partner' notifications to + two state-machines if this is the second side for a given token. + + :param bytes token: the token + :param bytes side: The side token (or None) + """ + self._token = token + self._side = side + self._first = self._pending_requests.register(token, side, self) + + @_machine.state(initial=True) + def listening(self): + """ + Initial state, awaiting connection. + """ + + @_machine.state() + def wait_relay(self): + """ + Waiting for a 'relay' message + """ + + @_machine.state() + def wait_partner(self): + """ + Waiting for our partner to connect + """ + + @_machine.state() + def relaying(self): + """ + Relaying bytes to our partner + """ + + @_machine.state() + def done(self): + """ + Terminal state + """ + + listening.upon( + connection_made, + enter=wait_relay, + outputs=[_remember_client], + ) + listening.upon( + connection_lost, + enter=done, + outputs=[_mood_errory], + ) + + wait_relay.upon( + please_relay, + enter=wait_partner, + outputs=[_mood_lonely, _register_token], + ) + wait_relay.upon( + please_relay_for_side, + enter=wait_partner, + outputs=[_mood_lonely, _register_token_for_side], + ) + wait_relay.upon( + bad_token, + enter=done, + outputs=[_mood_errory, _send_bad, _disconnect, _record_usage], + ) + wait_relay.upon( + got_bytes, + enter=done, + outputs=[_count_bytes, _mood_errory, _disconnect, _record_usage], + ) + wait_relay.upon( + connection_lost, + enter=done, + outputs=[_disconnect, _record_usage], + ) + + wait_partner.upon( + got_partner, + enter=relaying, + outputs=[_mood_happy, _send_ok, _connect_partner], + ) + wait_partner.upon( + connection_lost, + enter=done, + outputs=[_mood_lonely, _unregister, _record_usage], + ) + wait_partner.upon( + got_bytes, + enter=done, + outputs=[_mood_impatient, _send_impatient, _disconnect, _unregister, _record_usage], + ) + wait_partner.upon( + partner_connection_lost, + enter=done, + outputs=[_mood_redundant, _disconnect, _record_usage], + ) + + relaying.upon( + got_bytes, + enter=relaying, + outputs=[_count_bytes, _send_to_partner], + ) + relaying.upon( + connection_lost, + enter=done, + outputs=[_mood_happy_if_first, _disconnect_partner, _unregister, _record_usage], + ) + + done.upon( + connection_lost, + enter=done, + outputs=[], + ) + done.upon( + partner_connection_lost, + enter=done, + outputs=[], + ) + + # uncomment to turn on state-machine tracing + # set_trace_function = _machine._setTrace diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index 8fbfde2..0db3ef6 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -5,8 +5,14 @@ from twisted.application.service import MultiService from twisted.application.internet import (TimerService, StreamServerEndpointService) from twisted.internet import endpoints +from twisted.internet import protocol + +from autobahn.twisted.websocket import WebSocketServerFactory + from . import transit_server +from .usage import create_usage_tracker from .increase_rlimits import increase_rlimits +from .database import get_db LONGDESC = """\ This plugin sets up a 'Transit Relay' server for magic-wormhole. This service @@ -20,6 +26,8 @@ class Options(usage.Options): optParameters = [ ("port", "p", "tcp:4001:interface=\:\:", "endpoint to listen on"), + ("websocket", "w", None, "endpoint to listen for WebSocket connections"), + ("websocket-url", "u", None, "WebSocket URL (derived from endpoint if not provided)"), ("blur-usage", None, None, "blur timestamps and data sizes in logs"), ("log-fd", None, None, "write JSON usage logs to this file descriptor"), ("usage-db", None, None, "record usage data (SQLite)"), @@ -31,14 +39,45 @@ class Options(usage.Options): def makeService(config, reactor=reactor): increase_rlimits() - ep = endpoints.serverFromString(reactor, config["port"]) # to listen - log_file = (os.fdopen(int(config["log-fd"]), "w") - if config["log-fd"] is not None - else None) - f = transit_server.Transit(blur_usage=config["blur-usage"], - log_file=log_file, - usage_db=config["usage-db"]) + tcp_ep = endpoints.serverFromString(reactor, config["port"]) # to listen + ws_ep = ( + endpoints.serverFromString(reactor, config["websocket"]) + if config["websocket"] is not None + else None + ) + log_file = ( + os.fdopen(int(config["log-fd"]), "w") + if config["log-fd"] is not None + else None + ) + db = None if config["usage-db"] is None else get_db(config["usage-db"]) + usage = create_usage_tracker( + blur_usage=config["blur-usage"], + log_file=log_file, + usage_db=db, + ) + transit = transit_server.Transit(usage, reactor.seconds) + tcp_factory = protocol.ServerFactory() + tcp_factory.protocol = transit_server.TransitConnection + tcp_factory.log_requests = False + + if ws_ep is not None: + ws_url = config["websocket-url"] + if ws_url is None: + # we're using a "private" attribute here but I don't see + # any useful alternative unless we also want to parse + # Twisted endpoint-strings. + ws_url = "ws://localhost:{}/".format(ws_ep._port) + print("Using WebSocket URL '{}'".format(ws_url)) + ws_factory = WebSocketServerFactory(ws_url) + ws_factory.protocol = transit_server.WebSocketTransitConnection + ws_factory.transit = transit + ws_factory.log_requests = False + + tcp_factory.transit = transit parent = MultiService() - StreamServerEndpointService(ep, f).setServiceParent(parent) - TimerService(5*60.0, f.timerUpdateStats).setServiceParent(parent) + StreamServerEndpointService(tcp_ep, tcp_factory).setServiceParent(parent) + if ws_ep is not None: + StreamServerEndpointService(ws_ep, ws_factory).setServiceParent(parent) + TimerService(5*60.0, transit.update_stats).setServiceParent(parent) return parent diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 8073ee0..4b2469f 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -10,7 +10,10 @@ from zope.interface import ( ) from ..transit_server import ( Transit, + TransitConnection, ) +from twisted.internet.protocol import ServerFactory +from ..usage import create_usage_tracker class IRelayTestClient(Interface): @@ -42,6 +45,7 @@ class IRelayTestClient(Interface): Erase any received data to this point. """ + class ServerBase: log_requests = False @@ -62,19 +66,30 @@ class ServerBase: self.flush() def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): - self._transit_server = Transit( + usage = create_usage_tracker( blur_usage=blur_usage, log_file=log_file, usage_db=usage_db, ) - self._transit_server._debug_log = self.log_requests + self._transit_server = Transit(usage, lambda: 123456789.0) def new_protocol(self): + """ + This should be overridden by derived test-case classes to decide + if they want a TCP or WebSockets protocol. + """ + raise NotImplementedError() + + def new_protocol_tcp(self): """ Create a new client protocol connected to the server. :returns: a IRelayTestClient implementation """ - server_protocol = self._transit_server.buildProtocol(('127.0.0.1', 0)) + server_factory = ServerFactory() + server_factory.protocol = TransitConnection + server_factory.transit = self._transit_server + server_factory.log_requests = self.log_requests + server_protocol = server_factory.buildProtocol(('127.0.0.1', 0)) @implementer(IRelayTestClient) class TransitClientProtocolTcp(Protocol): diff --git a/src/wormhole_transit_relay/test/test_config.py b/src/wormhole_transit_relay/test/test_config.py index b27ffd5..b2bb7e8 100644 --- a/src/wormhole_transit_relay/test/test_config.py +++ b/src/wormhole_transit_relay/test/test_config.py @@ -8,12 +8,29 @@ class Config(unittest.TestCase): o = server_tap.Options() o.parseOptions([]) self.assertEqual(o, {"blur-usage": None, "log-fd": None, - "usage-db": None, "port": PORT}) + "usage-db": None, "port": PORT, + "websocket": None, "websocket-url": None}) def test_blur(self): o = server_tap.Options() o.parseOptions(["--blur-usage=60"]) self.assertEqual(o, {"blur-usage": 60, "log-fd": None, - "usage-db": None, "port": PORT}) + "usage-db": None, "port": PORT, + "websocket": None, "websocket-url": None}) + + def test_websocket(self): + o = server_tap.Options() + o.parseOptions(["--websocket=tcp:4004"]) + self.assertEqual(o, {"blur-usage": None, "log-fd": None, + "usage-db": None, "port": PORT, + "websocket": "tcp:4004", "websocket-url": None}) + + def test_websocket_url(self): + o = server_tap.Options() + o.parseOptions(["--websocket=tcp:4004", "--websocket-url=ws://example.com/"]) + self.assertEqual(o, {"blur-usage": None, "log-fd": None, + "usage-db": None, "port": PORT, + "websocket": "tcp:4004", + "websocket-url": "ws://example.com/"}) def test_string(self): o = server_tap.Options() diff --git a/src/wormhole_transit_relay/test/test_service.py b/src/wormhole_transit_relay/test/test_service.py index 1532f56..9ab30c8 100644 --- a/src/wormhole_transit_relay/test/test_service.py +++ b/src/wormhole_transit_relay/test/test_service.py @@ -1,13 +1,14 @@ from twisted.trial import unittest from unittest import mock from twisted.application.service import MultiService +from autobahn.twisted.websocket import WebSocketServerFactory from .. import server_tap class Service(unittest.TestCase): def test_defaults(self): o = server_tap.Options() o.parseOptions([]) - with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: + with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t: s = server_tap.makeService(o) self.assertEqual(t.mock_calls, [mock.call(blur_usage=None, @@ -17,7 +18,7 @@ class Service(unittest.TestCase): def test_blur(self): o = server_tap.Options() o.parseOptions(["--blur-usage=60"]) - with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: + with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t: server_tap.makeService(o) self.assertEqual(t.mock_calls, [mock.call(blur_usage=60, @@ -27,7 +28,7 @@ class Service(unittest.TestCase): o = server_tap.Options() o.parseOptions(["--log-fd=99"]) fd = object() - with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: + with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t: with mock.patch("wormhole_transit_relay.server_tap.os.fdopen", return_value=fd) as f: server_tap.makeService(o) @@ -36,3 +37,34 @@ class Service(unittest.TestCase): [mock.call(blur_usage=None, log_file=fd, usage_db=None)]) + def test_websocket(self): + """ + A websocket factory is created when passing --websocket + """ + o = server_tap.Options() + o.parseOptions(["--websocket=tcp:4004"]) + services = server_tap.makeService(o) + self.assertTrue( + any( + isinstance(s.factory, WebSocketServerFactory) + for s in services.services + ) + ) + + def test_websocket_explicit_url(self): + """ + A websocket factory is created with --websocket and + --websocket-url + """ + o = server_tap.Options() + o.parseOptions([ + "--websocket=tcp:4004", + "--websocket-url=ws://example.com:4004", + ]) + services = server_tap.makeService(o) + self.assertTrue( + any( + isinstance(s.factory, WebSocketServerFactory) + for s in services.services + ) + ) diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index be17d91..3f85071 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -1,27 +1,38 @@ -import os, io, json, sqlite3 +import os, io, json from unittest import mock from twisted.trial import unittest from ..transit_server import Transit +from ..usage import create_usage_tracker from .. import database class DB(unittest.TestCase): - def open_db(self, dbfile): - db = sqlite3.connect(dbfile) - database._initialize_db_connection(db) - return db def test_db(self): + T = 1519075308.0 + + class Timer: + t = T + def __call__(self): + return self.t + get_time = Timer() + d = self.mktemp() os.mkdir(d) usage_db = os.path.join(d, "usage.sqlite") - with mock.patch("time.time", return_value=T+0): - t = Transit(blur_usage=None, log_file=None, usage_db=usage_db) - db = self.open_db(usage_db) + db = database.get_db(usage_db) + t = Transit( + create_usage_tracker(blur_usage=None, log_file=None, usage_db=db), + get_time, + ) + self.assertEqual(len(t.usage._backends), 1) + usage = list(t.usage._backends)[0] + + get_time.t = T + 1 + usage.record_usage(started=123, mood="happy", total_bytes=100, + total_time=10, waiting_time=2) + t.update_stats() - with mock.patch("time.time", return_value=T+1): - t.recordUsage(started=123, result="happy", total_bytes=100, - total_time=10, waiting_time=2) self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), [dict(result="happy", started=123, total_bytes=100, total_time=10, waiting_time=2), @@ -31,9 +42,10 @@ class DB(unittest.TestCase): incomplete_bytes=0, waiting=0, connected=0)) - with mock.patch("time.time", return_value=T+2): - t.recordUsage(started=150, result="errory", total_bytes=200, - total_time=11, waiting_time=3) + get_time.t = T + 2 + usage.record_usage(started=150, mood="errory", total_bytes=200, + total_time=11, waiting_time=3) + t.update_stats() self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), [dict(result="happy", started=123, total_bytes=100, total_time=10, waiting_time=2), @@ -45,27 +57,37 @@ class DB(unittest.TestCase): incomplete_bytes=0, waiting=0, connected=0)) - with mock.patch("time.time", return_value=T+3): - t.timerUpdateStats() + get_time.t = T + 3 + t.update_stats() self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(), dict(rebooted=T+0, updated=T+3, incomplete_bytes=0, waiting=0, connected=0)) def test_no_db(self): - t = Transit(blur_usage=None, log_file=None, usage_db=None) + t = Transit( + create_usage_tracker(blur_usage=None, log_file=None, usage_db=None), + lambda: 0, + ) + self.assertEqual(0, len(t.usage._backends)) - t.recordUsage(started=123, result="happy", total_bytes=100, - total_time=10, waiting_time=2) - t.timerUpdateStats() class LogToStdout(unittest.TestCase): def test_log(self): # emit lines of JSON to log_file, if set log_file = io.StringIO() - t = Transit(blur_usage=None, log_file=log_file, usage_db=None) - t.recordUsage(started=123, result="happy", total_bytes=100, - total_time=10, waiting_time=2) + t = Transit( + create_usage_tracker(blur_usage=None, log_file=log_file, usage_db=None), + lambda: 0, + ) + with mock.patch("time.time", return_value=133): + t.usage.record( + started=123, + buddy_started=125, + result="happy", + bytes_sent=100, + buddy_bytes=0, + ) self.assertEqual(json.loads(log_file.getvalue()), {"started": 123, "total_time": 10, "waiting_time": 2, "total_bytes": 100, @@ -75,15 +97,34 @@ class LogToStdout(unittest.TestCase): # if blurring is enabled, timestamps should be rounded to the # requested amount, and sizes should be rounded up too log_file = io.StringIO() - t = Transit(blur_usage=60, log_file=log_file, usage_db=None) - t.recordUsage(started=123, result="happy", total_bytes=11999, - total_time=10, waiting_time=2) + t = Transit( + create_usage_tracker(blur_usage=60, log_file=log_file, usage_db=None), + lambda: 0, + ) + + with mock.patch("time.time", return_value=123 + 10): + t.usage.record( + started=123, + buddy_started=125, + result="happy", + bytes_sent=11999, + buddy_bytes=0, + ) + print(log_file.getvalue()) self.assertEqual(json.loads(log_file.getvalue()), {"started": 120, "total_time": 10, "waiting_time": 2, "total_bytes": 20000, "mood": "happy"}) def test_do_not_log(self): - t = Transit(blur_usage=60, log_file=None, usage_db=None) - t.recordUsage(started=123, result="happy", total_bytes=11999, - total_time=10, waiting_time=2) + t = Transit( + create_usage_tracker(blur_usage=60, log_file=None, usage_db=None), + lambda: 0, + ) + t.usage.record( + started=123, + buddy_started=124, + result="happy", + bytes_sent=11999, + buddy_bytes=12, + ) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index fddbe14..c5370de 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -1,7 +1,30 @@ from binascii import hexlify from twisted.trial import unittest -from .common import ServerBase -from .. import transit_server +from twisted.test import iosim +from autobahn.twisted.websocket import ( + WebSocketServerFactory, + WebSocketClientFactory, + WebSocketClientProtocol, +) +from autobahn.twisted.testing import ( + create_pumper, + MemoryReactorClockResolver, +) +from autobahn.exception import Disconnected +from zope.interface import implementer +from .common import ( + ServerBase, + IRelayTestClient, +) +from ..usage import ( + MemoryUsageRecorder, + blur_size, +) +from ..transit_server import ( + WebSocketTransitConnection, + TransitServerState, +) + def handshake(token, side=None): hs = b"please relay " + hexlify(token) @@ -12,27 +35,28 @@ def handshake(token, side=None): class _Transit: def count(self): - return sum([len(potentials) - for potentials - in self._transit_server._pending_requests.values()]) + return sum([ + len(potentials) + for potentials + in self._transit_server.pending_requests._requests.values() + ]) def test_blur_size(self): - blur = transit_server.blur_size - self.failUnlessEqual(blur(0), 0) - self.failUnlessEqual(blur(1), 10e3) - self.failUnlessEqual(blur(10e3), 10e3) - self.failUnlessEqual(blur(10e3+1), 20e3) - self.failUnlessEqual(blur(15e3), 20e3) - self.failUnlessEqual(blur(20e3), 20e3) - self.failUnlessEqual(blur(1e6), 1e6) - self.failUnlessEqual(blur(1e6+1), 2e6) - self.failUnlessEqual(blur(1.5e6), 2e6) - self.failUnlessEqual(blur(2e6), 2e6) - self.failUnlessEqual(blur(900e6), 900e6) - self.failUnlessEqual(blur(1000e6), 1000e6) - self.failUnlessEqual(blur(1050e6), 1100e6) - self.failUnlessEqual(blur(1100e6), 1100e6) - self.failUnlessEqual(blur(1150e6), 1200e6) + self.failUnlessEqual(blur_size(0), 0) + self.failUnlessEqual(blur_size(1), 10e3) + self.failUnlessEqual(blur_size(10e3), 10e3) + self.failUnlessEqual(blur_size(10e3+1), 20e3) + self.failUnlessEqual(blur_size(15e3), 20e3) + self.failUnlessEqual(blur_size(20e3), 20e3) + self.failUnlessEqual(blur_size(1e6), 1e6) + self.failUnlessEqual(blur_size(1e6+1), 2e6) + self.failUnlessEqual(blur_size(1.5e6), 2e6) + self.failUnlessEqual(blur_size(2e6), 2e6) + self.failUnlessEqual(blur_size(900e6), 900e6) + self.failUnlessEqual(blur_size(1000e6), 1000e6) + self.failUnlessEqual(blur_size(1050e6), 1100e6) + self.failUnlessEqual(blur_size(1100e6), 1100e6) + self.failUnlessEqual(blur_size(1150e6), 1200e6) def test_register(self): p1 = self.new_protocol() @@ -49,7 +73,7 @@ class _Transit: self.assertEqual(self.count(), 0) # the token should be removed too - self.assertEqual(len(self._transit_server._pending_requests), 0) + self.assertEqual(len(self._transit_server.pending_requests._requests), 0) def test_both_unsided(self): p1 = self.new_protocol() @@ -75,7 +99,6 @@ class _Transit: self.assertEqual(p2.get_received_data(), s1) p1.disconnect() - p2.disconnect() self.flush() def test_sided_unsided(self): @@ -104,7 +127,6 @@ class _Transit: self.assertEqual(p2.get_received_data(), s1) p1.disconnect() - p2.disconnect() self.flush() def test_unsided_sided(self): @@ -177,6 +199,7 @@ class _Transit: p2.send(handshake(token1, side=side1)) self.flush() + self.flush() self.assertEqual(self.count(), 2) # same-side connections don't match # when the second side arrives, the spare first connection should be @@ -185,8 +208,8 @@ class _Transit: p3.send(handshake(token1, side=side2)) self.flush() self.assertEqual(self.count(), 0) - self.assertEqual(len(self._transit_server._pending_requests), 0) - self.assertEqual(len(self._transit_server._active_connections), 2) + self.assertEqual(len(self._transit_server.pending_requests._requests), 0) + self.assertEqual(len(self._transit_server.active_connections._connections), 2) # That will trigger a disconnect on exactly one of (p1 or p2). # The other connection should still be connected self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1) @@ -266,7 +289,8 @@ class _Transit: token1 = b"\x00"*32 # sending too many bytes is impatience. - p1.send(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") + p1.send(b"please relay " + hexlify(token1)) + p1.send(b"\nNOWNOWNOW") self.flush() exp = b"impatient\n" @@ -281,7 +305,8 @@ class _Transit: side1 = b"\x01"*8 # sending too many bytes is impatience. p1.send(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\nNOWNOWNOW") + b" for side " + hexlify(side1)) + p1.send(b"\nNOWNOWNOW") self.flush() exp = b"impatient\n" @@ -327,22 +352,163 @@ class _Transit: # hang up before sending anything p1.disconnect() + class TransitWithLogs(_Transit, ServerBase, unittest.TestCase): log_requests = True + def new_protocol(self): + return self.new_protocol_tcp() + + class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): log_requests = False + def new_protocol(self): + return self.new_protocol_tcp() + + +def _new_protocol_ws(transit_server, log_requests): + """ + Internal helper for test-suites that need to provide WebSocket + client/server pairs. + + :returns: a 2-tuple: (iosim.IOPump, protocol) + """ + ws_factory = WebSocketServerFactory("ws://localhost:4002") + ws_factory.protocol = WebSocketTransitConnection + ws_factory.transit = transit_server + ws_factory.log_requests = log_requests + ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) + + @implementer(IRelayTestClient) + class TransitWebSocketClientProtocol(WebSocketClientProtocol): + _received = b"" + connected = False + + def connectionMade(self): + self.connected = True + return super(TransitWebSocketClientProtocol, self).connectionMade() + + def connectionLost(self, reason): + self.connected = False + return super(TransitWebSocketClientProtocol, self).connectionLost(reason) + + def onMessage(self, data, isBinary): + self._received = self._received + data + + def send(self, data): + self.sendMessage(data, True) + + def get_received_data(self): + return self._received + + def reset_received_data(self): + self._received = b"" + + def disconnect(self): + self.sendClose(1000, True) + + client_factory = WebSocketClientFactory() + client_factory.protocol = TransitWebSocketClientProtocol + client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) + client_protocol.disconnect = client_protocol.dropConnection + + pump = iosim.connect( + ws_protocol, + iosim.makeFakeServer(ws_protocol), + client_protocol, + iosim.makeFakeClient(client_protocol), + ) + return pump, client_protocol + + + +class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): + + def new_protocol(self): + return self.new_protocol_ws() + + def new_protocol_ws(self): + pump, proto = _new_protocol_ws(self._transit_server, self.log_requests) + self._pumps.append(pump) + return proto + + def test_websocket_to_tcp(self): + """ + One client is WebSocket and one is TCP + """ + p1 = self.new_protocol_ws() + p2 = self.new_protocol_tcp() + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + side2 = b"\x02"*8 + p1.send(handshake(token1, side=side1)) + self.flush() + p2.send(handshake(token1, side=side2)) + self.flush() + + # a correct handshake yields an ack, after which we can send + exp = b"ok\n" + self.assertEqual(p1.get_received_data(), exp) + self.assertEqual(p2.get_received_data(), exp) + + p1.reset_received_data() + p2.reset_received_data() + + # all data they sent after the handshake should be given to us + s1 = b"data1" + p1.send(s1) + self.flush() + self.assertEqual(p2.get_received_data(), s1) + + p1.disconnect() + p2.disconnect() + self.flush() + + def test_bad_handshake_old_slow(self): + """ + This test only makes sense for TCP + """ + + def test_send_closed_partner(self): + """ + Sending data to a closed partner causes an error that propogates + to the sender. + """ + p1 = self.new_protocol() + p2 = self.new_protocol() + + # set up a successful connection + token = b"a" * 32 + p1.send(handshake(token)) + p2.send(handshake(token)) + self.flush() + + # p2 loses connection, then p1 sends a message + p2.transport.loseConnection() + self.flush() + + # at this point, p1 learns that p2 is disconnected (because it + # tried to relay "a message" but failed) + + # try to send more (our partner p2 is gone now though so it + # should be an immediate error) + with self.assertRaises(Disconnected): + p1.send(b"more message") + self.flush() + + class Usage(ServerBase, unittest.TestCase): log_requests = True def setUp(self): super(Usage, self).setUp() - self._usage = [] - def record(started, result, total_bytes, total_time, waiting_time): - self._usage.append((started, result, total_bytes, - total_time, waiting_time)) - self._transit_server.recordUsage = record + self._usage = MemoryUsageRecorder() + self._transit_server.usage.add_backend(self._usage) + + def new_protocol(self): + return self.new_protocol_tcp() def test_empty(self): p1 = self.new_protocol() @@ -351,11 +517,14 @@ class Usage(ServerBase, unittest.TestCase): self.flush() # that will log the "empty" usage event - self.assertEqual(len(self._usage), 1, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[0] - self.assertEqual(result, "empty", self._usage) + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual(self._usage.events[0]["mood"], "empty", self._usage) def test_short(self): + # Note: this test only runs on TCP clients because WebSockets + # already does framing (so it's either "a bad handshake" or + # there's no handshake at all yet .. you can't have a "short" + # one). p1 = self.new_protocol() # hang up before sending a complete handshake p1.send(b"short") @@ -363,9 +532,8 @@ class Usage(ServerBase, unittest.TestCase): self.flush() # that will log the "empty" usage event - self.assertEqual(len(self._usage), 1, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[0] - self.assertEqual(result, "empty", self._usage) + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual("empty", self._usage.events[0]["mood"]) def test_errory(self): p1 = self.new_protocol() @@ -374,9 +542,8 @@ class Usage(ServerBase, unittest.TestCase): self.flush() # that will log the "errory" usage event, then drop the connection p1.disconnect() - self.assertEqual(len(self._usage), 1, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[0] - self.assertEqual(result, "errory", self._usage) + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual(self._usage.events[0]["mood"], "errory", self._usage) def test_lonely(self): p1 = self.new_protocol() @@ -389,10 +556,9 @@ class Usage(ServerBase, unittest.TestCase): p1.disconnect() self.flush() - self.assertEqual(len(self._usage), 1, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[0] - self.assertEqual(result, "lonely", self._usage) - self.assertIdentical(waiting_time, None) + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual(self._usage.events[0]["mood"], "lonely", self._usage) + self.assertIdentical(self._usage.events[0]["waiting_time"], None) def test_one_happy_one_jilted(self): p1 = self.new_protocol() @@ -406,7 +572,7 @@ class Usage(ServerBase, unittest.TestCase): p2.send(handshake(token1, side=side2)) self.flush() - self.assertEqual(self._usage, []) # no events yet + self.assertEqual(self._usage.events, []) # no events yet p1.send(b"\x00" * 13) self.flush() @@ -416,11 +582,10 @@ class Usage(ServerBase, unittest.TestCase): p1.disconnect() self.flush() - self.assertEqual(len(self._usage), 1, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[0] - self.assertEqual(result, "happy", self._usage) - self.assertEqual(total_bytes, 20) - self.assertNotIdentical(waiting_time, None) + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual(self._usage.events[0]["mood"], "happy", self._usage) + self.assertEqual(self._usage.events[0]["total_bytes"], 20) + self.assertNotIdentical(self._usage.events[0]["waiting_time"], None) def test_redundant(self): p1a = self.new_protocol() @@ -443,21 +608,80 @@ class Usage(ServerBase, unittest.TestCase): p1c.disconnect() self.flush() - self.assertEqual(len(self._usage), 1, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[0] - self.assertEqual(result, "lonely", self._usage) + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual(self._usage.events[0]["mood"], "lonely") p2.send(handshake(token1, side=side2)) self.flush() - self.assertEqual(len(self._transit_server._pending_requests), 0) - self.assertEqual(len(self._usage), 2, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[1] - self.assertEqual(result, "redundant", self._usage) + self.assertEqual(len(self._transit_server.pending_requests._requests), 0) + self.assertEqual(len(self._usage.events), 2, self._usage) + self.assertEqual(self._usage.events[1]["mood"], "redundant") # one of the these is unecessary, but probably harmless p1a.disconnect() p1b.disconnect() self.flush() - self.assertEqual(len(self._usage), 3, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[2] - self.assertEqual(result, "happy", self._usage) + self.assertEqual(len(self._usage.events), 3, self._usage) + self.assertEqual(self._usage.events[2]["mood"], "happy") + + +class UsageWebSockets(Usage): + """ + All the tests of 'Usage' except with a WebSocket (instead of TCP) + transport. + + This overrides ServerBase.new_protocol to achieve this. It might + be nicer to parametrize these tests in a way that doesn't use + inheritance .. but all the support etc classes are set up that way + already. + """ + + def setUp(self): + super(UsageWebSockets, self).setUp() + self._pump = create_pumper() + self._reactor = MemoryReactorClockResolver() + return self._pump.start() + + def tearDown(self): + return self._pump.stop() + + def new_protocol(self): + return self.new_protocol_ws() + + def new_protocol_ws(self): + pump, proto = _new_protocol_ws(self._transit_server, self.log_requests) + self._pumps.append(pump) + return proto + + def test_short(self): + """ + This test essentially just tests the framing of the line-oriented + TCP protocol; it doesnt' make sense for the WebSockets case + because WS handles frameing: you either sent a 'bad handshake' + because it is semantically invalid or no handshake (yet). + """ + + def test_send_non_binary_message(self): + """ + A non-binary WebSocket message is an error + """ + ws_factory = WebSocketServerFactory("ws://localhost:4002") + ws_factory.protocol = WebSocketTransitConnection + ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) + with self.assertRaises(ValueError): + ws_protocol.onMessage(u"foo", isBinary=False) + + +class State(unittest.TestCase): + """ + Tests related to server_state.TransitServerState + """ + + def setUp(self): + self.state = TransitServerState(None, None) + + def test_empty_token(self): + self.assertEqual( + "-", + self.state.get_token(), + ) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 8b4fbf9..4b7b0b5 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -1,9 +1,9 @@ -import re, time, json -from collections import defaultdict +import re +import time from twisted.python import log -from twisted.internet import protocol from twisted.protocols.basic import LineReceiver -from .database import get_db +from autobahn.twisted.websocket import WebSocketServerProtocol + SECONDS = 1.0 MINUTE = 60*SECONDS @@ -11,340 +11,254 @@ HOUR = 60*MINUTE DAY = 24*HOUR MB = 1000*1000 -def round_to(size, coarseness): - return int(coarseness*(1+int((size-1)/coarseness))) -def blur_size(size): - if size == 0: - return 0 - if size < 1e6: - return round_to(size, 10e3) - if size < 1e9: - return round_to(size, 1e6) - return round_to(size, 100e6) +from wormhole_transit_relay.server_state import ( + TransitServerState, + PendingRequests, + ActiveConnections, + ITransitClient, +) +from zope.interface import implementer + +@implementer(ITransitClient) class TransitConnection(LineReceiver): delimiter = b'\n' # maximum length of a line we will accept before the handshake is complete. # This must be >= to the longest possible handshake message. MAX_LENGTH = 1024 + started_time = None - def __init__(self): - self._got_token = False - self._got_side = False - self._sent_ok = False - self._mood = "empty" + def send(self, data): + """ + ITransitClient API + """ + self.transport.write(data) + + def disconnect(self): + """ + ITransitClient API + """ + self.transport.loseConnection() + + def connect_partner(self, other): + """ + ITransitClient API + """ + self._buddy = other + + def disconnect_partner(self): + """ + ITransitClient API + """ + assert self._buddy is not None, "internal error: no buddy" + if self.factory.log_requests: + log.msg("buddy_disconnected {}".format(self._buddy.get_token())) + self._buddy._client.disconnect() self._buddy = None - self._total_sent = 0 - - def describeToken(self): - d = "-" - if self._got_token: - d = self._got_token[:16].decode("ascii") - if self._got_side: - d += "-" + self._got_side.decode("ascii") - else: - d += "-" - return d def connectionMade(self): - self._started = time.time() - self._log_requests = self.factory._log_requests + # ideally more like self._reactor.seconds() ... but Twisted + # doesn't have a good way to get the reactor for a protocol + # (besides "use the global one") + self.started_time = time.time() + self._state = TransitServerState( + self.factory.transit.pending_requests, + self.factory.transit.usage, + ) + self._state.connection_made(self) self.transport.setTcpKeepAlive(True) + # uncomment to turn on state-machine tracing + # def tracer(oldstate, theinput, newstate): + # print("TRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) + # self._state.set_trace_function(tracer) + def lineReceived(self, line): + """ + LineReceiver API + """ # old: "please relay {64}\n" + token = None old = re.search(br"^please relay (\w{64})$", line) if old: token = old.group(1) - return self._got_handshake(token, None) + self._state.please_relay(token) # new: "please relay {64} for side {16}\n" new = re.search(br"^please relay (\w{64}) for side (\w{16})$", line) if new: token = new.group(1) side = new.group(2) - return self._got_handshake(token, side) + self._state.please_relay_for_side(token, side) - self.sendLine(b"bad handshake") - if self._log_requests: - log.msg("transit handshake failure") - return self.disconnect_error() + if token is None: + self._state.bad_token() + else: + self.setRawMode() def rawDataReceived(self, data): + """ + LineReceiver API + """ # We are an IPushProducer to our buddy's IConsumer, so they'll # throttle us (by calling pauseProducing()) when their outbound # buffer is full (e.g. when their downstream pipe is full). In # practice, this buffers about 10MB per connection, after which # point the sender will only transmit data as fast as the # receiver can handle it. - if self._sent_ok: - # if self._buddy is None then our buddy disconnected - # (we're "jilted"), so we hung up too, but our incoming - # data hasn't stopped yet (it will in a moment, after our - # disconnect makes a roundtrip through the kernel). This - # probably means the file receiver hung up, and this - # connection is the file sender. In may-2020 this happened - # 11 times in 40 days. - if self._buddy: - self._total_sent += len(data) - self._buddy.transport.write(data) - return - - # handshake is complete but not yet sent_ok - self.sendLine(b"impatient") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect_error() # impatience yields failure - - def _got_handshake(self, token, side): - self._got_token = token - self._got_side = side - self._mood = "lonely" # until buddy connects - self.setRawMode() - self.factory.connection_got_token(token, side, self) - - def buddy_connected(self, them): - self._buddy = them - self._mood = "happy" - self.sendLine(b"ok") - self._sent_ok = True - # Connect the two as a producer/consumer pair. We use streaming=True, - # so this expects the IPushProducer interface, and uses - # pauseProducing() to throttle, and resumeProducing() to unthrottle. - self._buddy.transport.registerProducer(self.transport, True) - # The Transit object calls buddy_connected() on both protocols, so - # there will be two producer/consumer pairs. - - def buddy_disconnected(self): - if self._log_requests: - log.msg("buddy_disconnected %s" % self.describeToken()) - self._buddy = None - self._mood = "jilted" - self.transport.loseConnection() - - def disconnect_error(self): - # we haven't finished the handshake, so there are no tokens tracking - # us - self._mood = "errory" - self.transport.loseConnection() - if self.factory._debug_log: - log.msg("transitFailed %r" % self) - - def disconnect_redundant(self): - # this is called if a buddy connected and we were found unnecessary. - # Any token-tracking cleanup will have been done before we're called. - self._mood = "redundant" - self.transport.loseConnection() + self._state.got_bytes(data) def connectionLost(self, reason): - finished = time.time() - total_time = finished - self._started + self._state.connection_lost() - # Record usage. There are eight cases: - # * n0: we haven't gotten a full handshake yet (empty) - # * n1: the handshake failed, not a real client (errory) - # * n2: real client disconnected before any buddy appeared (lonely) - # * n3: real client closed as redundant after buddy appears (redundant) - # * n4: real client connected first, buddy closes first (jilted) - # * n5: real client connected first, buddy close last (happy) - # * n6: real client connected last, buddy closes first (jilted) - # * n7: real client connected last, buddy closes last (happy) - # * non-connected clients (0,1,2,3) always write a usage record - # * for connected clients, whoever disconnects first gets to write the - # usage record (5, 7). The last disconnect doesn't write a record. +class Transit(object): + """ + I manage pairs of simultaneous connections to a secondary TCP port, + both forwarded to the other. Clients must begin each connection with + "please relay TOKEN for SIDE\n" (or a legacy form without the "for + SIDE"). Two connections match if they use the same TOKEN and have + different SIDEs (the redundant connections are dropped when a match is + made). Legacy connections match any with the same TOKEN, ignoring SIDE + (so two legacy connections will match each other). - if self._mood == "empty": # 0 - assert not self._buddy - self.factory.recordUsage(self._started, "empty", 0, - total_time, None) - elif self._mood == "errory": # 1 - assert not self._buddy - self.factory.recordUsage(self._started, "errory", 0, - total_time, None) - elif self._mood == "redundant": # 3 - assert not self._buddy - self.factory.recordUsage(self._started, "redundant", 0, - total_time, None) - elif self._mood == "jilted": # 4 or 6 - # we were connected, but our buddy hung up on us. They record the - # usage event, we do not - pass - elif self._mood == "lonely": # 2 - assert not self._buddy - self.factory.recordUsage(self._started, "lonely", 0, - total_time, None) - else: # 5 or 7 - # we were connected, we hung up first. We record the event. - assert self._mood == "happy", self._mood - assert self._buddy - starts = [self._started, self._buddy._started] - total_time = finished - min(starts) - waiting_time = max(starts) - min(starts) - total_bytes = self._total_sent + self._buddy._total_sent - self.factory.recordUsage(self._started, "happy", total_bytes, - total_time, waiting_time) + I will send "ok\n" when the matching connection is established, or + disconnect if no matching connection is made within MAX_WAIT_TIME + seconds. I will disconnect if you send data before the "ok\n". All data + you get after the "ok\n" will be from the other side. You will not + receive "ok\n" until the other side has also connected and submitted a + matching token (and differing SIDE). - if self._buddy: - self._buddy.buddy_disconnected() - self.factory.transitFinished(self, self._got_token, self._got_side, - self.describeToken()) + In addition, the connections will be dropped after MAXLENGTH bytes have + been sent by either side, or MAXTIME seconds have elapsed after the + matching connections were established. A future API will reveal these + limits to clients instead of causing mysterious spontaneous failures. -class Transit(protocol.ServerFactory): - # I manage pairs of simultaneous connections to a secondary TCP port, - # both forwarded to the other. Clients must begin each connection with - # "please relay TOKEN for SIDE\n" (or a legacy form without the "for - # SIDE"). Two connections match if they use the same TOKEN and have - # different SIDEs (the redundant connections are dropped when a match is - # made). Legacy connections match any with the same TOKEN, ignoring SIDE - # (so two legacy connections will match each other). - - # I will send "ok\n" when the matching connection is established, or - # disconnect if no matching connection is made within MAX_WAIT_TIME - # seconds. I will disconnect if you send data before the "ok\n". All data - # you get after the "ok\n" will be from the other side. You will not - # receive "ok\n" until the other side has also connected and submitted a - # matching token (and differing SIDE). - - # In addition, the connections will be dropped after MAXLENGTH bytes have - # been sent by either side, or MAXTIME seconds have elapsed after the - # matching connections were established. A future API will reveal these - # limits to clients instead of causing mysterious spontaneous failures. - - # These relay connections are not half-closeable (unlike full TCP - # connections, applications will not receive any data after half-closing - # their outgoing side). Applications must negotiate shutdown with their - # peer and not close the connection until all data has finished - # transferring in both directions. Applications which only need to send - # data in one direction can use close() as usual. + These relay connections are not half-closeable (unlike full TCP + connections, applications will not receive any data after half-closing + their outgoing side). Applications must negotiate shutdown with their + peer and not close the connection until all data has finished + transferring in both directions. Applications which only need to send + data in one direction can use close() as usual. + """ + # TODO: unused MAX_WAIT_TIME = 30*SECONDS + # TODO: unused MAXLENGTH = 10*MB + # TODO: unused MAXTIME = 60*SECONDS - protocol = TransitConnection - def __init__(self, blur_usage, log_file, usage_db): - self._blur_usage = blur_usage - self._log_requests = blur_usage is None - if self._blur_usage: - log.msg("blurring access times to %d seconds" % self._blur_usage) - log.msg("not logging Transit connections to Twisted log") - else: - log.msg("not blurring access times") - self._debug_log = False - self._log_file = log_file - self._db = None - if usage_db: - self._db = get_db(usage_db) - self._rebooted = time.time() - # we don't track TransitConnections until they submit a token - self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection)) - self._active_connections = set() # TransitConnection + def __init__(self, usage, get_timestamp): + self.active_connections = ActiveConnections() + self.pending_requests = PendingRequests(self.active_connections) + self.usage = usage + self._timestamp = get_timestamp + self._rebooted = self._timestamp() - def connection_got_token(self, token, new_side, new_tc): - potentials = self._pending_requests[token] - for old in potentials: - (old_side, old_tc) = old - if ((old_side is None) - or (new_side is None) - or (old_side != new_side)): - # we found a match - if self._debug_log: - log.msg("transit relay 2: %s" % new_tc.describeToken()) - - # drop and stop tracking the rest - potentials.remove(old) - for (_, leftover_tc) in potentials.copy(): - # Don't record this as errory. It's just a spare connection - # from the same side as a connection that got used. This - # can happen if the connection hint contains multiple - # addresses (we don't currently support those, but it'd - # probably be useful in the future). - leftover_tc.disconnect_redundant() - self._pending_requests.pop(token, None) - - # glue the two ends together - self._active_connections.add(new_tc) - self._active_connections.add(old_tc) - new_tc.buddy_connected(old_tc) - old_tc.buddy_connected(new_tc) - return - if self._debug_log: - log.msg("transit relay 1: %s" % new_tc.describeToken()) - potentials.add((new_side, new_tc)) - # TODO: timer - - def transitFinished(self, tc, token, side, description): - if token in self._pending_requests: - side_tc = (side, tc) - self._pending_requests[token].discard(side_tc) - if not self._pending_requests[token]: # set is now empty - del self._pending_requests[token] - if self._debug_log: - log.msg("transitFinished %s" % (description,)) - self._active_connections.discard(tc) - # we could update the usage database "current" row immediately, or wait - # until the 5-minute timer updates it. If we update it now, just after - # losing a connection, we should probably also update it just after - # establishing one (at the end of connection_got_token). For now I'm - # going to omit these, but maybe someday we'll turn them both on. The - # consequence is that a manual execution of the munin scripts ("munin - # run wormhole_transit_active") will give the wrong value just after a - # connect/disconnect event. Actual munin graphs should accurately - # report connections that last longer than the 5-minute sampling - # window, which is what we actually care about. - #self.timerUpdateStats() - - def recordUsage(self, started, result, total_bytes, - total_time, waiting_time): - if self._debug_log: - log.msg(format="Transit.recordUsage {bytes}B", bytes=total_bytes) - if self._blur_usage: - started = self._blur_usage * (started // self._blur_usage) - total_bytes = blur_size(total_bytes) - if self._log_file is not None: - data = {"started": started, - "total_time": total_time, - "waiting_time": waiting_time, - "total_bytes": total_bytes, - "mood": result, - } - self._log_file.write(json.dumps(data)+"\n") - self._log_file.flush() - if self._db: - self._db.execute("INSERT INTO `usage`" - " (`started`, `total_time`, `waiting_time`," - " `total_bytes`, `result`)" - " VALUES (?,?,?, ?,?)", - (started, total_time, waiting_time, - total_bytes, result)) - self._update_stats() - self._db.commit() - - def timerUpdateStats(self): - if self._db: - self._update_stats() - self._db.commit() - - def _update_stats(self): - # current status: should be zero when idle - rebooted = self._rebooted - updated = time.time() - connected = len(self._active_connections) / 2 + def update_stats(self): # TODO: when a connection is half-closed, len(active) will be odd. a # moment later (hopefully) the other side will disconnect, but # _update_stats isn't updated until later. - waiting = len(self._pending_requests) + # "waiting" doesn't count multiple parallel connections from the same # side - incomplete_bytes = sum(tc._total_sent - for tc in self._active_connections) - self._db.execute("DELETE FROM `current`") - self._db.execute("INSERT INTO `current`" - " (`rebooted`, `updated`, `connected`, `waiting`," - " `incomplete_bytes`)" - " VALUES (?, ?, ?, ?, ?)", - (rebooted, updated, connected, waiting, - incomplete_bytes)) + self.usage.update_stats( + rebooted=self._rebooted, + updated=self._timestamp(), + connected=len(self.active_connections._connections), + waiting=len(self.pending_requests._requests), + incomplete_bytes=sum( + tc._total_sent + for tc in self.active_connections._connections + ), + ) + + +@implementer(ITransitClient) +class WebSocketTransitConnection(WebSocketServerProtocol): + started_time = None + + def send(self, data): + """ + ITransitClient API + """ + self.sendMessage(data, isBinary=True) + + def disconnect(self): + """ + ITransitClient API + """ + self.sendClose(1000, None) + + def connect_partner(self, other): + """ + ITransitClient API + """ + self._buddy = other + + def disconnect_partner(self): + """ + ITransitClient API + """ + assert self._buddy is not None, "internal error: no buddy" + if self.factory.log_requests: + log.msg("buddy_disconnected {}".format(self._buddy.get_token())) + self._buddy._client.disconnect() + self._buddy = None + + def connectionMade(self): + """ + IProtocol API + """ + super(WebSocketTransitConnection, self).connectionMade() + self.started_time = time.time() + self._first_message = True + self._state = TransitServerState( + self.factory.transit.pending_requests, + self.factory.transit.usage, + ) + + # uncomment to turn on state-machine tracing + # def tracer(oldstate, theinput, newstate): + # print("WSTRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) + # self._state.set_trace_function(tracer) + + def onOpen(self): + self._state.connection_made(self) + + def onMessage(self, payload, isBinary): + """ + We may have a 'handshake' on our hands or we may just have some bytes to relay + """ + if not isBinary: + raise ValueError( + "All messages must be binary" + ) + if self._first_message: + self._first_message = False + token = None + old = re.search(br"^please relay (\w{64})$", payload) + if old: + token = old.group(1) + self._state.please_relay(token) + + # new: "please relay {64} for side {16}\n" + new = re.search(br"^please relay (\w{64}) for side (\w{16})$", payload) + if new: + token = new.group(1) + side = new.group(2) + self._state.please_relay_for_side(token, side) + + if token is None: + self._state.bad_token() + else: + self._state.got_bytes(payload) + + def onClose(self, wasClean, code, reason): + """ + IWebSocketChannel API + """ + self._state.connection_lost() diff --git a/src/wormhole_transit_relay/usage.py b/src/wormhole_transit_relay/usage.py new file mode 100644 index 0000000..92f8e35 --- /dev/null +++ b/src/wormhole_transit_relay/usage.py @@ -0,0 +1,238 @@ +import time +import json + +from twisted.python import log +from zope.interface import ( + implementer, + Interface, +) + + +def create_usage_tracker(blur_usage, log_file, usage_db): + """ + :param int blur_usage: see UsageTracker + + :param log_file: None or a file-like object to write JSON-encoded + lines of usage information to. + + :param usage_db: None or an sqlite3 database connection + + :returns: a new UsageTracker instance configured with backends. + """ + tracker = UsageTracker(blur_usage) + if usage_db: + tracker.add_backend(DatabaseUsageRecorder(usage_db)) + if log_file: + tracker.add_backend(LogFileUsageRecorder(log_file)) + return tracker + + +class IUsageWriter(Interface): + """ + Records actual usage statistics in some way + """ + + def record_usage(started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): + """ + :param int started: timestemp when this connection began + + :param float total_time: total seconds this connection lasted + + :param float waiting_time: None or the total seconds one side + waited for the other + + :param int total_bytes: the total bytes sent. In case the + connection was concluded successfully, only one side will + record the total bytes (but count both). + + :param str mood: the 'mood' of the connection + """ + + +@implementer(IUsageWriter) +class MemoryUsageRecorder: + """ + Remebers usage records in memory. + """ + + def __init__(self): + self.events = [] + + def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): + """ + IUsageWriter. + """ + data = { + "started": started, + "total_time": total_time, + "waiting_time": waiting_time, + "total_bytes": total_bytes, + "mood": mood, + } + self.events.append(data) + + +@implementer(IUsageWriter) +class LogFileUsageRecorder: + """ + Writes usage records to a file. The records are written in JSON, + one record per line. + """ + + def __init__(self, writable_file): + self._file = writable_file + + def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): + """ + IUsageWriter. + """ + data = { + "started": started, + "total_time": total_time, + "waiting_time": waiting_time, + "total_bytes": total_bytes, + "mood": mood, + } + self._file.write(json.dumps(data) + "\n") + self._file.flush() + + +@implementer(IUsageWriter) +class DatabaseUsageRecorder: + """ + Write usage records into a database + """ + + def __init__(self, db): + self._db = db + + def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): + """ + IUsageWriter. + """ + self._db.execute( + "INSERT INTO `usage`" + " (`started`, `total_time`, `waiting_time`," + " `total_bytes`, `result`)" + " VALUES (?,?,?,?,?)", + (started, total_time, waiting_time, total_bytes, mood) + ) + # original code did "self._update_stats()" here, thus causing + # "global" stats update on every connection update .. should + # we repeat this behavior, or really only record every + # 60-seconds with the timer? + self._db.commit() + + +class UsageTracker(object): + """ + Tracks usage statistics of connections + """ + + def __init__(self, blur_usage): + """ + :param int blur_usage: None or the number of seconds to use as a + window around which to blur time statistics (e.g. "60" means times + will be rounded to 1 minute intervals). When blur_usage is + non-zero, sizes will also be rounded into buckets of "one + megabyte", "one gigabyte" or "lots" + """ + self._backends = set() + self._blur_usage = blur_usage + if blur_usage: + log.msg("blurring access times to %d seconds" % self._blur_usage) + else: + log.msg("not blurring access times") + + def add_backend(self, backend): + """ + Add a new backend. + + :param IUsageWriter backend: the backend to add + """ + self._backends.add(backend) + + def record(self, started, buddy_started, result, bytes_sent, buddy_bytes): + """ + :param int started: timestamp when our connection started + + :param int buddy_started: None, or the timestamp when our + partner's connection started (will be None if we don't yet + have a partner). + + :param str result: a label for the result of the connection + (one of the "moods"). + + :param int bytes_sent: number of bytes we sent + + :param int buddy_bytes: number of bytes our partner sent + """ + # ideally self._reactor.seconds() or similar, but .. + finished = time.time() + if buddy_started is not None: + starts = [started, buddy_started] + total_time = finished - min(starts) + waiting_time = max(starts) - min(starts) + total_bytes = bytes_sent + buddy_bytes + else: + total_time = finished - started + waiting_time = None + total_bytes = bytes_sent + # note that "bytes_sent" should always be 0 here, but + # we're recording what the state-machine remembered in any + # case + + if self._blur_usage: + started = self._blur_usage * (started // self._blur_usage) + total_bytes = blur_size(total_bytes) + + # This is "a dict" instead of "kwargs" because we have to make + # it into a dict for the log use-case and in-memory/testing + # use-case anyway so this is less repeats of the names. + self._notify_backends({ + "started": started, + "total_time": total_time, + "waiting_time": waiting_time, + "total_bytes": total_bytes, + "mood": result, + }) + + def update_stats(self, rebooted, updated, connected, waiting, + incomplete_bytes): + """ + Update general statistics. + """ + # in original code, this is only recorded in the database + # .. perhaps a better way to do this, but .. + for backend in self._backends: + if isinstance(backend, DatabaseUsageRecorder): + backend._db.execute("DELETE FROM `current`") + backend._db.execute( + "INSERT INTO `current`" + " (`rebooted`, `updated`, `connected`, `waiting`," + " `incomplete_bytes`)" + " VALUES (?, ?, ?, ?, ?)", + (int(rebooted), int(updated), connected, waiting, + incomplete_bytes) + ) + + def _notify_backends(self, data): + """ + Internal helper. Tell every backend we have about a new usage record. + """ + for backend in self._backends: + backend.record_usage(**data) + + +def round_to(size, coarseness): + return int(coarseness*(1+int((size-1)/coarseness))) + + +def blur_size(size): + if size == 0: + return 0 + if size < 1e6: + return round_to(size, 10e3) + if size < 1e9: + return round_to(size, 1e6) + return round_to(size, 100e6) diff --git a/ws_client.py b/ws_client.py new file mode 100644 index 0000000..27e989c --- /dev/null +++ b/ws_client.py @@ -0,0 +1,82 @@ +""" +This is a test-client for the transit-relay that uses WebSockets. + +If an additional command-line argument (anything) is added, it will +send 5 messages upon connection. Otherwise, it just prints out what is +received. Uses a fixed token of 64 'a' characters. Always connects on +localhost:4002 +""" + +import sys + +from twisted.internet import endpoints +from twisted.internet.defer import ( + Deferred, + inlineCallbacks, +) +from twisted.internet.task import react, deferLater + +from autobahn.twisted.websocket import ( + WebSocketClientProtocol, + WebSocketClientFactory, +) + + +class RelayEchoClient(WebSocketClientProtocol): + + def onOpen(self): + self._received = b"" + self.sendMessage( + u"please relay {} for side {}".format( + self.factory.token, + self.factory.side, + ).encode("ascii"), + True, + ) + + def onMessage(self, data, isBinary): + print(">onMessage: {} bytes".format(len(data))) + print(data, isBinary) + if data == b"ok\n": + self.factory.ready.callback(None) + else: + self._received += data + if False: + # test abrupt hangup from receiving side + self.transport.loseConnection() + + def onClose(self, wasClean, code, reason): + print(">onClose", wasClean, code, reason) + self.factory.done.callback(reason) + if not self.factory.ready.called: + self.factory.ready.errback(RuntimeError(reason)) + + +@react +@inlineCallbacks +def main(reactor): + will_send_message = len(sys.argv) > 1 + ep = endpoints.clientFromString(reactor, "tcp:localhost:4002") + f = WebSocketClientFactory("ws://127.0.0.1:4002/") + f.reactor = reactor + f.protocol = RelayEchoClient + f.token = "a" * 64 + f.side = "0" * 16 if will_send_message else "1" * 16 + f.done = Deferred() + f.ready = Deferred() + + proto = yield ep.connect(f) + print("proto", proto) + yield f.ready + + print("ready") + if will_send_message: + for _ in range(5): + print("sending message") + proto.sendMessage(b"it's a message", True) + yield deferLater(reactor, 0.2) + yield proto.sendClose() + print("closing") + yield f.done + print("relayed {} bytes:".format(len(proto._received))) + print(proto._received.decode("utf8"))