Skip to content

Commit 8c5cc2e

Browse files
author
Paolo Tranquilli
committed
Rust: generate test code from schema docstrings
This generates test source files from code blocks in class docstrings. By default the test code is generated as is, but it can optionally: * be wrapped in a function providing an adequate context using `@rust.doc_test_function(name, *, lifetimes=(), return_type="()", **kwargs)`, with `kwargs` providing both generic and normal params depending on capitalization * be skipped altogether using `@rust.skip_doc_test` So for example an annotation like ```python @rust.doc_test_function("foo", lifetimes=("a",), T="Eq", x="&'a T", y="&'a T", return_type="&'a T") ``` will result in the following wrapper: ```rust fn foo<'a, T: Eq>(x: &'a T, y: &'a T) -> &'a T { // example code here } ```
1 parent 122e5a7 commit 8c5cc2e

File tree

19 files changed

+199
-35
lines changed

19 files changed

+199
-35
lines changed

misc/codegen/generators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import dbschemegen, qlgen, trapgen, cppgen, rustgen
1+
from . import dbschemegen, trapgen, cppgen, rustgen, rusttestgen, qlgen
22

33

44
def generate(target, opts, renderer):

misc/codegen/generators/qlgen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def _is_under_qltest_collapsed_hierarchy(cls: schema.Class, lookup: typing.Dict[
287287
_is_in_qltest_collapsed_hierarchy(lookup[b], lookup) for b in cls.bases)
288288

289289

290-
def _should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
290+
def should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
291291
return "qltest_skip" in cls.pragmas or not (
292292
cls.final or "qltest_collapse_hierarchy" in cls.pragmas) or _is_under_qltest_collapsed_hierarchy(
293293
cls, lookup)
@@ -413,7 +413,7 @@ def generate(opts, renderer):
413413

414414
if test_out:
415415
for c in data.classes.values():
416-
if _should_skip_qltest(c, data.classes):
416+
if should_skip_qltest(c, data.classes):
417417
continue
418418
test_with = data.classes[c.test_with] if c.test_with else c
419419
test_dir = test_out / test_with.group / test_with.name

misc/codegen/generators/rustgen.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,24 @@ def generate(opts, renderer):
8686
processor = Processor(schemaloader.load_file(opts.schema))
8787
out = opts.rust_output
8888
groups = set()
89-
for group, classes in processor.get_classes().items():
90-
group = group or "top"
91-
groups.add(group)
89+
with renderer.manage(generated=out.rglob("*.rs"),
90+
stubs=(),
91+
registry=opts.generated_registry,
92+
force=opts.force) as renderer:
93+
for group, classes in processor.get_classes().items():
94+
group = group or "top"
95+
groups.add(group)
96+
renderer.render(
97+
rust.ClassList(
98+
classes,
99+
opts.schema,
100+
),
101+
out / f"{group}.rs",
102+
)
92103
renderer.render(
93-
rust.ClassList(
94-
classes,
104+
rust.ModuleList(
105+
groups,
95106
opts.schema,
96107
),
97-
out / f"{group}.rs",
108+
out / f"mod.rs",
98109
)
99-
renderer.render(
100-
rust.ModuleList(
101-
groups,
102-
opts.schema,
103-
),
104-
out / f"mod.rs",
105-
)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import dataclasses
2+
import typing
3+
4+
from misc.codegen.loaders import schemaloader
5+
from . import qlgen
6+
7+
8+
@dataclasses.dataclass
9+
class Param:
10+
name: str
11+
type: str
12+
first: bool = False
13+
14+
15+
@dataclasses.dataclass
16+
class Function:
17+
name: str
18+
generic_params: list[Param]
19+
params: list[Param]
20+
return_type: str
21+
22+
def __post_init__(self):
23+
if self.generic_params:
24+
self.generic_params[0].first = True
25+
if self.params:
26+
self.params[0].first = True
27+
28+
29+
@dataclasses.dataclass
30+
class TestCode:
31+
template: typing.ClassVar[str] = "rust_test_code"
32+
33+
code: str
34+
function: Function | None = None
35+
36+
37+
def generate(opts, renderer):
38+
assert opts.ql_test_output
39+
schema = schemaloader.load_file(opts.schema)
40+
with renderer.manage(generated=opts.ql_test_output.rglob("gen_*.rs"),
41+
stubs=(),
42+
registry=opts.generated_registry,
43+
force=opts.force) as renderer:
44+
for cls in schema.classes.values():
45+
if (qlgen.should_skip_qltest(cls, schema.classes) or
46+
"rust_skip_test_from_doc" in cls.pragmas or
47+
not cls.doc
48+
):
49+
continue
50+
fn = cls.rust_doc_test_function
51+
if fn:
52+
generic_params = [Param(k, v) for k, v in fn.params.items() if k[0].isupper() or k[0] == "'"]
53+
params = [Param(k, v) for k, v in fn.params.items() if k[0].islower()]
54+
fn = Function(fn.name, generic_params, params, fn.return_type)
55+
code = []
56+
adding_code = False
57+
for line in cls.doc:
58+
match line, adding_code:
59+
case "```", _:
60+
adding_code = not adding_code
61+
case _, False:
62+
code.append(f"// {line}")
63+
case _, True:
64+
code.append(line)
65+
if fn:
66+
indent = 4 * " "
67+
code = [indent + l for l in code]
68+
test_with = schema.classes[cls.test_with] if cls.test_with else cls
69+
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{cls.name.lower()}.rs"
70+
renderer.render(TestCode(code="\n".join(code), function=fn), test)

misc/codegen/lib/schema.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class Class:
9494
default_doc_name: Optional[str] = None
9595
hideable: bool = False
9696
test_with: Optional[str] = None
97+
rust_doc_test_function: Optional["FunctionInfo"] = None # TODO: parametrized pragmas
9798

9899
@property
99100
def final(self):
@@ -202,3 +203,10 @@ def split_doc(doc):
202203
while trimmed and not trimmed[0]:
203204
trimmed.pop(0)
204205
return trimmed
206+
207+
208+
@dataclass
209+
class FunctionInfo:
210+
name: str
211+
params: dict[str, str]
212+
return_type: str

misc/codegen/lib/schemadefs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def modify(self, prop: _schema.Property):
5252
qltest = _Namespace()
5353
ql = _Namespace()
5454
cpp = _Namespace()
55+
rust = _Namespace()
5556
synth = _SynthModifier()
5657

5758

@@ -156,6 +157,14 @@ def f(cls: type) -> type:
156157

157158
_Pragma("cpp_skip")
158159

160+
_Pragma("rust_skip_doc_test")
161+
162+
rust.doc_test_function = lambda name, *, lifetimes=(), return_type="()", **kwargs: _annotate(
163+
rust_doc_test_function=_schema.FunctionInfo(name,
164+
params={f"'{lifetime}": "" for lifetime in lifetimes} | kwargs,
165+
return_type=return_type)
166+
)
167+
159168

160169
def group(name: str = "") -> _ClassDecorator:
161170
return _annotate(group=name)

misc/codegen/loaders/schemaloader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def _get_class(cls: type) -> schema.Class:
5656
],
5757
doc=schema.split_doc(cls.__doc__),
5858
default_doc_name=cls.__dict__.get("_doc_name"),
59+
rust_doc_test_function=cls.__dict__.get("_rust_doc_test_function")
5960
)
6061

6162

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// generated by {{generator}}
2+
3+
{{#function}}
4+
fn {{name}}<{{#generic_params}}{{^first}}, {{/first}}{{name}}{{#type}}: {{.}}{{/type}}{{/generic_params}}>({{#params}}{{^first}}, {{/first}}{{name}}: {{type}}{{/params}}) -> {{return_type}} {
5+
{{/function}}
6+
{{code}}
7+
{{#function}}
8+
}
9+
{{/function}}

rust/.generated.list

Lines changed: 5 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/.gitattributes

Lines changed: 3 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)