diff --git a/vllm_ascend/ops/triton/fla/l2norm.py b/vllm_ascend/ops/triton/fla/l2norm.py new file mode 100644 index 00000000000..5d91e0c4068 --- /dev/null +++ b/vllm_ascend/ops/triton/fla/l2norm.py @@ -0,0 +1,144 @@ +from typing import Optional +import torch +import triton +import triton.language as tl + +from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num + +USE_DEFAULT_FLA_NORM = False + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32] + ], + key=["D"], +) +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + D, + BD: tl.constexpr, + eps, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=0) + b_rstd = 1 / tl.sqrt(b_var + eps) + # tl.store(Rstd + i_t, rstd) + # Normalize and apply linear transformation + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({"BT": BT}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + for BT in [8, 16, 32, 64, 128] + ], + key=["D"], +) +@triton.jit(do_not_specialize=["NB"]) +def l2norm_fwd_kernel( + x, + y, + eps, + NB, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1) + b_y = b_x / tl.sqrt(b_var + eps)[:, None] + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr, + MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr): + base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK) + rindex = tl.arange(0, N)[None, :] + + for chunk in range(NUM_CHUNKS): + row_idx = base_row + chunk * MBLOCK + tl.arange(0, MBLOCK)[:, None] + xmask = row_idx < M + + xs = tl.load(X + (rindex + N * row_idx), mask=xmask, other=0.0).to(tl.float32) + square = xs * xs + square_sum = tl.sum(square, 1)[:, None] + rsqrt = tl.rsqrt(square_sum + eps) + + tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) + +def l2norm_fwd( + x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None +): + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + if not USE_DEFAULT_FLA_NORM: + MBLOCK = 69 + # M, N = x.shape + num_core = get_vectorcore_num() + main_bs = triton.cdiv(T, num_core) + num_sub_blocks = triton.cdiv(main_bs, MBLOCK) + grid = (num_core, ) + l2norm_fwd_kernel2_loop[grid]( + X=x, + Y=y, + eps=eps, + M=T, + N=D, + MBLOCK=MBLOCK, + NUM_CHUNKS=num_sub_blocks, + ) + else: + if D <= 512: + NB = triton.cdiv(T, 2048) + + def grid(meta): + return (triton.cdiv(T, meta["BT"]),) + + l2norm_fwd_kernel[grid]( + x, + y, + eps, + NB=NB, + T=T, + D=D, + BD=BD, + ) + else: + l2norm_fwd_kernel1[(T,)]( + x, + y, + eps=eps, + D=D, + BD=BD, + ) + + return y.view(x_shape_og)