track transit usage in DB

This commit is contained in:
Brian Warner 2015-12-03 19:45:34 -08:00
parent a3656c162b
commit 909cdfa3dc
5 changed files with 89 additions and 9 deletions

View File

@ -21,7 +21,7 @@ CREATE INDEX `messages_idx` ON `messages` (`appid`, `channelid`);
CREATE TABLE `usage` CREATE TABLE `usage`
( (
`type` VARCHAR, -- "rendezvous" `type` VARCHAR, -- "rendezvous" or "transit"
`started` INTEGER, -- seconds since epoch, rounded to one day `started` INTEGER, -- seconds since epoch, rounded to one day
`result` VARCHAR, -- happy, scary, lonely, errory, pruney `result` VARCHAR, -- happy, scary, lonely, errory, pruney
-- rendezvous moods: -- rendezvous moods:
@ -31,7 +31,11 @@ CREATE TABLE `usage`
-- "errory": any side closes with mood=errory (other errors) -- "errory": any side closes with mood=errory (other errors)
-- "pruney": channels which get pruned for inactivity -- "pruney": channels which get pruned for inactivity
-- "crowded": three or more sides were involved -- "crowded": three or more sides were involved
`total_bytes` INTEGER, -- not yet used -- transit moods:
-- "errory": this side have the wrong handshake
-- "lonely": good handshake, but the other side never showed up
-- "happy": both sides gave correct handshake
`total_bytes` INTEGER, -- for transit, total bytes relayed (both directions)
`total_time` INTEGER, -- seconds from start to closed, or None `total_time` INTEGER, -- seconds from start to closed, or None
`waiting_time` INTEGER -- seconds from start to 2nd side appearing, or None `waiting_time` INTEGER -- seconds from start to 2nd side appearing, or None
); );

View File

@ -38,7 +38,7 @@ class RelayServer(service.MultiService):
self.relay = Relay(self.db, welcome) # accessible from tests self.relay = Relay(self.db, welcome) # accessible from tests
self.root.putChild(b"wormhole-relay", self.relay) self.root.putChild(b"wormhole-relay", self.relay)
if transitport: if transitport:
self.transit = Transit() self.transit = Transit(self.db)
self.transit.setServiceParent(self) # for the timer self.transit.setServiceParent(self) # for the timer
t = endpoints.serverFromString(reactor, transitport) t = endpoints.serverFromString(reactor, transitport)
self.transport_service = ServerEndpointService(t, self.transit) self.transport_service = ServerEndpointService(t, self.transit)

View File

@ -1,5 +1,5 @@
from __future__ import print_function from __future__ import print_function
import re import re, time
from twisted.python import log from twisted.python import log
from twisted.internet import protocol from twisted.internet import protocol
from twisted.application import service from twisted.application import service
@ -16,8 +16,12 @@ class TransitConnection(protocol.Protocol):
self._token_buffer = b"" self._token_buffer = b""
self._sent_ok = False self._sent_ok = False
self._buddy = None self._buddy = None
self._had_buddy = False
self._total_sent = 0 self._total_sent = 0
def connectionMade(self):
self._started = time.time()
def dataReceived(self, data): def dataReceived(self, data):
if self._sent_ok: if self._sent_ok:
# We are an IPushProducer to our buddy's IConsumer, so they'll # We are an IPushProducer to our buddy's IConsumer, so they'll
@ -29,10 +33,12 @@ class TransitConnection(protocol.Protocol):
self._total_sent += len(data) self._total_sent += len(data)
self._buddy.transport.write(data) self._buddy.transport.write(data)
return return
if self._got_token: # but not yet sent_ok if self._got_token: # but not yet sent_ok
self.transport.write(b"impatient\n") self.transport.write(b"impatient\n")
log.msg("transit impatience failure") log.msg("transit impatience failure")
return self.disconnect() # impatience yields failure return self.disconnect() # impatience yields failure
# else this should be (part of) the token # else this should be (part of) the token
self._token_buffer += data self._token_buffer += data
buf = self._token_buffer buf = self._token_buffer
@ -59,6 +65,7 @@ class TransitConnection(protocol.Protocol):
def buddy_connected(self, them): def buddy_connected(self, them):
self._buddy = them self._buddy = them
self._had_buddy = True
self.transport.write(b"ok\n") self.transport.write(b"ok\n")
self._sent_ok = True self._sent_ok = True
# Connect the two as a producer/consumer pair. We use streaming=True, # Connect the two as a producer/consumer pair. We use streaming=True,
@ -77,11 +84,37 @@ class TransitConnection(protocol.Protocol):
log.msg("connectionLost %r %s" % (self, reason)) log.msg("connectionLost %r %s" % (self, reason))
if self._buddy: if self._buddy:
self._buddy.buddy_disconnected() self._buddy.buddy_disconnected()
self.factory.transitFinished(self, self._total_sent) self.factory.transitFinished(self)
# Record usage. There are four cases:
# * 1: we connected, never had a buddy
# * 2: we connected first, we disconnect before the buddy
# * 3: we connected first, buddy disconnects first
# * 4: buddy connected first, we disconnect before buddy
# * 5: buddy connected first, buddy disconnects first
# whoever disconnects first gets to write the usage record (1,2,4)
finished = time.time()
if not self._had_buddy: # 1
total_time = finished - self._started
self.factory.recordUsage(self._started, u"lonely", 0,
total_time, None)
if self._had_buddy and self._buddy: # 2,4
total_bytes = self._total_sent + self._buddy._total_sent
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
self.factory.recordUsage(self._started, u"happy", total_bytes,
total_time, waiting_time)
def disconnect(self): def disconnect(self):
self.transport.loseConnection() self.transport.loseConnection()
self.factory.transitFailed(self) self.factory.transitFailed(self)
finished = time.time()
total_time = finished - self._started
self.factory.recordUsage(self._started, u"errory", 0,
total_time, None)
class Transit(protocol.ServerFactory, service.MultiService): class Transit(protocol.ServerFactory, service.MultiService):
# I manage pairs of simultaneous connections to a secondary TCP port, # I manage pairs of simultaneous connections to a secondary TCP port,
@ -110,8 +143,9 @@ class Transit(protocol.ServerFactory, service.MultiService):
MAXTIME = 60*SECONDS MAXTIME = 60*SECONDS
protocol = TransitConnection protocol = TransitConnection
def __init__(self): def __init__(self, db):
service.MultiService.__init__(self) service.MultiService.__init__(self)
self._db = db
self._pending_requests = {} # token -> TransitConnection self._pending_requests = {} # token -> TransitConnection
self._active_connections = set() # TransitConnection self._active_connections = set() # TransitConnection
@ -127,8 +161,20 @@ class Transit(protocol.ServerFactory, service.MultiService):
self._pending_requests[token] = p self._pending_requests[token] = p
log.msg("transit relay 1: %r" % token) log.msg("transit relay 1: %r" % token)
# TODO: timer # TODO: timer
def transitFinished(self, p, total_sent):
log.msg("transitFinished (%dB) %r" % (total_sent, p)) def recordUsage(self, started, result, total_bytes,
total_time, waiting_time):
log.msg("Transit.recordUsage (%dB)" % total_bytes)
self._db.execute("INSERT INTO `usage`"
" (`type`, `started`, `result`, `total_bytes`,"
" `total_time`, `waiting_time`)"
" VALUES (?,?,?,?, ?,?)",
(u"transit", started, result, total_bytes,
total_time, waiting_time))
self._db.commit()
def transitFinished(self, p):
log.msg("transitFinished %r" % (p,))
for token,tc in self._pending_requests.items(): for token,tc in self._pending_requests.items():
if tc is p: if tc is p:
del self._pending_requests[token] del self._pending_requests[token]

