Skip to content

Commit d5cf24b

Browse files
author
Codeflash Bot
committed
Merge branch 'test_cache_revival' of github.com:codeflash-ai/codeflash into test_cache_revival
2 parents bdc062a + 1b58dd1 commit d5cf24b

File tree

8 files changed

+727
-57
lines changed

8 files changed

+727
-57
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,16 +528,29 @@ def add_needed_imports_from_module(
528528

529529
try:
530530
for mod in gatherer.module_imports:
531+
# Skip __future__ imports as they cannot be imported directly
532+
# __future__ imports should only be imported with specific objects i.e from __future__ import annotations
533+
if mod == "__future__":
534+
continue
531535
if mod not in dotted_import_collector.imports:
532536
AddImportsVisitor.add_needed_import(dst_context, mod)
533537
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
538+
aliased_objects = set()
539+
for mod, alias_pairs in gatherer.alias_mapping.items():
540+
for alias_pair in alias_pairs:
541+
if alias_pair[0] and alias_pair[1]: # Both name and alias exist
542+
aliased_objects.add(f"{mod}.{alias_pair[0]}")
543+
534544
for mod, obj_seq in gatherer.object_mapping.items():
535545
for obj in obj_seq:
536546
if (
537547
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
538548
):
539549
continue # Skip adding imports for helper functions already in the context
540550

551+
if f"{mod}.{obj}" in aliased_objects:
552+
continue
553+
541554
# Handle star imports by resolving them to actual symbol names
542555
if obj == "*":
543556
resolved_symbols = resolve_star_import(mod, project_root)
@@ -559,6 +572,8 @@ def add_needed_imports_from_module(
559572
return dst_module_code
560573

561574
for mod, asname in gatherer.module_aliases.items():
575+
if not asname:
576+
continue
562577
if f"{mod}.{asname}" not in dotted_import_collector.imports:
563578
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
564579
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
@@ -568,12 +583,16 @@ def add_needed_imports_from_module(
568583
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
569584
continue
570585

586+
if not alias_pair[0] or not alias_pair[1]:
587+
continue
588+
571589
if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports:
572590
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
573591
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
574592

575593
try:
576-
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_dst_module)
594+
add_imports_visitor = AddImportsVisitor(dst_context)
595+
transformed_module = add_imports_visitor.transform_module(parsed_dst_module)
577596
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
578597
return transformed_module.code.lstrip("\n")
579598
except Exception as e:

codeflash/code_utils/code_utils.py

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import configparser
45
import difflib
56
import os
67
import re
@@ -15,10 +16,12 @@
1516
import tomlkit
1617

1718
from codeflash.cli_cmds.console import logger, paneled_text
18-
from codeflash.code_utils.config_parser import find_pyproject_toml
19+
from codeflash.code_utils.config_parser import find_pyproject_toml, get_all_closest_config_files
1920

2021
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
2122

23+
BLACKLIST_ADDOPTS = ("--benchmark", "--sugar", "--codespeed", "--cov", "--profile", "--junitxml", "-n")
24+
2225

2326
def unified_diff_strings(code1: str, code2: str, fromfile: str = "original", tofile: str = "modified") -> str:
2427
"""Return the unified diff between two code strings as a single string.
@@ -81,42 +84,105 @@ def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]:
8184
return {original_index: rank for rank, original_index in enumerate(sorted_indices)}
8285

8386

84-
@contextmanager
85-
def custom_addopts() -> None:
86-
pyproject_file = find_pyproject_toml()
87-
original_content = None
88-
non_blacklist_plugin_args = ""
89-
87+
def filter_args(addopts_args: list[str]) -> list[str]:
88+
# Convert BLACKLIST_ADDOPTS to a set for faster lookup of simple matches
89+
# But keep tuple for startswith
90+
blacklist = BLACKLIST_ADDOPTS
91+
# Precompute the length for re-use
92+
n = len(addopts_args)
93+
filtered_args = []
94+
i = 0
95+
while i < n:
96+
current_arg = addopts_args[i]
97+
if current_arg.startswith(blacklist):
98+
i += 1
99+
if i < n and not addopts_args[i].startswith("-"):
100+
i += 1
101+
else:
102+
filtered_args.append(current_arg)
103+
i += 1
104+
return filtered_args
105+
106+
107+
def modify_addopts(config_file: Path) -> tuple[str, bool]: # noqa : PLR0911
108+
file_type = config_file.suffix.lower()
109+
filename = config_file.name
110+
config = None
111+
if file_type not in {".toml", ".ini", ".cfg"} or not config_file.exists():
112+
return "", False
113+
# Read original file
114+
with Path.open(config_file, encoding="utf-8") as f:
115+
content = f.read()
90116
try:
91-
# Read original file
92-
if pyproject_file.exists():
93-
with Path.open(pyproject_file, encoding="utf-8") as f:
94-
original_content = f.read()
95-
data = tomlkit.parse(original_content)
96-
# Backup original addopts
117+
if filename == "pyproject.toml":
118+
# use tomlkit
119+
data = tomlkit.parse(content)
97120
original_addopts = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("addopts", "")
98121
# nothing to do if no addopts present
99-
if original_addopts != "" and isinstance(original_addopts, list):
100-
original_addopts = [x.strip() for x in original_addopts]
101-
non_blacklist_plugin_args = re.sub(r"-n(?: +|=)\S+", "", " ".join(original_addopts)).split(" ")
102-
non_blacklist_plugin_args = [x for x in non_blacklist_plugin_args if x != ""]
103-
if non_blacklist_plugin_args != original_addopts:
104-
data["tool"]["pytest"]["ini_options"]["addopts"] = non_blacklist_plugin_args
105-
# Write modified file
106-
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
107-
f.write(tomlkit.dumps(data))
122+
if original_addopts == "":
123+
return content, False
124+
if isinstance(original_addopts, list):
125+
original_addopts = " ".join(original_addopts)
126+
original_addopts = original_addopts.replace("=", " ")
127+
addopts_args = (
128+
original_addopts.split()
129+
) # any number of space characters as delimiter, doesn't look at = which is fine
130+
else:
131+
# use configparser
132+
config = configparser.ConfigParser()
133+
config.read_string(content)
134+
data = {section: dict(config[section]) for section in config.sections()}
135+
if config_file.name in {"pytest.ini", ".pytest.ini", "tox.ini"}:
136+
original_addopts = data.get("pytest", {}).get("addopts", "") # should only be a string
137+
else:
138+
original_addopts = data.get("tool:pytest", {}).get("addopts", "") # should only be a string
139+
original_addopts = original_addopts.replace("=", " ")
140+
addopts_args = original_addopts.split()
141+
new_addopts_args = filter_args(addopts_args)
142+
if new_addopts_args == addopts_args:
143+
return content, False
144+
# change addopts now
145+
if file_type == ".toml":
146+
data["tool"]["pytest"]["ini_options"]["addopts"] = " ".join(new_addopts_args)
147+
# Write modified file
148+
with Path.open(config_file, "w", encoding="utf-8") as f:
149+
f.write(tomlkit.dumps(data))
150+
return content, True
151+
elif config_file.name in {"pytest.ini", ".pytest.ini", "tox.ini"}:
152+
config.set("pytest", "addopts", " ".join(new_addopts_args))
153+
# Write modified file
154+
with Path.open(config_file, "w", encoding="utf-8") as f:
155+
config.write(f)
156+
return content, True
157+
else:
158+
config.set("tool:pytest", "addopts", " ".join(new_addopts_args))
159+
# Write modified file
160+
with Path.open(config_file, "w", encoding="utf-8") as f:
161+
config.write(f)
162+
return content, True
163+
164+
except Exception:
165+
logger.debug("Trouble parsing")
166+
return content, False # not modified
167+
168+
169+
@contextmanager
170+
def custom_addopts() -> None:
171+
closest_config_files = get_all_closest_config_files()
172+
173+
original_content = {}
108174

175+
try:
176+
for config_file in closest_config_files:
177+
original_content[config_file] = modify_addopts(config_file)
109178
yield
110179

111180
finally:
112181
# Restore original file
113-
if (
114-
original_content
115-
and pyproject_file.exists()
116-
and tuple(original_addopts) not in {(), tuple(non_blacklist_plugin_args)}
117-
):
118-
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
119-
f.write(original_content)
182+
for file, (content, was_modified) in original_content.items():
183+
if was_modified:
184+
with Path.open(file, "w", encoding="utf-8") as f:
185+
f.write(content)
120186

121187

122188
@contextmanager

codeflash/code_utils/config_parser.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
import tomlkit
77

8+
PYPROJECT_TOML_CACHE = {}
9+
ALL_CONFIG_FILES = {} # map path to closest config file
10+
811

912
def find_pyproject_toml(config_file: Path | None = None) -> Path:
1013
# Find the pyproject.toml file on the root of the project
@@ -19,10 +22,15 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path:
1922
raise ValueError(msg)
2023
return config_file
2124
dir_path = Path.cwd()
22-
25+
cur_path = dir_path
26+
# see if it was encountered before in search
27+
if cur_path in PYPROJECT_TOML_CACHE:
28+
return PYPROJECT_TOML_CACHE[cur_path]
29+
# map current path to closest file
2330
while dir_path != dir_path.parent:
2431
config_file = dir_path / "pyproject.toml"
2532
if config_file.exists():
33+
PYPROJECT_TOML_CACHE[cur_path] = config_file
2634
return config_file
2735
# Search for pyproject.toml in the parent directories
2836
dir_path = dir_path.parent
@@ -31,6 +39,33 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path:
3139
raise ValueError(msg)
3240

3341

42+
def get_all_closest_config_files() -> list[Path]:
43+
all_closest_config_files = []
44+
for file_type in ["pyproject.toml", "pytest.ini", ".pytest.ini", "tox.ini", "setup.cfg"]:
45+
closest_config_file = find_closest_config_file(file_type)
46+
if closest_config_file:
47+
all_closest_config_files.append(closest_config_file)
48+
return all_closest_config_files
49+
50+
51+
def find_closest_config_file(file_type: str) -> Path | None:
52+
# Find the closest pyproject.toml, pytest.ini, tox.ini, or setup.cfg file on the root of the project
53+
dir_path = Path.cwd()
54+
cur_path = dir_path
55+
if cur_path in ALL_CONFIG_FILES and file_type in ALL_CONFIG_FILES[cur_path]:
56+
return ALL_CONFIG_FILES[cur_path][file_type]
57+
while dir_path != dir_path.parent:
58+
config_file = dir_path / file_type
59+
if config_file.exists():
60+
if cur_path not in ALL_CONFIG_FILES:
61+
ALL_CONFIG_FILES[cur_path] = {}
62+
ALL_CONFIG_FILES[cur_path][file_type] = config_file
63+
return config_file
64+
# Search for pyproject.toml in the parent directories
65+
dir_path = dir_path.parent
66+
return None
67+
68+
3469
def find_conftest_files(test_paths: list[Path]) -> list[Path]:
3570
list_of_conftest_files = set()
3671
for test_path in test_paths:

codeflash/discovery/discover_unit_tests.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,35 @@ class TestFunction:
6767

6868

6969
class TestsCache:
70+
SCHEMA_VERSION = 1 # Increment this when schema changes
71+
7072
def __init__(self, project_root_path: str | Path) -> None:
7173
self.project_root_path = Path(project_root_path).resolve().as_posix()
7274
self.connection = sqlite3.connect(codeflash_cache_db)
7375
self.cur = self.connection.cursor()
76+
77+
self.cur.execute(
78+
"""
79+
CREATE TABLE IF NOT EXISTS schema_version(
80+
version INTEGER PRIMARY KEY
81+
)
82+
"""
83+
)
84+
85+
self.cur.execute("SELECT version FROM schema_version")
86+
result = self.cur.fetchone()
87+
current_version = result[0] if result else None
88+
89+
if current_version != self.SCHEMA_VERSION:
90+
logger.debug(
91+
f"Schema version mismatch (current: {current_version}, expected: {self.SCHEMA_VERSION}). Recreating tables."
92+
)
93+
self.cur.execute("DROP TABLE IF EXISTS discovered_tests")
94+
self.cur.execute("DROP INDEX IF EXISTS idx_discovered_tests_project_file_path_hash")
95+
self.cur.execute("DELETE FROM schema_version")
96+
self.cur.execute("INSERT INTO schema_version (version) VALUES (?)", (self.SCHEMA_VERSION,))
97+
self.connection.commit()
98+
7499
self.cur.execute(
75100
"""
76101
CREATE TABLE IF NOT EXISTS discovered_tests(
@@ -158,14 +183,16 @@ def get_function_to_test_map_for_file(
158183
return result
159184

160185
@staticmethod
161-
def compute_file_hash(path: str | Path) -> str:
186+
def compute_file_hash(path: Path) -> str:
162187
h = hashlib.sha256(usedforsecurity=False)
163-
with Path(path).open("rb") as f:
188+
with path.open("rb", buffering=0) as f:
189+
buf = bytearray(8192)
190+
mv = memoryview(buf)
164191
while True:
165-
chunk = f.read(8192)
166-
if not chunk:
192+
n = f.readinto(mv)
193+
if n == 0:
167194
break
168-
h.update(chunk)
195+
h.update(mv[:n])
169196
return h.hexdigest()
170197

171198
def close(self) -> None:
@@ -488,13 +515,13 @@ def discover_tests_pytest(
488515

489516
def discover_tests_unittest(
490517
cfg: TestConfig,
491-
discover_only_these_tests: list[str] | None = None,
518+
discover_only_these_tests: list[Path] | None = None,
492519
functions_to_optimize: list[FunctionToOptimize] | None = None,
493520
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
494521
tests_root: Path = cfg.tests_root
495522
loader: unittest.TestLoader = unittest.TestLoader()
496523
tests: unittest.TestSuite = loader.discover(str(tests_root))
497-
file_to_test_map: defaultdict[str, list[TestsInFile]] = defaultdict(list)
524+
file_to_test_map: defaultdict[Path, list[TestsInFile]] = defaultdict(list)
498525

499526
def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
500527
_test_function, _test_module, _test_suite_name = (
@@ -506,7 +533,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
506533
_test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py")
507534
_test_module_path = tests_root / _test_module_path
508535
if not _test_module_path.exists() or (
509-
discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests
536+
discover_only_these_tests and _test_module_path not in discover_only_these_tests
510537
):
511538
return None
512539
if "__replay_test" in str(_test_module_path):
@@ -516,10 +543,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
516543
else:
517544
test_type = TestType.EXISTING_UNIT_TEST
518545
return TestsInFile(
519-
test_file=str(_test_module_path),
520-
test_function=_test_function,
521-
test_type=test_type,
522-
test_class=_test_suite_name,
546+
test_file=_test_module_path, test_function=_test_function, test_type=test_type, test_class=_test_suite_name
523547
)
524548

525549
for _test_suite in tests._tests:
@@ -537,11 +561,11 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
537561
continue
538562
details = get_test_details(test_2)
539563
if details is not None:
540-
file_to_test_map[str(details.test_file)].append(details)
564+
file_to_test_map[details.test_file].append(details)
541565
else:
542566
details = get_test_details(test)
543567
if details is not None:
544-
file_to_test_map[str(details.test_file)].append(details)
568+
file_to_test_map[details.test_file].append(details)
545569
return process_test_files(file_to_test_map, cfg, functions_to_optimize)
546570

547571

0 commit comments

Comments
 (0)