11import sys
2- import logging
3- from typing import Generic , TypeVar
4- from typing_extensions import Annotated
2+ from typing import Generic , TypeVar , Any
53
4+ from typing_extensions import Annotated
5+ from packaging import version
66from pydantic import Field
77from pydantic .fields import ModelField
88import numpy as np
99
1010Cuid = Annotated [str , Field (min_length = 25 , max_length = 25 )]
1111
1212DType = TypeVar ('DType' )
13-
14- logger = logging .getLogger (__name__ )
13+ DShape = TypeVar ('DShape' )
1514
1615
17- class TypedArray (np .ndarray , Generic [DType ]):
16+ class _TypedArray (np .ndarray , Generic [DType , DShape ]):
1817
1918 @classmethod
2019 def __get_validators__ (cls ):
@@ -26,12 +25,19 @@ def validate(cls, val, field: ModelField):
2625 raise TypeError (f"Expected numpy array. Found { type (val )} " )
2726
2827 if sys .version_info .minor > 6 :
29- actual_dtype = field .sub_fields [0 ].type_ .__args__ [0 ]
28+ actual_dtype = field .sub_fields [- 1 ].type_ .__args__ [0 ]
3029 else :
31- actual_dtype = field .sub_fields [0 ].type_ .__values__ [0 ]
30+ actual_dtype = field .sub_fields [- 1 ].type_ .__values__ [0 ]
3231
3332 if val .dtype != actual_dtype :
3433 raise TypeError (
3534 f"Expected numpy array have type { actual_dtype } . Found { val .dtype } "
3635 )
3736 return val
37+
38+
39+ if version .parse (np .__version__ ) >= version .parse ('1.22.0' ):
40+ from numpy .typing import _GenericAlias
41+ TypedArray = _GenericAlias (_TypedArray , (Any , DType ))
42+ else :
43+ TypedArray = _TypedArray [Any , DType ]
0 commit comments