DB: small cleanups, improve test coverage
This commit is contained in:
		
							parent
							
								
									9008d4339a
								
							
						
					
					
						commit
						c270ad6e0b
					
				| 
						 | 
				
			
			@ -13,10 +13,10 @@ def get_schema(version):
 | 
			
		|||
                                   "db-schemas/v%d.sql" % version)
 | 
			
		||||
    return schema_bytes.decode("utf-8")
 | 
			
		||||
 | 
			
		||||
def get_upgrader(new_version):
 | 
			
		||||
    schema_bytes = resource_string("wormhole_transit_relay",
 | 
			
		||||
                                   "db-schemas/upgrade-to-v%d.sql" % new_version)
 | 
			
		||||
    return schema_bytes.decode("utf-8")
 | 
			
		||||
## def get_upgrader(new_version):
 | 
			
		||||
##     schema_bytes = resource_string("wormhole_transit_relay",
 | 
			
		||||
##                                    "db-schemas/upgrade-to-v%d.sql" % new_version)
 | 
			
		||||
##     return schema_bytes.decode("utf-8")
 | 
			
		||||
 | 
			
		||||
TARGET_VERSION = 1
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -51,9 +51,11 @@ def _open_db_connection(dbfile):
 | 
			
		|||
    """
 | 
			
		||||
    try:
 | 
			
		||||
        db = sqlite3.connect(dbfile)
 | 
			
		||||
    except (EnvironmentError, sqlite3.OperationalError) as e:
 | 
			
		||||
        _initialize_db_connection(db)
 | 
			
		||||
    except (EnvironmentError, sqlite3.OperationalError, sqlite3.DatabaseError) as e:
 | 
			
		||||
        # this indicates that the file is not a compatible database format.
 | 
			
		||||
        # Perhaps it was created with an old version, or it might be junk.
 | 
			
		||||
        raise DBError("Unable to create/open db file %s: %s" % (dbfile, e))
 | 
			
		||||
    _initialize_db_connection(db)
 | 
			
		||||
    return db
 | 
			
		||||
 | 
			
		||||
def _get_temporary_dbfile(dbfile):
 | 
			
		||||
| 
						 | 
				
			
			@ -91,25 +93,20 @@ def get_db(dbfile, target_version=TARGET_VERSION):
 | 
			
		|||
    else:
 | 
			
		||||
        db = _atomic_create_and_initialize_db(dbfile, target_version)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        version = db.execute("SELECT version FROM version").fetchone()["version"]
 | 
			
		||||
    except sqlite3.DatabaseError as e:
 | 
			
		||||
        # this indicates that the file is not a compatible database format.
 | 
			
		||||
        # Perhaps it was created with an old version, or it might be junk.
 | 
			
		||||
        raise DBError("db file is unusable: %s" % e)
 | 
			
		||||
    version = db.execute("SELECT version FROM version").fetchone()["version"]
 | 
			
		||||
 | 
			
		||||
    while version < target_version:
 | 
			
		||||
        log.msg(" need to upgrade from %s to %s" % (version, target_version))
 | 
			
		||||
        try:
 | 
			
		||||
            upgrader = get_upgrader(version+1)
 | 
			
		||||
        except ValueError: # ResourceError??
 | 
			
		||||
            log.msg(" unable to upgrade %s to %s" % (version, version+1))
 | 
			
		||||
            raise DBError("Unable to upgrade %s to version %s, left at %s"
 | 
			
		||||
                          % (dbfile, version+1, version))
 | 
			
		||||
        log.msg(" executing upgrader v%s->v%s" % (version, version+1))
 | 
			
		||||
        db.executescript(upgrader)
 | 
			
		||||
        db.commit()
 | 
			
		||||
        version = version+1
 | 
			
		||||
    ## while version < target_version:
 | 
			
		||||
    ##     log.msg(" need to upgrade from %s to %s" % (version, target_version))
 | 
			
		||||
    ##     try:
 | 
			
		||||
    ##         upgrader = get_upgrader(version+1)
 | 
			
		||||
    ##     except ValueError: # ResourceError??
 | 
			
		||||
    ##         log.msg(" unable to upgrade %s to %s" % (version, version+1))
 | 
			
		||||
    ##         raise DBError("Unable to upgrade %s to version %s, left at %s"
 | 
			
		||||
    ##                       % (dbfile, version+1, version))
 | 
			
		||||
    ##     log.msg(" executing upgrader v%s->v%s" % (version, version+1))
 | 
			
		||||
    ##     db.executescript(upgrader)
 | 
			
		||||
    ##     db.commit()
 | 
			
		||||
    ##     version = version+1
 | 
			
		||||
 | 
			
		||||
    if version != target_version:
 | 
			
		||||
        raise DBError("Unable to handle db version %s" % version)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,7 +3,7 @@ import os
 | 
			
		|||
from twisted.python import filepath
 | 
			
		||||
from twisted.trial import unittest
 | 
			
		||||
from .. import database
 | 
			
		||||
from ..database import get_db, TARGET_VERSION, dump_db
 | 
			
		||||
from ..database import get_db, TARGET_VERSION, dump_db, DBError
 | 
			
		||||
 | 
			
		||||
class Get(unittest.TestCase):
 | 
			
		||||
    def test_create_default(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -13,6 +13,41 @@ class Get(unittest.TestCase):
 | 
			
		|||
        self.assertEqual(len(rows), 1)
 | 
			
		||||
        self.assertEqual(rows[0]["version"], TARGET_VERSION)
 | 
			
		||||
 | 
			
		||||
    def test_open_existing_file(self):
 | 
			
		||||
        basedir = self.mktemp()
 | 
			
		||||
        os.mkdir(basedir)
 | 
			
		||||
        fn = os.path.join(basedir, "normal.db")
 | 
			
		||||
        db = get_db(fn)
 | 
			
		||||
        rows = db.execute("SELECT * FROM version").fetchall()
 | 
			
		||||
        self.assertEqual(len(rows), 1)
 | 
			
		||||
        self.assertEqual(rows[0]["version"], TARGET_VERSION)
 | 
			
		||||
        db2 = get_db(fn)
 | 
			
		||||
        rows = db2.execute("SELECT * FROM version").fetchall()
 | 
			
		||||
        self.assertEqual(len(rows), 1)
 | 
			
		||||
        self.assertEqual(rows[0]["version"], TARGET_VERSION)
 | 
			
		||||
 | 
			
		||||
    def test_open_bad_version(self):
 | 
			
		||||
        basedir = self.mktemp()
 | 
			
		||||
        os.mkdir(basedir)
 | 
			
		||||
        fn = os.path.join(basedir, "old.db")
 | 
			
		||||
        db = get_db(fn)
 | 
			
		||||
        db.execute("UPDATE version SET version=999")
 | 
			
		||||
        db.commit()
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(DBError) as e:
 | 
			
		||||
            get_db(fn)
 | 
			
		||||
        self.assertIn("Unable to handle db version 999", str(e.exception))
 | 
			
		||||
 | 
			
		||||
    def test_open_corrupt(self):
 | 
			
		||||
        basedir = self.mktemp()
 | 
			
		||||
        os.mkdir(basedir)
 | 
			
		||||
        fn = os.path.join(basedir, "corrupt.db")
 | 
			
		||||
        with open(fn, "wb") as f:
 | 
			
		||||
            f.write(b"I am not a database")
 | 
			
		||||
        with self.assertRaises(DBError) as e:
 | 
			
		||||
            get_db(fn)
 | 
			
		||||
        self.assertIn("not a database", str(e.exception))
 | 
			
		||||
 | 
			
		||||
    def test_failed_create_allows_subsequent_create(self):
 | 
			
		||||
        patch = self.patch(database, "get_schema", lambda version: b"this is a broken schema")
 | 
			
		||||
        dbfile = filepath.FilePath(self.mktemp())
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user