-
Notifications
You must be signed in to change notification settings - Fork 68
Use matmul fwd direclty in autograd for performance #1045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tianrengao
wants to merge
5
commits into
main
Choose a base branch
from
tianren/addmm_bwd_fix_impl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,6 +64,116 @@ def matmul( | |
| return out | ||
|
|
||
|
|
||
| # %% | ||
| class MatMulFunction(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward( | ||
| ctx: Any, # noqa: ANN401 | ||
| mat1: Tensor, | ||
| mat2: Tensor, | ||
| ) -> Tensor: | ||
| """Forward pass for matrix multiplication.""" | ||
| result = matmul(mat1, mat2) | ||
| ctx.save_for_backward(mat1, mat2) | ||
| return result | ||
|
|
||
| @staticmethod | ||
| def backward( | ||
| ctx: Any, # noqa: ANN401 | ||
| *grad_outputs: Tensor, | ||
| ) -> tuple[Tensor | None, Tensor | None]: | ||
| """ | ||
| Backward pass for matrix multiplication. | ||
|
|
||
| For C = A @ B, given grad_C: | ||
| - grad_A = grad_C @ B.T | ||
| - grad_B = A.T @ grad_C | ||
|
|
||
| We reuse the forward matmul kernel for both computations. | ||
| """ | ||
| grad_out = grad_outputs[0] | ||
| mat1, mat2 = ctx.saved_tensors | ||
|
|
||
| # grad_mat1 = grad_out @ mat2.T | ||
| grad_mat1 = matmul(grad_out, mat2.T) | ||
|
|
||
| # grad_mat2 = mat1.T @ grad_out | ||
| grad_mat2 = matmul(mat1.T, grad_out) | ||
|
|
||
| return grad_mat1, grad_mat2 | ||
|
|
||
|
|
||
| def matmul_autograd(mat1: Tensor, mat2: Tensor) -> Tensor: | ||
| """Matrix multiplication with forward + backward support.""" | ||
| return MatMulFunction.apply(mat1, mat2) # type: ignore[no-any-return] | ||
|
|
||
|
|
||
| class AddMMFunction(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward( | ||
| ctx: Any, # noqa: ANN401 | ||
| bias: Tensor, | ||
| mat1: Tensor, | ||
| mat2: Tensor, | ||
| alpha: float = 1.0, | ||
| beta: float = 1.0, | ||
| ) -> Tensor: | ||
| """Forward pass for addmm operation using helion matmul with epilogue.""" | ||
| m, k = mat1.size() | ||
| k2, n = mat2.size() | ||
| input_broadcasted = torch.broadcast_to(bias, [m, n]) | ||
|
|
||
| # Define epilogue that adds bias: alpha * acc + beta * bias | ||
| def addmm_epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor: | ||
| return alpha * acc + beta * input_broadcasted[tile[0], tile[1]] | ||
|
|
||
| result = matmul(mat1, mat2, addmm_epilogue) | ||
| ctx.save_for_backward(bias, mat1, mat2) | ||
| ctx.alpha = alpha | ||
| ctx.beta = beta | ||
| return result | ||
|
|
||
| @staticmethod | ||
| def backward( | ||
| ctx: Any, # noqa: ANN401 | ||
| *grad_outputs: Tensor, | ||
| ) -> tuple[Tensor | None, Tensor | None, Tensor | None, None, None]: | ||
| """ | ||
| Backward pass for addmm operation. | ||
|
|
||
| Forward: output = beta * bias + alpha * (mat1 @ mat2) | ||
|
|
||
| Given grad_out: | ||
| - grad_bias = beta * grad_out | ||
| - grad_mat1 = alpha * (grad_out @ mat2.T) | ||
| - grad_mat2 = alpha * (mat1.T @ grad_out) | ||
|
|
||
| We reuse the forward matmul kernel for both matrix gradient computations. | ||
| """ | ||
| grad_out = grad_outputs[0] | ||
| bias, mat1, mat2 = ctx.saved_tensors | ||
| alpha = ctx.alpha | ||
| beta = ctx.beta | ||
|
|
||
| # grad_bias = beta * grad_out | ||
| grad_bias = beta * grad_out | ||
|
|
||
| # grad_mat1 = alpha * (grad_out @ mat2.T) | ||
| grad_mat1 = alpha * matmul(grad_out, mat2.T) | ||
|
|
||
| # grad_mat2 = alpha * (mat1.T @ grad_out) | ||
| grad_mat2 = alpha * matmul(mat1.T, grad_out) | ||
|
Comment on lines
+158
to
+165
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This results in extra kernels, you should define an epilogue function to put the scaling into the matmul kernel. Also same issue as above. |
||
|
|
||
| return grad_bias, grad_mat1, grad_mat2, None, None | ||
|
|
||
|
|
||
| def addmm_autograd( | ||
| bias: Tensor, mat1: Tensor, mat2: Tensor, alpha: float = 1.0, beta: float = 1.0 | ||
| ) -> Tensor: | ||
| """AddMM operation with forward + backward support.""" | ||
| return AddMMFunction.apply(bias, mat1, mat2, alpha, beta) # type: ignore[no-any-return] | ||
|
|
||
|
|
||
| @helion.kernel | ||
| def matmul_bwd( | ||
| grad_out: Tensor, # [m, n] gradient w.r.t output | ||
|
|
@@ -188,84 +298,6 @@ def addmm_bwd( | |
| return grad_input, grad_mat1, grad_mat2 | ||
|
|
||
|
|
||
| # %% | ||
| class MatMulFunction(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward( | ||
| ctx: Any, # noqa: ANN401 | ||
| mat1: Tensor, | ||
| mat2: Tensor, | ||
| ) -> Tensor: | ||
| """Forward pass for matrix multiplication.""" | ||
| result = matmul(mat1, mat2) | ||
| ctx.save_for_backward(mat1, mat2) | ||
| return result | ||
|
|
||
| @staticmethod | ||
| def backward( | ||
| ctx: Any, # noqa: ANN401 | ||
| *grad_outputs: Tensor, | ||
| ) -> tuple[Tensor | None, Tensor | None]: | ||
| """Backward pass for matrix multiplication.""" | ||
| grad_out = grad_outputs[0] | ||
| mat1, mat2 = ctx.saved_tensors | ||
| grad_mat1, grad_mat2 = matmul_bwd(grad_out, mat1, mat2) | ||
| return grad_mat1, grad_mat2 | ||
|
|
||
|
|
||
| def matmul_autograd(mat1: Tensor, mat2: Tensor) -> Tensor: | ||
| """Matrix multiplication with forward + backward support.""" | ||
| return MatMulFunction.apply(mat1, mat2) # type: ignore[no-any-return] | ||
|
|
||
|
|
||
| class AddMMFunction(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward( | ||
| ctx: Any, # noqa: ANN401 | ||
| bias: Tensor, | ||
| mat1: Tensor, | ||
| mat2: Tensor, | ||
| alpha: float = 1.0, | ||
| beta: float = 1.0, | ||
| ) -> Tensor: | ||
| """Forward pass for addmm operation using helion matmul with epilogue.""" | ||
| m, k = mat1.size() | ||
| k2, n = mat2.size() | ||
| input_broadcasted = torch.broadcast_to(bias, [m, n]) | ||
|
|
||
| # Define epilogue that adds bias: alpha * acc + beta * bias | ||
| def addmm_epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor: | ||
| return alpha * acc + beta * input_broadcasted[tile[0], tile[1]] | ||
|
|
||
| result = matmul(mat1, mat2, addmm_epilogue) | ||
| ctx.save_for_backward(bias, mat1, mat2) | ||
| ctx.alpha = alpha | ||
| ctx.beta = beta | ||
| return result | ||
|
|
||
| @staticmethod | ||
| def backward( | ||
| ctx: Any, # noqa: ANN401 | ||
| *grad_outputs: Tensor, | ||
| ) -> tuple[Tensor | None, Tensor | None, Tensor | None, None, None]: | ||
| """Backward pass for addmm operation.""" | ||
| grad_out = grad_outputs[0] | ||
| bias, mat1, mat2 = ctx.saved_tensors | ||
| alpha = ctx.alpha | ||
| beta = ctx.beta | ||
| grad_input, grad_mat1, grad_mat2 = addmm_bwd( | ||
| grad_out, bias, mat1, mat2, alpha, beta | ||
| ) | ||
| return grad_input, grad_mat1, grad_mat2, None, None | ||
|
|
||
|
|
||
| def addmm_autograd( | ||
| bias: Tensor, mat1: Tensor, mat2: Tensor, alpha: float = 1.0, beta: float = 1.0 | ||
| ) -> Tensor: | ||
| """AddMM operation with forward + backward support.""" | ||
| return AddMMFunction.apply(bias, mat1, mat2, alpha, beta) # type: ignore[no-any-return] | ||
|
|
||
|
|
||
| # %% | ||
| def autotune(m: int, k: int, n: int) -> None: | ||
| """ | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You only need to compute these if requires_grad is set on the inputs.