@@ -54,7 +54,6 @@ class TestsCache:
5454 def __init__ (self ) -> None :
5555 self .connection = sqlite3 .connect (codeflash_cache_db )
5656 self .cur = self .connection .cursor ()
57-
5857 self .cur .execute (
5958 """
6059 CREATE TABLE IF NOT EXISTS discovered_tests(
@@ -76,7 +75,9 @@ def __init__(self) -> None:
7675 ON discovered_tests (file_path, file_hash)
7776 """
7877 )
78+
7979 self ._memory_cache = {}
80+ self ._hash_cache = {}
8081
8182 def insert_test (
8283 self ,
@@ -107,25 +108,30 @@ def insert_test(
107108 )
108109 self .connection .commit ()
109110
110- def get_tests_for_file (self , file_path : str , file_hash : str ) -> list [FunctionCalledInTest ]:
111+ def get_tests_for_file (self , file_path : str , file_hash : str ) -> list [FunctionCalledInTest ] | None :
111112 cache_key = (file_path , file_hash )
112113 if cache_key in self ._memory_cache :
113114 return self ._memory_cache [cache_key ]
115+
114116 self .cur .execute ("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?" , (file_path , file_hash ))
117+ rows = self .cur .fetchall ()
118+ if not rows :
119+ return None
120+
115121 result = [
116122 FunctionCalledInTest (
117123 tests_in_file = TestsInFile (
118124 test_file = Path (row [0 ]), test_class = row [4 ], test_function = row [5 ], test_type = TestType (int (row [6 ]))
119125 ),
120126 position = CodePosition (line_no = row [7 ], col_no = row [8 ]),
121127 )
122- for row in self . cur . fetchall ()
128+ for row in rows
123129 ]
124130 self ._memory_cache [cache_key ] = result
125131 return result
126132
127133 @staticmethod
128- def compute_file_hash (path : str ) -> str :
134+ def compute_file_hash (path : str | Path ) -> str :
129135 h = hashlib .sha256 (usedforsecurity = False )
130136 with Path (path ).open ("rb" ) as f :
131137 while True :
@@ -521,7 +527,7 @@ def process_test_files(
521527 file_to_test_map : dict [Path , list [TestsInFile ]],
522528 cfg : TestConfig ,
523529 functions_to_optimize : list [FunctionToOptimize ] | None = None ,
524- ) -> tuple [dict [str , set [FunctionCalledInTest ]], int ]:
530+ ) -> tuple [dict [str , set [FunctionCalledInTest ]], int , int ]:
525531 import jedi
526532
527533 project_root_path = cfg .project_root_path
@@ -536,29 +542,51 @@ def process_test_files(
536542 num_discovered_replay_tests = 0
537543 jedi_project = jedi .Project (path = project_root_path )
538544
545+ tests_cache = TestsCache ()
546+
539547 with test_files_progress_bar (total = len (file_to_test_map ), description = "Processing test files" ) as (
540548 progress ,
541549 task_id ,
542550 ):
543551 for test_file , functions in file_to_test_map .items ():
552+ file_hash = TestsCache .compute_file_hash (test_file )
553+
554+ cached_tests = tests_cache .get_tests_for_file (str (test_file ), file_hash )
555+
556+ if cached_tests :
557+ # Rebuild function_to_test_map from cached data
558+ tests_cache .cur .execute (
559+ "SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?" , (str (test_file ), file_hash )
560+ )
561+ for row in tests_cache .cur .fetchall ():
562+ qualified_name_with_modules_from_root = row [2 ]
563+ test_type = TestType (int (row [6 ]))
564+
565+ function_called_in_test = FunctionCalledInTest (
566+ tests_in_file = TestsInFile (
567+ test_file = test_file , test_class = row [4 ], test_function = row [5 ], test_type = test_type
568+ ),
569+ position = CodePosition (line_no = row [7 ], col_no = row [8 ]),
570+ )
571+
572+ function_to_test_map [qualified_name_with_modules_from_root ].add (function_called_in_test )
573+ if test_type == TestType .REPLAY_TEST :
574+ num_discovered_replay_tests += 1
575+ num_discovered_tests += 1
576+
577+ progress .advance (task_id )
578+ continue
544579 try :
545580 script = jedi .Script (path = test_file , project = jedi_project )
546581 test_functions = set ()
547582
548- # Single call to get all names with references and definitions
549- all_names = script .get_names (all_scopes = True , references = True , definitions = True )
583+ all_names = script .get_names (all_scopes = True , references = True )
584+ all_defs = script .get_names (all_scopes = True , definitions = True )
585+ all_names_top = script .get_names (all_scopes = True )
550586
551- # Filter once and create lookup dictionaries
552- top_level_functions = {}
553- top_level_classes = {}
554- all_defs = []
587+ top_level_functions = {name .name : name for name in all_names_top if name .type == "function" }
588+ top_level_classes = {name .name : name for name in all_names_top if name .type == "class" }
555589
556- for name in all_names :
557- if name .type == "function" :
558- top_level_functions [name .name ] = name
559- all_defs .append (name )
560- elif name .type == "class" :
561- top_level_classes [name .name ] = name
562590 except Exception as e :
563591 logger .debug (f"Failed to get jedi script for { test_file } : { e } " )
564592 progress .advance (task_id )
@@ -680,6 +708,18 @@ def process_test_files(
680708 position = CodePosition (line_no = name .line , col_no = name .column ),
681709 )
682710 )
711+ tests_cache .insert_test (
712+ file_path = str (test_file ),
713+ file_hash = file_hash ,
714+ qualified_name_with_modules_from_root = qualified_name_with_modules_from_root ,
715+ function_name = scope ,
716+ test_class = test_func .test_class or "" ,
717+ test_function = scope_test_function ,
718+ test_type = test_func .test_type ,
719+ line_number = name .line ,
720+ col_number = name .column ,
721+ )
722+
683723 if test_func .test_type == TestType .REPLAY_TEST :
684724 num_discovered_replay_tests += 1
685725
@@ -690,4 +730,6 @@ def process_test_files(
690730
691731 progress .advance (task_id )
692732
733+ tests_cache .close ()
734+
693735 return dict (function_to_test_map ), num_discovered_tests , num_discovered_replay_tests
0 commit comments