Skip to content

Commit 08d6273

Browse files
authored
Merge pull request #831 from codeflash-ai/fix/correct-resolve-test-paths-for-runtime-comments
[FIX] Correctly resolve test files paths when adding runtime comments
2 parents 226d10e + 8d160cc commit 08d6273

File tree

3 files changed

+14
-14
lines changed

3 files changed

+14
-14
lines changed

codeflash/code_utils/edit_generated_tests.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import re
66
from pathlib import Path
7-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, Optional
88

99
import libcst as cst
1010
from libcst import MetadataWrapper
@@ -149,18 +149,19 @@ def leave_SimpleStatementSuite(
149149
return updated_node
150150

151151

152-
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]]) -> dict[str, int]:
152+
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path) -> dict[str, int]:
153153
unique_inv_ids: dict[str, int] = {}
154154
for inv_id, runtimes in inv_id_runtimes.items():
155155
test_qualified_name = (
156156
inv_id.test_class_name + "." + inv_id.test_function_name # type: ignore[operator]
157157
if inv_id.test_class_name
158158
else inv_id.test_function_name
159159
)
160-
abs_path = str(Path(inv_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve().with_suffix(""))
161-
if "__unit_test_" not in abs_path:
160+
abs_path = tests_project_rootdir / Path(inv_id.test_module_path.replace(".", os.sep)).with_suffix(".py")
161+
abs_path_str = str(abs_path.resolve().with_suffix(""))
162+
if "__unit_test_" not in abs_path_str or not test_qualified_name:
162163
continue
163-
key = test_qualified_name + "#" + abs_path # type: ignore[operator]
164+
key = test_qualified_name + "#" + abs_path_str
164165
parts = inv_id.iteration_id.split("_").__len__() # type: ignore[union-attr]
165166
cur_invid = inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1]) # type: ignore[union-attr]
166167
match_key = key + "#" + cur_invid
@@ -174,10 +175,11 @@ def add_runtime_comments_to_generated_tests(
174175
generated_tests: GeneratedTestsList,
175176
original_runtimes: dict[InvocationId, list[int]],
176177
optimized_runtimes: dict[InvocationId, list[int]],
178+
tests_project_rootdir: Optional[Path] = None,
177179
) -> GeneratedTestsList:
178180
"""Add runtime performance comments to function calls in generated tests."""
179-
original_runtimes_dict = unique_inv_id(original_runtimes)
180-
optimized_runtimes_dict = unique_inv_id(optimized_runtimes)
181+
original_runtimes_dict = unique_inv_id(original_runtimes, tests_project_rootdir or Path())
182+
optimized_runtimes_dict = unique_inv_id(optimized_runtimes, tests_project_rootdir or Path())
181183
# Process each generated test
182184
modified_tests = []
183185
for test in generated_tests.generated_tests:

codeflash/lsp/beta.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -338,21 +338,19 @@ def initialize_function_optimization(
338338
) -> dict[str, str]:
339339
document_uri = params.textDocument.uri
340340
document = server.workspace.get_text_document(document_uri)
341+
file_path = Path(document.path)
341342

342343
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info")
343344

344345
if server.optimizer is None:
345346
_initialize_optimizer_if_api_key_is_valid(server)
346347

347-
server.optimizer.worktree_mode()
348-
349-
original_args, _ = server.optimizer.original_args_and_test_cfg
350-
348+
server.optimizer.args.file = file_path
351349
server.optimizer.args.function = params.functionName
352-
original_relative_file_path = Path(document.path).relative_to(original_args.project_root)
353-
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
354350
server.optimizer.args.previous_checkpoint_functions = False
355351

352+
server.optimizer.worktree_mode()
353+
356354
server.show_message_log(
357355
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
358356
)

codeflash/optimization/function_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1376,7 +1376,7 @@ def process_review(
13761376
)
13771377

13781378
generated_tests = add_runtime_comments_to_generated_tests(
1379-
generated_tests, original_runtime_by_test, optimized_runtime_by_test
1379+
generated_tests, original_runtime_by_test, optimized_runtime_by_test, self.test_cfg.tests_project_rootdir
13801380
)
13811381

13821382
generated_tests_str = "\n#------------------------------------------------\n".join(

0 commit comments

Comments
 (0)