Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,11 @@ class RunResult:
"examples.mamba2_chunk_state",
"helion_mamba2_chunk_state_kernel",
),
"gdn_fwd_h": (
"tritonbench.operators.gdn_fwd_h.operator",
"examples.gdn_fwd_h",
"helion_gdn_fwd_h_tb",
),
}


Expand Down Expand Up @@ -652,6 +657,13 @@ class RunResult:
"helion_mamba2_chunk_state_kernel_speedup": "helion_speedup",
"helion_mamba2_chunk_state_kernel_accuracy": "helion_accuracy",
},
"gdn_fwd_h": {
"eager": "baseline",
"compile_speedup": "torch_compile_speedup",
"compile_accuracy": "torch_compile_accuracy",
"helion_gdn_fwd_h_speedup": "helion_speedup",
"helion_gdn_fwd_h_accuracy": "helion_accuracy",
},
}


Expand Down
237 changes: 237 additions & 0 deletions examples/gdn_fwd_h.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
"""
Gated Delta Net Fwd H Kernel
============================

This code implements a fwd_h kernel as used in gated delta net
"""

# %%
# Imports
# -------
from __future__ import annotations

import math
from typing import Callable

import torch

import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl


# %%
# Helion Kernel Implementation
# ----------------------------
@helion.kernel()
def helion_gdn_fwd_h_kernel(
k_c: torch.Tensor, w_c: torch.Tensor, u_c: torch.Tensor, g_c: torch.Tensor
) -> torch.Tensor:
"""
Argument:
k_c: (batch, nchunks, chunk_size, nheads, dhead)
w_c: (batch, nchunks, chunk_size, nheads, dhead)
u_c: (batch, nchunks, chunk_size, nheads, expand_v*dhead)
g_c: (batch, nchunks, chunk_size, nheads)
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""

batch, nchunks, chunk_size, nheads, dhead = k_c.shape
dhead = hl.specialize(dhead)
chunk_size = hl.specialize(chunk_size)
dstate = u_c.shape[-1]

acc_dtype = torch.float32
dtype = k_c.dtype

h = torch.empty(
batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k_c.device
)
block_v = hl.register_block_size(dstate)
seqlen = chunk_size * nchunks

for tile_b, tile_h, tile_v in hl.tile(
[batch, nheads, dstate], block_size=[1, 1, block_v]
):
b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype)
for i_t in range(nchunks):
h[tile_b.begin, i_t, tile_h.begin, :, tile_v] = b_h.to(dtype)
b_w = w_c[tile_b.begin, i_t, :, tile_h.begin, :]
c_h = b_h.to(dtype)
b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype)
p_v = u_c[tile_b.begin, i_t, :, tile_h.begin, tile_v].to(acc_dtype)
b_v = p_v - b_v
m_t = (i_t * chunk_size + hl.arange(0, chunk_size)) < seqlen
b_g_last = g_c[tile_b.begin, i_t, chunk_size - 1, tile_h.begin].to(
acc_dtype
)
b_g = g_c[tile_b.begin, i_t, :, tile_h.begin].to(acc_dtype)
b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None]
b_g_last = torch.exp(b_g_last)
b_h *= b_g_last
b_v = b_v.to(dtype)
p_k = k_c[tile_b.begin, i_t, :, tile_h.begin, :]
b_h = hl.dot(p_k.T, b_v, acc=b_h)
return h


def helion_gdn_fwd_h(
k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int
) -> torch.Tensor:
"""
Argument:
k: (batch, seqlen, nheads, dhead)
w: (batch, seqlen, nheads, dhead)
u: (batch, seqlen, nheads, expand_v*dhead)
g: (batch, seqlen, nheads)
chunk_size: int
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""

batch, seqlen, nheads, dhead = k.shape
dstate = u.shape[-1]
nchunks = (seqlen + chunk_size - 1) // chunk_size

k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
u_c = u.reshape(batch, nchunks, chunk_size, nheads, dstate)
g_c = g.reshape(batch, nchunks, chunk_size, nheads)
return helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c)


