Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
256 changes: 256 additions & 0 deletions vllm_ascend/ops/triton/layernorm/split_qkv_rmsnorm_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import torch
import torch_npu

Check failure on line 2 in vllm_ascend/ops/triton/layernorm/split_qkv_rmsnorm_rope.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 2 in vllm_ascend/ops/triton/layernorm/split_qkv_rmsnorm_rope.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 3 in vllm_ascend/ops/triton/layernorm/split_qkv_rmsnorm_rope.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton.language": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 3 in vllm_ascend/ops/triton/layernorm/split_qkv_rmsnorm_rope.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton.language": module is installed, but missing library stubs or py.typed marker [import-untyped]
import triton

Check failure on line 4 in vllm_ascend/ops/triton/layernorm/split_qkv_rmsnorm_rope.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton.runtime": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 4 in vllm_ascend/ops/triton/layernorm/split_qkv_rmsnorm_rope.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton.runtime.driver": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 4 in vllm_ascend/ops/triton/layernorm/split_qkv_rmsnorm_rope.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton.runtime": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 4 in vllm_ascend/ops/triton/layernorm/split_qkv_rmsnorm_rope.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "triton.runtime.driver": module is installed, but missing library stubs or py.typed marker [import-untyped]
import triton.language as tl
import triton.runtime.driver as driver

import torch_npu._inductor


@triton.jit
def split_qkv_rmsnorm_rope_kernel(
input_ptr,
sin_ptr,
cos_ptr,
q_ptr,
k_ptr,
v_ptr,
q_weight_ptr,
q_bias_ptr,
k_weight_ptr,
k_bias_ptr,
batch_size,
q_hidden_size: tl.constexpr,
kv_hidden_size: tl.constexpr,
total_hidden_size: tl.constexpr,
eps: tl.constexpr,
Q_BLOCK_SIZE: tl.constexpr,
KV_BLOCK_SIZE: tl.constexpr,
BIAS: tl.constexpr,
HEAD_DIM: tl.constexpr,
HALF_HEAD_DIM: tl.constexpr,
):
Comment on lines +12 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is significant code duplication between the processing loops for Q and K projections (lines 37-106 for Q, and 108-176 for K). This makes the code harder to read, maintain, and debug. A bug fix or optimization in one path might be missed in the other.

Consider refactoring the common RMSNorm and RoPE logic into a separate @triton.jit helper function. This function would be called from both the Q and K loops with the appropriate parameters (e.g., pointers, sizes, offsets).

