diff --git a/src/wormhole/server/database.py b/src/wormhole/server/database.py index e49b51b..eb188e1 100644 --- a/src/wormhole/server/database.py +++ b/src/wormhole/server/database.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals import os import sqlite3 +import tempfile from pkg_resources import resource_string from twisted.python import log @@ -25,29 +26,70 @@ def dict_factory(cursor, row): d[col[0]] = row[idx] return d -def get_db(dbfile, target_version=TARGET_VERSION): - """Open or create the given db file. The parent directory must exist. - Returns the db connection object, or raises DBError. +def _initialize_db_schema(db, target_version): + """Creates the application schema in the given database. """ + log.msg("populating new database with schema v%s" % target_version) + schema = get_schema(target_version) + db.executescript(schema) + db.execute("INSERT INTO version (version) VALUES (?)", + (target_version,)) + db.commit() - must_create = (dbfile == ":memory:") or not os.path.exists(dbfile) - try: - db = sqlite3.connect(dbfile) - except (EnvironmentError, sqlite3.OperationalError) as e: - raise DBError("Unable to create/open db file %s: %s" % (dbfile, e)) +def _initialize_db_connection(db): + """Sets up the db connection object with a row factory and with necessary + foreign key settings. + """ db.row_factory = dict_factory db.execute("PRAGMA foreign_keys = ON") problems = db.execute("PRAGMA foreign_key_check").fetchall() if problems: raise DBError("failed foreign key check: %s" % (problems,)) - if must_create: - log.msg("populating new database with schema v%s" % target_version) - schema = get_schema(target_version) - db.executescript(schema) - db.execute("INSERT INTO version (version) VALUES (?)", - (target_version,)) - db.commit() +def _open_db_connection(dbfile): + """Open a new connection to the SQLite3 database at the given path. + """ + try: + db = sqlite3.connect(dbfile) + except (EnvironmentError, sqlite3.OperationalError) as e: + raise DBError("Unable to create/open db file %s: %s" % (dbfile, e)) + _initialize_db_connection(db) + return db + +def _get_temporary_dbfile(dbfile): + """Get a temporary filename near the given path. + """ + fd, name = tempfile.mkstemp( + prefix=os.path.basename(dbfile) + ".", + dir=os.path.dirname(dbfile) + ) + os.close(fd) + return name + +def _atomic_create_and_initialize_db(dbfile, target_version): + """Create and return a new database, initialized with the application + schema. + + If anything goes wrong, nothing is left at the ``dbfile`` path. + """ + temp_dbfile = _get_temporary_dbfile(dbfile) + db = _open_db_connection(temp_dbfile) + _initialize_db_schema(db, target_version) + db.close() + os.rename(temp_dbfile, dbfile) + return _open_db_connection(dbfile) + +def get_db(dbfile, target_version=TARGET_VERSION): + """Open or create the given db file. The parent directory must exist. + Returns the db connection object, or raises DBError. + """ + if dbfile == ":memory:": + db = _open_db_connection(dbfile) + _initialize_db_schema(db, target_version) + elif os.path.exists(dbfile): + db = _open_db_connection(dbfile) + else: + db = _atomic_create_and_initialize_db(dbfile, target_version) try: version = db.execute("SELECT version FROM version").fetchone()["version"] diff --git a/src/wormhole/test/test_database.py b/src/wormhole/test/test_database.py index 7a7b491..4ebc2cb 100644 --- a/src/wormhole/test/test_database.py +++ b/src/wormhole/test/test_database.py @@ -1,6 +1,8 @@ from __future__ import print_function, unicode_literals import os +from twisted.python import filepath from twisted.trial import unittest +from ..server import database from ..server.database import get_db, TARGET_VERSION, dump_db class DB(unittest.TestCase): @@ -11,6 +13,13 @@ class DB(unittest.TestCase): self.assertEqual(len(rows), 1) self.assertEqual(rows[0]["version"], TARGET_VERSION) + def test_failed_create_allows_subsequent_create(self): + patch = self.patch(database, "get_schema", lambda version: b"this is a broken schema") + dbfile = filepath.FilePath(self.mktemp()) + self.assertRaises(Exception, lambda: get_db(dbfile.path)) + patch.restore() + get_db(dbfile.path) + def test_upgrade(self): basedir = self.mktemp() os.mkdir(basedir)