Skip to content

Commit b210613

Browse files
authored
fix d class models on OrangePi (#2214)
1 parent 568e398 commit b210613

File tree

4 files changed

+33
-13
lines changed

4 files changed

+33
-13
lines changed

mindtorch/_apis/npu.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def cast(input, dtype):
182182
"""
183183
return legacy.cast(input, dtype)
184184

185-
def sub(input, other, alpha):
185+
def sub(input, other, alpha=1.0):
186186
"""
187187
Subtracts the other tensor from the input tensor.
188188
@@ -271,8 +271,10 @@ def matmul(input, other):
271271
Tensor: The result of the matrix multiplication.
272272
"""
273273
if ON_ORANGE_PI:
274+
dtype = input.dtype
274275
input = cast(input, mindspore.float16)
275276
other = cast(other, mindspore.float16)
277+
return cast(pyboost.matmul_ext_op(input, other), dtype)
276278
if use_pyboost():
277279
return pyboost.matmul_ext_op(input, other)
278280
return legacy.mat_mul(input, other)
@@ -1144,9 +1146,9 @@ def neg(input):
11441146
return legacy.neg(input)
11451147

11461148
def log1p(input):
1147-
if use_pyboost():
1149+
if use_pyboost() and not ON_ORANGE_PI:
11481150
return pyboost.log1p_op(input)
1149-
return legacy.log1p(input)
1151+
return log(add(input, 1))
11501152

11511153
def pow_scalar_tensor(input, scalar):
11521154
if use_pyboost():
@@ -1506,19 +1508,24 @@ def var(input, dim=None, correction=1, keepdim=False):
15061508
return legacy.var(input, dim, correction, keepdim)
15071509

15081510
def linspace(start, end, steps, dtype=None):
1509-
if use_pyboost():
1511+
if use_pyboost() and not ON_ORANGE_PI:
15101512
return pyboost.lin_space_ext_op(start, end, steps, dtype)
1511-
return legacy.lin_space(start, end, steps)
1513+
start = float(start)
1514+
end = float(end)
1515+
return legacy.lin_space(mindspore.Tensor(start), mindspore.Tensor(end), steps)
15121516

15131517
def masked_select(input, mask):
15141518
if use_pyboost():
15151519
return pyboost.masked_select_op(input, mask)
15161520
return legacy.masked_select(input, mask)
15171521

15181522
def glu(input, dim=-1):
1519-
if use_pyboost():
1523+
if use_pyboost() and not ON_ORANGE_PI:
15201524
return pyboost.glu_impl(input, dim)
1521-
return legacy.glu(input, dim)
1525+
a, b = chunk(input, 2, dim)
1526+
gate = sigmoid(b)
1527+
return mul(a, gate)
1528+
15221529

15231530
def scatter_value(input, dim, index, src, reduce='none'):
15241531
if use_pyboost():
@@ -1668,11 +1675,13 @@ def pixel_shuffle(input, upscale_factor):
16681675
return legacy.pixel_shuffle(input, upscale_factor)
16691676

16701677
def view_as_complex(input):
1678+
if ON_ORANGE_PI:
1679+
input = clone(input)
16711680
real_part, imag_part = chunk(input, 2, -1)
16721681
return legacy.complex(squeeze(real_part, -1), squeeze(imag_part, -1))
16731682

16741683
def rms_norm(input, weight, eps=1e-5):
1675-
if use_pyboost():
1684+
if use_pyboost() and not ON_ORANGE_PI:
16761685
return pyboost.rms_norm_impl(input, weight, eps)[0]
16771686
input_dtype = input.dtype
16781687
input = cast(input, mindspore.float32)
@@ -1904,4 +1913,11 @@ def tensor_scatter_update(input, indices, updates):
19041913
return legacy.tensor_scatter_update(input, indices, updates)
19051914

19061915
def lerp(input, end, weight):
1907-
return legacy.lerp(input, end, weight)
1916+
return legacy.lerp(input, end, weight)
1917+
1918+
def logaddexp(input, other):
1919+
m = maximum(input, other)
1920+
abs_val = abs(sub(input, other))
1921+
exp_val = exp(neg(abs_val))
1922+
y = add(m, log1p(exp_val))
1923+
return y

mindtorch/ops/creation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=Fals
103103
# arange
104104
def arange(start=0, end=None, step=1, *, out=None, dtype=None, layout=None, device=None, requires_grad=False):
105105
if end is None:
106-
start, end = 0, start
106+
start, end = 0, int(start)
107107
if dtype is None:
108108
dtype = mindtorch.py2dtype[type(start)]
109109
if device is None:

mindtorch/ops/other.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def bucketize(input, boundaries, *, out_int32=False, right=False, out=None):
7979
def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
8080
return execute('cdist', x1, x2, p)
8181

82+
8283
# clone
8384
def clone(input, *, memory_format=mindtorch.preserve_format):
8485
return execute('clone', input)

mindtorch/ops/pointwise.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,12 @@ def mul(input, other):
391391
other = other.to(device)
392392
input = input.to(device)
393393
if not isinstance(other, numbers.Number) and other.dtype != input.dtype:
394-
dtype = min([input.dtype, other.dtype])
395-
other = other.to(dtype)
396-
input = input.to(dtype)
394+
if other.dtype == mindtorch.bool:
395+
other = other.to(input.dtype)
396+
else:
397+
dtype = min([input.dtype, other.dtype])
398+
other = other.to(dtype)
399+
input = input.to(dtype)
397400
# and isinstance(input, torch.Tensor):
398401
# return execute("muls", input, other)
399402
return execute("mul", input, other)

0 commit comments

Comments
 (0)