@@ -81,9 +81,7 @@ def make_output(scheduler):
8181 req .request_id : i
8282 for i , req in enumerate (scheduler .running )
8383 }
84- sampled_token_ids = [
85- np .array ([1000 ], dtype = np .int64 ) for _ in scheduler .running
86- ]
84+ sampled_token_ids = [[1000 ]] * len (scheduler .running )
8785
8886 logprobs = None
8987
@@ -372,8 +370,7 @@ def test_stop_via_update_from_output(self):
372370 req .request_id : i
373371 for i , req in enumerate (requests )
374372 },
375- sampled_token_ids = [np .array ([EOS_TOKEN_ID ]),
376- np .array ([10 , 11 ])
373+ sampled_token_ids = [[EOS_TOKEN_ID ], [10 , 11 ]
377374 ], # First request hits EOS, second continues
378375 logprobs = None ,
379376 prompt_logprobs_dict = {},
@@ -424,9 +421,8 @@ def test_stop_via_update_from_output(self):
424421 req .request_id : i
425422 for i , req in enumerate (requests )
426423 },
427- sampled_token_ids = [np .array ([10 , 42 , 12 ]),
428- np .array ([13 , 14 ])
429- ], # First request hits stop token
424+ sampled_token_ids = [[10 , 42 , 12 ],
425+ [13 , 14 ]], # First request hits stop token
430426 logprobs = None ,
431427 prompt_logprobs_dict = {},
432428 pooler_output = [])
@@ -475,9 +471,8 @@ def test_stop_via_update_from_output(self):
475471 req .request_id : i
476472 for i , req in enumerate (requests )
477473 },
478- sampled_token_ids = [np .array ([10 , 11 , 12 ]),
479- np .array ([13 ])
480- ], # First request exceeds max_tokens
474+ sampled_token_ids = [[10 , 11 , 12 ],
475+ [13 ]], # First request exceeds max_tokens
481476 logprobs = None ,
482477 prompt_logprobs_dict = {},
483478 pooler_output = [])
@@ -516,7 +511,7 @@ def test_stop_via_update_from_output(self):
516511 model_output = ModelRunnerOutput (
517512 req_ids = [requests [0 ].request_id ],
518513 req_id_to_index = {requests [0 ].request_id : 0 },
519- sampled_token_ids = [np . array ( [EOS_TOKEN_ID , 10 , 11 ]) ],
514+ sampled_token_ids = [[EOS_TOKEN_ID , 10 , 11 ]],
520515 logprobs = None ,
521516 prompt_logprobs_dict = {},
522517 pooler_output = [])
@@ -573,7 +568,7 @@ def test_schedule_concurrent_batches(self):
573568 model_runner_output = ModelRunnerOutput (
574569 req_ids = [requests [0 ].request_id ],
575570 req_id_to_index = {requests [0 ].request_id : 0 },
576- sampled_token_ids = [np . array ( [0 ], dtype = np . int64 ) ],
571+ sampled_token_ids = [[0 ]],
577572 logprobs = None ,
578573 prompt_logprobs_dict = {},
579574 pooler_output = [])
@@ -589,7 +584,7 @@ def test_schedule_concurrent_batches(self):
589584 model_runner_output = ModelRunnerOutput (
590585 req_ids = [requests [1 ].request_id ],
591586 req_id_to_index = {requests [1 ].request_id : 0 },
592- sampled_token_ids = [np . array ( [0 ], dtype = np . int64 ) ],
587+ sampled_token_ids = [[0 ]],
593588 logprobs = None ,
594589 prompt_logprobs_dict = {},
595590 pooler_output = [])
@@ -607,12 +602,10 @@ def test_schedule_spec_decoding_stats(self):
607602 spec_tokens_list : List [List [List [int ]]] = [[[1 , 2 , 3 ]], [[1 , 2 , 3 ]],
608603 [[1 , 2 ], [3 ]], [[1 ]], [[]],
609604 [[1 , 2 , 3 ], [4 , 5 , 6 ]]]
610- output_tokens_list : List [List [List [int ]]] = [
611- [np .array ([1 , 2 , 3 , 4 ])], [np .array ([1 , 5 ])],
612- [np .array ([1 , 2 , 5 ]), np .array ([3 , 4 ])], [np .array ([1 , 2 ])],
613- [np .array ([5 ])], [np .array ([1 , 2 , 7 ]),
614- np .array ([4 , 8 ])]
615- ]
605+ output_tokens_list : List [List [List [int ]]] = [[[1 , 2 , 3 , 4 ]], [[1 , 5 ]],
606+ [[1 , 2 , 5 ], [3 , 4 ]],
607+ [[1 , 2 ]], [[5 ]],
608+ [[1 , 2 , 7 ], [4 , 8 ]]]
616609 expected_list : List [Tuple [int , int ,
617610 int , List [int ]]] = [(1 , 3 , 3 , [1 , 1 , 1 ]),
618611 (1 , 3 , 1 , [1 , 0 , 0 ]),
@@ -650,9 +643,7 @@ def test_schedule_spec_decoding_stats(self):
650643 model_runner_output = ModelRunnerOutput (
651644 req_ids = req_ids ,
652645 req_id_to_index = req_to_index ,
653- sampled_token_ids = [
654- np .array ([0 ]) for _ in range (len (requests ))
655- ],
646+ sampled_token_ids = [[0 ] for _ in range (len (requests ))],
656647 logprobs = None ,
657648 prompt_logprobs_dict = {},
658649 pooler_output = [])
@@ -892,11 +883,13 @@ def create_scheduler(self, mock_compute_encoder_budget):
892883 torch .float32 , False ))
893884 ],
894885 )
886+ kv_cache_config .hash_block_size = block_size
895887 cache_config .num_gpu_blocks = 10000
896888
897889 scheduler = SchedulerDynamicBatch (
898890 vllm_config = vllm_config ,
899891 kv_cache_config = kv_cache_config ,
892+ block_size = block_size ,
900893 log_stats = True ,
901894 structured_output_manager = MagicMock (spec = StructuredOutputManager ),
902895 )
@@ -1064,8 +1057,7 @@ def test_stop_via_update_from_output(self):
10641057 req .request_id : i
10651058 for i , req in enumerate (requests )
10661059 },
1067- sampled_token_ids = [np .array ([EOS_TOKEN_ID ]),
1068- np .array ([10 , 11 ])
1060+ sampled_token_ids = [[EOS_TOKEN_ID ], [10 , 11 ]
10691061 ], # First request hits EOS, second continues
10701062 logprobs = None ,
10711063 prompt_logprobs_dict = {},
@@ -1116,9 +1108,8 @@ def test_stop_via_update_from_output(self):
11161108 req .request_id : i
11171109 for i , req in enumerate (requests )
11181110 },
1119- sampled_token_ids = [np .array ([10 , 42 , 12 ]),
1120- np .array ([13 , 14 ])
1121- ], # First request hits stop token
1111+ sampled_token_ids = [[10 , 42 , 12 ],
1112+ [13 , 14 ]], # First request hits stop token
11221113 logprobs = None ,
11231114 prompt_logprobs_dict = {},
11241115 pooler_output = [])
@@ -1167,9 +1158,8 @@ def test_stop_via_update_from_output(self):
11671158 req .request_id : i
11681159 for i , req in enumerate (requests )
11691160 },
1170- sampled_token_ids = [np .array ([10 , 11 , 12 ]),
1171- np .array ([13 ])
1172- ], # First request exceeds max_tokens
1161+ sampled_token_ids = [[10 , 11 , 12 ],
1162+ [13 ]], # First request exceeds max_tokens
11731163 logprobs = None ,
11741164 prompt_logprobs_dict = {},
11751165 pooler_output = [])
@@ -1208,7 +1198,7 @@ def test_stop_via_update_from_output(self):
12081198 model_output = ModelRunnerOutput (
12091199 req_ids = [requests [0 ].request_id ],
12101200 req_id_to_index = {requests [0 ].request_id : 0 },
1211- sampled_token_ids = [np . array ( [EOS_TOKEN_ID , 10 , 11 ]) ],
1201+ sampled_token_ids = [[EOS_TOKEN_ID , 10 , 11 ]],
12121202 logprobs = None ,
12131203 prompt_logprobs_dict = {},
12141204 pooler_output = [])
@@ -1265,7 +1255,7 @@ def test_schedule_concurrent_batches(self):
12651255 model_runner_output = ModelRunnerOutput (
12661256 req_ids = [requests [0 ].request_id ],
12671257 req_id_to_index = {requests [0 ].request_id : 0 },
1268- sampled_token_ids = [np . array ( [0 ]) ],
1258+ sampled_token_ids = [[0 ]],
12691259 logprobs = None ,
12701260 prompt_logprobs_dict = {},
12711261 pooler_output = [])
@@ -1281,7 +1271,7 @@ def test_schedule_concurrent_batches(self):
12811271 model_runner_output = ModelRunnerOutput (
12821272 req_ids = [requests [1 ].request_id ],
12831273 req_id_to_index = {requests [1 ].request_id : 0 },
1284- sampled_token_ids = [np . array ( [0 ]) ],
1274+ sampled_token_ids = [[0 ]],
12851275 logprobs = None ,
12861276 prompt_logprobs_dict = {},
12871277 pooler_output = [])
@@ -1299,12 +1289,10 @@ def test_schedule_spec_decoding_stats(self):
12991289 spec_tokens_list : List [List [List [int ]]] = [[[1 , 2 , 3 ]], [[1 , 2 , 3 ]],
13001290 [[1 , 2 ], [3 ]], [[1 ]], [[]],
13011291 [[1 , 2 , 3 ], [4 , 5 , 6 ]]]
1302- output_tokens_list : List [List [List [int ]]] = [
1303- [np .array ([1 , 2 , 3 , 4 ])], [np .array ([1 , 5 ])],
1304- [np .array ([1 , 2 , 5 ]), np .array ([3 , 4 ])], [np .array ([1 , 2 ])],
1305- [np .array ([5 ])], [np .array ([1 , 2 , 7 ]),
1306- np .array ([4 , 8 ])]
1307- ]
1292+ output_tokens_list : List [List [List [int ]]] = [[[1 , 2 , 3 , 4 ]], [[1 , 5 ]],
1293+ [[1 , 2 , 5 ], [3 , 4 ]],
1294+ [[1 , 2 ]], [[5 ]],
1295+ [[1 , 2 , 7 ], [4 , 8 ]]]
13081296 expected_list : List [Tuple [int , int ,
13091297 int , List [int ]]] = [(1 , 3 , 3 , [1 , 1 , 1 ]),
13101298 (1 , 3 , 1 , [1 , 0 , 0 ]),
@@ -1342,9 +1330,7 @@ def test_schedule_spec_decoding_stats(self):
13421330 model_runner_output = ModelRunnerOutput (
13431331 req_ids = req_ids ,
13441332 req_id_to_index = req_to_index ,
1345- sampled_token_ids = [
1346- np .array ([0 ]) for _ in range (len (requests ))
1347- ],
1333+ sampled_token_ids = [[0 ] for _ in range (len (requests ))],
13481334 logprobs = None ,
13491335 prompt_logprobs_dict = {},
13501336 pooler_output = [])
0 commit comments