diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 8ce4a93..f07b275 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -214,6 +214,10 @@ class _Transit: ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) a2 = yield connectProtocol(ep, Accumulator()) + a3 = yield connectProtocol(ep, Accumulator()) + disconnects = [] + a1._disconnect.addCallback(disconnects.append) + a2._disconnect.addCallback(disconnects.append) token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -229,8 +233,26 @@ class _Transit: yield wait() self.assertEqual(self.count(), 2) # same-side connections don't match + # when the second side arrives, the spare first connection should be + # closed + side2 = b"\x02"*8 + a3.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side2) + b"\n") + # let that arrive + while self.count() != 0: + yield wait() + self.assertEqual(len(self._transit_server._pending_requests), 0) + self.assertEqual(len(self._transit_server._active_connections), 2) + # That will trigger a disconnect on exactly one of (a1 or a2). Wait + # until our client notices it. + while not disconnects: + yield wait() + # the other connection should still be connected + self.assertEqual(sum([int(t.transport.connected) for t in [a1, a2]]), 1) + a1.transport.loseConnection() a2.transport.loseConnection() + a3.transport.loseConnection() @defer.inlineCallbacks def test_bad_handshake_old(self): @@ -379,3 +401,143 @@ class TransitWithLogs(_Transit, ServerBase, unittest.TestCase): class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): log_requests = False + +class Usage(ServerBase, unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + yield super(Usage, self).setUp() + self._usage = [] + def record(started, result, total_bytes, total_time, waiting_time): + self._usage.append((started, result, total_bytes, + total_time, waiting_time)) + self._transit_server.recordUsage = record + + @defer.inlineCallbacks + def test_errory(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + a1.transport.write(b"this is a very bad handshake\n") + # that will log the "errory" usage event, then drop the connection + yield a1._disconnect + self.assertEqual(len(self._usage), 1, self._usage) + (started, result, total_bytes, total_time, waiting_time) = self._usage[0] + self.assertEqual(result, "errory", self._usage) + + @defer.inlineCallbacks + def test_lonely(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + while not self._transit_server._pending_requests: + yield wait() # wait for the server to see the connection + # now we disconnect before the peer connects + a1.transport.loseConnection() + yield a1._disconnect + while self._transit_server._pending_requests: + yield wait() # wait for the server to see the disconnect too + + self.assertEqual(len(self._usage), 1, self._usage) + (started, result, total_bytes, total_time, waiting_time) = self._usage[0] + self.assertEqual(result, "lonely", self._usage) + self.assertIdentical(waiting_time, None) + + @defer.inlineCallbacks + def test_one_happy_one_jilted(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + side2 = b"\x02"*8 + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + while not self._transit_server._pending_requests: + yield wait() # make sure a1 connects first + a2.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side2) + b"\n") + while not self._transit_server._active_connections: + yield wait() # wait for the server to see the connection + self.assertEqual(len(self._transit_server._pending_requests), 0) + self.assertEqual(self._usage, []) # no events yet + a1.transport.write(b"\x00" * 13) + yield a2.waitForBytes(13) + a2.transport.write(b"\xff" * 7) + yield a1.waitForBytes(7) + + a1.transport.loseConnection() + yield a1._disconnect + while self._transit_server._active_connections: + yield wait() + yield a2._disconnect + self.assertEqual(len(self._usage), 1, self._usage) + (started, result, total_bytes, total_time, waiting_time) = self._usage[0] + self.assertEqual(result, "happy", self._usage) + self.assertEqual(total_bytes, 20) + self.assertNotIdentical(waiting_time, None) + + @defer.inlineCallbacks + def test_redundant(self): + ep = clientFromString(reactor, self.transit) + a1a = yield connectProtocol(ep, Accumulator()) + a1b = yield connectProtocol(ep, Accumulator()) + a1c = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + side2 = b"\x02"*8 + a1a.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + def count_requests(): + return sum([len(v) + for v in self._transit_server._pending_requests.values()]) + while count_requests() < 1: + yield wait() + a1b.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + while count_requests() < 2: + yield wait() + + # connect and disconnect a third client (for side1) to exercise the + # code that removes a pending connection without removing the entire + # token + a1c.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + while count_requests() < 3: + yield wait() + a1c.transport.loseConnection() + yield a1c._disconnect + while count_requests() > 2: + yield wait() + self.assertEqual(len(self._usage), 1, self._usage) + (started, result, total_bytes, total_time, waiting_time) = self._usage[0] + self.assertEqual(result, "lonely", self._usage) + + a2.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side2) + b"\n") + # this will claim one of (a1a, a1b), and close the other as redundant + while not self._transit_server._active_connections: + yield wait() # wait for the server to see the connection + self.assertEqual(count_requests(), 0) + self.assertEqual(len(self._usage), 2, self._usage) + (started, result, total_bytes, total_time, waiting_time) = self._usage[1] + self.assertEqual(result, "redundant", self._usage) + + # one of the these is unecessary, but probably harmless + a1a.transport.loseConnection() + a1b.transport.loseConnection() + yield a1a._disconnect + yield a1b._disconnect + while self._transit_server._active_connections: + yield wait() + yield a2._disconnect + self.assertEqual(len(self._usage), 3, self._usage) + (started, result, total_bytes, total_time, waiting_time) = self._usage[2] + self.assertEqual(result, "happy", self._usage) + diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 21b1636..1b52045 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -28,8 +28,8 @@ class TransitConnection(protocol.Protocol): self._got_side = False self._token_buffer = b"" self._sent_ok = False + self._mood = None self._buddy = None - self._had_buddy = False self._total_sent = 0 def describeToken(self): @@ -62,7 +62,7 @@ class TransitConnection(protocol.Protocol): self.transport.write(b"impatient\n") if self._log_requests: log.msg("transit impatience failure") - return self.disconnect() # impatience yields failure + return self.disconnect_error() # impatience yields failure # else this should be (part of) the token self._token_buffer += data @@ -79,7 +79,7 @@ class TransitConnection(protocol.Protocol): self.transport.write(b"impatient\n") if self._log_requests: log.msg("transit impatience failure") - return self.disconnect() # impatience yields failure + return self.disconnect_error() # impatience yields failure return self._got_handshake(token, None) (new, handshake_len, token, side) = self._check_new_handshake(buf) assert new in ("yes", "waiting", "no") @@ -88,13 +88,13 @@ class TransitConnection(protocol.Protocol): self.transport.write(b"impatient\n") if self._log_requests: log.msg("transit impatience failure") - return self.disconnect() # impatience yields failure + return self.disconnect_error() # impatience yields failure return self._got_handshake(token, side) if (old == "no" and new == "no"): self.transport.write(b"bad handshake\n") if self._log_requests: log.msg("transit handshake failure") - return self.disconnect() # incorrectness yields failure + return self.disconnect_error() # incorrectness yields failure # else we'll keep waiting def _check_old_handshake(self, buf): @@ -132,11 +132,12 @@ class TransitConnection(protocol.Protocol): def _got_handshake(self, token, side): self._got_token = token self._got_side = side + self._mood = "lonely" # until buddy connects self.factory.connection_got_token(token, side, self) def buddy_connected(self, them): self._buddy = them - self._had_buddy = True + self._mood = "happy" self.transport.write(b"ok\n") self._sent_ok = True # Connect the two as a producer/consumer pair. We use streaming=True, @@ -150,44 +151,72 @@ class TransitConnection(protocol.Protocol): if self._log_requests: log.msg("buddy_disconnected %s" % self.describeToken()) self._buddy = None + self._mood = "jilted" + self.transport.loseConnection() + + def disconnect_error(self): + # we haven't finished the handshake, so there are no tokens tracking + # us + self._mood = "errory" + self.transport.loseConnection() + if self.factory._debug_log: + log.msg("transitFailed %r" % self) + + def disconnect_redundant(self): + # this is called if a buddy connected and we were found unnecessary. + # Any token-tracking cleanup will have been done before we're called. + self._mood = "redundant" self.transport.loseConnection() def connectionLost(self, reason): + finished = time.time() + total_time = finished - self._started + + # Record usage. There are seven cases: + # * n1: the handshake failed, not a real client (errory) + # * n2: real client disconnected before any buddy appeared (lonely) + # * n3: real client closed as redundant after buddy appears (redundant) + # * n4: real client connected first, buddy closes first (jilted) + # * n5: real client connected first, buddy close last (happy) + # * n6: real client connected last, buddy closes first (jilted) + # * n7: real client connected last, buddy closes last (happy) + + # * non-connected clients (1,2,3) always write a usage record + # * for connected clients, whoever disconnects first gets to write the + # usage record (5, 7). The last disconnect doesn't write a record. + + if self._mood == "errory": # 1 + assert not self._buddy + self.factory.recordUsage(self._started, "errory", 0, + total_time, None) + elif self._mood == "redundant": # 3 + assert not self._buddy + self.factory.recordUsage(self._started, "redundant", 0, + total_time, None) + elif self._mood == "jilted": # 4 or 6 + # we were connected, but our buddy hung up on us. They record the + # usage event, we do not + pass + elif self._mood == "lonely": # 2 + assert not self._buddy + self.factory.recordUsage(self._started, "lonely", 0, + total_time, None) + else: # 5 or 7 + # we were connected, we hung up first. We record the event. + assert self._mood == "happy", self._mood + assert self._buddy + starts = [self._started, self._buddy._started] + total_time = finished - min(starts) + waiting_time = max(starts) - min(starts) + total_bytes = self._total_sent + self._buddy._total_sent + self.factory.recordUsage(self._started, "happy", total_bytes, + total_time, waiting_time) + if self._buddy: self._buddy.buddy_disconnected() self.factory.transitFinished(self, self._got_token, self._got_side, self.describeToken()) - # Record usage. There are four cases: - # * 1: we connected, never had a buddy - # * 2: we connected first, we disconnect before the buddy - # * 3: we connected first, buddy disconnects first - # * 4: buddy connected first, we disconnect before buddy - # * 5: buddy connected first, buddy disconnects first - - # whoever disconnects first gets to write the usage record (1,2,4) - - finished = time.time() - if not self._had_buddy: # 1 - total_time = finished - self._started - self.factory.recordUsage(self._started, "lonely", 0, - total_time, None) - if self._had_buddy and self._buddy: # 2,4 - total_bytes = self._total_sent + self._buddy._total_sent - starts = [self._started, self._buddy._started] - total_time = finished - min(starts) - waiting_time = max(starts) - min(starts) - self.factory.recordUsage(self._started, "happy", total_bytes, - total_time, waiting_time) - - def disconnect(self): - self.transport.loseConnection() - self.factory.transitFailed(self) - finished = time.time() - total_time = finished - self._started - self.factory.recordUsage(self._started, "errory", 0, - total_time, None) - class Transit(protocol.ServerFactory): # I manage pairs of simultaneous connections to a secondary TCP port, # both forwarded to the other. Clients must begin each connection with @@ -255,7 +284,12 @@ class Transit(protocol.ServerFactory): # drop and stop tracking the rest potentials.remove(old) for (_, leftover_tc) in potentials: - leftover_tc.disconnect() # TODO: not "errory"? + # Don't record this as errory. It's just a spare connection + # from the same side as a connection that got used. This + # can happen if the connection hint contains multiple + # addresses (we don't currently support those, but it'd + # probably be useful in the future). + leftover_tc.disconnect_redundant() self._pending_requests.pop(token) # glue the two ends together @@ -272,18 +306,13 @@ class Transit(protocol.ServerFactory): def transitFinished(self, tc, token, side, description): if token in self._pending_requests: side_tc = (side, tc) - if side_tc in self._pending_requests[token]: - self._pending_requests[token].remove(side_tc) + self._pending_requests[token].discard(side_tc) if not self._pending_requests[token]: # set is now empty del self._pending_requests[token] if self._debug_log: log.msg("transitFinished %s" % (description,)) self._active_connections.discard(tc) - def transitFailed(self, p): - if self._debug_log: - log.msg("transitFailed %r" % p) - def recordUsage(self, started, result, total_bytes, total_time, waiting_time): if self._debug_log: