|
15 | 15 |
|
16 | 16 | if TYPE_CHECKING: |
17 | 17 | from codeflash.discovery.functions_to_optimize import FunctionToOptimize |
18 | | - from codeflash.models.models import CodeOptimizationContext, FunctionParent, FunctionSource |
| 18 | + from codeflash.models.models import CodeOptimizationContext, FunctionSource |
19 | 19 |
|
20 | 20 |
|
21 | 21 | @dataclass |
@@ -615,23 +615,29 @@ def _analyze_imports_in_optimized_code( |
615 | 615 | def find_target_node( |
616 | 616 | root: ast.AST, function_to_optimize: FunctionToOptimize |
617 | 617 | ) -> Optional[ast.FunctionDef | ast.AsyncFunctionDef]: |
618 | | - def _find(node: ast.AST, parents: list[FunctionParent]) -> Optional[ast.FunctionDef | ast.AsyncFunctionDef]: |
619 | | - if not parents: |
620 | | - for child in getattr(node, "body", []): |
621 | | - if ( |
622 | | - isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) |
623 | | - and child.name == function_to_optimize.function_name |
624 | | - ): |
625 | | - return child |
| 618 | + parents = function_to_optimize.parents |
| 619 | + node = root |
| 620 | + for parent in parents: |
| 621 | + # Fast loop: directly look for the matching ClassDef in node.body |
| 622 | + body = getattr(node, "body", None) |
| 623 | + if not body: |
626 | 624 | return None |
627 | | - |
628 | | - parent = parents[0] |
629 | | - for child in getattr(node, "body", []): |
| 625 | + for child in body: |
630 | 626 | if isinstance(child, ast.ClassDef) and child.name == parent.name: |
631 | | - return _find(child, parents[1:]) |
632 | | - return None |
| 627 | + node = child |
| 628 | + break |
| 629 | + else: |
| 630 | + return None |
633 | 631 |
|
634 | | - return _find(root, function_to_optimize.parents) |
| 632 | + # Now node is either the root or the target parent class; look for function |
| 633 | + body = getattr(node, "body", None) |
| 634 | + if not body: |
| 635 | + return None |
| 636 | + target_name = function_to_optimize.function_name |
| 637 | + for child in body: |
| 638 | + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == target_name: |
| 639 | + return child |
| 640 | + return None |
635 | 641 |
|
636 | 642 |
|
637 | 643 | def detect_unused_helper_functions( |
|
0 commit comments