Skip to content

Conversation

@mgoin
Copy link
Member

@mgoin mgoin commented Oct 17, 2025

Purpose

FIX #27057

The original issue was found because Qwen3-VL models completely lost accuracy (1% vs 86% on GSM8K) on B200 GPUs when using the default FULL_AND_PIECEWISE cudagraph_mode. The issue did not occur on Hopper at all, with PIECEWISE mode only, FlashAttention backend, or when explicitly disabling TRTLLM attention.

Because TRTLLM attention is selected dynamically based on runtime conditions (num_tokens, max_seq_len, kv_cache_dtype). During FULL CG capture, the max_seq_len is used which when greater than 128K results in FlashInfer being selected, but during actual inference without using full context length, the same conditions triggered TRTLLM selection. This created a graph/runtime mismatch where captured graphs referenced FlashInfer kernels but runtime attempted to execute TRTLLM kernels, producing incorrect results. I was able to see this behavior on any model with default max_model_len>128K

By enforcing PIECEWISE mode in this PR to disable cuda graph capture of attention, we can avoid this issue of dynamism. In the future we should see if we can made TRTLLM support larger context lengths to support FULL graphs

Test Plan

Test Result

Reproduction on B200 on main:

vllm serve gradientai/Llama-3-8B-Instruct-Gradient-1048k
python tests/evals/gsm8k/gsm8k_eval.py

Running GSM8K evaluation: 1319 questions, 5-shot
Evaluating: 100%|████████████████████████████████████████| 1319/1319 [01:04<00:00, 20.56it/s]

Results:
Accuracy: 0.006
Invalid responses: 0.775
Total latency: 64.173 s
Questions per second: 20.554
vllm serve gradientai/Llama-3-8B-Instruct-Gradient-1048k --max-model-len=100K
python tests/evals/gsm8k/gsm8k_eval.py

Running GSM8K evaluation: 1319 questions, 5-shot
Evaluating: 100%|███████████████████████████████████████| 1319/1319 [00:11<00:00, 118.01it/s]

Results:
Accuracy: 0.579
Invalid responses: 0.004
Total latency: 11.190 s
Questions per second: 117.872

Running on this PR:

vllm serve gradientai/Llama-3-8B-Instruct-Gradient-1048k
(APIServer pid=2588191) WARNING 10-17 13:50:11 [vllm.py:385] NVIDIA Blackwell TRTLLM attention cannot support max_model_len >= 131072 (found 1048576), causing dynamic dispatching that breaks full cudagraphs. Overriding cudagraph_mode to PIECEWISE.
python tests/evals/gsm8k/gsm8k_eval.py

Running GSM8K evaluation: 1319 questions, 5-shot
Evaluating: 100%|███████████████████████████████████████| 1319/1319 [00:12<00:00, 107.95it/s]

Results:
Accuracy: 0.578
Invalid responses: 0.004
Total latency: 12.233 s
Questions per second: 107.825

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a bug fix that uses PIECEWISE cudagraphs on Blackwell architecture if the max_model_len exceeds 131072. The code changes modify the VllmConfig class to check for this condition and override the cudagraph_mode accordingly. The changes also include adding warning messages to the logger.

mgoin added 2 commits October 17, 2025 13:34
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
@mgoin mgoin added this to the v0.11.1 milestone Oct 17, 2025
@mgoin mgoin added the bug Something isn't working label Oct 17, 2025
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 17, 2025
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing - thank you!

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@yewentao256 yewentao256 enabled auto-merge (squash) October 17, 2025 19:33
@yewentao256 yewentao256 merged commit 950cf9e into vllm-project:main Oct 17, 2025
46 checks passed
@mgoin mgoin deleted the disable-trtllm-full-cudagraph-long-context branch October 17, 2025 20:23
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
adabeyta pushed a commit to adabeyta/vllm that referenced this pull request Oct 20, 2025
@xinli-sw
Copy link
Contributor

Added this issue to Flashinfer to track the long term fix for FULL CG support of TRTLLM backend

flashinfer-ai/flashinfer#1968

albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 23, 2025
…072 (vllm-project#27114)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…072 (vllm-project#27114)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…072 (vllm-project#27114)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…072 (vllm-project#27114)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…072 (vllm-project#27114)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Zhathw pushed a commit to Zhathw/vllm that referenced this pull request Nov 12, 2025
@benchislett
Copy link
Collaborator

benchislett commented Nov 14, 2025

FYI, it's possible that #28479 caused a regression that breaks this fix. One path forward is #28755, another is just to patch the check added by here to also consider max model len.

devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Qwen3-VL broken on Blackwell with PIECEWISE_AND_FULL

6 participants