Skip to content

Commit ce19abf

Browse files
authored
Merge branch 'main' into small-fixes
2 parents dca0374 + b28521a commit ce19abf

File tree

12 files changed

+1409
-857
lines changed

12 files changed

+1409
-857
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -249,25 +249,29 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path)
249249

250250

251251
def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
252-
if hasattr(args, "all"):
253-
import git
254-
255-
from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
256-
from codeflash.code_utils.github_utils import require_github_app_or_exit
257-
258-
# Ensure that the user can actually open PRs on the repo.
259-
try:
260-
git_repo = git.Repo(search_parent_directories=True)
261-
except git.exc.InvalidGitRepositoryError:
262-
logger.exception(
263-
"I couldn't find a git repository in the current directory. "
264-
"I need a git repository to run --all and open PRs for optimizations. Exiting..."
265-
)
266-
apologize_and_exit()
267-
if not args.no_pr and not check_and_push_branch(git_repo, git_remote=args.git_remote):
268-
exit_with_message("Branch is not pushed...", error_on_exit=True)
269-
owner, repo = get_repo_owner_and_name(git_repo)
270-
if not args.no_pr:
252+
if hasattr(args, "all") or (hasattr(args, "file") and args.file):
253+
no_pr = getattr(args, "no_pr", False)
254+
255+
if not no_pr:
256+
import git
257+
258+
from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
259+
from codeflash.code_utils.github_utils import require_github_app_or_exit
260+
261+
# Ensure that the user can actually open PRs on the repo.
262+
try:
263+
git_repo = git.Repo(search_parent_directories=True)
264+
except git.exc.InvalidGitRepositoryError:
265+
mode = "--all" if hasattr(args, "all") else "--file"
266+
logger.exception(
267+
f"I couldn't find a git repository in the current directory. "
268+
f"I need a git repository to run {mode} and open PRs for optimizations. Exiting..."
269+
)
270+
apologize_and_exit()
271+
git_remote = getattr(args, "git_remote", None)
272+
if not check_and_push_branch(git_repo, git_remote=git_remote):
273+
exit_with_message("Branch is not pushed...", error_on_exit=True)
274+
owner, repo = get_repo_owner_and_name(git_repo)
271275
require_github_app_or_exit(owner, repo)
272276
if not hasattr(args, "all"):
273277
args.all = None

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -684,27 +684,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
684684
)
685685

686686

687-
def instrument_source_module_with_async_decorators(
688-
source_path: Path, function_to_optimize: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
689-
) -> tuple[bool, str | None]:
690-
if not function_to_optimize.is_async:
691-
return False, None
692-
693-
try:
694-
with source_path.open(encoding="utf8") as f:
695-
source_code = f.read()
696-
697-
modified_code, decorator_added = add_async_decorator_to_function(source_code, function_to_optimize, mode)
698-
699-
if decorator_added:
700-
return True, modified_code
701-
702-
except Exception:
703-
return False, None
704-
else:
705-
return False, None
706-
707-
708687
def inject_async_profiling_into_existing_test(
709688
test_path: Path,
710689
call_positions: list[CodePosition],
@@ -1288,25 +1267,29 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
12881267

12891268

12901269
def add_async_decorator_to_function(
1291-
source_code: str, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
1292-
) -> tuple[str, bool]:
1293-
"""Add async decorator to an async function definition.
1270+
source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
1271+
) -> bool:
1272+
"""Add async decorator to an async function definition and write back to file.
12941273
12951274
Args:
12961275
----
1297-
source_code: The source code to modify.
1276+
source_path: Path to the source file to modify in-place.
12981277
function: The FunctionToOptimize object representing the target async function.
12991278
mode: The testing mode to determine which decorator to apply.
13001279
13011280
Returns:
13021281
-------
1303-
Tuple of (modified_source_code, was_decorator_added).
1282+
Boolean indicating whether the decorator was successfully added.
13041283
13051284
"""
13061285
if not function.is_async:
1307-
return source_code, False
1286+
return False
13081287

13091288
try:
1289+
# Read source code
1290+
with source_path.open(encoding="utf8") as f:
1291+
source_code = f.read()
1292+
13101293
module = cst.parse_module(source_code)
13111294

13121295
# Add the decorator to the function
@@ -1318,10 +1301,17 @@ def add_async_decorator_to_function(
13181301
import_transformer = AsyncDecoratorImportAdder(mode)
13191302
module = module.visit(import_transformer)
13201303

1321-
return sort_imports(code=module.code, float_to_top=True), decorator_transformer.added_decorator
1304+
modified_code = sort_imports(code=module.code, float_to_top=True)
13221305
except Exception as e:
13231306
logger.exception(f"Error adding async decorator to function {function.qualified_name}: {e}")
1324-
return source_code, False
1307+
return False
1308+
else:
1309+
if decorator_transformer.added_decorator:
1310+
with source_path.open("w", encoding="utf8") as f:
1311+
f.write(modified_code)
1312+
logger.debug(f"Applied async {mode.value} instrumentation to {source_path}")
1313+
return True
1314+
return False
13251315

13261316

13271317
def create_instrumented_source_module_path(source_path: Path, temp_dir: Path) -> Path:

codeflash/context/unused_definition_remover.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -469,22 +469,32 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
469469
qualified_function_names: Set of function names to keep. For methods, use format 'classname.methodname'
470470
471471
"""
472-
module = cst.parse_module(code)
473-
# Collect all definitions (top level classes, variables or function)
474-
definitions = collect_top_level_definitions(module)
472+
try:
473+
module = cst.parse_module(code)
474+
except Exception as e:
475+
logger.debug(f"Failed to parse code with libcst: {type(e).__name__}: {e}")
476+
return code
475477

476-
# Collect dependencies between definitions using the visitor pattern
477-
dependency_collector = DependencyCollector(definitions)
478-
module.visit(dependency_collector)
478+
try:
479+
# Collect all definitions (top level classes, variables or function)
480+
definitions = collect_top_level_definitions(module)
479481

480-
# Mark definitions used by specified functions, and their dependencies recursively
481-
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
482-
usage_marker.mark_used_definitions()
482+
# Collect dependencies between definitions using the visitor pattern
483+
dependency_collector = DependencyCollector(definitions)
484+
module.visit(dependency_collector)
483485

484-
# Apply the recursive removal transformation
485-
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
486+
# Mark definitions used by specified functions, and their dependencies recursively
487+
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
488+
usage_marker.mark_used_definitions()
486489

487-
return modified_module.code if modified_module else ""
490+
# Apply the recursive removal transformation
491+
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
492+
493+
return modified_module.code if modified_module else "" # noqa: TRY300
494+
except Exception as e:
495+
# If any other error occurs during processing, return the original code
496+
logger.debug(f"Error processing code to remove unused definitions: {type(e).__name__}: {e}")
497+
return code
488498

489499

490500
def print_definitions(definitions: dict[str, UsageInfo]) -> None:

codeflash/github/PrComment.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@ class PrComment:
2121
winning_behavior_test_results: TestResults
2222
winning_benchmarking_test_results: TestResults
2323
benchmark_details: Optional[list[BenchmarkDetail]] = None
24+
original_async_throughput: Optional[int] = None
25+
best_async_throughput: Optional[int] = None
2426

25-
def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str, Optional[list[BenchmarkDetail]]]]:
27+
def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]:
2628
report_table = {
2729
test_type.to_name(): result
2830
for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items()
2931
if test_type.to_name()
3032
}
3133

32-
return {
34+
result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
3335
"optimization_explanation": self.optimization_explanation,
3436
"best_runtime": humanize_runtime(self.best_runtime),
3537
"original_runtime": humanize_runtime(self.original_runtime),
@@ -42,6 +44,12 @@ def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str, Option
4244
"benchmark_details": self.benchmark_details if self.benchmark_details else None,
4345
}
4446

47+
if self.original_async_throughput is not None and self.best_async_throughput is not None:
48+
result["original_async_throughput"] = str(self.original_async_throughput)
49+
result["best_async_throughput"] = str(self.best_async_throughput)
50+
51+
return result
52+
4553

4654
class FileDiffContent(BaseModel):
4755
oldContent: str # noqa: N815

0 commit comments

Comments
 (0)