-
Notifications
You must be signed in to change notification settings - Fork 622
[BugFix]add rope_fusion #4564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[BugFix]add rope_fusion #4564
Conversation
Signed-off-by: Angazenn <supperccell@163.com>
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new Triton kernel for a fused operation of splitting QKV, applying RMSNorm, and RoPE. The implementation has a few critical and high-severity issues. There is a critical bug in the handling of bias parameters that can lead to a crash. Additionally, device properties are fetched at module import time, which is not safe and can lead to incorrect grid configurations. Finally, there is significant code duplication within the kernel that harms maintainability. I've left specific comments on these points.
| n_cols = kv_hidden_size // KV_BLOCK_SIZE | ||
| assert num_vectorcore % n_cols == 0 | ||
| n_rows = num_vectorcore // n_cols | ||
| BIAS = q_bias is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The BIAS flag is determined solely by whether q_bias is None. This same flag is then used inside the kernel to conditionally access both q_bias_ptr and k_bias_ptr. If q_bias is provided but k_bias is None, the kernel will attempt to dereference a null pointer for k_bias_ptr, leading to a crash. The bias for Q and K should be handled independently.
To fix this, you should introduce separate flags for Q and K bias (e.g., Q_BIAS and K_BIAS) and pass them to the kernel, updating the kernel logic to use the respective flags.
| def split_qkv_rmsnorm_rope_kernel( | ||
| input_ptr, | ||
| sin_ptr, | ||
| cos_ptr, | ||
| q_ptr, | ||
| k_ptr, | ||
| v_ptr, | ||
| q_weight_ptr, | ||
| q_bias_ptr, | ||
| k_weight_ptr, | ||
| k_bias_ptr, | ||
| batch_size, | ||
| q_hidden_size: tl.constexpr, | ||
| kv_hidden_size: tl.constexpr, | ||
| total_hidden_size: tl.constexpr, | ||
| eps: tl.constexpr, | ||
| Q_BLOCK_SIZE: tl.constexpr, | ||
| KV_BLOCK_SIZE: tl.constexpr, | ||
| BIAS: tl.constexpr, | ||
| HEAD_DIM: tl.constexpr, | ||
| HALF_HEAD_DIM: tl.constexpr, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is significant code duplication between the processing loops for Q and K projections (lines 37-106 for Q, and 108-176 for K). This makes the code harder to read, maintain, and debug. A bug fix or optimization in one path might be missed in the other.
Consider refactoring the common RMSNorm and RoPE logic into a separate @triton.jit helper function. This function would be called from both the Q and K loops with the appropriate parameters (e.g., pointers, sizes, offsets).
| def get_npu_properties(): | ||
| device = torch.npu.current_device() | ||
| return driver.active.utils.get_device_properties(device) | ||
|
|
||
| num_vectorcore = get_npu_properties()["num_vectorcore"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The num_vectorcore property is fetched at module import time using torch.npu.current_device(). This can lead to incorrect behavior if the active device changes after import or in multi-device scenarios. The device properties should be fetched dynamically. get_npu_properties should be modified to accept a device, and the global num_vectorcore should be removed.
| def get_npu_properties(): | |
| device = torch.npu.current_device() | |
| return driver.active.utils.get_device_properties(device) | |
| num_vectorcore = get_npu_properties()["num_vectorcore"] | |
| def get_npu_properties(device): | |
| return driver.active.utils.get_device_properties(device) |
| ) | ||
| n_cols = kv_hidden_size // KV_BLOCK_SIZE | ||
| assert num_vectorcore % n_cols == 0 | ||
| n_rows = num_vectorcore // n_cols |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To complete the fix for fetching device properties dynamically, num_vectorcore should be calculated here using the input tensor's device before it is used to calculate n_rows.
| n_rows = num_vectorcore // n_cols | |
| num_vectorcore = get_npu_properties(input.device)["num_vectorcore"] | |
| n_rows = num_vectorcore // n_cols |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?