Skip to content

Conversation

@OsirisDuan
Copy link

@OsirisDuan OsirisDuan commented Dec 1, 2025

What this PR does / why we need it?

Does this PR introduce any user-facing change?

How was this patch tested?

Signed-off-by: Ascendyh <hw7osiris@outlook.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new L2 normalization kernel implemented in Triton, specifically for Ascend NPUs. The implementation includes multiple kernel versions for different scenarios. My review focuses on performance optimizations within the Triton kernels. I've identified a few areas where using more efficient Triton intrinsics like tl.rsqrt and optimizing data loading can improve performance. Overall, the changes look good and add valuable functionality.

mask = cols < D
b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
b_var = tl.sum(b_x * b_x, axis=0)
b_rstd = 1 / tl.sqrt(b_var + eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For better performance, use tl.rsqrt which is generally faster than 1 / tl.sqrt for computing the reciprocal square root.

Suggested change
b_rstd = 1 / tl.sqrt(b_var + eps)
b_rstd = tl.rsqrt(b_var + eps)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, in the QWen3-Next model, we only use the kernel l2norm_fwd_kernel2_loop, so the other kernels remain consistent with those in VLLM.

p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
b_var = tl.sum(b_x * b_x, axis=1)
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For better performance, you can use tl.rsqrt and multiplication instead of division by tl.sqrt. tl.rsqrt is often faster.

Suggested change
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
b_y = b_x * tl.rsqrt(b_var + eps)[:, None]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, in the QWen3-Next model, we only use the kernel l2norm_fwd_kernel2_loop, so the other kernels remain consistent with those in VLLM.

Comment on lines 80 to 82
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
square = xs * xs
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

You can simplify this and potentially improve performance by using the other parameter in tl.load to handle masked-out values. This avoids the need for tl.where before the summation.

Suggested change
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
square = xs * xs
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
xs = tl.load(X + (rindex + N * row_idx), mask=xmask, other=0.0).to(tl.float32)
square = xs * xs
square_sum = tl.sum(square, 1)[:, None]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accepted

@github-actions
Copy link

github-actions bot commented Dec 1, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Signed-off-by: Ascendyh <hw7osiris@outlook.com>
Signed-off-by: Ascendyh <hw7osiris@outlook.com>
@OsirisDuan OsirisDuan changed the title [task] add l2norm kernel [task] add l2norm triton kernel Dec 1, 2025
key=["D"],
)
@triton.jit
def l2norm_fwd_kernel1(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add op test.

raise RuntimeError("This layer doesn't support feature dim >= 64KB.")

if not USE_DEFAULT_FLA_NORM:
MBLOCK = 69
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this branch be deleted?

@weijinqian0
Copy link
Collaborator

weijinqian0 commented Dec 1, 2025

need golden test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants