11from unittest .mock import MagicMock , patch
22
33import torch
4- from vllm .config import VllmConfig
4+ from vllm .config import CacheConfig , ModelConfig , SchedulerConfig , VllmConfig
55from vllm .distributed .parallel_state import GroupCoordinator
66from vllm .model_executor .layers .linear import LinearBase
77
@@ -184,19 +184,15 @@ class TestAscendMLAMetadataBuilder(TestBase):
184184 return_value = 1 )
185185 def test_ascend_mla_metadata_builder_default (self , mock_get_dcp_size ,
186186 mock_dcp , mock_get_dcp_group ):
187- mock_model_config = MagicMock ()
188- mock_model_config .max_model_len = 1024
189- mock_model_config .get_head_size .return_value = 64
190- mock_model_config .dtype = torch .float16
191-
192187 mock_vllm_config = MagicMock ()
193- mock_vllm_config .model_config = mock_model_config
194- mock_vllm_config .cache_config = MagicMock (block_size = 16 )
195- mock_vllm_config .scheduler_config = MagicMock (
196- max_num_seqs = 4 , enable_chunked_prefill = False )
197- mock_vllm_config .speculative_config = None
198-
199- mock_device = torch .device ('cpu' )
188+ mock_vllm_config .model_config .max_model_len = 1024
189+ mock_vllm_config .model_config .get_head_size .return_value = 64
190+ mock_vllm_config .model_config .dtype = torch .float16
191+ mock_vllm_config .cache_config .block_size = 16
192+ mock_vllm_config .scheduler_config .max_num_seqs = 4
193+ mock_vllm_config .scheduler_config .decode_max_num_seqs = 4
194+ mock_vllm_config .scheduler_config .enable_chunked_prefill = False
195+ mock_device = 'cpu'
200196
201197 mock_dcp .world_size = 1
202198 dcp_group = MagicMock (spec = GroupCoordinator )
@@ -205,6 +201,8 @@ def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size,
205201 dcp_group .device_group = MagicMock ()
206202 mock_get_dcp_group .return_value = dcp_group
207203
204+ mock_vllm_config .speculative_config = None
205+
208206 ascend_config = MagicMock ()
209207 with patch ("vllm_ascend.attention.mla_v1.get_ascend_config" ,
210208 return_value = ascend_config ):
@@ -225,19 +223,15 @@ def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size,
225223 def test_ascend_mla_metadata_builder_spec_decode (self , mock_get_dcp_size ,
226224 mock_dcp ,
227225 mock_get_dcp_group ):
228- mock_model_config = MagicMock ()
229- mock_model_config .max_model_len = 1024
230- mock_model_config .get_head_size .return_value = 64
231- mock_model_config .dtype = torch .float16
232-
233226 mock_vllm_config = MagicMock ()
234- mock_vllm_config .model_config = mock_model_config
235- mock_vllm_config .cache_config = MagicMock (block_size = 16 )
236- mock_vllm_config .scheduler_config = MagicMock (
237- max_num_seqs = 4 , enable_chunked_prefill = False )
238- mock_vllm_config .speculative_config = None
239-
240- mock_device = torch .device ('cpu' )
227+ mock_vllm_config .model_config .max_model_len = 1024
228+ mock_vllm_config .model_config .get_head_size .return_value = 64
229+ mock_vllm_config .model_config .dtype = torch .float16
230+ mock_vllm_config .cache_config .block_size = 16
231+ mock_vllm_config .scheduler_config .max_num_seqs = 4
232+ mock_vllm_config .scheduler_config .decode_max_num_seqs = 4
233+ mock_vllm_config .scheduler_config .enable_chunked_prefill = False
234+ mock_device = 'cpu'
241235
242236 mock_dcp .world_size = 1
243237 dcp_group = MagicMock (spec = GroupCoordinator )
@@ -260,7 +254,7 @@ def test_ascend_mla_metadata_builder_spec_decode(self, mock_get_dcp_size,
260254 mock_vllm_config .cache_config .block_size )
261255 self .assertEqual (
262256 builder .chunked_prefill_enabled ,
263- mock_vllm_config .scheduler_config .chunked_prefill_enabled )
257+ mock_vllm_config .scheduler_config .enable_chunked_prefill )
264258
265259 @patch ('vllm.distributed.parallel_state.get_dcp_group' )
266260 @patch ('vllm.distributed.parallel_state._DCP' ,
@@ -322,19 +316,13 @@ def test_reorder_batch(self, mock_get_dcp_size, mock_dcp,
322316 mock_get_dcp_group ):
323317 ascend_config = MagicMock ()
324318
325- mock_model_config = MagicMock ()
326- mock_model_config .max_model_len = 1024
327- mock_model_config .get_head_size .return_value = 64
328- mock_model_config .dtype = torch .float16
329-
330319 mock_vllm_config = MagicMock ()
331- mock_vllm_config .model_config = mock_model_config
332- mock_vllm_config .cache_config = MagicMock (block_size = 16 )
333- mock_vllm_config .scheduler_config = MagicMock (
334- max_num_seqs = 4 , enable_chunked_prefill = False )
335- mock_vllm_config .speculative_config = None
336-
337- mock_device = torch .device ('cpu' )
320+ mock_vllm_config .model_config .max_model_len = 1024
321+ mock_vllm_config .cache_config .block_size = 16
322+ mock_vllm_config .scheduler_config .max_num_seqs = 4
323+ mock_vllm_config .scheduler_config .decode_max_num_seqs = 4
324+ mock_vllm_config .scheduler_config .enable_chunked_prefill = False
325+ mock_device = 'cpu'
338326
339327 mock_dcp .world_size = 1
340328 dcp_group = MagicMock (spec = GroupCoordinator )
@@ -343,6 +331,8 @@ def test_reorder_batch(self, mock_get_dcp_size, mock_dcp,
343331 dcp_group .device_group = MagicMock ()
344332 mock_get_dcp_group .return_value = dcp_group
345333
334+ mock_vllm_config .speculative_config = None
335+
346336 with patch ("vllm_ascend.attention.mla_v1.get_ascend_config" ,
347337 return_value = ascend_config ):
348338 builder = AscendMLAMetadataBuilder (None , None , mock_vllm_config ,
@@ -447,21 +437,15 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
447437
448438 def setUp (self ):
449439 self .mock_vllm_config = MagicMock (spec = VllmConfig )
450- # NOTE: Do not init the ModelConfig from constructor
451- # Which will try to download a model
452- mock_model_config = MagicMock ()
453- mock_model_config .max_model_len = 1024
454- mock_model_config .get_head_size .return_value = 64
455- mock_model_config .dtype = torch .float16
456-
457- from vllm .config .scheduler import SchedulerConfig
458- self .mock_vllm_config .scheduler_config = SchedulerConfig ()
459-
460- self .mock_vllm_config .model_config = mock_model_config
461- self .mock_vllm_config .cache_config = MagicMock (block_size = 16 )
440+ self .mock_vllm_config .model_config = ModelConfig (max_model_len = 2048 )
441+ self .mock_vllm_config .model_config .hf_text_config .qk_rope_head_dim = 32
442+ self .mock_vllm_config .cache_config = CacheConfig (block_size = 32 )
443+ mock_scheduler_config = MagicMock (spec = SchedulerConfig )
444+ mock_scheduler_config .max_num_seqs = 8 # 设置为整数,不是 MagicMock
445+ mock_scheduler_config .chunked_prefill_enabled = True
446+ self .mock_vllm_config .scheduler_config = mock_scheduler_config
462447 self .mock_vllm_config .speculative_config = None
463-
464- self .mock_device = torch .device ('cpu' )
448+ self .mock_device = torch .device ("cpu" )
465449
466450 self .kv_cache_spec = MagicMock ()
467451 self .kv_cache_spec .num_layers = 32
0 commit comments