@@ -715,9 +715,11 @@ def _convert_scalars_helper(x1, x2):
715715 return in_dtypes , in_shapes , (x1a , x2a )
716716
717717
718- def _assert_correctness_binary (name , func , in_dtypes , in_shapes , in_arrs , out , ** kwargs ):
718+ def _assert_correctness_binary (
719+ name , func , in_dtypes , in_shapes , in_arrs , out , expected_dtype = None , ** kwargs
720+ ):
719721 x1a , x2a = in_arrs
720- ph .assert_dtype (name , in_dtype = in_dtypes , out_dtype = out .dtype )
722+ ph .assert_dtype (name , in_dtype = in_dtypes , out_dtype = out .dtype , expected = expected_dtype )
721723 ph .assert_result_shape (name , in_shapes = in_shapes , out_shape = out .shape )
722724 binary_assert_against_refimpl (name , x1a , x2a , out , func , ** kwargs )
723725
@@ -1781,23 +1783,35 @@ def test_trunc(x):
17811783
17821784def _check_binary_with_scalars (func_data , x1x2 ):
17831785 x1 , x2 = x1x2
1784- func , name , refimpl , kwds = func_data
1786+ func , name , refimpl , kwds , expected_dtype = func_data
17851787 out = func (x1 , x2 )
17861788 in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
17871789 _assert_correctness_binary (
1788- name , refimpl , in_dtypes , in_shapes , (x1a , x2a ), out , ** kwds
1790+ name , refimpl , in_dtypes , in_shapes , (x1a , x2a ), out , expected_dtype , ** kwds
17891791 )
17901792
17911793
17921794@pytest .mark .min_version ("2024.12" )
17931795@pytest .mark .parametrize ('func_data' ,
1794- # xp_func, name, refimpl, kwargs
1796+ # xp_func, name, refimpl, kwargs, expected_dtype
17951797 [
1796- (xp .atan2 , "atan2" , math .atan2 , {}),
1797- (xp .hypot , "hypot" , math .hypot , {}),
1798- (xp .logaddexp , "logaddexp" , logaddexp_refimpl , {}),
1799- (xp .maximum , "maximum" , max , {'strict_check' : True }),
1800- (xp .minimum , "minimum" , min , {'strict_check' : True }),
1798+ (xp .add , "add" , operator .add , {}, None ),
1799+ (xp .atan2 , "atan2" , math .atan2 , {}, None ),
1800+ (xp .copysign , "copysign" , math .copysign , {}, None ),
1801+ (xp .divide , "divide" , operator .truediv , {"filter_" : lambda s : s != 0 }, None ),
1802+ (xp .hypot , "hypot" , math .hypot , {}, None ),
1803+ (xp .logaddexp , "logaddexp" , logaddexp_refimpl , {}, None ),
1804+ (xp .maximum , "maximum" , max , {'strict_check' : True }, None ),
1805+ (xp .minimum , "minimum" , min , {'strict_check' : True }, None ),
1806+ (xp .multiply , "mul" , operator .mul , {}, None ),
1807+ (xp .subtract , "sub" , operator .sub , {}, None ),
1808+
1809+ (xp .equal , "equal" , operator .eq , {}, xp .bool ),
1810+ (xp .not_equal , "neq" , operator .ne , {}, xp .bool ),
1811+ (xp .less , "less" , operator .lt , {}, xp .bool ),
1812+ (xp .less_equal , "les_equal" , operator .le , {}, xp .bool ),
1813+ (xp .greater , "greater" , operator .gt , {}, xp .bool ),
1814+ (xp .greater_equal , "greater_equal" , operator .ge , {}, xp .bool ),
18011815 ],
18021816 ids = lambda func_data : func_data [1 ] # use names for test IDs
18031817)
@@ -1808,14 +1822,15 @@ def test_binary_with_scalars_real(func_data, x1x2):
18081822
18091823@pytest .mark .min_version ("2024.12" )
18101824@pytest .mark .parametrize ('func_data' ,
1811- # xp_func, name, refimpl, kwargs
1825+ # xp_func, name, refimpl, kwargs, expected_dtype
18121826 [
1813- (xp .logical_and , "logical_and" , operator .and_ , {"expr_template" : "({} or {})={}" }),
1814- (xp .logical_or , "logical_or" , operator .or_ , {"expr_template" : "({} or {})={}" }),
1815- (xp .logical_xor , "logical_xor" , operator .xor , {"expr_template" : "({} or {})={}" }),
1827+ (xp .logical_and , "logical_and" , operator .and_ , {"expr_template" : "({} or {})={}" }, None ),
1828+ (xp .logical_or , "logical_or" , operator .or_ , {"expr_template" : "({} or {})={}" }, None ),
1829+ (xp .logical_xor , "logical_xor" , operator .xor , {"expr_template" : "({} or {})={}" }, None ),
18161830 ],
18171831 ids = lambda func_data : func_data [1 ] # use names for test IDs
18181832)
18191833@given (x1x2 = hh .array_and_py_scalar ([xp .bool ]))
18201834def test_binary_with_scalars_bool (func_data , x1x2 ):
18211835 _check_binary_with_scalars (func_data , x1x2 )
1836+
0 commit comments