@@ -690,6 +690,38 @@ def binary_param_assert_against_refimpl(
690690 )
691691
692692
693+ def _convert_scalars_helper (x1 , x2 ):
694+ """Convert python scalar to arrays, record the shapes/dtypes of arrays.
695+
696+ For inputs being scalars or arrays, return the dtypes and shapes of array arguments,
697+ and all arguments converted to arrays.
698+
699+ dtypes are separate to help distinguishing between
700+ `py_scalar + f32_array -> f32_array` and `f64_array + f32_array -> f64_array`
701+ """
702+ if dh .is_scalar (x1 ):
703+ in_dtypes = [x2 .dtype ]
704+ in_shapes = [x2 .shape ]
705+ x1a , x2a = xp .asarray (x1 ), x2
706+ elif dh .is_scalar (x2 ):
707+ in_dtypes = [x1 .dtype ]
708+ in_shapes = [x1 .shape ]
709+ x1a , x2a = x1 , xp .asarray (x2 )
710+ else :
711+ in_dtypes = [x1 .dtype , x2 .dtype ]
712+ in_shapes = [x1 .shape , x2 .shape ]
713+ x1a , x2a = x1 , x2
714+
715+ return in_dtypes , in_shapes , (x1a , x2a )
716+
717+
718+ def _assert_correctness_binary (name , func , in_dtypes , in_shapes , in_arrs , out , ** kwargs ):
719+ x1a , x2a = in_arrs
720+ ph .assert_dtype (name , in_dtype = in_dtypes , out_dtype = out .dtype )
721+ ph .assert_result_shape (name , in_shapes = in_shapes , out_shape = out .shape )
722+ binary_assert_against_refimpl (name , x1a , x2a , out , func , ** kwargs )
723+
724+
693725@pytest .mark .parametrize ("ctx" , make_unary_params ("abs" , dh .numeric_dtypes ))
694726@given (data = st .data ())
695727def test_abs (ctx , data ):
@@ -789,10 +821,14 @@ def test_atan(x):
789821@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
790822def test_atan2 (x1 , x2 ):
791823 out = xp .atan2 (x1 , x2 )
792- ph .assert_dtype ("atan2" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
793- ph .assert_result_shape ("atan2" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
794- refimpl = cmath .atan2 if x1 .dtype in dh .complex_dtypes else math .atan2
795- binary_assert_against_refimpl ("atan2" , x1 , x2 , out , refimpl )
824+ _assert_correctness_binary (
825+ "atan" ,
826+ cmath .atan2 if x1 .dtype in dh .complex_dtypes else math .atan2 ,
827+ in_dtypes = [x1 .dtype , x2 .dtype ],
828+ in_shapes = [x1 .shape , x2 .shape ],
829+ in_arrs = [x1 , x2 ],
830+ out = out ,
831+ )
796832
797833
798834@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
@@ -1258,10 +1294,14 @@ def test_greater_equal(ctx, data):
12581294@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
12591295def test_hypot (x1 , x2 ):
12601296 out = xp .hypot (x1 , x2 )
1261- ph .assert_dtype ("hypot" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1262- ph .assert_result_shape ("hypot" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1263- binary_assert_against_refimpl ("hypot" , x1 , x2 , out , math .hypot )
1264-
1297+ _assert_correctness_binary (
1298+ "hypot" ,
1299+ math .hypot ,
1300+ in_dtypes = [x1 .dtype , x2 .dtype ],
1301+ in_shapes = [x1 .shape , x2 .shape ],
1302+ in_arrs = [x1 , x2 ],
1303+ out = out
1304+ )
12651305
12661306
12671307@pytest .mark .min_version ("2022.12" )
@@ -1411,21 +1451,17 @@ def logaddexp_refimpl(l: float, r: float) -> float:
14111451 raise OverflowError
14121452
14131453
1454+ @pytest .mark .min_version ("2023.12" )
14141455@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
14151456def test_logaddexp (x1 , x2 ):
14161457 out = xp .logaddexp (x1 , x2 )
1417- ph .assert_dtype ("logaddexp" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1418- ph .assert_result_shape ("logaddexp" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1419- binary_assert_against_refimpl ("logaddexp" , x1 , x2 , out , logaddexp_refimpl )
1420-
1421-
1422- @given (* hh .two_mutual_arrays ([xp .bool ]))
1423- def test_logical_and (x1 , x2 ):
1424- out = xp .logical_and (x1 , x2 )
1425- ph .assert_dtype ("logical_and" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1426- ph .assert_result_shape ("logical_and" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1427- binary_assert_against_refimpl (
1428- "logical_and" , x1 , x2 , out , operator .and_ , expr_template = "({} and {})={}"
1458+ _assert_correctness_binary (
1459+ "logaddexp" ,
1460+ logaddexp_refimpl ,
1461+ in_dtypes = [x1 .dtype , x2 .dtype ],
1462+ in_shapes = [x1 .shape , x2 .shape ],
1463+ in_arrs = [x1 , x2 ],
1464+ out = out
14291465 )
14301466
14311467
@@ -1439,42 +1475,64 @@ def test_logical_not(x):
14391475 )
14401476
14411477
1478+ @given (* hh .two_mutual_arrays ([xp .bool ]))
1479+ def test_logical_and (x1 , x2 ):
1480+ out = xp .logical_and (x1 , x2 )
1481+ _assert_correctness_binary (
1482+ "logical_and" ,
1483+ operator .and_ ,
1484+ in_dtypes = [x1 .dtype , x2 .dtype ],
1485+ in_shapes = [x1 .shape , x2 .shape ],
1486+ in_arrs = [x1 , x2 ],
1487+ out = out ,
1488+ expr_template = "({} and {})={}"
1489+ )
1490+
1491+
14421492@given (* hh .two_mutual_arrays ([xp .bool ]))
14431493def test_logical_or (x1 , x2 ):
14441494 out = xp .logical_or (x1 , x2 )
1445- ph .assert_dtype ("logical_or" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1446- ph .assert_result_shape ("logical_or" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1447- binary_assert_against_refimpl (
1448- "logical_or" , x1 , x2 , out , operator .or_ , expr_template = "({} or {})={}"
1495+ _assert_correctness_binary (
1496+ "logical_or" ,
1497+ operator .or_ ,
1498+ in_dtypes = [x1 .dtype , x2 .dtype ],
1499+ in_shapes = [x1 .shape , x2 .shape ],
1500+ in_arrs = [x1 , x2 ],
1501+ out = out ,
1502+ expr_template = "({} or {})={}"
14491503 )
14501504
14511505
14521506@given (* hh .two_mutual_arrays ([xp .bool ]))
14531507def test_logical_xor (x1 , x2 ):
14541508 out = xp .logical_xor (x1 , x2 )
1455- ph .assert_dtype ("logical_xor" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1456- ph .assert_result_shape ("logical_xor" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1457- binary_assert_against_refimpl (
1458- "logical_xor" , x1 , x2 , out , operator .xor , expr_template = "({} ^ {})={}"
1509+ _assert_correctness_binary (
1510+ "logical_xor" ,
1511+ operator .xor ,
1512+ in_dtypes = [x1 .dtype , x2 .dtype ],
1513+ in_shapes = [x1 .shape , x2 .shape ],
1514+ in_arrs = [x1 , x2 ],
1515+ out = out ,
1516+ expr_template = "({} ^ {})={}"
14591517 )
14601518
14611519
14621520@pytest .mark .min_version ("2023.12" )
14631521@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
14641522def test_maximum (x1 , x2 ):
14651523 out = xp .maximum (x1 , x2 )
1466- ph . assert_dtype ( "maximum" , in_dtype = [ x1 . dtype , x2 . dtype ], out_dtype = out . dtype )
1467- ph . assert_result_shape ( "maximum" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1468- binary_assert_against_refimpl ( "maximum" , x1 , x2 , out , max , strict_check = True )
1524+ _assert_correctness_binary (
1525+ "maximum" , max , [x1 .dtype , x2 .dtype ], [ x1 .shape , x2 . shape ], ( x1 , x2 ), out , strict_check = True
1526+ )
14691527
14701528
14711529@pytest .mark .min_version ("2023.12" )
14721530@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
14731531def test_minimum (x1 , x2 ):
14741532 out = xp .minimum (x1 , x2 )
1475- ph . assert_dtype ( "minimum" , in_dtype = [ x1 . dtype , x2 . dtype ], out_dtype = out . dtype )
1476- ph . assert_result_shape ( "minimum" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1477- binary_assert_against_refimpl ( "minimum" , x1 , x2 , out , min , strict_check = True )
1533+ _assert_correctness_binary (
1534+ "minimum" , min , [x1 .dtype , x2 .dtype ], [ x1 .shape , x2 . shape ], ( x1 , x2 ), out , strict_check = True
1535+ )
14781536
14791537
14801538@pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , dh .numeric_dtypes ))
@@ -1719,3 +1777,45 @@ def test_trunc(x):
17191777 ph .assert_dtype ("trunc" , in_dtype = x .dtype , out_dtype = out .dtype )
17201778 ph .assert_shape ("trunc" , out_shape = out .shape , expected = x .shape )
17211779 unary_assert_against_refimpl ("trunc" , x , out , math .trunc , strict_check = True )
1780+
1781+
1782+ def _check_binary_with_scalars (func_data , x1x2 ):
1783+ x1 , x2 = x1x2
1784+ func , name , refimpl , kwds = func_data
1785+ out = func (x1 , x2 )
1786+ in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
1787+ _assert_correctness_binary (
1788+ name , refimpl , in_dtypes , in_shapes , (x1a , x2a ), out , ** kwds
1789+ )
1790+
1791+
1792+ @pytest .mark .min_version ("2024.12" )
1793+ @pytest .mark .parametrize ('func_data' ,
1794+ # xp_func, name, refimpl, kwargs
1795+ [
1796+ (xp .atan2 , "atan2" , math .atan2 , {}),
1797+ (xp .hypot , "hypot" , math .hypot , {}),
1798+ (xp .logaddexp , "logaddexp" , logaddexp_refimpl , {}),
1799+ (xp .maximum , "maximum" , max , {'strict_check' : True }),
1800+ (xp .minimum , "minimum" , min , {'strict_check' : True }),
1801+ ],
1802+ ids = lambda func_data : func_data [1 ] # use names for test IDs
1803+ )
1804+ @given (x1x2 = hh .array_and_py_scalar (dh .real_float_dtypes ))
1805+ def test_binary_with_scalars_real (func_data , x1x2 ):
1806+ _check_binary_with_scalars (func_data , x1x2 )
1807+
1808+
1809+ @pytest .mark .min_version ("2024.12" )
1810+ @pytest .mark .parametrize ('func_data' ,
1811+ # xp_func, name, refimpl, kwargs
1812+ [
1813+ (xp .logical_and , "logical_and" , operator .and_ , {"expr_template" : "({} or {})={}" }),
1814+ (xp .logical_or , "logical_or" , operator .or_ , {"expr_template" : "({} or {})={}" }),
1815+ (xp .logical_xor , "logical_xor" , operator .xor , {"expr_template" : "({} or {})={}" }),
1816+ ],
1817+ ids = lambda func_data : func_data [1 ] # use names for test IDs
1818+ )
1819+ @given (x1x2 = hh .array_and_py_scalar ([xp .bool ]))
1820+ def test_binary_with_scalars_bool (func_data , x1x2 ):
1821+ _check_binary_with_scalars (func_data , x1x2 )
0 commit comments