Skip to content

Commit 9d7da91

Browse files
fix mla_v1 acl_graph scheduler ut test
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
1 parent d11944e commit 9d7da91

File tree

3 files changed

+59
-54
lines changed

3 files changed

+59
-54
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,10 @@ def setUp(self):
440440
self.mock_vllm_config.model_config = ModelConfig(max_model_len=2048)
441441
self.mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 32
442442
self.mock_vllm_config.cache_config = CacheConfig(block_size=32)
443-
self.mock_vllm_config.scheduler_config = SchedulerConfig(
444-
max_num_seqs=8, chunked_prefill_enabled=True)
443+
mock_scheduler_config = MagicMock(spec=SchedulerConfig)
444+
mock_scheduler_config.max_num_seqs = 8 # 设置为整数,不是 MagicMock
445+
mock_scheduler_config.chunked_prefill_enabled = True
446+
self.mock_vllm_config.scheduler_config = mock_scheduler_config
445447
self.mock_vllm_config.speculative_config = None
446448
self.mock_device = torch.device("cpu")
447449

@@ -454,12 +456,20 @@ def setUp(self):
454456
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
455457
)
456458
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
457-
def test_build_prefix_no_cache_metadata(self, mock_get_ascend_config,
459+
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
460+
@patch("torch.Tensor.npu", new=lambda self: self)
461+
@patch("torch.npu.is_available")
462+
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
463+
mock_zeros, mock_get_ascend_config,
458464
mock_dcp_world_size):
459-
if not torch.npu.is_available():
460-
self.skipTest("NPU not available, skipping NPU-dependent tests")
465+
mock_npu_available.return_value = False
461466
mock_dcp_world_size.return_value = 1
462467

