diff --git a/misc/munin/wormhole_transit b/misc/munin/wormhole_transit_bytes similarity index 100% rename from misc/munin/wormhole_transit rename to misc/munin/wormhole_transit_bytes diff --git a/misc/munin/wormhole_transit_alltime b/misc/munin/wormhole_transit_bytes_alltime similarity index 100% rename from misc/munin/wormhole_transit_alltime rename to misc/munin/wormhole_transit_bytes_alltime diff --git a/misc/munin/wormhole_transit_errors b/misc/munin/wormhole_transit_events similarity index 65% rename from misc/munin/wormhole_transit_errors rename to misc/munin/wormhole_transit_events index d5dc16a..2690cdf 100755 --- a/misc/munin/wormhole_transit_errors +++ b/misc/munin/wormhole_transit_events @@ -11,15 +11,21 @@ from __future__ import print_function import os, sys, time, sqlite3 CONFIG = """\ -graph_title Magic-Wormhole Transit Server Errors (since reboot) +graph_title Magic-Wormhole Transit Server Events (since reboot) graph_vlabel Events Since Reboot graph_category network +happy.label Happy +happy.draw LINE1 +happy.type GAUGE errory.label Errory errory.draw LINE1 errory.type GAUGE lonely.label Lonely lonely.draw LINE1 lonely.type GAUGE +redundant.label Redundant +redundant.draw LINE1 +redundant.type GAUGE """ if len(sys.argv) > 1 and sys.argv[1] == "config": @@ -35,6 +41,13 @@ rebooted,updated = db.execute("SELECT `rebooted`, `updated` FROM `current`").fet if time.time() > updated + 5*MINUTE: sys.exit(1) # expired +count = db.execute("SELECT COUNT() FROM `usage`" + " WHERE" + " `started` > ? AND" + " `result` = 'happy'", + (rebooted,)).fetchone()[0] +print("happy.value", count) + count = db.execute("SELECT COUNT() FROM `usage`" " WHERE" " `started` > ? AND" @@ -48,3 +61,10 @@ count = db.execute("SELECT COUNT() FROM `usage`" " `result` = 'lonely'", (rebooted,)).fetchone()[0] print("lonely.value", count) + +count = db.execute("SELECT COUNT() FROM `usage`" + " WHERE" + " `started` > ? AND" + " `result` = 'redundant'", + (rebooted,)).fetchone()[0] +print("redundant.value", count) diff --git a/src/wormhole_transit_relay/db-schemas/v1.sql b/src/wormhole_transit_relay/db-schemas/v1.sql index f68d742..7c5e245 100644 --- a/src/wormhole_transit_relay/db-schemas/v1.sql +++ b/src/wormhole_transit_relay/db-schemas/v1.sql @@ -24,6 +24,7 @@ CREATE TABLE `usage` -- transit moods: -- "errory": one side gave the wrong handshake -- "lonely": good handshake, but the other side never showed up + -- "redundant": good handshake, abandoned in favor of different connection -- "happy": both sides gave correct handshake ); CREATE INDEX `usage_started_index` ON `usage` (`started`); diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 10755ca..f07b275 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -31,6 +31,11 @@ class Accumulator(protocol.Protocol): self._wait.errback(RuntimeError("closed")) self._disconnect.callback(None) +def wait(): + d = defer.Deferred() + reactor.callLater(0.001, d.callback, None) + return d + class _Transit: def test_blur_size(self): blur = transit_server.blur_size @@ -62,14 +67,14 @@ class _Transit: # let that arrive while self.count() == 0: - yield self.wait() + yield wait() self.assertEqual(self.count(), 1) a1.transport.loseConnection() # let that get removed while self.count() > 0: - yield self.wait() + yield wait() self.assertEqual(self.count(), 0) # the token should be removed too @@ -203,16 +208,16 @@ class _Transit: return sum([len(potentials) for potentials in self._transit_server._pending_requests.values()]) - def wait(self): - d = defer.Deferred() - reactor.callLater(0.001, d.callback, None) - return d @defer.inlineCallbacks def test_ignore_same_side(self): ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) a2 = yield connectProtocol(ep, Accumulator()) + a3 = yield connectProtocol(ep, Accumulator()) + disconnects = [] + a1._disconnect.addCallback(disconnects.append) + a2._disconnect.addCallback(disconnects.append) token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -220,16 +225,34 @@ class _Transit: b" for side " + hexlify(side1) + b"\n") # let that arrive while self.count() == 0: - yield self.wait() + yield wait() a2.transport.write(b"please relay " + hexlify(token1) + b" for side " + hexlify(side1) + b"\n") # let that arrive while self.count() == 1: - yield self.wait() + yield wait() self.assertEqual(self.count(), 2) # same-side connections don't match + # when the second side arrives, the spare first connection should be + # closed + side2 = b"\x02"*8 + a3.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side2) + b"\n") + # let that arrive + while self.count() != 0: + yield wait() + self.assertEqual(len(self._transit_server._pending_requests), 0) + self.assertEqual(len(self._transit_server._active_connections), 2) + # That will trigger a disconnect on exactly one of (a1 or a2). Wait + # until our client notices it. + while not disconnects: + yield wait() + # the other connection should still be connected + self.assertEqual(sum([int(t.transport.connected) for t in [a1, a2]]), 1) + a1.transport.loseConnection() a2.transport.loseConnection() + a3.transport.loseConnection() @defer.inlineCallbacks def test_bad_handshake_old(self): @@ -378,3 +401,143 @@ class TransitWithLogs(_Transit, ServerBase, unittest.TestCase): class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): log_requests = False + +class Usage(ServerBase, unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + yield super(Usage, self).setUp() + self._usage = [] + def record(started, result, total_bytes, total_time, waiting_time): + self._usage.append((started, result, total_bytes, + total_time, waiting_time)) + self._transit_server.recordUsage = record + + @defer.inlineCallbacks + def test_errory(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + a1.transport.write(b"this is a very bad handshake\n") + # that will log the "errory" usage event, then drop the connection + yield a1._disconnect + self.assertEqual(len(self._usage), 1, self._usage) + (started, result, total_bytes, total_time, waiting_time) = self._usage[0] + self.assertEqual(result, "errory", self._usage) + + @defer.inlineCallbacks + def test_lonely(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + while not self._transit_server._pending_requests: + yield wait() # wait for the server to see the connection + # now we disconnect before the peer connects + a1.transport.loseConnection() + yield a1._disconnect + while self._transit_server._pending_requests: + yield wait() # wait for the server to see the disconnect too + + self.assertEqual(len(self._usage), 1, self._usage) + (started, result, total_bytes, total_time, waiting_time) = self._usage[0] + self.assertEqual(result, "lonely", self._usage) + self.assertIdentical(waiting_time, None) + + @defer.inlineCallbacks + def test_one_happy_one_jilted(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + side2 = b"\x02"*8 + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + while not self._transit_server._pending_requests: + yield wait() # make sure a1 connects first + a2.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side2) + b"\n") + while not self._transit_server._active_connections: + yield wait() # wait for the server to see the connection + self.assertEqual(len(self._transit_server._pending_requests), 0) + self.assertEqual(self._usage, []) # no events yet + a1.transport.write(b"\x00" * 13) + yield a2.waitForBytes(13) + a2.transport.write(b"\xff" * 7) + yield a1.waitForBytes(7) + + a1.transport.loseConnection() + yield a1._disconnect + while self._transit_server._active_connections: + yield wait() + yield a2._disconnect + self.assertEqual(len(self._usage), 1, self._usage) + (started, result, total_bytes, total_time, waiting_time) = self._usage[0] + self.assertEqual(result, "happy", self._usage) + self.assertEqual(total_bytes, 20) + self.assertNotIdentical(waiting_time, None) + + @defer.inlineCallbacks + def test_redundant(self): + ep = clientFromString(reactor, self.transit) + a1a = yield connectProtocol(ep, Accumulator()) + a1b = yield connectProtocol(ep, Accumulator()) + a1c = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + side2 = b"\x02"*8 + a1a.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + def count_requests(): + return sum([len(v) + for v in self._transit_server._pending_requests.values()]) + while count_requests() < 1: + yield wait() + a1b.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + while count_requests() < 2: + yield wait() + + # connect and disconnect a third client (for side1) to exercise the + # code that removes a pending connection without removing the entire + # token + a1c.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + while count_requests() < 3: + yield wait() + a1c.transport.loseConnection() + yield a1c._disconnect + while count_requests() > 2: + yield wait() + self.assertEqual(len(self._usage), 1, self._usage) + (started, result, total_bytes, total_time, waiting_time) = self._usage[0] + self.assertEqual(result, "lonely", self._usage) + + a2.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side2) + b"\n") + # this will claim one of (a1a, a1b), and close the other as redundant + while not self._transit_server._active_connections: + yield wait() # wait for the server to see the connection + self.assertEqual(count_requests(), 0) + self.assertEqual(len(self._usage), 2, self._usage) + (started, result, total_bytes, total_time, waiting_time) = self._usage[1] + self.assertEqual(result, "redundant", self._usage) + + # one of the these is unecessary, but probably harmless + a1a.transport.loseConnection() + a1b.transport.loseConnection() + yield a1a._disconnect + yield a1b._disconnect + while self._transit_server._active_connections: + yield wait() + yield a2._disconnect + self.assertEqual(len(self._usage), 3, self._usage) + (started, result, total_bytes, total_time, waiting_time) = self._usage[2] + self.assertEqual(result, "happy", self._usage) + diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 21b1636..26d0274 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -1,5 +1,6 @@ from __future__ import print_function, unicode_literals import re, time, json +from collections import defaultdict from twisted.python import log from twisted.internet import protocol from .database import get_db @@ -28,8 +29,8 @@ class TransitConnection(protocol.Protocol): self._got_side = False self._token_buffer = b"" self._sent_ok = False + self._mood = None self._buddy = None - self._had_buddy = False self._total_sent = 0 def describeToken(self): @@ -62,7 +63,7 @@ class TransitConnection(protocol.Protocol): self.transport.write(b"impatient\n") if self._log_requests: log.msg("transit impatience failure") - return self.disconnect() # impatience yields failure + return self.disconnect_error() # impatience yields failure # else this should be (part of) the token self._token_buffer += data @@ -79,7 +80,7 @@ class TransitConnection(protocol.Protocol): self.transport.write(b"impatient\n") if self._log_requests: log.msg("transit impatience failure") - return self.disconnect() # impatience yields failure + return self.disconnect_error() # impatience yields failure return self._got_handshake(token, None) (new, handshake_len, token, side) = self._check_new_handshake(buf) assert new in ("yes", "waiting", "no") @@ -88,13 +89,13 @@ class TransitConnection(protocol.Protocol): self.transport.write(b"impatient\n") if self._log_requests: log.msg("transit impatience failure") - return self.disconnect() # impatience yields failure + return self.disconnect_error() # impatience yields failure return self._got_handshake(token, side) if (old == "no" and new == "no"): self.transport.write(b"bad handshake\n") if self._log_requests: log.msg("transit handshake failure") - return self.disconnect() # incorrectness yields failure + return self.disconnect_error() # incorrectness yields failure # else we'll keep waiting def _check_old_handshake(self, buf): @@ -132,11 +133,12 @@ class TransitConnection(protocol.Protocol): def _got_handshake(self, token, side): self._got_token = token self._got_side = side + self._mood = "lonely" # until buddy connects self.factory.connection_got_token(token, side, self) def buddy_connected(self, them): self._buddy = them - self._had_buddy = True + self._mood = "happy" self.transport.write(b"ok\n") self._sent_ok = True # Connect the two as a producer/consumer pair. We use streaming=True, @@ -150,44 +152,72 @@ class TransitConnection(protocol.Protocol): if self._log_requests: log.msg("buddy_disconnected %s" % self.describeToken()) self._buddy = None + self._mood = "jilted" + self.transport.loseConnection() + + def disconnect_error(self): + # we haven't finished the handshake, so there are no tokens tracking + # us + self._mood = "errory" + self.transport.loseConnection() + if self.factory._debug_log: + log.msg("transitFailed %r" % self) + + def disconnect_redundant(self): + # this is called if a buddy connected and we were found unnecessary. + # Any token-tracking cleanup will have been done before we're called. + self._mood = "redundant" self.transport.loseConnection() def connectionLost(self, reason): + finished = time.time() + total_time = finished - self._started + + # Record usage. There are seven cases: + # * n1: the handshake failed, not a real client (errory) + # * n2: real client disconnected before any buddy appeared (lonely) + # * n3: real client closed as redundant after buddy appears (redundant) + # * n4: real client connected first, buddy closes first (jilted) + # * n5: real client connected first, buddy close last (happy) + # * n6: real client connected last, buddy closes first (jilted) + # * n7: real client connected last, buddy closes last (happy) + + # * non-connected clients (1,2,3) always write a usage record + # * for connected clients, whoever disconnects first gets to write the + # usage record (5, 7). The last disconnect doesn't write a record. + + if self._mood == "errory": # 1 + assert not self._buddy + self.factory.recordUsage(self._started, "errory", 0, + total_time, None) + elif self._mood == "redundant": # 3 + assert not self._buddy + self.factory.recordUsage(self._started, "redundant", 0, + total_time, None) + elif self._mood == "jilted": # 4 or 6 + # we were connected, but our buddy hung up on us. They record the + # usage event, we do not + pass + elif self._mood == "lonely": # 2 + assert not self._buddy + self.factory.recordUsage(self._started, "lonely", 0, + total_time, None) + else: # 5 or 7 + # we were connected, we hung up first. We record the event. + assert self._mood == "happy", self._mood + assert self._buddy + starts = [self._started, self._buddy._started] + total_time = finished - min(starts) + waiting_time = max(starts) - min(starts) + total_bytes = self._total_sent + self._buddy._total_sent + self.factory.recordUsage(self._started, "happy", total_bytes, + total_time, waiting_time) + if self._buddy: self._buddy.buddy_disconnected() self.factory.transitFinished(self, self._got_token, self._got_side, self.describeToken()) - # Record usage. There are four cases: - # * 1: we connected, never had a buddy - # * 2: we connected first, we disconnect before the buddy - # * 3: we connected first, buddy disconnects first - # * 4: buddy connected first, we disconnect before buddy - # * 5: buddy connected first, buddy disconnects first - - # whoever disconnects first gets to write the usage record (1,2,4) - - finished = time.time() - if not self._had_buddy: # 1 - total_time = finished - self._started - self.factory.recordUsage(self._started, "lonely", 0, - total_time, None) - if self._had_buddy and self._buddy: # 2,4 - total_bytes = self._total_sent + self._buddy._total_sent - starts = [self._started, self._buddy._started] - total_time = finished - min(starts) - waiting_time = max(starts) - min(starts) - self.factory.recordUsage(self._started, "happy", total_bytes, - total_time, waiting_time) - - def disconnect(self): - self.transport.loseConnection() - self.factory.transitFailed(self) - finished = time.time() - total_time = finished - self._started - self.factory.recordUsage(self._started, "errory", 0, - total_time, None) - class Transit(protocol.ServerFactory): # I manage pairs of simultaneous connections to a secondary TCP port, # both forwarded to the other. Clients must begin each connection with @@ -236,12 +266,10 @@ class Transit(protocol.ServerFactory): self._db = get_db(usage_db) self._rebooted = time.time() # we don't track TransitConnections until they submit a token - self._pending_requests = {} # token -> set((side, TransitConnection)) + self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection)) self._active_connections = set() # TransitConnection def connection_got_token(self, token, new_side, new_tc): - if token not in self._pending_requests: - self._pending_requests[token] = set() potentials = self._pending_requests[token] for old in potentials: (old_side, old_tc) = old @@ -255,7 +283,12 @@ class Transit(protocol.ServerFactory): # drop and stop tracking the rest potentials.remove(old) for (_, leftover_tc) in potentials: - leftover_tc.disconnect() # TODO: not "errory"? + # Don't record this as errory. It's just a spare connection + # from the same side as a connection that got used. This + # can happen if the connection hint contains multiple + # addresses (we don't currently support those, but it'd + # probably be useful in the future). + leftover_tc.disconnect_redundant() self._pending_requests.pop(token) # glue the two ends together @@ -272,17 +305,23 @@ class Transit(protocol.ServerFactory): def transitFinished(self, tc, token, side, description): if token in self._pending_requests: side_tc = (side, tc) - if side_tc in self._pending_requests[token]: - self._pending_requests[token].remove(side_tc) + self._pending_requests[token].discard(side_tc) if not self._pending_requests[token]: # set is now empty del self._pending_requests[token] if self._debug_log: log.msg("transitFinished %s" % (description,)) self._active_connections.discard(tc) - - def transitFailed(self, p): - if self._debug_log: - log.msg("transitFailed %r" % p) + # we could update the usage database "current" row immediately, or wait + # until the 5-minute timer updates it. If we update it now, just after + # losing a connection, we should probably also update it just after + # establishing one (at the end of connection_got_token). For now I'm + # going to omit these, but maybe someday we'll turn them both on. The + # consequence is that a manual execution of the munin scripts ("munin + # run wormhole_transit_active") will give the wrong value just after a + # connect/disconnect event. Actual munin graphs should accurately + # report connections that last longer than the 5-minute sampling + # window, which is what we actually care about. + #self.timerUpdateStats() def recordUsage(self, started, result, total_bytes, total_time, waiting_time):