Skip to content

Commit 660c0c9

Browse files
committed
fix ut
Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent f5d3349 commit 660c0c9

File tree

1 file changed

+36
-52
lines changed

1 file changed

+36
-52
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 36 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from unittest.mock import MagicMock, patch
22

33
import torch
4-
from vllm.config import VllmConfig
4+
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
55
from vllm.distributed.parallel_state import GroupCoordinator
66
from vllm.model_executor.layers.linear import LinearBase
77

@@ -184,19 +184,15 @@ class TestAscendMLAMetadataBuilder(TestBase):
184184
return_value=1)
185185
def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size,
186186
mock_dcp, mock_get_dcp_group):
187-
mock_model_config = MagicMock()
188-
mock_model_config.max_model_len = 1024
189-
mock_model_config.get_head_size.return_value = 64
190-
mock_model_config.dtype = torch.float16
191-
192187
mock_vllm_config = MagicMock()
193-
mock_vllm_config.model_config = mock_model_config
194-
mock_vllm_config.cache_config = MagicMock(block_size=16)
195-
mock_vllm_config.scheduler_config = MagicMock(
196-
max_num_seqs=4, enable_chunked_prefill=False)
197-
mock_vllm_config.speculative_config = None
198-
199-
mock_device = torch.device('cpu')
188+
mock_vllm_config.model_config.max_model_len = 1024
189+
mock_vllm_config.model_config.get_head_size.return_value = 64
190+
mock_vllm_config.model_config.dtype = torch.float16
191+
mock_vllm_config.cache_config.block_size = 16
192+
mock_vllm_config.scheduler_config.max_num_seqs = 4
193+
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
194+
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
195+
mock_device = 'cpu'
200196

201197
mock_dcp.world_size = 1
202198
dcp_group = MagicMock(spec=GroupCoordinator)
@@ -205,6 +201,8 @@ def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size,
205201
dcp_group.device_group = MagicMock()
206202
mock_get_dcp_group.return_value = dcp_group
207203

204+
mock_vllm_config.speculative_config = None
205+
208206
ascend_config = MagicMock()
209207
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
210208
return_value=ascend_config):
@@ -225,19 +223,15 @@ def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size,
225223
def test_ascend_mla_metadata_builder_spec_decode(self, mock_get_dcp_size,
226224
mock_dcp,
227225
mock_get_dcp_group):
228-
mock_model_config = MagicMock()
229-
mock_model_config.max_model_len = 1024
230-
mock_model_config.get_head_size.return_value = 64
231-
mock_model_config.dtype = torch.float16
232-
233226
mock_vllm_config = MagicMock()
234-
mock_vllm_config.model_config = mock_model_config
235-
mock_vllm_config.cache_config = MagicMock(block_size=16)
236-
mock_vllm_config.scheduler_config = MagicMock(
237-
max_num_seqs=4, enable_chunked_prefill=False)
238-
mock_vllm_config.speculative_config = None
239-
240-
mock_device = torch.device('cpu')
227+
mock_vllm_config.model_config.max_model_len = 1024
228+
mock_vllm_config.model_config.get_head_size.return_value = 64
229+
mock_vllm_config.model_config.dtype = torch.float16
230+
mock_vllm_config.cache_config.block_size = 16
231+
mock_vllm_config.scheduler_config.max_num_seqs = 4
232+
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
233+
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
234+
mock_device = 'cpu'
241235

242236
mock_dcp.world_size = 1
243237
dcp_group = MagicMock(spec=GroupCoordinator)
@@ -260,7 +254,7 @@ def test_ascend_mla_metadata_builder_spec_decode(self, mock_get_dcp_size,
260254
mock_vllm_config.cache_config.block_size)
261255
self.assertEqual(
262256
builder.chunked_prefill_enabled,
263-
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
257+
mock_vllm_config.scheduler_config.enable_chunked_prefill)
264258

265259
@patch('vllm.distributed.parallel_state.get_dcp_group')
266260
@patch('vllm.distributed.parallel_state._DCP',
@@ -322,19 +316,13 @@ def test_reorder_batch(self, mock_get_dcp_size, mock_dcp,
322316
mock_get_dcp_group):
323317
ascend_config = MagicMock()
324318

325-
mock_model_config = MagicMock()
326-
mock_model_config.max_model_len = 1024
327-
mock_model_config.get_head_size.return_value = 64
328-
mock_model_config.dtype = torch.float16
329-
330319
mock_vllm_config = MagicMock()
331-
mock_vllm_config.model_config = mock_model_config
332-
mock_vllm_config.cache_config = MagicMock(block_size=16)
333-
mock_vllm_config.scheduler_config = MagicMock(
334-
max_num_seqs=4, enable_chunked_prefill=False)
335-
mock_vllm_config.speculative_config = None
336-
337-
mock_device = torch.device('cpu')
320+
mock_vllm_config.model_config.max_model_len = 1024
321+
mock_vllm_config.cache_config.block_size = 16
322+
mock_vllm_config.scheduler_config.max_num_seqs = 4
323+
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
324+
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
325+
mock_device = 'cpu'
338326

339327
mock_dcp.world_size = 1
340328
dcp_group = MagicMock(spec=GroupCoordinator)
@@ -343,6 +331,8 @@ def test_reorder_batch(self, mock_get_dcp_size, mock_dcp,
343331
dcp_group.device_group = MagicMock()
344332
mock_get_dcp_group.return_value = dcp_group
345333

334+
mock_vllm_config.speculative_config = None
335+
346336
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
347337
return_value=ascend_config):
348338
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
@@ -447,21 +437,15 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
447437

448438
def setUp(self):
449439
self.mock_vllm_config = MagicMock(spec=VllmConfig)
450-
# NOTE: Do not init the ModelConfig from constructor
451-
# Which will try to download a model
452-
mock_model_config = MagicMock()
453-
mock_model_config.max_model_len = 1024
454-
mock_model_config.get_head_size.return_value = 64
455-
mock_model_config.dtype = torch.float16
456-
457-
from vllm.config.scheduler import SchedulerConfig
458-
self.mock_vllm_config.scheduler_config = SchedulerConfig()
459-
460-
self.mock_vllm_config.model_config = mock_model_config
461-
self.mock_vllm_config.cache_config = MagicMock(block_size=16)
440+
self.mock_vllm_config.model_config = ModelConfig(max_model_len=2048)
441+
self.mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 32
442+
self.mock_vllm_config.cache_config = CacheConfig(block_size=32)
443+
mock_scheduler_config = MagicMock(spec=SchedulerConfig)
444+
mock_scheduler_config.max_num_seqs = 8 # 设置为整数,不是 MagicMock
445+
mock_scheduler_config.chunked_prefill_enabled = True
446+
self.mock_vllm_config.scheduler_config = mock_scheduler_config
462447
self.mock_vllm_config.speculative_config = None
463-
464-
self.mock_device = torch.device('cpu')
448+
self.mock_device = torch.device("cpu")
465449

466450
self.kv_cache_spec = MagicMock()
467451
self.kv_cache_spec.num_layers = 32

0 commit comments

Comments
 (0)