Skip to content

Commit 186352b

Browse files
authored
[Core] Performance: Use list[np.ndarray] instead of list[list[int]] for output tokens for GC optimization (#26368)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
1 parent 58e61e5 commit 186352b

File tree

12 files changed

+102
-76
lines changed

12 files changed

+102
-76
lines changed

tests/v1/core/test_async_scheduler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from collections import deque
44

5+
import numpy as np
56
import pytest
67

78
from vllm.v1.core.sched.output import SchedulerOutput
@@ -21,7 +22,7 @@ def _make_model_runner_output(
2122
return ModelRunnerOutput(
2223
req_ids=req_ids,
2324
req_id_to_index={req_id: i for i, req_id in enumerate(req_ids)},
24-
sampled_token_ids=[[i] for i in range(len(req_ids))],
25+
sampled_token_ids=[np.array([i]) for i in range(len(req_ids))],
2526
logprobs=None,
2627
prompt_logprobs_dict={},
2728
pooler_output=[],

tests/v1/core/test_scheduler.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import dataclasses
44
from unittest.mock import Mock
55

6+
import numpy as np
67
import pytest
78
import torch
89

@@ -169,7 +170,7 @@ def test_schedule_partial_requests():
169170
req_id_to_index=req_to_index,
170171
# Only the first request has a sampled token id because
171172
# the rest requests are still being prefilled.
172-
sampled_token_ids=[[0], [], []],
173+
sampled_token_ids=[np.array([0]), np.array([]), np.array([])],
173174
logprobs=None,
174175
prompt_logprobs_dict={},
175176
pooler_output=[],
@@ -216,7 +217,7 @@ def test_no_mm_input_chunking():
216217
model_runner_output = ModelRunnerOutput(
217218
req_ids=[request.request_id for request in requests],
218219
req_id_to_index=req_to_index,
219-
sampled_token_ids=[[] for _ in range(len(requests))],
220+
sampled_token_ids=[np.array([]) for _ in range(len(requests))],
220221
logprobs=None,
221222
prompt_logprobs_dict={},
222223
pooler_output=[],
@@ -276,7 +277,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
276277
model_runner_output = ModelRunnerOutput(
277278
req_ids=[request.request_id for request in requests],
278279
req_id_to_index=req_to_index,
279-
sampled_token_ids=[[] for _ in range(len(requests))],
280+
sampled_token_ids=[np.array([]) for _ in range(len(requests))],
280281
logprobs=None,
281282
prompt_logprobs_dict={},
282283
pooler_output=[],
@@ -300,7 +301,8 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
300301
model_runner_output = ModelRunnerOutput(
301302
req_ids=[request.request_id for request in requests],
302303
req_id_to_index=req_to_index,
303-
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
304+
sampled_token_ids=[np.array([0]), np.array([0])]
305+
+ [np.array([]) for _ in range(len(requests) - 2)],
304306
logprobs=None,
305307
prompt_logprobs_dict={},
306308
pooler_output=[],
@@ -347,8 +349,8 @@ def test_stop_via_update_from_output():
347349
req_ids=[req.request_id for req in requests],
348350
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
349351
sampled_token_ids=[
350-
[EOS_TOKEN_ID],
351-
[10, 11],
352+
np.array([EOS_TOKEN_ID]),
353+
np.array([10, 11]),
352354
], # First request hits EOS, second continues
353355
logprobs=None,
354356
prompt_logprobs_dict={},
@@ -392,7 +394,10 @@ def test_stop_via_update_from_output():
392394
model_output = ModelRunnerOutput(
393395
req_ids=[req.request_id for req in requests],
394396
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
395-
sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token
397+
sampled_token_ids=[
398+
np.array([10, 42, 12]),
399+
np.array([13, 14]),
400+
], # First request hits stop token
396401
logprobs=None,
397402
prompt_logprobs_dict={},
398403
pooler_output=[],
@@ -436,7 +441,10 @@ def test_stop_via_update_from_output():
436441
model_output = ModelRunnerOutput(
437442
req_ids=[req.request_id for req in requests],
438443
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
439-
sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens
444+
sampled_token_ids=[
445+
np.array([10, 11, 12]),
446+
np.array([13]),
447+
], # First request exceeds max_tokens
440448
logprobs=None,
441449
prompt_logprobs_dict={},
442450
pooler_output=[],
@@ -475,7 +483,7 @@ def test_stop_via_update_from_output():
475483
model_output = ModelRunnerOutput(
476484
req_ids=[requests[0].request_id],
477485
req_id_to_index={requests[0].request_id: 0},
478-
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
486+
sampled_token_ids=[np.array([EOS_TOKEN_ID, 10, 11])],
479487
logprobs=None,
480488
prompt_logprobs_dict={},
481489
pooler_output=[],
@@ -616,7 +624,7 @@ def test_schedule_concurrent_batches(
616624
model_runner_output = ModelRunnerOutput(
617625
req_ids=[requests[0].request_id],
618626
req_id_to_index={requests[0].request_id: 0},
619-
sampled_token_ids=[[0]],
627+
sampled_token_ids=[np.array([0])],
620628
logprobs=None,
621629
prompt_logprobs_dict={},
622630
pooler_output=[],
@@ -633,7 +641,7 @@ def test_schedule_concurrent_batches(
633641
model_runner_output = ModelRunnerOutput(
634642
req_ids=[requests[1].request_id],
635643
req_id_to_index={requests[1].request_id: 0},
636-
sampled_token_ids=[[0]],
644+
sampled_token_ids=[np.array([0])],
637645
logprobs=None,
638646
prompt_logprobs_dict={},
639647
pooler_output=[],
@@ -670,7 +678,7 @@ def test_preempt_during_execution():
670678
model_runner_output0 = ModelRunnerOutput(
671679
req_ids=[requests[0].request_id],
672680
req_id_to_index={requests[0].request_id: 0},
673-
sampled_token_ids=[[0]],
681+
sampled_token_ids=[np.array([0])],
674682
logprobs=None,
675683
prompt_logprobs_dict={},
676684
pooler_output=[],
@@ -687,7 +695,7 @@ def test_preempt_during_execution():
687695
model_runner_output1 = ModelRunnerOutput(
688696
req_ids=[requests[1].request_id],
689697
req_id_to_index={requests[1].request_id: 0},
690-
sampled_token_ids=[[42]],
698+
sampled_token_ids=[np.array([42])],
691699
logprobs=None,
692700
prompt_logprobs_dict={},
693701
pooler_output=[],
@@ -704,14 +712,18 @@ def test_preempt_during_execution():
704712
@pytest.mark.parametrize(
705713
"spec_tokens,output_tokens,expected",
706714
[
707-
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
708-
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
709-
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences
710-
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
711-
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
715+
([[1, 2, 3]], [np.array([1, 2, 3, 4])], (1, 3, 3, [1, 1, 1])), # perfect match
716+
([[1, 2, 3]], [np.array([1, 5])], (1, 3, 1, [1, 0, 0])), # early mismatch
717+
(
718+
[[1, 2], [3]],
719+
[np.array([1, 2, 5]), np.array([3, 4])],
720+
(2, 3, 3, [2, 1]),
721+
), # multiple sequences
722+
([[1]], [np.array([1, 2])], (1, 1, 1, [1])), # single token sequence
723+
([[]], [np.array([5])], (0, 0, 0, [0])), # empty sequence
712724
(
713725
[[1, 2, 3], [4, 5, 6]],
714-
[[1, 2, 7], [4, 8]],
726+
[np.array([1, 2, 7]), np.array([4, 8])],
715727
(2, 6, 3, [2, 1, 0]),
716728
), # multiple mismatches
717729
],
@@ -745,7 +757,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
745757
model_runner_output = ModelRunnerOutput(
746758
req_ids=req_ids,
747759
req_id_to_index=req_to_index,
748-
sampled_token_ids=[[0] for _ in range(len(requests))],
760+
sampled_token_ids=[np.array([0]) for _ in range(len(requests))],
749761
logprobs=None,
750762
prompt_logprobs_dict={},
751763
pooler_output=[],
@@ -972,7 +984,7 @@ def test_kv_connector_basic(is_async: bool):
972984
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
973985
req_ids=req_ids,
974986
req_id_to_index=req_to_index,
975-
sampled_token_ids=[[1000]] * len(req_ids),
987+
sampled_token_ids=[np.array([1000])] * len(req_ids),
976988
logprobs=None,
977989
prompt_logprobs_dict={},
978990
pooler_output=[],
@@ -1025,7 +1037,7 @@ def test_kv_connector_basic(is_async: bool):
10251037
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
10261038
req_ids=req_ids,
10271039
req_id_to_index=req_to_index,
1028-
sampled_token_ids=[[1000]] * len(req_ids),
1040+
sampled_token_ids=[np.array([1000])] * len(req_ids),
10291041
logprobs=None,
10301042
prompt_logprobs_dict={},
10311043
pooler_output=[],
@@ -1088,7 +1100,7 @@ def test_external_prefix_cache_metrics():
10881100
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
10891101
req_ids=[r.request_id for r in requests],
10901102
req_id_to_index={r.request_id: i for i, r in enumerate(requests)},
1091-
sampled_token_ids=[[1000]] * NUM_REQUESTS,
1103+
sampled_token_ids=[np.array([1000])] * NUM_REQUESTS,
10921104
logprobs=None,
10931105
prompt_logprobs_dict={},
10941106
pooler_output=[],
@@ -1154,7 +1166,7 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
11541166
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
11551167
req_ids=req_ids,
11561168
req_id_to_index=req_to_index,
1157-
sampled_token_ids=[[1000]] * len(req_ids),
1169+
sampled_token_ids=[np.array([1000])] * len(req_ids),
11581170
logprobs=None,
11591171
prompt_logprobs_dict={},
11601172
pooler_output=[],
@@ -1239,7 +1251,7 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
12391251
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
12401252
req_ids=req_ids,
12411253
req_id_to_index=req_to_index,
1242-
sampled_token_ids=[[1000]] * len(req_ids),
1254+
sampled_token_ids=[np.array([1000])] * len(req_ids),
12431255
logprobs=None,
12441256
prompt_logprobs_dict={},
12451257
pooler_output=[],
@@ -1332,7 +1344,7 @@ def make_output(scheduler: Scheduler):
13321344
return ModelRunnerOutput(
13331345
req_ids=[req.request_id for req in scheduler.running],
13341346
req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)},
1335-
sampled_token_ids=[[1000]] * len(scheduler.running),
1347+
sampled_token_ids=[np.array([1000])] * len(scheduler.running),
13361348
logprobs=None,
13371349
prompt_logprobs_dict={},
13381350
pooler_output=[],
@@ -1749,7 +1761,7 @@ def test_priority_scheduling_preemption():
17491761
req_id_to_index={
17501762
req.request_id: i for i, req in enumerate(low_priority_requests)
17511763
},
1752-
sampled_token_ids=[[100] for _ in low_priority_requests],
1764+
sampled_token_ids=[np.array([100]) for _ in low_priority_requests],
17531765
logprobs=None,
17541766
prompt_logprobs_dict={},
17551767
pooler_output=[],
@@ -1818,7 +1830,7 @@ def test_priority_scheduling_no_preemption_when_space_available():
18181830
req_id_to_index={
18191831
req.request_id: i for i, req in enumerate(low_priority_requests)
18201832
},
1821-
sampled_token_ids=[[100] for _ in low_priority_requests],
1833+
sampled_token_ids=[np.array([100]) for _ in low_priority_requests],
18221834
logprobs=None,
18231835
prompt_logprobs_dict={},
18241836
pooler_output=[],
@@ -2064,7 +2076,7 @@ def test_priority_scheduling_heap_property():
20642076
model_output = ModelRunnerOutput(
20652077
req_ids=[req.req_id],
20662078
req_id_to_index={req.req_id: 0},
2067-
sampled_token_ids=[[100]],
2079+
sampled_token_ids=[np.array([100])],
20682080
logprobs=None,
20692081
prompt_logprobs_dict={},
20702082
pooler_output=[],
@@ -2150,7 +2162,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
21502162
model_output = ModelRunnerOutput(
21512163
req_ids=[request_low.request_id],
21522164
req_id_to_index={request_low.request_id: 0},
2153-
sampled_token_ids=[[100]],
2165+
sampled_token_ids=[np.array([100])],
21542166
# spec_token_ids=None,
21552167
logprobs=None,
21562168
prompt_logprobs_dict={},
@@ -2181,7 +2193,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
21812193
model_output = ModelRunnerOutput(
21822194
req_ids=[req.request_id for req in requests],
21832195
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
2184-
sampled_token_ids=[[100] for _ in requests],
2196+
sampled_token_ids=[np.array([100]) for _ in requests],
21852197
# spec_token_ids=None,
21862198
logprobs=None,
21872199
prompt_logprobs_dict={},
@@ -2207,7 +2219,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
22072219
model_output = ModelRunnerOutput(
22082220
req_ids=[req.request_id for req in requests],
22092221
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
2210-
sampled_token_ids=[[], [100]],
2222+
sampled_token_ids=[np.array([]), np.array([100])],
22112223
# spec_token_ids=None,
22122224
logprobs=None,
22132225
prompt_logprobs_dict={},

tests/v1/kv_connector/unit/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from itertools import chain, count
88
from typing import Any
99

10+
import numpy as np
1011
import torch
1112

1213
from vllm import SamplingParams
@@ -228,7 +229,7 @@ def create_model_runner_output(
228229

229230
# Make sampled tokens.
230231
sampled_token = EOS_TOKEN_ID if use_eos else token_id
231-
sampled_token_ids = [[sampled_token] for _ in req_ids]
232+
sampled_token_ids = [np.array([sampled_token]) for _ in req_ids]
232233

233234
kv_connector_output = (
234235
None

tests/v1/spec_decode/test_eagle.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from unittest import mock
55

6+
import numpy as np
67
import pytest
78
import torch
89

@@ -112,7 +113,9 @@ def test_prepare_next_token_ids():
112113
sampled_token_ids_tensor = torch.tensor(
113114
sampled_token_ids, dtype=torch.int32, device=device
114115
)
115-
sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
116+
sampled_token_ids_cpu = [
117+
np.array([i for i in seq if i != -1]) for seq in sampled_token_ids
118+
]
116119

117120
expected_next_token_ids_cpu = [1, 4, 30, 40]
118121
expected_next_token_ids_tensor = torch.tensor(

0 commit comments

Comments
 (0)