Skip to content

Commit 136ea9f

Browse files
authored
[refact] unified soc_version code (#4359)
### What this PR does / why we need it? Currently, there are two paths to judge the chip type in code, `get_ascend_soc_version` use `get_soc_version` api in torch_npu, and `is_310p` `use _build_info.__soc_version__`, which generate when install. We need to unify the two paths. We need to unify these codes based on the following points: 1. We need to ensure consistency in chip type judgment between compiling and running states; 2. In compiling state, we need chip type to complete op's compilation, but in running state, we only need device type(910B/910_93/310P/910_95/etc) to make code branch judgement; 3. In compiling state, torch_npu may not have been installed yet, so we can't use torch_npu's api. Based on the above points, we have made the following changes: 1. When user set env `SOC_VERSION`, use it; when not set, query soc_version by `npu-smi`; 2. generate device_type based on soc_version when compiling, and write `__device_type__` instead of `__soc_version__` in `_build_info.py`; 3. In running state, use `__device_type__` to judge code branch. ### Does this PR introduce _any_ user-facing change? When not set env `SOC_VERSION`, it will not be `ASCEND910B1` by default, we will query soc_version by `npu-smi`. And env `SOC_VERSION` must be in the list `soc_to_device` in `setup.py`. - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@2918c1b Signed-off-by: zzzzwwjj <1183291235@qq.com>
1 parent a91e76c commit 136ea9f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+361
-243
lines changed

examples/disaggregated_prefill_v1/gen_ranktable.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch.distributed as dist
66

7-
from vllm_ascend.utils import AscendSocVersion, init_ascend_soc_version, get_ascend_soc_version
7+
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
88

99
parser = argparse.ArgumentParser(
1010
description="Arguments of rank table generator", )
@@ -42,8 +42,7 @@
4242
# and is different from WORLD_SIZE in gen_rank_table.sh.
4343
world_size = os.environ.get("WORLD_SIZE")
4444

45-
init_ascend_soc_version()
46-
soc_info = get_ascend_soc_version()
45+
device_type = get_ascend_device_type()
4746

4847

4948
def get_cmd_stdout(cmd):
@@ -83,7 +82,7 @@ def get_cmd_stdout(cmd):
8382
device_id = local_device_ids[idx]
8483
chip_id = device_id % chips_per_card
8584
card_id = device_id // chips_per_card
86-
if soc_info == AscendSocVersion.A3:
85+
if device_type == AscendDeviceType._910_93:
8786
device_ip = get_cmd_stdout(
8887
f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr"
8988
).split(":")[1].strip()
@@ -103,7 +102,7 @@ def get_cmd_stdout(cmd):
103102
"device_id": str(device_id),
104103
"device_ip": str(device_ip),
105104
}
106-
if soc_info == AscendSocVersion.A3:
105+
if device_type == AscendDeviceType._910_93:
107106
device_info.update({
108107
"super_pod_id": str(super_pod_id),
109108
"super_device_id": str(super_device_id)

setup.py

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,25 +65,103 @@ def check_or_set_default_env(cmake_args,
6565
return cmake_args
6666

6767

68+
def get_value_from_lines(lines: List[str], key: str) -> str:
69+
for line in lines:
70+
line = ' '.join(line.split())
71+
if key in line:
72+
return line.split(':')[-1].strip()
73+
return ""
74+
75+
76+
def get_chip_info() -> str:
77+
try:
78+
npu_info_lines = subprocess.check_output(
79+
['npu-smi', 'info', '-l']).decode().strip().split('\n')
80+
npu_id = int(get_value_from_lines(npu_info_lines, 'NPU ID'))
81+
chip_info_lines = subprocess.check_output(
82+
['npu-smi', 'info', '-t', 'board', '-i',
83+
str(npu_id), '-c', '0']).decode().strip().split('\n')
84+
chip_name = get_value_from_lines(chip_info_lines, 'Chip Name')
85+
chip_type = get_value_from_lines(chip_info_lines, 'Chip Type')
86+
npu_name = get_value_from_lines(chip_info_lines, 'NPU Name')
87+
88+
if "310" in chip_name:
89+
# 310P case
90+
assert chip_type
91+
return (chip_type + chip_name).lower()
92+
elif "910" in chip_name:
93+
if chip_type:
94+
# A2 case
95+
assert not npu_name
96+
return (chip_type + chip_name).lower()
97+
else:
98+
# A3 case
99+
assert npu_name
100+
return (chip_name + '_' + npu_name).lower()
101+
else:
102+
# TODO(zzzzwwjj): Currently, A5's chip name has not determined yet.
103+
raise ValueError(
104+
f"Unable to recognize chip name: {chip_name}, please manually set env SOC_VERSION"
105+
)
106+
except subprocess.CalledProcessError as e:
107+
raise RuntimeError(f"Get chip info failed: {e}")
108+
except FileNotFoundError:
109+
# cpu envir, release code case, return `ascend910b1` by default
110+
return "ascend910b1"
111+
112+
68113
envs = load_module_from_path("envs",
69114
os.path.join(ROOT_DIR, "vllm_ascend", "envs.py"))
70115

116+
soc_version = get_chip_info()
117+
118+
if not envs.SOC_VERSION:
119+
envs.SOC_VERSION = soc_version
120+
else:
121+
if envs.SOC_VERSION != soc_version:
122+
logging.warning(
123+
f"env SOC_VERSION: {envs.SOC_VERSION} is not equal to soc_version from npu-smi: {soc_version}"
124+
)
125+
71126

72127
def gen_build_info():
73128
soc_version = envs.SOC_VERSION
74-
if not soc_version:
75-
raise ValueError(
76-
"SOC version is not set. Please set SOC_VERSION environment variable."
77-
)
78129
if "310" in soc_version and not envs.COMPILE_CUSTOM_KERNELS:
79130
raise ValueError(
80131
"SOC version 310 only supports custom kernels. Please set COMPILE_CUSTOM_KERNELS=1 to enable custom kernels."
81132
)
82133

134+
# TODO(zzzzwwjj): Add A5 case
135+
soc_to_device = {
136+
"ascend910b1": "_910B",
137+
"ascend910b2": "_910B",
138+
"ascend910b2c": "_910B",
139+
"ascend910b3": "_910B",
140+
"ascend910b4": "_910B",
141+
"ascend910b4-1": "_910B",
142+
"ascend910_9391": "_910_93",
143+
"ascend910_9381": "_910_93",
144+
"ascend910_9372": "_910_93",
145+
"ascend910_9392": "_910_93",
146+
"ascend910_9382": "_910_93",
147+
"ascend910_9362": "_910_93",
148+
"ascend310p1": "_310P",
149+
"ascend310p3": "_310P",
150+
"ascend310p5": "_310P",
151+
"ascend310p7": "_310P",
152+
"ascend310p3vir01": "_310P",
153+
"ascend310p3vir02": "_310P",
154+
"ascend310p3vir04": "_310P",
155+
"ascend310p3vir08": "_310P",
156+
}
157+
158+
assert soc_version in soc_to_device, f"Undefined soc_version: {soc_version}. Please file an issue to vllm-ascend."
159+
device_type = soc_to_device[soc_version]
160+
83161
package_dir = os.path.join(ROOT_DIR, "vllm_ascend", "_build_info.py")
84162
with open(package_dir, "w+") as f:
85163
f.write('# Auto-generated file\n')
86-
f.write(f"__soc_version__ = '{soc_version}'\n")
164+
f.write(f"__device_type__ = '{device_type}'\n")
87165
f.write(f"__sleep_mode_enabled__ = {envs.COMPILE_CUSTOM_KERNELS}\n")
88166
logging.info(f"Generated _build_info.py with SOC version: {soc_version}")
89167

tests/ut/attention/test_attention_v1.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AscendAttentionMetadataBuilder,
1010
AscendAttentionState)
1111
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
12+
from vllm_ascend.utils import AscendDeviceType
1213

1314

1415
class TestAscendAttentionBackend(TestBase):
@@ -24,14 +25,15 @@ def test_get_builder_cls(self):
2425
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
2526
AscendAttentionMetadataBuilder)
2627

27-
@patch('vllm_ascend.attention.attention_v1.is_310p')
28-
def test_get_kv_cache_shape_310p(self, mock_is_310p):
29-
mock_is_310p.return_value = True
28+
@patch('vllm_ascend.attention.attention_v1.get_ascend_device_type',
29+
return_value=AscendDeviceType._310P)
30+
def test_get_kv_cache_shape_310p(self, mock_soc_version):
3031
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
3132
self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16))
3233

