@@ -253,16 +253,20 @@ def test_reorder_batch_without_torchair_graph(self):
253253 ascend_config .torchair_graph_config = MagicMock ()
254254 ascend_config .torchair_graph_config .enabled = False
255255
256- mock_vllm_config = MagicMock ()
257- mock_vllm_config .model_config .max_model_len = 1024
258- mock_vllm_config .get_head_size = lambda : 8
259- mock_vllm_config .cache_config .block_size = 16
260- mock_vllm_config .scheduler_config .max_num_seqs = 4
261- mock_vllm_config .scheduler_config .enable_chunked_prefill = False
262- mock_device = torch .device ('cpu' )
256+ mock_model_config = MagicMock ()
257+ mock_model_config .max_model_len = 1024
258+ mock_model_config .get_head_size .return_value = 64
259+ mock_model_config .dtype = torch .float16
263260
261+ mock_vllm_config = MagicMock ()
262+ mock_vllm_config .model_config = mock_model_config
263+ mock_vllm_config .cache_config = MagicMock (block_size = 16 )
264+ mock_vllm_config .scheduler_config = MagicMock (
265+ max_num_seqs = 4 , enable_chunked_prefill = False )
264266 mock_vllm_config .speculative_config = None
265267
268+ mock_device = torch .device ('cpu' )
269+
266270 with patch ("vllm_ascend.torchair.torchair_mla.get_ascend_config" ,
267271 return_value = ascend_config ):
268272 builder = AscendMLATorchairMetadataBuilder (None , None ,
@@ -293,14 +297,21 @@ def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
293297 ascend_config = MagicMock ()
294298 mock_ascend_config .return_value = ascend_config
295299 ascend_config .torchair_graph_config .enabled = False
296- mock_vllm_config = MagicMock ()
297- mock_vllm_config .model_config .max_model_len = 1024
298- mock_vllm_config .cache_config .block_size = 16
299- mock_vllm_config .scheduler_config .enable_chunked_prefill = False
300- mock_device = torch .device ('cpu' )
301300
301+ mock_model_config = MagicMock ()
302+ mock_model_config .max_model_len = 1024
303+ mock_model_config .get_head_size .return_value = 64
304+ mock_model_config .dtype = torch .float16
305+
306+ mock_vllm_config = MagicMock ()
307+ mock_vllm_config .model_config = mock_model_config
308+ mock_vllm_config .cache_config = MagicMock (block_size = 16 )
309+ mock_vllm_config .scheduler_config = MagicMock (
310+ max_num_seqs = 4 , enable_chunked_prefill = False )
302311 mock_vllm_config .speculative_config = None
303312
313+ mock_device = torch .device ('cpu' )
314+
304315 builder = AscendMLATorchairMetadataBuilder (None , None ,
305316 mock_vllm_config ,
306317 mock_device )
@@ -316,14 +327,21 @@ def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
316327 ascend_config = MagicMock ()
317328 mock_ascend_config .return_value = ascend_config
318329 ascend_config .torchair_graph_config .enabled = False
319- mock_vllm_config = MagicMock ()
320- mock_vllm_config .model_config .max_model_len = 64
321- mock_vllm_config .cache_config .block_size = 16
322- mock_vllm_config .scheduler_config .enable_chunked_prefill = False
323- mock_device = torch .device ('cpu' )
324330
331+ mock_model_config = MagicMock ()
332+ mock_model_config .max_model_len = 1024
333+ mock_model_config .get_head_size .return_value = 64
334+ mock_model_config .dtype = torch .float16
335+
336+ mock_vllm_config = MagicMock ()
337+ mock_vllm_config .model_config = mock_model_config
338+ mock_vllm_config .cache_config = MagicMock (block_size = 16 )
339+ mock_vllm_config .scheduler_config = MagicMock (
340+ max_num_seqs = 4 , enable_chunked_prefill = False )
325341 mock_vllm_config .speculative_config = None
326342
343+ mock_device = torch .device ('cpu' )
344+
327345 builder = AscendMLATorchairMetadataBuilder (None , None ,
328346 mock_vllm_config ,
329347 mock_device )
@@ -340,16 +358,21 @@ def test_get_graph_runner_block_tables_from_numpy(self,
340358 ascend_config = MagicMock ()
341359 mock_ascend_config .return_value = ascend_config
342360 ascend_config .torchair_graph_config .enabled = False
343- mock_vllm_config = MagicMock ()
344- mock_vllm_config .model_config .max_model_len = 1024
345- mock_vllm_config .cache_config .block_size = 16
346- mock_vllm_config .get_head_size = lambda : 28
347- mock_vllm_config .dtype = torch .bfloat16
348- mock_vllm_config .scheduler_config .enable_chunked_prefill = False
349- mock_device = torch .device ('cpu' )
350361
362+ mock_model_config = MagicMock ()
363+ mock_model_config .max_model_len = 1024
364+ mock_model_config .get_head_size .return_value = 64
365+ mock_model_config .dtype = torch .float16
366+
367+ mock_vllm_config = MagicMock ()
368+ mock_vllm_config .model_config = mock_model_config
369+ mock_vllm_config .cache_config = MagicMock (block_size = 16 )
370+ mock_vllm_config .scheduler_config = MagicMock (
371+ max_num_seqs = 4 , enable_chunked_prefill = False )
351372 mock_vllm_config .speculative_config = None
352373
374+ mock_device = torch .device ('cpu' )
375+
353376 builder = AscendMLATorchairMetadataBuilder (None , None ,
354377 mock_vllm_config ,
355378 mock_device )
@@ -368,16 +391,20 @@ def test_build_dummy(self, mock_ascend_config):
368391 mock_ascend_config .return_value = ascend_config
369392 ascend_config .torchair_graph_config .enabled = False
370393
371- mock_vllm_config = MagicMock ()
372- mock_vllm_config .model_config .max_model_len = 1024
373- mock_vllm_config .cache_config .block_size = 16
374- mock_vllm_config .scheduler_config .enable_chunked_prefill = False
375- mock_vllm_config .get_head_size .return_value = 64
376- mock_vllm_config .model_config .dtype = torch .float16
377- mock_device = torch .device ('cpu' )
394+ mock_model_config = MagicMock ()
395+ mock_model_config .max_model_len = 1024
396+ mock_model_config .get_head_size .return_value = 64
397+ mock_model_config .dtype = torch .float16
378398
399+ mock_vllm_config = MagicMock ()
400+ mock_vllm_config .model_config = mock_model_config
401+ mock_vllm_config .cache_config = MagicMock (block_size = 16 )
402+ mock_vllm_config .scheduler_config = MagicMock (
403+ max_num_seqs = 4 , enable_chunked_prefill = False )
379404 mock_vllm_config .speculative_config = None
380405
406+ mock_device = torch .device ('cpu' )
407+
381408 builder = AscendMLATorchairMetadataBuilder (
382409 None ,
383410 None ,
@@ -435,18 +462,23 @@ def test_build_decode(self, mock_ascend_config):
435462 mock_ascend_config .return_value = ascend_config
436463 ascend_config .torchair_graph_config .enabled = False
437464
465+ mock_model_config = MagicMock ()
466+ mock_model_config .max_model_len = 1024
467+ mock_model_config .get_head_size .return_value = 64
468+ mock_model_config .dtype = torch .float16
469+
438470 mock_vllm_config = MagicMock ()
439- mock_vllm_config .model_config .max_model_len = 1024
440- mock_vllm_config .cache_config .block_size = 16
441- mock_vllm_config .scheduler_config .enable_chunked_prefill = False
442- mock_vllm_config .get_head_size .return_value = 64
443- mock_vllm_config .model_config .dtype = torch .float16
471+ mock_vllm_config .model_config = mock_model_config
472+ mock_vllm_config .cache_config = MagicMock (block_size = 16 )
473+ mock_vllm_config .scheduler_config = MagicMock (
474+ max_num_seqs = 4 , enable_chunked_prefill = False )
475+ mock_vllm_config .speculative_config = None
476+
444477 mock_device = torch .device ('cpu' )
478+
445479 model = MagicMock (spec = nn .Module )
446480 model .model = MagicMock (spec = nn .Module )
447481
448- mock_vllm_config .speculative_config = None
449-
450482 builder = AscendMLATorchairMetadataBuilder (
451483 None ,
452484 None ,
0 commit comments