Skip to content

Commit b5e629b

Browse files
authored
optimize dispatch performance and add environment to disable dispatch (#2241)
1 parent 586fdfb commit b5e629b

39 files changed

+8673
-1934
lines changed

.github/workflows/ci_pipeline.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ jobs:
3535
- name: Install dependencies
3636
run: |
3737
python -m pip install --upgrade pip==24.0
38+
pip install torch --index-url https://download.pytorch.org/whl/cpu
3839
pip install -r requirements/pylint_requirements.txt
3940
# - name: Install MindSpore
4041
# shell: bash
@@ -122,6 +123,7 @@ jobs:
122123
- name: Install dependencies
123124
run: |
124125
python -m pip install --upgrade pip==24.0
126+
pip install torch --index-url https://download.pytorch.org/whl/cpu
125127
pip install -r requirements/requirements.txt
126128
- name: Install MindSpore
127129
shell: bash
@@ -132,7 +134,7 @@ jobs:
132134
pip install mindspore
133135
- name: Test with pytest
134136
run: |
135-
pip install transformers==4.56.2
137+
pip install transformers==4.57.1
136138
cd tests
137139
git clone -b v4.56.2 https://github.com/huggingface/transformers
138140
cd ..

mindnlp/quant/smooth_quant/quant.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def infer_dtype(self, x_dtype, v_dtype, bias_dtype=None):
3030

3131

3232
def quantize_mat(mat: Tensor) -> Tuple[Tensor, Tensor]:
33-
max_val = (ops.max(ops.abs(mat), dim=-1)[0] / 127.0).to(dtype=mat.dtype)
33+
max_val = (ops.max(ops.abs(mat), -1)[0] / 127.0).to(dtype=mat.dtype)
3434
mat = (mat / max_val[..., None]).to(dtype=mindspore.int8)
3535
return mat, max_val
3636

@@ -53,7 +53,7 @@ def decomposition(mat: Tensor, unq_idx: Tensor, t: Tensor):
5353

5454

5555
def get_unq_idx_topk(mat: Tensor, k: int = 64):
56-
idx = ops.topk(ops.max(mat.view(-1, mat.shape[-1]).abs(), dim=-2)[0], k, dim=-1)[1]
56+
idx = ops.topk(ops.max(mat.view(-1, mat.shape[-1]).abs(), -2)[0], k, dim=-1)[1]
5757
t = ops.ones((mat.shape[-1]), dtype=mat.dtype)
5858
t = t.copy()
5959
if ON_ORANGE_PI:
@@ -64,7 +64,7 @@ def get_unq_idx_topk(mat: Tensor, k: int = 64):
6464

6565

6666
def get_unq_idx_thres(mat: Tensor, threshold: float = 6.0):
67-
k = ops.max(mat.view(-1, mat.shape[-1]).abs(), dim=-2)[0] >= threshold
67+
k = ops.max(mat.view(-1, mat.shape[-1]).abs(), -2)[0] >= threshold
6868
return ops.nonzero(k).view(-1), k
6969

7070

@@ -113,7 +113,7 @@ def __init__(
113113
self.scales = None
114114
if act_max is not None:
115115
self.scales = (
116-
(act_max.pow(alpha) / ops.max(ori_w.abs(), dim=0)[0].pow(1 - alpha))
116+
(act_max.pow(alpha) / ops.max(ori_w.abs(), 0)[0].pow(1 - alpha))
117117
.clamp(min=1e-5)
118118
.to(dtype=ori_w.dtype)
119119
)

mindnlp/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .masking_utils import create_causal_mask, create_sliding_window_causal_mask, create_masks_for_generate
1111
from .modeling_utils import construct_pipeline_parallel_model, _load_pretrained_model_wrapper, \
1212
_get_resolved_checkpoint_files_wrapper
13+
from .cache_utils import dynamic_layer_update
1314
from .tokenization_utils import apply_chat_template_wrapper
1415
from .trainer import training_step
1516
from ..utils.decorators import dtype_wrapper, patch_dtype_wrapper, patch_wrappers
@@ -68,5 +69,6 @@ def empty_fn(*args, **kwargs):
6869

6970
transformers.trainer.Trainer.training_step = training_step
7071

72+
transformers.cache_utils.DynamicLayer.update = dynamic_layer_update
7173
# add mindnlp.transformers modules/attrs to lazymodule
7274
# setattr(sys.modules[__name__], 'test_ms_model', test_ms_model)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Any, Optional
2+
import mindtorch
3+
4+
def dynamic_layer_update(
5+
self,
6+
key_states: mindtorch.Tensor,
7+
value_states: mindtorch.Tensor,
8+
cache_kwargs: Optional[dict[str, Any]] = None,
9+
) -> tuple[mindtorch.Tensor, mindtorch.Tensor]:
10+
"""
11+
Update the key and value caches in-place, and return the necessary keys and value states.
12+
13+
Args:
14+
key_states (`mindtorch.Tensor`): The new key states to cache.
15+
value_states (`mindtorch.Tensor`): The new value states to cache.
16+
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
17+
18+
Returns:
19+
tuple[`mindtorch.Tensor`, `mindtorch.Tensor`]: The key and value states.
20+
"""
21+
# Lazy initialization
22+
if not self.is_initialized:
23+
self.lazy_initialization(key_states)
24+
self.keys = key_states
25+
self.values = value_states
26+
else:
27+
self.keys = mindtorch.cat([self.keys, key_states], dim=-2)
28+
self.values = mindtorch.cat([self.values, value_states], dim=-2)
29+
return self.keys, self.values

mindtorch/_C/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,6 @@ def __eq__(self, __value):
9191
def __hash__(self):
9292
return hash(self.type) ^ hash(self.index)
9393

94-
def __gt__(self, other):
95-
if self.type == 'cpu':
96-
return False
97-
return True
98-
9994
def __enter__(self):
10095
# self.prev_idx = torch.cuda._exchange_device(self.idx)
10196
mindtorch._bind.set_device_in_context(self)
@@ -201,8 +196,6 @@ def _step(self, step):
201196
Current seed and offset.
202197
"""
203198
outs = self._generator(STEP, (self._seed, self._offset, step,))[:2]
204-
for o in outs:
205-
o._device = self.device
206199
return outs
207200

208201
default_generator = Generator()

mindtorch/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,10 @@ def _running_with_deploy():
169169

170170
from .amp import autocast, GradScaler
171171
from .func import vmap
172-
from .configs import set_pyboost
173172
from .storage import UntypedStorage, Storage, TypedStorage
174173

175-
from . import _dynamo
176-
from . import profiler, cuda, amp, compiler, jit, version, __future__, overrides, \
174+
from . import _dynamo, library
175+
from . import profiler, cuda, npu, amp, compiler, jit, version, __future__, overrides, \
177176
return_types, linalg, fx, backends, nn, fft, _jit_internal, utils, optim, testing, _ops
178177
from ._lowrank import svd_lowrank
179178
from .random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state

mindtorch/_apis/cpu.py

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mindspore._c_expression import _empty_instance
77
from mindspore.ops.auto_generate.gen_ops_prim import Empty
88
import mindtorch
9-
from .._op_prim.cpu import legacy
9+
from .._op_prim.cpu import legacy, pyboost
1010

1111
empty_op = Empty().set_device('CPU')
1212
def empty(size, dtype):
@@ -124,22 +124,7 @@ def transpose_view(input, dim0, dim1):
124124
return legacy.transpose(input, tuple(ranks))
125125

126126
def matmul(self, other):
127-
if self.ndim > 2:
128-
if self.ndim == other.ndim:
129-
return legacy.batch_mat_mul(self, other, False, False)
130-
else:
131-
self_shape = self.shape
132-
other_shape = other.shape
133-
if other.ndim == 2:
134-
self = reshape(self, (-1, self_shape[-1]))
135-
out = legacy.mat_mul(self, other, False, False)
136-
return reshape(out, (*self_shape[:-1], out.shape[-1]))
137-
if self.ndim == 2:
138-
other = reshape(other, (-1, other_shape[-1]))
139-
out = legacy.mat_mul(self, other, False, False)
140-
return reshape(out, (*other_shape[:-1], out.shape[-1]))
141-
142-
return legacy.mat_mul(self, other, False, False)
127+
return pyboost.matmul_ext_op(self, other)
143128

144129
def div(input, other):
145130
return legacy.div(input, other)
@@ -592,7 +577,20 @@ def batch_norm(input, weight, bias, running_mean=None, runnning_var=None, traini
592577
def tanh(input):
593578
return legacy.tanh(input)
594579

595-
def dropout(input, p, seed, offset):
580+
def dropout(input, p, training=True):
581+
"""
582+
Returns a tensor with dropout applied element-wise.
583+
584+
Args:
585+
input (Tensor): The input tensor.
586+
p (float): The dropout probability.
587+
seed (int): The random seed.
588+
589+
Returns:
590+
Tensor: The tensor with dropout applied.
591+
"""
592+
if not training or p==0:
593+
return input
596594
return legacy.dropout(input, 1-p, 0, 0)
597595

598596
def split_tensor(input, split_size_or_sections, dim):
@@ -1259,3 +1257,65 @@ def lerp(input, end, weight):
12591257

12601258
def smooth_l1_loss(input, target, beta=1.0, reduction='none'):
12611259
return legacy.smooth_l1_loss(input, target, beta, reduction)
1260+
1261+
def index_select(input, dim, index):
1262+
return legacy.gather(input, index, dim, 0)
1263+
1264+
def custom_circular_pad(x, pad):
1265+
1266+
ndim = x.ndim
1267+
n_pad_dims = len(pad) // 2
1268+
assert n_pad_dims <= ndim, "填充参数超过了张量的维度"
1269+
1270+
# 按从最后维度向前处理填充
1271+
for dim in range(ndim-1, ndim-1-n_pad_dims, -1):
1272+
# 当前维度的左右填充量
1273+
idx = 2 * (ndim - 1 - dim) # 在pad元组中的起始位置
1274+
left_pad = pad[idx]
1275+
right_pad = pad[idx + 1]
1276+
1277+
if left_pad == 0 and right_pad == 0:
1278+
continue # 跳过该维度
1279+
1280+
size = x.shape[dim] # 当前维度的原始长度
1281+
new_size = left_pad + size + right_pad
1282+
1283+
# 生成循环索引: (index - left_pad) mod size
1284+
index = fmod_scalar(add(arange(0, new_size, 1, mindspore.int64), new_size - left_pad), size)
1285+
index = (index + x.shape[dim]) % x.shape[dim]
1286+
x = index_select(x, dim, index)
1287+
1288+
return x
1289+
1290+
def pad(input, pad, mode='constant', value=None):
1291+
if isinstance(pad, tuple):
1292+
pad = tuple(p if isinstance(p, int) else p.item() for p in pad)
1293+
1294+
new_pad = ()
1295+
for idx, pad_v in enumerate(pad):
1296+
if not isinstance(pad_v, int):
1297+
pad_v = pad_v.item()
1298+
if pad_v < 0:
1299+
dim = input.ndim - 1 - idx // 2
1300+
input = narrow(input, dim, 0, input.shape[dim] + pad_v)
1301+
pad_v = 0
1302+
new_pad += (pad_v,)
1303+
if sum(new_pad) == 0:
1304+
return input
1305+
if mode == 'circular':
1306+
return custom_circular_pad(input, pad)
1307+
elif mode == 'reflect':
1308+
return pad_v3(input, new_pad, mode)
1309+
if value is None:
1310+
value = 0
1311+
if mode == "replicate":
1312+
mode = "edge"
1313+
return pad_v3(input, new_pad, mode)
1314+
if input.dtype.is_floating_point:
1315+
value = float(value)
1316+
elif input.dtype == mindtorch.bool:
1317+
value = bool(value)
1318+
elif input.dtype in [mindtorch.int32, mindtorch.int64]:
1319+
value = int(value)
1320+
1321+
return pad_v3(input, new_pad, mode, value)

mindtorch/_apis/gpu.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,20 @@ def batch_norm(input, weight, bias, running_mean=None, runnning_var=None, traini
532532
def tanh(input):
533533
return legacy.tanh(input)
534534

535-
def dropout(input, p, seed, offset):
535+
def dropout(input, p, training=True):
536+
"""
537+
Returns a tensor with dropout applied element-wise.
538+
539+
Args:
540+
input (Tensor): The input tensor.
541+
p (float): The dropout probability.
542+
seed (int): The random seed.
543+
544+
Returns:
545+
Tensor: The tensor with dropout applied.
546+
"""
547+
if not training or p==0:
548+
return input
536549
return legacy.dropout(input, 1-p, 0, 0)
537550

538551
def split_tensor(input, split_size_or_sections, dim):
@@ -1256,3 +1269,66 @@ def cumprod(input, dim, dtype):
12561269

12571270
def lerp(input, end, weight):
12581271
return legacy.lerp(input, end, weight)
1272+
1273+
def custom_circular_pad(x, pad):
1274+
1275+
ndim = x.ndim
1276+
n_pad_dims = len(pad) // 2
1277+
assert n_pad_dims <= ndim, "填充参数超过了张量的维度"
1278+
1279+
# 按从最后维度向前处理填充
1280+
for dim in range(ndim-1, ndim-1-n_pad_dims, -1):
1281+
# 当前维度的左右填充量
1282+
idx = 2 * (ndim - 1 - dim) # 在pad元组中的起始位置
1283+
left_pad = pad[idx]
1284+
right_pad = pad[idx + 1]
1285+
1286+
if left_pad == 0 and right_pad == 0:
1287+
continue # 跳过该维度
1288+
1289+
size = x.shape[dim] # 当前维度的原始长度
1290+
new_size = left_pad + size + right_pad
1291+
1292+
# 生成循环索引: (index - left_pad) mod size
1293+
index = fmod_scalar(add(arange(0, new_size, 1, mindspore.int64), new_size - left_pad), size)
1294+
index = (index + x.shape[dim]) % x.shape[dim]
1295+
x = index_select(x, dim, index)
1296+
1297+
return x
1298+
1299+
def pad(input, pad, mode='constant', value=None):
1300+
if isinstance(pad, tuple):
1301+
pad = tuple(p if isinstance(p, int) else p.item() for p in pad)
1302+
1303+
new_pad = ()
1304+
for idx, pad_v in enumerate(pad):
1305+
if not isinstance(pad_v, int):
1306+
pad_v = pad_v.item()
1307+
if pad_v < 0:
1308+
dim = input.ndim - 1 - idx // 2
1309+
input = narrow(input, dim, 0, input.shape[dim] + pad_v)
1310+
pad_v = 0
1311+
new_pad += (pad_v,)
1312+
if sum(new_pad) == 0:
1313+
return input
1314+
if mode == 'circular':
1315+
return custom_circular_pad(input, pad)
1316+
elif mode == 'reflect':
1317+
return pad_v3(input, new_pad, mode)
1318+
if value is None:
1319+
value = 0
1320+
if mode == "replicate":
1321+
mode = "edge"
1322+
return pad_v3(input, new_pad, mode)
1323+
if input.dtype.is_floating_point:
1324+
value = float(value)
1325+
elif input.dtype == mindtorch.bool:
1326+
value = bool(value)
1327+
elif input.dtype in [mindtorch.int32, mindtorch.int64]:
1328+
value = int(value)
1329+
if mode == 'constant' and value == 0 and len(new_pad) > 6:
1330+
paddings = ()
1331+
for i in range(input.ndim-1, -1, -1):
1332+
paddings += ((new_pad[2*i], new_pad[2*i+1]),)
1333+
return pad(input, paddings)
1334+
return pad_v3(input, new_pad, mode, value)

0 commit comments

Comments
 (0)