Skip to content

Commit 6afe40e

Browse files
committed
Fix FixNumpyArrayDimTypeVar for pybind v3.0.0
1 parent 0a566ba commit 6afe40e

File tree

19 files changed

+234
-108
lines changed

19 files changed

+234
-108
lines changed

pybind11_stubgen/parser/mixins/fix.py

Lines changed: 122 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040

4141

4242
class RemoveSelfAnnotation(IParser):
43-
4443
__any_t_name = QualifiedName.from_str("Any")
4544
__typing_any_t_name = QualifiedName.from_str("typing.Any")
4645

@@ -632,10 +631,19 @@ def report_error(self, error: ParserError) -> None:
632631

633632

634633
class FixNumpyArrayDimTypeVar(IParser):
635-
__array_names: set[QualifiedName] = {QualifiedName.from_str("numpy.ndarray")}
634+
__array_names: set[QualifiedName] = {
635+
QualifiedName.from_str("numpy.ndarray"),
636+
QualifiedName.from_str("numpy.typing.ArrayLike"),
637+
QualifiedName.from_str("numpy.typing.NDArray"),
638+
}
639+
__typing_annotated_names = {
640+
QualifiedName.from_str("typing.Annotated"),
641+
QualifiedName.from_str("typing_extensions.Annotated"),
642+
}
636643
numpy_primitive_types = FixNumpyArrayDimAnnotation.numpy_primitive_types
637644

638645
__DIM_VARS: set[str] = set()
646+
__DIM_STRING_PATTERN = re.compile(r'"\[(.*?)\]"')
639647

