magic-wormhole-transit-relay/src/wormhole_transit_relay/transit_server.py
2021-04-14 15:05:44 -06:00

287 lines
11 KiB
Python

from __future__ import print_function, unicode_literals
import re, time, json
from collections import defaultdict
from twisted.python import log
from twisted.internet import protocol
from twisted.protocols.basic import LineReceiver
SECONDS = 1.0
MINUTE = 60*SECONDS
HOUR = 60*MINUTE
DAY = 24*HOUR
MB = 1000*1000
from wormhole_transit_relay.server_state import (
TransitServerState,
PendingRequests,
ActiveConnections,
UsageTracker,
DatabaseUsageRecorder,
LogFileUsageRecorder,
ITransitClient,
)
from zope.interface import implementer
@implementer(ITransitClient)
class TransitConnection(LineReceiver):
delimiter = b'\n'
# maximum length of a line we will accept before the handshake is complete.
# This must be >= to the longest possible handshake message.
MAX_LENGTH = 1024
started_time = None
def send(self, data):
"""
ITransitClient API
"""
self.transport.write(data)
def disconnect(self):
"""
ITransitClient API
"""
self.transport.loseConnection()
def connect_partner(self, other):
"""
ITransitClient API
"""
self._buddy = other
def disconnect_partner(self):
"""
ITransitClient API
"""
self._buddy._client.transport.loseConnection()
self._buddy = None
def describeToken(self):
d = "-"
if self._got_token:
d = self._got_token[:16].decode("ascii")
if self._got_side:
d += "-" + self._got_side.decode("ascii")
else:
d += "-<unsided>"
return d
def connectionMade(self):
# ideally more like self._reactor.seconds() ... but Twisted
# doesn't have a good way to get the reactor for a protocol
# (besides "use the global one")
self.started_time = time.time()
self._state = TransitServerState(
self.factory.pending_requests,
self.factory.usage,
)
self._state.connection_made(self)
## self._log_requests = self.factory._log_requests
try:
self.transport.setTcpKeepAlive(True)
except AttributeError:
pass
def lineReceived(self, line):
# old: "please relay {64}\n"
old = re.search(br"^please relay (\w{64})$", line)
if old:
token = old.group(1)
return self._got_handshake(token, None)
# new: "please relay {64} for side {16}\n"
new = re.search(br"^please relay (\w{64}) for side (\w{16})$", line)
if new:
token = new.group(1)
side = new.group(2)
return self._got_handshake(token, side)
# we should have been switched to "raw data" mode on the first
# line received (after which rawDataReceived() is called for
# all bytes) so getting here means a bad handshake.
return self._state.bad_token()
def rawDataReceived(self, data):
# We are an IPushProducer to our buddy's IConsumer, so they'll
# throttle us (by calling pauseProducing()) when their outbound
# buffer is full (e.g. when their downstream pipe is full). In
# practice, this buffers about 10MB per connection, after which
# point the sender will only transmit data as fast as the
# receiver can handle it.
self._state.got_bytes(data)
def _got_handshake(self, token, side):
self._state.please_relay_for_side(token, side)
# self._mood = "lonely" # until buddy connects
self.setRawMode()
def __buddy_connected(self, them):
self._buddy = them
self._mood = "happy"
self.sendLine(b"ok")
self._sent_ok = True
# Connect the two as a producer/consumer pair. We use streaming=True,
# so this expects the IPushProducer interface, and uses
# pauseProducing() to throttle, and resumeProducing() to unthrottle.
self._buddy.transport.registerProducer(self.transport, True)
# The Transit object calls buddy_connected() on both protocols, so
# there will be two producer/consumer pairs.
def __buddy_disconnected(self):
## if self._log_requests:
## log.msg("buddy_disconnected %s" % self.describeToken())
self._buddy = None
self._mood = "jilted"
self.transport.loseConnection()
def disconnect_error(self):
# we haven't finished the handshake, so there are no tokens tracking
# us
self._mood = "errory"
self.transport.loseConnection()
if self.factory._debug_log:
log.msg("transitFailed %r" % self)
def disconnect_redundant(self):
# this is called if a buddy connected and we were found unnecessary.
# Any token-tracking cleanup will have been done before we're called.
self._mood = "redundant"
self.transport.loseConnection()
def connectionLost(self, reason):
self._state.connection_lost()
# XXX FIXME record usage
if False:
# Record usage. There are eight cases:
# * n0: we haven't gotten a full handshake yet (empty)
# * n1: the handshake failed, not a real client (errory)
# * n2: real client disconnected before any buddy appeared (lonely)
# * n3: real client closed as redundant after buddy appears (redundant)
# * n4: real client connected first, buddy closes first (jilted)
# * n5: real client connected first, buddy close last (happy)
# * n6: real client connected last, buddy closes first (jilted)
# * n7: real client connected last, buddy closes last (happy)
# * non-connected clients (0,1,2,3) always write a usage record
# * for connected clients, whoever disconnects first gets to write the
# usage record (5, 7). The last disconnect doesn't write a record.
if self._mood == "empty": # 0
assert not self._buddy
self.factory.recordUsage(self._started, "empty", 0,
total_time, None)
elif self._mood == "errory": # 1
assert not self._buddy
self.factory.recordUsage(self._started, "errory", 0,
total_time, None)
elif self._mood == "redundant": # 3
assert not self._buddy
self.factory.recordUsage(self._started, "redundant", 0,
total_time, None)
elif self._mood == "jilted": # 4 or 6
# we were connected, but our buddy hung up on us. They record the
# usage event, we do not
pass
elif self._mood == "lonely": # 2
assert not self._buddy
self.factory.recordUsage(self._started, "lonely", 0,
total_time, None)
else: # 5 or 7
# we were connected, we hung up first. We record the event.
assert self._mood == "happy", self._mood
assert self._buddy
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
total_bytes = self._total_sent + self._buddy._total_sent
self.factory.recordUsage(self._started, "happy", total_bytes,
total_time, waiting_time)
if self._buddy:
self._buddy.buddy_disconnected()
# self.factory.transitFinished(self, self._got_token, self._got_side,
# self.describeToken())
class Transit(protocol.ServerFactory):
"""
I manage pairs of simultaneous connections to a secondary TCP port,
both forwarded to the other. Clients must begin each connection with
"please relay TOKEN for SIDE\n" (or a legacy form without the "for
SIDE"). Two connections match if they use the same TOKEN and have
different SIDEs (the redundant connections are dropped when a match is
made). Legacy connections match any with the same TOKEN, ignoring SIDE
(so two legacy connections will match each other).
I will send "ok\n" when the matching connection is established, or
disconnect if no matching connection is made within MAX_WAIT_TIME
seconds. I will disconnect if you send data before the "ok\n". All data
you get after the "ok\n" will be from the other side. You will not
receive "ok\n" until the other side has also connected and submitted a
matching token (and differing SIDE).
In addition, the connections will be dropped after MAXLENGTH bytes have
been sent by either side, or MAXTIME seconds have elapsed after the
matching connections were established. A future API will reveal these
limits to clients instead of causing mysterious spontaneous failures.
These relay connections are not half-closeable (unlike full TCP
connections, applications will not receive any data after half-closing
their outgoing side). Applications must negotiate shutdown with their
peer and not close the connection until all data has finished
transferring in both directions. Applications which only need to send
data in one direction can use close() as usual.
"""
# TODO: unused
MAX_WAIT_TIME = 30*SECONDS
# TODO: unused
MAXLENGTH = 10*MB
# TODO: unused
MAXTIME = 60*SECONDS
protocol = TransitConnection
def __init__(self, usage):
self.active_connections = ActiveConnections()
self.pending_requests = PendingRequests(self.active_connections)
self.usage = usage
if False:
# these logs-message should be made by the usage-tracker
# .. or in the "tap" setup?
if blur_usage:
log.msg("blurring access times to %d seconds" % self._blur_usage)
log.msg("not logging Transit connections to Twisted log")
else:
log.msg("not blurring access times")
self._debug_log = False
self._rebooted = time.time()
# XXX TODO self._rebooted and the below could be in a separate
# object? or in the DatabaseUsageRecorder .. but not here
def _update_stats(self):
# current status: should be zero when idle
rebooted = self._rebooted
updated = time.time()
connected = len(self._active_connections) / 2
# TODO: when a connection is half-closed, len(active) will be odd. a
# moment later (hopefully) the other side will disconnect, but
# _update_stats isn't updated until later.
waiting = len(self._pending_requests)
# "waiting" doesn't count multiple parallel connections from the same
# side
incomplete_bytes = sum(tc._total_sent
for tc in self._active_connections)
self._db.execute("DELETE FROM `current`")
self._db.execute("INSERT INTO `current`"
" (`rebooted`, `updated`, `connected`, `waiting`,"
" `incomplete_bytes`)"
" VALUES (?, ?, ?, ?, ?)",
(rebooted, updated, connected, waiting,
incomplete_bytes))