@@ -197,7 +197,7 @@ def isdtype(
197197 else :
198198 raise TypeError (f"'kind' must be a dtype, str, or tuple of dtypes and strs, not { type (kind ).__name__ } " )
199199
200- def result_type (* arrays_and_dtypes : Union [Array , Dtype ]) -> Dtype :
200+ def result_type (* arrays_and_dtypes : Union [Array , Dtype , int , float , complex , bool ]) -> Dtype :
201201 """
202202 Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
203203
@@ -208,19 +208,40 @@ def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
208208 # too many extra type promotions like int64 + uint64 -> float64, and does
209209 # value-based casting on scalar arrays.
210210 A = []
211+ scalars = []
211212 for a in arrays_and_dtypes :
212213 if isinstance (a , Array ):
213214 a = a .dtype
215+ elif isinstance (a , (bool , int , float , complex )):
216+ scalars .append (a )
214217 elif isinstance (a , np .ndarray ) or a not in _all_dtypes :
215218 raise TypeError ("result_type() inputs must be array_api arrays or dtypes" )
216219 A .append (a )
217220
221+ # remove python scalars
222+ A = [a for a in A if not isinstance (a , (bool , int , float , complex ))]
223+
218224 if len (A ) == 0 :
219225 raise ValueError ("at least one array or dtype is required" )
220226 elif len (A ) == 1 :
221- return A [0 ]
227+ result = A [0 ]
222228 else :
223229 t = A [0 ]
224230 for t2 in A [1 :]:
225231 t = _result_type (t , t2 )
226- return t
232+ result = t
233+
234+ if len (scalars ) == 0 :
235+ return result
236+
237+ if get_array_api_strict_flags ()['api_version' ] <= '2023.12' :
238+ raise TypeError ("result_type() inputs must be array_api arrays or dtypes" )
239+
240+ # promote python scalars given the result_type for all arrays/dtypes
241+ from ._creation_functions import empty
242+ arr = empty (1 , dtype = result )
243+ for s in scalars :
244+ x = arr ._promote_scalar (s )
245+ result = _result_type (x .dtype , result )
246+
247+ return result
0 commit comments