Skip to content

Commit 459b5cc

Browse files
author
Codeflash Bot
committed
unit tests
1 parent f626277 commit 459b5cc

File tree

2 files changed

+201
-7
lines changed

2 files changed

+201
-7
lines changed

codeflash/cli_cmds/cmd_init.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454

5555
@dataclass(frozen=True)
56-
class SetupInfo:
56+
class CLISetupInfo:
5757
module_root: str
5858
tests_root: str
5959
benchmarks_root: Union[str, None]
@@ -101,7 +101,7 @@ def init_codeflash() -> None:
101101
git_remote = config.get("git_remote", "origin") if config else "origin"
102102

103103
if should_modify:
104-
setup_info: SetupInfo = collect_setup_info()
104+
setup_info: CLISetupInfo = collect_setup_info()
105105
git_remote = setup_info.git_remote
106106
configured = configure_pyproject_toml(setup_info)
107107
if not configured:
@@ -240,7 +240,7 @@ def get_toml_key(self) -> str:
240240

241241

242242
@lru_cache(maxsize=1)
243-
def get_valid_subdirs() -> list[str]:
243+
def get_valid_subdirs(current_dir: Optional[Path] = None) -> list[str]:
244244
ignore_subdirs = [
245245
"venv",
246246
"node_modules",
@@ -253,8 +253,11 @@ def get_valid_subdirs() -> list[str]:
253253
"tmp",
254254
"__pycache__",
255255
]
256+
path_str = str(current_dir) if current_dir else "."
256257
return [
257-
d for d in next(os.walk("."))[1] if not d.startswith(".") and not d.startswith("__") and d not in ignore_subdirs
258+
d
259+
for d in next(os.walk(path_str))[1]
260+
if not d.startswith(".") and not d.startswith("__") and d not in ignore_subdirs
258261
]
259262

260263

@@ -274,7 +277,7 @@ def get_suggestions(section: str) -> tuple(list[str], Optional[str]):
274277
raise ValueError(msg)
275278

276279

277-
def collect_setup_info() -> SetupInfo:
280+
def collect_setup_info() -> CLISetupInfo:
278281
curdir = Path.cwd()
279282
# Check if the cwd is writable
280283
if not os.access(curdir, os.W_OK):
@@ -554,7 +557,7 @@ def collect_setup_info() -> SetupInfo:
554557
enable_telemetry = ask_for_telemetry()
555558

