Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions examples/jagged_dense_bmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from __future__ import annotations

from typing import Optional, Tuple

import helion
import helion.language as hl

import torch
from helion._testing import run_example

"""
---jagged_dense_bmm---
seq_offsets : [B + 1] # B is batch size
jagged : [L, D] # L is sum of sequence lengths, D is embedding dimension
dense : [B, D, K] # K is output dimension
bias : [B, K] # optional bias
"""


@helion.kernel()
def jagged_dense_bmm(
seq_offsets: torch.Tensor,
jagged: torch.Tensor,
dense: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
L, D = jagged.shape
B, D, K = dense.shape
dtype = torch.promote_types(jagged.dtype, dense.dtype)
device = jagged.device

jagged = jagged.view(-1) # flattening to [L * D]
# Allocate output tensor and flatten to 1D
output = torch.empty((L, K), dtype=dtype, device=device).view(-1)
for tile_b in hl.tile(B):
starts = seq_offsets[tile_b]
ends = seq_offsets[tile_b.index + 1]
seq_len = ends - starts
max_seq_len = seq_len.amax()

for tile_len in hl.tile(0, max_seq_len):
mask = tile_len.index[None, :] < seq_len[:, None]
jagged_indices = starts[:, None] + tile_len.index[None, :]

for tile_k in hl.tile(0, K):
acc = hl.zeros([tile_b, tile_len, tile_k], dtype=dtype, device=device)
for tile_d in hl.tile(0, D):
jagged_data = hl.load(
jagged,
[jagged_indices[:, :, None] * D + tile_d.index[None, None, :]],
extra_mask=mask[:, :, None] & (tile_d.index < D)[None, None, :],
) # [tile_b, tile_len, tile_d]
dense_data = dense[tile_b, tile_d, tile_k]

acc = acc + torch.matmul(
jagged_data, dense_data
) # [tile_b, tile_len, tile_k]

if bias is not None:
bias_data = bias[tile_b, tile_k] # [tile_b, tile_k]
# [tile_b, tile_len, tile_k] + [tile_b, 1, tile_k] -> [tile_b, tile_len, tile_k]
acc = acc + bias_data.unsqueeze(1)

hl.store(
output,
[jagged_indices[:, :, None] * K + tile_k.index[None, None, :]],
acc,
extra_mask=mask[:, :, None],
)
return output.reshape(L, K)


def jagged_dense_bmm_reference(
seq_offsets: torch.Tensor,
jagged: torch.Tensor,
dense: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
L, D = jagged.shape
B, _, K = dense.shape

# Allocate output tensor
ref_output = torch.empty((L, K), dtype=jagged.dtype, device=jagged.device)

# Process each example in the batch
for i in range(B):
seq_start = seq_offsets[i].item()
seq_end = seq_offsets[i + 1].item()

if seq_start < seq_end: # Non-empty sequence
seq_data = jagged[seq_start:seq_end] # [seq_len, D]

# Matrix multiplication: [seq_len, D] @ [D, K] -> [seq_len, K]
result = torch.matmul(seq_data, dense[i])

# Add bias if provided
if bias is not None:
result = result + bias[i].unsqueeze(0)

# Store result
ref_output[seq_start:seq_end] = result
return ref_output


def random_input(
D: int = 4,
K: int = 5,
batch_size: int = 3,
max_seq_len: int = 3,
dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
lengths = torch.randint(
max_seq_len + 1, size=(batch_size,), device=torch.device("cuda")
)
seq_offsets = torch.zeros(
(batch_size + 1,), dtype=torch.int64, device=torch.device("cuda")
)
seq_offsets[1:] = torch.cumsum(lengths, dim=0)
jagged_size = int(seq_offsets[-1].item())
jagged = (
torch.empty((jagged_size, D), dtype=dtype, device=torch.device("cuda"))
.uniform_(-1.0, 1.0)
.requires_grad_()
)
dense = (
torch.empty((batch_size, D, K), dtype=dtype, device=torch.device("cuda"))
.uniform_(-1.0, 1.0)
.requires_grad_()
)
bias = (
torch.empty((batch_size, K), dtype=dtype, device=torch.device("cuda"))
.uniform_(-1.0, 1.0)
.requires_grad_()
)
return seq_offsets, jagged, dense, bias


def main() -> None:
seq_offsets, jagged, dense, bias = random_input(
D=34, K=24, batch_size=23, max_seq_len=37, dtype=torch.float32
)
run_example(
jagged_dense_bmm, jagged_dense_bmm_reference, (seq_offsets, jagged, dense, bias)
)


if __name__ == "__main__":
main()
Loading