From fae03284350697e13585f6fcb2ffe6059646d306 Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Wed, 12 Nov 2025 10:00:20 -0800 Subject: [PATCH] example: gdn_fwd_h --- benchmarks/run.py | 12 +++ examples/gdn_fwd_h.py | 237 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+) create mode 100644 examples/gdn_fwd_h.py diff --git a/benchmarks/run.py b/benchmarks/run.py index 750833422..2bddf0fe2 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -336,6 +336,11 @@ class RunResult: "examples.mamba2_chunk_state", "helion_mamba2_chunk_state_kernel", ), + "gdn_fwd_h": ( + "tritonbench.operators.gdn_fwd_h.operator", + "examples.gdn_fwd_h", + "helion_gdn_fwd_h_tb", + ), } @@ -652,6 +657,13 @@ class RunResult: "helion_mamba2_chunk_state_kernel_speedup": "helion_speedup", "helion_mamba2_chunk_state_kernel_accuracy": "helion_accuracy", }, + "gdn_fwd_h": { + "eager": "baseline", + "compile_speedup": "torch_compile_speedup", + "compile_accuracy": "torch_compile_accuracy", + "helion_gdn_fwd_h_speedup": "helion_speedup", + "helion_gdn_fwd_h_accuracy": "helion_accuracy", + }, } diff --git a/examples/gdn_fwd_h.py b/examples/gdn_fwd_h.py new file mode 100644 index 000000000..89572ea42 --- /dev/null +++ b/examples/gdn_fwd_h.py @@ -0,0 +1,237 @@ +""" +Gated Delta Net Fwd H Kernel +============================ + +This code implements a fwd_h kernel as used in gated delta net +""" + +# %% +# Imports +# ------- +from __future__ import annotations + +import math +from typing import Callable + +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import run_example +import helion.language as hl + + +# %% +# Helion Kernel Implementation +# ---------------------------- +@helion.kernel() +def helion_gdn_fwd_h_kernel( + k_c: torch.Tensor, w_c: torch.Tensor, u_c: torch.Tensor, g_c: torch.Tensor +) -> torch.Tensor: + """ + Argument: + k_c: (batch, nchunks, chunk_size, nheads, dhead) + w_c: (batch, nchunks, chunk_size, nheads, dhead) + u_c: (batch, nchunks, chunk_size, nheads, expand_v*dhead) + g_c: (batch, nchunks, chunk_size, nheads) + Return: + h: (batch, nchunks, nheads, dhead, expand_v*dhead) + """ + + batch, nchunks, chunk_size, nheads, dhead = k_c.shape + dhead = hl.specialize(dhead) + chunk_size = hl.specialize(chunk_size) + dstate = u_c.shape[-1] + + acc_dtype = torch.float32 + dtype = k_c.dtype + + h = torch.empty( + batch, nchunks, nheads, dhead, dstate, dtype=dtype, device=k_c.device + ) + block_v = hl.register_block_size(dstate) + seqlen = chunk_size * nchunks + + for tile_b, tile_h, tile_v in hl.tile( + [batch, nheads, dstate], block_size=[1, 1, block_v] + ): + b_h = hl.zeros([dhead, tile_v], dtype=acc_dtype) + for i_t in range(nchunks): + h[tile_b.begin, i_t, tile_h.begin, :, tile_v] = b_h.to(dtype) + b_w = w_c[tile_b.begin, i_t, :, tile_h.begin, :] + c_h = b_h.to(dtype) + b_v = hl.dot(b_w, c_h, out_dtype=acc_dtype) + p_v = u_c[tile_b.begin, i_t, :, tile_h.begin, tile_v].to(acc_dtype) + b_v = p_v - b_v + m_t = (i_t * chunk_size + hl.arange(0, chunk_size)) < seqlen + b_g_last = g_c[tile_b.begin, i_t, chunk_size - 1, tile_h.begin].to( + acc_dtype + ) + b_g = g_c[tile_b.begin, i_t, :, tile_h.begin].to(acc_dtype) + b_v *= torch.where(m_t, torch.exp(b_g_last - b_g), 0)[:, None] + b_g_last = torch.exp(b_g_last) + b_h *= b_g_last + b_v = b_v.to(dtype) + p_k = k_c[tile_b.begin, i_t, :, tile_h.begin, :] + b_h = hl.dot(p_k.T, b_v, acc=b_h) + return h + + +def helion_gdn_fwd_h( + k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int +) -> torch.Tensor: + """ + Argument: + k: (batch, seqlen, nheads, dhead) + w: (batch, seqlen, nheads, dhead) + u: (batch, seqlen, nheads, expand_v*dhead) + g: (batch, seqlen, nheads) + chunk_size: int + Return: + h: (batch, nchunks, nheads, dhead, expand_v*dhead) + """ + + batch, seqlen, nheads, dhead = k.shape + dstate = u.shape[-1] + nchunks = (seqlen + chunk_size - 1) // chunk_size + + k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead) + w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead) + u_c = u.reshape(batch, nchunks, chunk_size, nheads, dstate) + g_c = g.reshape(batch, nchunks, chunk_size, nheads) + return helion_gdn_fwd_h_kernel(k_c, w_c, u_c, g_c) + + +def helion_gdn_fwd_h_tb( + tb_obj: object, + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor, + chunk_size: int, +) -> Callable[[], torch.Tensor]: + """ + Argument: + k: (batch, seqlen, nheads, dhead) + w: (batch, seqlen, nheads, dhead) + u: (batch, seqlen, nheads, expand_v*dhead) + g: (batch, seqlen, nheads) + chunk_size: int + Return: + h: (batch, nchunks, nheads, dhead, expand_v*dhead) + """ + return lambda: helion_gdn_fwd_h(k, w, u, g, chunk_size) + + +# %% +# Reference Function +# ------------- +def ref_gdn_fwd_h( + k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, g: torch.Tensor, chunk_size: int +) -> torch.Tensor: + """ + Argument: + k: (batch, seqlen, nheads, dhead) + w: (batch, seqlen, nheads, dhead) + u: (batch, seqlen, nheads, expand_v*dhead) + g: (batch, seqlen, nheads) + chunk_size: int + Return: + h: (batch, nchunks, nheads, dhead, expand_v*dhead) + """ + + batch, seqlen, nheads, dhead = k.shape + expand_v = u.shape[-1] // dhead + nchunks = (seqlen + chunk_size - 1) // chunk_size + + acc_dtype = torch.float32 + dtype = k.dtype + + h = torch.empty( + batch, nchunks, nheads, dhead, expand_v * dhead, dtype=k.dtype, device=k.device + ) + b_h = torch.zeros( + batch, nheads, dhead, expand_v * dhead, dtype=acc_dtype, device=k.device + ) + + k_c = k.reshape(batch, nchunks, chunk_size, nheads, dhead) + w_c = w.reshape(batch, nchunks, chunk_size, nheads, dhead) + u_c = u.reshape(batch, nchunks, chunk_size, nheads, expand_v * dhead) + g_c = g.reshape(batch, nchunks, chunk_size, nheads) + for i_t in range(nchunks): + h[:, i_t, :, :, :] = b_h.to(dtype) + b_w = w_c[:, i_t, :, :, :].to(acc_dtype) + c_h = b_h.to(dtype).to(acc_dtype) + b_v = torch.einsum("bchk,bhkv->bchv", b_w, c_h) + p_v = u_c[:, i_t, :, :, :].to(acc_dtype) + b_v = p_v - b_v + last_idx = min((i_t + 1) * chunk_size, seqlen) - 1 + m_t = (i_t * chunk_size + torch.arange(0, chunk_size, device=k.device)) < seqlen + b_g_last = g[:, last_idx, :].to(acc_dtype) + b_g = g_c[:, i_t, :, :].to(acc_dtype) # batch, chunk, nheads + b_v *= torch.where( + m_t.unsqueeze(0).unsqueeze(-1), torch.exp(b_g_last.unsqueeze(1) - b_g), 0 + ).unsqueeze(-1) + b_g_last = torch.exp(b_g_last) + b_h *= b_g_last.unsqueeze(-1).unsqueeze(-1) + b_v = b_v.to(dtype).to(acc_dtype) + p_k = k_c[:, i_t, :, :, :].to(acc_dtype) + b_h += torch.einsum("bchk,bchv->bhkv", p_k, b_v) + return h + + +# %% +# Testing Function +# ------------- +def test( + batch: int, + nheads: int, + seqlen: int, + chunk_size: int, + dhead: int, + dstate: int, + dtype: torch.dtype = torch.float16, +) -> None: + k = torch.randn(batch, seqlen, nheads, dhead, dtype=torch.bfloat16, device=DEVICE) + k = torch.nn.functional.rms_norm(k, (dhead,)) + w = torch.randn( + batch, + seqlen // chunk_size, + chunk_size, + nheads, + dhead, + dtype=torch.float32, + device=DEVICE, + ) + # w = torch.nn.functional.rms_norm(w.to(torch.bfloat16), (dhead,)) + wu, ws, wv = torch.linalg.svd(w.permute(0, 1, 3, 2, 4), full_matrices=False) + w = torch.einsum("bnhik,bnhkj->bnhij", wu, wv) + w = ( + w.permute(0, 1, 3, 2, 4) + .reshape(batch, seqlen, nheads, dhead) + .to(torch.bfloat16) + ) + u = torch.randn(batch, seqlen, nheads, dstate, dtype=torch.bfloat16, device=DEVICE) + u = torch.nn.functional.rms_norm(u, (dstate,)) + g = torch.cumsum( + 0.5 + * math.log(1 / dhead) + * torch.rand(batch, seqlen, nheads, dtype=torch.float32, device=DEVICE), + dim=1, + ) + args = (k, w, u, g, chunk_size) + run_example(helion_gdn_fwd_h, ref_gdn_fwd_h, args) + + +# %% +# Main Function +# ----------- +def main() -> None: + """ + Main entry point that runs the attention kernel test with specific parameters. + """ + test(8, 80, 4096, 256, 64, 128) + + +if __name__ == "__main__": + main()