Skip to content

Commit 15f9bf9

Browse files
committed
Merge remote-tracking branch 'origin/main' into rewrite-candidate-loop-clean
2 parents e73cb53 + b77f50e commit 15f9bf9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1021
-513
lines changed

README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
<a href="https://github.com/codeflash-ai/codeflash">
44
<img src="https://img.shields.io/github/commit-activity/m/codeflash-ai/codeflash" alt="GitHub commit activity">
55
</a>
6-
<a href="https://pypi.org/project/codeflash/">
7-
<img src="https://img.shields.io/pypi/dm/codeflash" alt="PyPI Downloads">
8-
</a>
6+
<a href="https://pypi.org/project/codeflash/"><img src="https://static.pepy.tech/badge/codeflash" alt="PyPI Downloads"></a>
97
<a href="https://pypi.org/project/codeflash/">
108
<img src="https://img.shields.io/pypi/v/codeflash?label=PyPI%20version" alt="PyPI Downloads">
119
</a>
@@ -83,4 +81,4 @@ Join our community for support and discussions. If you have any questions, feel
8381

8482
## License
8583

86-
Codeflash is licensed under the BSL-1.1 License. See the LICENSE file for details.
84+
Codeflash is licensed under the BSL-1.1 License. See the [LICENSE](https://github.com/codeflash-ai/codeflash/blob/main/codeflash/LICENSE) file for details.

codeflash/LICENSE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Business Source License 1.1
33
Parameters
44

55
Licensor: CodeFlash Inc.
6-
Licensed Work: Codeflash Client version 0.15.x
6+
Licensed Work: Codeflash Client version 0.16.x
77
The Licensed Work is (c) 2024 CodeFlash Inc.
88

99
Additional Use Grant: None. Production use of the Licensed Work is only permitted
@@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte
1313
Platform. Please visit codeflash.ai for further
1414
information.
1515

16-
Change Date: 2029-07-03
16+
Change Date: 2029-08-14
1717

1818
Change License: MIT
1919

codeflash/api/aiservice.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from pydantic.json import pydantic_encoder
1111

1212
from codeflash.cli_cmds.console import console, logger
13-
from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled
13+
from codeflash.code_utils.env_utils import get_codeflash_api_key
1414
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
15+
from codeflash.lsp.helpers import is_LSP_enabled
1516
from codeflash.models.ExperimentMetadata import ExperimentMetadata
1617
from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate
1718
from codeflash.telemetry.posthog_cf import ph

codeflash/api/cfapi.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
from codeflash.cli_cmds.console import console, logger
1616
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
17-
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name, git_root_dir
17+
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name
1818
from codeflash.github.PrComment import FileDiffContent, PrComment
19+
from codeflash.lsp.helpers import is_LSP_enabled
1920
from codeflash.version import __version__
2021

2122
if TYPE_CHECKING:
@@ -101,7 +102,7 @@ def get_user_id() -> Optional[str]:
101102
if min_version and version.parse(min_version) > version.parse(__version__):
102103
msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`."
103104
console.print(f"[bold red]{msg}[/bold red]")
104-
if console.quiet: # lsp
105+
if is_LSP_enabled():
105106
logger.debug(msg)
106107
return f"Error: {msg}"
107108
sys.exit(1)
@@ -203,8 +204,9 @@ def create_staging(
203204
generated_original_test_source: str,
204205
function_trace_id: str,
205206
coverage_message: str,
206-
replay_tests: str = "",
207-
concolic_tests: str = "",
207+
replay_tests: str,
208+
concolic_tests: str,
209+
root_dir: Path,
208210
) -> Response:
209211
"""Create a staging pull request, targeting the specified branch. (usually 'staging').
210212
@@ -217,12 +219,10 @@ def create_staging(
217219
:param coverage_message: Coverage report or summary.
218220
:return: The response object from the backend.
219221
"""
220-
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
222+
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
221223

222224
build_file_changes = {
223-
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
224-
oldContent=original_code[p], newContent=new_code[p]
225-
)
225+
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(oldContent=original_code[p], newContent=new_code[p])
226226
for p in original_code
227227
}
228228

codeflash/cli_cmds/cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from codeflash.code_utils import env_utils
1111
from codeflash.code_utils.code_utils import exit_with_message
1212
from codeflash.code_utils.config_parser import parse_config_file
13+
from codeflash.lsp.helpers import is_LSP_enabled
1314
from codeflash.version import __version__ as version
1415

1516

@@ -94,6 +95,7 @@ def parse_args() -> Namespace:
9495
help="Path to the directory of the project, where all the pytest-benchmark tests are located.",
9596
)
9697
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
98+
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
9799

98100
args, unknown_args = parser.parse_known_args()
99101
sys.argv[:] = [sys.argv[0], *unknown_args]
@@ -210,6 +212,9 @@ def process_pyproject_config(args: Namespace) -> Namespace:
210212
if args.benchmarks_root:
211213
args.benchmarks_root = Path(args.benchmarks_root).resolve()
212214
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path)
215+
if is_LSP_enabled():
216+
args.all = None
217+
return args
213218
return handle_optimize_all_arg_parsing(args)
214219

215220

codeflash/cli_cmds/cmd_init.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,22 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
155155
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)
156156

157157

158+
def is_valid_pyproject_toml(pyproject_toml_path: Path) -> dict[str, Any] | None:
159+
if not pyproject_toml_path.exists():
160+
return None
161+
try:
162+
config, _ = parse_config_file(pyproject_toml_path)
163+
except Exception:
164+
return None
165+
166+
if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
167+
return None
168+
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
169+
return None
170+
171+
return config
172+
173+
158174
def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
159175
"""Check if the current directory contains a valid pyproject.toml file with codeflash config.
160176
@@ -163,16 +179,9 @@ def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
163179
from rich.prompt import Confirm
164180

165181
pyproject_toml_path = Path.cwd() / "pyproject.toml"
166-
if not pyproject_toml_path.exists():
167-
return True, None
168-
try:
169-
config, config_file_path = parse_config_file(pyproject_toml_path)
170-
except Exception:
171-
return True, None
172182

173-
if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
174-
return True, None
175-
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
183+
config = is_valid_pyproject_toml(pyproject_toml_path)
184+
if config is None:
176185
return True, None
177186

178187
return Confirm.ask(
@@ -968,6 +977,11 @@ def install_github_app(git_remote: str) -> None:
968977
except git.InvalidGitRepositoryError:
969978
click.echo("Skipping GitHub app installation because you're not in a git repository.")
970979
return
980+
981+
if git_remote not in get_git_remotes(git_repo):
982+
click.echo(f"Skipping GitHub app installation, remote ({git_remote}) does not exist in this repository.")
983+
return
984+
971985
owner, repo = get_repo_owner_and_name(git_repo, git_remote)
972986

973987
if is_github_app_installed_on_repo(owner, repo, suppress_errors=True):

codeflash/cli_cmds/console.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import os
45
from contextlib import contextmanager
56
from itertools import cycle
67
from typing import TYPE_CHECKING
@@ -28,6 +29,10 @@
2829
DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG
2930

3031
console = Console()
32+
33+
if os.getenv("CODEFLASH_LSP"):
34+
console.quiet = True
35+
3136
logging.basicConfig(
3237
level=logging.INFO,
3338
handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)],

codeflash/code_utils/code_extractor.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,19 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
195195
self.last_import_line = self.current_line
196196

197197

198-
class ConditionalImportCollector(cst.CSTVisitor):
199-
"""Collect imports inside top-level conditionals (e.g., if TYPE_CHECKING, try/except)."""
198+
class DottedImportCollector(cst.CSTVisitor):
199+
"""Collects all top-level imports from a Python module in normalized dotted format, including top-level conditional imports like `if TYPE_CHECKING:`.
200+
201+
Examples
202+
--------
203+
import os ==> "os"
204+
import dbt.adapters.factory ==> "dbt.adapters.factory"
205+
from pathlib import Path ==> "pathlib.Path"
206+
from recce.adapter.base import BaseAdapter ==> "recce.adapter.base.BaseAdapter"
207+
from typing import Any, List, Optional ==> "typing.Any", "typing.List", "typing.Optional"
208+
from recce.util.lineage import ( build_column_key, filter_dependency_maps) ==> "recce.util.lineage.build_column_key", "recce.util.lineage.filter_dependency_maps"
209+
210+
"""
200211

201212
def __init__(self) -> None:
202213
self.imports: set[str] = set()
@@ -217,7 +228,10 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
217228
for alias in child.names:
218229
module = self.get_full_dotted_name(alias.name)
219230
asname = alias.asname.name.value if alias.asname else alias.name.value
220-
self.imports.add(module if module == asname else f"{module}.{asname}")
231+
if isinstance(asname, cst.Attribute):
232+
self.imports.add(module)
233+
else:
234+
self.imports.add(module if module == asname else f"{module}.{asname}")
221235

222236
elif isinstance(child, cst.ImportFrom):
223237
if child.module is None:
@@ -231,6 +245,7 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
231245

232246
def visit_Module(self, node: cst.Module) -> None:
233247
self.depth = 0
248+
self._collect_imports_from_block(node)
234249

235250
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
236251
self.depth += 1
@@ -388,45 +403,44 @@ def add_needed_imports_from_module(
388403
logger.error(f"Error parsing source module code: {e}")
389404
return dst_module_code
390405

391-
cond_import_collector = ConditionalImportCollector()
406+
dotted_import_collector = DottedImportCollector()
392407
try:
393408
parsed_dst_module = cst.parse_module(dst_module_code)
394-
parsed_dst_module.visit(cond_import_collector)
409+
parsed_dst_module.visit(dotted_import_collector)
395410
except cst.ParserSyntaxError as e:
396411
logger.exception(f"Syntax error in destination module code: {e}")
397412
return dst_module_code # Return the original code if there's a syntax error
398413

399414
try:
400415
for mod in gatherer.module_imports:
401-
if mod in cond_import_collector.imports:
402-
continue
403-
AddImportsVisitor.add_needed_import(dst_context, mod)
416+
if mod not in dotted_import_collector.imports:
417+
AddImportsVisitor.add_needed_import(dst_context, mod)
404418
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
405419
for mod, obj_seq in gatherer.object_mapping.items():
406420
for obj in obj_seq:
407421
if (
408422
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
409423
):
410424
continue # Skip adding imports for helper functions already in the context
411-
if f"{mod}.{obj}" in cond_import_collector.imports:
412-
continue
413-
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
425+
if f"{mod}.{obj}" not in dotted_import_collector.imports:
426+
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
414427
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
415428
except Exception as e:
416429
logger.exception(f"Error adding imports to destination module code: {e}")
417430
return dst_module_code
431+
418432
for mod, asname in gatherer.module_aliases.items():
419-
if f"{mod}.{asname}" in cond_import_collector.imports:
420-
continue
421-
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
433+
if f"{mod}.{asname}" not in dotted_import_collector.imports:
434+
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
422435
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
436+
423437
for mod, alias_pairs in gatherer.alias_mapping.items():
424438
for alias_pair in alias_pairs:
425439
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
426440
continue
427-
if f"{mod}.{alias_pair[1]}" in cond_import_collector.imports:
428-
continue
429-
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
441+
442+
if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports:
443+
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
430444
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
431445

432446
try:

codeflash/code_utils/coverage_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,20 @@ def build_fully_qualified_name(function_name: str, code_context: CodeOptimizatio
3939
return full_name
4040

4141

42-
def generate_candidates(source_code_path: Path) -> list[str]:
42+
def generate_candidates(source_code_path: Path) -> set[str]:
4343
"""Generate all the possible candidates for coverage data based on the source code path."""
44-
candidates = [source_code_path.name]
44+
candidates = set()
45+
candidates.add(source_code_path.name)
4546
current_path = source_code_path.parent
4647

48+
last_added = source_code_path.name
4749
while current_path != current_path.parent:
48-
candidate_path = str(Path(current_path.name) / candidates[-1])
49-
candidates.append(candidate_path)
50+
candidate_path = str(Path(current_path.name) / last_added)
51+
candidates.add(candidate_path)
52+
last_added = candidate_path
5053
current_path = current_path.parent
5154

55+
candidates.add(str(source_code_path))
5256
return candidates
5357

5458

codeflash/code_utils/env_utils.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from pathlib import Path
88
from typing import Any, Optional
99

10-
from codeflash.cli_cmds.console import console, logger
10+
from codeflash.cli_cmds.console import logger
1111
from codeflash.code_utils.code_utils import exit_with_message
1212
from codeflash.code_utils.formatter import format_code
1313
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
14+
from codeflash.lsp.helpers import is_LSP_enabled
1415

1516

1617
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
@@ -34,11 +35,12 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
3435

3536
@lru_cache(maxsize=1)
3637
def get_codeflash_api_key() -> str:
37-
if console.quiet: # lsp
38-
# prefer shell config over env var in lsp mode
39-
api_key = read_api_key_from_shell_config()
40-
else:
41-
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
38+
# prefer shell config over env var in lsp mode
39+
api_key = (
40+
read_api_key_from_shell_config()
41+
if is_LSP_enabled()
42+
else os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
43+
)
4244

4345
api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/getting-started/codeflash-github-actions#add-your-api-key-to-your-repository-secrets]." # noqa
4446
if not api_key:
@@ -125,11 +127,6 @@ def is_ci() -> bool:
125127
return bool(os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS"))
126128

127129

128-
@lru_cache(maxsize=1)
129-
def is_LSP_enabled() -> bool:
130-
return console.quiet
131-
132-
133130
def is_pr_draft() -> bool:
134131
"""Check if the PR is draft. in the github action context."""
135132
event = get_cached_gh_event_data()

0 commit comments

Comments
 (0)