Avoid corrupting state if creating a new db fails
This commit is contained in:
parent
2ecdd02d24
commit
efb77443bf
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
from pkg_resources import resource_string
|
from pkg_resources import resource_string
|
||||||
from twisted.python import log
|
from twisted.python import log
|
||||||
|
|
||||||
|
@ -25,29 +26,70 @@ def dict_factory(cursor, row):
|
||||||
d[col[0]] = row[idx]
|
d[col[0]] = row[idx]
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def get_db(dbfile, target_version=TARGET_VERSION):
|
def _initialize_db_schema(db, target_version):
|
||||||
"""Open or create the given db file. The parent directory must exist.
|
"""Creates the application schema in the given database.
|
||||||
Returns the db connection object, or raises DBError.
|
|
||||||
"""
|
"""
|
||||||
|
log.msg("populating new database with schema v%s" % target_version)
|
||||||
|
schema = get_schema(target_version)
|
||||||
|
db.executescript(schema)
|
||||||
|
db.execute("INSERT INTO version (version) VALUES (?)",
|
||||||
|
(target_version,))
|
||||||
|
db.commit()
|
||||||
|
|
||||||
must_create = (dbfile == ":memory:") or not os.path.exists(dbfile)
|
def _initialize_db_connection(db):
|
||||||
try:
|
"""Sets up the db connection object with a row factory and with necessary
|
||||||
db = sqlite3.connect(dbfile)
|
foreign key settings.
|
||||||
except (EnvironmentError, sqlite3.OperationalError) as e:
|
"""
|
||||||
raise DBError("Unable to create/open db file %s: %s" % (dbfile, e))
|
|
||||||
db.row_factory = dict_factory
|
db.row_factory = dict_factory
|
||||||
db.execute("PRAGMA foreign_keys = ON")
|
db.execute("PRAGMA foreign_keys = ON")
|
||||||
problems = db.execute("PRAGMA foreign_key_check").fetchall()
|
problems = db.execute("PRAGMA foreign_key_check").fetchall()
|
||||||
if problems:
|
if problems:
|
||||||
raise DBError("failed foreign key check: %s" % (problems,))
|
raise DBError("failed foreign key check: %s" % (problems,))
|
||||||
|
|
||||||
if must_create:
|
def _open_db_connection(dbfile):
|
||||||
log.msg("populating new database with schema v%s" % target_version)
|
"""Open a new connection to the SQLite3 database at the given path.
|
||||||
schema = get_schema(target_version)
|
"""
|
||||||
db.executescript(schema)
|
try:
|
||||||
db.execute("INSERT INTO version (version) VALUES (?)",
|
db = sqlite3.connect(dbfile)
|
||||||
(target_version,))
|
except (EnvironmentError, sqlite3.OperationalError) as e:
|
||||||
db.commit()
|
raise DBError("Unable to create/open db file %s: %s" % (dbfile, e))
|
||||||
|
_initialize_db_connection(db)
|
||||||
|
return db
|
||||||
|
|
||||||
|
def _get_temporary_dbfile(dbfile):
|
||||||
|
"""Get a temporary filename near the given path.
|
||||||
|
"""
|
||||||
|
fd, name = tempfile.mkstemp(
|
||||||
|
prefix=os.path.basename(dbfile) + ".",
|
||||||
|
dir=os.path.dirname(dbfile)
|
||||||
|
)
|
||||||
|
os.close(fd)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def _atomic_create_and_initialize_db(dbfile, target_version):
|
||||||
|
"""Create and return a new database, initialized with the application
|
||||||
|
schema.
|
||||||
|
|
||||||
|
If anything goes wrong, nothing is left at the ``dbfile`` path.
|
||||||
|
"""
|
||||||
|
temp_dbfile = _get_temporary_dbfile(dbfile)
|
||||||
|
db = _open_db_connection(temp_dbfile)
|
||||||
|
_initialize_db_schema(db, target_version)
|
||||||
|
db.close()
|
||||||
|
os.rename(temp_dbfile, dbfile)
|
||||||
|
return _open_db_connection(dbfile)
|
||||||
|
|
||||||
|
def get_db(dbfile, target_version=TARGET_VERSION):
|
||||||
|
"""Open or create the given db file. The parent directory must exist.
|
||||||
|
Returns the db connection object, or raises DBError.
|
||||||
|
"""
|
||||||
|
if dbfile == ":memory:":
|
||||||
|
db = _open_db_connection(dbfile)
|
||||||
|
_initialize_db_schema(db, target_version)
|
||||||
|
elif os.path.exists(dbfile):
|
||||||
|
db = _open_db_connection(dbfile)
|
||||||
|
else:
|
||||||
|
db = _atomic_create_and_initialize_db(dbfile, target_version)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
version = db.execute("SELECT version FROM version").fetchone()["version"]
|
version = db.execute("SELECT version FROM version").fetchone()["version"]
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from __future__ import print_function, unicode_literals
|
from __future__ import print_function, unicode_literals
|
||||||
import os
|
import os
|
||||||
|
from twisted.python import filepath
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
|
from ..server import database
|
||||||
from ..server.database import get_db, TARGET_VERSION, dump_db
|
from ..server.database import get_db, TARGET_VERSION, dump_db
|
||||||
|
|
||||||
class DB(unittest.TestCase):
|
class DB(unittest.TestCase):
|
||||||
|
@ -11,6 +13,13 @@ class DB(unittest.TestCase):
|
||||||
self.assertEqual(len(rows), 1)
|
self.assertEqual(len(rows), 1)
|
||||||
self.assertEqual(rows[0]["version"], TARGET_VERSION)
|
self.assertEqual(rows[0]["version"], TARGET_VERSION)
|
||||||
|
|
||||||
|
def test_failed_create_allows_subsequent_create(self):
|
||||||
|
patch = self.patch(database, "get_schema", lambda version: b"this is a broken schema")
|
||||||
|
dbfile = filepath.FilePath(self.mktemp())
|
||||||
|
self.assertRaises(Exception, lambda: get_db(dbfile.path))
|
||||||
|
patch.restore()
|
||||||
|
get_db(dbfile.path)
|
||||||
|
|
||||||
def test_upgrade(self):
|
def test_upgrade(self):
|
||||||
basedir = self.mktemp()
|
basedir = self.mktemp()
|
||||||
os.mkdir(basedir)
|
os.mkdir(basedir)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user