row_pid = tl.program_id(0)
col_pid = tl.program_id(1)
row_step = tl.num_programs(0)
# q
weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM))
if BIAS:
bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM))
input_offset = row_pid * total_hidden_size
output_offset = row_pid * q_hidden_size
input_offset_step = row_step * total_hidden_size
output_offset_step = row_step * q_hidden_size
for row_idx in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE)
valid_mask = col_indices < q_hidden_size
input_values = (
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
.to(tl.float32)
.reshape(Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
)
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
Q_BLOCK_SIZE // HEAD_DIM, 1
)
normalized_values = (
input_values * reciprocal_std
) # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (normalized_values * weight_values + bias_values).to(
tl.bfloat16
)
else:
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
# rope
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
x1 = tl.extract_slice(
normalized_values,
offsets=(0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
x2 = tl.extract_slice(
normalized_values,
offsets=(0, HALF_HEAD_DIM),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
cat_x = tl.insert_slice(
cat_x,
-x2,
offsets=(0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.insert_slice(
cat_x,
x1,
offsets=(0, HALF_HEAD_DIM),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
roped_q = cat_x * sin + normalized_values * cos
# store
tl.store(
q_ptr + output_offset + col_indices,
roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty),
mask=valid_mask,
)
input_offset += input_offset_step
output_offset += output_offset_step

# k
weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM))
if BIAS:
bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM))
input_offset = row_pid * total_hidden_size + q_hidden_size
output_offset = row_pid * kv_hidden_size
output_offset_step = row_step * kv_hidden_size
for row_idx in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
valid_mask = col_indices < kv_hidden_size
input_values = (
tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0)
.to(tl.float32)
.reshape(KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)
)
squares = input_values * input_values
variances = tl.sum(squares, axis=1) / HEAD_DIM
reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape(
KV_BLOCK_SIZE // HEAD_DIM, 1
)
normalized_values = (
input_values * reciprocal_std
) # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM)
if BIAS:
normalized_values = (normalized_values * weight_values + bias_values).to(
tl.bfloat16
)
else:
normalized_values = (normalized_values * weight_values).to(tl.bfloat16)
# # rope
sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM)
sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM)
cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM)
x1 = tl.extract_slice(
normalized_values,
offsets=(0, 0),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
x2 = tl.extract_slice(
normalized_values,
offsets=(0, HALF_HEAD_DIM),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
cat_x = tl.insert_slice(
cat_x,
-x2,
offsets=(0, 0),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
cat_x = tl.insert_slice(
cat_x,
x1,
offsets=(0, HALF_HEAD_DIM),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
roped_k = cat_x * sin + normalized_values * cos
# store
tl.store(
k_ptr + output_offset + col_indices,
roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE),
mask=valid_mask,
)
input_offset += input_offset_step
output_offset += output_offset_step

# v
input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size
output_offset = row_pid * kv_hidden_size
for _ in tl.range(row_pid, batch_size, row_step):
col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE)
valid_mask = col_indices < kv_hidden_size
input_values = tl.load(
input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0
)

Check failure on line 186 in vllm_ascend/ops/triton/layernorm/split_qkv_rmsnorm_rope.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "kernels" (hint: "kernels: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 186 in vllm_ascend/ops/triton/layernorm/split_qkv_rmsnorm_rope.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "kernels" (hint: "kernels: dict[<type>, <type>] = ...") [var-annotated]
tl.store(v_ptr + output_offset + col_indices, input_values, mask=valid_mask)
input_offset += input_offset_step
output_offset += output_offset_step


kernels = {}

def get_npu_properties():
device = torch.npu.current_device()
return driver.active.utils.get_device_properties(device)

num_vectorcore = get_npu_properties()["num_vectorcore"]
Comment on lines +194 to +198
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The num_vectorcore property is fetched at module import time using torch.npu.current_device(). This can lead to incorrect behavior if the active device changes after import or in multi-device scenarios. The device properties should be fetched dynamically. get_npu_properties should be modified to accept a device, and the global num_vectorcore should be removed.

Suggested change
def get_npu_properties():
device = torch.npu.current_device()
return driver.active.utils.get_device_properties(device)
num_vectorcore = get_npu_properties()["num_vectorcore"]
def get_npu_properties(device):
return driver.active.utils.get_device_properties(device)


def split_qkv_rmsnorm_rope(
input,
sin,
cos,
q_weight,
k_weight,
q_hidden_size,
kv_hidden_size,
head_dim,
eps,
q_bias,
k_bias,
):

KV_BLOCK_SIZE = triton.next_power_of_2(head_dim)
assert KV_BLOCK_SIZE == head_dim
assert q_hidden_size % kv_hidden_size == 0
Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim
batch_size = input.shape[0]
total_hidden_size = q_hidden_size + kv_hidden_size * 2
q_output = torch.empty(
batch_size, q_hidden_size, device=input.device, dtype=input.dtype
)
k_output = torch.empty(
batch_size, kv_hidden_size, device=input.device, dtype=input.dtype
)
v_output = torch.empty(
batch_size, kv_hidden_size, device=input.device, dtype=input.dtype
)
n_cols = kv_hidden_size // KV_BLOCK_SIZE
assert num_vectorcore % n_cols == 0
n_rows = num_vectorcore // n_cols
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To complete the fix for fetching device properties dynamically, num_vectorcore should be calculated here using the input tensor's device before it is used to calculate n_rows.

Suggested change
n_rows = num_vectorcore // n_cols
num_vectorcore = get_npu_properties(input.device)["num_vectorcore"]
n_rows = num_vectorcore // n_cols

BIAS = q_bias is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The BIAS flag is determined solely by whether q_bias is None. This same flag is then used inside the kernel to conditionally access both q_bias_ptr and k_bias_ptr. If q_bias is provided but k_bias is None, the kernel will attempt to dereference a null pointer for k_bias_ptr, leading to a crash. The bias for Q and K should be handled independently.

To fix this, you should introduce separate flags for Q and K bias (e.g., Q_BIAS and K_BIAS) and pass them to the kernel, updating the kernel logic to use the respective flags.


split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)](
input,
sin,
cos,
q_output,
k_output,
v_output,
q_weight,
q_bias,
k_weight,
k_bias,
batch_size,
q_hidden_size,
kv_hidden_size,
total_hidden_size,
eps,
Q_BLOCK_SIZE,
KV_BLOCK_SIZE,
BIAS,
head_dim,
head_dim // 2,
)
return q_output, k_output, v_output
Loading