Merge pull request #23 from meejah/websocket-support-on-iosim-tests-master

WebSocket support
This commit is contained in:
meejah 2021-05-09 23:49:48 -06:00 committed by GitHub
commit 80e02d4a77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1541 additions and 398 deletions

54
client.py Normal file
View File

@ -0,0 +1,54 @@
"""
This is a test-client for the transit-relay that uses TCP. It
doesn't send any data, only prints out data that is received. Uses a
fixed token of 64 'a' characters. Always connects on localhost:4001
"""
from twisted.internet import endpoints
from twisted.internet.defer import (
Deferred,
)
from twisted.internet.task import react
from twisted.internet.error import (
ConnectionDone,
)
from twisted.internet.protocol import (
Protocol,
Factory,
)
class RelayEchoClient(Protocol):
"""
Speaks the version1 magic wormhole transit relay protocol (as a client)
"""
def connectionMade(self):
print(">CONNECT")
self.data = b""
self.transport.write(u"please relay {}\n".format(self.factory.token).encode("ascii"))
def dataReceived(self, data):
print(">RECV {} bytes".format(len(data)))
print(data.decode("ascii"))
self.data += data
if data == "ok\n":
self.transport.write("ding\n")
def connectionLost(self, reason):
if isinstance(reason.value, ConnectionDone):
self.factory.done.callback(None)
else:
print(">DISCONNCT: {}".format(reason))
self.factory.done.callback(reason)
@react
def main(reactor):
ep = endpoints.clientFromString(reactor, "tcp:localhost:4001")
f = Factory.forProtocol(RelayEchoClient)
f.token = "a" * 64
f.done = Deferred()
ep.connect(f)
return f.done

View File

@ -50,6 +50,15 @@ The relevant arguments are:
* ``--usage-db=``: maintains a SQLite database with current and historical usage data * ``--usage-db=``: maintains a SQLite database with current and historical usage data
* ``--blur-usage=``: round logged timestamps and data sizes * ``--blur-usage=``: round logged timestamps and data sizes
For WebSockets support, two additional arguments:
* ``--websocket``: the endpoint to listen for websocket connections
on, like ``tcp:4002``
* ``--websocket-url``: the URL of the WebSocket connection. This may
be different from the listening endpoint because of port-forwarding
and so forth. By default it will be ``ws://localhost:<port>`` if not
provided
When you use ``twist``, the relay runs in the foreground, so it will When you use ``twist``, the relay runs in the foreground, so it will
generally exit as soon as the controlling terminal exits. For persistent generally exit as soon as the controlling terminal exits. For persistent
environments, you should daemonize the server. environments, you should daemonize the server.

View File

@ -18,7 +18,8 @@ setup(name="magic-wormhole-transit-relay",
], ],
package_data={"wormhole_transit_relay": ["db-schemas/*.sql"]}, package_data={"wormhole_transit_relay": ["db-schemas/*.sql"]},
install_requires=[ install_requires=[
"twisted >= 17.5.0", "twisted >= 21.2.0",
"autobahn >= 21.3.1",
], ],
extras_require={ extras_require={
':sys_platform=="win32"': ["pypiwin32"], ':sys_platform=="win32"': ["pypiwin32"],

View File

@ -0,0 +1,477 @@
from collections import defaultdict
import automat
from twisted.python import log
from zope.interface import (
Interface,
Attribute,
)
class ITransitClient(Interface):
"""
Represents the client side of a connection to this transit
relay. This is used by TransitServerState instances.
"""
started_time = Attribute("timestamp when the connection was established")
def send(data):
"""
Send some byets to the client
"""
def disconnect():
"""
Disconnect the client transport
"""
def connect_partner(other):
"""
Hook up to our partner.
:param ITransitClient other: our partner
"""
def disconnect_partner():
"""
Disconnect our partner's transport
"""
class ActiveConnections(object):
"""
Tracks active connections.
A connection is 'active' when both sides have shown up and they
are glued together (and thus could be passing data back and forth
if any is flowing).
"""
def __init__(self):
self._connections = set()
def register(self, side0, side1):
"""
A connection has become active so register both its sides
:param TransitConnection side0: one side of the connection
:param TransitConnection side1: one side of the connection
"""
self._connections.add(side0)
self._connections.add(side1)
def unregister(self, side):
"""
One side of a connection has become inactive.
:param TransitConnection side: an inactive side of a connection
"""
self._connections.discard(side)
class PendingRequests(object):
"""
Tracks outstanding (non-"active") requests.
We register client connections against the tokens we have
received. When the other side shows up we can thus match it to the
correct partner connection. At this point, the connection becomes
"active" is and is thus no longer "pending" and so will no longer
be in this collection.
"""
def __init__(self, active_connections):
"""
:param active_connections: an instance of ActiveConnections where
connections are put when both sides arrive.
"""
self._requests = defaultdict(set) # token -> set((side, TransitConnection))
self._active = active_connections
def unregister(self, token, side, tc):
"""
We no longer care about a particular client (e.g. it has
disconnected).
"""
if token in self._requests:
self._requests[token].discard((side, tc))
if not self._requests[token]:
# no more sides; token is dead
del self._requests[token]
self._active.unregister(tc)
def register(self, token, new_side, new_tc):
"""
A client has connected and successfully offered a token (and
optional 'side' token). If this is the first one for this
token, we merely remember it. If it is the second side for
this token we connect them together.
:param bytes token: the token for this connection.
:param bytes new_side: None or the side token for this connection
:param TransitServerState new_tc: the state-machine of the connection
:returns bool: True if we are the first side to register this
token
"""
potentials = self._requests[token]
for old in potentials:
(old_side, old_tc) = old
if ((old_side is None)
or (new_side is None)
or (old_side != new_side)):
# we found a match
# drop and stop tracking the rest
potentials.remove(old)
for (_, leftover_tc) in potentials.copy():
# Don't record this as errory. It's just a spare connection
# from the same side as a connection that got used. This
# can happen if the connection hint contains multiple
# addresses (we don't currently support those, but it'd
# probably be useful in the future).
leftover_tc.partner_connection_lost()
self._requests.pop(token, None)
# glue the two ends together
self._active.register(new_tc, old_tc)
new_tc.got_partner(old_tc)
old_tc.got_partner(new_tc)
return False
potentials.add((new_side, new_tc))
return True
# TODO: timer
class TransitServerState(object):
"""
Encapsulates the state-machine of the server side of a transit
relay connection.
Once the protocol has been told to relay (or to relay for a side)
it starts passing all received bytes to the other side until it
closes.
"""
_machine = automat.MethodicalMachine()
_client = None
_buddy = None
_token = None
_side = None
_first = None
_mood = "empty"
_total_sent = 0
def __init__(self, pending_requests, usage_recorder):
self._pending_requests = pending_requests
self._usage = usage_recorder
def get_token(self):
"""
:returns str: a string describing our token. This will be "-" if
we have no token yet, or "{16 chars}-<unsided>" if we have
just a token or "{16 chars}-{16 chars}" if we have a token and
a side.
"""
d = "-"
if self._token is not None:
d = self._token[:16].decode("ascii")
if self._side is not None:
d += "-" + self._side.decode("ascii")
else:
d += "-<unsided>"
return d
@_machine.input()
def connection_made(self, client):
"""
A client has connected. May only be called once.
:param ITransitClient client: our client.
"""
# NB: the "only called once" is enforced by the state-machine;
# this input is only valid for the "listening" state, to which
# we never return.
@_machine.input()
def please_relay(self, token):
"""
A 'please relay X' message has been received (the original version
of the protocol).
"""
@_machine.input()
def please_relay_for_side(self, token, side):
"""
A 'please relay X for side Y' message has been received (the
second version of the protocol).
"""
@_machine.input()
def bad_token(self):
"""
A bad token / relay line was received (e.g. couldn't be parsed)
"""
@_machine.input()
def got_partner(self, client):
"""
The partner for this relay session has been found
"""
@_machine.input()
def connection_lost(self):
"""
Our transport has failed.
"""
@_machine.input()
def partner_connection_lost(self):
"""
Our partner's transport has failed.
"""
@_machine.input()
def got_bytes(self, data):
"""
Some bytes have arrived (that aren't part of the handshake)
"""
@_machine.output()
def _remember_client(self, client):
self._client = client
# note that there is no corresponding "_forget_client" because we
# may still want to access it after it is gone .. for example, to
# get the .started_time for logging purposes
@_machine.output()
def _register_token(self, token):
return self._real_register_token_for_side(token, None)
@_machine.output()
def _register_token_for_side(self, token, side):
return self._real_register_token_for_side(token, side)
@_machine.output()
def _unregister(self):
"""
remove us from the thing that remembers tokens and sides
"""
return self._pending_requests.unregister(self._token, self._side, self)
@_machine.output()
def _send_bad(self):
self._mood = "errory"
self._client.send(b"bad handshake\n")
if self._client.factory.log_requests:
log.msg("transit handshake failure")
@_machine.output()
def _send_ok(self):
self._client.send(b"ok\n")
@_machine.output()
def _send_impatient(self):
self._client.send(b"impatient\n")
if self._client.factory.log_requests:
log.msg("transit impatience failure")
@_machine.output()
def _count_bytes(self, data):
self._total_sent += len(data)
@_machine.output()
def _send_to_partner(self, data):
self._buddy._client.send(data)
@_machine.output()
def _connect_partner(self, client):
self._buddy = client
self._client.connect_partner(client)
@_machine.output()
def _disconnect(self):
self._client.disconnect()
@_machine.output()
def _disconnect_partner(self):
self._client.disconnect_partner()
# some outputs to record "usage" information ..
@_machine.output()
def _record_usage(self):
if self._mood == "jilted":
if self._buddy and self._buddy._mood == "happy":
return
self._usage.record(
started=self._client.started_time,
buddy_started=self._buddy._client.started_time if self._buddy is not None else None,
result=self._mood,
bytes_sent=self._total_sent,
buddy_bytes=self._buddy._total_sent if self._buddy is not None else None
)
# some outputs to record the "mood" ..
@_machine.output()
def _mood_happy(self):
self._mood = "happy"
@_machine.output()
def _mood_lonely(self):
self._mood = "lonely"
@_machine.output()
def _mood_redundant(self):
self._mood = "redundant"
@_machine.output()
def _mood_impatient(self):
self._mood = "impatient"
@_machine.output()
def _mood_errory(self):
self._mood = "errory"
@_machine.output()
def _mood_happy_if_first(self):
"""
We disconnected first so we're only happy if we also connected
first.
"""
if self._first:
self._mood = "happy"
else:
self._mood = "jilted"
def _real_register_token_for_side(self, token, side):
"""
A client has connected and sent a valid version 1 or version 2
handshake. If the former, `side` will be None.
In either case, we remember the tokens and register
ourselves. This might result in 'got_partner' notifications to
two state-machines if this is the second side for a given token.
:param bytes token: the token
:param bytes side: The side token (or None)
"""
self._token = token
self._side = side
self._first = self._pending_requests.register(token, side, self)
@_machine.state(initial=True)
def listening(self):
"""
Initial state, awaiting connection.
"""
@_machine.state()
def wait_relay(self):
"""
Waiting for a 'relay' message
"""
@_machine.state()
def wait_partner(self):
"""
Waiting for our partner to connect
"""
@_machine.state()
def relaying(self):
"""
Relaying bytes to our partner
"""
@_machine.state()
def done(self):
"""
Terminal state
"""
listening.upon(
connection_made,
enter=wait_relay,
outputs=[_remember_client],
)
listening.upon(
connection_lost,
enter=done,
outputs=[_mood_errory],
)
wait_relay.upon(
please_relay,
enter=wait_partner,
outputs=[_mood_lonely, _register_token],
)
wait_relay.upon(
please_relay_for_side,
enter=wait_partner,
outputs=[_mood_lonely, _register_token_for_side],
)
wait_relay.upon(
bad_token,
enter=done,
outputs=[_mood_errory, _send_bad, _disconnect, _record_usage],
)
wait_relay.upon(
got_bytes,
enter=done,
outputs=[_count_bytes, _mood_errory, _disconnect, _record_usage],
)
wait_relay.upon(
connection_lost,
enter=done,
outputs=[_disconnect, _record_usage],
)
wait_partner.upon(
got_partner,
enter=relaying,
outputs=[_mood_happy, _send_ok, _connect_partner],
)
wait_partner.upon(
connection_lost,
enter=done,
outputs=[_mood_lonely, _unregister, _record_usage],
)
wait_partner.upon(
got_bytes,
enter=done,
outputs=[_mood_impatient, _send_impatient, _disconnect, _unregister, _record_usage],
)
wait_partner.upon(
partner_connection_lost,
enter=done,
outputs=[_mood_redundant, _disconnect, _record_usage],
)
relaying.upon(
got_bytes,
enter=relaying,
outputs=[_count_bytes, _send_to_partner],
)
relaying.upon(
connection_lost,
enter=done,
outputs=[_mood_happy_if_first, _disconnect_partner, _unregister, _record_usage],
)
done.upon(
connection_lost,
enter=done,
outputs=[],
)
done.upon(
partner_connection_lost,
enter=done,
outputs=[],
)
# uncomment to turn on state-machine tracing
# set_trace_function = _machine._setTrace

View File

@ -5,8 +5,14 @@ from twisted.application.service import MultiService
from twisted.application.internet import (TimerService, from twisted.application.internet import (TimerService,
StreamServerEndpointService) StreamServerEndpointService)
from twisted.internet import endpoints from twisted.internet import endpoints
from twisted.internet import protocol
from autobahn.twisted.websocket import WebSocketServerFactory
from . import transit_server from . import transit_server
from .usage import create_usage_tracker
from .increase_rlimits import increase_rlimits from .increase_rlimits import increase_rlimits
from .database import get_db
LONGDESC = """\ LONGDESC = """\
This plugin sets up a 'Transit Relay' server for magic-wormhole. This service This plugin sets up a 'Transit Relay' server for magic-wormhole. This service
@ -20,6 +26,8 @@ class Options(usage.Options):
optParameters = [ optParameters = [
("port", "p", "tcp:4001:interface=\:\:", "endpoint to listen on"), ("port", "p", "tcp:4001:interface=\:\:", "endpoint to listen on"),
("websocket", "w", None, "endpoint to listen for WebSocket connections"),
("websocket-url", "u", None, "WebSocket URL (derived from endpoint if not provided)"),
("blur-usage", None, None, "blur timestamps and data sizes in logs"), ("blur-usage", None, None, "blur timestamps and data sizes in logs"),
("log-fd", None, None, "write JSON usage logs to this file descriptor"), ("log-fd", None, None, "write JSON usage logs to this file descriptor"),
("usage-db", None, None, "record usage data (SQLite)"), ("usage-db", None, None, "record usage data (SQLite)"),
@ -31,14 +39,45 @@ class Options(usage.Options):
def makeService(config, reactor=reactor): def makeService(config, reactor=reactor):
increase_rlimits() increase_rlimits()
ep = endpoints.serverFromString(reactor, config["port"]) # to listen tcp_ep = endpoints.serverFromString(reactor, config["port"]) # to listen
log_file = (os.fdopen(int(config["log-fd"]), "w") ws_ep = (
if config["log-fd"] is not None endpoints.serverFromString(reactor, config["websocket"])
else None) if config["websocket"] is not None
f = transit_server.Transit(blur_usage=config["blur-usage"], else None
log_file=log_file, )
usage_db=config["usage-db"]) log_file = (
os.fdopen(int(config["log-fd"]), "w")
if config["log-fd"] is not None
else None
)
db = None if config["usage-db"] is None else get_db(config["usage-db"])
usage = create_usage_tracker(
blur_usage=config["blur-usage"],
log_file=log_file,
usage_db=db,
)
transit = transit_server.Transit(usage, reactor.seconds)
tcp_factory = protocol.ServerFactory()
tcp_factory.protocol = transit_server.TransitConnection
tcp_factory.log_requests = False
if ws_ep is not None:
ws_url = config["websocket-url"]
if ws_url is None:
# we're using a "private" attribute here but I don't see
# any useful alternative unless we also want to parse
# Twisted endpoint-strings.
ws_url = "ws://localhost:{}/".format(ws_ep._port)
print("Using WebSocket URL '{}'".format(ws_url))
ws_factory = WebSocketServerFactory(ws_url)
ws_factory.protocol = transit_server.WebSocketTransitConnection
ws_factory.transit = transit
ws_factory.log_requests = False
tcp_factory.transit = transit
parent = MultiService() parent = MultiService()
StreamServerEndpointService(ep, f).setServiceParent(parent) StreamServerEndpointService(tcp_ep, tcp_factory).setServiceParent(parent)
TimerService(5*60.0, f.timerUpdateStats).setServiceParent(parent) if ws_ep is not None:
StreamServerEndpointService(ws_ep, ws_factory).setServiceParent(parent)
TimerService(5*60.0, transit.update_stats).setServiceParent(parent)
return parent return parent

View File

@ -10,7 +10,10 @@ from zope.interface import (
) )
from ..transit_server import ( from ..transit_server import (
Transit, Transit,
TransitConnection,
) )
from twisted.internet.protocol import ServerFactory
from ..usage import create_usage_tracker
class IRelayTestClient(Interface): class IRelayTestClient(Interface):
@ -42,6 +45,7 @@ class IRelayTestClient(Interface):
Erase any received data to this point. Erase any received data to this point.
""" """
class ServerBase: class ServerBase:
log_requests = False log_requests = False
@ -62,19 +66,30 @@ class ServerBase:
self.flush() self.flush()
def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None):
self._transit_server = Transit( usage = create_usage_tracker(
blur_usage=blur_usage, blur_usage=blur_usage,
log_file=log_file, log_file=log_file,
usage_db=usage_db, usage_db=usage_db,
) )
self._transit_server._debug_log = self.log_requests self._transit_server = Transit(usage, lambda: 123456789.0)
def new_protocol(self): def new_protocol(self):
"""
This should be overridden by derived test-case classes to decide
if they want a TCP or WebSockets protocol.
"""
raise NotImplementedError()
def new_protocol_tcp(self):
""" """
Create a new client protocol connected to the server. Create a new client protocol connected to the server.
:returns: a IRelayTestClient implementation :returns: a IRelayTestClient implementation
""" """
server_protocol = self._transit_server.buildProtocol(('127.0.0.1', 0)) server_factory = ServerFactory()
server_factory.protocol = TransitConnection
server_factory.transit = self._transit_server
server_factory.log_requests = self.log_requests
server_protocol = server_factory.buildProtocol(('127.0.0.1', 0))
@implementer(IRelayTestClient) @implementer(IRelayTestClient)
class TransitClientProtocolTcp(Protocol): class TransitClientProtocolTcp(Protocol):

View File

@ -8,12 +8,29 @@ class Config(unittest.TestCase):
o = server_tap.Options() o = server_tap.Options()
o.parseOptions([]) o.parseOptions([])
self.assertEqual(o, {"blur-usage": None, "log-fd": None, self.assertEqual(o, {"blur-usage": None, "log-fd": None,
"usage-db": None, "port": PORT}) "usage-db": None, "port": PORT,
"websocket": None, "websocket-url": None})
def test_blur(self): def test_blur(self):
o = server_tap.Options() o = server_tap.Options()
o.parseOptions(["--blur-usage=60"]) o.parseOptions(["--blur-usage=60"])
self.assertEqual(o, {"blur-usage": 60, "log-fd": None, self.assertEqual(o, {"blur-usage": 60, "log-fd": None,
"usage-db": None, "port": PORT}) "usage-db": None, "port": PORT,
"websocket": None, "websocket-url": None})
def test_websocket(self):
o = server_tap.Options()
o.parseOptions(["--websocket=tcp:4004"])
self.assertEqual(o, {"blur-usage": None, "log-fd": None,
"usage-db": None, "port": PORT,
"websocket": "tcp:4004", "websocket-url": None})
def test_websocket_url(self):
o = server_tap.Options()
o.parseOptions(["--websocket=tcp:4004", "--websocket-url=ws://example.com/"])
self.assertEqual(o, {"blur-usage": None, "log-fd": None,
"usage-db": None, "port": PORT,
"websocket": "tcp:4004",
"websocket-url": "ws://example.com/"})
def test_string(self): def test_string(self):
o = server_tap.Options() o = server_tap.Options()

View File

@ -1,13 +1,14 @@
from twisted.trial import unittest from twisted.trial import unittest
from unittest import mock from unittest import mock
from twisted.application.service import MultiService from twisted.application.service import MultiService
from autobahn.twisted.websocket import WebSocketServerFactory
from .. import server_tap from .. import server_tap
class Service(unittest.TestCase): class Service(unittest.TestCase):
def test_defaults(self): def test_defaults(self):
o = server_tap.Options() o = server_tap.Options()
o.parseOptions([]) o.parseOptions([])
with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t:
s = server_tap.makeService(o) s = server_tap.makeService(o)
self.assertEqual(t.mock_calls, self.assertEqual(t.mock_calls,
[mock.call(blur_usage=None, [mock.call(blur_usage=None,
@ -17,7 +18,7 @@ class Service(unittest.TestCase):
def test_blur(self): def test_blur(self):
o = server_tap.Options() o = server_tap.Options()
o.parseOptions(["--blur-usage=60"]) o.parseOptions(["--blur-usage=60"])
with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t:
server_tap.makeService(o) server_tap.makeService(o)
self.assertEqual(t.mock_calls, self.assertEqual(t.mock_calls,
[mock.call(blur_usage=60, [mock.call(blur_usage=60,
@ -27,7 +28,7 @@ class Service(unittest.TestCase):
o = server_tap.Options() o = server_tap.Options()
o.parseOptions(["--log-fd=99"]) o.parseOptions(["--log-fd=99"])
fd = object() fd = object()
with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t:
with mock.patch("wormhole_transit_relay.server_tap.os.fdopen", with mock.patch("wormhole_transit_relay.server_tap.os.fdopen",
return_value=fd) as f: return_value=fd) as f:
server_tap.makeService(o) server_tap.makeService(o)
@ -36,3 +37,34 @@ class Service(unittest.TestCase):
[mock.call(blur_usage=None, [mock.call(blur_usage=None,
log_file=fd, usage_db=None)]) log_file=fd, usage_db=None)])
def test_websocket(self):
"""
A websocket factory is created when passing --websocket
"""
o = server_tap.Options()
o.parseOptions(["--websocket=tcp:4004"])
services = server_tap.makeService(o)
self.assertTrue(
any(
isinstance(s.factory, WebSocketServerFactory)
for s in services.services
)
)
def test_websocket_explicit_url(self):
"""
A websocket factory is created with --websocket and
--websocket-url
"""
o = server_tap.Options()
o.parseOptions([
"--websocket=tcp:4004",
"--websocket-url=ws://example.com:4004",
])
services = server_tap.makeService(o)
self.assertTrue(
any(
isinstance(s.factory, WebSocketServerFactory)
for s in services.services
)
)

View File

@ -1,27 +1,38 @@
import os, io, json, sqlite3 import os, io, json
from unittest import mock from unittest import mock
from twisted.trial import unittest from twisted.trial import unittest
from ..transit_server import Transit from ..transit_server import Transit
from ..usage import create_usage_tracker
from .. import database from .. import database
class DB(unittest.TestCase): class DB(unittest.TestCase):
def open_db(self, dbfile):
db = sqlite3.connect(dbfile)
database._initialize_db_connection(db)
return db
def test_db(self): def test_db(self):
T = 1519075308.0 T = 1519075308.0
class Timer:
t = T
def __call__(self):
return self.t
get_time = Timer()
d = self.mktemp() d = self.mktemp()
os.mkdir(d) os.mkdir(d)
usage_db = os.path.join(d, "usage.sqlite") usage_db = os.path.join(d, "usage.sqlite")
with mock.patch("time.time", return_value=T+0): db = database.get_db(usage_db)
t = Transit(blur_usage=None, log_file=None, usage_db=usage_db) t = Transit(
db = self.open_db(usage_db) create_usage_tracker(blur_usage=None, log_file=None, usage_db=db),
get_time,
)
self.assertEqual(len(t.usage._backends), 1)
usage = list(t.usage._backends)[0]
get_time.t = T + 1
usage.record_usage(started=123, mood="happy", total_bytes=100,
total_time=10, waiting_time=2)
t.update_stats()
with mock.patch("time.time", return_value=T+1):
t.recordUsage(started=123, result="happy", total_bytes=100,
total_time=10, waiting_time=2)
self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(),
[dict(result="happy", started=123, [dict(result="happy", started=123,
total_bytes=100, total_time=10, waiting_time=2), total_bytes=100, total_time=10, waiting_time=2),
@ -31,9 +42,10 @@ class DB(unittest.TestCase):
incomplete_bytes=0, incomplete_bytes=0,
waiting=0, connected=0)) waiting=0, connected=0))
with mock.patch("time.time", return_value=T+2): get_time.t = T + 2
t.recordUsage(started=150, result="errory", total_bytes=200, usage.record_usage(started=150, mood="errory", total_bytes=200,
total_time=11, waiting_time=3) total_time=11, waiting_time=3)
t.update_stats()
self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(),
[dict(result="happy", started=123, [dict(result="happy", started=123,
total_bytes=100, total_time=10, waiting_time=2), total_bytes=100, total_time=10, waiting_time=2),
@ -45,27 +57,37 @@ class DB(unittest.TestCase):
incomplete_bytes=0, incomplete_bytes=0,
waiting=0, connected=0)) waiting=0, connected=0))
with mock.patch("time.time", return_value=T+3): get_time.t = T + 3
t.timerUpdateStats() t.update_stats()
self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(), self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(),
dict(rebooted=T+0, updated=T+3, dict(rebooted=T+0, updated=T+3,
incomplete_bytes=0, incomplete_bytes=0,
waiting=0, connected=0)) waiting=0, connected=0))
def test_no_db(self): def test_no_db(self):
t = Transit(blur_usage=None, log_file=None, usage_db=None) t = Transit(
create_usage_tracker(blur_usage=None, log_file=None, usage_db=None),
lambda: 0,
)
self.assertEqual(0, len(t.usage._backends))
t.recordUsage(started=123, result="happy", total_bytes=100,
total_time=10, waiting_time=2)
t.timerUpdateStats()
class LogToStdout(unittest.TestCase): class LogToStdout(unittest.TestCase):
def test_log(self): def test_log(self):
# emit lines of JSON to log_file, if set # emit lines of JSON to log_file, if set
log_file = io.StringIO() log_file = io.StringIO()
t = Transit(blur_usage=None, log_file=log_file, usage_db=None) t = Transit(
t.recordUsage(started=123, result="happy", total_bytes=100, create_usage_tracker(blur_usage=None, log_file=log_file, usage_db=None),
total_time=10, waiting_time=2) lambda: 0,
)
with mock.patch("time.time", return_value=133):
t.usage.record(
started=123,
buddy_started=125,
result="happy",
bytes_sent=100,
buddy_bytes=0,
)
self.assertEqual(json.loads(log_file.getvalue()), self.assertEqual(json.loads(log_file.getvalue()),
{"started": 123, "total_time": 10, {"started": 123, "total_time": 10,
"waiting_time": 2, "total_bytes": 100, "waiting_time": 2, "total_bytes": 100,
@ -75,15 +97,34 @@ class LogToStdout(unittest.TestCase):
# if blurring is enabled, timestamps should be rounded to the # if blurring is enabled, timestamps should be rounded to the
# requested amount, and sizes should be rounded up too # requested amount, and sizes should be rounded up too
log_file = io.StringIO() log_file = io.StringIO()
t = Transit(blur_usage=60, log_file=log_file, usage_db=None) t = Transit(
t.recordUsage(started=123, result="happy", total_bytes=11999, create_usage_tracker(blur_usage=60, log_file=log_file, usage_db=None),
total_time=10, waiting_time=2) lambda: 0,
)
with mock.patch("time.time", return_value=123 + 10):
t.usage.record(
started=123,
buddy_started=125,
result="happy",
bytes_sent=11999,
buddy_bytes=0,
)
print(log_file.getvalue())
self.assertEqual(json.loads(log_file.getvalue()), self.assertEqual(json.loads(log_file.getvalue()),
{"started": 120, "total_time": 10, {"started": 120, "total_time": 10,
"waiting_time": 2, "total_bytes": 20000, "waiting_time": 2, "total_bytes": 20000,
"mood": "happy"}) "mood": "happy"})
def test_do_not_log(self): def test_do_not_log(self):
t = Transit(blur_usage=60, log_file=None, usage_db=None) t = Transit(
t.recordUsage(started=123, result="happy", total_bytes=11999, create_usage_tracker(blur_usage=60, log_file=None, usage_db=None),
total_time=10, waiting_time=2) lambda: 0,
)
t.usage.record(
started=123,
buddy_started=124,
result="happy",
bytes_sent=11999,
buddy_bytes=12,
)

View File

@ -1,7 +1,30 @@
from binascii import hexlify from binascii import hexlify
from twisted.trial import unittest from twisted.trial import unittest
from .common import ServerBase from twisted.test import iosim
from .. import transit_server from autobahn.twisted.websocket import (
WebSocketServerFactory,
WebSocketClientFactory,
WebSocketClientProtocol,
)
from autobahn.twisted.testing import (
create_pumper,
MemoryReactorClockResolver,
)
from autobahn.exception import Disconnected
from zope.interface import implementer
from .common import (
ServerBase,
IRelayTestClient,
)
from ..usage import (
MemoryUsageRecorder,
blur_size,
)
from ..transit_server import (
WebSocketTransitConnection,
TransitServerState,
)
def handshake(token, side=None): def handshake(token, side=None):
hs = b"please relay " + hexlify(token) hs = b"please relay " + hexlify(token)
@ -12,27 +35,28 @@ def handshake(token, side=None):
class _Transit: class _Transit:
def count(self): def count(self):
return sum([len(potentials) return sum([
for potentials len(potentials)
in self._transit_server._pending_requests.values()]) for potentials
in self._transit_server.pending_requests._requests.values()
])
def test_blur_size(self): def test_blur_size(self):
blur = transit_server.blur_size self.failUnlessEqual(blur_size(0), 0)
self.failUnlessEqual(blur(0), 0) self.failUnlessEqual(blur_size(1), 10e3)
self.failUnlessEqual(blur(1), 10e3) self.failUnlessEqual(blur_size(10e3), 10e3)
self.failUnlessEqual(blur(10e3), 10e3) self.failUnlessEqual(blur_size(10e3+1), 20e3)
self.failUnlessEqual(blur(10e3+1), 20e3) self.failUnlessEqual(blur_size(15e3), 20e3)
self.failUnlessEqual(blur(15e3), 20e3) self.failUnlessEqual(blur_size(20e3), 20e3)
self.failUnlessEqual(blur(20e3), 20e3) self.failUnlessEqual(blur_size(1e6), 1e6)
self.failUnlessEqual(blur(1e6), 1e6) self.failUnlessEqual(blur_size(1e6+1), 2e6)
self.failUnlessEqual(blur(1e6+1), 2e6) self.failUnlessEqual(blur_size(1.5e6), 2e6)
self.failUnlessEqual(blur(1.5e6), 2e6) self.failUnlessEqual(blur_size(2e6), 2e6)
self.failUnlessEqual(blur(2e6), 2e6) self.failUnlessEqual(blur_size(900e6), 900e6)
self.failUnlessEqual(blur(900e6), 900e6) self.failUnlessEqual(blur_size(1000e6), 1000e6)
self.failUnlessEqual(blur(1000e6), 1000e6) self.failUnlessEqual(blur_size(1050e6), 1100e6)
self.failUnlessEqual(blur(1050e6), 1100e6) self.failUnlessEqual(blur_size(1100e6), 1100e6)
self.failUnlessEqual(blur(1100e6), 1100e6) self.failUnlessEqual(blur_size(1150e6), 1200e6)
self.failUnlessEqual(blur(1150e6), 1200e6)
def test_register(self): def test_register(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -49,7 +73,7 @@ class _Transit:
self.assertEqual(self.count(), 0) self.assertEqual(self.count(), 0)
# the token should be removed too # the token should be removed too
self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server.pending_requests._requests), 0)
def test_both_unsided(self): def test_both_unsided(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -75,7 +99,6 @@ class _Transit:
self.assertEqual(p2.get_received_data(), s1) self.assertEqual(p2.get_received_data(), s1)
p1.disconnect() p1.disconnect()
p2.disconnect()
self.flush() self.flush()
def test_sided_unsided(self): def test_sided_unsided(self):
@ -104,7 +127,6 @@ class _Transit:
self.assertEqual(p2.get_received_data(), s1) self.assertEqual(p2.get_received_data(), s1)
p1.disconnect() p1.disconnect()
p2.disconnect()
self.flush() self.flush()
def test_unsided_sided(self): def test_unsided_sided(self):
@ -177,6 +199,7 @@ class _Transit:
p2.send(handshake(token1, side=side1)) p2.send(handshake(token1, side=side1))
self.flush() self.flush()
self.flush()
self.assertEqual(self.count(), 2) # same-side connections don't match self.assertEqual(self.count(), 2) # same-side connections don't match
# when the second side arrives, the spare first connection should be # when the second side arrives, the spare first connection should be
@ -185,8 +208,8 @@ class _Transit:
p3.send(handshake(token1, side=side2)) p3.send(handshake(token1, side=side2))
self.flush() self.flush()
self.assertEqual(self.count(), 0) self.assertEqual(self.count(), 0)
self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server.pending_requests._requests), 0)
self.assertEqual(len(self._transit_server._active_connections), 2) self.assertEqual(len(self._transit_server.active_connections._connections), 2)
# That will trigger a disconnect on exactly one of (p1 or p2). # That will trigger a disconnect on exactly one of (p1 or p2).
# The other connection should still be connected # The other connection should still be connected
self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1) self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1)
@ -266,7 +289,8 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
# sending too many bytes is impatience. # sending too many bytes is impatience.
p1.send(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") p1.send(b"please relay " + hexlify(token1))
p1.send(b"\nNOWNOWNOW")
self.flush() self.flush()
exp = b"impatient\n" exp = b"impatient\n"
@ -281,7 +305,8 @@ class _Transit:
side1 = b"\x01"*8 side1 = b"\x01"*8
# sending too many bytes is impatience. # sending too many bytes is impatience.
p1.send(b"please relay " + hexlify(token1) + p1.send(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\nNOWNOWNOW") b" for side " + hexlify(side1))
p1.send(b"\nNOWNOWNOW")
self.flush() self.flush()
exp = b"impatient\n" exp = b"impatient\n"
@ -327,22 +352,163 @@ class _Transit:
# hang up before sending anything # hang up before sending anything
p1.disconnect() p1.disconnect()
class TransitWithLogs(_Transit, ServerBase, unittest.TestCase): class TransitWithLogs(_Transit, ServerBase, unittest.TestCase):
log_requests = True log_requests = True
def new_protocol(self):
return self.new_protocol_tcp()
class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
log_requests = False log_requests = False
def new_protocol(self):
return self.new_protocol_tcp()
def _new_protocol_ws(transit_server, log_requests):
"""
Internal helper for test-suites that need to provide WebSocket
client/server pairs.
:returns: a 2-tuple: (iosim.IOPump, protocol)
"""
ws_factory = WebSocketServerFactory("ws://localhost:4002")
ws_factory.protocol = WebSocketTransitConnection
ws_factory.transit = transit_server
ws_factory.log_requests = log_requests
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
@implementer(IRelayTestClient)
class TransitWebSocketClientProtocol(WebSocketClientProtocol):
_received = b""
connected = False
def connectionMade(self):
self.connected = True
return super(TransitWebSocketClientProtocol, self).connectionMade()
def connectionLost(self, reason):
self.connected = False
return super(TransitWebSocketClientProtocol, self).connectionLost(reason)
def onMessage(self, data, isBinary):
self._received = self._received + data
def send(self, data):
self.sendMessage(data, True)
def get_received_data(self):
return self._received
def reset_received_data(self):
self._received = b""
def disconnect(self):
self.sendClose(1000, True)
client_factory = WebSocketClientFactory()
client_factory.protocol = TransitWebSocketClientProtocol
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
client_protocol.disconnect = client_protocol.dropConnection
pump = iosim.connect(
ws_protocol,
iosim.makeFakeServer(ws_protocol),
client_protocol,
iosim.makeFakeClient(client_protocol),
)
return pump, client_protocol
class TransitWebSockets(_Transit, ServerBase, unittest.TestCase):
def new_protocol(self):
return self.new_protocol_ws()
def new_protocol_ws(self):
pump, proto = _new_protocol_ws(self._transit_server, self.log_requests)
self._pumps.append(pump)
return proto
def test_websocket_to_tcp(self):
"""
One client is WebSocket and one is TCP
"""
p1 = self.new_protocol_ws()
p2 = self.new_protocol_tcp()
token1 = b"\x00"*32
side1 = b"\x01"*8
side2 = b"\x02"*8
p1.send(handshake(token1, side=side1))
self.flush()
p2.send(handshake(token1, side=side2))
self.flush()
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
self.assertEqual(p1.get_received_data(), exp)
self.assertEqual(p2.get_received_data(), exp)
p1.reset_received_data()
p2.reset_received_data()
# all data they sent after the handshake should be given to us
s1 = b"data1"
p1.send(s1)
self.flush()
self.assertEqual(p2.get_received_data(), s1)
p1.disconnect()
p2.disconnect()
self.flush()
def test_bad_handshake_old_slow(self):
"""
This test only makes sense for TCP
"""
def test_send_closed_partner(self):
"""
Sending data to a closed partner causes an error that propogates
to the sender.
"""
p1 = self.new_protocol()
p2 = self.new_protocol()
# set up a successful connection
token = b"a" * 32
p1.send(handshake(token))
p2.send(handshake(token))
self.flush()
# p2 loses connection, then p1 sends a message
p2.transport.loseConnection()
self.flush()
# at this point, p1 learns that p2 is disconnected (because it
# tried to relay "a message" but failed)
# try to send more (our partner p2 is gone now though so it
# should be an immediate error)
with self.assertRaises(Disconnected):
p1.send(b"more message")
self.flush()
class Usage(ServerBase, unittest.TestCase): class Usage(ServerBase, unittest.TestCase):
log_requests = True log_requests = True
def setUp(self): def setUp(self):
super(Usage, self).setUp() super(Usage, self).setUp()
self._usage = [] self._usage = MemoryUsageRecorder()
def record(started, result, total_bytes, total_time, waiting_time): self._transit_server.usage.add_backend(self._usage)
self._usage.append((started, result, total_bytes,
total_time, waiting_time)) def new_protocol(self):
self._transit_server.recordUsage = record return self.new_protocol_tcp()
def test_empty(self): def test_empty(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -351,11 +517,14 @@ class Usage(ServerBase, unittest.TestCase):
self.flush() self.flush()
# that will log the "empty" usage event # that will log the "empty" usage event
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage.events), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(self._usage.events[0]["mood"], "empty", self._usage)
self.assertEqual(result, "empty", self._usage)
def test_short(self): def test_short(self):
# Note: this test only runs on TCP clients because WebSockets
# already does framing (so it's either "a bad handshake" or
# there's no handshake at all yet .. you can't have a "short"
# one).
p1 = self.new_protocol() p1 = self.new_protocol()
# hang up before sending a complete handshake # hang up before sending a complete handshake
p1.send(b"short") p1.send(b"short")
@ -363,9 +532,8 @@ class Usage(ServerBase, unittest.TestCase):
self.flush() self.flush()
# that will log the "empty" usage event # that will log the "empty" usage event
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage.events), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual("empty", self._usage.events[0]["mood"])
self.assertEqual(result, "empty", self._usage)
def test_errory(self): def test_errory(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -374,9 +542,8 @@ class Usage(ServerBase, unittest.TestCase):
self.flush() self.flush()
# that will log the "errory" usage event, then drop the connection # that will log the "errory" usage event, then drop the connection
p1.disconnect() p1.disconnect()
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage.events), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(self._usage.events[0]["mood"], "errory", self._usage)
self.assertEqual(result, "errory", self._usage)
def test_lonely(self): def test_lonely(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -389,10 +556,9 @@ class Usage(ServerBase, unittest.TestCase):
p1.disconnect() p1.disconnect()
self.flush() self.flush()
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage.events), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(self._usage.events[0]["mood"], "lonely", self._usage)
self.assertEqual(result, "lonely", self._usage) self.assertIdentical(self._usage.events[0]["waiting_time"], None)
self.assertIdentical(waiting_time, None)
def test_one_happy_one_jilted(self): def test_one_happy_one_jilted(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -406,7 +572,7 @@ class Usage(ServerBase, unittest.TestCase):
p2.send(handshake(token1, side=side2)) p2.send(handshake(token1, side=side2))
self.flush() self.flush()
self.assertEqual(self._usage, []) # no events yet self.assertEqual(self._usage.events, []) # no events yet
p1.send(b"\x00" * 13) p1.send(b"\x00" * 13)
self.flush() self.flush()
@ -416,11 +582,10 @@ class Usage(ServerBase, unittest.TestCase):
p1.disconnect() p1.disconnect()
self.flush() self.flush()
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage.events), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(self._usage.events[0]["mood"], "happy", self._usage)
self.assertEqual(result, "happy", self._usage) self.assertEqual(self._usage.events[0]["total_bytes"], 20)
self.assertEqual(total_bytes, 20) self.assertNotIdentical(self._usage.events[0]["waiting_time"], None)
self.assertNotIdentical(waiting_time, None)
def test_redundant(self): def test_redundant(self):
p1a = self.new_protocol() p1a = self.new_protocol()
@ -443,21 +608,80 @@ class Usage(ServerBase, unittest.TestCase):
p1c.disconnect() p1c.disconnect()
self.flush() self.flush()
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage.events), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] self.assertEqual(self._usage.events[0]["mood"], "lonely")
self.assertEqual(result, "lonely", self._usage)
p2.send(handshake(token1, side=side2)) p2.send(handshake(token1, side=side2))
self.flush() self.flush()
self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server.pending_requests._requests), 0)
self.assertEqual(len(self._usage), 2, self._usage) self.assertEqual(len(self._usage.events), 2, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[1] self.assertEqual(self._usage.events[1]["mood"], "redundant")
self.assertEqual(result, "redundant", self._usage)
# one of the these is unecessary, but probably harmless # one of the these is unecessary, but probably harmless
p1a.disconnect() p1a.disconnect()
p1b.disconnect() p1b.disconnect()
self.flush() self.flush()
self.assertEqual(len(self._usage), 3, self._usage) self.assertEqual(len(self._usage.events), 3, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[2] self.assertEqual(self._usage.events[2]["mood"], "happy")
self.assertEqual(result, "happy", self._usage)
class UsageWebSockets(Usage):
"""
All the tests of 'Usage' except with a WebSocket (instead of TCP)
transport.
This overrides ServerBase.new_protocol to achieve this. It might
be nicer to parametrize these tests in a way that doesn't use
inheritance .. but all the support etc classes are set up that way
already.
"""
def setUp(self):
super(UsageWebSockets, self).setUp()
self._pump = create_pumper()
self._reactor = MemoryReactorClockResolver()
return self._pump.start()
def tearDown(self):
return self._pump.stop()
def new_protocol(self):
return self.new_protocol_ws()
def new_protocol_ws(self):
pump, proto = _new_protocol_ws(self._transit_server, self.log_requests)
self._pumps.append(pump)
return proto
def test_short(self):
"""
This test essentially just tests the framing of the line-oriented
TCP protocol; it doesnt' make sense for the WebSockets case
because WS handles frameing: you either sent a 'bad handshake'
because it is semantically invalid or no handshake (yet).
"""
def test_send_non_binary_message(self):
"""
A non-binary WebSocket message is an error
"""
ws_factory = WebSocketServerFactory("ws://localhost:4002")
ws_factory.protocol = WebSocketTransitConnection
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
with self.assertRaises(ValueError):
ws_protocol.onMessage(u"foo", isBinary=False)
class State(unittest.TestCase):
"""
Tests related to server_state.TransitServerState
"""
def setUp(self):
self.state = TransitServerState(None, None)
def test_empty_token(self):
self.assertEqual(
"-",
self.state.get_token(),
)

View File

@ -1,9 +1,9 @@
import re, time, json import re
from collections import defaultdict import time
from twisted.python import log from twisted.python import log
from twisted.internet import protocol
from twisted.protocols.basic import LineReceiver from twisted.protocols.basic import LineReceiver
from .database import get_db from autobahn.twisted.websocket import WebSocketServerProtocol
SECONDS = 1.0 SECONDS = 1.0
MINUTE = 60*SECONDS MINUTE = 60*SECONDS
@ -11,340 +11,254 @@ HOUR = 60*MINUTE
DAY = 24*HOUR DAY = 24*HOUR
MB = 1000*1000 MB = 1000*1000
def round_to(size, coarseness):
return int(coarseness*(1+int((size-1)/coarseness)))
def blur_size(size): from wormhole_transit_relay.server_state import (
if size == 0: TransitServerState,
return 0 PendingRequests,
if size < 1e6: ActiveConnections,
return round_to(size, 10e3) ITransitClient,
if size < 1e9: )
return round_to(size, 1e6) from zope.interface import implementer
return round_to(size, 100e6)
@implementer(ITransitClient)
class TransitConnection(LineReceiver): class TransitConnection(LineReceiver):
delimiter = b'\n' delimiter = b'\n'
# maximum length of a line we will accept before the handshake is complete. # maximum length of a line we will accept before the handshake is complete.
# This must be >= to the longest possible handshake message. # This must be >= to the longest possible handshake message.
MAX_LENGTH = 1024 MAX_LENGTH = 1024
started_time = None
def __init__(self): def send(self, data):
self._got_token = False """
self._got_side = False ITransitClient API
self._sent_ok = False """
self._mood = "empty" self.transport.write(data)
def disconnect(self):
"""
ITransitClient API
"""
self.transport.loseConnection()
def connect_partner(self, other):
"""
ITransitClient API
"""
self._buddy = other
def disconnect_partner(self):
"""
ITransitClient API
"""
assert self._buddy is not None, "internal error: no buddy"
if self.factory.log_requests:
log.msg("buddy_disconnected {}".format(self._buddy.get_token()))
self._buddy._client.disconnect()
self._buddy = None self._buddy = None
self._total_sent = 0
def describeToken(self):
d = "-"
if self._got_token:
d = self._got_token[:16].decode("ascii")
if self._got_side:
d += "-" + self._got_side.decode("ascii")
else:
d += "-<unsided>"
return d
def connectionMade(self): def connectionMade(self):
self._started = time.time() # ideally more like self._reactor.seconds() ... but Twisted
self._log_requests = self.factory._log_requests # doesn't have a good way to get the reactor for a protocol
# (besides "use the global one")
self.started_time = time.time()
self._state = TransitServerState(
self.factory.transit.pending_requests,
self.factory.transit.usage,
)
self._state.connection_made(self)
self.transport.setTcpKeepAlive(True) self.transport.setTcpKeepAlive(True)
# uncomment to turn on state-machine tracing
# def tracer(oldstate, theinput, newstate):
# print("TRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate))
# self._state.set_trace_function(tracer)
def lineReceived(self, line): def lineReceived(self, line):
"""
LineReceiver API
"""
# old: "please relay {64}\n" # old: "please relay {64}\n"
token = None
old = re.search(br"^please relay (\w{64})$", line) old = re.search(br"^please relay (\w{64})$", line)
if old: if old:
token = old.group(1) token = old.group(1)
return self._got_handshake(token, None) self._state.please_relay(token)
# new: "please relay {64} for side {16}\n" # new: "please relay {64} for side {16}\n"
new = re.search(br"^please relay (\w{64}) for side (\w{16})$", line) new = re.search(br"^please relay (\w{64}) for side (\w{16})$", line)
if new: if new:
token = new.group(1) token = new.group(1)
side = new.group(2) side = new.group(2)
return self._got_handshake(token, side) self._state.please_relay_for_side(token, side)
self.sendLine(b"bad handshake") if token is None:
if self._log_requests: self._state.bad_token()
log.msg("transit handshake failure") else:
return self.disconnect_error() self.setRawMode()
def rawDataReceived(self, data): def rawDataReceived(self, data):
"""
LineReceiver API
"""
# We are an IPushProducer to our buddy's IConsumer, so they'll # We are an IPushProducer to our buddy's IConsumer, so they'll
# throttle us (by calling pauseProducing()) when their outbound # throttle us (by calling pauseProducing()) when their outbound
# buffer is full (e.g. when their downstream pipe is full). In # buffer is full (e.g. when their downstream pipe is full). In
# practice, this buffers about 10MB per connection, after which # practice, this buffers about 10MB per connection, after which
# point the sender will only transmit data as fast as the # point the sender will only transmit data as fast as the
# receiver can handle it. # receiver can handle it.
if self._sent_ok: self._state.got_bytes(data)
# if self._buddy is None then our buddy disconnected
# (we're "jilted"), so we hung up too, but our incoming
# data hasn't stopped yet (it will in a moment, after our
# disconnect makes a roundtrip through the kernel). This
# probably means the file receiver hung up, and this
# connection is the file sender. In may-2020 this happened
# 11 times in 40 days.
if self._buddy:
self._total_sent += len(data)
self._buddy.transport.write(data)
return
# handshake is complete but not yet sent_ok
self.sendLine(b"impatient")
if self._log_requests:
log.msg("transit impatience failure")
return self.disconnect_error() # impatience yields failure
def _got_handshake(self, token, side):
self._got_token = token
self._got_side = side
self._mood = "lonely" # until buddy connects
self.setRawMode()
self.factory.connection_got_token(token, side, self)
def buddy_connected(self, them):
self._buddy = them
self._mood = "happy"
self.sendLine(b"ok")
self._sent_ok = True
# Connect the two as a producer/consumer pair. We use streaming=True,
# so this expects the IPushProducer interface, and uses
# pauseProducing() to throttle, and resumeProducing() to unthrottle.
self._buddy.transport.registerProducer(self.transport, True)
# The Transit object calls buddy_connected() on both protocols, so
# there will be two producer/consumer pairs.
def buddy_disconnected(self):
if self._log_requests:
log.msg("buddy_disconnected %s" % self.describeToken())
self._buddy = None
self._mood = "jilted"
self.transport.loseConnection()
def disconnect_error(self):
# we haven't finished the handshake, so there are no tokens tracking
# us
self._mood = "errory"
self.transport.loseConnection()
if self.factory._debug_log:
log.msg("transitFailed %r" % self)
def disconnect_redundant(self):
# this is called if a buddy connected and we were found unnecessary.
# Any token-tracking cleanup will have been done before we're called.
self._mood = "redundant"
self.transport.loseConnection()
def connectionLost(self, reason): def connectionLost(self, reason):
finished = time.time() self._state.connection_lost()
total_time = finished - self._started
# Record usage. There are eight cases:
# * n0: we haven't gotten a full handshake yet (empty)
# * n1: the handshake failed, not a real client (errory)
# * n2: real client disconnected before any buddy appeared (lonely)
# * n3: real client closed as redundant after buddy appears (redundant)
# * n4: real client connected first, buddy closes first (jilted)
# * n5: real client connected first, buddy close last (happy)
# * n6: real client connected last, buddy closes first (jilted)
# * n7: real client connected last, buddy closes last (happy)
# * non-connected clients (0,1,2,3) always write a usage record class Transit(object):
# * for connected clients, whoever disconnects first gets to write the """
# usage record (5, 7). The last disconnect doesn't write a record. I manage pairs of simultaneous connections to a secondary TCP port,
both forwarded to the other. Clients must begin each connection with
"please relay TOKEN for SIDE\n" (or a legacy form without the "for
SIDE"). Two connections match if they use the same TOKEN and have
different SIDEs (the redundant connections are dropped when a match is
made). Legacy connections match any with the same TOKEN, ignoring SIDE
(so two legacy connections will match each other).
if self._mood == "empty": # 0 I will send "ok\n" when the matching connection is established, or
assert not self._buddy disconnect if no matching connection is made within MAX_WAIT_TIME
self.factory.recordUsage(self._started, "empty", 0, seconds. I will disconnect if you send data before the "ok\n". All data
total_time, None) you get after the "ok\n" will be from the other side. You will not
elif self._mood == "errory": # 1 receive "ok\n" until the other side has also connected and submitted a
assert not self._buddy matching token (and differing SIDE).
self.factory.recordUsage(self._started, "errory", 0,
total_time, None)
elif self._mood == "redundant": # 3
assert not self._buddy
self.factory.recordUsage(self._started, "redundant", 0,
total_time, None)
elif self._mood == "jilted": # 4 or 6
# we were connected, but our buddy hung up on us. They record the
# usage event, we do not
pass
elif self._mood == "lonely": # 2
assert not self._buddy
self.factory.recordUsage(self._started, "lonely", 0,
total_time, None)
else: # 5 or 7
# we were connected, we hung up first. We record the event.
assert self._mood == "happy", self._mood
assert self._buddy
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
total_bytes = self._total_sent + self._buddy._total_sent
self.factory.recordUsage(self._started, "happy", total_bytes,
total_time, waiting_time)
if self._buddy: In addition, the connections will be dropped after MAXLENGTH bytes have
self._buddy.buddy_disconnected() been sent by either side, or MAXTIME seconds have elapsed after the
self.factory.transitFinished(self, self._got_token, self._got_side, matching connections were established. A future API will reveal these
self.describeToken()) limits to clients instead of causing mysterious spontaneous failures.
class Transit(protocol.ServerFactory): These relay connections are not half-closeable (unlike full TCP
# I manage pairs of simultaneous connections to a secondary TCP port, connections, applications will not receive any data after half-closing
# both forwarded to the other. Clients must begin each connection with their outgoing side). Applications must negotiate shutdown with their
# "please relay TOKEN for SIDE\n" (or a legacy form without the "for peer and not close the connection until all data has finished
# SIDE"). Two connections match if they use the same TOKEN and have transferring in both directions. Applications which only need to send
# different SIDEs (the redundant connections are dropped when a match is data in one direction can use close() as usual.
# made). Legacy connections match any with the same TOKEN, ignoring SIDE """
# (so two legacy connections will match each other).
# I will send "ok\n" when the matching connection is established, or
# disconnect if no matching connection is made within MAX_WAIT_TIME
# seconds. I will disconnect if you send data before the "ok\n". All data
# you get after the "ok\n" will be from the other side. You will not
# receive "ok\n" until the other side has also connected and submitted a
# matching token (and differing SIDE).
# In addition, the connections will be dropped after MAXLENGTH bytes have
# been sent by either side, or MAXTIME seconds have elapsed after the
# matching connections were established. A future API will reveal these
# limits to clients instead of causing mysterious spontaneous failures.
# These relay connections are not half-closeable (unlike full TCP
# connections, applications will not receive any data after half-closing
# their outgoing side). Applications must negotiate shutdown with their
# peer and not close the connection until all data has finished
# transferring in both directions. Applications which only need to send
# data in one direction can use close() as usual.
# TODO: unused
MAX_WAIT_TIME = 30*SECONDS MAX_WAIT_TIME = 30*SECONDS
# TODO: unused
MAXLENGTH = 10*MB MAXLENGTH = 10*MB
# TODO: unused
MAXTIME = 60*SECONDS MAXTIME = 60*SECONDS
protocol = TransitConnection
def __init__(self, blur_usage, log_file, usage_db): def __init__(self, usage, get_timestamp):
self._blur_usage = blur_usage self.active_connections = ActiveConnections()
self._log_requests = blur_usage is None self.pending_requests = PendingRequests(self.active_connections)
if self._blur_usage: self.usage = usage
log.msg("blurring access times to %d seconds" % self._blur_usage) self._timestamp = get_timestamp
log.msg("not logging Transit connections to Twisted log") self._rebooted = self._timestamp()
else:
log.msg("not blurring access times")
self._debug_log = False
self._log_file = log_file
self._db = None
if usage_db:
self._db = get_db(usage_db)
self._rebooted = time.time()
# we don't track TransitConnections until they submit a token
self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection))
self._active_connections = set() # TransitConnection
def connection_got_token(self, token, new_side, new_tc): def update_stats(self):
potentials = self._pending_requests[token]
for old in potentials:
(old_side, old_tc) = old
if ((old_side is None)
or (new_side is None)
or (old_side != new_side)):
# we found a match
if self._debug_log:
log.msg("transit relay 2: %s" % new_tc.describeToken())
# drop and stop tracking the rest
potentials.remove(old)
for (_, leftover_tc) in potentials.copy():
# Don't record this as errory. It's just a spare connection
# from the same side as a connection that got used. This
# can happen if the connection hint contains multiple
# addresses (we don't currently support those, but it'd
# probably be useful in the future).
leftover_tc.disconnect_redundant()
self._pending_requests.pop(token, None)
# glue the two ends together
self._active_connections.add(new_tc)
self._active_connections.add(old_tc)
new_tc.buddy_connected(old_tc)
old_tc.buddy_connected(new_tc)
return
if self._debug_log:
log.msg("transit relay 1: %s" % new_tc.describeToken())
potentials.add((new_side, new_tc))
# TODO: timer
def transitFinished(self, tc, token, side, description):
if token in self._pending_requests:
side_tc = (side, tc)
self._pending_requests[token].discard(side_tc)
if not self._pending_requests[token]: # set is now empty
del self._pending_requests[token]
if self._debug_log:
log.msg("transitFinished %s" % (description,))
self._active_connections.discard(tc)
# we could update the usage database "current" row immediately, or wait
# until the 5-minute timer updates it. If we update it now, just after
# losing a connection, we should probably also update it just after
# establishing one (at the end of connection_got_token). For now I'm
# going to omit these, but maybe someday we'll turn them both on. The
# consequence is that a manual execution of the munin scripts ("munin
# run wormhole_transit_active") will give the wrong value just after a
# connect/disconnect event. Actual munin graphs should accurately
# report connections that last longer than the 5-minute sampling
# window, which is what we actually care about.
#self.timerUpdateStats()
def recordUsage(self, started, result, total_bytes,
total_time, waiting_time):
if self._debug_log:
log.msg(format="Transit.recordUsage {bytes}B", bytes=total_bytes)
if self._blur_usage:
started = self._blur_usage * (started // self._blur_usage)
total_bytes = blur_size(total_bytes)
if self._log_file is not None:
data = {"started": started,
"total_time": total_time,
"waiting_time": waiting_time,
"total_bytes": total_bytes,
"mood": result,
}
self._log_file.write(json.dumps(data)+"\n")
self._log_file.flush()
if self._db:
self._db.execute("INSERT INTO `usage`"
" (`started`, `total_time`, `waiting_time`,"
" `total_bytes`, `result`)"
" VALUES (?,?,?, ?,?)",
(started, total_time, waiting_time,
total_bytes, result))
self._update_stats()
self._db.commit()
def timerUpdateStats(self):
if self._db:
self._update_stats()
self._db.commit()
def _update_stats(self):
# current status: should be zero when idle
rebooted = self._rebooted
updated = time.time()
connected = len(self._active_connections) / 2
# TODO: when a connection is half-closed, len(active) will be odd. a # TODO: when a connection is half-closed, len(active) will be odd. a
# moment later (hopefully) the other side will disconnect, but # moment later (hopefully) the other side will disconnect, but
# _update_stats isn't updated until later. # _update_stats isn't updated until later.
waiting = len(self._pending_requests)
# "waiting" doesn't count multiple parallel connections from the same # "waiting" doesn't count multiple parallel connections from the same
# side # side
incomplete_bytes = sum(tc._total_sent self.usage.update_stats(
for tc in self._active_connections) rebooted=self._rebooted,
self._db.execute("DELETE FROM `current`") updated=self._timestamp(),
self._db.execute("INSERT INTO `current`" connected=len(self.active_connections._connections),
" (`rebooted`, `updated`, `connected`, `waiting`," waiting=len(self.pending_requests._requests),
" `incomplete_bytes`)" incomplete_bytes=sum(
" VALUES (?, ?, ?, ?, ?)", tc._total_sent
(rebooted, updated, connected, waiting, for tc in self.active_connections._connections
incomplete_bytes)) ),
)
@implementer(ITransitClient)
class WebSocketTransitConnection(WebSocketServerProtocol):
started_time = None
def send(self, data):
"""
ITransitClient API
"""
self.sendMessage(data, isBinary=True)
def disconnect(self):
"""
ITransitClient API
"""
self.sendClose(1000, None)
def connect_partner(self, other):
"""
ITransitClient API
"""
self._buddy = other
def disconnect_partner(self):
"""
ITransitClient API
"""
assert self._buddy is not None, "internal error: no buddy"
if self.factory.log_requests:
log.msg("buddy_disconnected {}".format(self._buddy.get_token()))
self._buddy._client.disconnect()
self._buddy = None
def connectionMade(self):
"""
IProtocol API
"""
super(WebSocketTransitConnection, self).connectionMade()
self.started_time = time.time()
self._first_message = True
self._state = TransitServerState(
self.factory.transit.pending_requests,
self.factory.transit.usage,
)
# uncomment to turn on state-machine tracing
# def tracer(oldstate, theinput, newstate):
# print("WSTRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate))
# self._state.set_trace_function(tracer)
def onOpen(self):
self._state.connection_made(self)
def onMessage(self, payload, isBinary):
"""
We may have a 'handshake' on our hands or we may just have some bytes to relay
"""
if not isBinary:
raise ValueError(
"All messages must be binary"
)
if self._first_message:
self._first_message = False
token = None
old = re.search(br"^please relay (\w{64})$", payload)
if old:
token = old.group(1)
self._state.please_relay(token)
# new: "please relay {64} for side {16}\n"
new = re.search(br"^please relay (\w{64}) for side (\w{16})$", payload)
if new:
token = new.group(1)
side = new.group(2)
self._state.please_relay_for_side(token, side)
if token is None:
self._state.bad_token()
else:
self._state.got_bytes(payload)
def onClose(self, wasClean, code, reason):
"""
IWebSocketChannel API
"""
self._state.connection_lost()

View File

@ -0,0 +1,238 @@
import time
import json
from twisted.python import log
from zope.interface import (
implementer,
Interface,
)
def create_usage_tracker(blur_usage, log_file, usage_db):
"""
:param int blur_usage: see UsageTracker
:param log_file: None or a file-like object to write JSON-encoded
lines of usage information to.
:param usage_db: None or an sqlite3 database connection
:returns: a new UsageTracker instance configured with backends.
"""
tracker = UsageTracker(blur_usage)
if usage_db:
tracker.add_backend(DatabaseUsageRecorder(usage_db))
if log_file:
tracker.add_backend(LogFileUsageRecorder(log_file))
return tracker
class IUsageWriter(Interface):
"""
Records actual usage statistics in some way
"""
def record_usage(started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None):
"""
:param int started: timestemp when this connection began
:param float total_time: total seconds this connection lasted
:param float waiting_time: None or the total seconds one side
waited for the other
:param int total_bytes: the total bytes sent. In case the
connection was concluded successfully, only one side will
record the total bytes (but count both).
:param str mood: the 'mood' of the connection
"""
@implementer(IUsageWriter)
class MemoryUsageRecorder:
"""
Remebers usage records in memory.
"""
def __init__(self):
self.events = []
def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None):
"""
IUsageWriter.
"""
data = {
"started": started,
"total_time": total_time,
"waiting_time": waiting_time,
"total_bytes": total_bytes,
"mood": mood,
}
self.events.append(data)
@implementer(IUsageWriter)
class LogFileUsageRecorder:
"""
Writes usage records to a file. The records are written in JSON,
one record per line.
"""
def __init__(self, writable_file):
self._file = writable_file
def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None):
"""
IUsageWriter.
"""
data = {
"started": started,
"total_time": total_time,
"waiting_time": waiting_time,
"total_bytes": total_bytes,
"mood": mood,
}
self._file.write(json.dumps(data) + "\n")
self._file.flush()
@implementer(IUsageWriter)
class DatabaseUsageRecorder:
"""
Write usage records into a database
"""
def __init__(self, db):
self._db = db
def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None):
"""
IUsageWriter.
"""
self._db.execute(
"INSERT INTO `usage`"
" (`started`, `total_time`, `waiting_time`,"
" `total_bytes`, `result`)"
" VALUES (?,?,?,?,?)",
(started, total_time, waiting_time, total_bytes, mood)
)
# original code did "self._update_stats()" here, thus causing
# "global" stats update on every connection update .. should
# we repeat this behavior, or really only record every
# 60-seconds with the timer?
self._db.commit()
class UsageTracker(object):
"""
Tracks usage statistics of connections
"""
def __init__(self, blur_usage):
"""
:param int blur_usage: None or the number of seconds to use as a
window around which to blur time statistics (e.g. "60" means times
will be rounded to 1 minute intervals). When blur_usage is
non-zero, sizes will also be rounded into buckets of "one
megabyte", "one gigabyte" or "lots"
"""
self._backends = set()
self._blur_usage = blur_usage
if blur_usage:
log.msg("blurring access times to %d seconds" % self._blur_usage)
else:
log.msg("not blurring access times")
def add_backend(self, backend):
"""
Add a new backend.
:param IUsageWriter backend: the backend to add
"""
self._backends.add(backend)
def record(self, started, buddy_started, result, bytes_sent, buddy_bytes):
"""
:param int started: timestamp when our connection started
:param int buddy_started: None, or the timestamp when our
partner's connection started (will be None if we don't yet
have a partner).
:param str result: a label for the result of the connection
(one of the "moods").
:param int bytes_sent: number of bytes we sent
:param int buddy_bytes: number of bytes our partner sent
"""
# ideally self._reactor.seconds() or similar, but ..
finished = time.time()
if buddy_started is not None:
starts = [started, buddy_started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
total_bytes = bytes_sent + buddy_bytes
else:
total_time = finished - started
waiting_time = None
total_bytes = bytes_sent
# note that "bytes_sent" should always be 0 here, but
# we're recording what the state-machine remembered in any
# case
if self._blur_usage:
started = self._blur_usage * (started // self._blur_usage)
total_bytes = blur_size(total_bytes)
# This is "a dict" instead of "kwargs" because we have to make
# it into a dict for the log use-case and in-memory/testing
# use-case anyway so this is less repeats of the names.
self._notify_backends({
"started": started,
"total_time": total_time,
"waiting_time": waiting_time,
"total_bytes": total_bytes,
"mood": result,
})
def update_stats(self, rebooted, updated, connected, waiting,
incomplete_bytes):
"""
Update general statistics.
"""
# in original code, this is only recorded in the database
# .. perhaps a better way to do this, but ..
for backend in self._backends:
if isinstance(backend, DatabaseUsageRecorder):
backend._db.execute("DELETE FROM `current`")
backend._db.execute(
"INSERT INTO `current`"
" (`rebooted`, `updated`, `connected`, `waiting`,"
" `incomplete_bytes`)"
" VALUES (?, ?, ?, ?, ?)",
(int(rebooted), int(updated), connected, waiting,
incomplete_bytes)
)
def _notify_backends(self, data):
"""
Internal helper. Tell every backend we have about a new usage record.
"""
for backend in self._backends:
backend.record_usage(**data)
def round_to(size, coarseness):
return int(coarseness*(1+int((size-1)/coarseness)))
def blur_size(size):
if size == 0:
return 0
if size < 1e6:
return round_to(size, 10e3)
if size < 1e9:
return round_to(size, 1e6)
return round_to(size, 100e6)

82
ws_client.py Normal file
View File

@ -0,0 +1,82 @@
"""
This is a test-client for the transit-relay that uses WebSockets.
If an additional command-line argument (anything) is added, it will
send 5 messages upon connection. Otherwise, it just prints out what is
received. Uses a fixed token of 64 'a' characters. Always connects on
localhost:4002
"""
import sys
from twisted.internet import endpoints
from twisted.internet.defer import (
Deferred,
inlineCallbacks,
)
from twisted.internet.task import react, deferLater
from autobahn.twisted.websocket import (
WebSocketClientProtocol,
WebSocketClientFactory,
)
class RelayEchoClient(WebSocketClientProtocol):
def onOpen(self):
self._received = b""
self.sendMessage(
u"please relay {} for side {}".format(
self.factory.token,
self.factory.side,
).encode("ascii"),
True,
)
def onMessage(self, data, isBinary):
print(">onMessage: {} bytes".format(len(data)))
print(data, isBinary)
if data == b"ok\n":
self.factory.ready.callback(None)
else:
self._received += data
if False:
# test abrupt hangup from receiving side
self.transport.loseConnection()
def onClose(self, wasClean, code, reason):
print(">onClose", wasClean, code, reason)
self.factory.done.callback(reason)
if not self.factory.ready.called:
self.factory.ready.errback(RuntimeError(reason))
@react
@inlineCallbacks
def main(reactor):
will_send_message = len(sys.argv) > 1
ep = endpoints.clientFromString(reactor, "tcp:localhost:4002")
f = WebSocketClientFactory("ws://127.0.0.1:4002/")
f.reactor = reactor
f.protocol = RelayEchoClient
f.token = "a" * 64
f.side = "0" * 16 if will_send_message else "1" * 16
f.done = Deferred()
f.ready = Deferred()
proto = yield ep.connect(f)
print("proto", proto)
yield f.ready
print("ready")
if will_send_message:
for _ in range(5):
print("sending message")
proto.sendMessage(b"it's a message", True)
yield deferLater(reactor, 0.2)
yield proto.sendClose()
print("closing")
yield f.done
print("relayed {} bytes:".format(len(proto._received)))
print(proto._received.decode("utf8"))