468+
def zeros_override(*args, **kwargs):
469+
kwargs.pop('pin_memory', None)
470+
return mock_zeros._mock_wraps(*args, **kwargs)
471+
472+
mock_zeros.side_effect = zeros_override
463473
common_attn_metadata = AscendCommonAttentionMetadata(
464474
query_start_loc=torch.tensor([0, 3, 7]),
465475
query_start_loc_cpu=torch.tensor([0, 3, 7]),
@@ -506,12 +516,21 @@ def test_build_prefix_no_cache_metadata(self, mock_get_ascend_config,
506516
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
507517
)
508518
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
509-
def test_build_chunked_prefix_metadata(self, mock_get_ascend_config,
519+
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
520+
@patch("torch.Tensor.npu", new=lambda self: self)
521+
@patch("torch.npu.is_available")
522+
def test_build_chunked_prefix_metadata(self, mock_npu_available,
523+
mock_zeros, mock_get_ascend_config,
510524
mock_dcp_world_size):
511-
if not torch.npu.is_available():
512-
self.skipTest("NPU not available, skipping NPU-dependent tests")
525+
mock_npu_available.return_value = False
513526
mock_dcp_world_size.return_value = 1
514527

528+
def zeros_override(*args, **kwargs):
529+
kwargs.pop('pin_memory', None)
530+
return mock_zeros._mock_wraps(*args, **kwargs)
531+
532+
mock_zeros.side_effect = zeros_override
533+
515534
common_attn_metadata = AscendCommonAttentionMetadata(
516535
query_start_loc=torch.tensor([0, 2, 5, 9]),
517536
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),

tests/ut/compilation/test_acl_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_aclgraph_entry_initialization(self):
3232
"""Test ACLGraphEntry initialization with default values"""
3333
batch_descriptor = BatchDescriptor(
3434
num_tokens=30,
35-
uniform_decode=False,
35+
uniform=False,
3636
)
3737

3838
entry = ACLGraphEntry(batch_descriptor=batch_descriptor)
@@ -46,7 +46,7 @@ def test_aclgraph_entry_with_values(self):
4646
"""Test ACLGraphEntry initialization with specified values"""
4747
batch_descriptor = BatchDescriptor(
4848
num_tokens=30,
49-
uniform_decode=False,
49+
uniform=False,
5050
)
5151

5252
mock_graph = MagicMock()
@@ -89,7 +89,7 @@ def setUp(self):
8989
# Mock BatchDescriptor
9090
self.mock_batch_descriptor = BatchDescriptor(
9191
num_tokens=30,
92-
uniform_decode=False,
92+
uniform=False,
9393
)
9494

9595
# Mock ForwardContext

tests/ut/core/test_scheduler.py

Lines changed: 29 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ def make_output(scheduler):
8181
req.request_id: i
8282
for i, req in enumerate(scheduler.running)
8383
}
84-
sampled_token_ids = [
85-
np.array([1000], dtype=np.int64) for _ in scheduler.running
86-
]
84+
sampled_token_ids = [[1000]] * len(scheduler.running)
8785

8886
logprobs = None
8987

@@ -372,8 +370,7 @@ def test_stop_via_update_from_output(self):
372370
req.request_id: i
373371
for i, req in enumerate(requests)
374372
},
375-
sampled_token_ids=[np.array([EOS_TOKEN_ID]),
376-
np.array([10, 11])
373+
sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]
377374
], # First request hits EOS, second continues
378375
logprobs=None,
379376
prompt_logprobs_dict={},
@@ -424,9 +421,8 @@ def test_stop_via_update_from_output(self):
424421
req.request_id: i
425422
for i, req in enumerate(requests)
426423
},
427-
sampled_token_ids=[np.array([10, 42, 12]),
428-
np.array([13, 14])
429-
], # First request hits stop token
424+
sampled_token_ids=[[10, 42, 12],
425+
[13, 14]], # First request hits stop token
430426
logprobs=None,
431427
prompt_logprobs_dict={},
432428
pooler_output=[])
@@ -475,9 +471,8 @@ def test_stop_via_update_from_output(self):
475471
req.request_id: i
476472
for i, req in enumerate(requests)
477473
},
478-
sampled_token_ids=[np.array([10, 11, 12]),
479-
np.array([13])
480-
], # First request exceeds max_tokens
474+
sampled_token_ids=[[10, 11, 12],
475+
[13]], # First request exceeds max_tokens
481476
logprobs=None,
482477
prompt_logprobs_dict={},
483478
pooler_output=[])
@@ -516,7 +511,7 @@ def test_stop_via_update_from_output(self):
516511
model_output = ModelRunnerOutput(
517512
req_ids=[requests[0].request_id],
518513
req_id_to_index={requests[0].request_id: 0},
519-
sampled_token_ids=[np.array([EOS_TOKEN_ID, 10, 11])],
514+
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
520515
logprobs=None,
521516
prompt_logprobs_dict={},
522517
pooler_output=[])
@@ -573,7 +568,7 @@ def test_schedule_concurrent_batches(self):
573568
model_runner_output = ModelRunnerOutput(
574569
req_ids=[requests[0].request_id],
575570
req_id_to_index={requests[0].request_id: 0},
576-
sampled_token_ids=[np.array([0], dtype=np.int64)],
571+
sampled_token_ids=[[0]],
577572
logprobs=None,
578573
prompt_logprobs_dict={},
579574
pooler_output=[])
@@ -589,7 +584,7 @@ def test_schedule_concurrent_batches(self):
589584
model_runner_output = ModelRunnerOutput(
590585
req_ids=[requests[1].request_id],
591586
req_id_to_index={requests[1].request_id: 0},
592-
sampled_token_ids=[np.array([0], dtype=np.int64)],
587+
sampled_token_ids=[[0]],
593588
logprobs=None,
594589
prompt_logprobs_dict={},
595590
pooler_output=[])
@@ -607,12 +602,10 @@ def test_schedule_spec_decoding_stats(self):
607602
spec_tokens_list: List[List[List[int]]] = [[[1, 2, 3]], [[1, 2, 3]],
608603
[[1, 2], [3]], [[1]], [[]],
609604
[[1, 2, 3], [4, 5, 6]]]
610-
output_tokens_list: List[List[List[int]]] = [
611-
[np.array([1, 2, 3, 4])], [np.array([1, 5])],
612-
[np.array([1, 2, 5]), np.array([3, 4])], [np.array([1, 2])],
613-
[np.array([5])], [np.array([1, 2, 7]),
614-
np.array([4, 8])]
615-
]
605+
output_tokens_list: List[List[List[int]]] = [[[1, 2, 3, 4]], [[1, 5]],
606+
[[1, 2, 5], [3, 4]],
607+
[[1, 2]], [[5]],
608+
[[1, 2, 7], [4, 8]]]
616609
expected_list: List[Tuple[int, int,
617610
int, List[int]]] = [(1, 3, 3, [1, 1, 1]),
618611
(1, 3, 1, [1, 0, 0]),
@@ -650,9 +643,7 @@ def test_schedule_spec_decoding_stats(self):
650643
model_runner_output = ModelRunnerOutput(
651644
req_ids=req_ids,
652645
req_id_to_index=req_to_index,
653-
sampled_token_ids=[
654-
np.array([0]) for _ in range(len(requests))
655-
],
646+
sampled_token_ids=[[0] for _ in range(len(requests))],
656647
logprobs=None,
657648
prompt_logprobs_dict={},
658649
pooler_output=[])
@@ -892,11 +883,13 @@ def create_scheduler(self, mock_compute_encoder_budget):
892883
torch.float32, False))
893884
],
894885
)
886+
kv_cache_config.hash_block_size = block_size
895887
cache_config.num_gpu_blocks = 10000
896888

