@@ -208,6 +208,7 @@ def test_isdtype(dtype, kind):
208208 assert out == expected , f"{ out = } , but should be { expected } [isdtype()]"
209209
210210
211+ @pytest .mark .min_version ("2024.12" )
211212class TestResultType :
212213 @given (dtypes = hh .mutually_promotable_dtypes (None ))
213214 def test_result_type (self , dtypes ):
@@ -230,3 +231,30 @@ def test_arrays_and_dtypes(self, pair, data):
230231 out = xp .result_type (* a_and_dt )
231232 ph .assert_dtype ("result_type" , in_dtype = s1 + s2 , out_dtype = out , repr_name = "out" )
232233
234+ @given (dtypes = hh .mutually_promotable_dtypes (2 ), data = st .data ())
235+ def test_with_scalars (self , dtypes , data ):
236+ out = xp .result_type (* dtypes )
237+
238+ if out == xp .bool :
239+ scalars = [True ]
240+ elif out in dh .all_int_dtypes :
241+ scalars = [1 ]
242+ elif out in dh .real_dtypes :
243+ scalars = [1 , 1.0 ]
244+ elif out in dh .numeric_dtypes :
245+ scalars = [1 , 1.0 , 1j ] # numeric_types - real_types == complex_types
246+ else :
247+ raise ValueError (f"unknown dtype { out = } ." )
248+
249+ scalar = data .draw (st .sampled_from (scalars ))
250+ inputs = data .draw (st .permutations (dtypes + (scalar ,)))
251+
252+ out_scalar = xp .result_type (* inputs )
253+ assert out_scalar == out
254+
255+ # retry with arrays
256+ arrays = tuple (xp .empty (1 , dtype = dt ) for dt in dtypes )
257+ inputs = data .draw (st .permutations (arrays + (scalar ,)))
258+ out_scalar = xp .result_type (* inputs )
259+ assert out_scalar == out
260+
0 commit comments