rewrite connection handling, not sure it's a good idea

This commit is contained in:
Brian Warner 2017-11-04 12:23:32 -07:00
parent d36e0c44bd
commit d7800f6337

View File

@ -28,7 +28,6 @@ class TransitConnection(protocol.Protocol):
self._token_buffer = b"" self._token_buffer = b""
self._sent_ok = False self._sent_ok = False
self._buddy = None self._buddy = None
self._had_buddy = False
self._total_sent = 0 self._total_sent = 0
def describeToken(self): def describeToken(self):
@ -131,7 +130,7 @@ class TransitConnection(protocol.Protocol):
def _got_handshake(self, token, side): def _got_handshake(self, token, side):
self._got_token = token self._got_token = token
self._got_side = side self._got_side = side
self.factory.connection_got_token(token, side, self) self.factory.transitGotToken(token, side, self)
def buddy_connected(self, them): def buddy_connected(self, them):
self._buddy = them self._buddy = them
@ -153,39 +152,57 @@ class TransitConnection(protocol.Protocol):
def connectionLost(self, reason): def connectionLost(self, reason):
if self._buddy: if self._buddy:
self._buddy.buddy_disconnected() self._buddy.buddy_disconnected() # hang up on the buddy
self.factory.transitFinished(self, self._got_token, self._got_side, self.factory.transitFinished(self, self._got_token, self._got_side,
self.describeToken()) self.describeToken())
# Record usage. There are four cases:
# * 1: we connected, never had a buddy
# * 2: we connected first, we disconnect before the buddy
# * 3: we connected first, buddy disconnects first
# * 4: buddy connected first, we disconnect before buddy
# * 5: buddy connected first, buddy disconnects first
# whoever disconnects first gets to write the usage record (1,2,4)
finished = time.time()
if not self._had_buddy: # 1
total_time = finished - self._started
self.factory.recordUsage(self._started, "lonely", 0,
total_time, None)
if self._had_buddy and self._buddy: # 2,4
total_bytes = self._total_sent + self._buddy._total_sent
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
self.factory.recordUsage(self._started, "happy", total_bytes,
total_time, waiting_time)
def disconnect(self): def disconnect(self):
# called when we hang up on a connection because they violated the
# protocol, or we abandon a losing connection because a different one
# from that side won
self.transport.loseConnection() self.transport.loseConnection()
self.factory.transitFailed(self) self.factory.transitFailed(self)
finished = time.time()
total_time = finished - self._started PENDING, OPEN, LINGERING, EMPTY = range(4)
self.factory.recordUsage(self._started, "errory", 0,
total_time, None) class Channel(object):
def __init__(self, factory):
self._factory = factory
self._state = PENDING
self._connections = set() # (side, tc)
def gotConnection(self, side, tc):
if self._state == PENDING:
for old in self._connections:
(old_side, old_tc) = old
if ((old_side is None)
or (side is None)
or (old_side != side)):
# we found a match
if self._debug_log:
log.msg("transit relay 2: %s" % new_tc.describeToken())
self._state = OPEN
self._factory.channelOpen(self)
# drop and stop tracking the rest
self._connections.remove(old)
for (_, leftover_tc) in self._connections:
# TODO: not "errory"? the ones we drop are the parallel
# connections from the first client ('side' was the
# same), so it's not really an error. More of a "you
# lost, one of your other connections won, sorry"
leftover_tc.disconnect()
self._pending_tokens.pop(token)
# glue the two ends together
self._active_connections.add(new_tc)
self._active_connections.add(old_tc)
new_tc.buddy_connected(old_tc)
old_tc.buddy_connected(new_tc)
return
if self._debug_log:
log.msg("transit relay 1: %s" % new_tc.describeToken())
potentials.add((new_side, new_tc))
# TODO: timer
class Transit(protocol.ServerFactory): class Transit(protocol.ServerFactory):
# I manage pairs of simultaneous connections to a secondary TCP port, # I manage pairs of simultaneous connections to a secondary TCP port,
@ -220,119 +237,135 @@ class Transit(protocol.ServerFactory):
MAXTIME = 60*SECONDS MAXTIME = 60*SECONDS
protocol = TransitConnection protocol = TransitConnection
def __init__(self, blur_usage, usage_logfile, stats_file): def __init__(self, blur_usage, log_stdout, usage_db):
self._blur_usage = blur_usage self._blur_usage = blur_usage
self._log_requests = blur_usage is None self._debug_log = False
self._usage_logfile = open(usage_logfile, "a") if usage_logfile else None self._log_stdout = log_stdout
self._stats_file = stats_file self._db = None
self._pending_requests = {} # token -> set((side, TransitConnection)) if usage_db:
self._active_connections = set() # TransitConnection self._db = get_db(usage_db)
self._counts = {"lonely": 0, "happy": 0, "errory": 0} # we don't track TransitConnections until they submit a token
self._count_bytes = 0
def connection_got_token(self, token, new_side, new_tc): # Channels are indexed by token, and are either pending, open, or
if token not in self._pending_requests: # lingering
self._pending_requests[token] = set() self._channels = {} # token -> Channel
potentials = self._pending_requests[token] self._pending_channels = set()
for old in potentials: self._open_channels = set()
(old_side, old_tc) = old self._lingering_channels = set()
if ((old_side is None)
or (new_side is None)
or (old_side != new_side)):
# we found a match
if self._log_requests:
log.msg("transit relay 2: %s" % new_tc.describeToken())
# drop and stop tracking the rest def transitGotToken(self, token, new_side, new_tc):
potentials.remove(old) if token not in self._channels:
for (_, leftover_tc) in potentials: self._channels[token] = Channel(self)
leftover_tc.disconnect() # TODO: not "errory"? self._channels[token].gotConnection(new_side, new_tc)
self._pending_requests.pop(token)
# glue the two ends together def channelOpen(self, c):
self._active_connections.add(new_tc) self._pending_channels.remove(c)
self._active_connections.add(old_tc) self._open_channels.add(c)
new_tc.buddy_connected(old_tc) def channelClosed(self, c):
old_tc.buddy_connected(new_tc) self._open_channels.remove(c)
return self._lingering_channels.add(c)
if self._log_requests: def channelEmpty(self, c):
log.msg("transit relay 1: %s" % new_tc.describeToken()) self._lingering_channels.remove(c)
potentials.add((new_side, new_tc))
# TODO: timer def transitFinished(self, tc, token, side, description):
# we're called each time a TransitConnection shuts down
if token in self._pending_tokens:
side_tc = (side, tc)
if side_tc in self._pending_tokens[token]:
self._pending_tokens[token].remove(side_tc)
if not self._pending_tokens[token]: # set is now empty
del self._pending_tokens[token]
if self._debug_log:
log.msg("transitFinished %s" % (description,))
self._active_connections.discard(tc)
# Record usage. There are five cases:
# * 1: we connected, never had a buddy
# * 2: we connected first, we disconnect before the buddy
# * 3: we connected first, buddy disconnects first
# * 4: buddy connected first, we disconnect before buddy
# * 5: buddy connected first, buddy disconnects first
# whoever disconnects first gets to write the usage record (1,2,4)
finished = time.time()
if self._had_buddy:
if self._buddy: # 2,4: we disconnected first
total_bytes = self._total_sent + self._buddy._total_sent
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
self.factory.buddyIsLingering(self._buddy)
self.factory.recordUsage(self._started, "happy", total_bytes,
total_time, waiting_time)
else: # 3, 5: we disconnected last
self.factory.doneLingering(self)
else: # 1: we were the only one
total_time = finished - self._started
self.factory.recordUsage(self._started, "lonely", 0,
total_time, None)
def transitFailed(self, p):
if self._debug_log:
log.msg("transitFailed %r" % p)
finished = time.time()
total_time = finished - self._started
self.factory.recordUsage(self._started, "errory", 0,
total_time, None)
pass
def buddyIsLingering(self, buddy_tc):
self._active_connections.remove(buddy_tc)
self._lingering_connections.add(buddy_tc)
def doneLingering(self, old_tc):
self._lingering_connections.remove(buddy_tc)
def recordUsage(self, started, result, total_bytes, def recordUsage(self, started, result, total_bytes,
total_time, waiting_time): total_time, waiting_time):
self._counts[result] += 1 if self._debug_log:
self._count_bytes += total_bytes
if self._log_requests:
log.msg(format="Transit.recordUsage {bytes}B", bytes=total_bytes) log.msg(format="Transit.recordUsage {bytes}B", bytes=total_bytes)
if self._blur_usage: if self._blur_usage:
started = self._blur_usage * (started // self._blur_usage) started = self._blur_usage * (started // self._blur_usage)
total_bytes = blur_size(total_bytes) total_bytes = blur_size(total_bytes)
if self._usage_logfile: if self._log_stdout:
data = {"started": started, data = {"started": started,
"total_time": total_time, "total_time": total_time,
"waiting_time": waiting_time, "waiting_time": waiting_time,
"total_bytes": total_bytes, "total_bytes": total_bytes,
"mood": result, "mood": result,
} }
self._usage_logfile.write(json.dumps(data)) sys.stdout.write(json.dumps(data))
self._usage_logfile.write("\n") sys.stdout.write("\n")
self._usage_logfile.flush() sys.stdout.flush()
if self._stats_file: if self._db:
self._update_stats(total_bytes, result) self._db.execute("INSERT INTO `usage`"
" (`started`, `total_time`, `waiting_time`,"
" `total_bytes`, `result`)"
" VALUES (?,?,?, ?,?)",
(started, total_time, waiting_time,
total_bytes, result))
self._update_stats()
self._db.commit()
def transitFinished(self, tc, token, side, description): def _update_stats(self):
if token in self._pending_requests: # current status: should be zero when idle
side_tc = (side, tc) reboot = self._reboot
if side_tc in self._pending_requests[token]: last_update = time.time()
self._pending_requests[token].remove(side_tc) connected = len(self._active_connections) / 2
if not self._pending_requests[token]: # set is now empty # TODO: when a connection is half-closed, len(active) will be odd. a
del self._pending_requests[token] # moment later (hopefully) the other side will disconnect, but
if self._log_requests: # _update_stats isn't updated until later.
log.msg("transitFinished %s" % (description,)) waiting = len(self._pending_tokens)
self._active_connections.discard(tc) # "waiting" doesn't count multiple parallel connections from the same
# side
def transitFailed(self, p): incomplete_bytes = sum(tc._total_sent
if self._log_requests: for tc in self._active_connections)
log.msg("transitFailed %r" % p) self._db.execute("DELETE FROM `current`")
pass self._db.execute("INSERT INTO `current`"
" (`reboot`, `last_update`, `connected`, `waiting`,"
def _update_stats(self, total_bytes, mood): " `incomplete_bytes`)"
try: " VALUES (?, ?, ?, ?, ?)",
with open(self._stats_file, "r") as f: (reboot, last_update, connected, waiting,
stats = json.load(f) incomplete_bytes))
except (EnvironmentError, ValueError):
stats = {}
# current status: expected to be zero most of the time
stats["active"] = {"connected": len(self._active_connections) / 2,
"waiting": len(self._pending_requests),
}
# usage since last reboot
rb = stats["since_reboot"] = {}
rb["bytes"] = self._count_bytes
rb["total"] = sum(self._counts.values(), 0)
rbm = rb["moods"] = {}
for result, count in self._counts.items():
rbm[result] = count
# historical usage (all-time)
if "all_time" not in stats:
stats["all_time"] = {}
u = stats["all_time"]
u["total"] = u.get("total", 0) + 1
u["bytes"] = u.get("bytes", 0) + total_bytes
if "moods" not in u:
u["moods"] = {}
um = u["moods"]
for m in "happy", "lonely", "errory":
if m not in um:
um[m] = 0
um[mood] += 1
tmpfile = self._stats_file + ".tmp"
with open(tmpfile, "w") as f:
f.write(json.dumps(stats))
f.write("\n")
os.rename(tmpfile, self._stats_file)