44import os
55import re
66from pathlib import Path
7- from typing import TYPE_CHECKING
7+ from typing import TYPE_CHECKING , Optional
88
99import libcst as cst
1010from libcst import MetadataWrapper
@@ -149,15 +149,16 @@ def leave_SimpleStatementSuite(
149149 return updated_node
150150
151151
152- def unique_inv_id (inv_id_runtimes : dict [InvocationId , list [int ]]) -> dict [str , int ]:
152+ def unique_inv_id (inv_id_runtimes : dict [InvocationId , list [int ]], tests_project_rootdir : Path ) -> dict [str , int ]:
153153 unique_inv_ids : dict [str , int ] = {}
154154 for inv_id , runtimes in inv_id_runtimes .items ():
155155 test_qualified_name = (
156156 inv_id .test_class_name + "." + inv_id .test_function_name # type: ignore[operator]
157157 if inv_id .test_class_name
158158 else inv_id .test_function_name
159159 )
160- abs_path = str (Path (inv_id .test_module_path .replace ("." , os .sep )).with_suffix (".py" ).resolve ().with_suffix ("" ))
160+ abs_path = tests_project_rootdir / Path (inv_id .test_module_path .replace ("." , os .sep )).with_suffix (".py" )
161+ abs_path = str (abs_path .resolve ().with_suffix ("" ))
161162 if "__unit_test_" not in abs_path :
162163 continue
163164 key = test_qualified_name + "#" + abs_path # type: ignore[operator]
@@ -174,10 +175,11 @@ def add_runtime_comments_to_generated_tests(
174175 generated_tests : GeneratedTestsList ,
175176 original_runtimes : dict [InvocationId , list [int ]],
176177 optimized_runtimes : dict [InvocationId , list [int ]],
178+ tests_project_rootdir : Optional [Path ] = None ,
177179) -> GeneratedTestsList :
178180 """Add runtime performance comments to function calls in generated tests."""
179- original_runtimes_dict = unique_inv_id (original_runtimes )
180- optimized_runtimes_dict = unique_inv_id (optimized_runtimes )
181+ original_runtimes_dict = unique_inv_id (original_runtimes , tests_project_rootdir or Path () )
182+ optimized_runtimes_dict = unique_inv_id (optimized_runtimes , tests_project_rootdir or Path () )
181183 # Process each generated test
182184 modified_tests = []
183185 for test in generated_tests .generated_tests :
0 commit comments