File tree Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -103,9 +103,13 @@ def __repr__(self):
103103all_int_dtypes = uint_dtypes + int_dtypes
104104real_dtypes = all_int_dtypes + float_dtypes
105105complex_dtypes = tuple (getattr (xp , name ) for name in _complex_names )
106- numeric_dtypes = real_dtypes + complex_dtypes
106+ numeric_dtypes = real_dtypes
107+ if api_version > "2021.12" :
108+ numeric_dtypes += complex_dtypes
107109all_dtypes = (xp .bool ,) + numeric_dtypes
108- all_float_dtypes = float_dtypes + complex_dtypes
110+ all_float_dtypes = float_dtypes
111+ if api_version > "2021.12" :
112+ all_float_dtypes += complex_dtypes
109113bool_and_all_int_dtypes = (xp .bool ,) + all_int_dtypes
110114
111115
@@ -132,7 +136,10 @@ def is_float_dtype(dtype):
132136 # See https://github.com/numpy/numpy/issues/18434
133137 if dtype is None :
134138 return False
135- return dtype in float_dtypes
139+ valid_dtypes = float_dtypes
140+ if api_version > "2021.12" :
141+ valid_dtypes += complex_dtypes
142+ return dtype in valid_dtypes
136143
137144
138145def get_scalar_type (dtype : DataType ) -> ScalarType :
You can’t perform that action at this time.
0 commit comments