4040
4141
4242class RemoveSelfAnnotation (IParser ):
43-
4443 __any_t_name = QualifiedName .from_str ("Any" )
4544 __typing_any_t_name = QualifiedName .from_str ("typing.Any" )
4645
@@ -632,10 +631,19 @@ def report_error(self, error: ParserError) -> None:
632631
633632
634633class FixNumpyArrayDimTypeVar (IParser ):
635- __array_names : set [QualifiedName ] = {QualifiedName .from_str ("numpy.ndarray" )}
634+ __array_names : set [QualifiedName ] = {
635+ QualifiedName .from_str ("numpy.ndarray" ),
636+ QualifiedName .from_str ("numpy.typing.ArrayLike" ),
637+ QualifiedName .from_str ("numpy.typing.NDArray" ),
638+ }
639+ __typing_annotated_names = {
640+ QualifiedName .from_str ("typing.Annotated" ),
641+ QualifiedName .from_str ("typing_extensions.Annotated" ),
642+ }
636643 numpy_primitive_types = FixNumpyArrayDimAnnotation .numpy_primitive_types
637644
638645 __DIM_VARS : set [str ] = set ()
646+ __DIM_STRING_PATTERN = re .compile (r'"\[(.*?)\]"' )
639647
640648 def handle_module (
641649 self , path : QualifiedName , module : types .ModuleType
@@ -659,85 +667,155 @@ def handle_module(
659667 )
660668
661669 self .__DIM_VARS .clear ()
662-
663670 return result
664671
665672 def parse_annotation_str (
666673 self , annotation_str : str
667674 ) -> ResolvedType | InvalidExpression | Value :
668- # Affects types of the following pattern:
669- # numpy.ndarray[PRIMITIVE_TYPE[*DIMS], *FLAGS]
670- # Replace with:
671- # numpy.ndarray[tuple[M, Literal[1]], numpy.dtype[numpy.float32]]
672-
673675 result = super ().parse_annotation_str (annotation_str )
674-
675676 if not isinstance (result , ResolvedType ):
676677 return result
677678
678679 # handle unqualified, single-letter annotation as a TypeVar
679680 if len (result .name ) == 1 and len (result .name [0 ]) == 1 :
680681 result .name = QualifiedName .from_str (result .name [0 ].upper ())
681682 self .__DIM_VARS .add (result .name [0 ])
683+ return result
682684
683- if result .name not in self .__array_names :
685+ if result .name == QualifiedName .from_str ("numpy.ndarray" ):
686+ parameters = self ._handle_old_style_numpy_array (result .parameters )
687+ elif result .name in self .__array_names :
688+ parameters = self ._handle_new_style_numpy_array ([result ])
689+ elif result .name in self .__typing_annotated_names :
690+ parameters = self ._handle_new_style_numpy_array (result .parameters )
691+ else :
692+ parameters = None
693+ if parameters is None : # Failure.
684694 return result
695+ return ResolvedType (
696+ name = QualifiedName .from_str ("numpy.ndarray" ), parameters = parameters
697+ )
698+
699+ def _process_numpy_array_type (
700+ self , scalar_type_name : QualifiedName , dimensions : list [int | str ] | None
701+ ) -> tuple [ResolvedType , ResolvedType ]:
702+ # Pybind annotates a bool Python type, which cannot be used with
703+ # numpy.dtype because it does not inherit from numpy.generic.
704+ # Only numpy.bool_ works reliably with both NumPy 1.x and 2.x.
705+ if str (scalar_type_name ) == "bool" :
706+ scalar_type_name = QualifiedName .from_str ("numpy.bool_" )
707+ dtype = ResolvedType (
708+ name = QualifiedName .from_str ("numpy.dtype" ),
709+ parameters = [ResolvedType (name = scalar_type_name )],
710+ )
711+
712+ shape = self .parse_annotation_str ("Any" )
713+ if dimensions is not None and len (dimensions ) > 0 :
714+ shape = self .parse_annotation_str ("Tuple" )
715+ assert isinstance (shape , ResolvedType )
716+ shape .parameters = []
717+ for dim in dimensions :
718+ if isinstance (dim , int ):
719+ literal_dim = self .parse_annotation_str ("Literal" )
720+ assert isinstance (literal_dim , ResolvedType )
721+ literal_dim .parameters = [Value (repr = str (dim ))]
722+ shape .parameters .append (literal_dim )
723+ else :
724+ shape .parameters .append (
725+ ResolvedType (name = QualifiedName .from_str (dim .upper ()))
726+ )
727+ return shape , dtype
728+
729+ def _handle_new_style_numpy_array (
730+ self , parameters : list [ResolvedType | Value | InvalidExpression ] | None
731+ ) -> list [ResolvedType ] | None :
732+ # Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, n]"]
733+ # Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]"]
734+ # Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]", "flags.writeable", "flags.c_contiguous"]
735+ if parameters is None or len (parameters ) == 0 :
736+ return
737+
738+ array_type , * parameters = parameters
739+ if (
740+ not isinstance (array_type , ResolvedType )
741+ or array_type .name not in self .__array_names
742+ ):
743+ return
744+
745+ dims_and_flags : Sequence [ResolvedType | Value | InvalidExpression ]
746+ if array_type .name == QualifiedName .from_str ("numpy.typing.ArrayLike" ):
747+ if not parameters :
748+ return
749+ scalar_type , * dims_and_flags = parameters
750+ elif array_type .name == QualifiedName .from_str ("numpy.typing.NDArray" ):
751+ if array_type .parameters is None or len (array_type .parameters ) == 0 :
752+ return
753+ [scalar_type ] = array_type .parameters
754+ dims_and_flags = parameters
755+ elif array_type .name == QualifiedName .from_str ("numpy.ndarray" ):
756+ _ , dtype_param = array_type .parameters
757+ if not (
758+ isinstance (dtype_param , ResolvedType )
759+ and dtype_param .name == QualifiedName .from_str ("numpy.dtype" )
760+ and dtype_param .parameters
761+ ):
762+ return
763+ [scalar_type ] = dtype_param .parameters
764+ dims_and_flags = parameters
765+ else :
766+ return
767+ scalar_type_name = scalar_type .name
768+ if scalar_type_name not in self .numpy_primitive_types :
769+ return
770+
771+ dims : list [int | str ] | None = None
772+ if dims_and_flags :
773+ dims_str , * flags = dims_and_flags
774+ del flags # Unused.
775+ if isinstance (dims_str , Value ):
776+ match = self .__DIM_STRING_PATTERN .search (dims_str .repr )
777+ if match :
778+ dims_str_content = match .group (1 )
779+ dims_list = [
780+ d .strip () for d in dims_str_content .split ("," ) if d .strip ()
781+ ]
782+ if dims_list :
783+ dims = self .__to_dims_from_strings (dims_list )
784+
785+ return self ._process_numpy_array_type (scalar_type_name , dims )
786+
787+ def _handle_old_style_numpy_array (
788+ self , parameters : list [ResolvedType | Value | InvalidExpression ] | None
789+ ) -> list [ResolvedType ] | None :
790+ # Affects types of the following pattern:
791+ # numpy.ndarray[PRIMITIVE_TYPE[*DIMS], *FLAGS]
792+ # Replace with:
793+ # numpy.ndarray[tuple[M, Literal[1]], numpy.dtype[numpy.float32]]
685794
686795 # ndarray is generic and should have 2 type arguments
687- if result . parameters is None or len (result . parameters ) == 0 :
688- result . parameters = [
796+ if parameters is None or len (parameters ) == 0 :
797+ return [
689798 self .parse_annotation_str ("Any" ),
690799 ResolvedType (
691800 name = QualifiedName .from_str ("numpy.dtype" ),
692801 parameters = [self .parse_annotation_str ("Any" )],
693802 ),
694803 ]
695- return result
696-
697- scalar_with_dims = result .parameters [0 ] # e.g. numpy.float64[32, 32]
698804
805+ scalar_with_dims = parameters [0 ] # e.g. numpy.float64[32, 32]
699806 if (
700807 not isinstance (scalar_with_dims , ResolvedType )
701808 or scalar_with_dims .name not in self .numpy_primitive_types
702809 ):
703- return result
704-
705- name = scalar_with_dims .name
706- # Pybind annotates a bool Python type, which cannot be used with
707- # numpy.dtype because it does not inherit from numpy.generic.
708- # Only numpy.bool_ works reliably with both NumPy 1.x and 2.x.
709- if str (name ) == "bool" :
710- name = QualifiedName .from_str ("numpy.bool_" )
711- dtype = ResolvedType (
712- name = QualifiedName .from_str ("numpy.dtype" ),
713- parameters = [ResolvedType (name = name )],
714- )
810+ return
715811
716- shape = self . parse_annotation_str ( "Any" )
812+ dims : list [ int | str ] | None = None
717813 if (
718814 scalar_with_dims .parameters is not None
719815 and len (scalar_with_dims .parameters ) > 0
720816 ):
721817 dims = self .__to_dims (scalar_with_dims .parameters )
722- if dims is not None :
723- shape = self .parse_annotation_str ("Tuple" )
724- assert isinstance (shape , ResolvedType )
725- shape .parameters = []
726- for dim in dims :
727- if isinstance (dim , int ):
728- # self.parse_annotation_str will qualify Literal with either
729- # typing or typing_extensions and add the import to the module
730- literal_dim = self .parse_annotation_str ("Literal" )
731- assert isinstance (literal_dim , ResolvedType )
732- literal_dim .parameters = [Value (repr = str (dim ))]
733- shape .parameters .append (literal_dim )
734- else :
735- shape .parameters .append (
736- ResolvedType (name = QualifiedName .from_str (dim ))
737- )
738-
739- result .parameters = [shape , dtype ]
740- return result
818+ return self ._process_numpy_array_type (scalar_with_dims .name , dims )
741819
742820 def __to_dims (
743821 self , dimensions : Sequence [ResolvedType | Value | InvalidExpression ]
@@ -756,6 +834,20 @@ def __to_dims(
756834 result .append (dim )
757835 return result
758836
837+ def __to_dims_from_strings (
838+ self , dimensions : Sequence [str ]
839+ ) -> list [int | str ] | None :
840+ result : list [int | str ] = []
841+ for dim_str in dimensions :
842+ try :
843+ dim = int (dim_str )
844+ except ValueError :
845+ dim = dim_str
846+ if len (dim ) == 1 : # Assuming single letter dims are type vars
847+ self .__DIM_VARS .add (dim .upper ()) # Add uppercase to TypeVar set
848+ result .append (dim )
849+ return result
850+
759851 def report_error (self , error : ParserError ) -> None :
760852 if (
761853 isinstance (error , NameResolutionError )
0 commit comments