Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/e2e-bubblesort-unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ jobs:
- name: Install dependencies (CLI)
run: |
uv sync
uv add timeout_decorator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add to toml and update lock file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh my bad, we're removing this from the dependencies, do we need the workflow file then?


- name: Run Codeflash to optimize code
id: optimize_code
Expand Down
54 changes: 0 additions & 54 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import ast
import platform
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -329,17 +328,6 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
if node.name.startswith("test_"):
did_update = False
if self.test_framework == "unittest" and platform.system() != "Windows":
# Only add timeout decorator on non-Windows platforms
# Windows doesn't support SIGALRM signal required by timeout_decorator

node.decorator_list.append(
ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
)
i = len(node.body) - 1
while i >= 0:
line_node = node.body[i]
Expand Down Expand Up @@ -505,25 +493,6 @@ def __init__(
self.class_name = function.top_level_parent_name

def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
# Add timeout decorator for unittest test classes if needed
if self.test_framework == "unittest":
timeout_decorator = ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
for item in node.body:
if (
isinstance(item, ast.FunctionDef)
and item.name.startswith("test_")
and not any(
isinstance(d, ast.Call)
and isinstance(d.func, ast.Name)
and d.func.id == "timeout_decorator.timeout"
for d in item.decorator_list
)
):
item.decorator_list.append(timeout_decorator)
return self.generic_visit(node)

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
Expand All @@ -542,25 +511,6 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
def _process_test_function(
self, node: ast.AsyncFunctionDef | ast.FunctionDef
) -> ast.AsyncFunctionDef | ast.FunctionDef:
# Optimize the search for decorator presence
if self.test_framework == "unittest":
found_timeout = False
for d in node.decorator_list:
# Avoid isinstance(d.func, ast.Name) if d is not ast.Call
if isinstance(d, ast.Call):
f = d.func
# Avoid attribute lookup if f is not ast.Name
if isinstance(f, ast.Name) and f.id == "timeout_decorator.timeout":
found_timeout = True
break
if not found_timeout:
timeout_decorator = ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
node.decorator_list.append(timeout_decorator)

# Initialize counter for this test function
if node.name not in self.async_call_counter:
self.async_call_counter[node.name] = 0
Expand Down Expand Up @@ -715,8 +665,6 @@ def inject_async_profiling_into_existing_test(

# Add necessary imports
new_imports = [ast.Import(names=[ast.alias(name="os")])]
if test_framework == "unittest":
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))

tree.body = [*new_imports, *tree.body]
return True, sort_imports(ast.unparse(tree), float_to_top=True)
Expand Down Expand Up @@ -762,8 +710,6 @@ def inject_profiling_into_existing_test(
ast.Import(names=[ast.alias(name="dill", asname="pickle")]),
]
)
if test_framework == "unittest" and platform.system() != "Windows":
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
additional_functions = [create_wrapper_function(mode)]

tree.body = [*new_imports, *additional_functions, *tree.body]
Expand Down
153 changes: 49 additions & 104 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import queue
import random
import subprocess
import time
import uuid
from collections import defaultdict
from pathlib import Path
Expand Down Expand Up @@ -1641,57 +1640,34 @@ def establish_original_code_baseline(
f"Test coverage is {coverage_results.coverage}%, which is below the required threshold of {COVERAGE_THRESHOLD}%."
)

if test_framework == "pytest":
with progress_bar("Running line profiler to identify performance bottlenecks..."):
line_profile_results = self.line_profiler_step(
code_context=code_context, original_helper_code=original_helper_code, candidate_index=0
)
console.rule()
with progress_bar("Running performance benchmarks..."):
if self.function_to_optimize.is_async:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
with progress_bar("Running line profiler to identify performance bottlenecks..."):
line_profile_results = self.line_profiler_step(
code_context=code_context, original_helper_code=original_helper_code, candidate_index=0
)
console.rule()
with progress_bar("Running performance benchmarks..."):
if self.function_to_optimize.is_async:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function

add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
)
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
)

try:
benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=total_looping_time,
enable_coverage=False,
code_context=code_context,
)
finally:
if self.function_to_optimize.is_async:
self.write_code_and_helpers(
self.function_to_optimize_source_code,
original_helper_code,
self.function_to_optimize.file_path,
)
else:
benchmarking_results = TestResults()
start_time: float = time.time()
for i in range(100):
if i >= 5 and time.time() - start_time >= total_looping_time * 1.5:
# * 1.5 to give unittest a bit more time to run
break
test_env["CODEFLASH_LOOP_INDEX"] = str(i + 1)
with progress_bar("Running performance benchmarks..."):
unittest_loop_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=total_looping_time,
enable_coverage=False,
code_context=code_context,
unittest_loop_index=i + 1,
try:
benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=total_looping_time,
enable_coverage=False,
code_context=code_context,
)
finally:
if self.function_to_optimize.is_async:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
benchmarking_results.merge(unittest_loop_results)

console.print(
TestResults.report_to_tree(
Expand Down Expand Up @@ -1760,8 +1736,6 @@ def run_optimized_candidate(
original_helper_code: dict[Path, str],
file_path_to_helper_classes: dict[Path, set[str]],
) -> Result[OptimizedCandidateResult, str]:
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018

with progress_bar("Testing optimization candidate"):
test_env = self.get_test_env(
codeflash_loop_index=0,
Expand Down Expand Up @@ -1818,59 +1792,34 @@ def run_optimized_candidate(

logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")

if test_framework == "pytest":
# For async functions, instrument at definition site for performance benchmarking
if self.function_to_optimize.is_async:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function

add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
)
# For async functions, instrument at definition site for performance benchmarking
if self.function_to_optimize.is_async:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function

try:
candidate_benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=optimization_candidate_index,
testing_time=total_looping_time,
enable_coverage=False,
)
finally:
# Restore original source if we instrumented it
if self.function_to_optimize.is_async:
self.write_code_and_helpers(
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
)
loop_count = (
max(all_loop_indices)
if (
all_loop_indices := {
result.loop_index for result in candidate_benchmarking_results.test_results
}
)
else 0
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
)

else:
candidate_benchmarking_results = TestResults()
start_time: float = time.time()
loop_count = 0
for i in range(100):
if i >= 5 and time.time() - start_time >= TOTAL_LOOPING_TIME_EFFECTIVE * 1.5:
# * 1.5 to give unittest a bit more time to run
break
test_env["CODEFLASH_LOOP_INDEX"] = str(i + 1)
unittest_loop_results, _cov = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=optimization_candidate_index,
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
unittest_loop_index=i + 1,
try:
candidate_benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=optimization_candidate_index,
testing_time=total_looping_time,
enable_coverage=False,
)
finally:
# Restore original source if we instrumented it
if self.function_to_optimize.is_async:
self.write_code_and_helpers(
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
)
loop_count = i + 1
candidate_benchmarking_results.merge(unittest_loop_results)
loop_count = (
max(all_loop_indices)
if (all_loop_indices := {result.loop_index for result in candidate_benchmarking_results.test_results})
else 0
)

if (total_candidate_timing := candidate_benchmarking_results.total_passed_runtime()) == 0:
logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.")
Expand Down Expand Up @@ -1920,7 +1869,6 @@ def run_and_parse_tests(
pytest_min_loops: int = 5,
pytest_max_loops: int = 250,
code_context: CodeOptimizationContext | None = None,
unittest_loop_index: int | None = None,
line_profiler_output_file: Path | None = None,
) -> tuple[TestResults | dict, CoverageData | None]:
coverage_database_file = None
Expand All @@ -1933,7 +1881,6 @@ def run_and_parse_tests(
cwd=self.project_root,
test_env=test_env,
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
verbose=True,
enable_coverage=enable_coverage,
)
elif testing_type == TestingMode.LINE_PROFILE:
Expand All @@ -1947,7 +1894,6 @@ def run_and_parse_tests(
pytest_min_loops=1,
pytest_max_loops=1,
test_framework=self.test_cfg.test_framework,
line_profiler_output_file=line_profiler_output_file,
)
elif testing_type == TestingMode.PERFORMANCE:
result_file_path, run_result = run_benchmarking_tests(
Expand Down Expand Up @@ -1996,7 +1942,6 @@ def run_and_parse_tests(
test_config=self.test_cfg,
optimization_iteration=optimization_iteration,
run_result=run_result,
unittest_loop_index=unittest_loop_index,
function_name=self.function_to_optimize.function_name,
source_file=self.function_to_optimize.file_path,
code_context=code_context,
Expand Down
Loading
Loading