640648
def handle_module(
641649
self, path: QualifiedName, module: types.ModuleType
@@ -659,17 +667,11 @@ def handle_module(
659667
)
660668

661669
self.__DIM_VARS.clear()
662-
663670
return result
664671

665672
def parse_annotation_str(
666673
self, annotation_str: str
667674
) -> ResolvedType | InvalidExpression | Value:
668-
# Affects types of the following pattern:
669-
# numpy.ndarray[PRIMITIVE_TYPE[*DIMS], *FLAGS]
670-
# Replace with:
671-
# numpy.ndarray[tuple[M, Literal[1]], numpy.dtype[numpy.float32]]
672-
673675
result = super().parse_annotation_str(annotation_str)
674676

675677
if not isinstance(result, ResolvedType):
@@ -679,9 +681,103 @@ def parse_annotation_str(
679681
if len(result.name) == 1 and len(result.name[0]) == 1:
680682
result.name = QualifiedName.from_str(result.name[0].upper())
681683
self.__DIM_VARS.add(result.name[0])
684+
return result
685+
686+
if result.name in self.__typing_annotated_names:
687+
return self._handle_annotated_numpy_array(result)
688+
elif result.name in self.__array_names:
689+
return self._handle_old_style_numpy_array(result)
690+
691+
return result
682692

683-
if result.name not in self.__array_names:
693+
def _process_numpy_array_type(
694+
self, scalar_type_name: QualifiedName, dimensions: list[int | str] | None
695+
) -> tuple[ResolvedType, ResolvedType]:
696+
# Pybind annotates a bool Python type, which cannot be used with
697+
# numpy.dtype because it does not inherit from numpy.generic.
698+
# Only numpy.bool_ works reliably with both NumPy 1.x and 2.x.
699+
if str(scalar_type_name) == "bool":
700+
scalar_type_name = QualifiedName.from_str("numpy.bool_")
701+
dtype = ResolvedType(
702+
name=QualifiedName.from_str("numpy.dtype"),
703+
parameters=[ResolvedType(name=scalar_type_name)],
704+
)
705+
706+
shape = self.parse_annotation_str("Any")
707+
if dimensions is not None and len(dimensions) > 0:
708+
shape = self.parse_annotation_str("Tuple")
709+
assert isinstance(shape, ResolvedType)
710+
shape.parameters = []
711+
for dim in dimensions:
712+
if isinstance(dim, int):
713+
literal_dim = self.parse_annotation_str("Literal")
714+
assert isinstance(literal_dim, ResolvedType)
715+
literal_dim.parameters = [Value(repr=str(dim))]
716+
shape.parameters.append(literal_dim)
717+
else:
718+
shape.parameters.append(
719+
ResolvedType(name=QualifiedName.from_str(dim.upper()))
720+
)
721+
return shape, dtype
722+
723+
def _handle_annotated_numpy_array(self, result: ResolvedType) -> ResolvedType:
724+
# Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, n]"]
725+
# Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]"]
726+
# Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]", "flags.writeable", "flags.c_contiguous"]
727+
if result.parameters is None or len(result.parameters) < 2:
728+
return result
729+
730+
array_type, *parameters = result.parameters
731+
if (
732+
not isinstance(array_type, ResolvedType)
733+
or array_type.name not in self.__array_names
734+
):
735+
return result
736+
737+
dims_and_flags: Sequence[ResolvedType | Value | InvalidExpression]
738+
if array_type.name == QualifiedName.from_str("numpy.typing.ArrayLike"):
739+
scalar_type, *dims_and_flags = parameters
740+
scalar_type_name = scalar_type.name
741+
elif array_type.name == QualifiedName.from_str("numpy.typing.NDArray"):
742+
if array_type.parameters is None or len(array_type.parameters) < 2:
743+
return result
744+
_, dtype_param = array_type.parameters
745+
if not (
746+
isinstance(dtype_param, ResolvedType)
747+
and dtype_param.name == QualifiedName.from_str("numpy.dtype")
748+
and dtype_param.parameters
749+
):
750+
return result
751+
scalar_type_name = dtype_param.parameters[0].name
752+
dims_and_flags = parameters
753+
else:
684754
return result
755+
if scalar_type_name not in self.numpy_primitive_types:
756+
return result
757+
758+
dims: list[int | str] | None = None
759+
if dims_and_flags:
760+
dims_str, *flags = dims_and_flags
761+
del flags # Unused.
762+
if isinstance(dims_str, Value):
763+
match = self.__DIM_STRING_PATTERN.search(dims_str.repr)
764+
if match:
765+
dims_str_content = match.group(1)
766+
dims_list = [d.strip() for d in dims_str_content.split(",") if d.strip()]
767+
if dims_list:
768+
dims = self.__to_dims_from_strings(dims_list)
769+
770+
shape, dtype = self._process_numpy_array_type(scalar_type_name, dims)
771+
return ResolvedType(
772+
name=QualifiedName.from_str("numpy.ndarray"),
773+
parameters=[shape, dtype],
774+
)
775+
776+
def _handle_old_style_numpy_array(self, result: ResolvedType) -> ResolvedType:
777+
# Affects types of the following pattern:
778+
# numpy.ndarray[PRIMITIVE_TYPE[*DIMS], *FLAGS]
779+
# Replace with:
780+
# numpy.ndarray[tuple[M, Literal[1]], numpy.dtype[numpy.float32]]
685781

686782
# ndarray is generic and should have 2 type arguments
687783
if result.parameters is None or len(result.parameters) == 0:
@@ -702,39 +798,14 @@ def parse_annotation_str(
702798
):
703799
return result
704800

705-
name = scalar_with_dims.name
706-
# Pybind annotates a bool Python type, which cannot be used with
707-
# numpy.dtype because it does not inherit from numpy.generic.
708-
# Only numpy.bool_ works reliably with both NumPy 1.x and 2.x.
709-
if str(name) == "bool":
710-
name = QualifiedName.from_str("numpy.bool_")
711-
dtype = ResolvedType(
712-
name=QualifiedName.from_str("numpy.dtype"),
713-
parameters=[ResolvedType(name=name)],
714-
)
715-
716-
shape = self.parse_annotation_str("Any")
801+
dims: list[int | str] | None = None
717802
if (
718803
scalar_with_dims.parameters is not None
719804
and len(scalar_with_dims.parameters) > 0
720805
):
721806
dims = self.__to_dims(scalar_with_dims.parameters)
722-
if dims is not None:
723-
shape = self.parse_annotation_str("Tuple")
724-
assert isinstance(shape, ResolvedType)
725-
shape.parameters = []
726-
for dim in dims:
727-
if isinstance(dim, int):
728-
# self.parse_annotation_str will qualify Literal with either
729-
# typing or typing_extensions and add the import to the module
730-
literal_dim = self.parse_annotation_str("Literal")
731-
assert isinstance(literal_dim, ResolvedType)
732-
literal_dim.parameters = [Value(repr=str(dim))]
733-
shape.parameters.append(literal_dim)
734-
else:
735-
shape.parameters.append(
736-
ResolvedType(name=QualifiedName.from_str(dim))
737-
)
807+
808+
shape, dtype = self._process_numpy_array_type(scalar_with_dims.name, dims)
738809

739810
result.parameters = [shape, dtype]
740811
return result
@@ -756,6 +827,20 @@ def __to_dims(
756827
result.append(dim)
757828
return result
758829

830+
def __to_dims_from_strings(
831+
self, dimensions: Sequence[str]
832+
) -> list[int | str] | None:
833+
result: list[int | str] = []
834+
for dim_str in dimensions:
835+
try:
836+
dim = int(dim_str)
837+
except ValueError:
838+
dim = dim_str
839+
if len(dim) == 1: # Assuming single letter dims are type vars
840+
self.__DIM_VARS.add(dim.upper()) # Add uppercase to TypeVar set
841+
result.append(dim)
842+
return result
843+
759844
def report_error(self, error: ParserError) -> None:
760845
if (
761846
isinstance(error, NameResolutionError)

tests/check-demo-stubs-generation.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

3-
set -e
3+
set -ex
44

55
function parse_args() {
66

@@ -30,8 +30,8 @@ function parse_args() {
3030
if [ -z "$STUBS_SUB_DIR" ]; then usage "STUBS_SUB_DIR is not set"; fi;
3131
if [ -z "$NUMPY_FORMAT" ]; then usage "NUMPY_FORMAT is not set"; fi;
3232

33-
TESTS_ROOT="$(readlink -m "$(dirname "$0")")"
34-
STUBS_DIR=$(readlink -m "${TESTS_ROOT}/${STUBS_SUB_DIR}")
33+
TESTS_ROOT="$(greadlink -m "$(dirname "$0")")"
34+
STUBS_DIR=$(greadlink -m "${TESTS_ROOT}/${STUBS_SUB_DIR}")
3535
}
3636

3737
remove_stubs() {

tests/demo-lib/include/demo/Foo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace demo{
55

66

77
class CppException : public std::runtime_error {
8-
using std::runtime_error::runtime_error;
8+
//using std::runtime_error;
99
};
1010

1111
struct Foo {

tests/install-demo-module.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function parse_args() {
2727
# verify params
2828
if [ -z "$PYBIND11_BRANCH" ]; then usage "PYBIND11_BRANCH is not set"; fi;
2929

30-
TESTS_ROOT="$(readlink -m "$(dirname "$0")")"
30+
TESTS_ROOT="$(greadlink -m "$(dirname "$0")")"
3131
PROJECT_ROOT="${TESTS_ROOT}/.."
3232
TEMP_DIR="${PROJECT_ROOT}/tmp/pybind11-${PYBIND11_BRANCH}"
3333
INSTALL_PREFIX="${TEMP_DIR}/install"
@@ -67,7 +67,7 @@ install_demo() {
6767

6868
install_pydemo() {
6969
(
70-
export CMAKE_PREFIX_PATH="$(readlink -m "${INSTALL_PREFIX}"):$(cmeel cmake)";
70+
export CMAKE_PREFIX_PATH="$(greadlink -m "${INSTALL_PREFIX}"):$(cmeel cmake)";
7171
export CMAKE_ARGS="-DCMAKE_CXX_STANDARD=17";
7272
pip install --force-reinstall "${TESTS_ROOT}/py-demo"
7373
)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "modules.h"
22

3-
namespace {
3+
namespace mymodules {
44
struct Dummy {
55
int regular_method(int x) { return x + 1; }
66
static int static_method(int x) { return x + 1; }
@@ -9,8 +9,8 @@ struct Dummy {
99
} // namespace
1010

1111
void bind_methods_module(py::module&& m) {
12-
auto &&pyDummy = py::class_<Dummy>(m, "Dummy");
12+
auto &&pyDummy = py::class_<mymodules::Dummy>(m, "Dummy");
1313

14-
pyDummy.def_static("static_method", &Dummy::static_method);
15-
pyDummy.def("regular_method", &Dummy::regular_method);
14+
pyDummy.def_static("static_method", &mymodules::Dummy::static_method);
15+
pyDummy.def("regular_method", &mymodules::Dummy::regular_method);
1616
}

tests/py-demo/bindings/src/modules/values.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44

55
#include <chrono>
66

7-
namespace {
7+
namespace myvalues {
88
class Dummy {};
99
class Foo {};
1010
} // namespace
1111

1212
void bind_values_module(py::module &&m) {
1313
{
1414
// python module as value
15-
auto &&pyDummy = py::class_<Dummy>(m, "Dummy");
15+
auto &&pyDummy = py::class_<myvalues::Dummy>(m, "Dummy");
1616

1717
pyDummy.def_property_readonly_static(
1818
"linalg", [](py::object &) { return py::module::import("numpy.linalg"); });
@@ -27,12 +27,12 @@ void bind_values_module(py::module &&m) {
2727
m.attr("list_with_none") = li;
2828
}
2929
{
30-
auto pyFoo = py::class_<Foo>(m, "Foo");
31-
m.attr("foovar") = Foo();
30+
auto pyFoo = py::class_<myvalues::Foo>(m, "Foo");
31+
m.attr("foovar") = myvalues::Foo();
3232

3333
py::list foolist;
34-
foolist.append(Foo());
35-
foolist.append(Foo());
34+
foolist.append(myvalues::Foo());
35+
foolist.append(myvalues::Foo());
3636

3737
m.attr("foolist") = foolist;
3838
m.attr("none") = py::none();

tests/stubs/python-3.12/pybind11-v3.0.0/numpy-array-use-type-var/demo/_bindings/aliases/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class Dummy:
4747
def foreign_enum_default(
4848
color: typing.Any = demo._bindings.enum.ConsoleForegroundColor.Blue,
4949
) -> None: ...
50-
def func(arg0: int) -> int: ...
50+
def func(arg0: typing.SupportsInt) -> int: ...
5151

5252
local_func_alias = func
5353
local_type_alias = Color

tests/stubs/python-3.12/pybind11-v3.0.0/numpy-array-use-type-var/demo/_bindings/classes.pyi

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ class CppException(Exception):
1313
pass
1414

1515
class Derived(Base):
16-
count: int
16+
@property
17+
def count(self) -> int: ...
18+
@count.setter
19+
def count(self, arg0: typing.SupportsInt) -> None: ...
1720

1821
class Foo:
1922
class FooChild:
@@ -43,11 +46,11 @@ class Outer:
4346
def __getstate__(self) -> int: ...
4447
def __hash__(self) -> int: ...
4548
def __index__(self) -> int: ...
46-
def __init__(self, value: int) -> None: ...
49+
def __init__(self, value: typing.SupportsInt) -> None: ...
4750
def __int__(self) -> int: ...
4851
def __ne__(self, other: typing.Any) -> bool: ...
4952
def __repr__(self) -> str: ...
50-
def __setstate__(self, state: int) -> None: ...
53+
def __setstate__(self, state: typing.SupportsInt) -> None: ...
5154
def __str__(self) -> str: ...
5255
@property
5356
def name(self) -> str: ...

tests/stubs/python-3.12/pybind11-v3.0.0/numpy-array-use-type-var/demo/_bindings/eigen.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ from __future__ import annotations
33
import typing
44

55
import numpy
6+
import numpy.typing
67
import scipy.sparse
78

89
__all__: list[str] = [

tests/stubs/python-3.12/pybind11-v3.0.0/numpy-array-use-type-var/demo/_bindings/enum.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ class ConsoleForegroundColor:
4949
def __getstate__(self) -> int: ...
5050
def __hash__(self) -> int: ...
5151
def __index__(self) -> int: ...
52-
def __init__(self, value: int) -> None: ...
52+
def __init__(self, value: typing.SupportsInt) -> None: ...
5353
def __int__(self) -> int: ...
5454
def __ne__(self, other: typing.Any) -> bool: ...
5555
def __repr__(self) -> str: ...
56-
def __setstate__(self, state: int) -> None: ...
56+
def __setstate__(self, state: typing.SupportsInt) -> None: ...
5757
def __str__(self) -> str: ...
5858
@property
5959
def name(self) -> str: ...

0 commit comments

Comments
 (0)