DB: small cleanups, improve test coverage
This commit is contained in:
parent
9008d4339a
commit
c270ad6e0b
|
@ -13,10 +13,10 @@ def get_schema(version):
|
||||||
"db-schemas/v%d.sql" % version)
|
"db-schemas/v%d.sql" % version)
|
||||||
return schema_bytes.decode("utf-8")
|
return schema_bytes.decode("utf-8")
|
||||||
|
|
||||||
def get_upgrader(new_version):
|
## def get_upgrader(new_version):
|
||||||
schema_bytes = resource_string("wormhole_transit_relay",
|
## schema_bytes = resource_string("wormhole_transit_relay",
|
||||||
"db-schemas/upgrade-to-v%d.sql" % new_version)
|
## "db-schemas/upgrade-to-v%d.sql" % new_version)
|
||||||
return schema_bytes.decode("utf-8")
|
## return schema_bytes.decode("utf-8")
|
||||||
|
|
||||||
TARGET_VERSION = 1
|
TARGET_VERSION = 1
|
||||||
|
|
||||||
|
@ -51,9 +51,11 @@ def _open_db_connection(dbfile):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db = sqlite3.connect(dbfile)
|
db = sqlite3.connect(dbfile)
|
||||||
except (EnvironmentError, sqlite3.OperationalError) as e:
|
_initialize_db_connection(db)
|
||||||
|
except (EnvironmentError, sqlite3.OperationalError, sqlite3.DatabaseError) as e:
|
||||||
|
# this indicates that the file is not a compatible database format.
|
||||||
|
# Perhaps it was created with an old version, or it might be junk.
|
||||||
raise DBError("Unable to create/open db file %s: %s" % (dbfile, e))
|
raise DBError("Unable to create/open db file %s: %s" % (dbfile, e))
|
||||||
_initialize_db_connection(db)
|
|
||||||
return db
|
return db
|
||||||
|
|
||||||
def _get_temporary_dbfile(dbfile):
|
def _get_temporary_dbfile(dbfile):
|
||||||
|
@ -91,25 +93,20 @@ def get_db(dbfile, target_version=TARGET_VERSION):
|
||||||
else:
|
else:
|
||||||
db = _atomic_create_and_initialize_db(dbfile, target_version)
|
db = _atomic_create_and_initialize_db(dbfile, target_version)
|
||||||
|
|
||||||
try:
|
version = db.execute("SELECT version FROM version").fetchone()["version"]
|
||||||
version = db.execute("SELECT version FROM version").fetchone()["version"]
|
|
||||||
except sqlite3.DatabaseError as e:
|
|
||||||
# this indicates that the file is not a compatible database format.
|
|
||||||
# Perhaps it was created with an old version, or it might be junk.
|
|
||||||
raise DBError("db file is unusable: %s" % e)
|
|
||||||
|
|
||||||
while version < target_version:
|
## while version < target_version:
|
||||||
log.msg(" need to upgrade from %s to %s" % (version, target_version))
|
## log.msg(" need to upgrade from %s to %s" % (version, target_version))
|
||||||
try:
|
## try:
|
||||||
upgrader = get_upgrader(version+1)
|
## upgrader = get_upgrader(version+1)
|
||||||
except ValueError: # ResourceError??
|
## except ValueError: # ResourceError??
|
||||||
log.msg(" unable to upgrade %s to %s" % (version, version+1))
|
## log.msg(" unable to upgrade %s to %s" % (version, version+1))
|
||||||
raise DBError("Unable to upgrade %s to version %s, left at %s"
|
## raise DBError("Unable to upgrade %s to version %s, left at %s"
|
||||||
% (dbfile, version+1, version))
|
## % (dbfile, version+1, version))
|
||||||
log.msg(" executing upgrader v%s->v%s" % (version, version+1))
|
## log.msg(" executing upgrader v%s->v%s" % (version, version+1))
|
||||||
db.executescript(upgrader)
|
## db.executescript(upgrader)
|
||||||
db.commit()
|
## db.commit()
|
||||||
version = version+1
|
## version = version+1
|
||||||
|
|
||||||
if version != target_version:
|
if version != target_version:
|
||||||
raise DBError("Unable to handle db version %s" % version)
|
raise DBError("Unable to handle db version %s" % version)
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
from twisted.python import filepath
|
from twisted.python import filepath
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from .. import database
|
from .. import database
|
||||||
from ..database import get_db, TARGET_VERSION, dump_db
|
from ..database import get_db, TARGET_VERSION, dump_db, DBError
|
||||||
|
|
||||||
class Get(unittest.TestCase):
|
class Get(unittest.TestCase):
|
||||||
def test_create_default(self):
|
def test_create_default(self):
|
||||||
|
@ -13,6 +13,41 @@ class Get(unittest.TestCase):
|
||||||
self.assertEqual(len(rows), 1)
|
self.assertEqual(len(rows), 1)
|
||||||
self.assertEqual(rows[0]["version"], TARGET_VERSION)
|
self.assertEqual(rows[0]["version"], TARGET_VERSION)
|
||||||
|
|
||||||
|
def test_open_existing_file(self):
|
||||||
|
basedir = self.mktemp()
|
||||||
|
os.mkdir(basedir)
|
||||||
|
fn = os.path.join(basedir, "normal.db")
|
||||||
|
db = get_db(fn)
|
||||||
|
rows = db.execute("SELECT * FROM version").fetchall()
|
||||||
|
self.assertEqual(len(rows), 1)
|
||||||
|
self.assertEqual(rows[0]["version"], TARGET_VERSION)
|
||||||
|
db2 = get_db(fn)
|
||||||
|
rows = db2.execute("SELECT * FROM version").fetchall()
|
||||||
|
self.assertEqual(len(rows), 1)
|
||||||
|
self.assertEqual(rows[0]["version"], TARGET_VERSION)
|
||||||
|
|
||||||
|
def test_open_bad_version(self):
|
||||||
|
basedir = self.mktemp()
|
||||||
|
os.mkdir(basedir)
|
||||||
|
fn = os.path.join(basedir, "old.db")
|
||||||
|
db = get_db(fn)
|
||||||
|
db.execute("UPDATE version SET version=999")
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
with self.assertRaises(DBError) as e:
|
||||||
|
get_db(fn)
|
||||||
|
self.assertIn("Unable to handle db version 999", str(e.exception))
|
||||||
|
|
||||||
|
def test_open_corrupt(self):
|
||||||
|
basedir = self.mktemp()
|
||||||
|
os.mkdir(basedir)
|
||||||
|
fn = os.path.join(basedir, "corrupt.db")
|
||||||
|
with open(fn, "wb") as f:
|
||||||
|
f.write(b"I am not a database")
|
||||||
|
with self.assertRaises(DBError) as e:
|
||||||
|
get_db(fn)
|
||||||
|
self.assertIn("not a database", str(e.exception))
|
||||||
|
|
||||||
def test_failed_create_allows_subsequent_create(self):
|
def test_failed_create_allows_subsequent_create(self):
|
||||||
patch = self.patch(database, "get_schema", lambda version: b"this is a broken schema")
|
patch = self.patch(database, "get_schema", lambda version: b"this is a broken schema")
|
||||||
dbfile = filepath.FilePath(self.mktemp())
|
dbfile = filepath.FilePath(self.mktemp())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user