Skip to content

Commit 74fe322

Browse files
authored
Merge branch 'main' into pyproject-search-improvement
2 parents a95ad4f + 7797c9f commit 74fe322

File tree

9 files changed

+683
-44
lines changed

9 files changed

+683
-44
lines changed

.github/workflows/e2e-init-optimization.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
COLUMNS: 110
2020
MAX_RETRIES: 3
2121
RETRY_DELAY: 5
22-
EXPECTED_IMPROVEMENT_PCT: 30
22+
EXPECTED_IMPROVEMENT_PCT: 10
2323
CODEFLASH_END_TO_END: 1
2424
steps:
2525
- name: 🛎️ Checkout

codeflash/code_utils/code_extractor.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,16 +528,29 @@ def add_needed_imports_from_module(
528528

529529
try:
530530
for mod in gatherer.module_imports:
531+
# Skip __future__ imports as they cannot be imported directly
532+
# __future__ imports should only be imported with specific objects i.e from __future__ import annotations
533+
if mod == "__future__":
534+
continue
531535
if mod not in dotted_import_collector.imports:
532536
AddImportsVisitor.add_needed_import(dst_context, mod)
533537
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
538+
aliased_objects = set()
539+
for mod, alias_pairs in gatherer.alias_mapping.items():
540+
for alias_pair in alias_pairs:
541+
if alias_pair[0] and alias_pair[1]: # Both name and alias exist
542+
aliased_objects.add(f"{mod}.{alias_pair[0]}")
543+
534544
for mod, obj_seq in gatherer.object_mapping.items():
535545
for obj in obj_seq:
536546
if (
537547
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
538548
):
539549
continue # Skip adding imports for helper functions already in the context
540550

551+
if f"{mod}.{obj}" in aliased_objects:
552+
continue
553+
541554
# Handle star imports by resolving them to actual symbol names
542555
if obj == "*":
543556
resolved_symbols = resolve_star_import(mod, project_root)
@@ -559,6 +572,8 @@ def add_needed_imports_from_module(
559572
return dst_module_code
560573

561574
for mod, asname in gatherer.module_aliases.items():
575+
if not asname:
576+
continue
562577
if f"{mod}.{asname}" not in dotted_import_collector.imports:
563578
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
564579
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
@@ -568,12 +583,16 @@ def add_needed_imports_from_module(
568583
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
569584
continue
570585

586+
if not alias_pair[0] or not alias_pair[1]:
587+
continue
588+
571589
if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports:
572590
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
573591
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
574592

575593
try:
576-
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_dst_module)
594+
add_imports_visitor = AddImportsVisitor(dst_context)
595+
transformed_module = add_imports_visitor.transform_module(parsed_dst_module)
577596
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
578597
return transformed_module.code.lstrip("\n")
579598
except Exception as e:

codeflash/code_utils/code_utils.py

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import configparser
45
import difflib
56
import os
67
import re
@@ -15,10 +16,12 @@
1516
import tomlkit
1617

1718
from codeflash.cli_cmds.console import logger, paneled_text
18-
from codeflash.code_utils.config_parser import find_pyproject_toml
19+
from codeflash.code_utils.config_parser import find_pyproject_toml, get_all_closest_config_files
1920

2021
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
2122

23+
BLACKLIST_ADDOPTS = ("--benchmark", "--sugar", "--codespeed", "--cov", "--profile", "--junitxml", "-n")
24+
2225

2326
def unified_diff_strings(code1: str, code2: str, fromfile: str = "original", tofile: str = "modified") -> str:
2427
"""Return the unified diff between two code strings as a single string.
@@ -81,42 +84,105 @@ def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]:
8184
return {original_index: rank for rank, original_index in enumerate(sorted_indices)}
8285

8386

84-
@contextmanager
85-
def custom_addopts() -> None:
86-
pyproject_file = find_pyproject_toml()
87-
original_content = None
88-
non_blacklist_plugin_args = ""
89-
87+
def filter_args(addopts_args: list[str]) -> list[str]:
88+
# Convert BLACKLIST_ADDOPTS to a set for faster lookup of simple matches
89+
# But keep tuple for startswith
90+
blacklist = BLACKLIST_ADDOPTS
91+
# Precompute the length for re-use
92+
n = len(addopts_args)
93+
filtered_args = []
94+
i = 0
95+
while i < n:
96+
current_arg = addopts_args[i]
97+
if current_arg.startswith(blacklist):
98+
i += 1
99+
if i < n and not addopts_args[i].startswith("-"):
100+
i += 1
101+
else:
102+
filtered_args.append(current_arg)
103+
i += 1
104+
return filtered_args
105+
106+
107+
def modify_addopts(config_file: Path) -> tuple[str, bool]: # noqa : PLR0911
108+
file_type = config_file.suffix.lower()
109+
filename = config_file.name
110+
config = None
111+
if file_type not in {".toml", ".ini", ".cfg"} or not config_file.exists():
112+
return "", False
113+
# Read original file
114+
with Path.open(config_file, encoding="utf-8") as f:
115+
content = f.read()
90116
try:
91-
# Read original file
92-
if pyproject_file.exists():
93-
with Path.open(pyproject_file, encoding="utf-8") as f:
94-
original_content = f.read()
95-
data = tomlkit.parse(original_content)
96-
# Backup original addopts
117+
if filename == "pyproject.toml":
118+
# use tomlkit
119+
data = tomlkit.parse(content)
97120
original_addopts = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("addopts", "")
98121
# nothing to do if no addopts present
99-
if original_addopts != "" and isinstance(original_addopts, list):
100-
original_addopts = [x.strip() for x in original_addopts]
101-
non_blacklist_plugin_args = re.sub(r"-n(?: +|=)\S+", "", " ".join(original_addopts)).split(" ")
102-
non_blacklist_plugin_args = [x for x in non_blacklist_plugin_args if x != ""]
103-
if non_blacklist_plugin_args != original_addopts:
104-
data["tool"]["pytest"]["ini_options"]["addopts"] = non_blacklist_plugin_args
105-
# Write modified file
106-
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
107-
f.write(tomlkit.dumps(data))
122+
if original_addopts == "":
123+
return content, False
124+
if isinstance(original_addopts, list):
125+
original_addopts = " ".join(original_addopts)
126+
original_addopts = original_addopts.replace("=", " ")
127+
addopts_args = (
128+
original_addopts.split()
129+
) # any number of space characters as delimiter, doesn't look at = which is fine
130+
else:
131+
# use configparser
132+
config = configparser.ConfigParser()
133+
config.read_string(content)
134+
data = {section: dict(config[section]) for section in config.sections()}
135+
if config_file.name in {"pytest.ini", ".pytest.ini", "tox.ini"}:
136+
original_addopts = data.get("pytest", {}).get("addopts", "") # should only be a string
137+
else:
138+
original_addopts = data.get("tool:pytest", {}).get("addopts", "") # should only be a string
139+
original_addopts = original_addopts.replace("=", " ")
140+
addopts_args = original_addopts.split()
141+
new_addopts_args = filter_args(addopts_args)
142+
if new_addopts_args == addopts_args:
143+
return content, False
144+
# change addopts now
145+
if file_type == ".toml":
146+
data["tool"]["pytest"]["ini_options"]["addopts"] = " ".join(new_addopts_args)
147+
# Write modified file
148+
with Path.open(config_file, "w", encoding="utf-8") as f:
149+
f.write(tomlkit.dumps(data))
150+
return content, True
151+
elif config_file.name in {"pytest.ini", ".pytest.ini", "tox.ini"}:
152+
config.set("pytest", "addopts", " ".join(new_addopts_args))
153+
# Write modified file
154+
with Path.open(config_file, "w", encoding="utf-8") as f:
155+
config.write(f)
156+
return content, True
157+
else:
158+
config.set("tool:pytest", "addopts", " ".join(new_addopts_args))
159+
# Write modified file
160+
with Path.open(config_file, "w", encoding="utf-8") as f:
161+
config.write(f)
162+
return content, True
163+
164+
except Exception:
165+
logger.debug("Trouble parsing")
166+
return content, False # not modified
167+
168+
169+
@contextmanager
170+
def custom_addopts() -> None:
171+
closest_config_files = get_all_closest_config_files()
172+
173+
original_content = {}
108174

175+
try:
176+
for config_file in closest_config_files:
177+
original_content[config_file] = modify_addopts(config_file)
109178
yield
110179

111180
finally:
112181
# Restore original file
113-
if (
114-
original_content
115-
and pyproject_file.exists()
116-
and tuple(original_addopts) not in {(), tuple(non_blacklist_plugin_args)}
117-
):
118-
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
119-
f.write(original_content)
182+
for file, (content, was_modified) in original_content.items():
183+
if was_modified:
184+
with Path.open(file, "w", encoding="utf-8") as f:
185+
f.write(content)
120186

121187

122188
@contextmanager

codeflash/code_utils/config_parser.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import tomlkit
77

88
PYPROJECT_TOML_CACHE = {}
9+
ALL_CONFIG_FILES = {} # map path to closest config file
910

1011

1112
def find_pyproject_toml(config_file: Path | None = None) -> Path:
@@ -38,6 +39,33 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path:
3839
raise ValueError(msg)
3940

4041

42+
def get_all_closest_config_files() -> list[Path]:
43+
all_closest_config_files = []
44+
for file_type in ["pyproject.toml", "pytest.ini", ".pytest.ini", "tox.ini", "setup.cfg"]:
45+
closest_config_file = find_closest_config_file(file_type)
46+
if closest_config_file:
47+
all_closest_config_files.append(closest_config_file)
48+
return all_closest_config_files
49+
50+
51+
def find_closest_config_file(file_type: str) -> Path | None:
52+
# Find the closest pyproject.toml, pytest.ini, tox.ini, or setup.cfg file on the root of the project
53+
dir_path = Path.cwd()
54+
cur_path = dir_path
55+
if cur_path in ALL_CONFIG_FILES and file_type in ALL_CONFIG_FILES[cur_path]:
56+
return ALL_CONFIG_FILES[cur_path][file_type]
57+
while dir_path != dir_path.parent:
58+
config_file = dir_path / file_type
59+
if config_file.exists():
60+
if cur_path not in ALL_CONFIG_FILES:
61+
ALL_CONFIG_FILES[cur_path] = {}
62+
ALL_CONFIG_FILES[cur_path][file_type] = config_file
63+
return config_file
64+
# Search for pyproject.toml in the parent directories
65+
dir_path = dir_path.parent
66+
return None
67+
68+
4169
def find_conftest_files(test_paths: list[Path]) -> list[Path]:
4270
list_of_conftest_files = set()
4371
for test_path in test_paths:

codeflash/optimization/function_optimizer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,14 +1461,14 @@ def process_review(
14611461

14621462
if raise_pr or staging_review:
14631463
data["root_dir"] = git_root_dir()
1464-
try:
1465-
# modify argument of staging vs pr based on the impact
1466-
opt_impact_response = self.aiservice_client.get_optimization_impact(**data)
1467-
if opt_impact_response == "low":
1468-
raise_pr = False
1469-
staging_review = True
1470-
except Exception as e:
1471-
logger.debug(f"optimization impact response failed, investigate {e}")
1464+
# try:
1465+
# # modify argument of staging vs pr based on the impact
1466+
# opt_impact_response = self.aiservice_client.get_optimization_impact(**data)
1467+
# if opt_impact_response == "low":
1468+
# raise_pr = False
1469+
# staging_review = True
1470+
# except Exception as e:
1471+
# logger.debug(f"optimization impact response failed, investigate {e}")
14721472
if raise_pr and not staging_review:
14731473
data["git_remote"] = self.args.git_remote
14741474
check_create_pr(**data)

codeflash/verification/parse_test_output.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def calculate_function_throughput_from_test_results(test_results: TestResults, f
6767
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
6868
test_results = TestResults()
6969
if not file_location.exists():
70-
logger.warning(f"No test results for {file_location} found.")
70+
logger.debug(f"No test results for {file_location} found.")
7171
console.rule()
7272
return test_results
7373

@@ -237,6 +237,11 @@ def parse_test_xml(
237237

238238
test_class_path = testcase.classname
239239
try:
240+
if testcase.name is None:
241+
logger.debug(
242+
f"testcase.name is None for testcase {testcase!r} in file {test_xml_file_path}, skipping"
243+
)
244+
continue
240245
test_function = testcase.name.split("[", 1)[0] if "[" in testcase.name else testcase.name
241246
except (AttributeError, TypeError) as e:
242247
msg = (
@@ -273,16 +278,16 @@ def parse_test_xml(
273278

274279
timed_out = False
275280
if test_config.test_framework == "pytest":
276-
loop_index = int(testcase.name.split("[ ")[-1][:-2]) if "[" in testcase.name else 1
281+
loop_index = int(testcase.name.split("[ ")[-1][:-2]) if testcase.name and "[" in testcase.name else 1
277282
if len(testcase.result) > 1:
278-
logger.warning(f"!!!!!Multiple results for {testcase.name} in {test_xml_file_path}!!!")
283+
logger.debug(f"!!!!!Multiple results for {testcase.name or '<None>'} in {test_xml_file_path}!!!")
279284
if len(testcase.result) == 1:
280285
message = testcase.result[0].message.lower()
281286
if "failed: timeout >" in message:
282287
timed_out = True
283288
else:
284289
if len(testcase.result) > 1:
285-
logger.warning(f"!!!!!Multiple results for {testcase.name} in {test_xml_file_path}!!!")
290+
logger.debug(f"!!!!!Multiple results for {testcase.name or '<None>'} in {test_xml_file_path}!!!")
286291
if len(testcase.result) == 1:
287292
message = testcase.result[0].message.lower()
288293
if "timed out" in message:

0 commit comments

Comments
 (0)