Skip to content

Commit d5fa1ef

Browse files
committed
tests cache
1 parent 443cb4d commit d5fa1ef

File tree

2 files changed

+62
-18
lines changed

2 files changed

+62
-18
lines changed

codeflash/code_utils/code_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]:
238238
return True, function_names
239239

240240

241-
def get_run_tmp_file(file_path: Path) -> Path:
241+
def get_run_tmp_file(file_path: Path | str) -> Path:
242+
if isinstance(file_path, str):
243+
file_path = Path(file_path)
242244
if not hasattr(get_run_tmp_file, "tmpdir"):
243245
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
244246
return Path(get_run_tmp_file.tmpdir.name) / file_path

codeflash/discovery/discover_unit_tests.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ class TestsCache:
5454
def __init__(self) -> None:
5555
self.connection = sqlite3.connect(codeflash_cache_db)
5656
self.cur = self.connection.cursor()
57-
5857
self.cur.execute(
5958
"""
6059
CREATE TABLE IF NOT EXISTS discovered_tests(
@@ -76,7 +75,9 @@ def __init__(self) -> None:
7675
ON discovered_tests (file_path, file_hash)
7776
"""
7877
)
78+
7979
self._memory_cache = {}
80+
self._hash_cache = {}
8081

8182
def insert_test(
8283
self,
@@ -107,25 +108,30 @@ def insert_test(
107108
)
108109
self.connection.commit()
109110

110-
def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCalledInTest]:
111+
def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCalledInTest] | None:
111112
cache_key = (file_path, file_hash)
112113
if cache_key in self._memory_cache:
113114
return self._memory_cache[cache_key]
115+
114116
self.cur.execute("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (file_path, file_hash))
117+
rows = self.cur.fetchall()
118+
if not rows:
119+
return None
120+
115121
result = [
116122
FunctionCalledInTest(
117123
tests_in_file=TestsInFile(
118124
test_file=Path(row[0]), test_class=row[4], test_function=row[5], test_type=TestType(int(row[6]))
119125
),
120126
position=CodePosition(line_no=row[7], col_no=row[8]),
121127
)
122-
for row in self.cur.fetchall()
128+
for row in rows
123129
]
124130
self._memory_cache[cache_key] = result
125131
return result
126132

127133
@staticmethod
128-
def compute_file_hash(path: str) -> str:
134+
def compute_file_hash(path: str | Path) -> str:
129135
h = hashlib.sha256(usedforsecurity=False)
130136
with Path(path).open("rb") as f:
131137
while True:
@@ -521,7 +527,7 @@ def process_test_files(
521527
file_to_test_map: dict[Path, list[TestsInFile]],
522528
cfg: TestConfig,
523529
functions_to_optimize: list[FunctionToOptimize] | None = None,
524-
) -> tuple[dict[str, set[FunctionCalledInTest]], int]:
530+
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
525531
import jedi
526532

527533
project_root_path = cfg.project_root_path
@@ -536,29 +542,51 @@ def process_test_files(
536542
num_discovered_replay_tests = 0
537543
jedi_project = jedi.Project(path=project_root_path)
538544

545+
tests_cache = TestsCache()
546+
539547
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
540548
progress,
541549
task_id,
542550
):
543551
for test_file, functions in file_to_test_map.items():
552+
file_hash = TestsCache.compute_file_hash(test_file)
553+
554+
cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash)
555+
556+
if cached_tests:
557+
# Rebuild function_to_test_map from cached data
558+
tests_cache.cur.execute(
559+
"SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (str(test_file), file_hash)
560+
)
561+
for row in tests_cache.cur.fetchall():
562+
qualified_name_with_modules_from_root = row[2]
563+
test_type = TestType(int(row[6]))
564+
565+
function_called_in_test = FunctionCalledInTest(
566+
tests_in_file=TestsInFile(
567+
test_file=test_file, test_class=row[4], test_function=row[5], test_type=test_type
568+
),
569+
position=CodePosition(line_no=row[7], col_no=row[8]),
570+
)
571+
572+
function_to_test_map[qualified_name_with_modules_from_root].add(function_called_in_test)
573+
if test_type == TestType.REPLAY_TEST:
574+
num_discovered_replay_tests += 1
575+
num_discovered_tests += 1
576+
577+
progress.advance(task_id)
578+
continue
544579
try:
545580
script = jedi.Script(path=test_file, project=jedi_project)
546581
test_functions = set()
547582

548-
# Single call to get all names with references and definitions
549-
all_names = script.get_names(all_scopes=True, references=True, definitions=True)
583+
all_names = script.get_names(all_scopes=True, references=True)
584+
all_defs = script.get_names(all_scopes=True, definitions=True)
585+
all_names_top = script.get_names(all_scopes=True)
550586

551-
# Filter once and create lookup dictionaries
552-
top_level_functions = {}
553-
top_level_classes = {}
554-
all_defs = []
587+
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
588+
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
555589

556-
for name in all_names:
557-
if name.type == "function":
558-
top_level_functions[name.name] = name
559-
all_defs.append(name)
560-
elif name.type == "class":
561-
top_level_classes[name.name] = name
562590
except Exception as e:
563591
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
564592
progress.advance(task_id)
@@ -680,6 +708,18 @@ def process_test_files(
680708
position=CodePosition(line_no=name.line, col_no=name.column),
681709
)
682710
)
711+
tests_cache.insert_test(
712+
file_path=str(test_file),
713+
file_hash=file_hash,
714+
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
715+
function_name=scope,
716+
test_class=test_func.test_class or "",
717+
test_function=scope_test_function,
718+
test_type=test_func.test_type,
719+
line_number=name.line,
720+
col_number=name.column,
721+
)
722+
683723
if test_func.test_type == TestType.REPLAY_TEST:
684724
num_discovered_replay_tests += 1
685725

@@ -690,4 +730,6 @@ def process_test_files(
690730

691731
progress.advance(task_id)
692732

733+
tests_cache.close()
734+
693735
return dict(function_to_test_map), num_discovered_tests, num_discovered_replay_tests

0 commit comments

Comments
 (0)