Skip to content

Commit 35378a3

Browse files
committed
Merge remote-tracking branch 'origin/main' into cf-842
2 parents d69fc2a + 0168944 commit 35378a3

23 files changed

+1631
-883
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -249,25 +249,29 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path)
249249

250250

251251
def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
252-
if hasattr(args, "all"):
253-
import git
254-
255-
from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
256-
from codeflash.code_utils.github_utils import require_github_app_or_exit
257-
258-
# Ensure that the user can actually open PRs on the repo.
259-
try:
260-
git_repo = git.Repo(search_parent_directories=True)
261-
except git.exc.InvalidGitRepositoryError:
262-
logger.exception(
263-
"I couldn't find a git repository in the current directory. "
264-
"I need a git repository to run --all and open PRs for optimizations. Exiting..."
265-
)
266-
apologize_and_exit()
267-
if not args.no_pr and not check_and_push_branch(git_repo, git_remote=args.git_remote):
268-
exit_with_message("Branch is not pushed...", error_on_exit=True)
269-
owner, repo = get_repo_owner_and_name(git_repo)
270-
if not args.no_pr:
252+
if hasattr(args, "all") or (hasattr(args, "file") and args.file):
253+
no_pr = getattr(args, "no_pr", False)
254+
255+
if not no_pr:
256+
import git
257+
258+
from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
259+
from codeflash.code_utils.github_utils import require_github_app_or_exit
260+
261+
# Ensure that the user can actually open PRs on the repo.
262+
try:
263+
git_repo = git.Repo(search_parent_directories=True)
264+
except git.exc.InvalidGitRepositoryError:
265+
mode = "--all" if hasattr(args, "all") else "--file"
266+
logger.exception(
267+
f"I couldn't find a git repository in the current directory. "
268+
f"I need a git repository to run {mode} and open PRs for optimizations. Exiting..."
269+
)
270+
apologize_and_exit()
271+
git_remote = getattr(args, "git_remote", None)
272+
if not check_and_push_branch(git_repo, git_remote=git_remote):
273+
exit_with_message("Branch is not pushed...", error_on_exit=True)
274+
owner, repo = get_repo_owner_and_name(git_repo)
271275
require_github_app_or_exit(owner, repo)
272276
if not hasattr(args, "all"):
273277
args.all = None

codeflash/cli_cmds/cmd_init.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from rich.table import Table
2323
from rich.text import Text
2424

25-
from codeflash.api.cfapi import is_github_app_installed_on_repo
25+
from codeflash.api.cfapi import get_user_id, is_github_app_installed_on_repo
2626
from codeflash.cli_cmds.cli_common import apologize_and_exit
2727
from codeflash.cli_cmds.console import console, logger
2828
from codeflash.cli_cmds.extension import install_vscode_extension
@@ -1216,6 +1216,7 @@ def enter_api_key_and_save_to_rc() -> None:
12161216
# On Windows, create a batch file in the user's home directory (not auto-run, just used to store api key)
12171217
shell_rc_path.touch()
12181218
click.echo(f"✅ Created {shell_rc_path}")
1219+
get_user_id(api_key=api_key) # Used to verify whether the API key is valid.
12191220
result = save_api_key_to_rc(api_key)
12201221
if is_successful(result):
12211222
click.echo(result.unwrap())
@@ -1373,7 +1374,7 @@ def ask_for_telemetry() -> bool:
13731374
from rich.prompt import Confirm
13741375

13751376
return Confirm.ask(
1376-
"⚡️ Would you like to enable telemetry to help us improve the Codeflash experience?",
1377+
"⚡️ Help us improve Codeflash by sharing anonymous usage data (e.g. errors encountered)?",
13771378
default=True,
13781379
show_default=True,
13791380
)

