@@ -1394,6 +1394,35 @@ def sample_inputs_scatter_src(op_info, device, dtype, requires_grad, **kwargs):
13941394 src_tensor = make_arg (src_shape )
13951395 yield opinfo_core .SampleInput (self_tensor , args = (dim , index_tensor , src_tensor ))
13961396
1397+ # Additional test cases for scalar and single-element tensor combinations with dim=0
1398+ # Test case: scalar index, scalar src (dim_size=5)
1399+ dim_size = 5
1400+ data_1d = make_arg ((dim_size ,))
1401+ valid_index = torch .randint (0 , dim_size , (), device = device , dtype = torch .long )
1402+ scalar_src = make_arg (())
1403+ yield opinfo_core .SampleInput (data_1d , args = (0 , valid_index , scalar_src ))
1404+
1405+ # Test case: single-element tensor index, scalar src (dim_size=7)
1406+ dim_size = 7
1407+ data_1d = make_arg ((dim_size ,))
1408+ valid_index_1d = torch .randint (0 , dim_size , (1 ,), device = device , dtype = torch .long )
1409+ scalar_src = make_arg (())
1410+ yield opinfo_core .SampleInput (data_1d , args = (0 , valid_index_1d , scalar_src ))
1411+
1412+ # Test case: scalar index, single-element tensor src (dim_size=3)
1413+ dim_size = 3
1414+ data_1d = make_arg ((dim_size ,))
1415+ valid_index = torch .randint (0 , dim_size , (), device = device , dtype = torch .long )
1416+ src_1d = make_arg ((1 ,))
1417+ yield opinfo_core .SampleInput (data_1d , args = (0 , valid_index , src_1d ))
1418+
1419+ # Test case: single-element tensor index, single-element tensor src (dim_size=10)
1420+ dim_size = 10
1421+ data_1d = make_arg ((dim_size ,))
1422+ valid_index_1d = torch .randint (0 , dim_size , (1 ,), device = device , dtype = torch .long )
1423+ src_1d = make_arg ((1 ,))
1424+ yield opinfo_core .SampleInput (data_1d , args = (0 , valid_index_1d , src_1d ))
1425+
13971426
13981427def sample_inputs_scatter_value (op_info , device , dtype , requires_grad , ** kwargs ):
13991428 del op_info
@@ -1423,6 +1452,21 @@ def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs)
14231452 ]
14241453 yield opinfo_core .SampleInput (self_tensor , args = (dim , index_tensor , value ))
14251454
1455+ # Additional test cases for scalar and single-element tensor combinations with dim=0
1456+ # Test case: scalar index with scalar value (dim_size=6, value_type=torch.long)
1457+ dim_size = 6
1458+ data_1d = make_arg ((dim_size ,))
1459+ valid_index = torch .randint (0 , dim_size , (), device = device , dtype = torch .long )
1460+ random_value = torch .randint (0 , 10 , (), device = device , dtype = torch .long ).item ()
1461+ yield opinfo_core .SampleInput (data_1d , args = (0 , valid_index , random_value ))
1462+
1463+ # Test case: single-element tensor index with scalar value (dim_size=8, value_type=torch.float)
1464+ dim_size = 8
1465+ data_1d = make_arg ((dim_size ,))
1466+ valid_index_1d = torch .randint (0 , dim_size , (1 ,), device = device , dtype = torch .long )
1467+ random_value = torch .rand ((), device = device , dtype = torch .float ).item ()
1468+ yield opinfo_core .SampleInput (data_1d , args = (0 , valid_index_1d , random_value ))
1469+
14261470
14271471def sample_inputs__scaled_dot_product_flash_attention (
14281472 op_info , device , dtype , requires_grad , ** kwargs
0 commit comments