@@ -545,7 +545,7 @@ class _NumpyArrayAnnotation:
545545 )
546546 __DIM_VARS = ["n" , "m" ]
547547 __DIM_STRING_PATTERN = re .compile (r'"\[(.*?)\]"' )
548-
548+
549549 def __init__ (
550550 self ,
551551 array_type : ResolvedType ,
@@ -557,7 +557,7 @@ def __init__(
557557 self .scalar_type = scalar_type
558558 self .dimensions = dimensions
559559 self .flags = flags
560-
560+
561561 def to_type_hint (
562562 self , parser : IParser , on_dynamic_dim : Optional [callable [[str ], None ]] = None
563563 ) -> tuple [ResolvedType , ResolvedType ]:
@@ -571,7 +571,7 @@ def to_type_hint(
571571 name = QualifiedName .from_str ("numpy.dtype" ),
572572 parameters = [ResolvedType (name = scalar_type_name )],
573573 )
574-
574+
575575 shape = parser .parse_annotation_str ("Any" )
576576 if self .dimensions :
577577 shape = parser .parse_annotation_str ("Tuple" )
@@ -590,7 +590,7 @@ def to_type_hint(
590590 ResolvedType (name = QualifiedName .from_str (dim .upper ()))
591591 )
592592 return shape , dtype
593-
593+
594594 @classmethod
595595 def from_annotation (
596596 cls , resolved_type : ResolvedType
@@ -601,23 +601,23 @@ def from_annotation(
601601 return cls ._from_old_style (resolved_type )
602602 else :
603603 return None
604-
604+
605605 @classmethod
606606 def _from_old_style (
607607 cls , resolved_type : ResolvedType
608608 ) -> Optional [_NumpyArrayAnnotation ]:
609609 if resolved_type .parameters is None or len (resolved_type .parameters ) == 0 :
610610 return None
611-
611+
612612 scalar_with_dims = resolved_type .parameters [0 ]
613613 flags = resolved_type .parameters [1 :]
614-
614+
615615 if (
616616 not isinstance (scalar_with_dims , ResolvedType )
617617 or scalar_with_dims .name not in cls .numpy_primitive_types
618618 ):
619619 return None
620-
620+
621621 array_type = ResolvedType (name = resolved_type .name )
622622 scalar_type = ResolvedType (name = scalar_with_dims .name )
623623 dimensions : Optional [list [str | int ]] = None
@@ -626,9 +626,11 @@ def _from_old_style(
626626 and len (scalar_with_dims .parameters ) > 0
627627 ):
628628 dimensions = cls ._to_dims (scalar_with_dims .parameters )
629-
630- return cls (array_type , scalar_type , dimensions , flags )
631-
629+
630+ cls ._fix_flags (flags )
631+
632+ return _NumpyArrayAnnotation (array_type , scalar_type , dimensions , flags )
633+
632634 @classmethod
633635 def _from_new_style (
634636 cls , resolved_type : ResolvedType
@@ -705,7 +707,7 @@ def _from_new_style(
705707 flags = dims_and_flags [1 :]
706708
707709 return cls (array_type , scalar_type , dims , flags )
708-
710+
709711 @classmethod
710712 def _to_dims (
711713 cls , dimensions : Sequence [ResolvedType | Value | InvalidExpression ]
@@ -719,13 +721,24 @@ def _to_dims(
719721 return None
720722 elif isinstance (dim_param , ResolvedType ):
721723 dim = str (dim_param )
722- if dim not in cls .__DIM_VARS :
724+ if len ( dim ) > 1 and dim not in cls .__DIM_VARS :
723725 return None
724726 else :
725727 return None
726728 result .append (dim )
727729 return result
728-
730+
731+ @staticmethod
732+ def _fix_flags (flags : list [ResolvedType | Value | InvalidExpression ]):
733+ __flags : set [QualifiedName ] = {
734+ QualifiedName .from_str ("flags.writeable" ),
735+ QualifiedName .from_str ("flags.c_contiguous" ),
736+ QualifiedName .from_str ("flags.f_contiguous" ),
737+ }
738+ for flag in flags :
739+ if isinstance (flag , ResolvedType ) and flag .name in __flags :
740+ flag .name = QualifiedName .from_str (f"numpy.ndarray.{ flag .name } " )
741+
729742 @staticmethod
730743 def _to_dims_from_strings (dimensions : Sequence [str ]) -> list [int | str ] | None :
731744 result : list [int | str ] = []
@@ -736,29 +749,27 @@ def _to_dims_from_strings(dimensions: Sequence[str]) -> list[int | str] | None:
736749 dim = dim_str
737750 result .append (dim )
738751 return result
739-
740-
752+
753+
741754class FixNumpyArrayDimAnnotation (IParser ):
742755 # NB: Not using full name due to ambiguity `typing.Annotated` vs
743756 # `typing_extension.Annotated` in different python versions
744757 # Rely on later fix by `FixTypingTypeNames`
745758 __annotated_name = QualifiedName .from_str ("Annotated" )
746759 __DIM_VARS = _NumpyArrayAnnotation ._NumpyArrayAnnotation__DIM_VARS
747-
760+
748761 def parse_annotation_str (
749762 self , annotation_str : str
750763 ) -> ResolvedType | InvalidExpression | Value :
751764 # Affects types of the following pattern:
752765 # ARRAY_T[PRIMITIVE_TYPE[*DIMS], *FLAGS]
753766 # Replace with:
754767 # Annotated[ARRAY_T, PRIMITIVE_TYPE, FixedSize/DynamicSize[*DIMS], *FLAGS]
755-
768+
756769 result = super ().parse_annotation_str (annotation_str )
757- # if 'scipy.sparse' in str(result):
758- # __import__('ipdb').set_trace()
759770 if not isinstance (result , ResolvedType ):
760771 return result
761-
772+
762773 numpy_array = _NumpyArrayAnnotation .from_annotation (result )
763774 if numpy_array is None :
764775 return result
@@ -791,7 +802,9 @@ def parse_annotation_str(
791802 "flags.f_contiguous" ,
792803 ):
793804 params .append (
794- ResolvedType (name = QualifiedName .from_str (f"numpy.ndarray.{ flag_str } " ))
805+ ResolvedType (
806+ name = QualifiedName .from_str (f"numpy.ndarray.{ flag_str } " )
807+ )
795808 )
796809 else :
797810 params .append (flag )
@@ -805,13 +818,18 @@ def __wrap_with_size_helper(self, dims: list[int | str]) -> FixedSize | DynamicS
805818 return_t = FixedSize
806819 else :
807820 return_t = DynamicSize
808-
821+
809822 # TRICK: Use `self.handle_type` to make `FixedSize`/`DynamicSize`
810823 # properly added to the list of imports
811824 self .handle_type (return_t )
812825 return return_t (* dims ) # type: ignore[arg-type]
813-
826+
814827 def report_error (self , error : ParserError ) -> None :
828+ __flags : set [QualifiedName ] = {
829+ QualifiedName .from_str ("flags.writeable" ),
830+ QualifiedName .from_str ("flags.c_contiguous" ),
831+ QualifiedName .from_str ("flags.f_contiguous" ),
832+ }
815833 if (
816834 isinstance (error , NameResolutionError )
817835 and len (error .name ) == 1
@@ -820,36 +838,38 @@ def report_error(self, error: ParserError) -> None:
820838 ):
821839 # Ignores all unknown 'm' and 'n' regardless of the context
822840 return
841+ if isinstance (error , NameResolutionError ) and error .name in __flags :
842+ return
823843 super ().report_error (error )
824844
825845
826846class FixNumpyArrayDimTypeVar (IParser ):
827847 __DIM_VARS : set [str ] = set ()
828-
848+
829849 def handle_module (
830850 self , path : QualifiedName , module : types .ModuleType
831851 ) -> Module | None :
832852 result = super ().handle_module (path , module )
833853 if result is None :
834854 return None
835-
855+
836856 if self .__DIM_VARS :
837857 # the TypeVar_'s generated code will reference `typing`
838858 result .imports .add (
839859 Import (name = None , origin = QualifiedName .from_str ("typing" ))
840860 )
841-
861+
842862 for name in self .__DIM_VARS :
843863 result .type_vars .append (
844864 TypeVar_ (
845865 name = Identifier (name ),
846866 bound = self .parse_annotation_str ("int" ),
847867 ),
848868 )
849-
869+
850870 self .__DIM_VARS .clear ()
851871 return result
852-
872+
853873 def parse_annotation_str (
854874 self , annotation_str : str
855875 ) -> ResolvedType | InvalidExpression | Value :
@@ -867,32 +887,39 @@ def parse_annotation_str(
867887 if numpy_array is None :
868888 return result
869889 # __import__('ipdb').set_trace()
870-
890+
871891 # scipy.sparse arrays/matrices are not currently generic and do not accept type
872892 # arguments
873893 if numpy_array .array_type .name [:2 ] == ("scipy" , "sparse" ):
874894 return result
875-
895+
876896 def on_dynamic_dim (dim : str ) -> None :
877897 if len (dim ) == 1 : # Assuming single letter dims are type vars
878898 self .__DIM_VARS .add (dim .upper ())
879-
899+
880900 shape , dtype = numpy_array .to_type_hint (self , on_dynamic_dim )
881901 return ResolvedType (
882902 name = QualifiedName .from_str ("numpy.ndarray" ), parameters = [shape , dtype ]
883903 )
884-
904+
885905 def report_error (self , error : ParserError ) -> None :
906+ __flags : set [QualifiedName ] = {
907+ QualifiedName .from_str ("flags.writeable" ),
908+ QualifiedName .from_str ("flags.c_contiguous" ),
909+ QualifiedName .from_str ("flags.f_contiguous" ),
910+ }
886911 if (
887912 isinstance (error , NameResolutionError )
888913 and len (error .name ) == 1
889914 and error .name [0 ] in self .__DIM_VARS
890915 ):
891916 # allow type variables, which are manually resolved in `handle_module`
892917 return
918+ if isinstance (error , NameResolutionError ) and error .name in __flags :
919+ return
893920 super ().report_error (error )
894-
895-
921+
922+
896923class FixNumpyArrayRemoveParameters (IParser ):
897924 def parse_annotation_str (
898925 self , annotation_str : str
0 commit comments