Skip to content

Conversation

@Angazenn
Copy link
Collaborator

@Angazenn Angazenn commented Nov 29, 2025

What this PR does / why we need it?

Does this PR introduce any user-facing change?

How was this patch tested?

Signed-off-by: Angazenn <supperccell@163.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Triton kernel for a fused operation of splitting QKV, applying RMSNorm, and RoPE. The implementation has a few critical and high-severity issues. There is a critical bug in the handling of bias parameters that can lead to a crash. Additionally, device properties are fetched at module import time, which is not safe and can lead to incorrect grid configurations. Finally, there is significant code duplication within the kernel that harms maintainability. I've left specific comments on these points.

n_cols = kv_hidden_size // KV_BLOCK_SIZE
assert num_vectorcore % n_cols == 0
n_rows = num_vectorcore // n_cols
BIAS = q_bias is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The BIAS flag is determined solely by whether q_bias is None. This same flag is then used inside the kernel to conditionally access both q_bias_ptr and k_bias_ptr. If q_bias is provided but k_bias is None, the kernel will attempt to dereference a null pointer for k_bias_ptr, leading to a crash. The bias for Q and K should be handled independently.

To fix this, you should introduce separate flags for Q and K bias (e.g., Q_BIAS and K_BIAS) and pass them to the kernel, updating the kernel logic to use the respective flags.

Comment on lines +12 to +33
def split_qkv_rmsnorm_rope_kernel(
input_ptr,
sin_ptr,
cos_ptr,
q_ptr,
k_ptr,
v_ptr,
q_weight_ptr,
q_bias_ptr,
k_weight_ptr,
k_bias_ptr,
batch_size,
q_hidden_size: tl.constexpr,
kv_hidden_size: tl.constexpr,
total_hidden_size: tl.constexpr,
eps: tl.constexpr,
Q_BLOCK_SIZE: tl.constexpr,
KV_BLOCK_SIZE: tl.constexpr,
BIAS: tl.constexpr,
HEAD_DIM: tl.constexpr,
HALF_HEAD_DIM: tl.constexpr,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is significant code duplication between the processing loops for Q and K projections (lines 37-106 for Q, and 108-176 for K). This makes the code harder to read, maintain, and debug. A bug fix or optimization in one path might be missed in the other.

Consider refactoring the common RMSNorm and RoPE logic into a separate @triton.jit helper function. This function would be called from both the Q and K loops with the appropriate parameters (e.g., pointers, sizes, offsets).

Comment on lines +194 to +198
def get_npu_properties():
device = torch.npu.current_device()
return driver.active.utils.get_device_properties(device)

num_vectorcore = get_npu_properties()["num_vectorcore"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The num_vectorcore property is fetched at module import time using torch.npu.current_device(). This can lead to incorrect behavior if the active device changes after import or in multi-device scenarios. The device properties should be fetched dynamically. get_npu_properties should be modified to accept a device, and the global num_vectorcore should be removed.

Suggested change
def get_npu_properties():
device = torch.npu.current_device()
return driver.active.utils.get_device_properties(device)
num_vectorcore = get_npu_properties()["num_vectorcore"]
def get_npu_properties(device):
return driver.active.utils.get_device_properties(device)

)
n_cols = kv_hidden_size // KV_BLOCK_SIZE
assert num_vectorcore % n_cols == 0
n_rows = num_vectorcore // n_cols
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To complete the fix for fetching device properties dynamically, num_vectorcore should be calculated here using the input tensor's device before it is used to calculate n_rows.

Suggested change
n_rows = num_vectorcore // n_cols
num_vectorcore = get_npu_properties(input.device)["num_vectorcore"]
n_rows = num_vectorcore // n_cols

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant