@@ -208,7 +208,53 @@ def test_isdtype(dtype, kind):
208208 assert out == expected , f"{ out = } , but should be { expected } [isdtype()]"
209209
210210
211- @given (hh .mutually_promotable_dtypes (None ))
212- def test_result_type (dtypes ):
213- out = xp .result_type (* dtypes )
214- ph .assert_dtype ("result_type" , in_dtype = dtypes , out_dtype = out , repr_name = "out" )
211+ @pytest .mark .min_version ("2024.12" )
212+ class TestResultType :
213+ @given (dtypes = hh .mutually_promotable_dtypes (None ))
214+ def test_result_type (self , dtypes ):
215+ out = xp .result_type (* dtypes )
216+ ph .assert_dtype ("result_type" , in_dtype = dtypes , out_dtype = out , repr_name = "out" )
217+
218+ @given (pair = hh .pair_of_mutually_promotable_dtypes (None ))
219+ def test_shuffled (self , pair ):
220+ """Test that result_type is insensitive to the order of arguments."""
221+ s1 , s2 = pair
222+ out1 = xp .result_type (* s1 )
223+ out2 = xp .result_type (* s2 )
224+ assert out1 == out2
225+
226+ @given (pair = hh .pair_of_mutually_promotable_dtypes (2 ), data = st .data ())
227+ def test_arrays_and_dtypes (self , pair , data ):
228+ s1 , s2 = pair
229+ a2 = tuple (xp .empty (1 , dtype = dt ) for dt in s2 )
230+ a_and_dt = data .draw (st .permutations (s1 + a2 ))
231+ out = xp .result_type (* a_and_dt )
232+ ph .assert_dtype ("result_type" , in_dtype = s1 + s2 , out_dtype = out , repr_name = "out" )
233+
234+ @given (dtypes = hh .mutually_promotable_dtypes (2 ), data = st .data ())
235+ def test_with_scalars (self , dtypes , data ):
236+ out = xp .result_type (* dtypes )
237+
238+ if out == xp .bool :
239+ scalars = [True ]
240+ elif out in dh .all_int_dtypes :
241+ scalars = [1 ]
242+ elif out in dh .real_dtypes :
243+ scalars = [1 , 1.0 ]
244+ elif out in dh .numeric_dtypes :
245+ scalars = [1 , 1.0 , 1j ] # numeric_types - real_types == complex_types
246+ else :
247+ raise ValueError (f"unknown dtype { out = } ." )
248+
249+ scalar = data .draw (st .sampled_from (scalars ))
250+ inputs = data .draw (st .permutations (dtypes + (scalar ,)))
251+
252+ out_scalar = xp .result_type (* inputs )
253+ assert out_scalar == out
254+
255+ # retry with arrays
256+ arrays = tuple (xp .empty (1 , dtype = dt ) for dt in dtypes )
257+ inputs = data .draw (st .permutations (arrays + (scalar ,)))
258+ out_scalar = xp .result_type (* inputs )
259+ assert out_scalar == out
260+
0 commit comments