897889
scheduler = SchedulerDynamicBatch(
898890
vllm_config=vllm_config,
899891
kv_cache_config=kv_cache_config,
892+
block_size=block_size,
900893
log_stats=True,
901894
structured_output_manager=MagicMock(spec=StructuredOutputManager),
902895
)
@@ -1064,8 +1057,7 @@ def test_stop_via_update_from_output(self):
10641057
req.request_id: i
10651058
for i, req in enumerate(requests)
10661059
},
1067-
sampled_token_ids=[np.array([EOS_TOKEN_ID]),
1068-
np.array([10, 11])
1060+
sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]
10691061
], # First request hits EOS, second continues
10701062
logprobs=None,
10711063
prompt_logprobs_dict={},
@@ -1116,9 +1108,8 @@ def test_stop_via_update_from_output(self):
11161108
req.request_id: i
11171109
for i, req in enumerate(requests)
11181110
},
1119-
sampled_token_ids=[np.array([10, 42, 12]),
1120-
np.array([13, 14])
1121-
], # First request hits stop token
1111+
sampled_token_ids=[[10, 42, 12],
1112+
[13, 14]], # First request hits stop token
11221113
logprobs=None,
11231114
prompt_logprobs_dict={},
11241115
pooler_output=[])
@@ -1167,9 +1158,8 @@ def test_stop_via_update_from_output(self):
11671158
req.request_id: i
11681159
for i, req in enumerate(requests)
11691160
},
1170-
sampled_token_ids=[np.array([10, 11, 12]),
1171-
np.array([13])
1172-
], # First request exceeds max_tokens
1161+
sampled_token_ids=[[10, 11, 12],
1162+
[13]], # First request exceeds max_tokens
11731163
logprobs=None,
11741164
prompt_logprobs_dict={},
11751165
pooler_output=[])
@@ -1208,7 +1198,7 @@ def test_stop_via_update_from_output(self):
12081198
model_output = ModelRunnerOutput(
12091199
req_ids=[requests[0].request_id],
12101200
req_id_to_index={requests[0].request_id: 0},
1211-
sampled_token_ids=[np.array([EOS_TOKEN_ID, 10, 11])],
1201+
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
12121202
logprobs=None,
12131203
prompt_logprobs_dict={},
12141204
pooler_output=[])
@@ -1265,7 +1255,7 @@ def test_schedule_concurrent_batches(self):
12651255
model_runner_output = ModelRunnerOutput(
12661256
req_ids=[requests[0].request_id],
12671257
req_id_to_index={requests[0].request_id: 0},
1268-
sampled_token_ids=[np.array([0])],
1258+
sampled_token_ids=[[0]],
12691259
logprobs=None,
12701260
prompt_logprobs_dict={},
12711261
pooler_output=[])
@@ -1281,7 +1271,7 @@ def test_schedule_concurrent_batches(self):
12811271
model_runner_output = ModelRunnerOutput(
12821272
req_ids=[requests[1].request_id],
12831273
req_id_to_index={requests[1].request_id: 0},
1284-
sampled_token_ids=[np.array([0])],
1274+
sampled_token_ids=[[0]],
12851275
logprobs=None,
12861276
prompt_logprobs_dict={},
12871277
pooler_output=[])
@@ -1299,12 +1289,10 @@ def test_schedule_spec_decoding_stats(self):
12991289
spec_tokens_list: List[List[List[int]]] = [[[1, 2, 3]], [[1, 2, 3]],
13001290
[[1, 2], [3]], [[1]], [[]],
13011291
[[1, 2, 3], [4, 5, 6]]]
1302-
output_tokens_list: List[List[List[int]]] = [
1303-
[np.array([1, 2, 3, 4])], [np.array([1, 5])],
1304-
[np.array([1, 2, 5]), np.array([3, 4])], [np.array([1, 2])],
1305-
[np.array([5])], [np.array([1, 2, 7]),
1306-
np.array([4, 8])]
1307-
]
1292+
output_tokens_list: List[List[List[int]]] = [[[1, 2, 3, 4]], [[1, 5]],
1293+
[[1, 2, 5], [3, 4]],
1294+
[[1, 2]], [[5]],
1295+
[[1, 2, 7], [4, 8]]]
13081296
expected_list: List[Tuple[int, int,
13091297
int, List[int]]] = [(1, 3, 3, [1, 1, 1]),
13101298
(1, 3, 1, [1, 0, 0]),
@@ -1342,9 +1330,7 @@ def test_schedule_spec_decoding_stats(self):
13421330
model_runner_output = ModelRunnerOutput(
13431331
req_ids=req_ids,
13441332
req_id_to_index=req_to_index,
1345-
sampled_token_ids=[
1346-
np.array([0]) for _ in range(len(requests))
1347-
],
1333+
sampled_token_ids=[[0] for _ in range(len(requests))],
13481334
logprobs=None,
13491335
prompt_logprobs_dict={},
13501336
pooler_output=[])

0 commit comments

Comments
 (0)