Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion pybind11_stubgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class CLIArgs(Namespace):
exit_code: bool
dry_run: bool
stub_extension: str
sort_by: str
module_name: str


Expand Down Expand Up @@ -216,6 +217,14 @@ def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]:
"Must be 'pyi' (default) or 'py'",
)

parser.add_argument(
"--sort-by",
type=str,
default="definition",
choices=["definition", "topological"],
help="Sort classes by 'definition' order (default) or 'topological' order.",
)

parser.add_argument(
"module_name",
metavar="MODULE_NAME",
Expand Down Expand Up @@ -309,7 +318,10 @@ def main(argv: Sequence[str] | None = None) -> None:
args = arg_parser().parse_args(argv, namespace=CLIArgs())

parser = stub_parser_from_args(args)
printer = Printer(invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is)
printer = Printer(
invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is,
sort_by=args.sort_by,
)

out_dir, sub_dir = to_output_and_subdir(
output_dir=args.output_dir,
Expand Down
2 changes: 1 addition & 1 deletion pybind11_stubgen/parser/mixins/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def handle_module(
self, path: QualifiedName, module: types.ModuleType
) -> Module | None:
result = Module(name=path[-1])
for name, member in inspect.getmembers(module):
for name, member in module.__dict__.items():
obj = self.handle_module_member(
QualifiedName([*path, Identifier(name)]), module, member
)
Expand Down
54 changes: 51 additions & 3 deletions pybind11_stubgen/printer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

import dataclasses
import logging
import sys
from collections import defaultdict

log = logging.getLogger("pybind11_stubgen")

from pybind11_stubgen.structs import (
Alias,
Expand Down Expand Up @@ -30,8 +34,44 @@ def indent_lines(lines: list[str], by=4) -> list[str]:


class Printer:
def __init__(self, invalid_expr_as_ellipses: bool):
def __init__(self, invalid_expr_as_ellipses: bool, sort_by: str = "definition"):
self.invalid_expr_as_ellipses = invalid_expr_as_ellipses
self.sort_by = sort_by

def _toposort_classes(self, classes: list[Class]) -> list[Class]:
in_degree = {c.name: 0 for c in classes}
graph = defaultdict(list)
class_map = {c.name: c for c in classes}

for c in classes:
for base in c.bases:
base_name = base[-1]
if base_name in class_map:
graph[base_name].append(c.name)
in_degree[c.name] += 1

queue = sorted([name for name, degree in in_degree.items() if degree == 0])

sorted_classes = []
while queue:
name = queue.pop(0)
sorted_classes.append(class_map[name])
for neighbor in sorted(graph[name]):
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)

if len(sorted_classes) == len(classes):
return sorted_classes
else:
# Cycle detected, fallback to alphabetical sort
remaining = [c for c in classes if c not in sorted_classes]
log.warning(
"Cycle detected in class inheritance involving: %s. "
"Falling back to alphabetical sort for these classes.",
[c.name for c in remaining],
)
return sorted_classes + sorted(remaining, key=lambda c: c.name)

def print_alias(self, alias: Alias) -> list[str]:
return [f"{alias.name} = {alias.origin}"]
Expand Down Expand Up @@ -90,7 +130,11 @@ def print_class_body(self, class_: Class) -> list[str]:
if class_.doc is not None:
result.extend(self.print_docstring(class_.doc))

for sub_class in sorted(class_.classes, key=lambda c: c.name):
classes_to_print = class_.classes
if self.sort_by == "topological":
classes_to_print = self._toposort_classes(class_.classes)

for sub_class in classes_to_print:
result.extend(self.print_class(sub_class))

modifier_order: dict[Modifier, int] = {
Expand Down Expand Up @@ -225,7 +269,11 @@ def print_module(self, module: Module) -> list[str]:
for type_var in sorted(module.type_vars, key=lambda t: t.name):
result.extend(self.print_type_var(type_var))

for class_ in sorted(module.classes, key=lambda c: c.name):
classes_to_print = module.classes
if self.sort_by == "topological":
classes_to_print = self._toposort_classes(module.classes)

for class_ in classes_to_print:
result.extend(self.print_class(class_))

for func in sorted(module.functions, key=lambda f: f.name):
Expand Down
20 changes: 11 additions & 9 deletions tests/demo-lib/include/demo/Inheritance.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
#pragma once
#include <string>

namespace demo{
namespace demo
{
// note: class stubs must not be sorted
// https://github.com/sizmailov/pybind11-stubgen/issues/231

struct Base {
struct Inner{};
std::string name;
};

struct Derived : Base {
int count;
};
struct MyBase {
struct Inner{};
std::string name;
};

struct Derived : MyBase {
int count;
};
}
8 changes: 4 additions & 4 deletions tests/py-demo/bindings/src/modules/classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ void bind_classes_module(py::module&&m) {
}

{
py::class_<demo::Base> pyBase(m, "Base");
py::class_<demo::MyBase> pyMyBase(m, "MyBase");

pyBase.def_readwrite("name", &demo::Base::name);
pyMyBase.def_readwrite("name", &demo::MyBase::name);

py::class_<demo::Base::Inner>(pyBase, "Inner");
py::class_<demo::MyBase::Inner>(pyMyBase, "Inner");

py::class_<demo::Derived, demo::Base>(m, "Derived")
py::class_<demo::Derived, demo::MyBase>(m, "Derived")
.def_readwrite("count", &demo::Derived::count);

}
Expand Down