@@ -1589,16 +1589,20 @@ class ScanPathResult:
15891589 has_run_registered_test_suites Whether or not the path contained at
15901590 least one call to
15911591 ztest_run_registered_test_suites.
1592+ has_test_main Whether or not the path contains a
1593+ definition of test_main(void)
15921594 """
15931595 def __init__ (self ,
15941596 matches : List [str ] = None ,
15951597 warnings : str = None ,
15961598 has_registered_test_suites : bool = False ,
1597- has_run_registered_test_suites : bool = False ):
1599+ has_run_registered_test_suites : bool = False ,
1600+ has_test_main : bool = False ):
15981601 self .matches = matches
15991602 self .warnings = warnings
16001603 self .has_registered_test_suites = has_registered_test_suites
16011604 self .has_run_registered_test_suites = has_run_registered_test_suites
1605+ self .has_test_main = has_test_main
16021606
16031607 def __eq__ (self , other ):
16041608 if not isinstance (other , ScanPathResult ):
@@ -1608,7 +1612,8 @@ def __eq__(self, other):
16081612 (self .has_registered_test_suites ==
16091613 other .has_registered_test_suites ) and
16101614 (self .has_run_registered_test_suites ==
1611- other .has_run_registered_test_suites ))
1615+ other .has_run_registered_test_suites ) and
1616+ self .has_test_main == other .has_test_main )
16121617
16131618
16141619class TestCase (DisablePyTestCollectionMixin ):
@@ -1701,6 +1706,15 @@ def scan_file(inf_name):
17011706 br"^\s*ztest_register_test_suite"
17021707 br"\(\s*(?P<suite_name>[a-zA-Z0-9_]+)\s*," ,
17031708 re .MULTILINE )
1709+ # Checks if the file contains a definition of "void test_main(void)"
1710+ # Since ztest provides a plain test_main implementation it is OK to:
1711+ # 1. register test suites and not call the run function iff the test
1712+ # doesn't have a custom test_main.
1713+ # 2. register test suites and a custom test_main definition iff the test
1714+ # also calls ztest_run_registered_test_suites.
1715+ test_main_regex = re .compile (
1716+ br"^\s*void\s+test_main\(void\)" ,
1717+ re .MULTILINE )
17041718 stc_regex = re .compile (
17051719 br"""^\s* # empy space at the beginning is ok
17061720 # catch the case where it is declared in the same sentence, e.g:
@@ -1733,6 +1747,7 @@ def scan_file(inf_name):
17331747 warnings = None
17341748 has_registered_test_suites = False
17351749 has_run_registered_test_suites = False
1750+ has_test_main = False
17361751
17371752 with open (inf_name ) as inf :
17381753 if os .name == 'nt' :
@@ -1750,6 +1765,8 @@ def scan_file(inf_name):
17501765 has_registered_test_suites = True
17511766 if registered_suite_run_regex .search (main_c ):
17521767 has_run_registered_test_suites = True
1768+ if test_main_regex .search (main_c ):
1769+ has_test_main = True
17531770
17541771 if not suite_regex_match and not has_registered_test_suites :
17551772 # can't find ztest_test_suite, maybe a client, because
@@ -1758,7 +1775,8 @@ def scan_file(inf_name):
17581775 matches = None ,
17591776 warnings = None ,
17601777 has_registered_test_suites = has_registered_test_suites ,
1761- has_run_registered_test_suites = has_run_registered_test_suites )
1778+ has_run_registered_test_suites = has_run_registered_test_suites ,
1779+ has_test_main = has_test_main )
17621780
17631781 suite_run_match = suite_run_regex .search (main_c )
17641782 if suite_regex_match and not suite_run_match :
@@ -1792,12 +1810,14 @@ def scan_file(inf_name):
17921810 matches = matches ,
17931811 warnings = warnings ,
17941812 has_registered_test_suites = has_registered_test_suites ,
1795- has_run_registered_test_suites = has_run_registered_test_suites )
1813+ has_run_registered_test_suites = has_run_registered_test_suites ,
1814+ has_test_main = has_test_main )
17961815
17971816 def scan_path (self , path ):
17981817 subcases = []
17991818 has_registered_test_suites = False
18001819 has_run_registered_test_suites = False
1820+ has_test_main = False
18011821 for filename in glob .glob (os .path .join (path , "src" , "*.c*" )):
18021822 try :
18031823 result : ScanPathResult = self .scan_file (filename )
@@ -1811,6 +1831,8 @@ def scan_path(self, path):
18111831 has_registered_test_suites = True
18121832 if result .has_run_registered_test_suites :
18131833 has_run_registered_test_suites = True
1834+ if result .has_test_main :
1835+ has_test_main = True
18141836 except ValueError as e :
18151837 logger .error ("%s: can't find: %s" % (filename , e ))
18161838
@@ -1824,7 +1846,8 @@ def scan_path(self, path):
18241846 except ValueError as e :
18251847 logger .error ("%s: can't find: %s" % (filename , e ))
18261848
1827- if has_registered_test_suites and not has_run_registered_test_suites :
1849+ if (has_registered_test_suites and has_test_main and
1850+ not has_run_registered_test_suites ):
18281851 warning = \
18291852 "Found call to 'ztest_register_test_suite()' but no " \
18301853 "call to 'ztest_run_registered_test_suites()'"
0 commit comments