556559
ignore_paths: list[str] = []
557-
return SetupInfo(
560+
return CLISetupInfo(
558561
module_root=str(module_root),
559562
tests_root=str(tests_root),
560563
benchmarks_root=str(benchmarks_root) if benchmarks_root else None,
@@ -1003,7 +1006,9 @@ def get_formatter_cmds(formatter: str) -> list[str]:
10031006

10041007

10051008
# Create or update the pyproject.toml file with the Codeflash dependency & configuration
1006-
def configure_pyproject_toml(setup_info: Union[VsCodeSetupInfo, SetupInfo], config_file: Optional[Path] = None) -> bool:
1009+
def configure_pyproject_toml(
1010+
setup_info: Union[VsCodeSetupInfo, CLISetupInfo], config_file: Optional[Path] = None
1011+
) -> bool:
10071012
for_vscode = isinstance(setup_info, VsCodeSetupInfo)
10081013
toml_path = config_file or Path.cwd() / "pyproject.toml"
10091014
try:

tests/test_cmd_init.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import pytest
2+
import tempfile
3+
from pathlib import Path
4+
from codeflash.cli_cmds.cmd_init import (
5+
is_valid_pyproject_toml,
6+
configure_pyproject_toml,
7+
CLISetupInfo,
8+
get_formatter_cmds,
9+
VsCodeSetupInfo,
10+
get_valid_subdirs,
11+
)
12+
import os
13+
14+
15+
@pytest.fixture
16+
def temp_dir():
17+
with tempfile.TemporaryDirectory() as tmpdirname:
18+
yield Path(tmpdirname)
19+
20+
21+
def test_is_valid_pyproject_toml_with_empty_config(temp_dir: Path) -> None:
22+
with (temp_dir / "pyproject.toml").open(mode="w") as f:
23+
f.write(
24+
"""[tool.codeflash]
25+
"""
26+
)
27+
f.flush()
28+
valid, _, _message = is_valid_pyproject_toml(temp_dir / "pyproject.toml")
29+
assert not valid
30+
assert _message == "Missing required field: 'module_root'"
31+
32+
def test_is_valid_pyproject_toml_with_incorrect_module_root(temp_dir: Path) -> None:
33+
with (temp_dir / "pyproject.toml").open(mode="w") as f:
34+
wrong_module_root = temp_dir / "invalid_directory"
35+
f.write(
36+
f"""[tool.codeflash]
37+
module-root = "invalid_directory"
38+
"""
39+
)
40+
f.flush()
41+
valid, config, _message = is_valid_pyproject_toml(temp_dir / "pyproject.toml")
42+
assert not valid
43+
assert _message == f"Invalid 'module_root': directory does not exist at {wrong_module_root}"
44+
45+
46+
def test_is_valid_pyproject_toml_with_incorrect_tests_root(temp_dir: Path) -> None:
47+
with (temp_dir / "pyproject.toml").open(mode="w") as f:
48+
wrong_tests_root = temp_dir / "incorrect_tests_root"
49+
f.write(
50+
f"""[tool.codeflash]
51+
module-root = "."
52+
tests-root = "incorrect_tests_root"
53+
"""
54+
)
55+
f.flush()
56+
valid, config, _message = is_valid_pyproject_toml(temp_dir / "pyproject.toml")
57+
assert not valid
58+
assert _message == f"Invalid 'tests_root': directory does not exist at {wrong_tests_root}"
59+
60+
61+
def test_is_valid_pyproject_toml_with_valid_config(temp_dir: Path) -> None:
62+
with (temp_dir / "pyproject.toml").open(mode="w") as f:
63+
os.makedirs(temp_dir / "tests")
64+
f.write(
65+
"""[tool.codeflash]
66+
module-root = "."
67+
tests-root = "tests"
68+
test-framework = "pytest"
69+
"""
70+
)
71+
f.flush()
72+
valid, config, _message = is_valid_pyproject_toml(temp_dir / "pyproject.toml")
73+
assert valid
74+
75+
def test_get_formatter_cmd(temp_dir: Path) -> None:
76+
assert get_formatter_cmds("black") == ["black $file"]
77+
assert get_formatter_cmds("ruff") == ["ruff check --exit-zero --fix $file", "ruff format $file"]
78+
assert get_formatter_cmds("disabled") == ["disabled"]
79+
assert get_formatter_cmds("don't use a formatter") == ["disabled"]
80+
81+
def test_configure_pyproject_toml_for_cli(temp_dir: Path) -> None:
82+
83+
pyproject_path = temp_dir / "pyproject.toml"
84+
85+
with (pyproject_path).open(mode="w") as f:
86+
f.write("")
87+
f.flush()
88+
os.mkdir(temp_dir / "tests")
89+
config = CLISetupInfo(
90+
module_root=".",
91+
tests_root="tests",
92+
benchmarks_root=None,
93+
test_framework="pytest",
94+
ignore_paths=[],
95+
formatter="black",
96+
git_remote="origin",
97+
enable_telemetry=False,
98+
)
99+
100+
success = configure_pyproject_toml(config, pyproject_path)
101+
assert success
102+
103+
config_content = pyproject_path.read_text()
104+
assert """[tool.codeflash]
105+
# All paths are relative to this pyproject.toml's directory.
106+
module-root = "."
107+
tests-root = "tests"
108+
test-framework = "pytest"
109+
ignore-paths = []
110+
disable-telemetry = true
111+
formatter-cmds = ["black $file"]
112+
""" == config_content
113+
valid, _, _ = is_valid_pyproject_toml(pyproject_path)
114+
assert valid
115+
116+
def test_configure_pyproject_toml_for_vscode_with_empty_config(temp_dir: Path) -> None:
117+
118+
pyproject_path = temp_dir / "pyproject.toml"
119+
120+
with (pyproject_path).open(mode="w") as f:
121+
f.write("")
122+
f.flush()
123+
os.mkdir(temp_dir / "tests")
124+
config = VsCodeSetupInfo(
125+
module_root=".",
126+
tests_root="tests",
127+
test_framework="pytest",
128+
formatter="black",
129+
)
130+
131+
success = configure_pyproject_toml(config, pyproject_path)
132+
assert success
133+
134+
config_content = pyproject_path.read_text()
135+
assert """[tool.codeflash]
136+
module-root = "."
137+
tests-root = "tests"
138+
test-framework = "pytest"
139+
formatter-cmds = ["black $file"]
140+
""" == config_content
141+
valid, _, _ = is_valid_pyproject_toml(pyproject_path)
142+
assert valid
143+
144+
def test_configure_pyproject_toml_for_vscode_with_existing_config(temp_dir: Path) -> None:
145+
pyproject_path = temp_dir / "pyproject.toml"
146+
147+
with (pyproject_path).open(mode="w") as f:
148+
f.write("""[tool.codeflash]
149+
module-root = "codeflash"
150+
tests-root = "tests"
151+
benchmarks-root = "tests/benchmarks"
152+
test-framework = "pytest"
153+
formatter-cmds = ["disabled"]
154+
""")
155+
f.flush()
156+
os.mkdir(temp_dir / "tests")
157+
config = VsCodeSetupInfo(
158+
module_root=".",
159+
tests_root="tests",
160+
test_framework="pytest",
161+
formatter="disabled",
162+
)
163+
164+
success = configure_pyproject_toml(config, pyproject_path)
165+
assert success
166+
167+
config_content = pyproject_path.read_text()
168+
# the benchmarks-root shouldn't get overwritten
169+
assert """[tool.codeflash]
170+
module-root = "."
171+
tests-root = "tests"
172+
benchmarks-root = "tests/benchmarks"
173+
test-framework = "pytest"
174+
formatter-cmds = ["disabled"]
175+
""" == config_content
176+
valid, _, _ = is_valid_pyproject_toml(pyproject_path)
177+
assert valid
178+
179+
def test_get_valid_subdirs(temp_dir: Path) -> None:
180+
os.mkdir(temp_dir / "dir1")
181+
os.mkdir(temp_dir / "dir2")
182+
os.mkdir(temp_dir / "__pycache__")
183+
os.mkdir(temp_dir / ".git")
184+
os.mkdir(temp_dir / "tests")
185+
186+
dirs = get_valid_subdirs(temp_dir)
187+
assert "tests" in dirs
188+
assert "dir1" in dirs
189+
assert "dir2" in dirs

0 commit comments

Comments
 (0)