@@ -690,6 +690,40 @@ def binary_param_assert_against_refimpl(
690690 )
691691
692692
693+ def _convert_scalars_helper (x1 , x2 ):
694+ """Convert python scalar to arrays, record the shapes/dtypes of arrays.
695+
696+ For inputs being scalars or arrays, return the dtypes and shapes of array arguments,
697+ and all arguments converted to arrays.
698+
699+ dtypes are separate to help distinguishing between
700+ `py_scalar + f32_array -> f32_array` and `f64_array + f32_array -> f64_array`
701+ """
702+ if dh .is_scalar (x1 ):
703+ in_dtypes = [x2 .dtype ]
704+ in_shapes = [x2 .shape ]
705+ x1a , x2a = xp .asarray (x1 ), x2
706+ elif dh .is_scalar (x2 ):
707+ in_dtypes = [x1 .dtype ]
708+ in_shapes = [x1 .shape ]
709+ x1a , x2a = x1 , xp .asarray (x2 )
710+ else :
711+ in_dtypes = [x1 .dtype , x2 .dtype ]
712+ in_shapes = [x1 .shape , x2 .shape ]
713+ x1a , x2a = x1 , x2
714+
715+ return in_dtypes , in_shapes , (x1a , x2a )
716+
717+
718+ def _assert_correctness_binary (
719+ name , func , in_dtypes , in_shapes , in_arrs , out , expected_dtype = None , ** kwargs
720+ ):
721+ x1a , x2a = in_arrs
722+ ph .assert_dtype (name , in_dtype = in_dtypes , out_dtype = out .dtype , expected = expected_dtype )
723+ ph .assert_result_shape (name , in_shapes = in_shapes , out_shape = out .shape )
724+ binary_assert_against_refimpl (name , x1a , x2a , out , func , ** kwargs )
725+
726+
693727@pytest .mark .parametrize ("ctx" , make_unary_params ("abs" , dh .numeric_dtypes ))
694728@given (data = st .data ())
695729def test_abs (ctx , data ):
@@ -789,10 +823,14 @@ def test_atan(x):
789823@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
790824def test_atan2 (x1 , x2 ):
791825 out = xp .atan2 (x1 , x2 )
792- ph .assert_dtype ("atan2" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
793- ph .assert_result_shape ("atan2" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
794- refimpl = cmath .atan2 if x1 .dtype in dh .complex_dtypes else math .atan2
795- binary_assert_against_refimpl ("atan2" , x1 , x2 , out , refimpl )
826+ _assert_correctness_binary (
827+ "atan" ,
828+ cmath .atan2 if x1 .dtype in dh .complex_dtypes else math .atan2 ,
829+ in_dtypes = [x1 .dtype , x2 .dtype ],
830+ in_shapes = [x1 .shape , x2 .shape ],
831+ in_arrs = [x1 , x2 ],
832+ out = out ,
833+ )
796834
797835
798836@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
@@ -1258,10 +1296,14 @@ def test_greater_equal(ctx, data):
12581296@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
12591297def test_hypot (x1 , x2 ):
12601298 out = xp .hypot (x1 , x2 )
1261- ph .assert_dtype ("hypot" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1262- ph .assert_result_shape ("hypot" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1263- binary_assert_against_refimpl ("hypot" , x1 , x2 , out , math .hypot )
1264-
1299+ _assert_correctness_binary (
1300+ "hypot" ,
1301+ math .hypot ,
1302+ in_dtypes = [x1 .dtype , x2 .dtype ],
1303+ in_shapes = [x1 .shape , x2 .shape ],
1304+ in_arrs = [x1 , x2 ],
1305+ out = out
1306+ )
12651307
12661308
12671309@pytest .mark .min_version ("2022.12" )
@@ -1411,21 +1453,17 @@ def logaddexp_refimpl(l: float, r: float) -> float:
14111453 raise OverflowError
14121454
14131455
1456+ @pytest .mark .min_version ("2023.12" )
14141457@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
14151458def test_logaddexp (x1 , x2 ):
14161459 out = xp .logaddexp (x1 , x2 )
1417- ph .assert_dtype ("logaddexp" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1418- ph .assert_result_shape ("logaddexp" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1419- binary_assert_against_refimpl ("logaddexp" , x1 , x2 , out , logaddexp_refimpl )
1420-
1421-
1422- @given (* hh .two_mutual_arrays ([xp .bool ]))
1423- def test_logical_and (x1 , x2 ):
1424- out = xp .logical_and (x1 , x2 )
1425- ph .assert_dtype ("logical_and" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1426- ph .assert_result_shape ("logical_and" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1427- binary_assert_against_refimpl (
1428- "logical_and" , x1 , x2 , out , operator .and_ , expr_template = "({} and {})={}"
1460+ _assert_correctness_binary (
1461+ "logaddexp" ,
1462+ logaddexp_refimpl ,
1463+ in_dtypes = [x1 .dtype , x2 .dtype ],
1464+ in_shapes = [x1 .shape , x2 .shape ],
1465+ in_arrs = [x1 , x2 ],
1466+ out = out
14291467 )
14301468
14311469
@@ -1439,42 +1477,64 @@ def test_logical_not(x):
14391477 )
14401478
14411479
1480+ @given (* hh .two_mutual_arrays ([xp .bool ]))
1481+ def test_logical_and (x1 , x2 ):
1482+ out = xp .logical_and (x1 , x2 )
1483+ _assert_correctness_binary (
1484+ "logical_and" ,
1485+ operator .and_ ,
1486+ in_dtypes = [x1 .dtype , x2 .dtype ],
1487+ in_shapes = [x1 .shape , x2 .shape ],
1488+ in_arrs = [x1 , x2 ],
1489+ out = out ,
1490+ expr_template = "({} and {})={}"
1491+ )
1492+
1493+
14421494@given (* hh .two_mutual_arrays ([xp .bool ]))
14431495def test_logical_or (x1 , x2 ):
14441496 out = xp .logical_or (x1 , x2 )
1445- ph .assert_dtype ("logical_or" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1446- ph .assert_result_shape ("logical_or" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1447- binary_assert_against_refimpl (
1448- "logical_or" , x1 , x2 , out , operator .or_ , expr_template = "({} or {})={}"
1497+ _assert_correctness_binary (
1498+ "logical_or" ,
1499+ operator .or_ ,
1500+ in_dtypes = [x1 .dtype , x2 .dtype ],
1501+ in_shapes = [x1 .shape , x2 .shape ],
1502+ in_arrs = [x1 , x2 ],
1503+ out = out ,
1504+ expr_template = "({} or {})={}"
14491505 )
14501506
14511507
14521508@given (* hh .two_mutual_arrays ([xp .bool ]))
14531509def test_logical_xor (x1 , x2 ):
14541510 out = xp .logical_xor (x1 , x2 )
1455- ph .assert_dtype ("logical_xor" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1456- ph .assert_result_shape ("logical_xor" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1457- binary_assert_against_refimpl (
1458- "logical_xor" , x1 , x2 , out , operator .xor , expr_template = "({} ^ {})={}"
1511+ _assert_correctness_binary (
1512+ "logical_xor" ,
1513+ operator .xor ,
1514+ in_dtypes = [x1 .dtype , x2 .dtype ],
1515+ in_shapes = [x1 .shape , x2 .shape ],
1516+ in_arrs = [x1 , x2 ],
1517+ out = out ,
1518+ expr_template = "({} ^ {})={}"
14591519 )
14601520
14611521
14621522@pytest .mark .min_version ("2023.12" )
14631523@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
14641524def test_maximum (x1 , x2 ):
14651525 out = xp .maximum (x1 , x2 )
1466- ph . assert_dtype ( "maximum" , in_dtype = [ x1 . dtype , x2 . dtype ], out_dtype = out . dtype )
1467- ph . assert_result_shape ( "maximum" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1468- binary_assert_against_refimpl ( "maximum" , x1 , x2 , out , max , strict_check = True )
1526+ _assert_correctness_binary (
1527+ "maximum" , max , [x1 .dtype , x2 .dtype ], [ x1 .shape , x2 . shape ], ( x1 , x2 ), out , strict_check = True
1528+ )
14691529
14701530
14711531@pytest .mark .min_version ("2023.12" )
14721532@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
14731533def test_minimum (x1 , x2 ):
14741534 out = xp .minimum (x1 , x2 )
1475- ph . assert_dtype ( "minimum" , in_dtype = [ x1 . dtype , x2 . dtype ], out_dtype = out . dtype )
1476- ph . assert_result_shape ( "minimum" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1477- binary_assert_against_refimpl ( "minimum" , x1 , x2 , out , min , strict_check = True )
1535+ _assert_correctness_binary (
1536+ "minimum" , min , [x1 .dtype , x2 .dtype ], [ x1 .shape , x2 . shape ], ( x1 , x2 ), out , strict_check = True
1537+ )
14781538
14791539
14801540@pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , dh .numeric_dtypes ))
@@ -1719,3 +1779,88 @@ def test_trunc(x):
17191779 ph .assert_dtype ("trunc" , in_dtype = x .dtype , out_dtype = out .dtype )
17201780 ph .assert_shape ("trunc" , out_shape = out .shape , expected = x .shape )
17211781 unary_assert_against_refimpl ("trunc" , x , out , math .trunc , strict_check = True )
1782+
1783+
1784+ def _check_binary_with_scalars (func_data , x1x2 ):
1785+ x1 , x2 = x1x2
1786+ func , name , refimpl , kwds , expected_dtype = func_data
1787+ out = func (x1 , x2 )
1788+ in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
1789+ _assert_correctness_binary (
1790+ name , refimpl , in_dtypes , in_shapes , (x1a , x2a ), out , expected_dtype , ** kwds
1791+ )
1792+
1793+
1794+ def _filter_zero (x ):
1795+ return x != 0 if dh .is_scalar (x ) else (not xp .any (x == 0 ))
1796+
1797+
1798+ @pytest .mark .min_version ("2024.12" )
1799+ @pytest .mark .parametrize ('func_data' ,
1800+ # xp_func, name, refimpl, kwargs, expected_dtype
1801+ [
1802+ (xp .add , "add" , operator .add , {}, None ),
1803+ (xp .atan2 , "atan2" , math .atan2 , {}, None ),
1804+ (xp .copysign , "copysign" , math .copysign , {}, None ),
1805+ (xp .divide , "divide" , operator .truediv , {"filter_" : lambda s : s != 0 }, None ),
1806+ (xp .hypot , "hypot" , math .hypot , {}, None ),
1807+ (xp .logaddexp , "logaddexp" , logaddexp_refimpl , {}, None ),
1808+ (xp .maximum , "maximum" , max , {'strict_check' : True }, None ),
1809+ (xp .minimum , "minimum" , min , {'strict_check' : True }, None ),
1810+ (xp .multiply , "mul" , operator .mul , {}, None ),
1811+ (xp .subtract , "sub" , operator .sub , {}, None ),
1812+
1813+ (xp .equal , "equal" , operator .eq , {}, xp .bool ),
1814+ (xp .not_equal , "neq" , operator .ne , {}, xp .bool ),
1815+ (xp .less , "less" , operator .lt , {}, xp .bool ),
1816+ (xp .less_equal , "les_equal" , operator .le , {}, xp .bool ),
1817+ (xp .greater , "greater" , operator .gt , {}, xp .bool ),
1818+ (xp .greater_equal , "greater_equal" , operator .ge , {}, xp .bool ),
1819+ (xp .remainder , "remainder" , operator .mod , {}, None ),
1820+ (xp .floor_divide , "floor_divide" , operator .floordiv , {}, None ),
1821+ ],
1822+ ids = lambda func_data : func_data [1 ] # use names for test IDs
1823+ )
1824+ @given (x1x2 = hh .array_and_py_scalar (dh .real_float_dtypes ))
1825+ def test_binary_with_scalars_real (func_data , x1x2 ):
1826+
1827+ if func_data [1 ] == "remainder" :
1828+ assume (_filter_zero (x1x2 [1 ]))
1829+ if func_data [1 ] == "floor_divide" :
1830+ assume (_filter_zero (x1x2 [0 ]) and _filter_zero (x1x2 [1 ]))
1831+
1832+ _check_binary_with_scalars (func_data , x1x2 )
1833+
1834+
1835+ @pytest .mark .min_version ("2024.12" )
1836+ @pytest .mark .parametrize ('func_data' ,
1837+ # xp_func, name, refimpl, kwargs, expected_dtype
1838+ [
1839+ (xp .logical_and , "logical_and" , operator .and_ , {"expr_template" : "({} or {})={}" }, None ),
1840+ (xp .logical_or , "logical_or" , operator .or_ , {"expr_template" : "({} or {})={}" }, None ),
1841+ (xp .logical_xor , "logical_xor" , operator .xor , {"expr_template" : "({} or {})={}" }, None ),
1842+ ],
1843+ ids = lambda func_data : func_data [1 ] # use names for test IDs
1844+ )
1845+ @given (x1x2 = hh .array_and_py_scalar ([xp .bool ]))
1846+ def test_binary_with_scalars_bool (func_data , x1x2 ):
1847+ _check_binary_with_scalars (func_data , x1x2 )
1848+
1849+
1850+ @pytest .mark .min_version ("2024.12" )
1851+ @pytest .mark .parametrize ('func_data' ,
1852+ # xp_func, name, refimpl, kwargs, expected_dtype
1853+ [
1854+ (xp .bitwise_and , "bitwise_and" , operator .and_ , {}, None ),
1855+ (xp .bitwise_or , "bitwise_or" , operator .or_ , {}, None ),
1856+ (xp .bitwise_xor , "bitwise_xor" , operator .xor , {}, None ),
1857+ ],
1858+ ids = lambda func_data : func_data [1 ] # use names for test IDs
1859+ )
1860+ @given (x1x2 = hh .array_and_py_scalar ([xp .int32 ]))
1861+ def test_binary_with_scalars_bitwise (func_data , x1x2 ):
1862+ xp_func , name , refimpl , kwargs , expected = func_data
1863+ # repack the refimpl
1864+ refimpl_ = lambda l , r : mock_int_dtype (refimpl (l , r ), xp .int32 )
1865+ _check_binary_with_scalars ((xp_func , name , refimpl_ , kwargs ,expected ), x1x2 )
1866+
0 commit comments