From 69d66dd4c1dddf43ae15934a91daad692c290cdc Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 7 Nov 2017 21:45:41 -0600 Subject: [PATCH] database: add create-only function, for migration tool --- src/wormhole_transit_relay/database.py | 17 +++++++++++++ .../test/test_database.py | 24 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/src/wormhole_transit_relay/database.py b/src/wormhole_transit_relay/database.py index de8bc08..1f1f023 100644 --- a/src/wormhole_transit_relay/database.py +++ b/src/wormhole_transit_relay/database.py @@ -116,6 +116,23 @@ def get_db(dbfile, target_version=TARGET_VERSION): return db +class DBAlreadyExists(Exception): + pass + +def create_db(dbfile): + """Create the given db file. Refuse to touch a pre-existing file. + + This is meant for use by migration tools, to create the output target""" + + if dbfile == ":memory:": + db = _open_db_connection(dbfile) + _initialize_db_schema(db, TARGET_VERSION) + elif os.path.exists(dbfile): + raise DBAlreadyExists() + else: + db = _atomic_create_and_initialize_db(dbfile, TARGET_VERSION) + return db + def dump_db(db): # to let _iterdump work, we need to restore the original row factory orig = db.row_factory diff --git a/src/wormhole_transit_relay/test/test_database.py b/src/wormhole_transit_relay/test/test_database.py index 46dd570..4319eec 100644 --- a/src/wormhole_transit_relay/test/test_database.py +++ b/src/wormhole_transit_relay/test/test_database.py @@ -59,3 +59,27 @@ class Get(unittest.TestCase): with open("new.sql","w") as f: f.write(latest_text) # check with "diff -u _trial_temp/up.sql _trial_temp/new.sql" self.assertEqual(dbA_text, latest_text) + +class Create(unittest.TestCase): + def test_memory(self): + db = database.create_db(":memory:") + latest_text = dump_db(db) + self.assertIn("CREATE TABLE", latest_text) + + def test_preexisting(self): + basedir = self.mktemp() + os.mkdir(basedir) + fn = os.path.join(basedir, "preexisting.db") + with open(fn, "w"): + pass + with self.assertRaises(database.DBAlreadyExists): + database.create_db(fn) + + def test_create(self): + basedir = self.mktemp() + os.mkdir(basedir) + fn = os.path.join(basedir, "created.db") + db = database.create_db(fn) + latest_text = dump_db(db) + self.assertIn("CREATE TABLE", latest_text) +