1919 Modifier ,
2020 Module ,
2121 Property ,
22+ QualifiedName ,
2223 ResolvedType ,
2324 TypeVar_ ,
2425 Value ,
@@ -30,8 +31,8 @@ def indent_lines(lines: list[str], by=4) -> list[str]:
3031
3132
3233class Printer :
33- def __init__ (self , invalid_expr_as_ellipses : bool ):
34- self .invalid_expr_as_ellipses = invalid_expr_as_ellipses
34+ def __init__ (self ):
35+ self ._need_typing_ext = False
3536
3637 def print_alias (self , alias : Alias ) -> list [str ]:
3738 return [f"{ alias .name } = { alias .origin } " ]
@@ -43,13 +44,8 @@ def print_attribute(self, attr: Attribute) -> list[str]:
4344 if attr .annotation is not None :
4445 parts .append (f": { self .print_annotation (attr .annotation )} " )
4546
46- if attr .value is not None and attr . value . is_print_safe :
47+ if attr .value is not None :
4748 parts .append (f" = { self .print_value (attr .value )} " )
48- else :
49- if attr .annotation is None :
50- parts .append (" = ..." )
51- if attr .value is not None :
52- parts .append (f" # value = { self .print_value (attr .value )} " )
5349
5450 return ["" .join (parts )]
5551
@@ -202,40 +198,51 @@ def print_method(self, method: Method) -> list[str]:
202198 return result
203199
204200 def print_module (self , module : Module ) -> list [str ]:
205- result = []
206-
207- if module .doc is not None :
208- result .extend (self .print_docstring (module .doc ))
209-
210- for import_ in sorted (module .imports , key = lambda x : x .origin ):
211- result .extend (self .print_import (import_ ))
201+ result_bottom = []
202+ tmp = self ._need_typing_ext
212203
213204 for sub_module in module .sub_modules :
214- result .extend (self .print_submodule_import (sub_module .name ))
205+ result_bottom .extend (self .print_submodule_import (sub_module .name ))
215206
216207 # Place __all__ above everything
217208 for attr in sorted (module .attributes , key = lambda a : a .name ):
218209 if attr .name == "__all__" :
219- result .extend (self .print_attribute (attr ))
210+ result_bottom .extend (self .print_attribute (attr ))
220211 break
221212
222213 for type_var in sorted (module .type_vars , key = lambda t : t .name ):
223- result .extend (self .print_type_var (type_var ))
214+ result_bottom .extend (self .print_type_var (type_var ))
224215
225216 for class_ in sorted (module .classes , key = lambda c : c .name ):
226- result .extend (self .print_class (class_ ))
217+ result_bottom .extend (self .print_class (class_ ))
227218
228219 for func in sorted (module .functions , key = lambda f : f .name ):
229- result .extend (self .print_function (func ))
220+ result_bottom .extend (self .print_function (func ))
230221
231222 for attr in sorted (module .attributes , key = lambda a : a .name ):
232223 if attr .name != "__all__" :
233- result .extend (self .print_attribute (attr ))
224+ result_bottom .extend (self .print_attribute (attr ))
234225
235226 for alias in module .aliases :
236- result .extend (self .print_alias (alias ))
227+ result_bottom .extend (self .print_alias (alias ))
228+
229+ if self ._need_typing_ext :
230+ module .imports .add (
231+ Import (
232+ name = None ,
233+ origin = QualifiedName .from_str ("pybind11_stubgen.typing_ext" ),
234+ )
235+ )
237236
238- return result
237+ result_top = []
238+ if module .doc is not None :
239+ result_top .extend (self .print_docstring (module .doc ))
240+
241+ for import_ in sorted (module .imports , key = lambda x : x .origin ):
242+ result_top .extend (self .print_import (import_ ))
243+
244+ self ._need_typing_ext = tmp
245+ return result_top + result_bottom
239246
240247 def print_property (self , prop : Property ) -> list [str ]:
241248 if not prop .getter :
@@ -276,11 +283,10 @@ def print_property(self, prop: Property) -> list[str]:
276283 return result
277284
278285 def print_value (self , value : Value ) -> str :
279- split = value .repr .split ("\n " , 1 )
280- if len (split ) == 1 :
281- return split [0 ]
282- else :
283- return split [0 ] + "..."
286+ if value .is_print_safe :
287+ return value .repr
288+ self ._need_typing_ext = True
289+ return f"pybind11_stubgen.typing_ext.ValueExpr({ repr (value .repr )} )"
284290
285291 def print_type (self , type_ : ResolvedType ) -> str :
286292 if (
@@ -312,6 +318,5 @@ def print_annotation(self, annotation: Annotation) -> str:
312318 raise AssertionError ()
313319
314320 def print_invalid_exp (self , invalid_expr : InvalidExpression ) -> str :
315- if self .invalid_expr_as_ellipses :
316- return "..."
317- return invalid_expr .text
321+ self ._need_typing_ext = True
322+ return f"pybind11_stubgen.typing_ext.InvalidExpr({ repr (invalid_expr .text )} )"
0 commit comments