4040
4141
4242class RemoveSelfAnnotation (IParser ):
43-
4443 __any_t_name = QualifiedName .from_str ("Any" )
4544 __typing_any_t_name = QualifiedName .from_str ("typing.Any" )
4645
@@ -632,10 +631,19 @@ def report_error(self, error: ParserError) -> None:
632631
633632
634633class FixNumpyArrayDimTypeVar (IParser ):
635- __array_names : set [QualifiedName ] = {QualifiedName .from_str ("numpy.ndarray" )}
634+ __array_names : set [QualifiedName ] = {
635+ QualifiedName .from_str ("numpy.ndarray" ),
636+ QualifiedName .from_str ("numpy.typing.ArrayLike" ),
637+ QualifiedName .from_str ("numpy.typing.NDArray" ),
638+ }
639+ __typing_annotated_names = {
640+ QualifiedName .from_str ("typing.Annotated" ),
641+ QualifiedName .from_str ("typing_extensions.Annotated" ),
642+ }
636643 numpy_primitive_types = FixNumpyArrayDimAnnotation .numpy_primitive_types
637644
638645 __DIM_VARS : set [str ] = set ()
646+ __DIM_STRING_PATTERN = re .compile (r'"\[(.*?)\]"' )
639647
640648 def handle_module (
641649 self , path : QualifiedName , module : types .ModuleType
@@ -659,17 +667,11 @@ def handle_module(
659667 )
660668
661669 self .__DIM_VARS .clear ()
662-
663670 return result
664671
665672 def parse_annotation_str (
666673 self , annotation_str : str
667674 ) -> ResolvedType | InvalidExpression | Value :
668- # Affects types of the following pattern:
669- # numpy.ndarray[PRIMITIVE_TYPE[*DIMS], *FLAGS]
670- # Replace with:
671- # numpy.ndarray[tuple[M, Literal[1]], numpy.dtype[numpy.float32]]
672-
673675 result = super ().parse_annotation_str (annotation_str )
674676
675677 if not isinstance (result , ResolvedType ):
@@ -679,9 +681,103 @@ def parse_annotation_str(
679681 if len (result .name ) == 1 and len (result .name [0 ]) == 1 :
680682 result .name = QualifiedName .from_str (result .name [0 ].upper ())
681683 self .__DIM_VARS .add (result .name [0 ])
684+ return result
685+
686+ if result .name in self .__typing_annotated_names :
687+ return self ._handle_annotated_numpy_array (result )
688+ elif result .name in self .__array_names :
689+ return self ._handle_old_style_numpy_array (result )
690+
691+ return result
682692
683- if result .name not in self .__array_names :
693+ def _process_numpy_array_type (
694+ self , scalar_type_name : QualifiedName , dimensions : list [int | str ] | None
695+ ) -> tuple [ResolvedType , ResolvedType ]:
696+ # Pybind annotates a bool Python type, which cannot be used with
697+ # numpy.dtype because it does not inherit from numpy.generic.
698+ # Only numpy.bool_ works reliably with both NumPy 1.x and 2.x.
699+ if str (scalar_type_name ) == "bool" :
700+ scalar_type_name = QualifiedName .from_str ("numpy.bool_" )
701+ dtype = ResolvedType (
702+ name = QualifiedName .from_str ("numpy.dtype" ),
703+ parameters = [ResolvedType (name = scalar_type_name )],
704+ )
705+
706+ shape = self .parse_annotation_str ("Any" )
707+ if dimensions is not None and len (dimensions ) > 0 :
708+ shape = self .parse_annotation_str ("Tuple" )
709+ assert isinstance (shape , ResolvedType )
710+ shape .parameters = []
711+ for dim in dimensions :
712+ if isinstance (dim , int ):
713+ literal_dim = self .parse_annotation_str ("Literal" )
714+ assert isinstance (literal_dim , ResolvedType )
715+ literal_dim .parameters = [Value (repr = str (dim ))]
716+ shape .parameters .append (literal_dim )
717+ else :
718+ shape .parameters .append (
719+ ResolvedType (name = QualifiedName .from_str (dim .upper ()))
720+ )
721+ return shape , dtype
722+
723+ def _handle_annotated_numpy_array (self , result : ResolvedType ) -> ResolvedType :
724+ # Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, n]"]
725+ # Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]"]
726+ # Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]", "flags.writeable", "flags.c_contiguous"]
727+ if result .parameters is None or len (result .parameters ) < 2 :
728+ return result
729+
730+ array_type , * parameters = result .parameters
731+ if (
732+ not isinstance (array_type , ResolvedType )
733+ or array_type .name not in self .__array_names
734+ ):
735+ return result
736+
737+ dims_and_flags : Sequence [ResolvedType | Value | InvalidExpression ]
738+ if array_type .name == QualifiedName .from_str ("numpy.typing.ArrayLike" ):
739+ scalar_type , * dims_and_flags = parameters
740+ scalar_type_name = scalar_type .name
741+ elif array_type .name == QualifiedName .from_str ("numpy.typing.NDArray" ):
742+ if array_type .parameters is None or len (array_type .parameters ) < 2 :
743+ return result
744+ _ , dtype_param = array_type .parameters
745+ if not (
746+ isinstance (dtype_param , ResolvedType )
747+ and dtype_param .name == QualifiedName .from_str ("numpy.dtype" )
748+ and dtype_param .parameters
749+ ):
750+ return result
751+ scalar_type_name = dtype_param .parameters [0 ].name
752+ dims_and_flags = parameters
753+ else :
684754 return result
755+ if scalar_type_name not in self .numpy_primitive_types :
756+ return result
757+
758+ dims : list [int | str ] | None = None
759+ if dims_and_flags :
760+ dims_str , * flags = dims_and_flags
761+ del flags # Unused.
762+ if isinstance (dims_str , Value ):
763+ match = self .__DIM_STRING_PATTERN .search (dims_str .repr )
764+ if match :
765+ dims_str_content = match .group (1 )
766+ dims_list = [d .strip () for d in dims_str_content .split ("," ) if d .strip ()]
767+ if dims_list :
768+ dims = self .__to_dims_from_strings (dims_list )
769+
770+ shape , dtype = self ._process_numpy_array_type (scalar_type_name , dims )
771+ return ResolvedType (
772+ name = QualifiedName .from_str ("numpy.ndarray" ),
773+ parameters = [shape , dtype ],
774+ )
775+
776+ def _handle_old_style_numpy_array (self , result : ResolvedType ) -> ResolvedType :
777+ # Affects types of the following pattern:
778+ # numpy.ndarray[PRIMITIVE_TYPE[*DIMS], *FLAGS]
779+ # Replace with:
780+ # numpy.ndarray[tuple[M, Literal[1]], numpy.dtype[numpy.float32]]
685781
686782 # ndarray is generic and should have 2 type arguments
687783 if result .parameters is None or len (result .parameters ) == 0 :
@@ -702,39 +798,14 @@ def parse_annotation_str(
702798 ):
703799 return result
704800
705- name = scalar_with_dims .name
706- # Pybind annotates a bool Python type, which cannot be used with
707- # numpy.dtype because it does not inherit from numpy.generic.
708- # Only numpy.bool_ works reliably with both NumPy 1.x and 2.x.
709- if str (name ) == "bool" :
710- name = QualifiedName .from_str ("numpy.bool_" )
711- dtype = ResolvedType (
712- name = QualifiedName .from_str ("numpy.dtype" ),
713- parameters = [ResolvedType (name = name )],
714- )
715-
716- shape = self .parse_annotation_str ("Any" )
801+ dims : list [int | str ] | None = None
717802 if (
718803 scalar_with_dims .parameters is not None
719804 and len (scalar_with_dims .parameters ) > 0
720805 ):
721806 dims = self .__to_dims (scalar_with_dims .parameters )
722- if dims is not None :
723- shape = self .parse_annotation_str ("Tuple" )
724- assert isinstance (shape , ResolvedType )
725- shape .parameters = []
726- for dim in dims :
727- if isinstance (dim , int ):
728- # self.parse_annotation_str will qualify Literal with either
729- # typing or typing_extensions and add the import to the module
730- literal_dim = self .parse_annotation_str ("Literal" )
731- assert isinstance (literal_dim , ResolvedType )
732- literal_dim .parameters = [Value (repr = str (dim ))]
733- shape .parameters .append (literal_dim )
734- else :
735- shape .parameters .append (
736- ResolvedType (name = QualifiedName .from_str (dim ))
737- )
807+
808+ shape , dtype = self ._process_numpy_array_type (scalar_with_dims .name , dims )
738809
739810 result .parameters = [shape , dtype ]
740811 return result
@@ -756,6 +827,20 @@ def __to_dims(
756827 result .append (dim )
757828 return result
758829
830+ def __to_dims_from_strings (
831+ self , dimensions : Sequence [str ]
832+ ) -> list [int | str ] | None :
833+ result : list [int | str ] = []
834+ for dim_str in dimensions :
835+ try :
836+ dim = int (dim_str )
837+ except ValueError :
838+ dim = dim_str
839+ if len (dim ) == 1 : # Assuming single letter dims are type vars
840+ self .__DIM_VARS .add (dim .upper ()) # Add uppercase to TypeVar set
841+ result .append (dim )
842+ return result
843+
759844 def report_error (self , error : ParserError ) -> None :
760845 if (
761846 isinstance (error , NameResolutionError )
0 commit comments