Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions vllm_ascend/ops/triton/fla/l2norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from typing import Optional
import torch
import triton
import triton.language as tl

from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num

USE_DEFAULT_FLA_NORM = False

@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
],
key=["D"],
)
@triton.jit
def l2norm_fwd_kernel1(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add op test.

x,
y,
D,
BD: tl.constexpr,
eps,
):
i_t = tl.program_id(0)
x += i_t * D
y += i_t * D
# Compute mean and variance
cols = tl.arange(0, BD)
mask = cols < D
b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
b_var = tl.sum(b_x * b_x, axis=0)
b_rstd = 1 / tl.sqrt(b_var + eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For better performance, use tl.rsqrt which is generally faster than 1 / tl.sqrt for computing the reciprocal square root.

Suggested change
b_rstd = 1 / tl.sqrt(b_var + eps)
b_rstd = tl.rsqrt(b_var + eps)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, in the QWen3-Next model, we only use the kernel l2norm_fwd_kernel2_loop, so the other kernels remain consistent with those in VLLM.

# tl.store(Rstd + i_t, rstd)
# Normalize and apply linear transformation
b_y = b_x * b_rstd
tl.store(y + cols, b_y, mask=mask)


@triton.autotune(
configs=[
triton.Config({"BT": BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16]
for BT in [8, 16, 32, 64, 128]
],
key=["D"],
)
@triton.jit(do_not_specialize=["NB"])
def l2norm_fwd_kernel(
x,
y,
eps,
NB,
T,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr,
):
i_t = tl.program_id(0)
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
b_var = tl.sum(b_x * b_x, axis=1)
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For better performance, you can use tl.rsqrt and multiplication instead of division by tl.sqrt. tl.rsqrt is often faster.

Suggested change
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
b_y = b_x * tl.rsqrt(b_var + eps)[:, None]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, in the QWen3-Next model, we only use the kernel l2norm_fwd_kernel2_loop, so the other kernels remain consistent with those in VLLM.

p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))


@triton.jit
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr,
MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr):
base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK)
rindex = tl.arange(0, N)[None, :]

for chunk in range(NUM_CHUNKS):
row_idx = base_row + chunk * MBLOCK + tl.arange(0, MBLOCK)[:, None]
xmask = row_idx < M

xs = tl.load(X + (rindex + N * row_idx), mask=xmask, other=0.0).to(tl.float32)
square = xs * xs
square_sum = tl.sum(square, 1)[:, None]
rsqrt = tl.rsqrt(square_sum + eps)

tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)

def l2norm_fwd(
x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None
):
x_shape_og = x.shape
x = x.reshape(-1, x.shape[-1])
# allocate output
if output_dtype is None:
y = torch.empty_like(x)
else:
y = torch.empty_like(x, dtype=output_dtype)
assert y.stride(-1) == 1
T, D = x.shape[0], x.shape[-1]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
if D > BD:
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")

if not USE_DEFAULT_FLA_NORM:
MBLOCK = 69
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this branch be deleted?

# M, N = x.shape
num_core = get_vectorcore_num()
main_bs = triton.cdiv(T, num_core)
num_sub_blocks = triton.cdiv(main_bs, MBLOCK)
grid = (num_core, )
l2norm_fwd_kernel2_loop[grid](
X=x,
Y=y,
eps=eps,
M=T,
N=D,
MBLOCK=MBLOCK,
NUM_CHUNKS=num_sub_blocks,
)
else:
if D <= 512:
NB = triton.cdiv(T, 2048)

def grid(meta):
return (triton.cdiv(T, meta["BT"]),)

l2norm_fwd_kernel[grid](
x,
y,
eps,
NB=NB,
T=T,
D=D,
BD=BD,
)
else:
l2norm_fwd_kernel1[(T,)](
x,
y,
eps=eps,
D=D,
BD=BD,
)

return y.view(x_shape_og)
Loading