codeflash/code_utils/config_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def parse_config_file(
105105
if lsp_mode:
106106
# don't fail in lsp mode if codeflash config is not found.
107107
return {}, config_file_path
108-
msg = f"Could not find the 'codeflash' block in the config file {config_file_path}. Please run 'codeflash init' to create the config file."
108+
msg = f"Could not find the 'codeflash' block in the config file {config_file_path}. Please run 'codeflash init' to add Codeflash config in the pyproject.toml config file."
109109
raise ValueError(msg) from e
110110
assert isinstance(config, dict)
111111

codeflash/code_utils/formatter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def is_diff_line(line: str) -> bool:
9797

9898

9999
def format_generated_code(generated_test_source: str, formatter_cmds: list[str]) -> str:
100+
formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
101+
if formatter_name == "disabled": # nothing to do if no formatter provided
102+
return re.sub(r"\n{2,}", "\n\n", generated_test_source)
100103
with tempfile.TemporaryDirectory() as test_dir_str:
101104
# try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) return code with 2 or more newlines substituted with 2 newlines
102105
original_temp = Path(test_dir_str) / "original_temp.py"

codeflash/code_utils/git_worktree_utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import configparser
34
import subprocess
45
import tempfile
56
import time
@@ -18,14 +19,36 @@
1819

1920
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
2021
repository = git.Repo(worktree_dir, search_parent_directories=True)
21-
with repository.config_writer() as cw:
22+
username = None
23+
no_username = False
24+
email = None
25+
no_email = False
26+
with repository.config_reader(config_level="repository") as cr:
27+
try:
28+
username = cr.get("user", "name")
29+
except (configparser.NoSectionError, configparser.NoOptionError):
30+
no_username = True
31+
try:
32+
email = cr.get("user", "email")
33+
except (configparser.NoSectionError, configparser.NoOptionError):
34+
no_email = True
35+
with repository.config_writer(config_level="repository") as cw:
2236
if not cw.has_option("user", "name"):
2337
cw.set_value("user", "name", "Codeflash Bot")
2438
if not cw.has_option("user", "email"):
2539
cw.set_value("user", "email", "bot@codeflash.ai")
2640

2741
repository.git.add(".")
2842
repository.git.commit("-m", commit_message, "--no-verify")
43+
with repository.config_writer(config_level="repository") as cw:
44+
if username:
45+
cw.set_value("user", "name", username)
46+
elif no_username:
47+
cw.remove_option("user", "name")
48+
if email:
49+
cw.set_value("user", "email", email)
50+
elif no_email:
51+
cw.remove_option("user", "email")
2952

3053

3154
def create_detached_worktree(module_root: Path) -> Optional[Path]:

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -684,27 +684,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
684684
)
685685

686686

