rewrite connection handling, not sure it's a good idea

This commit is contained in:
Brian Warner 2017-11-04 12:23:32 -07:00
parent d36e0c44bd
commit d7800f6337

View File

@ -28,7 +28,6 @@ class TransitConnection(protocol.Protocol):
self._token_buffer = b""
self._sent_ok = False
self._buddy = None
self._had_buddy = False
self._total_sent = 0
def describeToken(self):
@ -131,7 +130,7 @@ class TransitConnection(protocol.Protocol):
def _got_handshake(self, token, side):
self._got_token = token
self._got_side = side
self.factory.connection_got_token(token, side, self)
self.factory.transitGotToken(token, side, self)
def buddy_connected(self, them):
self._buddy = them
@ -153,39 +152,57 @@ class TransitConnection(protocol.Protocol):
def connectionLost(self, reason):
if self._buddy:
self._buddy.buddy_disconnected()
self._buddy.buddy_disconnected() # hang up on the buddy
self.factory.transitFinished(self, self._got_token, self._got_side,
self.describeToken())
# Record usage. There are four cases:
# * 1: we connected, never had a buddy
# * 2: we connected first, we disconnect before the buddy
# * 3: we connected first, buddy disconnects first
# * 4: buddy connected first, we disconnect before buddy
# * 5: buddy connected first, buddy disconnects first
# whoever disconnects first gets to write the usage record (1,2,4)
finished = time.time()
if not self._had_buddy: # 1
total_time = finished - self._started
self.factory.recordUsage(self._started, "lonely", 0,
total_time, None)
if self._had_buddy and self._buddy: # 2,4
total_bytes = self._total_sent + self._buddy._total_sent
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
self.factory.recordUsage(self._started, "happy", total_bytes,
total_time, waiting_time)
def disconnect(self):
# called when we hang up on a connection because they violated the
# protocol, or we abandon a losing connection because a different one
# from that side won
self.transport.loseConnection()
self.factory.transitFailed(self)
finished = time.time()
total_time = finished - self._started
self.factory.recordUsage(self._started, "errory", 0,
total_time, None)
PENDING, OPEN, LINGERING, EMPTY = range(4)
class Channel(object):
def __init__(self, factory):
self._factory = factory
self._state = PENDING
self._connections = set() # (side, tc)
def gotConnection(self, side, tc):
if self._state == PENDING:
for old in self._connections:
(old_side, old_tc) = old
if ((old_side is None)
or (side is None)
or (old_side != side)):
# we found a match
if self._debug_log:
log.msg("transit relay 2: %s" % new_tc.describeToken())
self._state = OPEN
self._factory.channelOpen(self)
# drop and stop tracking the rest
self._connections.remove(old)
for (_, leftover_tc) in self._connections:
# TODO: not "errory"? the ones we drop are the parallel
# connections from the first client ('side' was the
# same), so it's not really an error. More of a "you
# lost, one of your other connections won, sorry"
leftover_tc.disconnect()
self._pending_tokens.pop(token)
# glue the two ends together
self._active_connections.add(new_tc)
self._active_connections.add(old_tc)
new_tc.buddy_connected(old_tc)
old_tc.buddy_connected(new_tc)
return
if self._debug_log:
log.msg("transit relay 1: %s" % new_tc.describeToken())
potentials.add((new_side, new_tc))
# TODO: timer
class Transit(protocol.ServerFactory):
# I manage pairs of simultaneous connections to a secondary TCP port,
@ -220,119 +237,135 @@ class Transit(protocol.ServerFactory):
MAXTIME = 60*SECONDS
protocol = TransitConnection
def __init__(self, blur_usage, usage_logfile, stats_file):
def __init__(self, blur_usage, log_stdout, usage_db):
self._blur_usage = blur_usage
self._log_requests = blur_usage is None
self._usage_logfile = open(usage_logfile, "a") if usage_logfile else None
self._stats_file = stats_file
self._pending_requests = {} # token -> set((side, TransitConnection))
self._active_connections = set() # TransitConnection
self._counts = {"lonely": 0, "happy": 0, "errory": 0}
self._count_bytes = 0
self._debug_log = False
self._log_stdout = log_stdout
self._db = None
if usage_db:
self._db = get_db(usage_db)
# we don't track TransitConnections until they submit a token
def connection_got_token(self, token, new_side, new_tc):
if token not in self._pending_requests:
self._pending_requests[token] = set()
potentials = self._pending_requests[token]
for old in potentials:
(old_side, old_tc) = old
if ((old_side is None)
or (new_side is None)
or (old_side != new_side)):
# we found a match
if self._log_requests:
log.msg("transit relay 2: %s" % new_tc.describeToken())
# Channels are indexed by token, and are either pending, open, or
# lingering
self._channels = {} # token -> Channel
self._pending_channels = set()
self._open_channels = set()
self._lingering_channels = set()
# drop and stop tracking the rest
potentials.remove(old)
for (_, leftover_tc) in potentials:
leftover_tc.disconnect() # TODO: not "errory"?
self._pending_requests.pop(token)
def transitGotToken(self, token, new_side, new_tc):
if token not in self._channels:
self._channels[token] = Channel(self)
self._channels[token].gotConnection(new_side, new_tc)
# glue the two ends together
self._active_connections.add(new_tc)
self._active_connections.add(old_tc)
new_tc.buddy_connected(old_tc)
old_tc.buddy_connected(new_tc)
return
if self._log_requests:
log.msg("transit relay 1: %s" % new_tc.describeToken())
potentials.add((new_side, new_tc))
# TODO: timer
def channelOpen(self, c):
self._pending_channels.remove(c)
self._open_channels.add(c)
def channelClosed(self, c):
self._open_channels.remove(c)
self._lingering_channels.add(c)
def channelEmpty(self, c):
self._lingering_channels.remove(c)
def transitFinished(self, tc, token, side, description):
# we're called each time a TransitConnection shuts down
if token in self._pending_tokens:
side_tc = (side, tc)
if side_tc in self._pending_tokens[token]:
self._pending_tokens[token].remove(side_tc)
if not self._pending_tokens[token]: # set is now empty
del self._pending_tokens[token]
if self._debug_log:
log.msg("transitFinished %s" % (description,))
self._active_connections.discard(tc)
# Record usage. There are five cases:
# * 1: we connected, never had a buddy
# * 2: we connected first, we disconnect before the buddy
# * 3: we connected first, buddy disconnects first
# * 4: buddy connected first, we disconnect before buddy
# * 5: buddy connected first, buddy disconnects first
# whoever disconnects first gets to write the usage record (1,2,4)
finished = time.time()
if self._had_buddy:
if self._buddy: # 2,4: we disconnected first
total_bytes = self._total_sent + self._buddy._total_sent
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
self.factory.buddyIsLingering(self._buddy)
self.factory.recordUsage(self._started, "happy", total_bytes,
total_time, waiting_time)
else: # 3, 5: we disconnected last
self.factory.doneLingering(self)
else: # 1: we were the only one
total_time = finished - self._started
self.factory.recordUsage(self._started, "lonely", 0,
total_time, None)
def transitFailed(self, p):
if self._debug_log:
log.msg("transitFailed %r" % p)
finished = time.time()
total_time = finished - self._started
self.factory.recordUsage(self._started, "errory", 0,
total_time, None)
pass
def buddyIsLingering(self, buddy_tc):
self._active_connections.remove(buddy_tc)
self._lingering_connections.add(buddy_tc)
def doneLingering(self, old_tc):
self._lingering_connections.remove(buddy_tc)
def recordUsage(self, started, result, total_bytes,
total_time, waiting_time):
self._counts[result] += 1
self._count_bytes += total_bytes
if self._log_requests:
if self._debug_log:
log.msg(format="Transit.recordUsage {bytes}B", bytes=total_bytes)
if self._blur_usage:
started = self._blur_usage * (started // self._blur_usage)
total_bytes = blur_size(total_bytes)
if self._usage_logfile:
if self._log_stdout:
data = {"started": started,
"total_time": total_time,
"waiting_time": waiting_time,
"total_bytes": total_bytes,
"mood": result,
}
self._usage_logfile.write(json.dumps(data))
self._usage_logfile.write("\n")
self._usage_logfile.flush()
if self._stats_file:
self._update_stats(total_bytes, result)
sys.stdout.write(json.dumps(data))
sys.stdout.write("\n")
sys.stdout.flush()
if self._db:
self._db.execute("INSERT INTO `usage`"
" (`started`, `total_time`, `waiting_time`,"
" `total_bytes`, `result`)"
" VALUES (?,?,?, ?,?)",
(started, total_time, waiting_time,
total_bytes, result))
self._update_stats()
self._db.commit()
def transitFinished(self, tc, token, side, description):
if token in self._pending_requests:
side_tc = (side, tc)
if side_tc in self._pending_requests[token]:
self._pending_requests[token].remove(side_tc)
if not self._pending_requests[token]: # set is now empty
del self._pending_requests[token]
if self._log_requests:
log.msg("transitFinished %s" % (description,))
self._active_connections.discard(tc)
def transitFailed(self, p):
if self._log_requests:
log.msg("transitFailed %r" % p)
pass
def _update_stats(self, total_bytes, mood):
try:
with open(self._stats_file, "r") as f:
stats = json.load(f)
except (EnvironmentError, ValueError):
stats = {}
# current status: expected to be zero most of the time
stats["active"] = {"connected": len(self._active_connections) / 2,
"waiting": len(self._pending_requests),
}
# usage since last reboot
rb = stats["since_reboot"] = {}
rb["bytes"] = self._count_bytes
rb["total"] = sum(self._counts.values(), 0)
rbm = rb["moods"] = {}
for result, count in self._counts.items():
rbm[result] = count
# historical usage (all-time)
if "all_time" not in stats:
stats["all_time"] = {}
u = stats["all_time"]
u["total"] = u.get("total", 0) + 1
u["bytes"] = u.get("bytes", 0) + total_bytes
if "moods" not in u:
u["moods"] = {}
um = u["moods"]
for m in "happy", "lonely", "errory":
if m not in um:
um[m] = 0
um[mood] += 1
tmpfile = self._stats_file + ".tmp"
with open(tmpfile, "w") as f:
f.write(json.dumps(stats))
f.write("\n")
os.rename(tmpfile, self._stats_file)
def _update_stats(self):
# current status: should be zero when idle
reboot = self._reboot
last_update = time.time()
connected = len(self._active_connections) / 2
# TODO: when a connection is half-closed, len(active) will be odd. a
# moment later (hopefully) the other side will disconnect, but
# _update_stats isn't updated until later.
waiting = len(self._pending_tokens)
# "waiting" doesn't count multiple parallel connections from the same
# side
incomplete_bytes = sum(tc._total_sent
for tc in self._active_connections)
self._db.execute("DELETE FROM `current`")
self._db.execute("INSERT INTO `current`"
" (`reboot`, `last_update`, `connected`, `waiting`,"
" `incomplete_bytes`)"
" VALUES (?, ?, ?, ?, ?)",
(reboot, last_update, connected, waiting,
incomplete_bytes))