make it all basically work, improve test coverage

This commit is contained in:
Brian Warner 2017-11-07 19:29:26 -06:00
parent a898a65b09
commit 83e1c8acfe
7 changed files with 142 additions and 76 deletions

View File

@ -9,12 +9,12 @@ class DBError(Exception):
pass pass
def get_schema(version): def get_schema(version):
schema_bytes = resource_string("wormhole.server", schema_bytes = resource_string("wormhole_transit_relay",
"db-schemas/v%d.sql" % version) "db-schemas/v%d.sql" % version)
return schema_bytes.decode("utf-8") return schema_bytes.decode("utf-8")
def get_upgrader(new_version): def get_upgrader(new_version):
schema_bytes = resource_string("wormhole.server", schema_bytes = resource_string("wormhole_transit_relay",
"db-schemas/upgrade-to-v%d.sql" % new_version) "db-schemas/upgrade-to-v%d.sql" % new_version)
return schema_bytes.decode("utf-8") return schema_bytes.decode("utf-8")

View File

@ -7,8 +7,8 @@ CREATE TABLE `version` -- contains one row
CREATE TABLE `current` -- contains one row CREATE TABLE `current` -- contains one row
( (
`reboot` INTEGER, -- seconds since epoch of most recent reboot `rebooted` INTEGER, -- seconds since epoch of most recent reboot
`last_update` INTEGER, -- when `current` was last updated `updated` INTEGER, -- when `current` was last updated
`connected` INTEGER, -- number of current paired connections `connected` INTEGER, -- number of current paired connections
`waiting` INTEGER, -- number of not-yet-paired connections `waiting` INTEGER, -- number of not-yet-paired connections
`incomplete_bytes` INTEGER -- bytes sent through not-yet-complete connections `incomplete_bytes` INTEGER -- bytes sent through not-yet-complete connections
@ -26,5 +26,5 @@ CREATE TABLE `usage`
-- "lonely": good handshake, but the other side never showed up -- "lonely": good handshake, but the other side never showed up
-- "happy": both sides gave correct handshake -- "happy": both sides gave correct handshake
); );
CREATE INDEX `transit_usage_idx` ON `transit_usage` (`started`); CREATE INDEX `usage_started_index` ON `usage` (`started`);
CREATE INDEX `transit_usage_result_idx` ON `transit_usage` (`result`); CREATE INDEX `usage_result_index` ON `usage` (`result`);

View File

@ -1,6 +1,8 @@
import sys
from . import transit_server from . import transit_server
from twisted.internet import reactor from twisted.internet import reactor
from twisted.python import usage from twisted.python import usage
from twisted.application.service import MultiService
from twisted.application.internet import (TimerService, from twisted.application.internet import (TimerService,
StreamServerEndpointService) StreamServerEndpointService)
from twisted.internet import endpoints from twisted.internet import endpoints
@ -92,10 +94,11 @@ class Options(usage.Options):
def makeService(config, reactor=reactor): def makeService(config, reactor=reactor):
ep = endpoints.serverFromString(reactor, config["port"]) # to listen ep = endpoints.serverFromString(reactor, config["port"]) # to listen
log_file = sys.stdout if config["log-stdout"] else None
f = transit_server.Transit(blur_usage=config["blur-usage"], f = transit_server.Transit(blur_usage=config["blur-usage"],
log_stdout=config["log-stdout"], log_file=log_file,
usage_db=config["usage-db"]) usage_db=config["usage-db"])
parent = service.MultiService() parent = MultiService()
StreamServerEndpointService(ep, f).setServiceParent(parent) StreamServerEndpointService(ep, f).setServiceParent(parent)
TimerService(5.0, f.timerUpdateStats).setServiceParent(parent) TimerService(5*60.0, f.timerUpdateStats).setServiceParent(parent)
return parent return parent

View File

@ -4,16 +4,16 @@ from twisted.internet.defer import inlineCallbacks
from ..transit_server import Transit from ..transit_server import Transit
class ServerBase: class ServerBase:
@inlineCallbacks
def setUp(self): def setUp(self):
self._lp = None self._lp = None
self._setup_relay() yield self._setup_relay()
@inlineCallbacks @inlineCallbacks
def _setup_relay(self, blur_usage=None, usage_logfile=None, stats_file=None): def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None):
ep = endpoints.TCP4ServerEndpoint(reactor, 0, interface="127.0.0.1") ep = endpoints.TCP4ServerEndpoint(reactor, 0, interface="127.0.0.1")
self._transit_server = Transit(blur_usage=blur_usage, self._transit_server = Transit(blur_usage=blur_usage,
usage_logfile=usage_logfile, log_file=log_file, usage_db=usage_db)
stats_file=stats_file)
self._lp = yield ep.listen(self._transit_server) self._lp = yield ep.listen(self._transit_server)
addr = self._lp.getHost() addr = self._lp.getHost()
# ws://127.0.0.1:%d/wormhole-relay/ws # ws://127.0.0.1:%d/wormhole-relay/ws

View File