687-
def instrument_source_module_with_async_decorators(
688-
source_path: Path, function_to_optimize: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
689-
) -> tuple[bool, str | None]:
690-
if not function_to_optimize.is_async:
691-
return False, None
692-
693-
try:
694-
with source_path.open(encoding="utf8") as f:
695-
source_code = f.read()
696-
697-
modified_code, decorator_added = add_async_decorator_to_function(source_code, function_to_optimize, mode)
698-
699-
if decorator_added:
700-
return True, modified_code
701-
702-
except Exception:
703-
return False, None
704-
else:
705-
return False, None
706-
707-
708687
def inject_async_profiling_into_existing_test(
709688
test_path: Path,
710689
call_positions: list[CodePosition],
@@ -1288,25 +1267,29 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
12881267

12891268

12901269
def add_async_decorator_to_function(
1291-
source_code: str, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
1292-
) -> tuple[str, bool]:
1293-
"""Add async decorator to an async function definition.
1270+
source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
1271+
) -> bool:
1272+
"""Add async decorator to an async function definition and write back to file.
12941273
12951274
Args:
12961275
----
1297-
source_code: The source code to modify.
1276+
source_path: Path to the source file to modify in-place.
12981277
function: The FunctionToOptimize object representing the target async function.
12991278
mode: The testing mode to determine which decorator to apply.
13001279
13011280
Returns:
13021281
-------
1303-
Tuple of (modified_source_code, was_decorator_added).
1282+
Boolean indicating whether the decorator was successfully added.
13041283
13051284
"""
13061285
if not function.is_async:
1307-
return source_code, False
1286+
return False
13081287

13091288
try:
1289+
# Read source code
1290+
with source_path.open(encoding="utf8") as f:
1291+
source_code = f.read()
1292+
13101293
module = cst.parse_module(source_code)
13111294

13121295
# Add the decorator to the function
@@ -1318,10 +1301,17 @@ def add_async_decorator_to_function(
13181301
import_transformer = AsyncDecoratorImportAdder(mode)
13191302
module = module.visit(import_transformer)
13201303

1321-
return sort_imports(code=module.code, float_to_top=True), decorator_transformer.added_decorator
1304+
modified_code = sort_imports(code=module.code, float_to_top=True)
13221305
except Exception as e:
13231306
logger.exception(f"Error adding async decorator to function {function.qualified_name}: {e}")
1324-
return source_code, False
1307+
return False
1308+
else:
1309+
if decorator_transformer.added_decorator:
1310+
with source_path.open("w", encoding="utf8") as f:
1311+
f.write(modified_code)
1312+
logger.debug(f"Applied async {mode.value} instrumentation to {source_path}")
1313+
return True
1314+
return False
13251315

13261316

13271317
def create_instrumented_source_module_path(source_path: Path, temp_dir: Path) -> Path:

codeflash/context/unused_definition_remover.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -469,22 +469,32 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
469469
qualified_function_names: Set of function names to keep. For methods, use format 'classname.methodname'
470470
471471
"""
472-
module = cst.parse_module(code)
473-
# Collect all definitions (top level classes, variables or function)
474-
definitions = collect_top_level_definitions(module)
472+
try:
473+
module = cst.parse_module(code)
474+
except Exception as e:
475+
logger.debug(f"Failed to parse code with libcst: {type(e).__name__}: {e}")
476+
return code
475477

476-
# Collect dependencies between definitions using the visitor pattern
477-
dependency_collector = DependencyCollector(definitions)
478-
module.visit(dependency_collector)
478+
try:
479+
# Collect all definitions (top level classes, variables or function)
480+
definitions = collect_top_level_definitions(module)
479481

480-
# Mark definitions used by specified functions, and their dependencies recursively
481-
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
482-
usage_marker.mark_used_definitions()
482+
# Collect dependencies between definitions using the visitor pattern
483+
dependency_collector = DependencyCollector(definitions)
484+
module.visit(dependency_collector)
483485

484-
# Apply the recursive removal transformation
485-
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
486+
# Mark definitions used by specified functions, and their dependencies recursively
487+
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
488+
usage_marker.mark_used_definitions()
486489

487-
return modified_module.code if modified_module else ""
490+
# Apply the recursive removal transformation
491+
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
492+
493+
return modified_module.code if modified_module else "" # noqa: TRY300
494+
except Exception as e:
495+
# If any other error occurs during processing, return the original code
496+
logger.debug(f"Error processing code to remove unused definitions: {type(e).__name__}: {e}")
497+
return code
488498

489499

490500
def print_definitions(definitions: dict[str, UsageInfo]) -> None:

codeflash/discovery/functions_to_optimize.py

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def get_functions_to_optimize(
201201
elif file is not None:
202202
logger.info("!lsp|Finding all functions in the file '%s'…", file)
203203
console.rule()
204-
functions = find_all_functions_in_file(file)
204+
functions: dict[Path, list[FunctionToOptimize]] = find_all_functions_in_file(file)
205205
if only_get_this_function is not None:
206206
split_function = only_get_this_function.split(".")
207207
if len(split_function) > 2:
@@ -224,8 +224,16 @@ def get_functions_to_optimize(
224224
if found_function is None:
225225
if is_lsp:
226226
return functions, 0, None
227+
found = closest_matching_file_function_name(only_get_this_function, functions)
228+
if found is not None:
229+
file, found_function = found
230+
exit_with_message(
231+
f"Function {only_get_this_function} not found in file {file}\nor the function does not have a 'return' statement or is a property.\n"
232+
f"Did you mean {found_function.qualified_name} instead?"
233+
)
234+
227235
exit_with_message(
228-
f"Function {only_function_name} not found in file {file}\nor the function does not have a 'return' statement or is a property"
236+
f"Function {only_get_this_function} not found in file {file}\nor the function does not have a 'return' statement or is a property"
229237
)
230238
functions[file] = [found_function]
231239
else:
@@ -259,6 +267,76 @@ def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[F
259267
return get_functions_within_lines(modified_lines)
260268

261269

270+
def closest_matching_file_function_name(
271+
qualified_fn_to_find: str, found_fns: dict[Path, list[FunctionToOptimize]]
272+
) -> tuple[Path, FunctionToOptimize] | None:
273+
"""Find the closest matching function name using Levenshtein distance.
274+
275+
Args:
276+
qualified_fn_to_find: Function name to find in format "Class.function" or "function"
277+
found_fns: Dictionary of file paths to list of functions
278+
279+
Returns:
280+
Tuple of (file_path, function) for closest match, or None if no matches found
281+
282+
"""
283+
min_distance = 4
284+
closest_match = None
285+
closest_file = None
286+
287+
qualified_fn_to_find_lower = qualified_fn_to_find.lower()
288+
289+
# Cache levenshtein_distance locally for improved lookup speed
290+
_levenshtein = levenshtein_distance
291+
292+
for file_path, functions in found_fns.items():
293+
for function in functions:
294+
# Compare either full qualified name or just function name
295+
fn_name = function.qualified_name.lower()
296+
# If the absolute length difference is already >= min_distance, skip calculation
297+
if abs(len(qualified_fn_to_find_lower) - len(fn_name)) >= min_distance:
298+
continue
299+
dist = _levenshtein(qualified_fn_to_find_lower, fn_name)
300+
301+
if dist < min_distance:
302+
min_distance = dist
303+
closest_match = function
304+
closest_file = file_path
305+
306+
if closest_match is not None:
307+
return closest_file, closest_match
308+
return None
309+
310+
311+
def levenshtein_distance(s1: str, s2: str) -> int:
312+
if len(s1) > len(s2):
313+
s1, s2 = s2, s1
314+
len1 = len(s1)
315+
len2 = len(s2)
316+
# Use a preallocated list instead of creating a new list every iteration
317+
previous = list(range(len1 + 1))
318+
current = [0] * (len1 + 1)
319+
320+
for index2 in range(len2):
321+
char2 = s2[index2]
322+
current[0] = index2 + 1
323+
for index1 in range(len1):
324+
char1 = s1[index1]
325+
if char1 == char2:
326+
current[index1 + 1] = previous[index1]
327+
else:
328+
# Fast min calculation without tuple construct
329+
a = previous[index1]
330+
b = previous[index1 + 1]
331+
c = current[index1]
332+
min_val = min(b, a)
333+
min_val = min(c, min_val)
334+
current[index1 + 1] = 1 + min_val
335+
# Swap references instead of copying
336+
previous, current = current, previous
337+
return previous[len1]
338+
339+
262340
def get_functions_inside_a_commit(commit_hash: str) -> dict[str, list[FunctionToOptimize]]:
263341
modified_lines: dict[str, list[int]] = get_git_diff(only_this_commit=commit_hash)
264342
return get_functions_within_lines(modified_lines)
@@ -405,7 +483,10 @@ def is_git_repo(file_path: str) -> bool:
405483
def ignored_submodule_paths(module_root: str) -> list[str]:
406484
if is_git_repo(module_root):
407485
git_repo = git.Repo(module_root, search_parent_directories=True)
408-
return [Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules]
486+
try:
487+
return [Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules]
488+
except Exception as e:
489+
logger.warning(f"Error getting submodule paths: {e}")
409490
return []
410491

411492

0 commit comments

Comments
 (0)