Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,15 @@ class RunResult:
"helion_addmm_tritonbench-speedup": "helion_speedup",
"helion_addmm_tritonbench-accuracy": "helion_accuracy",
},
"addmm-bwd": {
"aten_addmm": "baseline",
"triton_addmm-speedup": "triton_speedup",
"triton_addmm-accuracy": "triton_accuracy",
"pt2_addmm_maxautotune-speedup": "torch_compile_speedup",
"pt2_addmm_maxautotune-accuracy": "torch_compile_accuracy",
"helion_addmm_tritonbench-speedup": "helion_speedup",
"helion_addmm_tritonbench-accuracy": "helion_accuracy",
},
# "ragged_attention": {
# "triton_ragged_attention-speedup": "triton_speedup",
# "triton_ragged_attention-accuracy": "triton_accuracy",
Expand Down Expand Up @@ -603,6 +612,15 @@ class RunResult:
"helion_matmul_tritonbench-speedup": "helion_speedup",
"helion_matmul_tritonbench-accuracy": "helion_accuracy",
},
"gemm-bwd": {
"aten_matmul": "baseline",
"triton_tutorial_matmul-speedup": "triton_speedup",
"triton_tutorial_matmul-accuracy": "triton_accuracy",
"pt2_triton_matmul-speedup": "torch_compile_speedup",
"pt2_triton_matmul-accuracy": "torch_compile_accuracy",
"helion_matmul_tritonbench-speedup": "helion_speedup",
"helion_matmul_tritonbench-accuracy": "helion_accuracy",
},
"fp8_gemm": {
"torch_fp8_gemm": "baseline",
f"{'blackwell_persistent_tma' if IS_B200 else 'triton_tma_persistent'}_fp8_gemm-speedup": "triton_speedup",
Expand Down
188 changes: 110 additions & 78 deletions examples/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,116 @@ def matmul(
return out


# %%
class MatMulFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any, # noqa: ANN401
mat1: Tensor,
mat2: Tensor,
) -> Tensor:
"""Forward pass for matrix multiplication."""
result = matmul(mat1, mat2)
ctx.save_for_backward(mat1, mat2)
return result

@staticmethod
def backward(
ctx: Any, # noqa: ANN401
*grad_outputs: Tensor,
) -> tuple[Tensor | None, Tensor | None]:
"""
Backward pass for matrix multiplication.

For C = A @ B, given grad_C:
- grad_A = grad_C @ B.T
- grad_B = A.T @ grad_C

We reuse the forward matmul kernel for both computations.
"""
grad_out = grad_outputs[0]
mat1, mat2 = ctx.saved_tensors

# grad_mat1 = grad_out @ mat2.T
grad_mat1 = matmul(grad_out, mat2.T)

# grad_mat2 = mat1.T @ grad_out
grad_mat2 = matmul(mat1.T, grad_out)
Comment on lines +97 to +101
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You only need to compute these if requires_grad is set on the inputs.


return grad_mat1, grad_mat2


def matmul_autograd(mat1: Tensor, mat2: Tensor) -> Tensor:
"""Matrix multiplication with forward + backward support."""
return MatMulFunction.apply(mat1, mat2) # type: ignore[no-any-return]


class AddMMFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any, # noqa: ANN401
bias: Tensor,
mat1: Tensor,
mat2: Tensor,
alpha: float = 1.0,
beta: float = 1.0,
) -> Tensor:
"""Forward pass for addmm operation using helion matmul with epilogue."""
m, k = mat1.size()
k2, n = mat2.size()
input_broadcasted = torch.broadcast_to(bias, [m, n])

# Define epilogue that adds bias: alpha * acc + beta * bias
def addmm_epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor:
return alpha * acc + beta * input_broadcasted[tile[0], tile[1]]

result = matmul(mat1, mat2, addmm_epilogue)
ctx.save_for_backward(bias, mat1, mat2)
ctx.alpha = alpha
ctx.beta = beta
return result

