Skip to content

Commit 3e8877d

Browse files
committed
Fix flags also when numpy annotations are not converted
1 parent 59e373d commit 3e8877d

File tree

2 files changed

+39
-42
lines changed

2 files changed

+39
-42
lines changed

pybind11_stubgen/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
FixMissingNoneHashFieldAnnotation,
3737
FixNumpyArrayDimAnnotation,
3838
FixNumpyArrayDimTypeVar,
39+
FixNumpyArrayFlags,
3940
FixNumpyArrayRemoveParameters,
4041
FixNumpyDtype,
4142
FixPEP585CollectionNames,
@@ -264,6 +265,7 @@ class Parser(
264265
OverridePrintSafeValues,
265266
*numpy_fixes, # type: ignore[misc]
266267
FixNumpyDtype,
268+
FixNumpyArrayFlags,
267269
FixCurrentModulePrefixInTypeNames,
268270
FixBuiltinTypes,
269271
RewritePybind11EnumValueRepr,

pybind11_stubgen/parser/mixins/fix.py

Lines changed: 37 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -543,11 +543,6 @@ class _NumpyArrayAnnotation:
543543
),
544544
)
545545
)
546-
numpy_flags: set[QualifiedName] = {
547-
QualifiedName.from_str("flags.writeable"),
548-
QualifiedName.from_str("flags.c_contiguous"),
549-
QualifiedName.from_str("flags.f_contiguous"),
550-
}
551546
dim_vars: set[str] = {"n", "m"}
552547
__dim_string_pattern = re.compile(r'"\[(.*?)\]"')
553548

@@ -632,8 +627,6 @@ def _from_old_style(
632627
):
633628
dimensions = cls._to_dims(scalar_with_dims.parameters)
634629

635-
cls._fix_flags(flags)
636-
637630
return _NumpyArrayAnnotation(array_type, scalar_type, dimensions, flags)
638631

639632
@classmethod
@@ -711,6 +704,14 @@ def _from_new_style(
711704
dims = cls._to_dims_from_strings(dims_list)
712705
flags = dims_and_flags[1:]
713706

707+
for i, flag in enumerate(flags):
708+
if isinstance(flag, Value):
709+
flag_str = flag.repr.strip('"')
710+
if flag_str.startswith("flags."):
711+
flags[i] = ResolvedType(
712+
name=QualifiedName.from_str(f"numpy.ndarray.{flag_str}")
713+
)
714+
714715
return cls(array_type, scalar_type, dims, flags)
715716

716717
@classmethod
@@ -733,12 +734,6 @@ def _to_dims(
733734
result.append(dim)
734735
return result
735736

736-
@classmethod
737-
def _fix_flags(cls, flags: list[ResolvedType | Value | InvalidExpression]):
738-
for flag in flags:
739-
if isinstance(flag, ResolvedType) and flag.name in cls.numpy_flags:
740-
flag.name = QualifiedName.from_str(f"numpy.ndarray.{flag.name}")
741-
742737
@staticmethod
743738
def _to_dims_from_strings(dimensions: Sequence[str]) -> list[int | str] | None:
744739
result: list[int | str] = []
@@ -791,24 +786,7 @@ def parse_annotation_str(
791786
params.append(
792787
self.handle_value(self.__wrap_with_size_helper(numpy_array.dimensions))
793788
)
794-
795-
for flag in numpy_array.flags:
796-
if isinstance(flag, Value):
797-
flag_str = flag.repr.strip('"')
798-
if flag_str in (
799-
"flags.writeable",
800-
"flags.c_contiguous",
801-
"flags.f_contiguous",
802-
):
803-
params.append(
804-
ResolvedType(
805-
name=QualifiedName.from_str(f"numpy.ndarray.{flag_str}")
806-
)
807-
)
808-
else:
809-
params.append(flag)
810-
else:
811-
params.append(flag)
789+
params.extend(numpy_array.flags)
812790

813791
return ResolvedType(name=self.__annotated_name, parameters=params)
814792

@@ -832,11 +810,6 @@ def report_error(self, error: ParserError) -> None:
832810
):
833811
# Ignores all unknown 'm' and 'n' regardless of the context
834812
return
835-
if (
836-
isinstance(error, NameResolutionError)
837-
and error.name in _NumpyArrayAnnotation.numpy_flags
838-
):
839-
return
840813
super().report_error(error)
841814

842815

@@ -883,7 +856,6 @@ def parse_annotation_str(
883856
numpy_array = _NumpyArrayAnnotation.from_annotation(result)
884857
if numpy_array is None:
885858
return result
886-
# __import__('ipdb').set_trace()
887859

888860
# scipy.sparse arrays/matrices are not currently generic and do not accept type
889861
# arguments
@@ -907,11 +879,6 @@ def report_error(self, error: ParserError) -> None:
907879
):
908880
# allow type variables, which are manually resolved in `handle_module`
909881
return
910-
if (
911-
isinstance(error, NameResolutionError)
912-
and error.name in _NumpyArrayAnnotation.numpy_flags
913-
):
914-
return
915882
super().report_error(error)
916883

917884

@@ -964,6 +931,34 @@ def parse_annotation_str(
964931
return result
965932

966933

934+
class FixNumpyArrayFlags(IParser):
935+
__ndarray_name = QualifiedName.from_str("numpy.ndarray")
936+
__flags: set[QualifiedName] = {
937+
QualifiedName.from_str("flags.writeable"),
938+
QualifiedName.from_str("flags.c_contiguous"),
939+
QualifiedName.from_str("flags.f_contiguous"),
940+
}
941+
942+
def parse_annotation_str(
943+
self, annotation_str: str
944+
) -> ResolvedType | InvalidExpression | Value:
945+
result = super().parse_annotation_str(annotation_str)
946+
if isinstance(result, ResolvedType) and result.name == self.__ndarray_name:
947+
if result.parameters is not None:
948+
for param in result.parameters:
949+
if isinstance(param, ResolvedType) and param.name in self.__flags:
950+
param.name = QualifiedName.from_str(
951+
f"numpy.ndarray.{param.name}"
952+
)
953+
954+
return result
955+
956+
def report_error(self, error: ParserError) -> None:
957+
if isinstance(error, NameResolutionError) and error.name in self.__flags:
958+
return
959+
super().report_error(error)
960+
961+
967962
class FixRedundantMethodsFromBuiltinObject(IParser):
968963
def handle_method(self, path: QualifiedName, method: Any) -> list[Method]:
969964
result = super().handle_method(path, method)

0 commit comments

Comments
 (0)