@@ -1880,3 +1880,33 @@ def test_binary_with_scalars_bitwise(func_data, x1x2):
18801880 refimpl_ = lambda l , r : mock_int_dtype (refimpl (l , r ), xp .int32 )
18811881 _check_binary_with_scalars ((func_name , refimpl_ , kwargs , expected ), x1x2 )
18821882
1883+
1884+ @pytest .mark .unvectorized
1885+ @given (
1886+ x1x2 = hh .array_and_py_scalar ([xp .int32 ]),
1887+ data = st .data ()
1888+ )
1889+ def test_where_with_scalars (x1x2 , data ):
1890+ x1 , x2 = x1x2
1891+
1892+ if dh .is_scalar (x1 ):
1893+ dtype , shape = x2 .dtype , x2 .shape
1894+ x1_arr , x2_arr = xp .broadcast_to (xp .asarray (x1 ), shape ), x2
1895+ else :
1896+ dtype , shape = x1 .dtype , x1 .shape
1897+ x1_arr , x2_arr = x1 , xp .broadcast_to (xp .asarray (x2 ), shape )
1898+
1899+ condition = data .draw (hh .arrays (shape = shape , dtype = xp .bool ))
1900+
1901+ out = xp .where (condition , x1 , x2 )
1902+
1903+ assert out .dtype == dtype , f"where: got { out .dtype = } for { dtype = } , { x1 = } and { x2 = } "
1904+ assert out .shape == shape , f"where: got { out .shape = } for { shape = } , { x1 = } and { x2 = } "
1905+
1906+ # value test
1907+ for idx in sh .ndindex (shape ):
1908+ if condition [idx ]:
1909+ assert out [idx ] == x1_arr [idx ]
1910+ else :
1911+ assert out [idx ] == x2_arr [idx ]
1912+
0 commit comments