@ -1,59 +1,89 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import os, json import os, io, json, sqlite3
import mock import mock
from twisted.trial import unittest from twisted.trial import unittest
from ..transit_server import Transit from ..transit_server import Transit
from .. import database
class UsageLog(unittest.TestCase): class DB(unittest.TestCase):
def test_log(self): def open_db(self, dbfile):
db = sqlite3.connect(dbfile)
database._initialize_db_connection(db)
return db
def test_db(self):
d = self.mktemp() d = self.mktemp()
os.mkdir(d) os.mkdir(d)
usage_logfile = os.path.join(d, "usage.log") usage_db = os.path.join(d, "usage.sqlite")
def read(): with mock.patch("time.time", return_value=456):
with open(usage_logfile, "r") as f: t = Transit(blur_usage=None, log_file=None, usage_db=usage_db)
return [json.loads(line) for line in f.readlines()] db = self.open_db(usage_db)
t = Transit(None, usage_logfile, None)
with mock.patch("time.time", return_value=457):
t.recordUsage(started=123, result="happy", total_bytes=100, t.recordUsage(started=123, result="happy", total_bytes=100,
total_time=10, waiting_time=2) total_time=10, waiting_time=2)
self.assertEqual(read(), [dict(started=123, mood="happy", self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(),
total_time=10, waiting_time=2, [dict(result="happy", started=123,
total_bytes=100)]) total_bytes=100, total_time=10, waiting_time=2),
])
self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(),
dict(rebooted=456, updated=457,
incomplete_bytes=0,
waiting=0, connected=0))
with mock.patch("time.time", return_value=458):
t.recordUsage(started=150, result="errory", total_bytes=200, t.recordUsage(started=150, result="errory", total_bytes=200,
total_time=11, waiting_time=3) total_time=11, waiting_time=3)
self.assertEqual(read(), [dict(started=123, mood="happy", self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(),
total_time=10, waiting_time=2, [dict(result="happy", started=123,
total_bytes=100), total_bytes=100, total_time=10, waiting_time=2),
dict(started=150, mood="errory", dict(result="errory", started=150,
total_time=11, waiting_time=3, total_bytes=200, total_time=11, waiting_time=3),
total_bytes=200),
]) ])
self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(),
dict(rebooted=456, updated=458,
incomplete_bytes=0,
waiting=0, connected=0))
if False: with mock.patch("time.time", return_value=459):
# the current design opens the logfile exactly once, at process t.timerUpdateStats()
# start, in the faint hopes of surviving an exhaustion of available self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(),
# file descriptors. This should be rethought. dict(rebooted=456, updated=459,
os.unlink(usage_logfile) incomplete_bytes=0,
waiting=0, connected=0))
t.recordUsage(started=200, result="lonely", total_bytes=300, def test_no_db(self):
total_time=12, waiting_time=4) t = Transit(blur_usage=None, log_file=None, usage_db=None)
self.assertEqual(read(), [dict(started=200, mood="lonely",
total_time=12, waiting_time=4,
total_bytes=300)])
class StandardLogfile(unittest.TestCase):
def test_log(self):
# the default, when _blur_usage is None, will log to twistd.log
t = Transit(blur_usage=None, usage_logfile=None, stats_file=None)
with mock.patch("twisted.python.log.msg") as m:
t.recordUsage(started=123, result="happy", total_bytes=100, t.recordUsage(started=123, result="happy", total_bytes=100,
total_time=10, waiting_time=2) total_time=10, waiting_time=2)
self.assertEqual(m.mock_calls, [mock.call(format="Transit.recordUsage {bytes}B", bytes=100)]) t.timerUpdateStats()
class LogToStdout(unittest.TestCase):
def test_log(self):
# emit lines of JSON to log_file, if set
log_file = io.StringIO()
t = Transit(blur_usage=None, log_file=log_file, usage_db=None)
t.recordUsage(started=123, result="happy", total_bytes=100,
total_time=10, waiting_time=2)
self.assertEqual(json.loads(log_file.getvalue()),
{"started": 123, "total_time": 10,
"waiting_time": 2, "total_bytes": 100,
"mood": "happy"})
def test_log_blurred(self):
# if blurring is enabled, timestamps should be rounded to the
# requested amount, and sizes should be rounded up too
log_file = io.StringIO()
t = Transit(blur_usage=60, log_file=log_file, usage_db=None)
t.recordUsage(started=123, result="happy", total_bytes=11999,
total_time=10, waiting_time=2)
self.assertEqual(json.loads(log_file.getvalue()),
{"started": 120, "total_time": 10,
"waiting_time": 2, "total_bytes": 20000,
"mood": "happy"})
def test_do_not_log(self): def test_do_not_log(self):
# the default, when _blur_usage is None, will log to twistd.log t = Transit(blur_usage=60, log_file=None, usage_db=None)
t = Transit(blur_usage=60, usage_logfile=None, stats_file=None) t.recordUsage(started=123, result="happy", total_bytes=11999,
with mock.patch("twisted.python.log.msg") as m:
t.recordUsage(started=123, result="happy", total_bytes=100,
total_time=10, waiting_time=2) total_time=10, waiting_time=2)
self.assertEqual(m.mock_calls, [])

View File

@ -298,3 +298,34 @@ class Transit(ServerBase, unittest.TestCase):
self.assertEqual(a1.data, exp) self.assertEqual(a1.data, exp)
a1.transport.loseConnection() a1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience_new2(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
# For full coverage, we need dataReceived to see a particular framing
# of these two pieces of data, and ITCPTransport doesn't have flush()
# (which probably wouldn't work anyways). For now, force a 100ms
# stall between the two writes. I tried setTcpNoDelay(True) but it
# didn't seem to help without the stall. The long-term fix is to
# rewrite dataReceived() to remove the multiple "impatient"
# codepaths, deleting the particular clause that this test exercises,
# then remove this test.
token1 = b"\x00"*32
side1 = b"\x01"*8
# sending too many bytes is impatience.
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
d = defer.Deferred()
reactor.callLater(0.1, d.callback, None)
yield d
a1.transport.write(b"NOWNOWNOW")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()

View File

@ -1,5 +1,5 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import os, re, time, json import re, time, json
from twisted.python import log from twisted.python import log
from twisted.internet import protocol from twisted.internet import protocol
from .database import get_db from .database import get_db
@ -221,13 +221,15 @@ class Transit(protocol.ServerFactory):
MAXTIME = 60*SECONDS MAXTIME = 60*SECONDS
protocol = TransitConnection protocol = TransitConnection
def __init__(self, blur_usage, log_stdout, usage_db): def __init__(self, blur_usage, log_file, usage_db):
self._blur_usage = blur_usage self._blur_usage = blur_usage
self._log_requests = blur_usage is None
self._debug_log = False self._debug_log = False
self._log_stdout = log_stdout self._log_file = log_file
self._db = None self._db = None
if usage_db: if usage_db:
self._db = get_db(usage_db) self._db = get_db(usage_db)
self._rebooted = time.time()
# we don't track TransitConnections until they submit a token # we don't track TransitConnections until they submit a token
self._pending_requests = {} # token -> set((side, TransitConnection)) self._pending_requests = {} # token -> set((side, TransitConnection))
self._active_connections = set() # TransitConnection self._active_connections = set() # TransitConnection
@ -285,16 +287,15 @@ class Transit(protocol.ServerFactory):
if self._blur_usage: if self._blur_usage:
started = self._blur_usage * (started // self._blur_usage) started = self._blur_usage * (started // self._blur_usage)
total_bytes = blur_size(total_bytes) total_bytes = blur_size(total_bytes)
if self._log_stdout: if self._log_file is not None:
data = {"started": started, data = {"started": started,
"total_time": total_time, "total_time": total_time,
"waiting_time": waiting_time, "waiting_time": waiting_time,
"total_bytes": total_bytes, "total_bytes": total_bytes,
"mood": result, "mood": result,
} }
sys.stdout.write(json.dumps(data)) self._log_file.write(json.dumps(data)+"\n")
sys.stdout.write("\n") self._log_file.flush()
sys.stdout.flush()
if self._db: if self._db:
self._db.execute("INSERT INTO `usage`" self._db.execute("INSERT INTO `usage`"
" (`started`, `total_time`, `waiting_time`," " (`started`, `total_time`, `waiting_time`,"
@ -306,26 +307,27 @@ class Transit(protocol.ServerFactory):
self._db.commit() self._db.commit()
def timerUpdateStats(self): def timerUpdateStats(self):
if self._db:
self._update_stats() self._update_stats()
self._db.commit() self._db.commit()
def _update_stats(self): def _update_stats(self):
# current status: should be zero when idle # current status: should be zero when idle
reboot = self._reboot rebooted = self._rebooted
last_update = time.time() updated = time.time()
connected = len(self._active_connections) / 2 connected = len(self._active_connections) / 2
# TODO: when a connection is half-closed, len(active) will be odd. a # TODO: when a connection is half-closed, len(active) will be odd. a
# moment later (hopefully) the other side will disconnect, but # moment later (hopefully) the other side will disconnect, but
# _update_stats isn't updated until later. # _update_stats isn't updated until later.
waiting = len(self._pending_tokens) waiting = len(self._pending_requests)
# "waiting" doesn't count multiple parallel connections from the same # "waiting" doesn't count multiple parallel connections from the same
# side # side
incomplete_bytes = sum(tc._total_sent incomplete_bytes = sum(tc._total_sent
for tc in self._active_connections) for tc in self._active_connections)
self._db.execute("DELETE FROM `current`") self._db.execute("DELETE FROM `current`")
self._db.execute("INSERT INTO `current`" self._db.execute("INSERT INTO `current`"
" (`reboot`, `last_update`, `connected`, `waiting`," " (`rebooted`, `updated`, `connected`, `waiting`,"
" `incomplete_bytes`)" " `incomplete_bytes`)"
" VALUES (?, ?, ?, ?, ?)", " VALUES (?, ?, ?, ?, ?)",
(reboot, last_update, connected, waiting, (rebooted, updated, connected, waiting,
incomplete_bytes)) incomplete_bytes))