-
Notifications
You must be signed in to change notification settings - Fork 621
[task] add l2norm triton kernel #4595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[task] add l2norm triton kernel #4595
Conversation
Signed-off-by: Ascendyh <hw7osiris@outlook.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new L2 normalization kernel implemented in Triton, specifically for Ascend NPUs. The implementation includes multiple kernel versions for different scenarios. My review focuses on performance optimizations within the Triton kernels. I've identified a few areas where using more efficient Triton intrinsics like tl.rsqrt and optimizing data loading can improve performance. Overall, the changes look good and add valuable functionality.
| mask = cols < D | ||
| b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) | ||
| b_var = tl.sum(b_x * b_x, axis=0) | ||
| b_rstd = 1 / tl.sqrt(b_var + eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, in the QWen3-Next model, we only use the kernel l2norm_fwd_kernel2_loop, so the other kernels remain consistent with those in VLLM.
| p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) | ||
| b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) | ||
| b_var = tl.sum(b_x * b_x, axis=1) | ||
| b_y = b_x / tl.sqrt(b_var + eps)[:, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, in the QWen3-Next model, we only use the kernel l2norm_fwd_kernel2_loop, so the other kernels remain consistent with those in VLLM.
vllm_ascend/ops/triton/fla/l2norm.py
Outdated
| xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) | ||
| square = xs * xs | ||
| square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can simplify this and potentially improve performance by using the other parameter in tl.load to handle masked-out values. This avoids the need for tl.where before the summation.
| xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) | |
| square = xs * xs | |
| square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] | |
| xs = tl.load(X + (rindex + N * row_idx), mask=xmask, other=0.0).to(tl.float32) | |
| square = xs * xs | |
| square_sum = tl.sum(square, 1)[:, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accepted
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
Signed-off-by: Ascendyh <hw7osiris@outlook.com>
Signed-off-by: Ascendyh <hw7osiris@outlook.com>
| key=["D"], | ||
| ) | ||
| @triton.jit | ||
| def l2norm_fwd_kernel1( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add op test.
| raise RuntimeError("This layer doesn't support feature dim >= 64KB.") | ||
|
|
||
| if not USE_DEFAULT_FLA_NORM: | ||
| MBLOCK = 69 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this branch be deleted?
|
need golden test |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?