Skip to content

Commit 542cceb

Browse files
authored
Add db_journal_mode trait (#61)
1 parent 99c1830 commit 542cceb

File tree

2 files changed

+62
-7
lines changed

2 files changed

+62
-7
lines changed

jupyter_server_fileid/manager.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Any, Callable, Dict, Optional
99

1010
from jupyter_core.paths import jupyter_data_dir
11-
from traitlets import TraitError, Unicode, validate
11+
from traitlets import TraitError, Unicode, default, validate
1212
from traitlets.config.configurable import LoggingConfigurable
1313

1414

@@ -75,6 +75,23 @@ def _validate_db_path(self, proposal):
7575
)
7676
return proposal["value"]
7777

78+
JOURNAL_MODES = ["DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"]
79+
db_journal_mode = Unicode(
80+
help=(
81+
f"The journal mode setting for the SQLite database. Must be one of {JOURNAL_MODES}."
82+
),
83+
config=True,
84+
)
85+
86+
@validate("db_journal_mode")
87+
def _validate_db_journal_mode(self, proposal):
88+
candidate_value = proposal["value"]
89+
if candidate_value is None or candidate_value.upper() not in self.JOURNAL_MODES:
90+
raise TraitError(
91+
f"db_journal_mode ('{candidate_value}') must be one of {self.JOURNAL_MODES}."
92+
)
93+
return candidate_value.upper()
94+
7895
@staticmethod
7996
def _uuid() -> str:
8097
return str(uuid.uuid4())
@@ -229,6 +246,10 @@ def _validate_root_dir(self, proposal):
229246
normalized_content_root = self._normalize_separators(proposal["value"])
230247
return normalized_content_root
231248

249+
@default("db_journal_mode")
250+
def _default_db_journal_mode(self):
251+
return "DELETE"
252+
232253
def __init__(self, *args, **kwargs):
233254
# pass args and kwargs to parent Configurable
234255
super().__init__(*args, **kwargs)
@@ -239,9 +260,11 @@ def __init__(self, *args, **kwargs):
239260
self.log.info(f"ArbitraryFileIdManager : Configured database path: {self.db_path}")
240261
self.con = sqlite3.connect(self.db_path)
241262
self.log.info("ArbitraryFileIdManager : Successfully connected to database file.")
242-
self.log.info("ArbitraryFileIdManager : Creating File ID tables and indices.")
243-
# do not allow reads to block writes. required when using multiple processes
244-
self.con.execute("PRAGMA journal_mode = WAL")
263+
self.log.info(
264+
f"ArbitraryFileIdManager : Creating File ID tables and indices with "
265+
f"journal_mode = {self.db_journal_mode}"
266+
)
267+
self.con.execute(f"PRAGMA journal_mode = {self.db_journal_mode}")
245268
self.con.execute(
246269
"CREATE TABLE IF NOT EXISTS Files("
247270
"id TEXT PRIMARY KEY NOT NULL, "
@@ -394,6 +417,10 @@ def _validate_root_dir(self, proposal):
394417
)
395418
return proposal["value"]
396419

420+
@default("db_journal_mode")
421+
def _default_db_journal_mode(self):
422+
return "WAL"
423+
397424
def __init__(self, *args, **kwargs):
398425
# pass args and kwargs to parent Configurable
399426
super().__init__(*args, **kwargs)
@@ -405,9 +432,11 @@ def __init__(self, *args, **kwargs):
405432
self.log.info(f"LocalFileIdManager : Configured database path: {self.db_path}")
406433
self.con = sqlite3.connect(self.db_path)
407434
self.log.info("LocalFileIdManager : Successfully connected to database file.")
408-
self.log.info("LocalFileIdManager : Creating File ID tables and indices.")
409-
# do not allow reads to block writes. required when using multiple processes
410-
self.con.execute("PRAGMA journal_mode = WAL")
435+
self.log.info(
436+
f"LocalFileIdManager : Creating File ID tables and indices with "
437+
f"journal_mode = {self.db_journal_mode}"
438+
)
439+
self.con.execute(f"PRAGMA journal_mode = {self.db_journal_mode}")
411440
self.con.execute(
412441
"CREATE TABLE IF NOT EXISTS Files("
413442
"id TEXT PRIMARY KEY NOT NULL, "

tests/test_manager.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,3 +565,29 @@ def test_save(any_fid_manager, test_path, fs_helpers):
565565
any_fid_manager.save(test_path)
566566

567567
assert any_fid_manager.get_id(test_path) == id
568+
569+
570+
@pytest.mark.parametrize(
571+
"db_journal_mode", ["invalid", None, "DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"]
572+
)
573+
def test_db_journal_mode(any_fid_manager_class, fid_db_path, jp_root_dir, db_journal_mode):
574+
if db_journal_mode == "invalid": # test invalid
575+
with pytest.raises(TraitError, match=" must be one of "):
576+
any_fid_manager_class(
577+
db_path=fid_db_path, root_dir=str(jp_root_dir), db_journal_mode=db_journal_mode
578+
)
579+
else:
580+
if not db_journal_mode: # test correct defaults
581+
expected_journal_mode = (
582+
"WAL" if any_fid_manager_class.__name__ == "LocalFileIdManager" else "DELETE"
583+
)
584+
fid_manager = any_fid_manager_class(db_path=fid_db_path, root_dir=str(jp_root_dir))
585+
else: # test any valid value
586+
expected_journal_mode = db_journal_mode
587+
fid_manager = any_fid_manager_class(
588+
db_path=fid_db_path, root_dir=str(jp_root_dir), db_journal_mode=db_journal_mode
589+
)
590+
591+
cursor = fid_manager.con.execute("PRAGMA journal_mode")
592+
actual_journal_mode = cursor.fetchone()
593+
assert actual_journal_mode[0].upper() == expected_journal_mode

0 commit comments

Comments
 (0)