Skip to content

Commit fc31168

Browse files
committed
fix mla
Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent 3a773e2 commit fc31168

File tree

1 file changed

+70
-38
lines changed

1 file changed

+70
-38
lines changed

tests/ut/torchair/test_torchair_mla.py

Lines changed: 70 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -253,16 +253,20 @@ def test_reorder_batch_without_torchair_graph(self):
253253
ascend_config.torchair_graph_config = MagicMock()
254254
ascend_config.torchair_graph_config.enabled = False
255255

256-
mock_vllm_config = MagicMock()
257-
mock_vllm_config.model_config.max_model_len = 1024
258-
mock_vllm_config.get_head_size = lambda: 8
259-
mock_vllm_config.cache_config.block_size = 16
260-
mock_vllm_config.scheduler_config.max_num_seqs = 4
261-
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
262-
mock_device = torch.device('cpu')
256+
mock_model_config = MagicMock()
257+
mock_model_config.max_model_len = 1024
258+
mock_model_config.get_head_size.return_value = 64
259+
mock_model_config.dtype = torch.float16
263260

261+
mock_vllm_config = MagicMock()
262+
mock_vllm_config.model_config = mock_model_config
263+
mock_vllm_config.cache_config = MagicMock(block_size=16)
264+
mock_vllm_config.scheduler_config = MagicMock(
265+
max_num_seqs=4, enable_chunked_prefill=False)
264266
mock_vllm_config.speculative_config = None
265267