def helion_gdn_fwd_h_tb(
tb_obj: object,
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: torch.Tensor,
chunk_size: int,
) -> Callable[[], torch.Tensor]:
"""
Argument:
k: (batch, seqlen, nheads, dhead)
w: (batch, seqlen, nheads, dhead)
u: (batch, seqlen, nheads, expand_v*dhead)
g: (batch, seqlen, nheads)
chunk_size: int
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""
return lambda: helion_gdn_fwd_h(k, w, u, g, chunk_size)


# %%
# Reference Function
# -------------
def ref_gdn_fwd_h(
k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int
) -> torch.Tensor:
"""
Argument:
k: (batch, seqlen, nheads, dhead)
w: (batch, seqlen, nheads, dhead)
u: (batch, seqlen, nheads, expand_v*dhead)
g: (batch, seqlen, nheads)
chunk_size: int
Return:
h: (batch, nchunks, nheads, dhead, expand_v*dhead)
"""

batch, seqlen, nheads, dhead = k.shape
expand_v = u.shape[-1] // dhead
nchunks = (seqlen + chunk_size - 1) // chunk_size

acc_dtype = torch.float32
dtype = k.dtype

h = torch.empty(
batch, nchunks, nheads, dhead, expand_v * dhead, dtype=k.dtype, device=k.device
)
b_h = torch.zeros(
batch, nheads, dhead, expand_v * dhead, dtype=acc_dtype, device=k.device
)

k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead)
w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead)
u_c = u.reshape(batch, nchunks, chunk_size, nheads, expand_v * dhead)
g_c = g.reshape(batch, nchunks, chunk_size, nheads)
for i_t in range(nchunks):
h[:, i_t, :, :, :] = b_h.to(dtype)
b_w = w_c[:, i_t, :, :, :].to(acc_dtype)
c_h = b_h.to(dtype).to(acc_dtype)
b_v = torch.einsum("bchk,bhkv->bchv", b_w, c_h)
p_v = u_c[:, i_t, :, :, :].to(acc_dtype)
b_v = p_v - b_v
last_idx = min((i_t + 1) * chunk_size, seqlen) - 1
m_t = (i_t * chunk_size + torch.arange(0, chunk_size, device=k.device)) < seqlen
b_g_last = g[:, last_idx, :].to(acc_dtype)
b_g = g_c[:, i_t, :, :].to(acc_dtype) # batch, chunk, nheads
b_v *= torch.where(
m_t.unsqueeze(0).unsqueeze(-1), torch.exp(b_g_last.unsqueeze(1) - b_g), 0
).unsqueeze(-1)
b_g_last = torch.exp(b_g_last)
b_h *= b_g_last.unsqueeze(-1).unsqueeze(-1)
b_v = b_v.to(dtype).to(acc_dtype)
p_k = k_c[:, i_t, :, :, :].to(acc_dtype)
b_h += torch.einsum("bchk,bchv->bhkv", p_k, b_v)
return h


# %%
# Testing Function
# -------------
def test(
batch: int,
nheads: int,
seqlen: int,
chunk_size: int,
dhead: int,
dstate: int,
dtype: torch.dtype = torch.float16,
) -> None:
k = torch.randn(batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device=DEVICE)
k = torch.nn.functional.rms_norm(k, (dhead,))
w = torch.randn(
batch,
seqlen // chunk_size,
chunk_size,
nheads,
dhead,
dtype=torch.float32,
device=DEVICE,
)
# w = torch.nn.functional.rms_norm(w.to(torch.bfloat16), (dhead,))
wu, ws, wv = torch.linalg.svd(w.permute(0, 1, 3, 2, 4), full_matrices=False)
w = torch.einsum("bnhik,bnhkj->bnhij", wu, wv)
w = (
w.permute(0, 1, 3, 2, 4)
.reshape(batch, seqlen, nheads, dhead)
.to(torch.bfloat16)
)
u = torch.randn(batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device=DEVICE)
u = torch.nn.functional.rms_norm(u, (dstate,))
g = torch.cumsum(
0.5
* math.log(1 / dhead)
* torch.rand(batch, seqlen, nheads, dtype=torch.float32, device=DEVICE),
dim=1,
)
args = (k, w, u, g, chunk_size)
run_example(helion_gdn_fwd_h, ref_gdn_fwd_h, args)


# %%
# Main Function
# -----------
def main() -> None:
"""
Main entry point that runs the attention kernel test with specific parameters.
"""
test(8, 80, 4096, 256, 64, 128)


if __name__ == "__main__":
main()
Loading