55from typing import Any , Dict , NamedTuple , Sequence , Tuple , Union
66from warnings import warn
77
8+ from . import api_version
89from . import _array_module as xp
910from ._array_module import _UndefinedStub
1011from .stubs import name_to_func
1516 "uint_dtypes" ,
1617 "all_int_dtypes" ,
1718 "float_dtypes" ,
19+ "real_dtypes" ,
1820 "numeric_dtypes" ,
1921 "all_dtypes" ,
20- "dtype_to_name " ,
22+ "all_float_dtypes " ,
2123 "bool_and_all_int_dtypes" ,
24+ "dtype_to_name" ,
2225 "dtype_to_scalars" ,
2326 "is_int_dtype" ,
2427 "is_float_dtype" ,
2730 "default_int" ,
2831 "default_uint" ,
2932 "default_float" ,
33+ "default_complex" ,
3034 "promotion_table" ,
3135 "dtype_nbits" ,
3236 "dtype_signed" ,
37+ "dtype_components" ,
3338 "func_in_dtypes" ,
3439 "func_returns_bool" ,
3540 "binary_op_to_symbol" ,
@@ -86,15 +91,25 @@ def __repr__(self):
8691_uint_names = ("uint8" , "uint16" , "uint32" , "uint64" )
8792_int_names = ("int8" , "int16" , "int32" , "int64" )
8893_float_names = ("float32" , "float64" )
89- _dtype_names = ("bool" ,) + _uint_names + _int_names + _float_names
94+ _real_names = _uint_names + _int_names + _float_names
95+ _complex_names = ("complex64" , "complex128" )
96+ _numeric_names = _real_names + _complex_names
97+ _dtype_names = ("bool" ,) + _numeric_names
9098
9199
92100uint_dtypes = tuple (getattr (xp , name ) for name in _uint_names )
93101int_dtypes = tuple (getattr (xp , name ) for name in _int_names )
94102float_dtypes = tuple (getattr (xp , name ) for name in _float_names )
95103all_int_dtypes = uint_dtypes + int_dtypes
96- numeric_dtypes = all_int_dtypes + float_dtypes
104+ real_dtypes = all_int_dtypes + float_dtypes
105+ complex_dtypes = tuple (getattr (xp , name ) for name in _complex_names )
106+ numeric_dtypes = real_dtypes
107+ if api_version > "2021.12" :
108+ numeric_dtypes += complex_dtypes
97109all_dtypes = (xp .bool ,) + numeric_dtypes
110+ all_float_dtypes = float_dtypes
111+ if api_version > "2021.12" :
112+ all_float_dtypes += complex_dtypes
98113bool_and_all_int_dtypes = (xp .bool ,) + all_int_dtypes
99114
100115
@@ -121,14 +136,19 @@ def is_float_dtype(dtype):
121136 # See https://github.com/numpy/numpy/issues/18434
122137 if dtype is None :
123138 return False
124- return dtype in float_dtypes
139+ valid_dtypes = float_dtypes
140+ if api_version > "2021.12" :
141+ valid_dtypes += complex_dtypes
142+ return dtype in valid_dtypes
125143
126144
127145def get_scalar_type (dtype : DataType ) -> ScalarType :
128146 if is_int_dtype (dtype ):
129147 return int
130148 elif is_float_dtype (dtype ):
131149 return float
150+ elif dtype in complex_dtypes :
151+ return complex
132152 else :
133153 return bool
134154
@@ -157,7 +177,8 @@ class MinMax(NamedTuple):
157177 [(d , 8 ) for d in [xp .int8 , xp .uint8 ]]
158178 + [(d , 16 ) for d in [xp .int16 , xp .uint16 ]]
159179 + [(d , 32 ) for d in [xp .int32 , xp .uint32 , xp .float32 ]]
160- + [(d , 64 ) for d in [xp .int64 , xp .uint64 , xp .float64 ]]
180+ + [(d , 64 ) for d in [xp .int64 , xp .uint64 , xp .float64 , xp .complex64 ]]
181+ + [(xp .complex128 , 128 )]
161182)
162183
163184
@@ -166,6 +187,11 @@ class MinMax(NamedTuple):
166187)
167188
168189
190+ dtype_components = EqualityMapping (
191+ [(xp .complex64 , xp .float32 ), (xp .complex128 , xp .float64 )]
192+ )
193+
194+
169195if isinstance (xp .asarray , _UndefinedStub ):
170196 default_int = xp .int32
171197 default_float = xp .float32
@@ -180,6 +206,15 @@ class MinMax(NamedTuple):
180206 default_float = xp .asarray (float ()).dtype
181207 if default_float not in float_dtypes :
182208 warn (f"inferred default float is { default_float !r} , which is not a float" )
209+ if api_version > "2021.12" :
210+ default_complex = xp .asarray (complex ()).dtype
211+ if default_complex not in complex_dtypes :
212+ warn (
213+ f"inferred default complex is { default_complex !r} , "
214+ "which is not a complex"
215+ )
216+ else :
217+ default_complex = None
183218if dtype_nbits [default_int ] == 32 :
184219 default_uint = xp .uint32
185220else :
@@ -226,6 +261,11 @@ class MinMax(NamedTuple):
226261 ((xp .float32 , xp .float32 ), xp .float32 ),
227262 ((xp .float32 , xp .float64 ), xp .float64 ),
228263 ((xp .float64 , xp .float64 ), xp .float64 ),
264+ # complex
265+ ((xp .complex64 , xp .complex64 ), xp .complex64 ),
266+ ((xp .complex64 , xp .complex128 ), xp .complex128 ),
267+ ((xp .complex128 , xp .complex128 ), xp .complex128 ),
268+
229269]
230270_numeric_promotions += [((d2 , d1 ), res ) for (d1 , d2 ), res in _numeric_promotions ]
231271_promotion_table = list (set (_numeric_promotions ))
0 commit comments