Skip to content

Commit 025a6f3

Browse files
committed
refactor!: Remove invalid invalid_expr_as_ellipses from printer
I don't like the fact I'm modifying the module import in printer
1 parent 671e114 commit 025a6f3

File tree

3 files changed

+45
-38
lines changed

3 files changed

+45
-38
lines changed

pybind11_stubgen/__init__.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,6 @@ def stub_parser_from_args(args: CLIArgs) -> IParser:
249249
),
250250
]
251251

252-
if args.print_invalid_expressions_as_is:
253-
wrap_invalid_expressions = []
254-
else:
255-
wrap_invalid_expressions = [WrapInvalidExpressions]
256-
257252
class Parser(
258253
*error_handlers_top, # type: ignore[misc]
259254
FixMissing__future__AnnotationsImport,
@@ -282,7 +277,7 @@ class Parser(
282277
FixRedundantMethodsFromBuiltinObject,
283278
RemoveSelfAnnotation,
284279
FixPybind11EnumStrDoc,
285-
*wrap_invalid_expressions,
280+
WrapInvalidExpressions,
286281
ExtractSignaturesFromPybind11Docstrings,
287282
ParserDispatchMixin,
288283
BaseParser,
@@ -313,7 +308,7 @@ def main():
313308
args = arg_parser().parse_args(namespace=CLIArgs())
314309

315310
parser = stub_parser_from_args(args)
316-
printer = Printer(invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is)
311+
printer = Printer()
317312

318313
out_dir, sub_dir = to_output_and_subdir(
319314
output_dir=args.output_dir,

pybind11_stubgen/printer.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Modifier,
2020
Module,
2121
Property,
22+
QualifiedName,
2223
ResolvedType,
2324
TypeVar_,
2425
Value,
@@ -30,8 +31,8 @@ def indent_lines(lines: list[str], by=4) -> list[str]:
3031

3132

3233
class Printer:
33-
def __init__(self, invalid_expr_as_ellipses: bool):
34-
self.invalid_expr_as_ellipses = invalid_expr_as_ellipses
34+
def __init__(self):
35+
self._need_typing_ext = False
3536

3637
def print_alias(self, alias: Alias) -> list[str]:
3738
return [f"{alias.name} = {alias.origin}"]
@@ -43,13 +44,8 @@ def print_attribute(self, attr: Attribute) -> list[str]:
4344
if attr.annotation is not None:
4445
parts.append(f": {self.print_annotation(attr.annotation)}")
4546

46-
if attr.value is not None and attr.value.is_print_safe:
47+
if attr.value is not None:
4748
parts.append(f" = {self.print_value(attr.value)}")
48-
else:
49-
if attr.annotation is None:
50-
parts.append(" = ...")
51-
if attr.value is not None:
52-
parts.append(f" # value = {self.print_value(attr.value)}")
5349

5450
return ["".join(parts)]
5551

@@ -202,40 +198,51 @@ def print_method(self, method: Method) -> list[str]:
202198
return result
203199

204200
def print_module(self, module: Module) -> list[str]:
205-
result = []
206-
207-
if module.doc is not None:
208-
result.extend(self.print_docstring(module.doc))
209-
210-
for import_ in sorted(module.imports, key=lambda x: x.origin):
211-
result.extend(self.print_import(import_))
201+
result_bottom = []
202+
tmp = self._need_typing_ext
212203

213204
for sub_module in module.sub_modules:
214-
result.extend(self.print_submodule_import(sub_module.name))
205+
result_bottom.extend(self.print_submodule_import(sub_module.name))
215206

216207
# Place __all__ above everything
217208
for attr in sorted(module.attributes, key=lambda a: a.name):
218209
if attr.name == "__all__":
219-
result.extend(self.print_attribute(attr))
210+
result_bottom.extend(self.print_attribute(attr))
220211
break
221212

222213
for type_var in sorted(module.type_vars, key=lambda t: t.name):
223-
result.extend(self.print_type_var(type_var))
214+
result_bottom.extend(self.print_type_var(type_var))
224215

225216
for class_ in sorted(module.classes, key=lambda c: c.name):
226-
result.extend(self.print_class(class_))
217+
result_bottom.extend(self.print_class(class_))
227218

228219
for func in sorted(module.functions, key=lambda f: f.name):
229-
result.extend(self.print_function(func))
220+
result_bottom.extend(self.print_function(func))
230221

231222
for attr in sorted(module.attributes, key=lambda a: a.name):
232223
if attr.name != "__all__":
233-
result.extend(self.print_attribute(attr))
224+
result_bottom.extend(self.print_attribute(attr))
234225

235226
for alias in module.aliases:
236-
result.extend(self.print_alias(alias))
227+
result_bottom.extend(self.print_alias(alias))
228+
229+
if self._need_typing_ext:
230+
module.imports.add(
231+
Import(
232+
name=None,
233+
origin=QualifiedName.from_str("pybind11_stubgen.typing_ext"),
234+
)
235+
)
237236

238-
return result
237+
result_top = []
238+
if module.doc is not None:
239+
result_top.extend(self.print_docstring(module.doc))
240+
241+
for import_ in sorted(module.imports, key=lambda x: x.origin):
242+
result_top.extend(self.print_import(import_))
243+
244+
self._need_typing_ext = tmp
245+
return result_top + result_bottom
239246

240247
def print_property(self, prop: Property) -> list[str]:
241248
if not prop.getter:
@@ -276,11 +283,10 @@ def print_property(self, prop: Property) -> list[str]:
276283
return result
277284

278285
def print_value(self, value: Value) -> str:
279-
split = value.repr.split("\n", 1)
280-
if len(split) == 1:
281-
return split[0]
282-
else:
283-
return split[0] + "..."
286+
if value.is_print_safe:
287+
return value.repr
288+
self._need_typing_ext = True
289+
return f"pybind11_stubgen.typing_ext.ValueExpr({repr(value.repr)})"
284290

285291
def print_type(self, type_: ResolvedType) -> str:
286292
if (
@@ -312,6 +318,5 @@ def print_annotation(self, annotation: Annotation) -> str:
312318
raise AssertionError()
313319

314320
def print_invalid_exp(self, invalid_expr: InvalidExpression) -> str:
315-
if self.invalid_expr_as_ellipses:
316-
return "..."
317-
return invalid_expr.text
321+
self._need_typing_ext = True
322+
return f"pybind11_stubgen.typing_ext.InvalidExpr({repr(invalid_expr.text)})"

pybind11_stubgen/typing_ext.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,10 @@ def InvalidExpr(expr: str) -> Any:
3232
"The method exists only for annotation purposes in stub files. "
3333
"Should never not be used at runtime"
3434
)
35+
36+
37+
def ValueExpr(expr: str) -> Any:
38+
raise RuntimeError(
39+
"The method exists only for annotation purposes in stub files. "
40+
"Should never not be used at runtime"
41+
)

0 commit comments

Comments
 (0)