Skip to content

Commit cec7c28

Browse files
[Bugfix] Padded Eagle Specdec with Chunked Prefill (vllm-project#26263)
Signed-off-by: Rémi Delacourt <remi@mistral.ai> Signed-off-by: Rémi Delacourt <54138269+Flechman@users.noreply.github.com> Signed-off-by: remi <remi@mistral.ai> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
1 parent 18961c5 commit cec7c28

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

tests/v1/e2e/test_spec_decode.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,9 @@ def test_speculators_model_integration(
202202

203203

204204
@pytest.mark.parametrize(
205-
["model_setup", "mm_enabled"],
205+
["model_setup", "mm_enabled", "chunked_prefill_enabled"],
206206
[
207-
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
207+
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False),
208208
pytest.param(
209209
(
210210
"eagle3",
@@ -213,19 +213,22 @@ def test_speculators_model_integration(
213213
1,
214214
),
215215
False,
216+
False,
216217
marks=pytest.mark.skip(
217218
reason="Skipping due to its head_dim not being a a multiple of 32"
218219
),
219220
),
220-
(
221+
pytest.param(
221222
(
222223
"eagle",
223224
"meta-llama/Llama-3.1-8B-Instruct",
224225
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
225226
1,
226227
),
227228
False,
228-
),
229+
True,
230+
marks=large_gpu_mark(min_gb=40),
231+
), # works on 4x H100
229232
(
230233
(
231234
"eagle3",
@@ -234,6 +237,7 @@ def test_speculators_model_integration(
234237
1,
235238
),
236239
False,
240+
False,
237241
),
238242
pytest.param(
239243
(
@@ -243,6 +247,7 @@ def test_speculators_model_integration(
243247
4,
244248
),
245249
False,
250+
False,
246251
marks=large_gpu_mark(min_gb=80),
247252
), # works on 4x H100
248253
pytest.param(
@@ -253,6 +258,7 @@ def test_speculators_model_integration(
253258
4,
254259
),
255260
True,
261+
True,
256262
marks=large_gpu_mark(min_gb=80),
257263
), # works on 4x H100
258264
(
@@ -263,6 +269,7 @@ def test_speculators_model_integration(
263269
1,
264270
),
265271
False,
272+
False,
266273
),
267274
],
268275
ids=[
@@ -281,6 +288,7 @@ def test_eagle_correctness(
281288
sampling_config: SamplingParams,
282289
model_setup: tuple[str, str, str, int],
283290
mm_enabled: bool,
291+
chunked_prefill_enabled: bool,
284292
attn_backend: str,
285293
):
286294
if attn_backend == "TREE_ATTN":
@@ -317,9 +325,13 @@ def test_eagle_correctness(
317325
m.setenv("VLLM_ROCM_USE_AITER", "1")
318326

319327
method, model_name, spec_model_name, tp_size = model_setup
328+
max_model_len = 2048
329+
max_num_batched_tokens = max_model_len
330+
if chunked_prefill_enabled:
331+
max_num_batched_tokens = 128
320332

321333
ref_llm = LLM(
322-
model=model_name, max_model_len=2048, tensor_parallel_size=tp_size
334+
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
323335
)
324336
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
325337
del ref_llm
@@ -334,9 +346,11 @@ def test_eagle_correctness(
334346
"method": method,
335347
"model": spec_model_name,
336348
"num_speculative_tokens": 3,
337-
"max_model_len": 2048,
349+
"max_model_len": max_model_len,
338350
},
339-
max_model_len=2048,
351+
max_model_len=max_model_len,
352+
max_num_batched_tokens=max_num_batched_tokens,
353+
enable_chunked_prefill=chunked_prefill_enabled,
340354
)
341355
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
342356
matches = 0

0 commit comments

Comments
 (0)