diff --git a/pybind11_stubgen/__init__.py b/pybind11_stubgen/__init__.py index 89f7a9c9..06207820 100644 --- a/pybind11_stubgen/__init__.py +++ b/pybind11_stubgen/__init__.py @@ -77,6 +77,7 @@ class CLIArgs(Namespace): exit_code: bool dry_run: bool stub_extension: str + sort_by: str module_name: str @@ -216,6 +217,14 @@ def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]: "Must be 'pyi' (default) or 'py'", ) + parser.add_argument( + "--sort-by", + type=str, + default="definition", + choices=["definition", "topological"], + help="Sort classes by 'definition' order (default) or 'topological' order.", + ) + parser.add_argument( "module_name", metavar="MODULE_NAME", @@ -309,7 +318,10 @@ def main(argv: Sequence[str] | None = None) -> None: args = arg_parser().parse_args(argv, namespace=CLIArgs()) parser = stub_parser_from_args(args) - printer = Printer(invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is) + printer = Printer( + invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is, + sort_by=args.sort_by, + ) out_dir, sub_dir = to_output_and_subdir( output_dir=args.output_dir, diff --git a/pybind11_stubgen/parser/mixins/parse.py b/pybind11_stubgen/parser/mixins/parse.py index f8867603..c83a1850 100644 --- a/pybind11_stubgen/parser/mixins/parse.py +++ b/pybind11_stubgen/parser/mixins/parse.py @@ -86,7 +86,7 @@ def handle_module( self, path: QualifiedName, module: types.ModuleType ) -> Module | None: result = Module(name=path[-1]) - for name, member in inspect.getmembers(module): + for name, member in module.__dict__.items(): obj = self.handle_module_member( QualifiedName([*path, Identifier(name)]), module, member ) diff --git a/pybind11_stubgen/printer.py b/pybind11_stubgen/printer.py index 8ef4af20..d6eb72a6 100644 --- a/pybind11_stubgen/printer.py +++ b/pybind11_stubgen/printer.py @@ -1,7 +1,11 @@ from __future__ import annotations import dataclasses +import logging import sys +from collections import defaultdict + +log = logging.getLogger("pybind11_stubgen") from pybind11_stubgen.structs import ( Alias, @@ -30,8 +34,44 @@ def indent_lines(lines: list[str], by=4) -> list[str]: class Printer: - def __init__(self, invalid_expr_as_ellipses: bool): + def __init__(self, invalid_expr_as_ellipses: bool, sort_by: str = "definition"): self.invalid_expr_as_ellipses = invalid_expr_as_ellipses + self.sort_by = sort_by + + def _toposort_classes(self, classes: list[Class]) -> list[Class]: + in_degree = {c.name: 0 for c in classes} + graph = defaultdict(list) + class_map = {c.name: c for c in classes} + + for c in classes: + for base in c.bases: + base_name = base[-1] + if base_name in class_map: + graph[base_name].append(c.name) + in_degree[c.name] += 1 + + queue = sorted([name for name, degree in in_degree.items() if degree == 0]) + + sorted_classes = [] + while queue: + name = queue.pop(0) + sorted_classes.append(class_map[name]) + for neighbor in sorted(graph[name]): + in_degree[neighbor] -= 1 + if in_degree[neighbor] == 0: + queue.append(neighbor) + + if len(sorted_classes) == len(classes): + return sorted_classes + else: + # Cycle detected, fallback to alphabetical sort + remaining = [c for c in classes if c not in sorted_classes] + log.warning( + "Cycle detected in class inheritance involving: %s. " + "Falling back to alphabetical sort for these classes.", + [c.name for c in remaining], + ) + return sorted_classes + sorted(remaining, key=lambda c: c.name) def print_alias(self, alias: Alias) -> list[str]: return [f"{alias.name} = {alias.origin}"] @@ -90,7 +130,11 @@ def print_class_body(self, class_: Class) -> list[str]: if class_.doc is not None: result.extend(self.print_docstring(class_.doc)) - for sub_class in sorted(class_.classes, key=lambda c: c.name): + classes_to_print = class_.classes + if self.sort_by == "topological": + classes_to_print = self._toposort_classes(class_.classes) + + for sub_class in classes_to_print: result.extend(self.print_class(sub_class)) modifier_order: dict[Modifier, int] = { @@ -225,7 +269,11 @@ def print_module(self, module: Module) -> list[str]: for type_var in sorted(module.type_vars, key=lambda t: t.name): result.extend(self.print_type_var(type_var)) - for class_ in sorted(module.classes, key=lambda c: c.name): + classes_to_print = module.classes + if self.sort_by == "topological": + classes_to_print = self._toposort_classes(module.classes) + + for class_ in classes_to_print: result.extend(self.print_class(class_)) for func in sorted(module.functions, key=lambda f: f.name): diff --git a/tests/demo-lib/include/demo/Inheritance.h b/tests/demo-lib/include/demo/Inheritance.h index 1b577cbe..f99ffa4a 100644 --- a/tests/demo-lib/include/demo/Inheritance.h +++ b/tests/demo-lib/include/demo/Inheritance.h @@ -1,15 +1,17 @@ #pragma once #include -namespace demo{ +namespace demo +{ + // note: class stubs must not be sorted + // https://github.com/sizmailov/pybind11-stubgen/issues/231 -struct Base { - struct Inner{}; - std::string name; -}; - -struct Derived : Base { - int count; -}; + struct MyBase { + struct Inner{}; + std::string name; + }; + struct Derived : MyBase { + int count; + }; } diff --git a/tests/py-demo/bindings/src/modules/classes.cpp b/tests/py-demo/bindings/src/modules/classes.cpp index 347d9c11..f6689d10 100644 --- a/tests/py-demo/bindings/src/modules/classes.cpp +++ b/tests/py-demo/bindings/src/modules/classes.cpp @@ -19,13 +19,13 @@ void bind_classes_module(py::module&&m) { } { - py::class_ pyBase(m, "Base"); + py::class_ pyMyBase(m, "MyBase"); - pyBase.def_readwrite("name", &demo::Base::name); + pyMyBase.def_readwrite("name", &demo::MyBase::name); - py::class_(pyBase, "Inner"); + py::class_(pyMyBase, "Inner"); - py::class_(m, "Derived") + py::class_(m, "Derived") .def_readwrite("count", &demo::Derived::count); }