Avoid corrupting state if creating a new db fails

This commit is contained in:
Jean-Paul Calderone 2017-06-27 10:41:22 -04:00 committed by Brian Warner
parent 2ecdd02d24
commit efb77443bf
2 changed files with 66 additions and 15 deletions

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import os import os
import sqlite3 import sqlite3
import tempfile
from pkg_resources import resource_string from pkg_resources import resource_string
from twisted.python import log from twisted.python import log
@ -25,23 +26,9 @@ def dict_factory(cursor, row):
d[col[0]] = row[idx] d[col[0]] = row[idx]
return d return d
def get_db(dbfile, target_version=TARGET_VERSION): def _initialize_db_schema(db, target_version):
"""Open or create the given db file. The parent directory must exist. """Creates the application schema in the given database.
Returns the db connection object, or raises DBError.
""" """
must_create = (dbfile == ":memory:") or not os.path.exists(dbfile)
try:
db = sqlite3.connect(dbfile)
except (EnvironmentError, sqlite3.OperationalError) as e:
raise DBError("Unable to create/open db file %s: %s" % (dbfile, e))
db.row_factory = dict_factory
db.execute("PRAGMA foreign_keys = ON")
problems = db.execute("PRAGMA foreign_key_check").fetchall()
if problems:
raise DBError("failed foreign key check: %s" % (problems,))
if must_create:
log.msg("populating new database with schema v%s" % target_version) log.msg("populating new database with schema v%s" % target_version)
schema = get_schema(target_version) schema = get_schema(target_version)
db.executescript(schema) db.executescript(schema)
@ -49,6 +36,61 @@ def get_db(dbfile, target_version=TARGET_VERSION):
(target_version,)) (target_version,))
db.commit() db.commit()
def _initialize_db_connection(db):
"""Sets up the db connection object with a row factory and with necessary
foreign key settings.
"""
db.row_factory = dict_factory
db.execute("PRAGMA foreign_keys = ON")
problems = db.execute("PRAGMA foreign_key_check").fetchall()
if problems:
raise DBError("failed foreign key check: %s" % (problems,))
def _open_db_connection(dbfile):
"""Open a new connection to the SQLite3 database at the given path.
"""
try:
db = sqlite3.connect(dbfile)
except (EnvironmentError, sqlite3.OperationalError) as e:
raise DBError("Unable to create/open db file %s: %s" % (dbfile, e))
_initialize_db_connection(db)
return db
def _get_temporary_dbfile(dbfile):
"""Get a temporary filename near the given path.
"""
fd, name = tempfile.mkstemp(
prefix=os.path.basename(dbfile) + ".",
dir=os.path.dirname(dbfile)
)
os.close(fd)
return name
def _atomic_create_and_initialize_db(dbfile, target_version):
"""Create and return a new database, initialized with the application
schema.
If anything goes wrong, nothing is left at the ``dbfile`` path.
"""
temp_dbfile = _get_temporary_dbfile(dbfile)
db = _open_db_connection(temp_dbfile)
_initialize_db_schema(db, target_version)
db.close()
os.rename(temp_dbfile, dbfile)
return _open_db_connection(dbfile)
def get_db(dbfile, target_version=TARGET_VERSION):
"""Open or create the given db file. The parent directory must exist.
Returns the db connection object, or raises DBError.
"""
if dbfile == ":memory:":
db = _open_db_connection(dbfile)
_initialize_db_schema(db, target_version)
elif os.path.exists(dbfile):
db = _open_db_connection(dbfile)
else:
db = _atomic_create_and_initialize_db(dbfile, target_version)
try: try:
version = db.execute("SELECT version FROM version").fetchone()["version"] version = db.execute("SELECT version FROM version").fetchone()["version"]
except sqlite3.DatabaseError as e: except sqlite3.DatabaseError as e:

View File

@ -1,6 +1,8 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import os import os
from twisted.python import filepath
from twisted.trial import unittest from twisted.trial import unittest
from ..server import database
from ..server.database import get_db, TARGET_VERSION, dump_db from ..server.database import get_db, TARGET_VERSION, dump_db
class DB(unittest.TestCase): class DB(unittest.TestCase):
@ -11,6 +13,13 @@ class DB(unittest.TestCase):
self.assertEqual(len(rows), 1) self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["version"], TARGET_VERSION) self.assertEqual(rows[0]["version"], TARGET_VERSION)
def test_failed_create_allows_subsequent_create(self):
patch = self.patch(database, "get_schema", lambda version: b"this is a broken schema")
dbfile = filepath.FilePath(self.mktemp())
self.assertRaises(Exception, lambda: get_db(dbfile.path))
patch.restore()
get_db(dbfile.path)
def test_upgrade(self): def test_upgrade(self):
basedir = self.mktemp() basedir = self.mktemp()
os.mkdir(basedir) os.mkdir(basedir)