diff --git a/src/wormhole/db-schemas/v1.sql b/src/wormhole/db-schemas/v1.sql index 7dba360..f3eb82f 100644 --- a/src/wormhole/db-schemas/v1.sql +++ b/src/wormhole/db-schemas/v1.sql @@ -21,7 +21,7 @@ CREATE INDEX `messages_idx` ON `messages` (`appid`, `channelid`); CREATE TABLE `usage` ( - `type` VARCHAR, -- "rendezvous" + `type` VARCHAR, -- "rendezvous" or "transit" `started` INTEGER, -- seconds since epoch, rounded to one day `result` VARCHAR, -- happy, scary, lonely, errory, pruney -- rendezvous moods: @@ -31,7 +31,11 @@ CREATE TABLE `usage` -- "errory": any side closes with mood=errory (other errors) -- "pruney": channels which get pruned for inactivity -- "crowded": three or more sides were involved - `total_bytes` INTEGER, -- not yet used + -- transit moods: + -- "errory": this side have the wrong handshake + -- "lonely": good handshake, but the other side never showed up + -- "happy": both sides gave correct handshake + `total_bytes` INTEGER, -- for transit, total bytes relayed (both directions) `total_time` INTEGER, -- seconds from start to closed, or None `waiting_time` INTEGER -- seconds from start to 2nd side appearing, or None ); diff --git a/src/wormhole/servers/server.py b/src/wormhole/servers/server.py index 3354745..8e4118a 100644 --- a/src/wormhole/servers/server.py +++ b/src/wormhole/servers/server.py @@ -38,7 +38,7 @@ class RelayServer(service.MultiService): self.relay = Relay(self.db, welcome) # accessible from tests self.root.putChild(b"wormhole-relay", self.relay) if transitport: - self.transit = Transit() + self.transit = Transit(self.db) self.transit.setServiceParent(self) # for the timer t = endpoints.serverFromString(reactor, transitport) self.transport_service = ServerEndpointService(t, self.transit) diff --git a/src/wormhole/servers/transit_server.py b/src/wormhole/servers/transit_server.py index 5071ad0..a082a2e 100644 --- a/src/wormhole/servers/transit_server.py +++ b/src/wormhole/servers/transit_server.py @@ -1,5 +1,5 @@ from __future__ import print_function -import re +import re, time from twisted.python import log from twisted.internet import protocol from twisted.application import service @@ -16,8 +16,12 @@ class TransitConnection(protocol.Protocol): self._token_buffer = b"" self._sent_ok = False self._buddy = None + self._had_buddy = False self._total_sent = 0 + def connectionMade(self): + self._started = time.time() + def dataReceived(self, data): if self._sent_ok: # We are an IPushProducer to our buddy's IConsumer, so they'll @@ -29,10 +33,12 @@ class TransitConnection(protocol.Protocol): self._total_sent += len(data) self._buddy.transport.write(data) return + if self._got_token: # but not yet sent_ok self.transport.write(b"impatient\n") log.msg("transit impatience failure") return self.disconnect() # impatience yields failure + # else this should be (part of) the token self._token_buffer += data buf = self._token_buffer @@ -59,6 +65,7 @@ class TransitConnection(protocol.Protocol): def buddy_connected(self, them): self._buddy = them + self._had_buddy = True self.transport.write(b"ok\n") self._sent_ok = True # Connect the two as a producer/consumer pair. We use streaming=True, @@ -77,11 +84,37 @@ class TransitConnection(protocol.Protocol): log.msg("connectionLost %r %s" % (self, reason)) if self._buddy: self._buddy.buddy_disconnected() - self.factory.transitFinished(self, self._total_sent) + self.factory.transitFinished(self) + + # Record usage. There are four cases: + # * 1: we connected, never had a buddy + # * 2: we connected first, we disconnect before the buddy + # * 3: we connected first, buddy disconnects first + # * 4: buddy connected first, we disconnect before buddy + # * 5: buddy connected first, buddy disconnects first + + # whoever disconnects first gets to write the usage record (1,2,4) + + finished = time.time() + if not self._had_buddy: # 1 + total_time = finished - self._started + self.factory.recordUsage(self._started, u"lonely", 0, + total_time, None) + if self._had_buddy and self._buddy: # 2,4 + total_bytes = self._total_sent + self._buddy._total_sent + starts = [self._started, self._buddy._started] + total_time = finished - min(starts) + waiting_time = max(starts) - min(starts) + self.factory.recordUsage(self._started, u"happy", total_bytes, + total_time, waiting_time) def disconnect(self): self.transport.loseConnection() self.factory.transitFailed(self) + finished = time.time() + total_time = finished - self._started + self.factory.recordUsage(self._started, u"errory", 0, + total_time, None) class Transit(protocol.ServerFactory, service.MultiService): # I manage pairs of simultaneous connections to a secondary TCP port, @@ -110,8 +143,9 @@ class Transit(protocol.ServerFactory, service.MultiService): MAXTIME = 60*SECONDS protocol = TransitConnection - def __init__(self): + def __init__(self, db): service.MultiService.__init__(self) + self._db = db self._pending_requests = {} # token -> TransitConnection self._active_connections = set() # TransitConnection @@ -127,8 +161,20 @@ class Transit(protocol.ServerFactory, service.MultiService): self._pending_requests[token] = p log.msg("transit relay 1: %r" % token) # TODO: timer - def transitFinished(self, p, total_sent): - log.msg("transitFinished (%dB) %r" % (total_sent, p)) + + def recordUsage(self, started, result, total_bytes, + total_time, waiting_time): + log.msg("Transit.recordUsage (%dB)" % total_bytes) + self._db.execute("INSERT INTO `usage`" + " (`type`, `started`, `result`, `total_bytes`," + " `total_time`, `waiting_time`)" + " VALUES (?,?,?,?, ?,?)", + (u"transit", started, result, total_bytes, + total_time, waiting_time)) + self._db.commit() + + def transitFinished(self, p): + log.msg("transitFinished %r" % (p,)) for token,tc in self._pending_requests.items(): if tc is p: del self._pending_requests[token] diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index c61531d..e456e05 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -15,6 +15,7 @@ class ServerBase: __version__) s.setServiceParent(self.sp) self._relay_server = s.relay + self._transit_server = s.transit self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport self.transit = u"tcp:127.0.0.1:%d" % transitport d.addCallback(_got_ports) diff --git a/src/wormhole/test/test_blocking.py b/src/wormhole/test/test_blocking.py index 22199d7..60dc6f7 100644 --- a/src/wormhole/test/test_blocking.py +++ b/src/wormhole/test/test_blocking.py @@ -6,7 +6,9 @@ from twisted.internet.threads import deferToThread from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager, WrongPasswordError) from ..blocking.eventsource import EventSourceFollower -from ..blocking.transit import TransitSender, TransitReceiver +from ..blocking.transit import (TransitSender, TransitReceiver, + build_sender_handshake, + build_receiver_handshake) from .common import ServerBase APPID = u"appid" @@ -461,6 +463,8 @@ class Transit(_DoBothMixin, ServerBase, unittest.TestCase): r.add_their_direct_hints([]) r.add_their_relay_hints([]) + # it'd be nice to factor this chunk out with 'yield from', but that + # didn't appear until python-3.3, and isn't in py2 at all. (sp, rp) = yield self.doBoth([s.connect], [r.connect]) yield deferToThread(sp.send_record, b"01234") rec = yield deferToThread(rp.receive_record) @@ -508,3 +512,28 @@ class Transit(_DoBothMixin, ServerBase, unittest.TestCase): self.assertEqual(rec, b"01234") yield deferToThread(sp.close) yield deferToThread(rp.close) + # TODO: this may be racy if we don't poll the server to make sure + # it's witnessed the first connection closing before querying the DB + #import time + #yield deferToThread(time.sleep, 1) + + # check the transit relay's DB, make sure it counted the bytes + db = self._transit_server._db + c = db.execute("SELECT * FROM `usage` WHERE `type`=?", (u"transit",)) + rows = c.fetchall() + self.assertEqual(len(rows), 1) + row = rows[0] + self.assertEqual(row["result"], u"happy") + # Sender first writes relay_handshake and waits for OK, but that's + # not counted by the transit server. Then sender writes + # sender_handshake and waits for receiver_handshake. Then sender + # writes GO and the body. Body is length-prefixed SecretBox, so + # includes 4-byte length, 24-byte nonce, and 16-byte MAC. + sender_count = (len(build_sender_handshake(b""))+ + len(b"go\n")+ + 4+24+len(b"01234")+16) + # Receiver first writes relay_handshake and waits for OK, but that's + # not counted. Then receiver writes receiver_handshake and waits for + # sender_handshake+GO. + receiver_count = len(build_receiver_handshake(b"")) + self.assertEqual(row["total_bytes"], sender_count+receiver_count)