66import subprocess
77import sys
88from enum import Enum , auto
9+ from functools import lru_cache
910from pathlib import Path
10- from typing import TYPE_CHECKING , Any , Union , cast
11+ from typing import TYPE_CHECKING , Any , Optional , Union , cast
1112
1213import click
1314import git
@@ -214,17 +215,15 @@ def __init__(self) -> None:
214215 self .Checkbox .unselected_icon = "⬜"
215216
216217
217- def collect_setup_info () -> SetupInfo :
218- curdir = Path .cwd ()
219- # Check if the cwd is writable
220- if not os .access (curdir , os .W_OK ):
221- click .echo (f"❌ The current directory isn't writable, please check your folder permissions and try again.{ LF } " )
222- click .echo ("It's likely you don't have write permissions for this folder." )
223- sys .exit (1 )
218+ # common sections between normal mode and lsp mode
219+ class CommonSections (Enum ):
220+ module_root = "module_root"
221+ tests_root = "tests_root"
222+ test_framework = "test_framework"
224223
225- # Check for the existence of pyproject.toml or setup.py
226- project_name = check_for_toml_or_setup_file ()
227224
225+ @lru_cache (maxsize = 1 )
226+ def get_valid_subdirs () -> list [str ]:
228227 ignore_subdirs = [
229228 "venv" ,
230229 "node_modules" ,
@@ -237,11 +236,36 @@ def collect_setup_info() -> SetupInfo:
237236 "tmp" ,
238237 "__pycache__" ,
239238 ]
240- valid_subdirs = [
239+ return [
241240 d for d in next (os .walk ("." ))[1 ] if not d .startswith ("." ) and not d .startswith ("__" ) and d not in ignore_subdirs
242241 ]
243242
244- valid_module_subdirs = [d for d in valid_subdirs if d != "tests" ]
243+
244+ def get_suggestions (section : str ) -> tuple (list [str ], Optional [str ]):
245+ valid_subdirs = get_valid_subdirs ()
246+ if section == CommonSections .module_root :
247+ return [d for d in valid_subdirs if d != "tests" ], None
248+ if section == CommonSections .tests_root :
249+ default = "tests" if "tests" in valid_subdirs else None
250+ return valid_subdirs , default
251+ if section == CommonSections .test_framework :
252+ auto_detected = detect_test_framework_from_config_files (Path .cwd ())
253+ return ["pytest" , "unittest" ], auto_detected
254+ msg = f"Unknown section: { section } "
255+ raise ValueError (msg )
256+
257+
258+ def collect_setup_info () -> SetupInfo :
259+ curdir = Path .cwd ()
260+ # Check if the cwd is writable
261+ if not os .access (curdir , os .W_OK ):
262+ click .echo (f"❌ The current directory isn't writable, please check your folder permissions and try again.{ LF } " )
263+ click .echo ("It's likely you don't have write permissions for this folder." )
264+ sys .exit (1 )
265+
266+ # Check for the existence of pyproject.toml or setup.py
267+ project_name = check_for_toml_or_setup_file ()
268+ valid_module_subdirs , _ = get_suggestions (CommonSections .module_root )
245269
246270 curdir_option = f"current directory ({ curdir } )"
247271 custom_dir_option = "enter a custom directory…"
@@ -305,10 +329,10 @@ def collect_setup_info() -> SetupInfo:
305329 ph ("cli-project-root-provided" )
306330
307331 # Discover test directory
308- default_tests_subdir = "tests"
309332 create_for_me_option = f"🆕 Create a new tests{ os .pathsep } directory for me!"
310- test_subdir_options = [sub_dir for sub_dir in valid_subdirs if sub_dir != module_root ]
311- if "tests" not in valid_subdirs :
333+ tests_suggestions , default_tests_subdir = get_suggestions (CommonSections .tests_root )
334+ test_subdir_options = [sub_dir for sub_dir in tests_suggestions if sub_dir != module_root ]
335+ if "tests" not in tests_suggestions :
312336 test_subdir_options .append (create_for_me_option )
313337 custom_dir_option = "📁 Enter a custom directory…"
314338 test_subdir_options .append (custom_dir_option )
@@ -331,7 +355,7 @@ def collect_setup_info() -> SetupInfo:
331355 "tests_root" ,
332356 message = "Where are your tests located?" ,
333357 choices = test_subdir_options ,
334- default = (default_tests_subdir if default_tests_subdir in test_subdir_options else test_subdir_options [0 ]),
358+ default = (default_tests_subdir or test_subdir_options [0 ]),
335359 carousel = True ,
336360 )
337361 ]
@@ -382,7 +406,8 @@ def collect_setup_info() -> SetupInfo:
382406
383407 ph ("cli-tests-root-provided" )
384408
385- autodetected_test_framework = detect_test_framework (curdir , tests_root )
409+ test_framework_choices , detected_framework = get_suggestions (CommonSections .test_framework )
410+ autodetected_test_framework = detected_framework or detect_test_framework_from_test_files (tests_root )
386411
387412 framework_message = "⚗️ Let's configure your test framework.\n \n "
388413 if autodetected_test_framework :
@@ -393,11 +418,19 @@ def collect_setup_info() -> SetupInfo:
393418 console .print (framework_panel )
394419 console .print ()
395420
421+ framework_choices = []
422+ # add icons based on the detected framework
423+ for choice in test_framework_choices :
424+ if choice == "pytest" :
425+ framework_choices .append (("🧪 pytest" , "pytest" ))
426+ elif choice == "unittest" :
427+ framework_choices .append (("🐍 unittest" , "unittest" ))
428+
396429 framework_questions = [
397430 inquirer .List (
398431 "test_framework" ,
399432 message = "Which test framework do you use?" ,
400- choices = [( "🧪 pytest" , "pytest" ), ( "🐍 unittest" , "unittest" )] ,
433+ choices = framework_choices ,
401434 default = autodetected_test_framework or "pytest" ,
402435 carousel = True ,
403436 )
@@ -511,7 +544,7 @@ def collect_setup_info() -> SetupInfo:
511544 )
512545
513546
514- def detect_test_framework (curdir : Path , tests_root : Path ) -> str | None :
547+ def detect_test_framework_from_config_files (curdir : Path ) -> Optional [ str ] :
515548 test_framework = None
516549 pytest_files = ["pytest.ini" , "pyproject.toml" , "tox.ini" , "setup.cfg" ]
517550 pytest_config_patterns = {
@@ -529,27 +562,31 @@ def detect_test_framework(curdir: Path, tests_root: Path) -> str | None:
529562 test_framework = "pytest"
530563 break
531564 test_framework = "pytest"
532- else :
533- # Check if any python files contain a class that inherits from unittest.TestCase
534- for filename in tests_root .iterdir ():
535- if filename .suffix == ".py" :
536- with filename .open (encoding = "utf8" ) as file :
537- contents = file .read ()
538- try :
539- node = ast .parse (contents )
540- except SyntaxError :
541- continue
542- if any (
543- isinstance (item , ast .ClassDef )
544- and any (
545- (isinstance (base , ast .Attribute ) and base .attr == "TestCase" )
546- or (isinstance (base , ast .Name ) and base .id == "TestCase" )
547- for base in item .bases
548- )
549- for item in node .body
550- ):
551- test_framework = "unittest"
552- break
565+ return test_framework
566+
567+
568+ def detect_test_framework_from_test_files (tests_root : Path ) -> Optional [str ]:
569+ test_framework = None
570+ # Check if any python files contain a class that inherits from unittest.TestCase
571+ for filename in tests_root .iterdir ():
572+ if filename .suffix == ".py" :
573+ with filename .open (encoding = "utf8" ) as file :
574+ contents = file .read ()
575+ try :
576+ node = ast .parse (contents )
577+ except SyntaxError :
578+ continue
579+ if any (
580+ isinstance (item , ast .ClassDef )
581+ and any (
582+ (isinstance (base , ast .Attribute ) and base .attr == "TestCase" )
583+ or (isinstance (base , ast .Name ) and base .id == "TestCase" )
584+ for base in item .bases
585+ )
586+ for item in node .body
587+ ):
588+ test_framework = "unittest"
589+ break
553590 return test_framework
554591
555592
0 commit comments