Skip to content

Commit f626277

Browse files
Merge branch 'main' into lsp/init-flow
2 parents 4efbbd0 + fe82617 commit f626277

19 files changed

+215
-91
lines changed

codeflash/api/aiservice.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,13 @@ def get_optimization_impact(
577577
]
578578
)
579579
code_diff = f"```diff\n{diff_str}\n```"
580+
# TODO get complexity metrics and fn call heuristics -> constructing a complete static call graph can be expensive for really large repos
581+
# grep function name in codebase -> ast parser to get no of calls and no of calls in loop -> radon lib to get complexity metrics -> send as additional context to the AI service
582+
# metric 1 -> call count - how many times the function is called in the codebase
583+
# metric 2 -> loop call count - how many times the function is called in a loop in the codebase
584+
# metric 3 -> presence of decorators like @profile, @cache -> this means the owner of the repo cares about the performance of this function
585+
# metric 4 -> cyclomatic complexity (https://en.wikipedia.org/wiki/Cyclomatic_complexity)
586+
# metric 5 (for future) -> halstead complexity (https://en.wikipedia.org/wiki/Halstead_complexity_measures)
580587
logger.info("!lsp|Computing Optimization Impact…")
581588
payload = {
582589
"code_diff": code_diff,

codeflash/api/cfapi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def suggest_changes(
130130
coverage_message: str,
131131
replay_tests: str = "",
132132
concolic_tests: str = "",
133+
optimization_impact: str = "",
133134
) -> Response:
134135
"""Suggest changes to a pull request.
135136
@@ -155,6 +156,7 @@ def suggest_changes(
155156
"coverage_message": coverage_message,
156157
"replayTests": replay_tests,
157158
"concolicTests": concolic_tests,
159+
"optimizationImpact": optimization_impact,
158160
}
159161
return make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload)
160162

codeflash/benchmarking/instrument_codeflash_trace.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
from typing import TYPE_CHECKING, Optional, Union
44

5-
import isort
65
import libcst as cst
76

7+
from codeflash.code_utils.formatter import sort_imports
8+
89
if TYPE_CHECKING:
910
from pathlib import Path
1011

@@ -107,7 +108,7 @@ def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, l
107108
original_code = file_path.read_text(encoding="utf-8")
108109
new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize)
109110
# Modify the code
110-
modified_code = isort.code(code=new_code, float_to_top=True)
111+
modified_code = sort_imports(code=new_code, float_to_top=True)
111112

112113
# Write the modified code back to the file
113114
file_path.write_text(modified_code, encoding="utf-8")

codeflash/benchmarking/replay_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
from pathlib import Path
77
from typing import TYPE_CHECKING, Any
88

9-
import isort
10-
119
from codeflash.cli_cmds.console import logger
10+
from codeflash.code_utils.formatter import sort_imports
1211
from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods
1312
from codeflash.verification.verification_utils import get_test_file_path
1413

@@ -299,7 +298,7 @@ def generate_replay_test(
299298
test_framework=test_framework,
300299
max_run_count=max_run_count,
301300
)
302-
test_code = isort.code(test_code)
301+
test_code = sort_imports(code=test_code)
303302
output_file = get_test_file_path(
304303
test_dir=Path(output_dir), function_name=benchmark_module_path, test_type="replay"
305304
)

codeflash/code_utils/code_replacer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from functools import lru_cache
66
from typing import TYPE_CHECKING, Optional, TypeVar
77

8-
import isort
98
import libcst as cst
109
from libcst.metadata import PositionProvider
1110

1211
from codeflash.cli_cmds.console import logger
1312
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
1413
from codeflash.code_utils.config_parser import find_conftest_files
14+
from codeflash.code_utils.formatter import sort_imports
1515
from codeflash.code_utils.line_profile_utils import ImportAdder
1616
from codeflash.models.models import FunctionParent
1717

@@ -226,7 +226,7 @@ def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
226226
module = cst.parse_module(file_content)
227227
importadder = ImportAdder("import pytest")
228228
modified_module = module.visit(importadder)
229-
modified_module = cst.parse_module(isort.code(modified_module.code, float_to_top=True))
229+
modified_module = cst.parse_module(sort_imports(code=modified_module.code, float_to_top=True))
230230
pytest_mark_adder = PytestMarkAdder("codeflash_no_autouse")
231231
modified_module = modified_module.visit(pytest_mark_adder)
232232
test_path.write_text(modified_module.code, encoding="utf-8")

codeflash/code_utils/code_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,9 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]:
304304
return True, function_names
305305

306306

307-
def get_run_tmp_file(file_path: Path) -> Path:
307+
def get_run_tmp_file(file_path: Path | str) -> Path:
308+
if isinstance(file_path, str):
309+
file_path = Path(file_path)
308310
if not hasattr(get_run_tmp_file, "tmpdir"):
309311
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
310312
return Path(get_run_tmp_file.tmpdir.name) / file_path

codeflash/code_utils/formatter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,11 @@ def format_code(
166166
return formatted_code
167167

168168

169-
def sort_imports(code: str) -> str:
169+
def sort_imports(code: str, *, float_to_top: bool = False) -> str:
170170
try:
171171
# Deduplicate and sort imports, modify the code in memory, not on disk
172-
sorted_code = isort.code(code)
173-
except Exception:
172+
sorted_code = isort.code(code=code, float_to_top=float_to_top)
173+
except Exception: # this will also catch the FileSkipComment exception, use this fn everywhere
174174
logger.exception("Failed to sort imports with isort.")
175175
return code # Fall back to original code if isort fails
176176

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from codeflash.cli_cmds.console import logger
1212
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
13+
from codeflash.code_utils.formatter import sort_imports
1314
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1415
from codeflash.models.models import FunctionParent, TestingMode, VerificationType
1516

@@ -1129,7 +1130,7 @@ def add_async_decorator_to_function(
11291130
import_transformer = AsyncDecoratorImportAdder(mode)
11301131
module = module.visit(import_transformer)
11311132

1132-
return isort.code(module.code, float_to_top=True), decorator_transformer.added_decorator
1133+
return sort_imports(code=module.code, float_to_top=True), decorator_transformer.added_decorator
11331134
except Exception as e:
11341135
logger.exception(f"Error adding async decorator to function {function.qualified_name}: {e}")
11351136
return source_code, False

codeflash/code_utils/line_profile_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from pathlib import Path
77
from typing import TYPE_CHECKING, Union
88

9-
import isort
109
import libcst as cst
1110

1211
from codeflash.code_utils.code_utils import get_run_tmp_file
12+
from codeflash.code_utils.formatter import sort_imports
1313

1414
if TYPE_CHECKING:
1515
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@@ -213,7 +213,7 @@ def add_decorator_imports(function_to_optimize: FunctionToOptimize, code_context
213213
transformer = ImportAdder("from line_profiler import profile as codeflash_line_profile")
214214
# Apply the transformer to add the import
215215
module_node = module_node.visit(transformer)
216-
modified_code = isort.code(module_node.code, float_to_top=True)
216+
modified_code = sort_imports(code=module_node.code, float_to_top=True)
217217
# write to file
218218
with file_path.open("w", encoding="utf-8") as file:
219219
file.write(modified_code)

0 commit comments

Comments
 (0)