Skip to content

Commit 8809765

Browse files
authored
fix e class on OrangePi (#2215)
1 parent b210613 commit 8809765

File tree

3 files changed

+17
-27
lines changed

3 files changed

+17
-27
lines changed

mindtorch/_apis/npu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,3 +1921,6 @@ def logaddexp(input, other):
19211921
exp_val = exp(neg(abs_val))
19221922
y = add(m, log1p(exp_val))
19231923
return y
1924+
1925+
def reflection_pad_1d(input, padding):
1926+
return pyboost.reflection_pad_1d_op(input, padding)

mindtorch/nn/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def softplus(input, beta=1, threshold=20):
5858
def logsigmoid(input):
5959
return execute('logsigmoid', input)[0]
6060

61-
def leaky_relu(input, alpha=0.2):
61+
def leaky_relu(input, alpha=0.2, inplace=False):
6262
return execute('leaky_relu', input, alpha)
6363

6464
def prelu(input, weight):

mindtorch/ops/array.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -75,37 +75,23 @@ def chunk(input, chunks, dim=0):
7575
# gather
7676
def gather(input, dim, index):
7777
if ON_ORANGE_PI:
78-
return torch_gather(input, index, dim)
78+
return gather_with_index_select(input, dim, index)
7979
return execute("gather_d", input, dim, index)
8080

81-
def torch_gather(x, indices, axis=1):
82-
# 这个实现模拟了 torch.gather 的行为
83-
if axis < 0:
84-
axis = len(x.shape) + axis
81+
def gather_with_index_select(x, dim, index):
82+
# 获取所有维度的索引
83+
idx = mindtorch.meshgrid(*[mindtorch.arange(s) for s in index.shape], indexing='ij')
8584

86-
# 创建索引数组,其他维度保持原样
87-
all_indices = []
88-
for dim in range(len(x.shape)):
89-
if dim == axis:
90-
# 使用提供的索引
91-
indices = indices.to(mindspore.int32)
92-
all_indices.append(indices)
85+
# 替换目标维度的索引
86+
new_idx = ()
87+
for ix, i in enumerate(idx):
88+
if ix == dim:
89+
new_idx += (index,)
9390
else:
94-
# 创建该维度的原始索引
95-
shape = [1] * len(x.shape)
96-
shape[dim] = x.shape[dim]
97-
dim_indices = mindtorch.arange(x.shape[dim], dtype=mindspore.int32, device=x.device)
98-
dim_indices = mindtorch.reshape(dim_indices, shape)
99-
# 广播到 indices 的形状
100-
dim_indices = mindtorch.broadcast_to(dim_indices, indices.shape)
101-
all_indices.append(dim_indices)
91+
new_idx += (i,)
10292

103-
# 组合所有维度的索引
104-
multi_indices = mindtorch.stack(all_indices, dim=-1)
105-
106-
# 使用 tf.gather_nd 收集元素
107-
return gather_nd(x, multi_indices)
108-
93+
# 使用高级索引提取数据
94+
return x[new_idx]
10995

11096
def gather_nd(input, indices):
11197
return execute("gather_nd", input, indices)
@@ -1135,6 +1121,7 @@ def strided_slice_update(x, begin, end, strides, updates,
11351121

11361122
# Step 2: 计算目标切片 shape(考虑 shrink_axis_mask)
11371123
target_shape = []
1124+
11381125
for d, (b, e, s) in enumerate(zip(full_begin, full_end, full_strides)):
11391126
if (shrink_axis_mask >> d) & 1:
11401127
continue

0 commit comments

Comments
 (0)