Skip to content

Commit 4a6eaab

Browse files
authored
Merge pull request #895 from codeflash-ai/cf-773
Run Formatter on generated tests
2 parents 3a0f372 + cf26123 commit 4a6eaab

File tree

3 files changed

+635
-5
lines changed

3 files changed

+635
-5
lines changed

codeflash/code_utils/formatter.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,19 @@ def is_diff_line(line: str) -> bool:
9696
return len(diff_lines)
9797

9898

99+
def format_generated_code(generated_test_source: str, formatter_cmds: list[str]) -> str:
100+
with tempfile.TemporaryDirectory() as test_dir_str:
101+
# try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) return code with 2 or more newlines substituted with 2 newlines
102+
original_temp = Path(test_dir_str) / "original_temp.py"
103+
original_temp.write_text(generated_test_source, encoding="utf8")
104+
_, formatted_code, changed = apply_formatter_cmds(
105+
formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=False
106+
)
107+
if not changed:
108+
return re.sub(r"\n{2,}", "\n\n", formatted_code)
109+
return formatted_code
110+
111+
99112
def format_code(
100113
formatter_cmds: list[str],
101114
path: Union[str, Path],
@@ -120,7 +133,7 @@ def format_code(
120133
original_code_lines = len(original_code.split("\n"))
121134

122135
if check_diff and original_code_lines > 50:
123-
# we dont' count the formatting diff for the optimized function as it should be well-formatted
136+
# we don't count the formatting diff for the optimized function as it should be well-formatted
124137
original_code_without_opfunc = original_code.replace(optimized_code, "")
125138

126139
original_temp = Path(test_dir_str) / "original_temp.py"

codeflash/optimization/function_optimizer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
remove_functions_from_generated_tests,
5656
)
5757
from codeflash.code_utils.env_utils import get_pr_number
58-
from codeflash.code_utils.formatter import format_code, sort_imports
58+
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
5959
from codeflash.code_utils.git_utils import git_root_dir
6060
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
6161
from codeflash.code_utils.line_profile_utils import add_decorator_imports
@@ -1413,11 +1413,15 @@ def process_review(
14131413

14141414
generated_tests_str = ""
14151415
for test in generated_tests.generated_tests:
1416-
generated_tests_str += f"```python\n{test.generated_original_test_source}\n```"
1416+
formatted_generated_test = format_generated_code(
1417+
test.generated_original_test_source, self.args.formatter_cmds
1418+
)
1419+
generated_tests_str += f"```python\n{formatted_generated_test}\n```"
14171420
generated_tests_str += "\n\n"
14181421

14191422
if concolic_test_str:
1420-
generated_tests_str += f"```python\n{concolic_test_str}\n```\n\n"
1423+
formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds)
1424+
generated_tests_str += f"```python\n{formatted_generated_test}\n```\n\n"
14211425

14221426
existing_tests, replay_tests, concolic_tests = existing_tests_source_for(
14231427
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),

0 commit comments

Comments
 (0)