Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions vllm_ascend/ops/triton/fla/l2norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from typing import Optional
import torch
import triton
import triton.language as tl
import triton.runtime.driver as driver

USE_DEFAULT_FLA_NORM = False

def get_npu_properties():
device = torch.npu.current_device()
return driver.active.utils.get_device_properties(device)

@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
],
key=["D"],
)
@triton.jit
def l2norm_fwd_kernel1(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add op test.

x,
y,
D,
BD: tl.constexpr,
eps,
):
i_t = tl.program_id(0)
x += i_t * D
y += i_t * D
# Compute mean and variance
cols = tl.arange(0, BD)
mask = cols < D
b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
b_var = tl.sum(b_x * b_x, axis=0)
b_rstd = 1 / tl.sqrt(b_var + eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For better performance, use tl.rsqrt which is generally faster than 1 / tl.sqrt for computing the reciprocal square root.

Suggested change
b_rstd = 1 / tl.sqrt(b_var + eps)
b_rstd = tl.rsqrt(b_var + eps)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, in the QWen3-Next model, we only use the kernel l2norm_fwd_kernel2_loop, so the other kernels remain consistent with those in VLLM.

# tl.store(Rstd + i_t, rstd)
# Normalize and apply linear transformation
b_y = b_x * b_rstd
tl.store(y + cols, b_y, mask=mask)


@triton.autotune(
configs=[
triton.Config({"BT": BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16]
for BT in [8, 16, 32, 64, 128]
],
key=["D"],
)
@triton.jit(do_not_specialize=["NB"])
def l2norm_fwd_kernel(
x,
y,
eps,
NB,
T,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr,
):
i_t = tl.program_id(0)
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
b_var = tl.sum(b_x * b_x, axis=1)
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For better performance, you can use tl.rsqrt and multiplication instead of division by tl.sqrt. tl.rsqrt is often faster.

Suggested change
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
b_y = b_x * tl.rsqrt(b_var + eps)[:, None]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, in the QWen3-Next model, we only use the kernel l2norm_fwd_kernel2_loop, so the other kernels remain consistent with those in VLLM.

p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))


@triton.jit
def l2norm_fwd_kernel2_loop(X, Y, eps, M, N: tl.constexpr,
MBLOCK: tl.constexpr, NUM_CHUNKS: tl.constexpr):
base_row = tl.program_id(0) * (NUM_CHUNKS * MBLOCK)
rindex = tl.arange(0, N)[None, :]

for chunk in range(NUM_CHUNKS):
row_idx = base_row + chunk * MBLOCK + tl.arange(0, MBLOCK)[:, None]
xmask = row_idx < M

xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
square = xs * xs
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

You can simplify this and potentially improve performance by using the other parameter in tl.load to handle masked-out values. This avoids the need for tl.where before the summation.

Suggested change
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
square = xs * xs
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
xs = tl.load(X + (rindex + N * row_idx), mask=xmask, other=0.0).to(tl.float32)
square = xs * xs
square_sum = tl.sum(square, 1)[:, None]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accepted

rsqrt = tl.rsqrt(square_sum + eps)

tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)

def l2norm_fwd(
x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None
):
x_shape_og = x.shape
x = x.reshape(-1, x.shape[-1])
# allocate output
if output_dtype is None:
y = torch.empty_like(x)
else:
y = torch.empty_like(x, dtype=output_dtype)
assert y.stride(-1) == 1
T, D = x.shape[0], x.shape[-1]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
if D > BD:
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")

if not USE_DEFAULT_FLA_NORM:
MBLOCK = 69
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this branch be deleted?

# M, N = x.shape
num_core = get_npu_properties()["num_vectorcore"]
main_bs = triton.cdiv(T, num_core)
num_sub_blocks = triton.cdiv(main_bs, MBLOCK)
grid = (num_core, )
l2norm_fwd_kernel2_loop[grid](
X=x,
Y=y,
eps=eps,
M=T,
N=D,
MBLOCK=MBLOCK,
NUM_CHUNKS=num_sub_blocks,
)
else:
if D <= 512:
NB = triton.cdiv(T, 2048)

def grid(meta):
return (triton.cdiv(T, meta["BT"]),)

l2norm_fwd_kernel[grid](
x,
y,
eps,
NB=NB,
T=T,
D=D,
BD=BD,
)
else:
l2norm_fwd_kernel1[(T,)](
x,
y,
eps=eps,
D=D,
BD=BD,
)

return y.view(x_shape_og)
Loading