66
77from codeflash .discovery .functions_to_optimize import FunctionToOptimize
88from codeflash .either import is_successful
9- from codeflash .models .models import FunctionParent
9+ from codeflash .models .models import FunctionParent , get_code_block_splitter
1010from codeflash .optimization .function_optimizer import FunctionOptimizer
1111from codeflash .optimization .optimizer import Optimizer
1212from codeflash .verification .verification_utils import TestConfig
@@ -242,8 +242,9 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
242242 code_context = ctx_result .unwrap ()
243243 assert code_context .helper_functions [0 ].qualified_name == "AbstractCacheBackend.get_cache_or_call"
244244 assert (
245- code_context .testgen_context_code
246- == f'''_P = ParamSpec("_P")
245+ code_context .testgen_context .flat
246+ == f'''# file: { file_path .relative_to (project_root_path )}
247+ _P = ParamSpec("_P")
247248_KEY_T = TypeVar("_KEY_T")
248249_STORE_T = TypeVar("_STORE_T")
249250class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
@@ -395,10 +396,11 @@ def test_bubble_sort_deps() -> None:
395396 function_to_optimize = FunctionToOptimize (
396397 function_name = "sorter_deps" , file_path = file_path , parents = [], starting_line = None , ending_line = None
397398 )
399+ project_root = file_path .parent .parent .resolve ()
398400 test_config = TestConfig (
399401 tests_root = str (file_path .parent / "tests" ),
400402 tests_project_rootdir = file_path .parent .resolve (),
401- project_root_path = file_path . parent . parent . resolve () ,
403+ project_root_path = project_root ,
402404 test_framework = "pytest" ,
403405 pytest_cmd = "pytest" ,
404406 )
@@ -410,19 +412,20 @@ def test_bubble_sort_deps() -> None:
410412 pytest .fail ()
411413 code_context = ctx_result .unwrap ()
412414 assert (
413- code_context .testgen_context_code
414- == """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
415- from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
416-
415+ code_context .testgen_context .flat
416+ == f"""{ get_code_block_splitter (Path ("code_to_optimize/bubble_sort_dep1_helper.py" ))}
417417def dep1_comparer(arr, j: int) -> bool:
418418 return arr[j] > arr[j + 1]
419419
420+ { get_code_block_splitter (Path ("code_to_optimize/bubble_sort_dep2_swap.py" ))}
420421def dep2_swap(arr, j):
421422 temp = arr[j]
422423 arr[j] = arr[j + 1]
423424 arr[j + 1] = temp
424425
425-
426+ { get_code_block_splitter (Path ("code_to_optimize/bubble_sort_deps.py" ))}
427+ from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
428+ from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
426429
427430def sorter_deps(arr):
428431 for i in range(len(arr)):
0 commit comments