33-
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
34-
def test_get_kv_cache_shape_not_310p(self, mock_is_310p):
34+
@patch('vllm_ascend.utils.get_ascend_device_type',
35+
return_value=AscendDeviceType._910_93)
36+
def test_get_kv_cache_shape_not_310p(self, mock_soc_version):
3537
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
3638
self.assertEqual(result, (2, 10, 20, 30, 40))
3739

@@ -96,8 +98,9 @@ def test_reorder_batch(self):
9698
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
9799
@patch('torch_npu.npu_format_cast')
98100
@patch('vllm_ascend.utils.nd_to_nz_2d')
99-
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
100-
def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
101+
@patch('vllm_ascend.utils.get_ascend_device_type',
102+
return_value=AscendDeviceType._310P)
103+
def test_build_prefill_no_cache(self, mock_soc_version, mock_nd_to_nz_2d,
101104
mock_npu_format_cast,
102105
mock_ascend_metadata):
103106
common_attn_metadata = AscendCommonAttentionMetadata(
@@ -128,10 +131,11 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
128131
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
129132
@patch('torch_npu.npu_format_cast')
130133
@patch('vllm_ascend.utils.nd_to_nz_spec')
131-
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
134+
@patch('vllm_ascend.utils.get_ascend_device_type',
135+
return_value=AscendDeviceType._310P)
132136
@patch('vllm_ascend.attention.attention_v1.AscendAttentionState')
133137
def test_build_chunked_prefill(self, mock_ascend_attention_state,
134-
mock_is_310p, mock_nd_to_nz_spec,
138+
mock_soc_version, mock_nd_to_nz_spec,
135139
mock_npu_format_cast, mock_ascend_metadata):
136140
common_attn_metadata = AscendCommonAttentionMetadata(
137141
query_start_loc=torch.tensor([0, 2, 5, 9]),
@@ -162,8 +166,9 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state,
162166
self.builder.build(1, common_attn_metadata, mock_model)
163167

164168
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
165-
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
166-
def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
169+
@patch('vllm_ascend.utils.get_ascend_device_type',
170+
return_value=AscendDeviceType._910_93)
171+
def test_build_non_310p(self, mock_soc_version, mock_ascend_metadata):
167172
common_attn_metadata = AscendCommonAttentionMetadata(
168173
query_start_loc=torch.tensor([0, 2, 5, 9]),
169174
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
@@ -450,12 +455,13 @@ def test_forward_decode_only_swa_seq_len_mismatch(
450455
assert output.shape == (10, 8 * 64)
451456

452457
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
453-
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
458+
@patch('vllm_ascend.utils.get_ascend_device_type',
459+
return_value=AscendDeviceType._910_93)
454460
@patch('torch_npu._npu_reshape_and_cache')
455461
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
456462
def test_forward_head_size_192(self, mock_vanilla_prefill,
457-
mock_npu_reshape_and_cache, mock_is_310p,
458-
mock_get_forward_context):
463+
mock_npu_reshape_and_cache,
464+
mock_soc_version, mock_get_forward_context):
459465
"""Test forward pass when head_size is 192"""
460466

461467
self.impl.head_size = 192
@@ -522,9 +528,11 @@ def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache,
522528
@patch('torch_npu.npu_format_cast')
523529
@patch('torch_npu._npu_reshape_and_cache')
524530
@patch('torch_npu.npu_fused_infer_attention_score')
525-
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
531+
@patch('vllm_ascend.utils.get_ascend_device_type',
532+
return_value=AscendDeviceType._310P)
526533
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
527-
def test_forward_310p_device(self, mock_get_forward_context, mock_is_310p,
534+
def test_forward_310p_device(self, mock_get_forward_context,
535+
mock_soc_version,
528536
mock_npu_fused_infer_attention_score,
529537
mock_npu_reshape_and_cache,
530538
mock_npu_format_cast):

tests/ut/models/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def mock_distributed():
9292

9393
with patch("vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
9494
patch("vllm_ascend.ops.fused_moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
95-
patch("vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
95+
patch("vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_device_type", return_value=None), \
9696
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
9797
_PP=pp_group), \
9898
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \

tests/ut/ops/test_activation.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import torch
2020
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
2121

22+
from vllm_ascend.utils import AscendDeviceType
23+
2224

2325
@pytest.fixture
2426
def dummy_tensor():
@@ -36,20 +38,22 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor):
3638
mock_gelu.assert_called_once()
3739

3840

39-
@pytest.mark.parametrize("is_310p_return", [True, False])
41+
@pytest.mark.parametrize("is_310p", [True, False])
4042
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
4143
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
4244
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj",
4345
side_effect=lambda x: None)
4446
def test_SiluAndMul_forward(mock_maybe_prefetch_mlp_down_proj,
4547
mock_maybe_wait_prefetch_done, mock_swiglu,
46-
is_310p_return, dummy_tensor):
48+
is_310p, dummy_tensor):
4749

48-
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
50+
with patch("vllm_ascend.utils.get_ascend_device_type",
51+
return_value=AscendDeviceType._310P
52+
if is_310p else AscendDeviceType._910_93):
4953
layer = SiluAndMul()
5054
out = layer.forward(dummy_tensor)
5155

52-
if is_310p_return:
56+
if is_310p:
5357
expected_arg = dummy_tensor.to(torch.float32)
5458
else:
5559
expected_arg = dummy_tensor

tests/ut/ops/test_fused_moe.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
AscendFusedMoE, AscendUnquantizedFusedMoEMethod)
3030
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
3131
unified_apply_mlp)
32-
from vllm_ascend.utils import AscendSocVersion, adapt_patch
32+
from vllm_ascend.utils import AscendDeviceType, adapt_patch
3333

3434
adapt_patch(True)
3535

@@ -129,7 +129,7 @@ def mock_finalize(hidden_states, **kwargs):
129129
return_value=mock_forward_context_obj), \
130130
patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context',
131131
return_value=mock_forward_context_obj), \
132-
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
132+
patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType._910_93), \
133133
patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context',
134134
return_value=mock_forward_context_obj), \
135135
patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
@@ -323,22 +323,21 @@ def test_cumsum_group_list_with_type_2(self):
323323
class TestUnifiedApplyMLP(TestBase):
324324

