@@ -75,37 +75,23 @@ def chunk(input, chunks, dim=0):
7575# gather
7676def gather (input , dim , index ):
7777 if ON_ORANGE_PI :
78- return torch_gather (input , index , dim )
78+ return gather_with_index_select (input , dim , index )
7979 return execute ("gather_d" , input , dim , index )
8080
81- def torch_gather (x , indices , axis = 1 ):
82- # 这个实现模拟了 torch.gather 的行为
83- if axis < 0 :
84- axis = len (x .shape ) + axis
81+ def gather_with_index_select (x , dim , index ):
82+ # 获取所有维度的索引
83+ idx = mindtorch .meshgrid (* [mindtorch .arange (s ) for s in index .shape ], indexing = 'ij' )
8584
86- # 创建索引数组,其他维度保持原样
87- all_indices = []
88- for dim in range (len (x .shape )):
89- if dim == axis :
90- # 使用提供的索引
91- indices = indices .to (mindspore .int32 )
92- all_indices .append (indices )
85+ # 替换目标维度的索引
86+ new_idx = ()
87+ for ix , i in enumerate (idx ):
88+ if ix == dim :
89+ new_idx += (index ,)
9390 else :
94- # 创建该维度的原始索引
95- shape = [1 ] * len (x .shape )
96- shape [dim ] = x .shape [dim ]
97- dim_indices = mindtorch .arange (x .shape [dim ], dtype = mindspore .int32 , device = x .device )
98- dim_indices = mindtorch .reshape (dim_indices , shape )
99- # 广播到 indices 的形状
100- dim_indices = mindtorch .broadcast_to (dim_indices , indices .shape )
101- all_indices .append (dim_indices )
91+ new_idx += (i ,)
10292
103- # 组合所有维度的索引
104- multi_indices = mindtorch .stack (all_indices , dim = - 1 )
105-
106- # 使用 tf.gather_nd 收集元素
107- return gather_nd (x , multi_indices )
108-
93+ # 使用高级索引提取数据
94+ return x [new_idx ]
10995
11096def gather_nd (input , indices ):
11197 return execute ("gather_nd" , input , indices )
@@ -1135,6 +1121,7 @@ def strided_slice_update(x, begin, end, strides, updates,
11351121
11361122 # Step 2: 计算目标切片 shape(考虑 shrink_axis_mask)
11371123 target_shape = []
1124+
11381125 for d , (b , e , s ) in enumerate (zip (full_begin , full_end , full_strides )):
11391126 if (shrink_axis_mask >> d ) & 1 :
11401127 continue
0 commit comments