rework mood tracking, full tests of usage events

This commit is contained in:
Brian Warner 2018-02-12 13:35:29 -08:00
parent 547ead75ba
commit 4e9b6c53a9
2 changed files with 234 additions and 43 deletions

View File

@ -214,6 +214,10 @@ class _Transit:
ep = clientFromString(reactor, self.transit) ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator()) a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator()) a2 = yield connectProtocol(ep, Accumulator())
a3 = yield connectProtocol(ep, Accumulator())
disconnects = []
a1._disconnect.addCallback(disconnects.append)
a2._disconnect.addCallback(disconnects.append)
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
@ -229,8 +233,26 @@ class _Transit:
yield wait() yield wait()
self.assertEqual(self.count(), 2) # same-side connections don't match self.assertEqual(self.count(), 2) # same-side connections don't match
# when the second side arrives, the spare first connection should be
# closed
side2 = b"\x02"*8
a3.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side2) + b"\n")
# let that arrive
while self.count() != 0:
yield wait()
self.assertEqual(len(self._transit_server._pending_requests), 0)
self.assertEqual(len(self._transit_server._active_connections), 2)
# That will trigger a disconnect on exactly one of (a1 or a2). Wait
# until our client notices it.
while not disconnects:
yield wait()
# the other connection should still be connected
self.assertEqual(sum([int(t.transport.connected) for t in [a1, a2]]), 1)
a1.transport.loseConnection() a1.transport.loseConnection()
a2.transport.loseConnection() a2.transport.loseConnection()
a3.transport.loseConnection()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_bad_handshake_old(self): def test_bad_handshake_old(self):
@ -379,3 +401,143 @@ class TransitWithLogs(_Transit, ServerBase, unittest.TestCase):
class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase):
log_requests = False log_requests = False
class Usage(ServerBase, unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
yield super(Usage, self).setUp()
self._usage = []
def record(started, result, total_bytes, total_time, waiting_time):
self._usage.append((started, result, total_bytes,
total_time, waiting_time))
self._transit_server.recordUsage = record
@defer.inlineCallbacks
def test_errory(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a1.transport.write(b"this is a very bad handshake\n")
# that will log the "errory" usage event, then drop the connection
yield a1._disconnect
self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "errory", self._usage)
@defer.inlineCallbacks
def test_lonely(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
while not self._transit_server._pending_requests:
yield wait() # wait for the server to see the connection
# now we disconnect before the peer connects
a1.transport.loseConnection()
yield a1._disconnect
while self._transit_server._pending_requests:
yield wait() # wait for the server to see the disconnect too
self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "lonely", self._usage)
self.assertIdentical(waiting_time, None)
@defer.inlineCallbacks
def test_one_happy_one_jilted(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
side2 = b"\x02"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
while not self._transit_server._pending_requests:
yield wait() # make sure a1 connects first
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side2) + b"\n")
while not self._transit_server._active_connections:
yield wait() # wait for the server to see the connection
self.assertEqual(len(self._transit_server._pending_requests), 0)
self.assertEqual(self._usage, []) # no events yet
a1.transport.write(b"\x00" * 13)
yield a2.waitForBytes(13)
a2.transport.write(b"\xff" * 7)
yield a1.waitForBytes(7)
a1.transport.loseConnection()
yield a1._disconnect
while self._transit_server._active_connections:
yield wait()
yield a2._disconnect
self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "happy", self._usage)
self.assertEqual(total_bytes, 20)
self.assertNotIdentical(waiting_time, None)
@defer.inlineCallbacks
def test_redundant(self):
ep = clientFromString(reactor, self.transit)
a1a = yield connectProtocol(ep, Accumulator())
a1b = yield connectProtocol(ep, Accumulator())
a1c = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
side2 = b"\x02"*8
a1a.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
def count_requests():
return sum([len(v)
for v in self._transit_server._pending_requests.values()])
while count_requests() < 1:
yield wait()
a1b.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
while count_requests() < 2:
yield wait()
# connect and disconnect a third client (for side1) to exercise the
# code that removes a pending connection without removing the entire
# token
a1c.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
while count_requests() < 3:
yield wait()
a1c.transport.loseConnection()
yield a1c._disconnect
while count_requests() > 2:
yield wait()
self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "lonely", self._usage)
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side2) + b"\n")
# this will claim one of (a1a, a1b), and close the other as redundant
while not self._transit_server._active_connections:
yield wait() # wait for the server to see the connection
self.assertEqual(count_requests(), 0)
self.assertEqual(len(self._usage), 2, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[1]
self.assertEqual(result, "redundant", self._usage)
# one of the these is unecessary, but probably harmless
a1a.transport.loseConnection()
a1b.transport.loseConnection()
yield a1a._disconnect
yield a1b._disconnect
while self._transit_server._active_connections:
yield wait()
yield a2._disconnect
self.assertEqual(len(self._usage), 3, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[2]
self.assertEqual(result, "happy", self._usage)

View File

@ -28,8 +28,8 @@ class TransitConnection(protocol.Protocol):
self._got_side = False self._got_side = False
self._token_buffer = b"" self._token_buffer = b""
self._sent_ok = False self._sent_ok = False
self._mood = None
self._buddy = None self._buddy = None
self._had_buddy = False
self._total_sent = 0 self._total_sent = 0
def describeToken(self): def describeToken(self):
@ -62,7 +62,7 @@ class TransitConnection(protocol.Protocol):
self.transport.write(b"impatient\n") self.transport.write(b"impatient\n")
if self._log_requests: if self._log_requests:
log.msg("transit impatience failure") log.msg("transit impatience failure")
return self.disconnect() # impatience yields failure return self.disconnect_error() # impatience yields failure
# else this should be (part of) the token # else this should be (part of) the token
self._token_buffer += data self._token_buffer += data
@ -79,7 +79,7 @@ class TransitConnection(protocol.Protocol):
self.transport.write(b"impatient\n") self.transport.write(b"impatient\n")
if self._log_requests: if self._log_requests:
log.msg("transit impatience failure") log.msg("transit impatience failure")
return self.disconnect() # impatience yields failure return self.disconnect_error() # impatience yields failure
return self._got_handshake(token, None) return self._got_handshake(token, None)
(new, handshake_len, token, side) = self._check_new_handshake(buf) (new, handshake_len, token, side) = self._check_new_handshake(buf)
assert new in ("yes", "waiting", "no") assert new in ("yes", "waiting", "no")
@ -88,13 +88,13 @@ class TransitConnection(protocol.Protocol):
self.transport.write(b"impatient\n") self.transport.write(b"impatient\n")
if self._log_requests: if self._log_requests:
log.msg("transit impatience failure") log.msg("transit impatience failure")
return self.disconnect() # impatience yields failure return self.disconnect_error() # impatience yields failure
return self._got_handshake(token, side) return self._got_handshake(token, side)
if (old == "no" and new == "no"): if (old == "no" and new == "no"):
self.transport.write(b"bad handshake\n") self.transport.write(b"bad handshake\n")
if self._log_requests: if self._log_requests:
log.msg("transit handshake failure") log.msg("transit handshake failure")
return self.disconnect() # incorrectness yields failure return self.disconnect_error() # incorrectness yields failure
# else we'll keep waiting # else we'll keep waiting
def _check_old_handshake(self, buf): def _check_old_handshake(self, buf):
@ -132,11 +132,12 @@ class TransitConnection(protocol.Protocol):
def _got_handshake(self, token, side): def _got_handshake(self, token, side):
self._got_token = token self._got_token = token
self._got_side = side self._got_side = side
self._mood = "lonely" # until buddy connects
self.factory.connection_got_token(token, side, self) self.factory.connection_got_token(token, side, self)
def buddy_connected(self, them): def buddy_connected(self, them):
self._buddy = them self._buddy = them
self._had_buddy = True self._mood = "happy"
self.transport.write(b"ok\n") self.transport.write(b"ok\n")
self._sent_ok = True self._sent_ok = True
# Connect the two as a producer/consumer pair. We use streaming=True, # Connect the two as a producer/consumer pair. We use streaming=True,
@ -150,44 +151,72 @@ class TransitConnection(protocol.Protocol):
if self._log_requests: if self._log_requests:
log.msg("buddy_disconnected %s" % self.describeToken()) log.msg("buddy_disconnected %s" % self.describeToken())
self._buddy = None self._buddy = None
self._mood = "jilted"
self.transport.loseConnection()
def disconnect_error(self):
# we haven't finished the handshake, so there are no tokens tracking
# us
self._mood = "errory"
self.transport.loseConnection()
if self.factory._debug_log:
log.msg("transitFailed %r" % self)
def disconnect_redundant(self):
# this is called if a buddy connected and we were found unnecessary.
# Any token-tracking cleanup will have been done before we're called.
self._mood = "redundant"
self.transport.loseConnection() self.transport.loseConnection()
def connectionLost(self, reason): def connectionLost(self, reason):
finished = time.time()
total_time = finished - self._started
# Record usage. There are seven cases:
# * n1: the handshake failed, not a real client (errory)
# * n2: real client disconnected before any buddy appeared (lonely)
# * n3: real client closed as redundant after buddy appears (redundant)
# * n4: real client connected first, buddy closes first (jilted)
# * n5: real client connected first, buddy close last (happy)
# * n6: real client connected last, buddy closes first (jilted)
# * n7: real client connected last, buddy closes last (happy)
# * non-connected clients (1,2,3) always write a usage record
# * for connected clients, whoever disconnects first gets to write the
# usage record (5, 7). The last disconnect doesn't write a record.
if self._mood == "errory": # 1
assert not self._buddy
self.factory.recordUsage(self._started, "errory", 0,
total_time, None)
elif self._mood == "redundant": # 3
assert not self._buddy
self.factory.recordUsage(self._started, "redundant", 0,
total_time, None)
elif self._mood == "jilted": # 4 or 6
# we were connected, but our buddy hung up on us. They record the
# usage event, we do not
pass
elif self._mood == "lonely": # 2
assert not self._buddy
self.factory.recordUsage(self._started, "lonely", 0,
total_time, None)
else: # 5 or 7
# we were connected, we hung up first. We record the event.
assert self._mood == "happy", self._mood
assert self._buddy
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
total_bytes = self._total_sent + self._buddy._total_sent
self.factory.recordUsage(self._started, "happy", total_bytes,
total_time, waiting_time)
if self._buddy: if self._buddy:
self._buddy.buddy_disconnected() self._buddy.buddy_disconnected()
self.factory.transitFinished(self, self._got_token, self._got_side, self.factory.transitFinished(self, self._got_token, self._got_side,
self.describeToken()) self.describeToken())
# Record usage. There are four cases:
# * 1: we connected, never had a buddy
# * 2: we connected first, we disconnect before the buddy
# * 3: we connected first, buddy disconnects first
# * 4: buddy connected first, we disconnect before buddy
# * 5: buddy connected first, buddy disconnects first
# whoever disconnects first gets to write the usage record (1,2,4)
finished = time.time()
if not self._had_buddy: # 1
total_time = finished - self._started
self.factory.recordUsage(self._started, "lonely", 0,
total_time, None)
if self._had_buddy and self._buddy: # 2,4
total_bytes = self._total_sent + self._buddy._total_sent
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
self.factory.recordUsage(self._started, "happy", total_bytes,
total_time, waiting_time)
def disconnect(self):
self.transport.loseConnection()
self.factory.transitFailed(self)
finished = time.time()
total_time = finished - self._started
self.factory.recordUsage(self._started, "errory", 0,
total_time, None)
class Transit(protocol.ServerFactory): class Transit(protocol.ServerFactory):
# I manage pairs of simultaneous connections to a secondary TCP port, # I manage pairs of simultaneous connections to a secondary TCP port,
# both forwarded to the other. Clients must begin each connection with # both forwarded to the other. Clients must begin each connection with
@ -255,7 +284,12 @@ class Transit(protocol.ServerFactory):
# drop and stop tracking the rest # drop and stop tracking the rest
potentials.remove(old) potentials.remove(old)
for (_, leftover_tc) in potentials: for (_, leftover_tc) in potentials:
leftover_tc.disconnect() # TODO: not "errory"? # Don't record this as errory. It's just a spare connection
# from the same side as a connection that got used. This
# can happen if the connection hint contains multiple
# addresses (we don't currently support those, but it'd
# probably be useful in the future).
leftover_tc.disconnect_redundant()
self._pending_requests.pop(token) self._pending_requests.pop(token)
# glue the two ends together # glue the two ends together
@ -272,18 +306,13 @@ class Transit(protocol.ServerFactory):
def transitFinished(self, tc, token, side, description): def transitFinished(self, tc, token, side, description):
if token in self._pending_requests: if token in self._pending_requests:
side_tc = (side, tc) side_tc = (side, tc)
if side_tc in self._pending_requests[token]: self._pending_requests[token].discard(side_tc)
self._pending_requests[token].remove(side_tc)
if not self._pending_requests[token]: # set is now empty if not self._pending_requests[token]: # set is now empty
del self._pending_requests[token] del self._pending_requests[token]
if self._debug_log: if self._debug_log:
log.msg("transitFinished %s" % (description,)) log.msg("transitFinished %s" % (description,))
self._active_connections.discard(tc) self._active_connections.discard(tc)
def transitFailed(self, p):
if self._debug_log:
log.msg("transitFailed %r" % p)
def recordUsage(self, started, result, total_bytes, def recordUsage(self, started, result, total_bytes,
total_time, waiting_time): total_time, waiting_time):
if self._debug_log: if self._debug_log: