make it all basically work, improve test coverage

This commit is contained in:
Brian Warner 2017-11-07 19:29:26 -06:00
parent a898a65b09
commit 83e1c8acfe
7 changed files with 142 additions and 76 deletions

View File

@ -9,12 +9,12 @@ class DBError(Exception):
pass
def get_schema(version):
schema_bytes = resource_string("wormhole.server",
schema_bytes = resource_string("wormhole_transit_relay",
"db-schemas/v%d.sql" % version)
return schema_bytes.decode("utf-8")
def get_upgrader(new_version):
schema_bytes = resource_string("wormhole.server",
schema_bytes = resource_string("wormhole_transit_relay",
"db-schemas/upgrade-to-v%d.sql" % new_version)
return schema_bytes.decode("utf-8")

View File

@ -7,8 +7,8 @@ CREATE TABLE `version` -- contains one row
CREATE TABLE `current` -- contains one row
(
`reboot` INTEGER, -- seconds since epoch of most recent reboot
`last_update` INTEGER, -- when `current` was last updated
`rebooted` INTEGER, -- seconds since epoch of most recent reboot
`updated` INTEGER, -- when `current` was last updated
`connected` INTEGER, -- number of current paired connections
`waiting` INTEGER, -- number of not-yet-paired connections
`incomplete_bytes` INTEGER -- bytes sent through not-yet-complete connections
@ -26,5 +26,5 @@ CREATE TABLE `usage`
-- "lonely": good handshake, but the other side never showed up
-- "happy": both sides gave correct handshake
);
CREATE INDEX `transit_usage_idx` ON `transit_usage` (`started`);
CREATE INDEX `transit_usage_result_idx` ON `transit_usage` (`result`);
CREATE INDEX `usage_started_index` ON `usage` (`started`);
CREATE INDEX `usage_result_index` ON `usage` (`result`);

View File

@ -1,6 +1,8 @@
import sys
from . import transit_server
from twisted.internet import reactor
from twisted.python import usage
from twisted.application.service import MultiService
from twisted.application.internet import (TimerService,
StreamServerEndpointService)
from twisted.internet import endpoints
@ -92,10 +94,11 @@ class Options(usage.Options):
def makeService(config, reactor=reactor):
ep = endpoints.serverFromString(reactor, config["port"]) # to listen
log_file = sys.stdout if config["log-stdout"] else None
f = transit_server.Transit(blur_usage=config["blur-usage"],
log_stdout=config["log-stdout"],
log_file=log_file,
usage_db=config["usage-db"])
parent = service.MultiService()
parent = MultiService()
StreamServerEndpointService(ep, f).setServiceParent(parent)
TimerService(5.0, f.timerUpdateStats).setServiceParent(parent)
TimerService(5*60.0, f.timerUpdateStats).setServiceParent(parent)
return parent

View File

@ -4,16 +4,16 @@ from twisted.internet.defer import inlineCallbacks
from ..transit_server import Transit
class ServerBase:
@inlineCallbacks
def setUp(self):
self._lp = None
self._setup_relay()
yield self._setup_relay()
@inlineCallbacks
def _setup_relay(self, blur_usage=None, usage_logfile=None, stats_file=None):
def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None):
ep = endpoints.TCP4ServerEndpoint(reactor, 0, interface="127.0.0.1")
self._transit_server = Transit(blur_usage=blur_usage,
usage_logfile=usage_logfile,
stats_file=stats_file)
log_file=log_file, usage_db=usage_db)
self._lp = yield ep.listen(self._transit_server)
addr = self._lp.getHost()
# ws://127.0.0.1:%d/wormhole-relay/ws

View File

@ -1,59 +1,89 @@
from __future__ import print_function, unicode_literals
import os, json
import os, io, json, sqlite3
import mock
from twisted.trial import unittest
from ..transit_server import Transit
from .. import database
class UsageLog(unittest.TestCase):
def test_log(self):
class DB(unittest.TestCase):
def open_db(self, dbfile):
db = sqlite3.connect(dbfile)
database._initialize_db_connection(db)
return db
def test_db(self):
d = self.mktemp()
os.mkdir(d)
usage_logfile = os.path.join(d, "usage.log")
def read():
with open(usage_logfile, "r") as f:
return [json.loads(line) for line in f.readlines()]
t = Transit(None, usage_logfile, None)
t.recordUsage(started=123, result="happy", total_bytes=100,
total_time=10, waiting_time=2)
self.assertEqual(read(), [dict(started=123, mood="happy",
total_time=10, waiting_time=2,
total_bytes=100)])
usage_db = os.path.join(d, "usage.sqlite")
with mock.patch("time.time", return_value=456):
t = Transit(blur_usage=None, log_file=None, usage_db=usage_db)
db = self.open_db(usage_db)
t.recordUsage(started=150, result="errory", total_bytes=200,
total_time=11, waiting_time=3)
self.assertEqual(read(), [dict(started=123, mood="happy",
total_time=10, waiting_time=2,
total_bytes=100),
dict(started=150, mood="errory",
total_time=11, waiting_time=3,
total_bytes=200),
])
if False:
# the current design opens the logfile exactly once, at process
# start, in the faint hopes of surviving an exhaustion of available
# file descriptors. This should be rethought.
os.unlink(usage_logfile)
t.recordUsage(started=200, result="lonely", total_bytes=300,
total_time=12, waiting_time=4)
self.assertEqual(read(), [dict(started=200, mood="lonely",
total_time=12, waiting_time=4,
total_bytes=300)])
class StandardLogfile(unittest.TestCase):
def test_log(self):
# the default, when _blur_usage is None, will log to twistd.log
t = Transit(blur_usage=None, usage_logfile=None, stats_file=None)
with mock.patch("twisted.python.log.msg") as m:
t.recordUsage(started=123, result="happy", total_bytes=100,
total_time=10, waiting_time=2)
self.assertEqual(m.mock_calls, [mock.call(format="Transit.recordUsage {bytes}B", bytes=100)])
def test_do_not_log(self):
# the default, when _blur_usage is None, will log to twistd.log
t = Transit(blur_usage=60, usage_logfile=None, stats_file=None)
with mock.patch("twisted.python.log.msg") as m:
with mock.patch("time.time", return_value=457):
t.recordUsage(started=123, result="happy", total_bytes=100,
total_time=10, waiting_time=2)
self.assertEqual(m.mock_calls, [])
self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(),
[dict(result="happy", started=123,
total_bytes=100, total_time=10, waiting_time=2),
])
self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(),
dict(rebooted=456, updated=457,
incomplete_bytes=0,
waiting=0, connected=0))
with mock.patch("time.time", return_value=458):
t.recordUsage(started=150, result="errory", total_bytes=200,
total_time=11, waiting_time=3)
self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(),
[dict(result="happy", started=123,
total_bytes=100, total_time=10, waiting_time=2),
dict(result="errory", started=150,
total_bytes=200, total_time=11, waiting_time=3),
])
self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(),
dict(rebooted=456, updated=458,
incomplete_bytes=0,
waiting=0, connected=0))
with mock.patch("time.time", return_value=459):
t.timerUpdateStats()
self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(),
dict(rebooted=456, updated=459,
incomplete_bytes=0,
waiting=0, connected=0))
def test_no_db(self):
t = Transit(blur_usage=None, log_file=None, usage_db=None)
t.recordUsage(started=123, result="happy", total_bytes=100,
total_time=10, waiting_time=2)
t.timerUpdateStats()
class LogToStdout(unittest.TestCase):
def test_log(self):
# emit lines of JSON to log_file, if set
log_file = io.StringIO()
t = Transit(blur_usage=None, log_file=log_file, usage_db=None)
t.recordUsage(started=123, result="happy", total_bytes=100,
total_time=10, waiting_time=2)
self.assertEqual(json.loads(log_file.getvalue()),
{"started": 123, "total_time": 10,
"waiting_time": 2, "total_bytes": 100,
"mood": "happy"})
def test_log_blurred(self):
# if blurring is enabled, timestamps should be rounded to the
# requested amount, and sizes should be rounded up too
log_file = io.StringIO()
t = Transit(blur_usage=60, log_file=log_file, usage_db=None)
t.recordUsage(started=123, result="happy", total_bytes=11999,
total_time=10, waiting_time=2)
self.assertEqual(json.loads(log_file.getvalue()),
{"started": 120, "total_time": 10,
"waiting_time": 2, "total_bytes": 20000,
"mood": "happy"})
def test_do_not_log(self):
t = Transit(blur_usage=60, log_file=None, usage_db=None)
t.recordUsage(started=123, result="happy", total_bytes=11999,
total_time=10, waiting_time=2)

View File

@ -298,3 +298,34 @@ class Transit(ServerBase, unittest.TestCase):
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience_new2(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
# For full coverage, we need dataReceived to see a particular framing
# of these two pieces of data, and ITCPTransport doesn't have flush()
# (which probably wouldn't work anyways). For now, force a 100ms
# stall between the two writes. I tried setTcpNoDelay(True) but it
# didn't seem to help without the stall. The long-term fix is to
# rewrite dataReceived() to remove the multiple "impatient"
# codepaths, deleting the particular clause that this test exercises,
# then remove this test.
token1 = b"\x00"*32
side1 = b"\x01"*8
# sending too many bytes is impatience.
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
d = defer.Deferred()
reactor.callLater(0.1, d.callback, None)
yield d
a1.transport.write(b"NOWNOWNOW")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()

View File

@ -1,5 +1,5 @@
from __future__ import print_function, unicode_literals
import os, re, time, json
import re, time, json
from twisted.python import log
from twisted.internet import protocol
from .database import get_db
@ -221,13 +221,15 @@ class Transit(protocol.ServerFactory):
MAXTIME = 60*SECONDS
protocol = TransitConnection
def __init__(self, blur_usage, log_stdout, usage_db):
def __init__(self, blur_usage, log_file, usage_db):
self._blur_usage = blur_usage
self._log_requests = blur_usage is None
self._debug_log = False
self._log_stdout = log_stdout
self._log_file = log_file
self._db = None
if usage_db:
self._db = get_db(usage_db)
self._rebooted = time.time()
# we don't track TransitConnections until they submit a token
self._pending_requests = {} # token -> set((side, TransitConnection))
self._active_connections = set() # TransitConnection
@ -285,16 +287,15 @@ class Transit(protocol.ServerFactory):
if self._blur_usage:
started = self._blur_usage * (started // self._blur_usage)
total_bytes = blur_size(total_bytes)
if self._log_stdout:
if self._log_file is not None:
data = {"started": started,
"total_time": total_time,
"waiting_time": waiting_time,
"total_bytes": total_bytes,
"mood": result,
}
sys.stdout.write(json.dumps(data))
sys.stdout.write("\n")
sys.stdout.flush()
self._log_file.write(json.dumps(data)+"\n")
self._log_file.flush()
if self._db:
self._db.execute("INSERT INTO `usage`"
" (`started`, `total_time`, `waiting_time`,"
@ -306,26 +307,27 @@ class Transit(protocol.ServerFactory):
self._db.commit()
def timerUpdateStats(self):
self._update_stats()
self._db.commit()
if self._db:
self._update_stats()
self._db.commit()
def _update_stats(self):
# current status: should be zero when idle
reboot = self._reboot
last_update = time.time()
rebooted = self._rebooted
updated = time.time()
connected = len(self._active_connections) / 2
# TODO: when a connection is half-closed, len(active) will be odd. a
# moment later (hopefully) the other side will disconnect, but
# _update_stats isn't updated until later.
waiting = len(self._pending_tokens)
waiting = len(self._pending_requests)
# "waiting" doesn't count multiple parallel connections from the same
# side
incomplete_bytes = sum(tc._total_sent
for tc in self._active_connections)
self._db.execute("DELETE FROM `current`")
self._db.execute("INSERT INTO `current`"
" (`reboot`, `last_update`, `connected`, `waiting`,"
" (`rebooted`, `updated`, `connected`, `waiting`,"
" `incomplete_bytes`)"
" VALUES (?, ?, ?, ?, ?)",
(reboot, last_update, connected, waiting,
(rebooted, updated, connected, waiting,
incomplete_bytes))