Skip to content

Commit ef69713

Browse files
authored
refactor how async decorators are applied at source site (#897)
* 5% * better reporting for throughput * first pass * improve reporting for async optimizations * remove some deduplication * simplify usage * refactor add_async_decorator_to_function * doesn't even work * fix mypy complants * handle libcst exception * fix type checking * revert config
1 parent 50487d6 commit ef69713

File tree

10 files changed

+1384
-838
lines changed

10 files changed

+1384
-838
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -684,27 +684,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
684684
)
685685

686686

687-
def instrument_source_module_with_async_decorators(
688-
source_path: Path, function_to_optimize: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
689-
) -> tuple[bool, str | None]:
690-
if not function_to_optimize.is_async:
691-
return False, None
692-
693-
try:
694-
with source_path.open(encoding="utf8") as f:
695-
source_code = f.read()
696-
697-
modified_code, decorator_added = add_async_decorator_to_function(source_code, function_to_optimize, mode)
698-
699-
if decorator_added:
700-
return True, modified_code
701-
702-
except Exception:
703-
return False, None
704-
else:
705-
return False, None
706-
707-
708687
def inject_async_profiling_into_existing_test(
709688
test_path: Path,
710689
call_positions: list[CodePosition],
@@ -1288,25 +1267,29 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
12881267

12891268

12901269
def add_async_decorator_to_function(
1291-
source_code: str, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
1292-
) -> tuple[str, bool]:
1293-
"""Add async decorator to an async function definition.
1270+
source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
1271+
) -> bool:
1272+
"""Add async decorator to an async function definition and write back to file.
12941273
12951274
Args:
12961275
----
1297-
source_code: The source code to modify.
1276+
source_path: Path to the source file to modify in-place.
12981277
function: The FunctionToOptimize object representing the target async function.
12991278
mode: The testing mode to determine which decorator to apply.
13001279
13011280
Returns:
13021281
-------
1303-
Tuple of (modified_source_code, was_decorator_added).
1282+
Boolean indicating whether the decorator was successfully added.
13041283
13051284
"""
13061285
if not function.is_async:
1307-
return source_code, False
1286+
return False
13081287

13091288
try:
1289+
# Read source code
1290+
with source_path.open(encoding="utf8") as f:
1291+
source_code = f.read()
1292+
13101293
module = cst.parse_module(source_code)
13111294

13121295
# Add the decorator to the function
@@ -1318,10 +1301,17 @@ def add_async_decorator_to_function(
13181301
import_transformer = AsyncDecoratorImportAdder(mode)
13191302
module = module.visit(import_transformer)
13201303

1321-
return sort_imports(code=module.code, float_to_top=True), decorator_transformer.added_decorator
1304+
modified_code = sort_imports(code=module.code, float_to_top=True)
13221305
except Exception as e:
13231306
logger.exception(f"Error adding async decorator to function {function.qualified_name}: {e}")
1324-
return source_code, False
1307+
return False
1308+
else:
1309+
if decorator_transformer.added_decorator:
1310+
with source_path.open("w", encoding="utf8") as f:
1311+
f.write(modified_code)
1312+
logger.debug(f"Applied async {mode.value} instrumentation to {source_path}")
1313+
return True
1314+
return False
13251315

13261316

13271317
def create_instrumented_source_module_path(source_path: Path, temp_dir: Path) -> Path:

codeflash/context/unused_definition_remover.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -469,22 +469,32 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
469469
qualified_function_names: Set of function names to keep. For methods, use format 'classname.methodname'
470470
471471
"""
472-
module = cst.parse_module(code)
473-
# Collect all definitions (top level classes, variables or function)
474-
definitions = collect_top_level_definitions(module)
472+
try:
473+
module = cst.parse_module(code)
474+
except Exception as e:
475+
logger.debug(f"Failed to parse code with libcst: {type(e).__name__}: {e}")
476+
return code
475477

476-
# Collect dependencies between definitions using the visitor pattern
477-
dependency_collector = DependencyCollector(definitions)
478-
module.visit(dependency_collector)
478+
try:
479+
# Collect all definitions (top level classes, variables or function)
480+
definitions = collect_top_level_definitions(module)
479481

480-
# Mark definitions used by specified functions, and their dependencies recursively
481-
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
482-
usage_marker.mark_used_definitions()
482+
# Collect dependencies between definitions using the visitor pattern
483+
dependency_collector = DependencyCollector(definitions)
484+
module.visit(dependency_collector)
483485

484-
# Apply the recursive removal transformation
485-
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
486+
# Mark definitions used by specified functions, and their dependencies recursively
487+
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
488+
usage_marker.mark_used_definitions()
486489

487-
return modified_module.code if modified_module else ""
490+
# Apply the recursive removal transformation
491+
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
492+
493+
return modified_module.code if modified_module else "" # noqa: TRY300
494+
except Exception as e:
495+
# If any other error occurs during processing, return the original code
496+
logger.debug(f"Error processing code to remove unused definitions: {type(e).__name__}: {e}")
497+
return code
488498

489499

490500
def print_definitions(definitions: dict[str, UsageInfo]) -> None:

codeflash/github/PrComment.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@ class PrComment:
2121
winning_behavior_test_results: TestResults
2222
winning_benchmarking_test_results: TestResults
2323
benchmark_details: Optional[list[BenchmarkDetail]] = None
24+
original_async_throughput: Optional[int] = None
25+
best_async_throughput: Optional[int] = None
2426

25-
def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str, Optional[list[BenchmarkDetail]]]]:
27+
def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]:
2628
report_table = {
2729
test_type.to_name(): result
2830
for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items()
2931
if test_type.to_name()
3032
}
3133

