@@ -202,9 +202,9 @@ def test_speculators_model_integration(
202202
203203
204204@pytest .mark .parametrize (
205- ["model_setup" , "mm_enabled" ],
205+ ["model_setup" , "mm_enabled" , "chunked_prefill_enabled" ],
206206 [
207- (("eagle3" , "Qwen/Qwen3-8B" , "AngelSlim/Qwen3-8B_eagle3" , 1 ), False ),
207+ (("eagle3" , "Qwen/Qwen3-8B" , "AngelSlim/Qwen3-8B_eagle3" , 1 ), False , False ),
208208 pytest .param (
209209 (
210210 "eagle3" ,
@@ -213,19 +213,22 @@ def test_speculators_model_integration(
213213 1 ,
214214 ),
215215 False ,
216+ False ,
216217 marks = pytest .mark .skip (
217218 reason = "Skipping due to its head_dim not being a a multiple of 32"
218219 ),
219220 ),
220- (
221+ pytest . param (
221222 (
222223 "eagle" ,
223224 "meta-llama/Llama-3.1-8B-Instruct" ,
224225 "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" ,
225226 1 ,
226227 ),
227228 False ,
228- ),
229+ True ,
230+ marks = large_gpu_mark (min_gb = 40 ),
231+ ), # works on 4x H100
229232 (
230233 (
231234 "eagle3" ,
@@ -234,6 +237,7 @@ def test_speculators_model_integration(
234237 1 ,
235238 ),
236239 False ,
240+ False ,
237241 ),
238242 pytest .param (
239243 (
@@ -243,6 +247,7 @@ def test_speculators_model_integration(
243247 4 ,
244248 ),
245249 False ,
250+ False ,
246251 marks = large_gpu_mark (min_gb = 80 ),
247252 ), # works on 4x H100
248253 pytest .param (
@@ -253,6 +258,7 @@ def test_speculators_model_integration(
253258 4 ,
254259 ),
255260 True ,
261+ True ,
256262 marks = large_gpu_mark (min_gb = 80 ),
257263 ), # works on 4x H100
258264 (
@@ -263,6 +269,7 @@ def test_speculators_model_integration(
263269 1 ,
264270 ),
265271 False ,
272+ False ,
266273 ),
267274 ],
268275 ids = [
@@ -281,6 +288,7 @@ def test_eagle_correctness(
281288 sampling_config : SamplingParams ,
282289 model_setup : tuple [str , str , str , int ],
283290 mm_enabled : bool ,
291+ chunked_prefill_enabled : bool ,
284292 attn_backend : str ,
285293):
286294 if attn_backend == "TREE_ATTN" :
@@ -317,9 +325,13 @@ def test_eagle_correctness(
317325 m .setenv ("VLLM_ROCM_USE_AITER" , "1" )
318326
319327 method , model_name , spec_model_name , tp_size = model_setup
328+ max_model_len = 2048
329+ max_num_batched_tokens = max_model_len
330+ if chunked_prefill_enabled :
331+ max_num_batched_tokens = 128
320332
321333 ref_llm = LLM (
322- model = model_name , max_model_len = 2048 , tensor_parallel_size = tp_size
334+ model = model_name , max_model_len = max_model_len , tensor_parallel_size = tp_size
323335 )
324336 ref_outputs = ref_llm .chat (test_prompts , sampling_config )
325337 del ref_llm
@@ -334,9 +346,11 @@ def test_eagle_correctness(
334346 "method" : method ,
335347 "model" : spec_model_name ,
336348 "num_speculative_tokens" : 3 ,
337- "max_model_len" : 2048 ,
349+ "max_model_len" : max_model_len ,
338350 },
339- max_model_len = 2048 ,
351+ max_model_len = max_model_len ,
352+ max_num_batched_tokens = max_num_batched_tokens ,
353+ enable_chunked_prefill = chunked_prefill_enabled ,
340354 )
341355 spec_outputs = spec_llm .chat (test_prompts , sampling_config )
342356 matches = 0
0 commit comments