View File

@ -15,6 +15,7 @@ class ServerBase:
__version__) __version__)
s.setServiceParent(self.sp) s.setServiceParent(self.sp)
self._relay_server = s.relay self._relay_server = s.relay
self._transit_server = s.transit
self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport
self.transit = u"tcp:127.0.0.1:%d" % transitport self.transit = u"tcp:127.0.0.1:%d" % transitport
d.addCallback(_got_ports) d.addCallback(_got_ports)

View File

@ -6,7 +6,9 @@ from twisted.internet.threads import deferToThread
from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager, from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager,
WrongPasswordError) WrongPasswordError)
from ..blocking.eventsource import EventSourceFollower from ..blocking.eventsource import EventSourceFollower
from ..blocking.transit import TransitSender, TransitReceiver from ..blocking.transit import (TransitSender, TransitReceiver,
build_sender_handshake,
build_receiver_handshake)
from .common import ServerBase from .common import ServerBase
APPID = u"appid" APPID = u"appid"
@ -461,6 +463,8 @@ class Transit(_DoBothMixin, ServerBase, unittest.TestCase):
r.add_their_direct_hints([]) r.add_their_direct_hints([])
r.add_their_relay_hints([]) r.add_their_relay_hints([])
# it'd be nice to factor this chunk out with 'yield from', but that
# didn't appear until python-3.3, and isn't in py2 at all.
(sp, rp) = yield self.doBoth([s.connect], [r.connect]) (sp, rp) = yield self.doBoth([s.connect], [r.connect])
yield deferToThread(sp.send_record, b"01234") yield deferToThread(sp.send_record, b"01234")
rec = yield deferToThread(rp.receive_record) rec = yield deferToThread(rp.receive_record)
@ -508,3 +512,28 @@ class Transit(_DoBothMixin, ServerBase, unittest.TestCase):
self.assertEqual(rec, b"01234") self.assertEqual(rec, b"01234")
yield deferToThread(sp.close) yield deferToThread(sp.close)
yield deferToThread(rp.close) yield deferToThread(rp.close)
# TODO: this may be racy if we don't poll the server to make sure
# it's witnessed the first connection closing before querying the DB
#import time
#yield deferToThread(time.sleep, 1)
# check the transit relay's DB, make sure it counted the bytes
db = self._transit_server._db
c = db.execute("SELECT * FROM `usage` WHERE `type`=?", (u"transit",))
rows = c.fetchall()
self.assertEqual(len(rows), 1)
row = rows[0]
self.assertEqual(row["result"], u"happy")
# Sender first writes relay_handshake and waits for OK, but that's
# not counted by the transit server. Then sender writes
# sender_handshake and waits for receiver_handshake. Then sender
# writes GO and the body. Body is length-prefixed SecretBox, so
# includes 4-byte length, 24-byte nonce, and 16-byte MAC.
sender_count = (len(build_sender_handshake(b""))+
len(b"go\n")+
4+24+len(b"01234")+16)
# Receiver first writes relay_handshake and waits for OK, but that's
# not counted. Then receiver writes receiver_handshake and waits for
# sender_handshake+GO.
receiver_count = len(build_receiver_handshake(b""))
self.assertEqual(row["total_bytes"], sender_count+receiver_count)