11import sys
2- from typing import Generic , TypeVar
3- from typing_extensions import Annotated
2+ from typing import Generic , TypeVar , Any
43
4+ from typing_extensions import Annotated
5+ from packaging import version
56from pydantic import Field
67from pydantic .fields import ModelField
78import numpy as np
89
910Cuid = Annotated [str , Field (min_length = 25 , max_length = 25 )]
10- DType = TypeVar ('DType' )
1111
12+ DType = TypeVar ('DType' )
13+ DShape = TypeVar ('DShape' )
1214
13- class TypedArray (Generic [DType ]):
1415
15- def __new__ (cls , * args , ** kwargs ):
16- return np .ndarray (* args , ** kwargs )
16+ class _TypedArray (np .ndarray , Generic [DType , DShape ]):
1717
1818 @classmethod
1919 def __get_validators__ (cls ):
@@ -25,12 +25,19 @@ def validate(cls, val, field: ModelField):
2525 raise TypeError (f"Expected numpy array. Found { type (val )} " )
2626
2727 if sys .version_info .minor > 6 :
28- actual_dtype = field .sub_fields [0 ].type_ .__args__ [0 ]
28+ actual_dtype = field .sub_fields [1 ].type_ .__args__ [0 ]
2929 else :
30- actual_dtype = field .sub_fields [0 ].type_ .__values__ [0 ]
30+ actual_dtype = field .sub_fields [1 ].type_ .__values__ [0 ]
3131
3232 if val .dtype != actual_dtype :
3333 raise TypeError (
3434 f"Expected numpy array have type { actual_dtype } . Found { val .dtype } "
3535 )
3636 return val
37+
38+
39+ if version .parse (np .__version__ ) >= version .parse ('1.22.0' ):
40+ from numpy .typing import _GenericAlias
41+ TypedArray = _GenericAlias (_TypedArray , (Any , DType ))
42+ else :
43+ TypedArray = _TypedArray [Any , DType ]
0 commit comments