Skip to content

Commit 4444838

Browse files
committed
Fix broken handling of v2.13
1 parent 6be49ef commit 4444838

File tree

1 file changed

+62
-35
lines changed
  • pybind11_stubgen/parser/mixins

1 file changed

+62
-35
lines changed

pybind11_stubgen/parser/mixins/fix.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ class _NumpyArrayAnnotation:
545545
)
546546
__DIM_VARS = ["n", "m"]
547547
__DIM_STRING_PATTERN = re.compile(r'"\[(.*?)\]"')
548-
548+
549549
def __init__(
550550
self,
551551
array_type: ResolvedType,
@@ -557,7 +557,7 @@ def __init__(
557557
self.scalar_type = scalar_type
558558
self.dimensions = dimensions
559559
self.flags = flags
560-
560+
561561
def to_type_hint(
562562
self, parser: IParser, on_dynamic_dim: Optional[callable[[str], None]] = None
563563
) -> tuple[ResolvedType, ResolvedType]:
@@ -571,7 +571,7 @@ def to_type_hint(
571571
name=QualifiedName.from_str("numpy.dtype"),
572572
parameters=[ResolvedType(name=scalar_type_name)],
573573
)
574-
574+
575575
shape = parser.parse_annotation_str("Any")
576576
if self.dimensions:
577577
shape = parser.parse_annotation_str("Tuple")
@@ -590,7 +590,7 @@ def to_type_hint(
590590
ResolvedType(name=QualifiedName.from_str(dim.upper()))
591591
)
592592
return shape, dtype
593-
593+
594594
@classmethod
595595
def from_annotation(
596596
cls, resolved_type: ResolvedType
@@ -601,23 +601,23 @@ def from_annotation(
601601
return cls._from_old_style(resolved_type)
602602
else:
603603
return None
604-
604+
605605
@classmethod
606606
def _from_old_style(
607607
cls, resolved_type: ResolvedType
608608
) -> Optional[_NumpyArrayAnnotation]:
609609
if resolved_type.parameters is None or len(resolved_type.parameters) == 0:
610610
return None
611-
611+
612612
scalar_with_dims = resolved_type.parameters[0]
613613
flags = resolved_type.parameters[1:]
614-
614+
615615
if (
616616
not isinstance(scalar_with_dims, ResolvedType)
617617
or scalar_with_dims.name not in cls.numpy_primitive_types
618618
):
619619
return None
620-
620+
621621
array_type = ResolvedType(name=resolved_type.name)
622622
scalar_type = ResolvedType(name=scalar_with_dims.name)
623623
dimensions: Optional[list[str | int]] = None
@@ -626,9 +626,11 @@ def _from_old_style(
626626
and len(scalar_with_dims.parameters) > 0
627627
):
628628
dimensions = cls._to_dims(scalar_with_dims.parameters)
629-
630-
return cls(array_type, scalar_type, dimensions, flags)
631-
629+
630+
cls._fix_flags(flags)
631+
632+
return _NumpyArrayAnnotation(array_type, scalar_type, dimensions, flags)
633+
632634
@classmethod
633635
def _from_new_style(
634636
cls, resolved_type: ResolvedType
@@ -705,7 +707,7 @@ def _from_new_style(
705707
flags = dims_and_flags[1:]
706708

707709
return cls(array_type, scalar_type, dims, flags)
708-
710+
709711
@classmethod
710712
def _to_dims(
711713
cls, dimensions: Sequence[ResolvedType | Value | InvalidExpression]
@@ -719,13 +721,24 @@ def _to_dims(
719721
return None
720722
elif isinstance(dim_param, ResolvedType):
721723
dim = str(dim_param)
722-
if dim not in cls.__DIM_VARS:
724+
if len(dim) > 1 and dim not in cls.__DIM_VARS:
723725
return None
724726
else:
725727
return None
726728
result.append(dim)
727729
return result
728-
730+
731+
@staticmethod
732+
def _fix_flags(flags: list[ResolvedType | Value | InvalidExpression]):
733+
__flags: set[QualifiedName] = {
734+
QualifiedName.from_str("flags.writeable"),
735+
QualifiedName.from_str("flags.c_contiguous"),
736+
QualifiedName.from_str("flags.f_contiguous"),
737+
}
738+
for flag in flags:
739+
if isinstance(flag, ResolvedType) and flag.name in __flags:
740+
flag.name = QualifiedName.from_str(f"numpy.ndarray.{flag.name}")
741+
729742
@staticmethod
730743
def _to_dims_from_strings(dimensions: Sequence[str]) -> list[int | str] | None:
731744
result: list[int | str] = []
@@ -736,29 +749,27 @@ def _to_dims_from_strings(dimensions: Sequence[str]) -> list[int | str] | None:
736749
dim = dim_str
737750
result.append(dim)
738751
return result
739-
740-
752+
753+
741754
class FixNumpyArrayDimAnnotation(IParser):
742755
# NB: Not using full name due to ambiguity `typing.Annotated` vs
743756
# `typing_extension.Annotated` in different python versions
744757
# Rely on later fix by `FixTypingTypeNames`
745758
__annotated_name = QualifiedName.from_str("Annotated")
746759
__DIM_VARS = _NumpyArrayAnnotation._NumpyArrayAnnotation__DIM_VARS
747-
760+
748761
def parse_annotation_str(
749762
self, annotation_str: str
750763
) -> ResolvedType | InvalidExpression | Value:
751764
# Affects types of the following pattern:
752765
# ARRAY_T[PRIMITIVE_TYPE[*DIMS], *FLAGS]
753766
# Replace with:
754767
# Annotated[ARRAY_T, PRIMITIVE_TYPE, FixedSize/DynamicSize[*DIMS], *FLAGS]
755-
768+
756769
result = super().parse_annotation_str(annotation_str)
757-
# if 'scipy.sparse' in str(result):
758-
# __import__('ipdb').set_trace()
759770
if not isinstance(result, ResolvedType):
760771
return result
761-
772+
762773
numpy_array = _NumpyArrayAnnotation.from_annotation(result)
763774
if numpy_array is None:
764775
return result
@@ -791,7 +802,9 @@ def parse_annotation_str(
791802
"flags.f_contiguous",
792803
):
793804
params.append(
794-
ResolvedType(name=QualifiedName.from_str(f"numpy.ndarray.{flag_str}"))
805+
ResolvedType(
806+
name=QualifiedName.from_str(f"numpy.ndarray.{flag_str}")
807+
)
795808
)
796809
else:
797810
params.append(flag)
@@ -805,13 +818,18 @@ def __wrap_with_size_helper(self, dims: list[int | str]) -> FixedSize | DynamicS
805818
return_t = FixedSize
806819
else:
807820
return_t = DynamicSize
808-
821+
809822
# TRICK: Use `self.handle_type` to make `FixedSize`/`DynamicSize`
810823
# properly added to the list of imports
811824
self.handle_type(return_t)
812825
return return_t(*dims) # type: ignore[arg-type]
813-
826+
814827
def report_error(self, error: ParserError) -> None:
828+
__flags: set[QualifiedName] = {
829+
QualifiedName.from_str("flags.writeable"),
830+
QualifiedName.from_str("flags.c_contiguous"),
831+
QualifiedName.from_str("flags.f_contiguous"),
832+
}
815833
if (
816834
isinstance(error, NameResolutionError)
817835
and len(error.name) == 1
@@ -820,36 +838,38 @@ def report_error(self, error: ParserError) -> None:
820838
):
821839
# Ignores all unknown 'm' and 'n' regardless of the context
822840
return
841+
if isinstance(error, NameResolutionError) and error.name in __flags:
842+
return
823843
super().report_error(error)
824844

825845

826846
class FixNumpyArrayDimTypeVar(IParser):
827847
__DIM_VARS: set[str] = set()
828-
848+
829849
def handle_module(
830850
self, path: QualifiedName, module: types.ModuleType
831851
) -> Module | None:
832852
result = super().handle_module(path, module)
833853
if result is None:
834854
return None
835-
855+
836856
if self.__DIM_VARS:
837857
# the TypeVar_'s generated code will reference `typing`
838858
result.imports.add(
839859
Import(name=None, origin=QualifiedName.from_str("typing"))
840860
)
841-
861+
842862
for name in self.__DIM_VARS:
843863
result.type_vars.append(
844864
TypeVar_(
845865
name=Identifier(name),
846866
bound=self.parse_annotation_str("int"),
847867
),
848868
)
849-
869+
850870
self.__DIM_VARS.clear()
851871
return result
852-
872+
853873
def parse_annotation_str(
854874
self, annotation_str: str
855875
) -> ResolvedType | InvalidExpression | Value:
@@ -867,32 +887,39 @@ def parse_annotation_str(
867887
if numpy_array is None:
868888
return result
869889
# __import__('ipdb').set_trace()
870-
890+
871891
# scipy.sparse arrays/matrices are not currently generic and do not accept type
872892
# arguments
873893
if numpy_array.array_type.name[:2] == ("scipy", "sparse"):
874894
return result
875-
895+
876896
def on_dynamic_dim(dim: str) -> None:
877897
if len(dim) == 1: # Assuming single letter dims are type vars
878898
self.__DIM_VARS.add(dim.upper())
879-
899+
880900
shape, dtype = numpy_array.to_type_hint(self, on_dynamic_dim)
881901
return ResolvedType(
882902
name=QualifiedName.from_str("numpy.ndarray"), parameters=[shape, dtype]
883903
)
884-
904+
885905
def report_error(self, error: ParserError) -> None:
906+
__flags: set[QualifiedName] = {
907+
QualifiedName.from_str("flags.writeable"),
908+
QualifiedName.from_str("flags.c_contiguous"),
909+
QualifiedName.from_str("flags.f_contiguous"),
910+
}
886911
if (
887912
isinstance(error, NameResolutionError)
888913
and len(error.name) == 1
889914
and error.name[0] in self.__DIM_VARS
890915
):
891916
# allow type variables, which are manually resolved in `handle_module`
892917
return
918+
if isinstance(error, NameResolutionError) and error.name in __flags:
919+
return
893920
super().report_error(error)
894-
895-
921+
922+
896923
class FixNumpyArrayRemoveParameters(IParser):
897924
def parse_annotation_str(
898925
self, annotation_str: str

0 commit comments

Comments
 (0)