66from mindspore ._c_expression import _empty_instance
77from mindspore .ops .auto_generate .gen_ops_prim import Empty
88import mindtorch
9- from .._op_prim .cpu import legacy
9+ from .._op_prim .cpu import legacy , pyboost
1010
1111empty_op = Empty ().set_device ('CPU' )
1212def empty (size , dtype ):
@@ -124,22 +124,7 @@ def transpose_view(input, dim0, dim1):
124124 return legacy .transpose (input , tuple (ranks ))
125125
126126def matmul (self , other ):
127- if self .ndim > 2 :
128- if self .ndim == other .ndim :
129- return legacy .batch_mat_mul (self , other , False , False )
130- else :
131- self_shape = self .shape
132- other_shape = other .shape
133- if other .ndim == 2 :
134- self = reshape (self , (- 1 , self_shape [- 1 ]))
135- out = legacy .mat_mul (self , other , False , False )
136- return reshape (out , (* self_shape [:- 1 ], out .shape [- 1 ]))
137- if self .ndim == 2 :
138- other = reshape (other , (- 1 , other_shape [- 1 ]))
139- out = legacy .mat_mul (self , other , False , False )
140- return reshape (out , (* other_shape [:- 1 ], out .shape [- 1 ]))
141-
142- return legacy .mat_mul (self , other , False , False )
127+ return pyboost .matmul_ext_op (self , other )
143128
144129def div (input , other ):
145130 return legacy .div (input , other )
@@ -592,7 +577,20 @@ def batch_norm(input, weight, bias, running_mean=None, runnning_var=None, traini
592577def tanh (input ):
593578 return legacy .tanh (input )
594579
595- def dropout (input , p , seed , offset ):
580+ def dropout (input , p , training = True ):
581+ """
582+ Returns a tensor with dropout applied element-wise.
583+
584+ Args:
585+ input (Tensor): The input tensor.
586+ p (float): The dropout probability.
587+ seed (int): The random seed.
588+
589+ Returns:
590+ Tensor: The tensor with dropout applied.
591+ """
592+ if not training or p == 0 :
593+ return input
596594 return legacy .dropout (input , 1 - p , 0 , 0 )
597595
598596def split_tensor (input , split_size_or_sections , dim ):
@@ -1259,3 +1257,65 @@ def lerp(input, end, weight):
12591257
12601258def smooth_l1_loss (input , target , beta = 1.0 , reduction = 'none' ):
12611259 return legacy .smooth_l1_loss (input , target , beta , reduction )
1260+
1261+ def index_select (input , dim , index ):
1262+ return legacy .gather (input , index , dim , 0 )
1263+
1264+ def custom_circular_pad (x , pad ):
1265+
1266+ ndim = x .ndim
1267+ n_pad_dims = len (pad ) // 2
1268+ assert n_pad_dims <= ndim , "填充参数超过了张量的维度"
1269+
1270+ # 按从最后维度向前处理填充
1271+ for dim in range (ndim - 1 , ndim - 1 - n_pad_dims , - 1 ):
1272+ # 当前维度的左右填充量
1273+ idx = 2 * (ndim - 1 - dim ) # 在pad元组中的起始位置
1274+ left_pad = pad [idx ]
1275+ right_pad = pad [idx + 1 ]
1276+
1277+ if left_pad == 0 and right_pad == 0 :
1278+ continue # 跳过该维度
1279+
1280+ size = x .shape [dim ] # 当前维度的原始长度
1281+ new_size = left_pad + size + right_pad
1282+
1283+ # 生成循环索引: (index - left_pad) mod size
1284+ index = fmod_scalar (add (arange (0 , new_size , 1 , mindspore .int64 ), new_size - left_pad ), size )
1285+ index = (index + x .shape [dim ]) % x .shape [dim ]
1286+ x = index_select (x , dim , index )
1287+
1288+ return x
1289+
1290+ def pad (input , pad , mode = 'constant' , value = None ):
1291+ if isinstance (pad , tuple ):
1292+ pad = tuple (p if isinstance (p , int ) else p .item () for p in pad )
1293+
1294+ new_pad = ()
1295+ for idx , pad_v in enumerate (pad ):
1296+ if not isinstance (pad_v , int ):
1297+ pad_v = pad_v .item ()
1298+ if pad_v < 0 :
1299+ dim = input .ndim - 1 - idx // 2
1300+ input = narrow (input , dim , 0 , input .shape [dim ] + pad_v )
1301+ pad_v = 0
1302+ new_pad += (pad_v ,)
1303+ if sum (new_pad ) == 0 :
1304+ return input
1305+ if mode == 'circular' :
1306+ return custom_circular_pad (input , pad )
1307+ elif mode == 'reflect' :
1308+ return pad_v3 (input , new_pad , mode )
1309+ if value is None :
1310+ value = 0
1311+ if mode == "replicate" :
1312+ mode = "edge"
1313+ return pad_v3 (input , new_pad , mode )
1314+ if input .dtype .is_floating_point :
1315+ value = float (value )
1316+ elif input .dtype == mindtorch .bool :
1317+ value = bool (value )
1318+ elif input .dtype in [mindtorch .int32 , mindtorch .int64 ]:
1319+ value = int (value )
1320+
1321+ return pad_v3 (input , new_pad , mode , value )
0 commit comments