32-
return {
34+
result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
3335
"optimization_explanation": self.optimization_explanation,
3436
"best_runtime": humanize_runtime(self.best_runtime),
3537
"original_runtime": humanize_runtime(self.original_runtime),
@@ -42,6 +44,12 @@ def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str, Option
4244
"benchmark_details": self.benchmark_details if self.benchmark_details else None,
4345
}
4446

47+
if self.original_async_throughput is not None and self.best_async_throughput is not None:
48+
result["original_async_throughput"] = str(self.original_async_throughput)
49+
result["best_async_throughput"] = str(self.best_async_throughput)
50+
51+
return result
52+
4553

4654
class FileDiffContent(BaseModel):
4755
oldContent: str # noqa: N815

codeflash/optimization/function_optimizer.py

Lines changed: 51 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -607,26 +607,32 @@ def determine_best_candidate(
607607
original_async_throughput=original_code_baseline.async_throughput,
608608
best_throughput_until_now=None,
609609
) and quantity_of_tests_critic(candidate_result):
610-
tree.add("This candidate is faster than the original code. 🚀") # TODO: Change this description
611-
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
612-
tree.add(
613-
f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} "
614-
f"(measured over {candidate_result.max_loop_count} "
615-
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
616-
)
617-
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
618-
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
619-
if (
610+
# For async functions, prioritize throughput metrics over runtime
611+
is_async = (
620612
original_code_baseline.async_throughput is not None
621613
and candidate_result.async_throughput is not None
622-
):
614+
)
615+
616+
if is_async:
623617
throughput_gain_value = throughput_gain(
624618
original_throughput=original_code_baseline.async_throughput,
625619
optimized_throughput=candidate_result.async_throughput,
626620
)
621+
tree.add("This candidate has better async throughput than the original code. 🚀")
627622
tree.add(f"Original async throughput: {original_code_baseline.async_throughput} executions")
628623
tree.add(f"Optimized async throughput: {candidate_result.async_throughput} executions")
629624
tree.add(f"Throughput improvement: {throughput_gain_value * 100:.1f}%")
625+
tree.add(f"Throughput ratio: {throughput_gain_value + 1:.3f}X")
626+
else:
627+
tree.add("This candidate is faster than the original code. 🚀")
628+
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
629+
tree.add(
630+
f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} "
631+
f"(measured over {candidate_result.max_loop_count} "
632+
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
633+
)
634+
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
635+
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
630636
line_profile_test_results = self.line_profiler_step(
631637
code_context=code_context,
632638
original_helper_code=original_helper_code,
@@ -681,22 +687,31 @@ def determine_best_candidate(
681687
)
682688
)
683689
else:
684-
tree.add(
685-
f"Summed runtime: {humanize_runtime(best_test_runtime)} "
686-
f"(measured over {candidate_result.max_loop_count} "
687-
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
688-
)
689-
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
690-
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
691-
if (
690+
# For async functions, prioritize throughput metrics over runtime even for slow candidates
691+
is_async = (
692692
original_code_baseline.async_throughput is not None
693693
and candidate_result.async_throughput is not None
694-
):
694+
)
695+
696+
if is_async:
695697
throughput_gain_value = throughput_gain(
696698
original_throughput=original_code_baseline.async_throughput,
697699
optimized_throughput=candidate_result.async_throughput,
698700
)
699-
tree.add(f"Throughput gain: {throughput_gain_value * 100:.1f}%")
701+
tree.add(f"Async throughput: {candidate_result.async_throughput} executions")
702+
tree.add(f"Throughput change: {throughput_gain_value * 100:.1f}%")
703+
tree.add(
704+
f"(Runtime for reference: {humanize_runtime(best_test_runtime)} over "
705+
f"{candidate_result.max_loop_count} loop{'s' if candidate_result.max_loop_count > 1 else ''})"
706+
)
707+
else:
708+
tree.add(
709+
f"Summed runtime: {humanize_runtime(best_test_runtime)} "
710+
f"(measured over {candidate_result.max_loop_count} "
711+
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
712+
)
713+
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
714+
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
700715

