Skip to content

Commit 7aee1c1

Browse files
author
Codeflash Bot
committed
helper for init choices
1 parent 491e616 commit 7aee1c1

File tree

1 file changed

+77
-40
lines changed

1 file changed

+77
-40
lines changed

codeflash/cli_cmds/cmd_init.py

Lines changed: 77 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import subprocess
77
import sys
88
from enum import Enum, auto
9+
from functools import lru_cache
910
from pathlib import Path
10-
from typing import TYPE_CHECKING, Any, Union, cast
11+
from typing import TYPE_CHECKING, Any, Optional, Union, cast
1112

1213
import click
1314
import git
@@ -214,17 +215,15 @@ def __init__(self) -> None:
214215
self.Checkbox.unselected_icon = "⬜"
215216

216217

217-
def collect_setup_info() -> SetupInfo:
218-
curdir = Path.cwd()
219-
# Check if the cwd is writable
220-
if not os.access(curdir, os.W_OK):
221-
click.echo(f"❌ The current directory isn't writable, please check your folder permissions and try again.{LF}")
222-
click.echo("It's likely you don't have write permissions for this folder.")
223-
sys.exit(1)
218+
# common sections between normal mode and lsp mode
219+
class CommonSections(Enum):
220+
module_root = "module_root"
221+
tests_root = "tests_root"
222+
test_framework = "test_framework"
224223

225-
# Check for the existence of pyproject.toml or setup.py
226-
project_name = check_for_toml_or_setup_file()
227224

225+
@lru_cache(maxsize=1)
226+
def get_valid_subdirs() -> list[str]:
228227
ignore_subdirs = [
229228
"venv",
230229
"node_modules",
@@ -237,11 +236,36 @@ def collect_setup_info() -> SetupInfo:
237236
"tmp",
238237
"__pycache__",
239238
]
240-
valid_subdirs = [
239+
return [
241240
d for d in next(os.walk("."))[1] if not d.startswith(".") and not d.startswith("__") and d not in ignore_subdirs
242241
]
243242

244-
valid_module_subdirs = [d for d in valid_subdirs if d != "tests"]
243+
244+
def get_suggestions(section: str) -> tuple(list[str], Optional[str]):
245+
valid_subdirs = get_valid_subdirs()
246+
if section == CommonSections.module_root:
247+
return [d for d in valid_subdirs if d != "tests"], None
248+
if section == CommonSections.tests_root:
249+
default = "tests" if "tests" in valid_subdirs else None
250+
return valid_subdirs, default
251+
if section == CommonSections.test_framework:
252+
auto_detected = detect_test_framework_from_config_files(Path.cwd())
253+
return ["pytest", "unittest"], auto_detected
254+
msg = f"Unknown section: {section}"
255+
raise ValueError(msg)
256+
257+
258+
def collect_setup_info() -> SetupInfo:
259+
curdir = Path.cwd()
260+
# Check if the cwd is writable
261+
if not os.access(curdir, os.W_OK):
262+
click.echo(f"❌ The current directory isn't writable, please check your folder permissions and try again.{LF}")
263+
click.echo("It's likely you don't have write permissions for this folder.")
264+
sys.exit(1)
265+
266+
# Check for the existence of pyproject.toml or setup.py
267+
project_name = check_for_toml_or_setup_file()
268+
valid_module_subdirs, _ = get_suggestions(CommonSections.module_root)
245269

246270
curdir_option = f"current directory ({curdir})"
247271
custom_dir_option = "enter a custom directory…"
@@ -305,10 +329,10 @@ def collect_setup_info() -> SetupInfo:
305329
ph("cli-project-root-provided")
306330

307331
# Discover test directory
308-
default_tests_subdir = "tests"
309332
create_for_me_option = f"🆕 Create a new tests{os.pathsep} directory for me!"
310-
test_subdir_options = [sub_dir for sub_dir in valid_subdirs if sub_dir != module_root]
311-
if "tests" not in valid_subdirs:
333+
tests_suggestions, default_tests_subdir = get_suggestions(CommonSections.tests_root)
334+
test_subdir_options = [sub_dir for sub_dir in tests_suggestions if sub_dir != module_root]
335+
if "tests" not in tests_suggestions:
312336
test_subdir_options.append(create_for_me_option)
313337
custom_dir_option = "📁 Enter a custom directory…"
314338
test_subdir_options.append(custom_dir_option)
@@ -331,7 +355,7 @@ def collect_setup_info() -> SetupInfo:
331355
"tests_root",
332356
message="Where are your tests located?",
333357
choices=test_subdir_options,
334-
default=(default_tests_subdir if default_tests_subdir in test_subdir_options else test_subdir_options[0]),
358+
default=(default_tests_subdir or test_subdir_options[0]),
335359
carousel=True,
336360
)
337361
]
@@ -382,7 +406,8 @@ def collect_setup_info() -> SetupInfo:
382406

383407
ph("cli-tests-root-provided")
384408

385-
autodetected_test_framework = detect_test_framework(curdir, tests_root)
409+
test_framework_choices, detected_framework = get_suggestions(CommonSections.test_framework)
410+
autodetected_test_framework = detected_framework or detect_test_framework_from_test_files(tests_root)
386411

387412
framework_message = "⚗️ Let's configure your test framework.\n\n"
388413
if autodetected_test_framework:
@@ -393,11 +418,19 @@ def collect_setup_info() -> SetupInfo:
393418
console.print(framework_panel)
394419
console.print()
395420

421+
framework_choices = []
422+
# add icons based on the detected framework
423+
for choice in test_framework_choices:
424+
if choice == "pytest":
425+
framework_choices.append(("🧪 pytest", "pytest"))
426+
elif choice == "unittest":
427+
framework_choices.append(("🐍 unittest", "unittest"))
428+
396429
framework_questions = [
397430
inquirer.List(
398431
"test_framework",
399432
message="Which test framework do you use?",
400-
choices=[("🧪 pytest", "pytest"), ("🐍 unittest", "unittest")],
433+
choices=framework_choices,
401434
default=autodetected_test_framework or "pytest",
402435
carousel=True,
403436
)
@@ -511,7 +544,7 @@ def collect_setup_info() -> SetupInfo:
511544
)
512545

513546

514-
def detect_test_framework(curdir: Path, tests_root: Path) -> str | None:
547+
def detect_test_framework_from_config_files(curdir: Path) -> Optional[str]:
515548
test_framework = None
516549
pytest_files = ["pytest.ini", "pyproject.toml", "tox.ini", "setup.cfg"]
517550
pytest_config_patterns = {
@@ -529,27 +562,31 @@ def detect_test_framework(curdir: Path, tests_root: Path) -> str | None:
529562
test_framework = "pytest"
530563
break
531564
test_framework = "pytest"
532-
else:
533-
# Check if any python files contain a class that inherits from unittest.TestCase
534-
for filename in tests_root.iterdir():
535-
if filename.suffix == ".py":
536-
with filename.open(encoding="utf8") as file:
537-
contents = file.read()
538-
try:
539-
node = ast.parse(contents)
540-
except SyntaxError:
541-
continue
542-
if any(
543-
isinstance(item, ast.ClassDef)
544-
and any(
545-
(isinstance(base, ast.Attribute) and base.attr == "TestCase")
546-
or (isinstance(base, ast.Name) and base.id == "TestCase")
547-
for base in item.bases
548-
)
549-
for item in node.body
550-
):
551-
test_framework = "unittest"
552-
break
565+
return test_framework
566+
567+
568+
def detect_test_framework_from_test_files(tests_root: Path) -> Optional[str]:
569+
test_framework = None
570+
# Check if any python files contain a class that inherits from unittest.TestCase
571+
for filename in tests_root.iterdir():
572+
if filename.suffix == ".py":
573+
with filename.open(encoding="utf8") as file:
574+
contents = file.read()
575+
try:
576+
node = ast.parse(contents)
577+
except SyntaxError:
578+
continue
579+
if any(
580+
isinstance(item, ast.ClassDef)
581+
and any(
582+
(isinstance(base, ast.Attribute) and base.attr == "TestCase")
583+
or (isinstance(base, ast.Name) and base.id == "TestCase")
584+
for base in item.bases
585+
)
586+
for item in node.body
587+
):
588+
test_framework = "unittest"
589+
break
553590
return test_framework
554591

555592

0 commit comments

Comments
 (0)