@@ -182,7 +182,7 @@ def cast(input, dtype):
182182 """
183183 return legacy .cast (input , dtype )
184184
185- def sub (input , other , alpha ):
185+ def sub (input , other , alpha = 1.0 ):
186186 """
187187 Subtracts the other tensor from the input tensor.
188188
@@ -271,8 +271,10 @@ def matmul(input, other):
271271 Tensor: The result of the matrix multiplication.
272272 """
273273 if ON_ORANGE_PI :
274+ dtype = input .dtype
274275 input = cast (input , mindspore .float16 )
275276 other = cast (other , mindspore .float16 )
277+ return cast (pyboost .matmul_ext_op (input , other ), dtype )
276278 if use_pyboost ():
277279 return pyboost .matmul_ext_op (input , other )
278280 return legacy .mat_mul (input , other )
@@ -1144,9 +1146,9 @@ def neg(input):
11441146 return legacy .neg (input )
11451147
11461148def log1p (input ):
1147- if use_pyboost ():
1149+ if use_pyboost () and not ON_ORANGE_PI :
11481150 return pyboost .log1p_op (input )
1149- return legacy . log1p ( input )
1151+ return log ( add ( input , 1 ) )
11501152
11511153def pow_scalar_tensor (input , scalar ):
11521154 if use_pyboost ():
@@ -1506,19 +1508,24 @@ def var(input, dim=None, correction=1, keepdim=False):
15061508 return legacy .var (input , dim , correction , keepdim )
15071509
15081510def linspace (start , end , steps , dtype = None ):
1509- if use_pyboost ():
1511+ if use_pyboost () and not ON_ORANGE_PI :
15101512 return pyboost .lin_space_ext_op (start , end , steps , dtype )
1511- return legacy .lin_space (start , end , steps )
1513+ start = float (start )
1514+ end = float (end )
1515+ return legacy .lin_space (mindspore .Tensor (start ), mindspore .Tensor (end ), steps )
15121516
15131517def masked_select (input , mask ):
15141518 if use_pyboost ():
15151519 return pyboost .masked_select_op (input , mask )
15161520 return legacy .masked_select (input , mask )
15171521
15181522def glu (input , dim = - 1 ):
1519- if use_pyboost ():
1523+ if use_pyboost () and not ON_ORANGE_PI :
15201524 return pyboost .glu_impl (input , dim )
1521- return legacy .glu (input , dim )
1525+ a , b = chunk (input , 2 , dim )
1526+ gate = sigmoid (b )
1527+ return mul (a , gate )
1528+
15221529
15231530def scatter_value (input , dim , index , src , reduce = 'none' ):
15241531 if use_pyboost ():
@@ -1668,11 +1675,13 @@ def pixel_shuffle(input, upscale_factor):
16681675 return legacy .pixel_shuffle (input , upscale_factor )
16691676
16701677def view_as_complex (input ):
1678+ if ON_ORANGE_PI :
1679+ input = clone (input )
16711680 real_part , imag_part = chunk (input , 2 , - 1 )
16721681 return legacy .complex (squeeze (real_part , - 1 ), squeeze (imag_part , - 1 ))
16731682
16741683def rms_norm (input , weight , eps = 1e-5 ):
1675- if use_pyboost ():
1684+ if use_pyboost () and not ON_ORANGE_PI :
16761685 return pyboost .rms_norm_impl (input , weight , eps )[0 ]
16771686 input_dtype = input .dtype
16781687 input = cast (input , mindspore .float32 )
@@ -1904,4 +1913,11 @@ def tensor_scatter_update(input, indices, updates):
19041913 return legacy .tensor_scatter_update (input , indices , updates )
19051914
19061915def lerp (input , end , weight ):
1907- return legacy .lerp (input , end , weight )
1916+ return legacy .lerp (input , end , weight )
1917+
1918+ def logaddexp (input , other ):
1919+ m = maximum (input , other )
1920+ abs_val = abs (sub (input , other ))
1921+ exp_val = exp (neg (abs_val ))
1922+ y = add (m , log1p (exp_val ))
1923+ return y
0 commit comments