@@ -804,7 +804,7 @@ def index(input, index):
804804def scatter (input , dim , index , src ):
805805 if use_pyboost () and not ON_ORANGE_PI :
806806 return pyboost .scatter_op (input , dim , index , src )
807- return legacy .tensor_scatter_elements (input , index , src , dim , "none" )
807+ return legacy .tensor_scatter_elements (input , index , cast ( src , input . dtype ) , dim , "none" )
808808
809809def tril (input , diagonal = 0 ):
810810 if use_pyboost ():
@@ -858,7 +858,8 @@ def isinf(input):
858858def sort (input , dim , descending , stable ):
859859 if use_pyboost () and not ON_ORANGE_PI :
860860 return pyboost .sort_ext_op (input , dim , descending , stable )
861- return legacy .sort (input , dim , descending )
861+ out = legacy .sort (input , dim , descending )
862+ return out [0 ], cast (out [1 ], mindspore .int64 )
862863
863864def prod (input , axis , keepdims , dtype ):
864865 if use_pyboost ():
@@ -1612,9 +1613,15 @@ def inplace_add(input, other, alpha):
16121613 return legacy .inplace_add (input , other )
16131614
16141615def logsumexp (input , dim , keepdim ):
1615- if use_pyboost ():
1616+ if use_pyboost () and not ON_ORANGE_PI :
16161617 return pyboost .logsumexp_op (input , dim , keepdim )
1617- return legacy .logsumexp (input , dim , keepdim )
1618+ input_max = legacy .reduce_max (input , dim , True )
1619+ input_exp = exp (sub (input , input_max ))
1620+ input_sumexp = sum (input_exp , dim , keepdim , None )
1621+ input_logsumexp = log (input_sumexp )
1622+ if not keepdim :
1623+ input_max = squeeze (input_max , dim )
1624+ return add (input_logsumexp , input_max )
16181625
16191626def ctc_loss (log_probs , targets , input_lengths , target_lengths , blank , reduction , zero_infinity ):
16201627 loss , log_alpha = legacy .ctc_loss_v2 (log_probs , targets , input_lengths , target_lengths , blank , 'none' , zero_infinity )
@@ -1922,9 +1929,11 @@ def linalg_qr(input_x, mode):
19221929
19231930def bernoulli (input , generator ):
19241931 seed , offset = generator ._step (12 )
1925- if use_pyboost ():
1932+ if use_pyboost () and not ON_ORANGE_PI :
19261933 return pyboost .bernoulli_ext_op (input , seed , offset )
1927- return legacy .bernoulli (input , seed , offset )
1934+ uniform = rand_like (input , generator , input .dtype )
1935+ result = cast (less (uniform , input ), input .dtype )
1936+ return result
19281937
19291938def multinomial (input , num_samples , replacement , generator ):
19301939 seed , offset = generator ._step (12 ) # pylint: disable=protected-access
@@ -1998,4 +2007,7 @@ def replication_pad_1d(input, padding):
19982007 return pyboost .reflection_pad_1d_op (input , padding )
19992008
20002009def hardtanh (input , min_val , max_val ):
2001- return pyboost .hardtanh_op (input , min_val , max_val )
2010+ return pyboost .hardtanh_op (input , min_val , max_val )
2011+
2012+ def smooth_l1_loss (input , target , beta = 1.0 , reduction = 'none' ):
2013+ return pyboost .smooth_l1_loss_impl (input , target , beta , reduction )
0 commit comments