@@ -1901,3 +1901,33 @@ def test_binary_with_scalars_bitwise(func_data, x1x2):
19011901 refimpl_ = lambda l , r : mock_int_dtype (refimpl (l , r ), xp .int32 )
19021902 _check_binary_with_scalars ((func_name , refimpl_ , kwargs , expected ), x1x2 )
19031903
1904+
1905+ @pytest .mark .unvectorized
1906+ @given (
1907+ x1x2 = hh .array_and_py_scalar ([xp .int32 ]),
1908+ data = st .data ()
1909+ )
1910+ def test_where_with_scalars (x1x2 , data ):
1911+ x1 , x2 = x1x2
1912+
1913+ if dh .is_scalar (x1 ):
1914+ dtype , shape = x2 .dtype , x2 .shape
1915+ x1_arr , x2_arr = xp .broadcast_to (xp .asarray (x1 ), shape ), x2
1916+ else :
1917+ dtype , shape = x1 .dtype , x1 .shape
1918+ x1_arr , x2_arr = x1 , xp .broadcast_to (xp .asarray (x2 ), shape )
1919+
1920+ condition = data .draw (hh .arrays (shape = shape , dtype = xp .bool ))
1921+
1922+ out = xp .where (condition , x1 , x2 )
1923+
1924+ assert out .dtype == dtype , f"where: got { out .dtype = } for { dtype = } , { x1 = } and { x2 = } "
1925+ assert out .shape == shape , f"where: got { out .shape = } for { shape = } , { x1 = } and { x2 = } "
1926+
1927+ # value test
1928+ for idx in sh .ndindex (shape ):
1929+ if condition [idx ]:
1930+ assert out [idx ] == x1_arr [idx ]
1931+ else :
1932+ assert out [idx ] == x2_arr [idx ]
1933+
0 commit comments