701716
if is_LSP_enabled():
702717
lsp_log(LspMarkdownMessage(markdown=tree_to_markdown(tree)))
@@ -1502,16 +1517,21 @@ def process_review(
15021517
raise_pr = not self.args.no_pr
15031518
staging_review = self.args.staging_review
15041519
opt_review_response = ""
1505-
if raise_pr or staging_review:
1520+
# Skip optimization review for async functions for now
1521+
if (raise_pr or staging_review) and not self.function_to_optimize.is_async:
15061522
data["root_dir"] = git_root_dir()
15071523
try:
15081524
opt_review_response = self.aiservice_client.get_optimization_review(
15091525
**data, calling_fn_details=function_references
15101526
)
15111527
except Exception as e:
15121528
logger.debug(f"optimization review response failed, investigate {e}")
1513-
data["optimization_review"] = opt_review_response
1529+
# Always set optimization_review in data (empty string for async functions)
1530+
data["optimization_review"] = opt_review_response
15141531
if raise_pr and not staging_review and opt_review_response != "low":
1532+
# Ensure root_dir is set for PR creation (needed for async functions that skip opt_review)
1533+
if "root_dir" not in data:
1534+
data["root_dir"] = git_root_dir()
15151535
data["git_remote"] = self.args.git_remote
15161536
check_create_pr(**data)
15171537
elif staging_review:
@@ -1579,15 +1599,11 @@ def establish_original_code_baseline(
15791599
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
15801600

15811601
if self.function_to_optimize.is_async:
1582-
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
1602+
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
15831603

1584-
success, instrumented_source = instrument_source_module_with_async_decorators(
1604+
success = add_async_decorator_to_function(
15851605
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
15861606
)
1587-
if success and instrumented_source:
1588-
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
1589-
f.write(instrumented_source)
1590-
logger.debug(f"Applied async instrumentation to {self.function_to_optimize.file_path}")
15911607

15921608
# Instrument codeflash capture
15931609
with progress_bar("Running tests to establish original code behavior..."):
@@ -1632,19 +1648,11 @@ def establish_original_code_baseline(
16321648
console.rule()
16331649
with progress_bar("Running performance benchmarks..."):
16341650
if self.function_to_optimize.is_async:
1635-
from codeflash.code_utils.instrument_existing_tests import (
1636-
instrument_source_module_with_async_decorators,
1637-
)
1651+
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
16381652

1639-
success, instrumented_source = instrument_source_module_with_async_decorators(
1653+
add_async_decorator_to_function(
16401654
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
16411655
)
1642-
if success and instrumented_source:
1643-
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
1644-
f.write(instrumented_source)
1645-
logger.debug(
1646-
f"Applied async performance instrumentation to {self.function_to_optimize.file_path}"
1647-
)
16481656

16491657
try:
16501658
benchmarking_results, _ = self.run_and_parse_tests(
@@ -1767,19 +1775,11 @@ def run_optimized_candidate(
17671775
for module_abspath in original_helper_code:
17681776
candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8")
17691777
if self.function_to_optimize.is_async:
1770-
from codeflash.code_utils.instrument_existing_tests import (
1771-
instrument_source_module_with_async_decorators,
1772-
)
1778+
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
17731779

1774-
success, instrumented_source = instrument_source_module_with_async_decorators(
1780+
add_async_decorator_to_function(
17751781
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
17761782
)
1777-
if success and instrumented_source:
1778-
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
1779-
f.write(instrumented_source)
1780-
logger.debug(
1781-
f"Applied async behavioral instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}"
1782-
)
17831783

17841784
try:
17851785
instrument_codeflash_capture(
@@ -1820,19 +1820,11 @@ def run_optimized_candidate(
18201820
if test_framework == "pytest":
18211821
# For async functions, instrument at definition site for performance benchmarking
18221822
if self.function_to_optimize.is_async:
1823-
from codeflash.code_utils.instrument_existing_tests import (
1824-
instrument_source_module_with_async_decorators,
1825-
)
1823+
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
18261824

1827-
success, instrumented_source = instrument_source_module_with_async_decorators(
1825+
add_async_decorator_to_function(
18281826
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
18291827
)
1830-
if success and instrumented_source:
1831-
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
1832-
f.write(instrumented_source)
1833-
logger.debug(
1834-
f"Applied async performance instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}"
1835-
)
18361828

18371829
try:
18381830
candidate_benchmarking_results, _ = self.run_and_parse_tests(

codeflash/result/create_pr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ def check_create_pr(
220220
winning_behavior_test_results=explanation.winning_behavior_test_results,
221221
winning_benchmarking_test_results=explanation.winning_benchmarking_test_results,
222222
benchmark_details=explanation.benchmark_details,
223+
original_async_throughput=explanation.original_async_throughput,
224+
best_async_throughput=explanation.best_async_throughput,
223225
),
224226
existing_tests=existing_tests_source,
225227
generated_tests=generated_original_test_source,
@@ -270,6 +272,8 @@ def check_create_pr(
270272
winning_behavior_test_results=explanation.winning_behavior_test_results,
271273
winning_benchmarking_test_results=explanation.winning_benchmarking_test_results,
272274
benchmark_details=explanation.benchmark_details,
275+
original_async_throughput=explanation.original_async_throughput,
276+
best_async_throughput=explanation.best_async_throughput,
273277
),
274278
existing_tests=existing_tests_source,
275279
generated_tests=generated_original_test_source,

0 commit comments

Comments
 (0)