33import dataclasses
44from unittest .mock import Mock
55
6+ import numpy as np
67import pytest
78import torch
89
@@ -169,7 +170,7 @@ def test_schedule_partial_requests():
169170 req_id_to_index = req_to_index ,
170171 # Only the first request has a sampled token id because
171172 # the rest requests are still being prefilled.
172- sampled_token_ids = [[0 ], [], [] ],
173+ sampled_token_ids = [np . array ( [0 ]), np . array ([]), np . array ([]) ],
173174 logprobs = None ,
174175 prompt_logprobs_dict = {},
175176 pooler_output = [],
@@ -216,7 +217,7 @@ def test_no_mm_input_chunking():
216217 model_runner_output = ModelRunnerOutput (
217218 req_ids = [request .request_id for request in requests ],
218219 req_id_to_index = req_to_index ,
219- sampled_token_ids = [[] for _ in range (len (requests ))],
220+ sampled_token_ids = [np . array ([]) for _ in range (len (requests ))],
220221 logprobs = None ,
221222 prompt_logprobs_dict = {},
222223 pooler_output = [],
@@ -276,7 +277,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
276277 model_runner_output = ModelRunnerOutput (
277278 req_ids = [request .request_id for request in requests ],
278279 req_id_to_index = req_to_index ,
279- sampled_token_ids = [[] for _ in range (len (requests ))],
280+ sampled_token_ids = [np . array ([]) for _ in range (len (requests ))],
280281 logprobs = None ,
281282 prompt_logprobs_dict = {},
282283 pooler_output = [],
@@ -300,7 +301,8 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
300301 model_runner_output = ModelRunnerOutput (
301302 req_ids = [request .request_id for request in requests ],
302303 req_id_to_index = req_to_index ,
303- sampled_token_ids = [[0 ], [0 ]] + [[] for _ in range (len (requests ) - 2 )],
304+ sampled_token_ids = [np .array ([0 ]), np .array ([0 ])]
305+ + [np .array ([]) for _ in range (len (requests ) - 2 )],
304306 logprobs = None ,
305307 prompt_logprobs_dict = {},
306308 pooler_output = [],
@@ -347,8 +349,8 @@ def test_stop_via_update_from_output():
347349 req_ids = [req .request_id for req in requests ],
348350 req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
349351 sampled_token_ids = [
350- [EOS_TOKEN_ID ],
351- [10 , 11 ],
352+ np . array ( [EOS_TOKEN_ID ]) ,
353+ np . array ( [10 , 11 ]) ,
352354 ], # First request hits EOS, second continues
353355 logprobs = None ,
354356 prompt_logprobs_dict = {},
@@ -392,7 +394,10 @@ def test_stop_via_update_from_output():
392394 model_output = ModelRunnerOutput (
393395 req_ids = [req .request_id for req in requests ],
394396 req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
395- sampled_token_ids = [[10 , 42 , 12 ], [13 , 14 ]], # First request hits stop token
397+ sampled_token_ids = [
398+ np .array ([10 , 42 , 12 ]),
399+ np .array ([13 , 14 ]),
400+ ], # First request hits stop token
396401 logprobs = None ,
397402 prompt_logprobs_dict = {},
398403 pooler_output = [],
@@ -436,7 +441,10 @@ def test_stop_via_update_from_output():
436441 model_output = ModelRunnerOutput (
437442 req_ids = [req .request_id for req in requests ],
438443 req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
439- sampled_token_ids = [[10 , 11 , 12 ], [13 ]], # First request exceeds max_tokens
444+ sampled_token_ids = [
445+ np .array ([10 , 11 , 12 ]),
446+ np .array ([13 ]),
447+ ], # First request exceeds max_tokens
440448 logprobs = None ,
441449 prompt_logprobs_dict = {},
442450 pooler_output = [],
@@ -475,7 +483,7 @@ def test_stop_via_update_from_output():
475483 model_output = ModelRunnerOutput (
476484 req_ids = [requests [0 ].request_id ],
477485 req_id_to_index = {requests [0 ].request_id : 0 },
478- sampled_token_ids = [[EOS_TOKEN_ID , 10 , 11 ]],
486+ sampled_token_ids = [np . array ( [EOS_TOKEN_ID , 10 , 11 ]) ],
479487 logprobs = None ,
480488 prompt_logprobs_dict = {},
481489 pooler_output = [],
@@ -616,7 +624,7 @@ def test_schedule_concurrent_batches(
616624 model_runner_output = ModelRunnerOutput (
617625 req_ids = [requests [0 ].request_id ],
618626 req_id_to_index = {requests [0 ].request_id : 0 },
619- sampled_token_ids = [[0 ]],
627+ sampled_token_ids = [np . array ( [0 ]) ],
620628 logprobs = None ,
621629 prompt_logprobs_dict = {},
622630 pooler_output = [],
@@ -633,7 +641,7 @@ def test_schedule_concurrent_batches(
633641 model_runner_output = ModelRunnerOutput (
634642 req_ids = [requests [1 ].request_id ],
635643 req_id_to_index = {requests [1 ].request_id : 0 },
636- sampled_token_ids = [[0 ]],
644+ sampled_token_ids = [np . array ( [0 ]) ],
637645 logprobs = None ,
638646 prompt_logprobs_dict = {},
639647 pooler_output = [],
@@ -670,7 +678,7 @@ def test_preempt_during_execution():
670678 model_runner_output0 = ModelRunnerOutput (
671679 req_ids = [requests [0 ].request_id ],
672680 req_id_to_index = {requests [0 ].request_id : 0 },
673- sampled_token_ids = [[0 ]],
681+ sampled_token_ids = [np . array ( [0 ]) ],
674682 logprobs = None ,
675683 prompt_logprobs_dict = {},
676684 pooler_output = [],
@@ -687,7 +695,7 @@ def test_preempt_during_execution():
687695 model_runner_output1 = ModelRunnerOutput (
688696 req_ids = [requests [1 ].request_id ],
689697 req_id_to_index = {requests [1 ].request_id : 0 },
690- sampled_token_ids = [[42 ]],
698+ sampled_token_ids = [np . array ( [42 ]) ],
691699 logprobs = None ,
692700 prompt_logprobs_dict = {},
693701 pooler_output = [],
@@ -704,14 +712,18 @@ def test_preempt_during_execution():
704712@pytest .mark .parametrize (
705713 "spec_tokens,output_tokens,expected" ,
706714 [
707- ([[1 , 2 , 3 ]], [[1 , 2 , 3 , 4 ]], (1 , 3 , 3 , [1 , 1 , 1 ])), # perfect match
708- ([[1 , 2 , 3 ]], [[1 , 5 ]], (1 , 3 , 1 , [1 , 0 , 0 ])), # early mismatch
709- ([[1 , 2 ], [3 ]], [[1 , 2 , 5 ], [3 , 4 ]], (2 , 3 , 3 , [2 , 1 ])), # multiple sequences
710- ([[1 ]], [[1 , 2 ]], (1 , 1 , 1 , [1 ])), # single token sequence
711- ([[]], [[5 ]], (0 , 0 , 0 , [0 ])), # empty sequence
715+ ([[1 , 2 , 3 ]], [np .array ([1 , 2 , 3 , 4 ])], (1 , 3 , 3 , [1 , 1 , 1 ])), # perfect match
716+ ([[1 , 2 , 3 ]], [np .array ([1 , 5 ])], (1 , 3 , 1 , [1 , 0 , 0 ])), # early mismatch
717+ (
718+ [[1 , 2 ], [3 ]],
719+ [np .array ([1 , 2 , 5 ]), np .array ([3 , 4 ])],
720+ (2 , 3 , 3 , [2 , 1 ]),
721+ ), # multiple sequences
722+ ([[1 ]], [np .array ([1 , 2 ])], (1 , 1 , 1 , [1 ])), # single token sequence
723+ ([[]], [np .array ([5 ])], (0 , 0 , 0 , [0 ])), # empty sequence
712724 (
713725 [[1 , 2 , 3 ], [4 , 5 , 6 ]],
714- [[1 , 2 , 7 ], [4 , 8 ]],
726+ [np . array ( [1 , 2 , 7 ]), np . array ( [4 , 8 ]) ],
715727 (2 , 6 , 3 , [2 , 1 , 0 ]),
716728 ), # multiple mismatches
717729 ],
@@ -745,7 +757,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
745757 model_runner_output = ModelRunnerOutput (
746758 req_ids = req_ids ,
747759 req_id_to_index = req_to_index ,
748- sampled_token_ids = [[0 ] for _ in range (len (requests ))],
760+ sampled_token_ids = [np . array ( [0 ]) for _ in range (len (requests ))],
749761 logprobs = None ,
750762 prompt_logprobs_dict = {},
751763 pooler_output = [],
@@ -972,7 +984,7 @@ def test_kv_connector_basic(is_async: bool):
972984 MODEL_RUNNER_OUTPUT = ModelRunnerOutput (
973985 req_ids = req_ids ,
974986 req_id_to_index = req_to_index ,
975- sampled_token_ids = [[1000 ]] * len (req_ids ),
987+ sampled_token_ids = [np . array ( [1000 ]) ] * len (req_ids ),
976988 logprobs = None ,
977989 prompt_logprobs_dict = {},
978990 pooler_output = [],
@@ -1025,7 +1037,7 @@ def test_kv_connector_basic(is_async: bool):
10251037 MODEL_RUNNER_OUTPUT = ModelRunnerOutput (
10261038 req_ids = req_ids ,
10271039 req_id_to_index = req_to_index ,
1028- sampled_token_ids = [[1000 ]] * len (req_ids ),
1040+ sampled_token_ids = [np . array ( [1000 ]) ] * len (req_ids ),
10291041 logprobs = None ,
10301042 prompt_logprobs_dict = {},
10311043 pooler_output = [],
@@ -1088,7 +1100,7 @@ def test_external_prefix_cache_metrics():
10881100 MODEL_RUNNER_OUTPUT = ModelRunnerOutput (
10891101 req_ids = [r .request_id for r in requests ],
10901102 req_id_to_index = {r .request_id : i for i , r in enumerate (requests )},
1091- sampled_token_ids = [[1000 ]] * NUM_REQUESTS ,
1103+ sampled_token_ids = [np . array ( [1000 ]) ] * NUM_REQUESTS ,
10921104 logprobs = None ,
10931105 prompt_logprobs_dict = {},
10941106 pooler_output = [],
@@ -1154,7 +1166,7 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
11541166 MODEL_RUNNER_OUTPUT = ModelRunnerOutput (
11551167 req_ids = req_ids ,
11561168 req_id_to_index = req_to_index ,
1157- sampled_token_ids = [[1000 ]] * len (req_ids ),
1169+ sampled_token_ids = [np . array ( [1000 ]) ] * len (req_ids ),
11581170 logprobs = None ,
11591171 prompt_logprobs_dict = {},
11601172 pooler_output = [],
@@ -1239,7 +1251,7 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
12391251 MODEL_RUNNER_OUTPUT = ModelRunnerOutput (
12401252 req_ids = req_ids ,
12411253 req_id_to_index = req_to_index ,
1242- sampled_token_ids = [[1000 ]] * len (req_ids ),
1254+ sampled_token_ids = [np . array ( [1000 ]) ] * len (req_ids ),
12431255 logprobs = None ,
12441256 prompt_logprobs_dict = {},
12451257 pooler_output = [],
@@ -1332,7 +1344,7 @@ def make_output(scheduler: Scheduler):
13321344 return ModelRunnerOutput (
13331345 req_ids = [req .request_id for req in scheduler .running ],
13341346 req_id_to_index = {req .request_id : i for i , req in enumerate (scheduler .running )},
1335- sampled_token_ids = [[1000 ]] * len (scheduler .running ),
1347+ sampled_token_ids = [np . array ( [1000 ]) ] * len (scheduler .running ),
13361348 logprobs = None ,
13371349 prompt_logprobs_dict = {},
13381350 pooler_output = [],
@@ -1749,7 +1761,7 @@ def test_priority_scheduling_preemption():
17491761 req_id_to_index = {
17501762 req .request_id : i for i , req in enumerate (low_priority_requests )
17511763 },
1752- sampled_token_ids = [[100 ] for _ in low_priority_requests ],
1764+ sampled_token_ids = [np . array ( [100 ]) for _ in low_priority_requests ],
17531765 logprobs = None ,
17541766 prompt_logprobs_dict = {},
17551767 pooler_output = [],
@@ -1818,7 +1830,7 @@ def test_priority_scheduling_no_preemption_when_space_available():
18181830 req_id_to_index = {
18191831 req .request_id : i for i , req in enumerate (low_priority_requests )
18201832 },
1821- sampled_token_ids = [[100 ] for _ in low_priority_requests ],
1833+ sampled_token_ids = [np . array ( [100 ]) for _ in low_priority_requests ],
18221834 logprobs = None ,
18231835 prompt_logprobs_dict = {},
18241836 pooler_output = [],
@@ -2064,7 +2076,7 @@ def test_priority_scheduling_heap_property():
20642076 model_output = ModelRunnerOutput (
20652077 req_ids = [req .req_id ],
20662078 req_id_to_index = {req .req_id : 0 },
2067- sampled_token_ids = [[100 ]],
2079+ sampled_token_ids = [np . array ( [100 ]) ],
20682080 logprobs = None ,
20692081 prompt_logprobs_dict = {},
20702082 pooler_output = [],
@@ -2150,7 +2162,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
21502162 model_output = ModelRunnerOutput (
21512163 req_ids = [request_low .request_id ],
21522164 req_id_to_index = {request_low .request_id : 0 },
2153- sampled_token_ids = [[100 ]],
2165+ sampled_token_ids = [np . array ( [100 ]) ],
21542166 # spec_token_ids=None,
21552167 logprobs = None ,
21562168 prompt_logprobs_dict = {},
@@ -2181,7 +2193,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
21812193 model_output = ModelRunnerOutput (
21822194 req_ids = [req .request_id for req in requests ],
21832195 req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
2184- sampled_token_ids = [[100 ] for _ in requests ],
2196+ sampled_token_ids = [np . array ( [100 ]) for _ in requests ],
21852197 # spec_token_ids=None,
21862198 logprobs = None ,
21872199 prompt_logprobs_dict = {},
@@ -2207,7 +2219,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
22072219 model_output = ModelRunnerOutput (
22082220 req_ids = [req .request_id for req in requests ],
22092221 req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
2210- sampled_token_ids = [[], [100 ]],
2222+ sampled_token_ids = [np . array ([]), np . array ( [100 ]) ],
22112223 # spec_token_ids=None,
22122224 logprobs = None ,
22132225 prompt_logprobs_dict = {},
@@ -2624,7 +2636,7 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
26242636 model_output = ModelRunnerOutput (
26252637 req_ids = [request1 .request_id ],
26262638 req_id_to_index = {request1 .request_id : 0 },
2627- sampled_token_ids = [[100 ]],
2639+ sampled_token_ids = [np . array ( [100 ]) ],
26282640 # spec_token_ids=None,
26292641 logprobs = None ,
26302642 prompt_logprobs_dict = {},
@@ -2830,7 +2842,7 @@ def test_ec_connector_unable_to_allocate(use_kv_connector):
28302842 MODEL_RUNNER_OUTPUT = ModelRunnerOutput (
28312843 req_ids = req_ids ,
28322844 req_id_to_index = req_to_index ,
2833- sampled_token_ids = [[1000 ]] * len (req_ids ),
2845+ sampled_token_ids = [np . array ( [1000 ]) ] * len (req_ids ),
28342846 logprobs = None ,
28352847 prompt_logprobs_dict = {},
28362848 pooler_output = [],
@@ -2943,7 +2955,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
29432955 model_output = ModelRunnerOutput (
29442956 req_ids = [request_low .request_id ],
29452957 req_id_to_index = {request_low .request_id : 0 },
2946- sampled_token_ids = [[100 ]],
2958+ sampled_token_ids = [np . array ( [100 ]) ],
29472959 # spec_token_ids=None,
29482960 logprobs = None ,
29492961 prompt_logprobs_dict = {},
@@ -2994,7 +3006,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
29943006 model_output = ModelRunnerOutput (
29953007 req_ids = [req .request_id for req in requests ],
29963008 req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
2997- sampled_token_ids = [[100 ] for _ in requests ],
3009+ sampled_token_ids = [np . array ( [100 ]) for _ in requests ],
29983010 # spec_token_ids=None,
29993011 logprobs = None ,
30003012 prompt_logprobs_dict = {},
@@ -3029,7 +3041,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
30293041 model_output = ModelRunnerOutput (
30303042 req_ids = [req .request_id for req in requests ],
30313043 req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
3032- sampled_token_ids = [[100 ], [100 , 200 ]],
3044+ sampled_token_ids = [np . array ( [100 ]), np . array ( [100 , 200 ]) ],
30333045 # spec_token_ids=None,
30343046 logprobs = None ,
30353047 prompt_logprobs_dict = {},
@@ -3215,7 +3227,7 @@ def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connecto
32153227 model_output = ModelRunnerOutput (
32163228 req_ids = [request1 .request_id , request2 .request_id ],
32173229 req_id_to_index = {request1 .request_id : 0 , request2 .request_id : 1 },
3218- sampled_token_ids = [[100 ], [121 ]],
3230+ sampled_token_ids = [np . array ( [100 ]), np . array ( [121 ]) ],
32193231 # spec_token_ids=None,
32203232 logprobs = None ,
32213233 prompt_logprobs_dict = {},
0 commit comments