Skip to content

Commit 614a86e

Browse files
authored
fix t-z class on OrangePi (#2228)
1 parent d669913 commit 614a86e

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

mindtorch/_apis/npu.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,10 @@ def bmm(input, other):
893893
return pyboost.bmm_ext_op(input, other)
894894
return legacy.batch_mat_mul(input, other, False, False)
895895

896+
def topk_legacy(input, k, sorted):
897+
out = legacy.top_k(input, k, sorted)
898+
return out[0], cast(out[1], mindspore.int64)
899+
896900
def topk(input, k, dim, largest, sorted):
897901
if use_pyboost() and not ON_ORANGE_PI:
898902
return pyboost.topk_ext_op(input, k, dim, largest, sorted)
@@ -901,12 +905,12 @@ def topk(input, k, dim, largest, sorted):
901905
input = -input
902906
if dim is None or dim == input.ndim - 1:
903907
if not largest:
904-
res = legacy.top_k(input, k, sorted)
908+
res = topk_legacy(input, k, sorted)
905909
values, indices = -res[0], res[1]
906910
return values, indices
907-
return legacy.top_k(input, k, sorted)
911+
return topk_legacy(input, k, sorted)
908912
input = transpose_view(input, dim, input.ndim - 1)
909-
output = legacy.top_k(input, k, sorted)
913+
output = topk_legacy(input, k, sorted)
910914
values = transpose_view(output[0], dim, input.ndim - 1)
911915
indices = transpose_view(output[1], dim, input.ndim - 1)
912916
if not largest:
@@ -1541,9 +1545,11 @@ def diag(input, diagonal):
15411545
return legacy.diag(input, diagonal)
15421546

15431547
def logsigmoid(input):
1544-
if use_pyboost():
1548+
if use_pyboost() and not ON_ORANGE_PI:
15451549
return pyboost.logsigmoid_op(input)
1546-
return legacy.logsigmoid(input)
1550+
output = sigmoid(input)
1551+
ret = log(output)
1552+
return ret
15471553

15481554
def one_hot(tensor, num_classes):
15491555
if use_pyboost():

0 commit comments

Comments
 (0)