1- import sys
2- from typing import Annotated , Any , Generic , TypeVar
3-
1+ from typing import Generic , TypeVar
42import numpy as np
5- from packaging import version
6- from pydantic import ConfigDict , Field , StringConstraints
73from pydantic_core import core_schema
84
95DType = TypeVar ("DType" )
106DShape = TypeVar ("DShape" )
117
128
13- class _TypedArray (np .ndarray , Generic [DType , DShape ]):
9+ class TypedArray (np .ndarray , Generic [DType , DShape ]):
1410 @classmethod
1511 def __get_pydantic_core_schema__ (
1612 cls , _source_type : type , _model : type
@@ -22,23 +18,3 @@ def validate(cls, val):
2218 if not isinstance (val , np .ndarray ):
2319 raise TypeError (f"Expected numpy array. Found { type (val )} " )
2420 return val
25-
26-
27- if version .parse (np .__version__ ) >= version .parse ("1.25.0" ):
28- from typing import GenericAlias
29-
30- TypedArray = GenericAlias (_TypedArray , (Any , DType ))
31- elif version .parse (np .__version__ ) >= version .parse ("1.23.0" ):
32- from numpy ._typing import _GenericAlias
33-
34- TypedArray = _GenericAlias (_TypedArray , (Any , DType ))
35- elif (
36- version .parse ("1.22.0" )
37- <= version .parse (np .__version__ )
38- < version .parse ("1.23.0" )
39- ):
40- from numpy .typing import _GenericAlias
41-
42- TypedArray = _GenericAlias (_TypedArray , (Any , DType ))
43- else :
44- TypedArray = _TypedArray [Any , DType ]
0 commit comments