diff --git a/src/wormhole_transit_relay/database.py b/src/wormhole_transit_relay/database.py index 1f1f023..7bb09d8 100644 --- a/src/wormhole_transit_relay/database.py +++ b/src/wormhole_transit_relay/database.py @@ -116,6 +116,15 @@ def get_db(dbfile, target_version=TARGET_VERSION): return db +class DBDoesntExist(Exception): + pass + +def open_existing_db(dbfile): + assert dbfile != ":memory:" + if not os.path.exists(dbfile): + raise DBDoesntExist() + return _open_db_connection(dbfile) + class DBAlreadyExists(Exception): pass diff --git a/src/wormhole_transit_relay/test/test_database.py b/src/wormhole_transit_relay/test/test_database.py index 4319eec..0b5c918 100644 --- a/src/wormhole_transit_relay/test/test_database.py +++ b/src/wormhole_transit_relay/test/test_database.py @@ -83,3 +83,22 @@ class Create(unittest.TestCase): latest_text = dump_db(db) self.assertIn("CREATE TABLE", latest_text) +class Open(unittest.TestCase): + def test_open(self): + basedir = self.mktemp() + os.mkdir(basedir) + fn = os.path.join(basedir, "created.db") + db1 = database.create_db(fn) + latest_text = dump_db(db1) + self.assertIn("CREATE TABLE", latest_text) + db2 = database.open_existing_db(fn) + self.assertIn("CREATE TABLE", dump_db(db2)) + + def test_doesnt_exist(self): + basedir = self.mktemp() + os.mkdir(basedir) + fn = os.path.join(basedir, "created.db") + with self.assertRaises(database.DBDoesntExist): + database.open_existing_db(fn) + +