diff --git a/src/wormhole_transit_relay/database.py b/src/wormhole_transit_relay/database.py index 7bb09d8..7fe6408 100644 --- a/src/wormhole_transit_relay/database.py +++ b/src/wormhole_transit_relay/database.py @@ -13,10 +13,10 @@ def get_schema(version): "db-schemas/v%d.sql" % version) return schema_bytes.decode("utf-8") -def get_upgrader(new_version): - schema_bytes = resource_string("wormhole_transit_relay", - "db-schemas/upgrade-to-v%d.sql" % new_version) - return schema_bytes.decode("utf-8") +## def get_upgrader(new_version): +## schema_bytes = resource_string("wormhole_transit_relay", +## "db-schemas/upgrade-to-v%d.sql" % new_version) +## return schema_bytes.decode("utf-8") TARGET_VERSION = 1 @@ -51,9 +51,11 @@ def _open_db_connection(dbfile): """ try: db = sqlite3.connect(dbfile) - except (EnvironmentError, sqlite3.OperationalError) as e: + _initialize_db_connection(db) + except (EnvironmentError, sqlite3.OperationalError, sqlite3.DatabaseError) as e: + # this indicates that the file is not a compatible database format. + # Perhaps it was created with an old version, or it might be junk. raise DBError("Unable to create/open db file %s: %s" % (dbfile, e)) - _initialize_db_connection(db) return db def _get_temporary_dbfile(dbfile): @@ -91,25 +93,20 @@ def get_db(dbfile, target_version=TARGET_VERSION): else: db = _atomic_create_and_initialize_db(dbfile, target_version) - try: - version = db.execute("SELECT version FROM version").fetchone()["version"] - except sqlite3.DatabaseError as e: - # this indicates that the file is not a compatible database format. - # Perhaps it was created with an old version, or it might be junk. - raise DBError("db file is unusable: %s" % e) + version = db.execute("SELECT version FROM version").fetchone()["version"] - while version < target_version: - log.msg(" need to upgrade from %s to %s" % (version, target_version)) - try: - upgrader = get_upgrader(version+1) - except ValueError: # ResourceError?? - log.msg(" unable to upgrade %s to %s" % (version, version+1)) - raise DBError("Unable to upgrade %s to version %s, left at %s" - % (dbfile, version+1, version)) - log.msg(" executing upgrader v%s->v%s" % (version, version+1)) - db.executescript(upgrader) - db.commit() - version = version+1 + ## while version < target_version: + ## log.msg(" need to upgrade from %s to %s" % (version, target_version)) + ## try: + ## upgrader = get_upgrader(version+1) + ## except ValueError: # ResourceError?? + ## log.msg(" unable to upgrade %s to %s" % (version, version+1)) + ## raise DBError("Unable to upgrade %s to version %s, left at %s" + ## % (dbfile, version+1, version)) + ## log.msg(" executing upgrader v%s->v%s" % (version, version+1)) + ## db.executescript(upgrader) + ## db.commit() + ## version = version+1 if version != target_version: raise DBError("Unable to handle db version %s" % version) diff --git a/src/wormhole_transit_relay/test/test_database.py b/src/wormhole_transit_relay/test/test_database.py index 0b5c918..fb3d4ff 100644 --- a/src/wormhole_transit_relay/test/test_database.py +++ b/src/wormhole_transit_relay/test/test_database.py @@ -3,7 +3,7 @@ import os from twisted.python import filepath from twisted.trial import unittest from .. import database -from ..database import get_db, TARGET_VERSION, dump_db +from ..database import get_db, TARGET_VERSION, dump_db, DBError class Get(unittest.TestCase): def test_create_default(self): @@ -13,6 +13,41 @@ class Get(unittest.TestCase): self.assertEqual(len(rows), 1) self.assertEqual(rows[0]["version"], TARGET_VERSION) + def test_open_existing_file(self): + basedir = self.mktemp() + os.mkdir(basedir) + fn = os.path.join(basedir, "normal.db") + db = get_db(fn) + rows = db.execute("SELECT * FROM version").fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["version"], TARGET_VERSION) + db2 = get_db(fn) + rows = db2.execute("SELECT * FROM version").fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["version"], TARGET_VERSION) + + def test_open_bad_version(self): + basedir = self.mktemp() + os.mkdir(basedir) + fn = os.path.join(basedir, "old.db") + db = get_db(fn) + db.execute("UPDATE version SET version=999") + db.commit() + + with self.assertRaises(DBError) as e: + get_db(fn) + self.assertIn("Unable to handle db version 999", str(e.exception)) + + def test_open_corrupt(self): + basedir = self.mktemp() + os.mkdir(basedir) + fn = os.path.join(basedir, "corrupt.db") + with open(fn, "wb") as f: + f.write(b"I am not a database") + with self.assertRaises(DBError) as e: + get_db(fn) + self.assertIn("not a database", str(e.exception)) + def test_failed_create_allows_subsequent_create(self): patch = self.patch(database, "get_schema", lambda version: b"this is a broken schema") dbfile = filepath.FilePath(self.mktemp())