Skip to content

Commit 3a86699

Browse files
authored
fix g class on OrangePi (#2219)
1 parent 52e9ebb commit 3a86699

File tree

2 files changed

+63
-14
lines changed

2 files changed

+63
-14
lines changed

mindtorch/_apis/npu.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import mindspore
22
import mindtorch
3+
import numpy as np
34
from mindspore._c_expression import _empty_instance
45
from ..configs import use_pyboost, ON_A1, ON_ORANGE_PI
56
from .._op_prim.ascend import legacy, pyboost
@@ -52,6 +53,8 @@ def select_ext_view(input, dim, index):
5253
Tensor: The selected slice.
5354
"""
5455
if use_pyboost():
56+
if ON_ORANGE_PI:
57+
input = clone(input)
5558
return pyboost.select_ext_view_op(input, dim, index)
5659
else:
5760
return legacy.select_view(input, index, dim)
@@ -83,17 +86,63 @@ def slice(input, dim, start, end, step):
8386
Returns:
8487
Tensor: The sliced tensor.
8588
"""
86-
if use_pyboost():
87-
return pyboost.slice_ext_op(input, dim, start, end, step)
89+
if use_pyboost() and not ON_ORANGE_PI:
90+
return pyboost.slice_ext_view_op(input, dim, start, end, step)
91+
else:
92+
if step == 1:
93+
return pyboost.slice_ext_view_op(input, dim, start, end, step)
94+
# ndim = input.ndim
95+
# begins = [0] * ndim
96+
# ends = [i for i in input.shape]
97+
# strides = [1] * ndim
98+
# begins[dim] = start
99+
# ends[dim] = end
100+
# strides[dim] = step
101+
# print(input.shape)
102+
# print(tuple(begins), tuple(ends), tuple(strides))
103+
# print(legacy.strided_slice(input, tuple(begins), tuple(ends), tuple(strides), 0, 0, 0, 0, 0))
104+
# return legacy.strided_slice(input, tuple(begins), tuple(ends), tuple(strides), 0, 0, 0, 0, 0)
105+
if end < 0:
106+
end = input.shape[dim] + end
107+
108+
# 2. 计算新视图的大小(size)
109+
new_size = list(input.size())
110+
# 新维度上的长度计算考虑了步长
111+
new_size[dim] = (end - start + step - 1) // step # 向上取整计算元素个数
112+
if new_size[dim] <= 0:
113+
raise RuntimeError(f"Calculated size for dimension {dim} is non-positive after slicing.")
114+
115+
# 3. 计算新的步长(stride)和存储偏移量(storage_offset)
116+
old_strides = input.stride()
117+
new_strides = list(old_strides)
118+
# 在目标维度上,新步长 = 原步长 * 步长(step)
119+
new_strides[dim] = old_strides[dim] * step
120+
# 新的存储偏移量 = 原偏移量 + 起始索引 * 目标维度的原步长
121+
new_storage_offset = input.storage_offset() + start * old_strides[dim]
122+
123+
# 4. 使用 as_strided 创建新视图
124+
# 关键:as_strided 通过直接定义新张量的尺寸、步长和存储偏移量来创建一个视图,而不复制数据。
125+
sliced_tensor = as_strided_manual(input, size=tuple(new_size), stride=tuple(new_strides), storage_offset=new_storage_offset)
126+
127+
return sliced_tensor
128+
129+
def as_strided_manual(self, size, stride, storage_offset=None):
130+
if len(size) != len(stride):
131+
raise RuntimeError("mismatch in length of strides and shape.")
132+
index = np.arange(0, size[0]*stride[0], stride[0])
133+
for i in np.arange(1, len(size)):
134+
tmp = np.arange(0, size[i]*stride[i], stride[i])
135+
index = np.expand_dims(index, -1)
136+
index = index + tmp
137+
if storage_offset is not None:
138+
index = index + storage_offset
139+
140+
if index.size == 0:
141+
input_indices = mindspore.Tensor(Tensor_(index.shape, dtype=mindspore.int32))
88142
else:
89-
ndim = input.ndim
90-
begins = [0] * ndim
91-
ends = [i for i in input.shape]
92-
strides = [1] * ndim
93-
begins[dim] = start
94-
ends[dim] = end
95-
strides[dim] = step
96-
return legacy.strided_slice(input, tuple(begins), tuple(ends), tuple(strides), 0, 0, 0, 0, 0)
143+
input_indices = mindspore.tensor(index.astype(np.int32))
144+
out = gather(reshape(self, (-1,)), input_indices, 0, 0)
145+
return out
97146

98147

99148
def embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq):
@@ -306,7 +355,7 @@ def divmod(input, other, rounding_mode):
306355
Returns:
307356
Tuple[Tensor, Tensor]: The quotient and the remainder.
308357
"""
309-
if use_pyboost():
358+
if use_pyboost() and not ON_ORANGE_PI:
310359
return pyboost.divmod_op(input, other, rounding_mode)
311360
if rounding_mode == 'floor':
312361
return legacy.floor_div(input, other)

mindtorch/ops/array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -608,9 +608,9 @@ def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims):
608608
self_viewed = self
609609
self_viewed_shape = list(self.shape)
610610
dim = 0
611-
if ON_ORANGE_PI:
612-
if all([isinstance(index, slice) for index in indexes]):
613-
return getitem(self_viewed, tuple(indexes)), remain_indexes
611+
# if ON_ORANGE_PI:
612+
# if all([isinstance(index, slice) for index in indexes]):
613+
# return getitem(self_viewed, tuple(indexes)), remain_indexes
614614
for i, index in enumerate(indexes):
615615
if isinstance(index, (list, tuple, np.ndarray)):
616616
index_np = np.array(index) if isinstance(index, (list, tuple)) else index

0 commit comments

Comments
 (0)