Skip to content

Commit 545d69c

Browse files
committed
fix mla ut
Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent 2c10120 commit 545d69c

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from unittest.mock import MagicMock, patch
22

33
import torch
4-
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
4+
from vllm.config import VllmConfig
55
from vllm.distributed.parallel_state import GroupCoordinator
66
from vllm.model_executor.layers.linear import LinearBase
77

@@ -215,7 +215,7 @@ def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size,
215215
mock_vllm_config.cache_config.block_size)
216216
self.assertEqual(
217217
builder.chunked_prefill_enabled,
218-
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
218+
mock_vllm_config.scheduler_config.enable_chunked_prefill)
219219

220220
@patch('vllm.distributed.parallel_state.get_dcp_group')
221221
@patch('vllm.distributed.parallel_state._DCP',
@@ -447,15 +447,20 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
447447

448448
def setUp(self):
449449
self.mock_vllm_config = MagicMock(spec=VllmConfig)
450-
self.mock_vllm_config.model_config = ModelConfig(max_model_len=2048)
451-
self.mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 32
452-
self.mock_vllm_config.cache_config = CacheConfig(block_size=32)
453-
mock_scheduler_config = MagicMock(spec=SchedulerConfig)
454-
mock_scheduler_config.max_num_seqs = 8 # 设置为整数,不是 MagicMock
455-
mock_scheduler_config.chunked_prefill_enabled = True
456-
self.mock_vllm_config.scheduler_config = mock_scheduler_config
450+
# NOTE: Do not init the ModelConfig from constructor
451+
# Which will try to download a model
452+
mock_model_config = MagicMock()
453+
mock_model_config.max_model_len = 1024
454+
mock_model_config.get_head_size.return_value = 64
455+
mock_model_config.dtype = torch.float16
456+
457+
self.mock_vllm_config.model_config = mock_model_config
458+
self.mock_vllm_config.cache_config = MagicMock(block_size=16)
459+
self.mock_vllm_config.scheduler_config = MagicMock(
460+
max_num_seqs=4, enable_chunked_prefill=False)
457461
self.mock_vllm_config.speculative_config = None
458-
self.mock_device = torch.device("cpu")
462+
463+
self.mock_device = torch.device('cpu')
459464

460465
self.kv_cache_spec = MagicMock()
461466
self.kv_cache_spec.num_layers = 32

0 commit comments

Comments
 (0)