325325
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
326-
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
326+
@patch('vllm_ascend.utils.get_ascend_device_type',
327+
return_value=AscendDeviceType._910_93)
327328
@patch('torch_npu.npu_grouped_matmul')
328329
@patch('torch_npu.npu_dynamic_quant')
329330
@patch('torch_npu.npu_dequant_swiglu_quant')
330331
def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
331332
mock_npu_dynamic_quant,
332333
mock_npu_grouped_matmul,
333-
mock_is_310p,
334+
mock_soc_version,
334335
mock_get_forward_context):
335336

336337
mock_forward_context = MagicMock()
337338
mock_forward_context.moe_comm_type = MoECommType.MC2
338339
mock_get_forward_context.return_value = mock_forward_context
339340

340-
mock_is_310p.return_value = False
341-
342341
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
343342
127, (10, 20),
344343
dtype=torch.int8),
@@ -387,17 +386,16 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
387386

388387
self.assertEqual(result.dtype, torch.bfloat16)
389388

390-
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
389+
@patch('vllm_ascend.utils.get_ascend_device_type',
390+
return_value=AscendDeviceType._910_93)
391391
@patch('torch_npu.npu_grouped_matmul')
392392
@patch('torch_npu.npu_swiglu')
393393
@patch('torch_npu.npu_dynamic_quant')
394394
def test_unified_apply_mlp_without_quantization(self,
395395
mock_npu_dynamic_quant,
396396
mock_npu_swiglu,
397397
mock_npu_grouped_matmul,
398-
mock_is_310p):
399-
mock_is_310p.return_value = False
400-
398+
mock_soc_version):
401399
mock_npu_grouped_matmul.side_effect = [[
402400
torch.randn(10, 40, dtype=torch.float16)
403401
], [torch.randn(10, 20, dtype=torch.float16)]]
@@ -490,15 +488,14 @@ def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
490488
self.assertEqual(result.shape, hidden_states_shape)
491489
self.assertEqual(result.dtype, torch.bfloat16)
492490

493-
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
491+
@patch('vllm_ascend.utils.get_ascend_device_type',
492+
return_value=AscendDeviceType._310P)
494493
@patch('torch_npu.npu_grouped_matmul')
495494
@patch('torch_npu.npu_swiglu')
496495
@patch('torch_npu.npu_dynamic_quant')
497496
def test_unified_apply_mlp_without_quantization_310p(
498497
self, mock_npu_dynamic_quant, mock_npu_swiglu,
499-
mock_npu_grouped_matmul, mock_is_310p):
500-
mock_is_310p.return_value = True
501-
498+
mock_npu_grouped_matmul, mock_soc_version):
502499
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
503500
mock_gmm2_out = torch.randn(10, 20, dtype=torch.float16)
504501
mock_npu_grouped_matmul.side_effect = [[mock_gmm1_out],
@@ -527,8 +524,6 @@ def test_unified_apply_mlp_without_quantization_310p(
527524
topk_scales=topk_scales,
528525
with_quant=False)
529526

530-
mock_is_310p.assert_called_once()
531-
532527
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
533528
mock_npu_swiglu.assert_called_once()
534529

0 commit comments

Comments
 (0)