cleanup, remove dead code

This commit is contained in:
meejah 2021-02-12 16:35:20 -07:00
parent b7bcdfdca3
commit 3ae3bb7443
6 changed files with 131 additions and 126 deletions

View File

@ -8,6 +8,7 @@ from zope.interface import (
Attribute, Attribute,
implementer, implementer,
) )
from .database import get_db
class ITransitClient(Interface): class ITransitClient(Interface):
@ -167,12 +168,41 @@ def blur_size(size):
return round_to(size, 100e6) return round_to(size, 100e6)
def create_usage_tracker(blur_usage, log_file, usage_db):
"""
:param int blur_usage: see UsageTracker
:param log_file: None or a file-like object to write JSON-encoded
lines of usage information to.
:param usage_db: None or an sqlite3 database connection
:returns: a new UsageTracker instance configured with backends.
"""
tracker = UsageTracker(blur_usage)
if usage_db:
db = get_db(usage_db)
tracker.add_backend(DatabaseUsageRecorder(db))
if log_file:
tracker.add_backend(LogFileUsageRecorder(log_file))
return tracker
class UsageTracker(object): class UsageTracker(object):
""" """
Tracks usage statistics of connections Tracks usage statistics of connections
""" """
def __init__(self, blur_usage): def __init__(self, blur_usage):
"""
:param int blur_usage: None or the number of seconds to use as a
window around which to blur time statistics (e.g. "60" means times
will be rounded to 1 minute intervals). When blur_usage is
non-zero, sizes will also be rounded into buckets of "one
megabyte", "one gigabyte" or "lots"
"""
self._backends = set() self._backends = set()
self._blur_usage = blur_usage self._blur_usage = blur_usage
@ -223,6 +253,9 @@ class UsageTracker(object):
started = self._blur_usage * (started // self._blur_usage) started = self._blur_usage * (started // self._blur_usage)
total_bytes = blur_size(total_bytes) total_bytes = blur_size(total_bytes)
# This is "a dict" instead of "kwargs" because we have to make
# it into a dict for the log use-case and in-memory/testing
# use-case anyway so this is less repeats of the names.
self._notify_backends({ self._notify_backends({
"started": started, "started": started,
"total_time": total_time, "total_time": total_time,
@ -233,7 +266,7 @@ class UsageTracker(object):
def _notify_backends(self, data): def _notify_backends(self, data):
""" """
Internal helper. Tell every backend we have about a new usage. Internal helper. Tell every backend we have about a new usage record.
""" """
for backend in self._backends: for backend in self._backends:
backend.record_usage(**data) backend.record_usage(**data)
@ -241,8 +274,11 @@ class UsageTracker(object):
class ActiveConnections(object): class ActiveConnections(object):
""" """
Tracks active connections. A connection is 'active' when both Tracks active connections.
sides have shown up and they are glued together.
A connection is 'active' when both sides have shown up and they
are glued together (and thus could be passing data back and forth
if any is flowing).
""" """
def __init__(self): def __init__(self):
self._connections = set() self._connections = set()
@ -268,12 +304,20 @@ class ActiveConnections(object):
class PendingRequests(object): class PendingRequests(object):
""" """
Tracks the tokens we have received from client connections and Tracks outstanding (non-"active") requests.
maps them to their partner connections for requests that haven't
yet been 'glued together' (that is, one side hasn't yet shown up). We register client connections against the tokens we have
received. When the other side shows up we can thus match it to the
correct partner connection. At this point, the connection becomes
"active" is and is thus no longer "pending" and so will no longer
be in this collection.
""" """
def __init__(self, active_connections): def __init__(self, active_connections):
"""
:param active_connections: an instance of ActiveConnections where
connections are put when both sides arrive.
"""
self._requests = defaultdict(set) # token -> set((side, TransitConnection)) self._requests = defaultdict(set) # token -> set((side, TransitConnection))
self._active = active_connections self._active = active_connections
@ -285,16 +329,23 @@ class PendingRequests(object):
if token in self._requests: if token in self._requests:
self._requests[token].discard((side, tc)) self._requests[token].discard((side, tc))
if not self._requests[token]: if not self._requests[token]:
# no more sides; token is dead
del self._requests[token] del self._requests[token]
self._active.unregister(tc) self._active.unregister(tc)
def register_token(self, token, new_side, new_tc): def register(self, token, new_side, new_tc):
""" """
A client has connected and successfully offered a token (and A client has connected and successfully offered a token (and
optional 'side' token). If this is the first one for this optional 'side' token). If this is the first one for this
token, we merely remember it. If it is the second side for token, we merely remember it. If it is the second side for
this token we connect them together. this token we connect them together.
:param bytes token: the token for this connection.
:param bytes new_side: None or the side token for this connection
:param TransitServerState new_tc: the state-machine of the connection
:returns bool: True if we are the first side to register this :returns bool: True if we are the first side to register this
token token
""" """
@ -562,7 +613,7 @@ class TransitServerState(object):
""" """
self._token = token self._token = token
self._side = side self._side = side
self._first = self._pending_requests.register_token(token, side, self) self._first = self._pending_requests.register(token, side, self)
@_machine.state(initial=True) @_machine.state(initial=True)
def listening(self): def listening(self):

View File

@ -6,6 +6,7 @@ from twisted.application.internet import (TimerService,
StreamServerEndpointService) StreamServerEndpointService)
from twisted.internet import endpoints from twisted.internet import endpoints
from . import transit_server from . import transit_server
from .server_state import create_usage_tracker
from .increase_rlimits import increase_rlimits from .increase_rlimits import increase_rlimits
LONGDESC = """\ LONGDESC = """\
@ -32,13 +33,18 @@ class Options(usage.Options):
def makeService(config, reactor=reactor): def makeService(config, reactor=reactor):
increase_rlimits() increase_rlimits()
ep = endpoints.serverFromString(reactor, config["port"]) # to listen ep = endpoints.serverFromString(reactor, config["port"]) # to listen
log_file = (os.fdopen(int(config["log-fd"]), "w") log_file = (
os.fdopen(int(config["log-fd"]), "w")
if config["log-fd"] is not None if config["log-fd"] is not None
else None) else None
f = transit_server.Transit(blur_usage=config["blur-usage"], )
usage = create_usage_tracker(
blur_usage=config["blur-usage"],
log_file=log_file, log_file=log_file,
usage_db=config["usage-db"]) usage_db=config["usage-db"],
)
factory = transit_server.Transit(usage)
parent = MultiService() parent = MultiService()
StreamServerEndpointService(ep, f).setServiceParent(parent) StreamServerEndpointService(ep, factory).setServiceParent(parent)
TimerService(5*60.0, f.timerUpdateStats).setServiceParent(parent) ### FIXME TODO TimerService(5*60.0, factory.timerUpdateStats).setServiceParent(parent)
return parent return parent

View File

@ -11,6 +11,8 @@ from zope.interface import (
from ..transit_server import ( from ..transit_server import (
Transit, Transit,
) )
from ..transit_server import Transit
from ..server_state import create_usage_tracker
class IRelayTestClient(Interface): class IRelayTestClient(Interface):
@ -42,6 +44,7 @@ class IRelayTestClient(Interface):
Erase any received data to this point. Erase any received data to this point.
""" """
class ServerBase: class ServerBase:
log_requests = False log_requests = False
@ -62,11 +65,12 @@ class ServerBase:
self.flush() self.flush()
def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None):
self._transit_server = Transit( usage = create_usage_tracker(
blur_usage=blur_usage, blur_usage=blur_usage,
log_file=log_file, log_file=log_file,
usage_db=usage_db, usage_db=usage_db,
) )
self._transit_server = Transit(usage)
self._transit_server._debug_log = self.log_requests self._transit_server._debug_log = self.log_requests
def new_protocol(self): def new_protocol(self):

View File

@ -11,7 +11,7 @@ class Service(unittest.TestCase):
def test_defaults(self): def test_defaults(self):
o = server_tap.Options() o = server_tap.Options()
o.parseOptions([]) o.parseOptions([])
with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t:
s = server_tap.makeService(o) s = server_tap.makeService(o)
self.assertEqual(t.mock_calls, self.assertEqual(t.mock_calls,
[mock.call(blur_usage=None, [mock.call(blur_usage=None,
@ -21,7 +21,7 @@ class Service(unittest.TestCase):
def test_blur(self): def test_blur(self):
o = server_tap.Options() o = server_tap.Options()
o.parseOptions(["--blur-usage=60"]) o.parseOptions(["--blur-usage=60"])
with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t:
server_tap.makeService(o) server_tap.makeService(o)
self.assertEqual(t.mock_calls, self.assertEqual(t.mock_calls,
[mock.call(blur_usage=60, [mock.call(blur_usage=60,
@ -31,7 +31,7 @@ class Service(unittest.TestCase):
o = server_tap.Options() o = server_tap.Options()
o.parseOptions(["--log-fd=99"]) o.parseOptions(["--log-fd=99"])
fd = object() fd = object()
with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t:
with mock.patch("wormhole_transit_relay.server_tap.os.fdopen", with mock.patch("wormhole_transit_relay.server_tap.os.fdopen",
return_value=fd) as f: return_value=fd) as f:
server_tap.makeService(o) server_tap.makeService(o)

View File

@ -6,6 +6,7 @@ except ImportError:
import mock import mock
from twisted.trial import unittest from twisted.trial import unittest
from ..transit_server import Transit from ..transit_server import Transit
from ..server_state import create_usage_tracker
from .. import database from .. import database
class DB(unittest.TestCase): class DB(unittest.TestCase):
@ -20,7 +21,7 @@ class DB(unittest.TestCase):
os.mkdir(d) os.mkdir(d)
usage_db = os.path.join(d, "usage.sqlite") usage_db = os.path.join(d, "usage.sqlite")
with mock.patch("time.time", return_value=T+0): with mock.patch("time.time", return_value=T+0):
t = Transit(blur_usage=None, log_file=None, usage_db=usage_db) t = Transit(create_usage_tracker(blur_usage=None, log_file=None, usage_db=usage_db))
db = self.open_db(usage_db) db = self.open_db(usage_db)
usage = list(t.usage._backends)[0] usage = list(t.usage._backends)[0]
@ -58,7 +59,7 @@ class DB(unittest.TestCase):
waiting=0, connected=0)) waiting=0, connected=0))
def test_no_db(self): def test_no_db(self):
t = Transit(blur_usage=None, log_file=None, usage_db=None) t = Transit(create_usage_tracker(blur_usage=None, log_file=None, usage_db=None))
self.assertEqual(0, len(t.usage._backends)) self.assertEqual(0, len(t.usage._backends))
@ -66,7 +67,7 @@ class LogToStdout(unittest.TestCase):
def test_log(self): def test_log(self):
# emit lines of JSON to log_file, if set # emit lines of JSON to log_file, if set
log_file = io.StringIO() log_file = io.StringIO()
t = Transit(blur_usage=None, log_file=log_file, usage_db=None) t = Transit(create_usage_tracker(blur_usage=None, log_file=log_file, usage_db=None))
with mock.patch("time.time", return_value=133): with mock.patch("time.time", return_value=133):
t.usage.record( t.usage.record(
started=123, started=123,
@ -84,7 +85,7 @@ class LogToStdout(unittest.TestCase):
# if blurring is enabled, timestamps should be rounded to the # if blurring is enabled, timestamps should be rounded to the
# requested amount, and sizes should be rounded up too # requested amount, and sizes should be rounded up too
log_file = io.StringIO() log_file = io.StringIO()
t = Transit(blur_usage=60, log_file=log_file, usage_db=None) t = Transit(create_usage_tracker(blur_usage=60, log_file=log_file, usage_db=None))
with mock.patch("time.time", return_value=123 + 10): with mock.patch("time.time", return_value=123 + 10):
t.usage.record( t.usage.record(
@ -101,7 +102,7 @@ class LogToStdout(unittest.TestCase):
"mood": "happy"}) "mood": "happy"})
def test_do_not_log(self): def test_do_not_log(self):
t = Transit(blur_usage=60, log_file=None, usage_db=None) t = Transit(create_usage_tracker(blur_usage=60, log_file=None, usage_db=None))
t.usage.record( t.usage.record(
started=123, started=123,
buddy_started=124, buddy_started=124,

View File

@ -4,7 +4,6 @@ from collections import defaultdict
from twisted.python import log from twisted.python import log
from twisted.internet import protocol from twisted.internet import protocol
from twisted.protocols.basic import LineReceiver from twisted.protocols.basic import LineReceiver
from .database import get_db
SECONDS = 1.0 SECONDS = 1.0
MINUTE = 60*SECONDS MINUTE = 60*SECONDS
@ -79,7 +78,7 @@ class TransitConnection(LineReceiver):
self.factory.usage, self.factory.usage,
) )
self._state.connection_made(self) self._state.connection_made(self)
self._log_requests = self.factory._log_requests ## self._log_requests = self.factory._log_requests
try: try:
self.transport.setTcpKeepAlive(True) self.transport.setTcpKeepAlive(True)
except AttributeError: except AttributeError:
@ -131,8 +130,8 @@ class TransitConnection(LineReceiver):
# there will be two producer/consumer pairs. # there will be two producer/consumer pairs.
def __buddy_disconnected(self): def __buddy_disconnected(self):
if self._log_requests: ## if self._log_requests:
log.msg("buddy_disconnected %s" % self.describeToken()) ## log.msg("buddy_disconnected %s" % self.describeToken())
self._buddy = None self._buddy = None
self._mood = "jilted" self._mood = "jilted"
self.transport.loseConnection() self.transport.loseConnection()
@ -210,117 +209,61 @@ class TransitConnection(LineReceiver):
class Transit(protocol.ServerFactory): class Transit(protocol.ServerFactory):
# I manage pairs of simultaneous connections to a secondary TCP port, """
# both forwarded to the other. Clients must begin each connection with I manage pairs of simultaneous connections to a secondary TCP port,
# "please relay TOKEN for SIDE\n" (or a legacy form without the "for both forwarded to the other. Clients must begin each connection with
# SIDE"). Two connections match if they use the same TOKEN and have "please relay TOKEN for SIDE\n" (or a legacy form without the "for
# different SIDEs (the redundant connections are dropped when a match is SIDE"). Two connections match if they use the same TOKEN and have
# made). Legacy connections match any with the same TOKEN, ignoring SIDE different SIDEs (the redundant connections are dropped when a match is
# (so two legacy connections will match each other). made). Legacy connections match any with the same TOKEN, ignoring SIDE
(so two legacy connections will match each other).
# I will send "ok\n" when the matching connection is established, or I will send "ok\n" when the matching connection is established, or
# disconnect if no matching connection is made within MAX_WAIT_TIME disconnect if no matching connection is made within MAX_WAIT_TIME
# seconds. I will disconnect if you send data before the "ok\n". All data seconds. I will disconnect if you send data before the "ok\n". All data
# you get after the "ok\n" will be from the other side. You will not you get after the "ok\n" will be from the other side. You will not
# receive "ok\n" until the other side has also connected and submitted a receive "ok\n" until the other side has also connected and submitted a
# matching token (and differing SIDE). matching token (and differing SIDE).
# In addition, the connections will be dropped after MAXLENGTH bytes have In addition, the connections will be dropped after MAXLENGTH bytes have
# been sent by either side, or MAXTIME seconds have elapsed after the been sent by either side, or MAXTIME seconds have elapsed after the
# matching connections were established. A future API will reveal these matching connections were established. A future API will reveal these
# limits to clients instead of causing mysterious spontaneous failures. limits to clients instead of causing mysterious spontaneous failures.
# These relay connections are not half-closeable (unlike full TCP These relay connections are not half-closeable (unlike full TCP
# connections, applications will not receive any data after half-closing connections, applications will not receive any data after half-closing
# their outgoing side). Applications must negotiate shutdown with their their outgoing side). Applications must negotiate shutdown with their
# peer and not close the connection until all data has finished peer and not close the connection until all data has finished
# transferring in both directions. Applications which only need to send transferring in both directions. Applications which only need to send
# data in one direction can use close() as usual. data in one direction can use close() as usual.
"""
# TODO: unused
MAX_WAIT_TIME = 30*SECONDS MAX_WAIT_TIME = 30*SECONDS
# TODO: unused
MAXLENGTH = 10*MB MAXLENGTH = 10*MB
# TODO: unused
MAXTIME = 60*SECONDS MAXTIME = 60*SECONDS
protocol = TransitConnection protocol = TransitConnection
def __init__(self, blur_usage, log_file, usage_db): def __init__(self, usage):
self.active_connections = ActiveConnections() self.active_connections = ActiveConnections()
self.pending_requests = PendingRequests(self.active_connections) self.pending_requests = PendingRequests(self.active_connections)
self.usage = UsageTracker(blur_usage) self.usage = usage
self._blur_usage = blur_usage if False:
self._log_requests = blur_usage is None # these logs-message should be made by the usage-tracker
if self._blur_usage: # .. or in the "tap" setup?
if blur_usage:
log.msg("blurring access times to %d seconds" % self._blur_usage) log.msg("blurring access times to %d seconds" % self._blur_usage)
log.msg("not logging Transit connections to Twisted log") log.msg("not logging Transit connections to Twisted log")
else: else:
log.msg("not blurring access times") log.msg("not blurring access times")
self._debug_log = False self._debug_log = False
## self._log_file = log_file
self._db = None
if usage_db:
self._db = get_db(usage_db)
self.usage.add_backend(DatabaseUsageRecorder(self._db))
if log_file:
self.usage.add_backend(LogFileUsageRecorder(log_file))
self._rebooted = time.time() self._rebooted = time.time()
# we don't track TransitConnections until they submit a token
## self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection))
## self._active_connections = set() # TransitConnection
def transitFinished(self, tc, token, side, description):
if token in self._pending_requests:
side_tc = (side, tc)
self._pending_requests[token].discard(side_tc)
if not self._pending_requests[token]: # set is now empty
del self._pending_requests[token]
if self._debug_log:
log.msg("transitFinished %s" % (description,))
self._active_connections.discard(tc)
# we could update the usage database "current" row immediately, or wait
# until the 5-minute timer updates it. If we update it now, just after
# losing a connection, we should probably also update it just after
# establishing one (at the end of connection_got_token). For now I'm
# going to omit these, but maybe someday we'll turn them both on. The
# consequence is that a manual execution of the munin scripts ("munin
# run wormhole_transit_active") will give the wrong value just after a
# connect/disconnect event. Actual munin graphs should accurately
# report connections that last longer than the 5-minute sampling
# window, which is what we actually care about.
#self.timerUpdateStats()
def recordUsage(self, started, result, total_bytes,
total_time, waiting_time):
if self._debug_log:
log.msg(format="Transit.recordUsage {bytes}B", bytes=total_bytes)
if self._blur_usage:
started = self._blur_usage * (started // self._blur_usage)
total_bytes = blur_size(total_bytes)
if self._log_file is not None:
data = {"started": started,
"total_time": total_time,
"waiting_time": waiting_time,
"total_bytes": total_bytes,
"mood": result,
}
self._log_file.write(json.dumps(data)+"\n")
self._log_file.flush()
if self._db:
self._db.execute("INSERT INTO `usage`"
" (`started`, `total_time`, `waiting_time`,"
" `total_bytes`, `result`)"
" VALUES (?,?,?, ?,?)",
(started, total_time, waiting_time,
total_bytes, result))
# XXXX aaaaaAA! okay, so just this one type of usage also
# does some other random stats-stuff; need more
# refactorizing
self._update_stats()
self._db.commit()
def timerUpdateStats(self):
if self._db:
self._update_stats()
self._db.commit()
# XXX TODO self._rebooted and the below could be in a separate
# object? or in the DatabaseUsageRecorder .. but not here
def _update_stats(self): def _update_stats(self):
# current status: should be zero when idle # current status: should be zero when idle
rebooted = self._rebooted rebooted = self._rebooted