Skip to content

Conversation

@Victor49152
Copy link
Contributor

@Victor49152 Victor49152 commented Nov 26, 2025

Purpose

According to #28763, the vllm flash attention has supported all headsize for vit module, thus removing the upstream flash-attn checks as they are no longer necessary.

Test Plan

Use one of impacted model qwen3-vl-235B as example to start the server
vllm serve RedHatAI/Qwen3-VL-235B-A22B-Instruct-NVFP4 -tp 4 -dp 1 --mm-encoder-tp-mode data --enable-expert-parallel --async-scheduling --max-num-seqs 1024

Test Result

Server started successfully


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: mingyuanm <mingyuanm@nvidia.com>
@mergify mergify bot added the qwen Related to Qwen models label Nov 26, 2025
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 70 to 71
else:
return AttentionBackendEnum.TORCH_SDPA, None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore CUDA path in vit flash-attn selection

maybe_get_vit_flash_attn_backend now returns TORCH_SDPA for any platform that is neither ROCm nor XPU. CUDA is caught by this else branch, so even when get_vit_attn_backend selects FLASH_ATTN the function forces Torch SDPA and never loads the flash attention kernel. This effectively disables flash attention for all vision models on CUDA, degrading the intended fast path everywhere and ignoring user overrides.

Useful? React with 👍 / 👎.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to remove upstream flash-attention checks. While most of the changes correctly remove the use_upstream_fa flag and related logic, I've identified a few critical issues. There's a regression in vllm/attention/layer.py that would disable FlashAttention for ViT models on CUDA. Additionally, there are remaining usages and imports of a removed function (check_upstream_fa_availability) in vllm/model_executor/models/paddleocr_vl.py and vllm/model_executor/models/qwen3_vl.py, which will cause runtime errors. These issues need to be addressed to ensure the correctness and performance of the codebase.

Signed-off-by: mingyuanm <mingyuanm@nvidia.com>
Signed-off-by: mingyuanm <mingyuanm@nvidia.com>
@mgoin
Copy link
Member

mgoin commented Nov 26, 2025

Need to check with AMD folks if there is a need from their side @gshtras

@Victor49152
Copy link
Contributor Author

@ywang96 Please also take look at the logic in maybe_get_vit_flash_attn_backend, thanks!

Signed-off-by: mingyuanm <mingyuanm@nvidia.com>
):
attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True
elif attn_backend_override is None \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to add back the on_gfx9() condition here to differentiate between Radeon and Instinct GPUs.

On Radeon, only TORCH_SDPA is supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed this changes, thanks and please comment if there is anything else you notice

Signed-off-by: mingyuanm <mingyuanm@nvidia.com>
from flash_attn import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
Copy link
Collaborator

@tjtanaa tjtanaa Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm/attention/utils/fa_utils.py does not have the logic for ROCm, flash_attn_varlen_func will be a None object if imported this way.

We can keep the import statement from flash_attn import flash_attn_varlen_func for now. Else we have to add this from flash_attn import flash_attn_varlen_func import statement into the vllm/attention/utils/fa_utils.py when platform is rocm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this import to fa_utils as it looks like the most simple way of it. And except message tells user to install upstream fa when import error is raised. Please check if that works, thanks!

from flash_attn import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
Copy link
Collaborator

@tjtanaa tjtanaa Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like wise,

vllm/attention/utils/fa_utils.py does not have the logic for ROCm, flash_attn_varlen_func will be a None object if imported this way.

We can keep the import statement from flash_attn import flash_attn_varlen_func for now. Else we have to add this from flash_attn import flash_attn_varlen_func import statement into the vllm/attention/utils/fa_utils.py when platform is rocm.

Victor49152 and others added 4 commits November 25, 2025 21:13
Signed-off-by: mingyuanm <mingyuanm@nvidia.com>
Signed-off-by: mingyuanm <mingyuanm@nvidia.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the precommit error but otherwise LGTM

cc @tjtanaa for final check on the changes for resolving FA import on ROCM platform.

@github-project-automation github-project-automation bot moved this to In review in NVIDIA Nov 28, 2025
@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 28, 2025
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Roger Wang <hey@rogerw.io>
@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 28, 2025

@ywang96 Thanks. LGTM.

It is using the flash attention and aiter flash attention.
And the code path on ROCm is working, ChartQA score of Qwen/Qwen3-VL-8B-Instruct when using both backends are

================================================================================
Metrics:
{
    "explicit_prompt_relaxed_correctness": 0.7948,
    "anywhere_in_answer_relaxed_correctness": 0.7988
}
================================================================================

@vllm-bot vllm-bot merged commit 460d8bb into vllm-project:main Nov 28, 2025
51 of 53 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in NVIDIA Nov 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants