-
Notifications
You must be signed in to change notification settings - Fork 623
[task] add l2norm triton kernel #4595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| from typing import Optional | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num | ||
|
|
||
| USE_DEFAULT_FLA_NORM = False | ||
|
|
||
| @triton.autotune( | ||
| configs=[ | ||
| triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32] | ||
| ], | ||
| key=["D"], | ||
| ) | ||
| @triton.jit | ||
| def l2norm_fwd_kernel1( | ||
| x, | ||
| y, | ||
| D, | ||
| BD: tl.constexpr, | ||
| eps, | ||
| ): | ||
| i_t = tl.program_id(0) | ||
| x += i_t * D | ||
| y += i_t * D | ||
| # Compute mean and variance | ||
| cols = tl.arange(0, BD) | ||
| mask = cols < D | ||
| b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) | ||
| b_var = tl.sum(b_x * b_x, axis=0) | ||
| b_rstd = 1 / tl.sqrt(b_var + eps) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, in the QWen3-Next model, we only use the kernel |
||
| # tl.store(Rstd + i_t, rstd) | ||
| # Normalize and apply linear transformation | ||
| b_y = b_x * b_rstd | ||
| tl.store(y + cols, b_y, mask=mask) | ||
|
|
||
|
|
||
| @triton.autotune( | ||
| configs=[ | ||
| triton.Config({"BT": BT}, num_warps=num_warps) | ||
| for num_warps in [1, 2, 4, 8, 16] | ||
| for BT in [8, 16, 32, 64, 128] | ||
| ], | ||
| key=["D"], | ||
| ) | ||
| @triton.jit(do_not_specialize=["NB"]) | ||
| def l2norm_fwd_kernel( | ||
| x, | ||
| y, | ||
| eps, | ||
| NB, | ||
| T, | ||
| D: tl.constexpr, | ||
| BT: tl.constexpr, | ||
| BD: tl.constexpr, | ||
| ): | ||
| i_t = tl.program_id(0) | ||
| p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) | ||
| b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) | ||
| b_var = tl.sum(b_x * b_x, axis=1) | ||
| b_y = b_x / tl.sqrt(b_var + eps)[:, None] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, in the QWen3-Next model, we only use the kernel |
||
| p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) | ||
| tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, | ||
| MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr): | ||
| base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK) | ||
| rindex = tl.arange(0, N)[None, :] | ||
|
|
||
| for chunk in range(NUM_CHUNKS): | ||
| row_idx = base_row + chunk * MBLOCK + tl.arange(0, MBLOCK)[:, None] | ||
| xmask = row_idx < M | ||
|
|
||
| xs = tl.load(X + (rindex + N * row_idx), mask=xmask, other=0.0).to(tl.float32) | ||
| square = xs * xs | ||
| square_sum = tl.sum(square, 1)[:, None] | ||
| rsqrt = tl.rsqrt(square_sum + eps) | ||
|
|
||
| tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) | ||
|
|
||
| def l2norm_fwd( | ||
| x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None | ||
| ): | ||
| x_shape_og = x.shape | ||
| x = x.reshape(-1, x.shape[-1]) | ||
| # allocate output | ||
| if output_dtype is None: | ||
| y = torch.empty_like(x) | ||
| else: | ||
| y = torch.empty_like(x, dtype=output_dtype) | ||
| assert y.stride(-1) == 1 | ||
| T, D = x.shape[0], x.shape[-1] | ||
| # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) | ||
| # Less than 64KB per feature: enqueue fused kernel | ||
| MAX_FUSED_SIZE = 65536 // x.element_size() | ||
| BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) | ||
| if D > BD: | ||
| raise RuntimeError("This layer doesn't support feature dim >= 64KB.") | ||
|
|
||
| if not USE_DEFAULT_FLA_NORM: | ||
| MBLOCK = 69 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this branch be deleted? |
||
| # M, N = x.shape | ||
| num_core = get_vectorcore_num() | ||
| main_bs = triton.cdiv(T, num_core) | ||
| num_sub_blocks = triton.cdiv(main_bs, MBLOCK) | ||
| grid = (num_core, ) | ||
| l2norm_fwd_kernel2_loop[grid]( | ||
| X=x, | ||
| Y=y, | ||
| eps=eps, | ||
| M=T, | ||
| N=D, | ||
| MBLOCK=MBLOCK, | ||
| NUM_CHUNKS=num_sub_blocks, | ||
| ) | ||
| else: | ||
| if D <= 512: | ||
| NB = triton.cdiv(T, 2048) | ||
|
|
||
| def grid(meta): | ||
| return (triton.cdiv(T, meta["BT"]),) | ||
|
|
||
| l2norm_fwd_kernel[grid]( | ||
| x, | ||
| y, | ||
| eps, | ||
| NB=NB, | ||
| T=T, | ||
| D=D, | ||
| BD=BD, | ||
| ) | ||
| else: | ||
| l2norm_fwd_kernel1[(T,)]( | ||
| x, | ||
| y, | ||
| eps=eps, | ||
| D=D, | ||
| BD=BD, | ||
| ) | ||
|
|
||
| return y.view(x_shape_og) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add op test.