11import dataclasses
22import typing
3+ import inflection
34
45from misc .codegen .loaders import schemaloader
56from . import qlgen
@@ -15,19 +16,7 @@ class Param:
1516@dataclasses .dataclass
1617class Function :
1718 name : str
18- generic_params : list [Param ]
19- params : list [Param ]
20- return_type : str
21-
22- def __post_init__ (self ):
23- if self .generic_params :
24- self .generic_params [0 ].first = True
25- if self .params :
26- self .params [0 ].first = True
27-
28- @property
29- def has_generic_params (self ) -> bool :
30- return bool (self .generic_params )
19+ signature : str
3120
3221
3322@dataclasses .dataclass
@@ -48,27 +37,28 @@ def generate(opts, renderer):
4837 for cls in schema .classes .values ():
4938 if (qlgen .should_skip_qltest (cls , schema .classes ) or
5039 "rust_skip_test_from_doc" in cls .pragmas or
51- not cls .doc
52- ):
40+ not cls .doc ):
5341 continue
54- fn = cls .rust_doc_test_function
55- if fn :
56- generic_params = [Param (k , v ) for k , v in fn .params .items () if k [0 ].isupper () or k [0 ] == "'" ]
57- params = [Param (k , v ) for k , v in fn .params .items () if k [0 ].islower ()]
58- fn = Function (fn .name , generic_params , params , fn .return_type )
5942 code = []
6043 adding_code = False
44+ has_code = False
6145 for line in cls .doc :
6246 match line , adding_code :
6347 case "```" , _:
6448 adding_code = not adding_code
49+ has_code = True
6550 case _, False :
6651 code .append (f"// { line } " )
6752 case _, True :
6853 code .append (line )
54+ if not has_code :
55+ continue
56+ test_name = inflection .underscore (cls .name )
57+ signature = cls .rust_doc_test_function
58+ fn = signature and Function (f"test_{ test_name } " , signature )
6959 if fn :
7060 indent = 4 * " "
7161 code = [indent + l for l in code ]
7262 test_with = schema .classes [cls .test_with ] if cls .test_with else cls
73- test = opts .ql_test_output / test_with .group / test_with .name / f"gen_{ cls . name . lower () } .rs"
63+ test = opts .ql_test_output / test_with .group / test_with .name / f"gen_{ test_name } .rs"
7464 renderer .render (TestCode (code = "\n " .join (code ), function = fn ), test )
0 commit comments