From 83a8c2d88d52f3d366ccfd114f4a24a1b43af20f Mon Sep 17 00:00:00 2001 From: Tuan Trieu Date: Tue, 18 Nov 2025 23:05:50 -0800 Subject: [PATCH] jagged_dense_bmm (#1126) Summary: Add an example of jagged dense bmm. Differential Revision: D84082652 --- examples/jagged_dense_bmm.py | 148 +++++++++++++++++++++++++++++++++++ test/test_examples.py | 14 ++++ 2 files changed, 162 insertions(+) create mode 100644 examples/jagged_dense_bmm.py diff --git a/examples/jagged_dense_bmm.py b/examples/jagged_dense_bmm.py new file mode 100644 index 000000000..5d37f447a --- /dev/null +++ b/examples/jagged_dense_bmm.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from typing import Optional, Tuple + +import helion +import helion.language as hl + +import torch +from helion._testing import run_example + +""" +---jagged_dense_bmm--- +seq_offsets : [B + 1] # B is batch size +jagged : [L, D] # L is sum of sequence lengths, D is embedding dimension +dense : [B, D, K] # K is output dimension +bias : [B, K] # optional bias +""" + + +@helion.kernel() +def jagged_dense_bmm( + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + L, D = jagged.shape + B, D, K = dense.shape + dtype = torch.promote_types(jagged.dtype, dense.dtype) + device = jagged.device + + jagged = jagged.view(-1) # flattening to [L * D] + # Allocate output tensor and flatten to 1D + output = torch.empty((L, K), dtype=dtype, device=device).view(-1) + for tile_b in hl.tile(B): + starts = seq_offsets[tile_b] + ends = seq_offsets[tile_b.index + 1] + seq_len = ends - starts + max_seq_len = seq_len.amax() + + for tile_len in hl.tile(0, max_seq_len): + mask = tile_len.index[None, :] < seq_len[:, None] + jagged_indices = starts[:, None] + tile_len.index[None, :] + + for tile_k in hl.tile(0, K): + acc = hl.zeros([tile_b, tile_len, tile_k], dtype=dtype, device=device) + for tile_d in hl.tile(0, D): + jagged_data = hl.load( + jagged, + [jagged_indices[:, :, None] * D + tile_d.index[None, None, :]], + extra_mask=mask[:, :, None] & (tile_d.index < D)[None, None, :], + ) # [tile_b, tile_len, tile_d] + dense_data = dense[tile_b, tile_d, tile_k] + + acc = acc + torch.matmul( + jagged_data, dense_data + ) # [tile_b, tile_len, tile_k] + + if bias is not None: + bias_data = bias[tile_b, tile_k] # [tile_b, tile_k] + # [tile_b, tile_len, tile_k] + [tile_b, 1, tile_k] -> [tile_b, tile_len, tile_k] + acc = acc + bias_data.unsqueeze(1) + + hl.store( + output, + [jagged_indices[:, :, None] * K + tile_k.index[None, None, :]], + acc, + extra_mask=mask[:, :, None], + ) + return output.reshape(L, K) + + +def jagged_dense_bmm_reference( + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + L, D = jagged.shape + B, _, K = dense.shape + + # Allocate output tensor + ref_output = torch.empty((L, K), dtype=jagged.dtype, device=jagged.device) + + # Process each example in the batch + for i in range(B): + seq_start = seq_offsets[i].item() + seq_end = seq_offsets[i + 1].item() + + if seq_start < seq_end: # Non-empty sequence + seq_data = jagged[seq_start:seq_end] # [seq_len, D] + + # Matrix multiplication: [seq_len, D] @ [D, K] -> [seq_len, K] + result = torch.matmul(seq_data, dense[i]) + + # Add bias if provided + if bias is not None: + result = result + bias[i].unsqueeze(0) + + # Store result + ref_output[seq_start:seq_end] = result + return ref_output + + +def random_input( + D: int = 4, + K: int = 5, + batch_size: int = 3, + max_seq_len: int = 3, + dtype: torch.dtype = torch.float32, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + lengths = torch.randint( + max_seq_len + 1, size=(batch_size,), device=torch.device("cuda") + ) + seq_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + seq_offsets[1:] = torch.cumsum(lengths, dim=0) + jagged_size = int(seq_offsets[-1].item()) + jagged = ( + torch.empty((jagged_size, D), dtype=dtype, device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + dense = ( + torch.empty((batch_size, D, K), dtype=dtype, device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + bias = ( + torch.empty((batch_size, K), dtype=dtype, device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + return seq_offsets, jagged, dense, bias + + +def main() -> None: + seq_offsets, jagged, dense, bias = random_input( + D=34, K=24, batch_size=23, max_seq_len=37, dtype=torch.float32 + ) + run_example( + jagged_dense_bmm, jagged_dense_bmm_reference, (seq_offsets, jagged, dense, bias) + ) + + +if __name__ == "__main__": + main() diff --git a/test/test_examples.py b/test/test_examples.py index 31a796a5f..57eb2bbf4 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -701,6 +701,20 @@ def test_jagged_dense_add(self): ) ) + def test_jagged_dense_bmm(self): + mod = import_path(EXAMPLES_DIR / "jagged_dense_bmm.py") + seq_offsets, jagged, dense, bias = mod.random_input( + D=32, K=24, batch_size=16, max_seq_len=32, dtype=torch.float32 + ) + args = (seq_offsets, jagged, dense, bias) + self.assertExpectedJournal( + check_example( + "jagged_dense_bmm", + args, + mod.jagged_dense_bmm_reference(*args), + ) + ) + @skipIfRefEager("Test has skip_accuracy=True and doesn't call assert_close") def test_moe_matmul_ogs(self): mod = import_path(EXAMPLES_DIR / "moe_matmul_ogs.py")