Skip to content

Commit 99cb908

Browse files
test & correct last import index
Signed-off-by: ali <mohammed18200118@gmail.com>
1 parent e60608e commit 99cb908

File tree

2 files changed

+233
-2
lines changed

2 files changed

+233
-2
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c
121121
return updated_node
122122

123123
def _find_insertion_index(self, updated_node: cst.Module) -> int:
124+
"""Find the position of the last import statement in the top-level of the module."""
124125
insert_index = 0
125126
for i, stmt in enumerate(updated_node.body):
126127
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
@@ -135,9 +136,14 @@ def _find_insertion_index(self, updated_node: cst.Module) -> int:
135136

136137
if is_top_level_import or is_conditional_import:
137138
insert_index = i + 1
138-
else:
139-
# stop when we find the first non-import statement
139+
140+
# Stop scanning once we reach a class or function definition.
141+
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
142+
# Without this check, a stray import later in the file
143+
# would incorrectly shift our insertion index below actual code definitions.
144+
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
140145
break
146+
141147
return insert_index
142148

143149
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:

tests/test_code_replacement.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3228,3 +3228,228 @@ def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
32283228
assert not re.search(r"^import aiohttp as aiohttp_\b", new_code, re.MULTILINE) # conditional alias import: import <name> as <alias>
32293229
assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports
32303230
assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import
3231+
3232+
def test_top_level_global_assignments() -> None:
3233+
root_dir = Path(__file__).parent.parent.resolve()
3234+
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
3235+
3236+
original_code = '''"""
3237+
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
3238+
"""
3239+
3240+
from typing import Any, Dict, List, Tuple
3241+
3242+
import structlog
3243+
from pydantic import BaseModel
3244+
3245+
from skyvern.forge import app
3246+
from skyvern.forge.sdk.prompting import PromptEngine
3247+
from skyvern.webeye.actions.actions import ActionType
3248+
3249+
LOG = structlog.get_logger(__name__)
3250+
3251+
# Initialize prompt engine
3252+
prompt_engine = PromptEngine("skyvern")
3253+
3254+
3255+
def hydrate_input_text_actions_with_field_names(
3256+
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
3257+
) -> Dict[str, List[Dict[str, Any]]]:
3258+
"""
3259+
Add field_name to input_text actions based on generated mappings.
3260+
3261+
Args:
3262+
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
3263+
field_mappings: Dictionary mapping "task_id:action_id" to field names
3264+
3265+
Returns:
3266+
Updated actions_by_task with field_name added to input_text actions
3267+
"""
3268+
updated_actions_by_task = {}
3269+
3270+
for task_id, actions in actions_by_task.items():
3271+
updated_actions = []
3272+
3273+
for action in actions:
3274+
action_copy = action.copy()
3275+
3276+
if action.get("action_type") == ActionType.INPUT_TEXT:
3277+
action_id = action.get("action_id", "")
3278+
mapping_key = f"{task_id}:{action_id}"
3279+
3280+
if mapping_key in field_mappings:
3281+
action_copy["field_name"] = field_mappings[mapping_key]
3282+
else:
3283+
# Fallback field name if mapping not found
3284+
intention = action.get("intention", "")
3285+
if intention:
3286+
# Simple field name generation from intention
3287+
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
3288+
field_name = "".join(c for c in field_name if c.isalnum() or c == "_")
3289+
action_copy["field_name"] = field_name or "unknown_field"
3290+
else:
3291+
action_copy["field_name"] = "unknown_field"
3292+
3293+
updated_actions.append(action_copy)
3294+
3295+
updated_actions_by_task[task_id] = updated_actions
3296+
3297+
return updated_actions_by_task
3298+
'''
3299+
main_file.write_text(original_code, encoding="utf-8")
3300+
optim_code = f'''```python:{main_file.relative_to(root_dir)}
3301+
from skyvern.webeye.actions.actions import ActionType
3302+
from typing import Any, Dict, List
3303+
import re
3304+
3305+
# Precompiled regex for efficiently generating simple field_name from intention
3306+
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
3307+
3308+
def hydrate_input_text_actions_with_field_names(
3309+
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
3310+
) -> Dict[str, List[Dict[str, Any]]]:
3311+
"""
3312+
Add field_name to input_text actions based on generated mappings.
3313+
3314+
Args:
3315+
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
3316+
field_mappings: Dictionary mapping "task_id:action_id" to field names
3317+
3318+
Returns:
3319+
Updated actions_by_task with field_name added to input_text actions
3320+
"""
3321+
updated_actions_by_task = {{}}
3322+
3323+
input_text_type = ActionType.INPUT_TEXT # local variable for faster access
3324+
intention_cleanup = _INTENTION_CLEANUP_RE
3325+
3326+
for task_id, actions in actions_by_task.items():
3327+
updated_actions = []
3328+
3329+
for action in actions:
3330+
action_copy = action.copy()
3331+
3332+
if action.get("action_type") == input_text_type:
3333+
action_id = action.get("action_id", "")
3334+
mapping_key = f"{{task_id}}:{{action_id}}"
3335+
3336+
if mapping_key in field_mappings:
3337+
action_copy["field_name"] = field_mappings[mapping_key]
3338+
else:
3339+
# Fallback field name if mapping not found
3340+
intention = action.get("intention", "")
3341+
if intention:
3342+
# Simple field name generation from intention
3343+
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
3344+
# Use compiled regex instead of "".join(c for ...)
3345+
field_name = intention_cleanup.sub("", field_name)
3346+
action_copy["field_name"] = field_name or "unknown_field"
3347+
else:
3348+
action_copy["field_name"] = "unknown_field"
3349+
3350+
updated_actions.append(action_copy)
3351+
3352+
updated_actions_by_task[task_id] = updated_actions
3353+
3354+
return updated_actions_by_task
3355+
```
3356+
'''
3357+
expected = '''"""
3358+
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
3359+
"""
3360+
3361+
from typing import Any, Dict, List, Tuple
3362+
3363+
import structlog
3364+
from pydantic import BaseModel
3365+
3366+
from skyvern.forge import app
3367+
from skyvern.forge.sdk.prompting import PromptEngine
3368+
from skyvern.webeye.actions.actions import ActionType
3369+
import re
3370+
3371+
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
3372+
3373+
LOG = structlog.get_logger(__name__)
3374+
3375+
# Initialize prompt engine
3376+
prompt_engine = PromptEngine("skyvern")
3377+
3378+
3379+
def hydrate_input_text_actions_with_field_names(
3380+
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
3381+
) -> Dict[str, List[Dict[str, Any]]]:
3382+
"""
3383+
Add field_name to input_text actions based on generated mappings.
3384+
3385+
Args:
3386+
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
3387+
field_mappings: Dictionary mapping "task_id:action_id" to field names
3388+
3389+
Returns:
3390+
Updated actions_by_task with field_name added to input_text actions
3391+
"""
3392+
updated_actions_by_task = {}
3393+
3394+
input_text_type = ActionType.INPUT_TEXT # local variable for faster access
3395+
intention_cleanup = _INTENTION_CLEANUP_RE
3396+
3397+
for task_id, actions in actions_by_task.items():
3398+
updated_actions = []
3399+
3400+
for action in actions:
3401+
action_copy = action.copy()
3402+
3403+
if action.get("action_type") == input_text_type:
3404+
action_id = action.get("action_id", "")
3405+
mapping_key = f"{task_id}:{action_id}"
3406+
3407+
if mapping_key in field_mappings:
3408+
action_copy["field_name"] = field_mappings[mapping_key]
3409+
else:
3410+
# Fallback field name if mapping not found
3411+
intention = action.get("intention", "")
3412+
if intention:
3413+
# Simple field name generation from intention
3414+
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
3415+
# Use compiled regex instead of "".join(c for ...)
3416+
field_name = intention_cleanup.sub("", field_name)
3417+
action_copy["field_name"] = field_name or "unknown_field"
3418+
else:
3419+
action_copy["field_name"] = "unknown_field"
3420+
3421+
updated_actions.append(action_copy)
3422+
3423+
updated_actions_by_task[task_id] = updated_actions
3424+
3425+
return updated_actions_by_task
3426+
'''
3427+
3428+
func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)
3429+
test_config = TestConfig(
3430+
tests_root=root_dir / "tests/pytest",
3431+
tests_project_rootdir=root_dir,
3432+
project_root_path=root_dir,
3433+
test_framework="pytest",
3434+
pytest_cmd="pytest",
3435+
)
3436+
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
3437+
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
3438+
3439+
original_helper_code: dict[Path, str] = {}
3440+
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
3441+
for helper_function_path in helper_function_paths:
3442+
with helper_function_path.open(encoding="utf8") as f:
3443+
helper_code = f.read()
3444+
original_helper_code[helper_function_path] = helper_code
3445+
3446+
func_optimizer.args = Args()
3447+
func_optimizer.replace_function_and_helpers_with_optimized_code(
3448+
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
3449+
)
3450+
3451+
3452+
new_code = main_file.read_text(encoding="utf-8")
3453+
main_file.unlink(missing_ok=True)
3454+
3455+
assert new_code == expected

0 commit comments

Comments
 (0)