@@ -64,130 +64,6 @@ def matmul(
6464 return out
6565
6666
67- @helion .kernel
68- def matmul_bwd (
69- grad_out : Tensor , # [m, n] gradient w.r.t output
70- mat1 : Tensor , # [m, k] first matrix
71- mat2 : Tensor , # [k, n] second matrix
72- ) -> tuple [Tensor , Tensor ]:
73- """
74- Backward pass for matrix multiplication following Triton reference pattern.
75-
76- For C = A @ B, given grad_C, computes:
77- - grad_A = grad_C @ B.T
78- - grad_B = A.T @ grad_C
79-
80- Args:
81- grad_out: Gradient w.r.t output [m, n]
82- mat1: First matrix [m, k]
83- mat2: Second matrix [k, n]
84-
85- Returns:
86- tuple[Tensor, Tensor]: (grad_mat1, grad_mat2)
87- """
88- # Get all dimensions first
89- m , n = grad_out .size ()
90- m2 , k = mat1 .size ()
91- k2 , n2 = mat2 .size ()
92-
93- # All assertions at the top
94- assert m == m2 and n == n2 and k == k2 , "Size mismatch in matmul backward"
95-
96- # Declare ALL output tensors at the top before any loops
97- grad_mat1 = torch .empty_like (mat1 )
98- grad_mat2 = torch .empty_like (mat2 )
99-
100- # First loop block: compute grad_mat1 = grad_out @ mat2.T
101- for tile_m1 , tile_k1 in hl .tile ([m , k ]):
102- acc1 = hl .zeros ([tile_m1 , tile_k1 ], dtype = torch .float32 )
103- for tile_n1 in hl .tile (n ):
104- # Need mat2.T: mat2 is [k, n], so mat2[tile_k, tile_n].T gives [tile_n, tile_k]
105- acc1 = torch .addmm (
106- acc1 , grad_out [tile_m1 , tile_n1 ], mat2 [tile_k1 , tile_n1 ].T
107- )
108- grad_mat1 [tile_m1 , tile_k1 ] = acc1 .to (mat1 .dtype )
109-
110- # Second loop block: compute grad_mat2 = mat1.T @ grad_out
111- for tile_k2 , tile_n2 in hl .tile ([k , n ]):
112- acc2 = hl .zeros ([tile_k2 , tile_n2 ], dtype = torch .float32 )
113- for tile_m2 in hl .tile (m ):
114- # Need mat1.T: mat1 is [m, k], so mat1[tile_m, tile_k].T gives [tile_k, tile_m]
115- acc2 = torch .addmm (
116- acc2 , mat1 [tile_m2 , tile_k2 ].T , grad_out [tile_m2 , tile_n2 ]
117- )
118- grad_mat2 [tile_k2 , tile_n2 ] = acc2 .to (mat2 .dtype )
119-
120- return grad_mat1 , grad_mat2
121-
122-
123- @helion .kernel
124- def addmm_bwd (
125- grad_out : Tensor , # [m, n] gradient w.r.t output
126- bias : Tensor , # [m, n] or broadcastable bias tensor
127- mat1 : Tensor , # [m, k] first matrix
128- mat2 : Tensor , # [k, n] second matrix
129- alpha : float = 1.0 , # scalar multiplier for matmul
130- beta : float = 1.0 , # scalar multiplier for bias
131- ) -> tuple [Tensor , Tensor , Tensor ]:
132- """
133- Backward pass for addmm operation following Triton reference pattern.
134-
135- Forward: output = beta * bias + alpha * (mat1 @ mat2)
136-
137- Based on the Triton kernel analysis:
138- - grad_input = beta * grad_out (with proper reduction for broadcasting)
139- - grad_mat1 = alpha * (grad_out @ mat2.T)
140- - grad_mat2 = alpha * (mat1.T @ grad_out)
141-
142- Args:
143- grad_out: Gradient w.r.t output [m, n]
144- bias: Bias tensor [m, n] (or broadcastable)
145- mat1: First matrix [m, k]
146- mat2: Second matrix [k, n]
147- alpha: Scalar multiplier for matmul
148- beta: Scalar multiplier for bias
149-
150- Returns:
151- tuple[Tensor, Tensor, Tensor]: (grad_input, grad_mat1, grad_mat2)
152- """
153- # Get all dimensions first
154- m , n = grad_out .size ()
155- m2 , k = mat1 .size ()
156- k2 , n2 = mat2 .size ()
157-
158- # All assertions at the top
159- assert m == m2 and n == n2 and k == k2 , "Size mismatch in addmm backward"
160-
161- # Declare ALL output tensors at the top before any loops
162- grad_input = torch .empty_like (bias )
163- grad_mat1 = torch .empty_like (mat1 )
164- grad_mat2 = torch .empty_like (mat2 )
165-
166- # Handle grad_input = beta * grad_out (assuming same shape for now)
167- for tile_m3 , tile_n3 in hl .tile ([m , n ]):
168- grad_input [tile_m3 , tile_n3 ] = beta * grad_out [tile_m3 , tile_n3 ]
169-
170- # First loop block: compute grad_mat1 = alpha * (grad_out @ mat2.T)
171- for tile_m1 , tile_k1 in hl .tile ([m , k ]):
172- acc1 = hl .zeros ([tile_m1 , tile_k1 ], dtype = torch .float32 )
173- for tile_n1 in hl .tile (n ):
174- acc1 = torch .addmm (
175- acc1 , grad_out [tile_m1 , tile_n1 ], mat2 [tile_k1 , tile_n1 ].T
176- )
177- grad_mat1 [tile_m1 , tile_k1 ] = (alpha * acc1 ).to (mat1 .dtype )
178-
179- # Second loop block: compute grad_mat2 = alpha * (mat1.T @ grad_out)
180- for tile_k2 , tile_n2 in hl .tile ([k , n ]):
181- acc2 = hl .zeros ([tile_k2 , tile_n2 ], dtype = torch .float32 )
182- for tile_m2 in hl .tile (m ):
183- acc2 = torch .addmm (
184- acc2 , mat1 [tile_m2 , tile_k2 ].T , grad_out [tile_m2 , tile_n2 ]
185- )
186- grad_mat2 [tile_k2 , tile_n2 ] = (alpha * acc2 ).to (mat2 .dtype )
187-
188- return grad_input , grad_mat1 , grad_mat2
189-
190-
19167# %%
19268class MatMulFunction (torch .autograd .Function ):
19369 @staticmethod
@@ -206,10 +82,24 @@ def backward(
20682 ctx : Any , # noqa: ANN401
20783 * grad_outputs : Tensor ,
20884 ) -> tuple [Tensor | None , Tensor | None ]:
209- """Backward pass for matrix multiplication."""
85+ """
86+ Backward pass for matrix multiplication.
87+
88+ For C = A @ B, given grad_C:
89+ - grad_A = grad_C @ B.T
90+ - grad_B = A.T @ grad_C
91+
92+ We reuse the forward matmul kernel for both computations.
93+ """
21094 grad_out = grad_outputs [0 ]
21195 mat1 , mat2 = ctx .saved_tensors
212- grad_mat1 , grad_mat2 = matmul_bwd (grad_out , mat1 , mat2 )
96+
97+ # grad_mat1 = grad_out @ mat2.T
98+ grad_mat1 = matmul (grad_out , mat2 .T )
99+
100+ # grad_mat2 = mat1.T @ grad_out
101+ grad_mat2 = matmul (mat1 .T , grad_out )
102+
213103 return grad_mat1 , grad_mat2
214104
215105
@@ -248,15 +138,33 @@ def backward(
248138 ctx : Any , # noqa: ANN401
249139 * grad_outputs : Tensor ,
250140 ) -> tuple [Tensor | None , Tensor | None , Tensor | None , None , None ]:
251- """Backward pass for addmm operation."""
141+ """
142+ Backward pass for addmm operation.
143+
144+ Forward: output = beta * bias + alpha * (mat1 @ mat2)
145+
146+ Given grad_out:
147+ - grad_bias = beta * grad_out
148+ - grad_mat1 = alpha * (grad_out @ mat2.T)
149+ - grad_mat2 = alpha * (mat1.T @ grad_out)
150+
151+ We reuse the forward matmul kernel for both matrix gradient computations.
152+ """
252153 grad_out = grad_outputs [0 ]
253154 bias , mat1 , mat2 = ctx .saved_tensors
254155 alpha = ctx .alpha
255156 beta = ctx .beta
256- grad_input , grad_mat1 , grad_mat2 = addmm_bwd (
257- grad_out , bias , mat1 , mat2 , alpha , beta
258- )
259- return grad_input , grad_mat1 , grad_mat2 , None , None
157+
158+ # grad_bias = beta * grad_out
159+ grad_bias = beta * grad_out
160+
161+ # grad_mat1 = alpha * (grad_out @ mat2.T)
162+ grad_mat1 = alpha * matmul (grad_out , mat2 .T )
163+
164+ # grad_mat2 = alpha * (mat1.T @ grad_out)
165+ grad_mat2 = alpha * matmul (mat1 .T , grad_out )
166+
167+ return grad_bias , grad_mat1 , grad_mat2 , None , None
260168
261169
262170def addmm_autograd (
0 commit comments