268+
mock_device = torch.device('cpu')
269+
266270
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
267271
return_value=ascend_config):
268272
builder = AscendMLATorchairMetadataBuilder(None, None,
@@ -293,14 +297,21 @@ def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
293297
ascend_config = MagicMock()
294298
mock_ascend_config.return_value = ascend_config
295299
ascend_config.torchair_graph_config.enabled = False
296-
mock_vllm_config = MagicMock()
297-
mock_vllm_config.model_config.max_model_len = 1024
298-
mock_vllm_config.cache_config.block_size = 16
299-
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
300-
mock_device = torch.device('cpu')
301300

301+
mock_model_config = MagicMock()
302+
mock_model_config.max_model_len = 1024
303+
mock_model_config.get_head_size.return_value = 64
304+
mock_model_config.dtype = torch.float16
305+
306+
mock_vllm_config = MagicMock()
307+
mock_vllm_config.model_config = mock_model_config
308+
mock_vllm_config.cache_config = MagicMock(block_size=16)
309+
mock_vllm_config.scheduler_config = MagicMock(
310+
max_num_seqs=4, enable_chunked_prefill=False)
302311
mock_vllm_config.speculative_config = None
303312

313+
mock_device = torch.device('cpu')
314+
304315
builder = AscendMLATorchairMetadataBuilder(None, None,
305316
mock_vllm_config,
306317
mock_device)
@@ -316,14 +327,21 @@ def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
316327
ascend_config = MagicMock()
317328
mock_ascend_config.return_value = ascend_config
318329
ascend_config.torchair_graph_config.enabled = False
319-
mock_vllm_config = MagicMock()
320-
mock_vllm_config.model_config.max_model_len = 64
321-
mock_vllm_config.cache_config.block_size = 16
322-
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
323-
mock_device = torch.device('cpu')
324330

331+
mock_model_config = MagicMock()
332+
mock_model_config.max_model_len = 1024
333+
mock_model_config.get_head_size.return_value = 64
334+
mock_model_config.dtype = torch.float16
335+
336+
mock_vllm_config = MagicMock()
337+
mock_vllm_config.model_config = mock_model_config
338+
mock_vllm_config.cache_config = MagicMock(block_size=16)
339+
mock_vllm_config.scheduler_config = MagicMock(
340+
max_num_seqs=4, enable_chunked_prefill=False)
325341
mock_vllm_config.speculative_config = None
326342

343+
mock_device = torch.device('cpu')
344+
327345
builder = AscendMLATorchairMetadataBuilder(None, None,
328346
mock_vllm_config,
329347
mock_device)
@@ -340,16 +358,21 @@ def test_get_graph_runner_block_tables_from_numpy(self,
340358
ascend_config = MagicMock()
341359
mock_ascend_config.return_value = ascend_config
342360
ascend_config.torchair_graph_config.enabled = False
343-
mock_vllm_config = MagicMock()
344-
mock_vllm_config.model_config.max_model_len = 1024
345-
mock_vllm_config.cache_config.block_size = 16
346-
mock_vllm_config.get_head_size = lambda: 28
347-
mock_vllm_config.dtype = torch.bfloat16
348-
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
349-
mock_device = torch.device('cpu')
350361

362+
mock_model_config = MagicMock()
363+
mock_model_config.max_model_len = 1024
364+
mock_model_config.get_head_size.return_value = 64
365+
mock_model_config.dtype = torch.float16
366+
367+
mock_vllm_config = MagicMock()
368+
mock_vllm_config.model_config = mock_model_config
369+
mock_vllm_config.cache_config = MagicMock(block_size=16)
370+
mock_vllm_config.scheduler_config = MagicMock(
371+
max_num_seqs=4, enable_chunked_prefill=False)
351372
mock_vllm_config.speculative_config = None
352373

374+
mock_device = torch.device('cpu')
375+
353376
builder = AscendMLATorchairMetadataBuilder(None, None,
354377
mock_vllm_config,
355378
mock_device)
@@ -368,16 +391,20 @@ def test_build_dummy(self, mock_ascend_config):
368391
mock_ascend_config.return_value = ascend_config
369392
ascend_config.torchair_graph_config.enabled = False
370393

371-
mock_vllm_config = MagicMock()
372-
mock_vllm_config.model_config.max_model_len = 1024
373-
mock_vllm_config.cache_config.block_size = 16
374-
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
375-
mock_vllm_config.get_head_size.return_value = 64
376-
mock_vllm_config.model_config.dtype = torch.float16
377-
mock_device = torch.device('cpu')
394+
mock_model_config = MagicMock()
395+
mock_model_config.max_model_len = 1024
396+
mock_model_config.get_head_size.return_value = 64
397+
mock_model_config.dtype = torch.float16
378398

399+
mock_vllm_config = MagicMock()
400+
mock_vllm_config.model_config = mock_model_config
401+
mock_vllm_config.cache_config = MagicMock(block_size=16)
402+
mock_vllm_config.scheduler_config = MagicMock(
403+
max_num_seqs=4, enable_chunked_prefill=False)
379404
mock_vllm_config.speculative_config = None
380405

406+
mock_device = torch.device('cpu')
407+
381408
builder = AscendMLATorchairMetadataBuilder(
382409
None,
383410
None,
@@ -435,18 +462,23 @@ def test_build_decode(self, mock_ascend_config):
435462
mock_ascend_config.return_value = ascend_config
436463
ascend_config.torchair_graph_config.enabled = False
437464

465+
mock_model_config = MagicMock()
466+
mock_model_config.max_model_len = 1024
467+
mock_model_config.get_head_size.return_value = 64
468+
mock_model_config.dtype = torch.float16
469+
438470
mock_vllm_config = MagicMock()
439-
mock_vllm_config.model_config.max_model_len = 1024
440-
mock_vllm_config.cache_config.block_size = 16
441-
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
442-
mock_vllm_config.get_head_size.return_value = 64
443-
mock_vllm_config.model_config.dtype = torch.float16
471+
mock_vllm_config.model_config = mock_model_config
472+
mock_vllm_config.cache_config = MagicMock(block_size=16)
473+
mock_vllm_config.scheduler_config = MagicMock(
474+
max_num_seqs=4, enable_chunked_prefill=False)
475+
mock_vllm_config.speculative_config = None
476+
444477
mock_device = torch.device('cpu')
478+
445479
model = MagicMock(spec=nn.Module)
446480
model.model = MagicMock(spec=nn.Module)
447481

448-
mock_vllm_config.speculative_config = None
449-
450482
builder = AscendMLATorchairMetadataBuilder(
451483
None,
452484
None,

0 commit comments

Comments
 (0)