@staticmethod
def backward(
ctx: Any, # noqa: ANN401
*grad_outputs: Tensor,
) -> tuple[Tensor | None, Tensor | None, Tensor | None, None, None]:
"""
Backward pass for addmm operation.

Forward: output = beta * bias + alpha * (mat1 @ mat2)

Given grad_out:
- grad_bias = beta * grad_out
- grad_mat1 = alpha * (grad_out @ mat2.T)
- grad_mat2 = alpha * (mat1.T @ grad_out)

We reuse the forward matmul kernel for both matrix gradient computations.
"""
grad_out = grad_outputs[0]
bias, mat1, mat2 = ctx.saved_tensors
alpha = ctx.alpha
beta = ctx.beta

# grad_bias = beta * grad_out
grad_bias = beta * grad_out

# grad_mat1 = alpha * (grad_out @ mat2.T)
grad_mat1 = alpha * matmul(grad_out, mat2.T)

# grad_mat2 = alpha * (mat1.T @ grad_out)
grad_mat2 = alpha * matmul(mat1.T, grad_out)
Comment on lines +158 to +165
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This results in extra kernels, you should define an epilogue function to put the scaling into the matmul kernel.

Also same issue as above.


return grad_bias, grad_mat1, grad_mat2, None, None


def addmm_autograd(
bias: Tensor, mat1: Tensor, mat2: Tensor, alpha: float = 1.0, beta: float = 1.0
) -> Tensor:
"""AddMM operation with forward + backward support."""
return AddMMFunction.apply(bias, mat1, mat2, alpha, beta) # type: ignore[no-any-return]


@helion.kernel
def matmul_bwd(
grad_out: Tensor, # [m, n] gradient w.r.t output
Expand Down Expand Up @@ -188,84 +298,6 @@ def addmm_bwd(
return grad_input, grad_mat1, grad_mat2


# %%
class MatMulFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any, # noqa: ANN401
mat1: Tensor,
mat2: Tensor,
) -> Tensor:
"""Forward pass for matrix multiplication."""
result = matmul(mat1, mat2)
ctx.save_for_backward(mat1, mat2)
return result

@staticmethod
def backward(
ctx: Any, # noqa: ANN401
*grad_outputs: Tensor,
) -> tuple[Tensor | None, Tensor | None]:
"""Backward pass for matrix multiplication."""
grad_out = grad_outputs[0]
mat1, mat2 = ctx.saved_tensors
grad_mat1, grad_mat2 = matmul_bwd(grad_out, mat1, mat2)
return grad_mat1, grad_mat2


def matmul_autograd(mat1: Tensor, mat2: Tensor) -> Tensor:
"""Matrix multiplication with forward + backward support."""
return MatMulFunction.apply(mat1, mat2) # type: ignore[no-any-return]


class AddMMFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any, # noqa: ANN401
bias: Tensor,
mat1: Tensor,
mat2: Tensor,
alpha: float = 1.0,
beta: float = 1.0,
) -> Tensor:
"""Forward pass for addmm operation using helion matmul with epilogue."""
m, k = mat1.size()
k2, n = mat2.size()
input_broadcasted = torch.broadcast_to(bias, [m, n])

# Define epilogue that adds bias: alpha * acc + beta * bias
def addmm_epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor:
return alpha * acc + beta * input_broadcasted[tile[0], tile[1]]

result = matmul(mat1, mat2, addmm_epilogue)
ctx.save_for_backward(bias, mat1, mat2)
ctx.alpha = alpha
ctx.beta = beta
return result

@staticmethod
def backward(
ctx: Any, # noqa: ANN401
*grad_outputs: Tensor,
) -> tuple[Tensor | None, Tensor | None, Tensor | None, None, None]:
"""Backward pass for addmm operation."""
grad_out = grad_outputs[0]
bias, mat1, mat2 = ctx.saved_tensors
alpha = ctx.alpha
beta = ctx.beta
grad_input, grad_mat1, grad_mat2 = addmm_bwd(
grad_out, bias, mat1, mat2, alpha, beta
)
return grad_input, grad_mat1, grad_mat2, None, None


def addmm_autograd(
bias: Tensor, mat1: Tensor, mat2: Tensor, alpha: float = 1.0, beta: float = 1.0
) -> Tensor:
"""AddMM operation with forward + backward support."""
return AddMMFunction.apply(bias, mat1, mat2, alpha, beta) # type: ignore[no-any-return]


# %%
def autotune(m: int, k: int, n: int) -> None:
"""
Expand Down
Loading