Skip to content

Commit 95b2255

Browse files
committed
use matmul direclty in bwd for performance
1 parent dee9f57 commit 95b2255

File tree

1 file changed

+39
-131
lines changed

1 file changed

+39
-131
lines changed

examples/matmul.py

Lines changed: 39 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -64,130 +64,6 @@ def matmul(
6464
return out
6565

6666

67-
@helion.kernel
68-
def matmul_bwd(
69-
grad_out: Tensor, # [m, n] gradient w.r.t output
70-
mat1: Tensor, # [m, k] first matrix
71-
mat2: Tensor, # [k, n] second matrix
72-
) -> tuple[Tensor, Tensor]:
73-
"""
74-
Backward pass for matrix multiplication following Triton reference pattern.
75-
76-
For C = A @ B, given grad_C, computes:
77-
- grad_A = grad_C @ B.T
78-
- grad_B = A.T @ grad_C
79-
80-
Args:
81-
grad_out: Gradient w.r.t output [m, n]
82-
mat1: First matrix [m, k]
83-
mat2: Second matrix [k, n]
84-
85-
Returns:
86-
tuple[Tensor, Tensor]: (grad_mat1, grad_mat2)
87-
"""
88-
# Get all dimensions first
89-
m, n = grad_out.size()
90-
m2, k = mat1.size()
91-
k2, n2 = mat2.size()
92-
93-
# All assertions at the top
94-
assert m == m2 and n == n2 and k == k2, "Size mismatch in matmul backward"
95-
96-
# Declare ALL output tensors at the top before any loops
97-
grad_mat1 = torch.empty_like(mat1)
98-
grad_mat2 = torch.empty_like(mat2)
99-
100-
# First loop block: compute grad_mat1 = grad_out @ mat2.T
101-
for tile_m1, tile_k1 in hl.tile([m, k]):
102-
acc1 = hl.zeros([tile_m1, tile_k1], dtype=torch.float32)
103-
for tile_n1 in hl.tile(n):
104-
# Need mat2.T: mat2 is [k, n], so mat2[tile_k, tile_n].T gives [tile_n, tile_k]
105-
acc1 = torch.addmm(
106-
acc1, grad_out[tile_m1, tile_n1], mat2[tile_k1, tile_n1].T
107-
)
108-
grad_mat1[tile_m1, tile_k1] = acc1.to(mat1.dtype)
109-
110-
# Second loop block: compute grad_mat2 = mat1.T @ grad_out
111-
for tile_k2, tile_n2 in hl.tile([k, n]):
112-
acc2 = hl.zeros([tile_k2, tile_n2], dtype=torch.float32)
113-
for tile_m2 in hl.tile(m):
114-
# Need mat1.T: mat1 is [m, k], so mat1[tile_m, tile_k].T gives [tile_k, tile_m]
115-
acc2 = torch.addmm(
116-
acc2, mat1[tile_m2, tile_k2].T, grad_out[tile_m2, tile_n2]
117-
)
118-
grad_mat2[tile_k2, tile_n2] = acc2.to(mat2.dtype)
119-
120-
return grad_mat1, grad_mat2
121-
122-
123-
@helion.kernel
124-
def addmm_bwd(
125-
grad_out: Tensor, # [m, n] gradient w.r.t output
126-
bias: Tensor, # [m, n] or broadcastable bias tensor
127-
mat1: Tensor, # [m, k] first matrix
128-
mat2: Tensor, # [k, n] second matrix
129-
alpha: float = 1.0, # scalar multiplier for matmul
130-
beta: float = 1.0, # scalar multiplier for bias
131-
) -> tuple[Tensor, Tensor, Tensor]:
132-
"""
133-
Backward pass for addmm operation following Triton reference pattern.
134-
135-
Forward: output = beta * bias + alpha * (mat1 @ mat2)
136-
137-
Based on the Triton kernel analysis:
138-
- grad_input = beta * grad_out (with proper reduction for broadcasting)
139-
- grad_mat1 = alpha * (grad_out @ mat2.T)
140-
- grad_mat2 = alpha * (mat1.T @ grad_out)
141-
142-
Args:
143-
grad_out: Gradient w.r.t output [m, n]
144-
bias: Bias tensor [m, n] (or broadcastable)
145-
mat1: First matrix [m, k]
146-
mat2: Second matrix [k, n]
147-
alpha: Scalar multiplier for matmul
148-
beta: Scalar multiplier for bias
149-
150-
Returns:
151-
tuple[Tensor, Tensor, Tensor]: (grad_input, grad_mat1, grad_mat2)
152-
"""
153-
# Get all dimensions first
154-
m, n = grad_out.size()
155-
m2, k = mat1.size()
156-
k2, n2 = mat2.size()
157-
158-
# All assertions at the top
159-
assert m == m2 and n == n2 and k == k2, "Size mismatch in addmm backward"
160-
161-
# Declare ALL output tensors at the top before any loops
162-
grad_input = torch.empty_like(bias)
163-
grad_mat1 = torch.empty_like(mat1)
164-
grad_mat2 = torch.empty_like(mat2)
165-
166-
# Handle grad_input = beta * grad_out (assuming same shape for now)
167-
for tile_m3, tile_n3 in hl.tile([m, n]):
168-
grad_input[tile_m3, tile_n3] = beta * grad_out[tile_m3, tile_n3]
169-
170-
# First loop block: compute grad_mat1 = alpha * (grad_out @ mat2.T)
171-
for tile_m1, tile_k1 in hl.tile([m, k]):
172-
acc1 = hl.zeros([tile_m1, tile_k1], dtype=torch.float32)
173-
for tile_n1 in hl.tile(n):
174-
acc1 = torch.addmm(
175-
acc1, grad_out[tile_m1, tile_n1], mat2[tile_k1, tile_n1].T
176-
)
177-
grad_mat1[tile_m1, tile_k1] = (alpha * acc1).to(mat1.dtype)
178-
179-
# Second loop block: compute grad_mat2 = alpha * (mat1.T @ grad_out)
180-
for tile_k2, tile_n2 in hl.tile([k, n]):
181-
acc2 = hl.zeros([tile_k2, tile_n2], dtype=torch.float32)
182-
for tile_m2 in hl.tile(m):
183-
acc2 = torch.addmm(
184-
acc2, mat1[tile_m2, tile_k2].T, grad_out[tile_m2, tile_n2]
185-
)
186-
grad_mat2[tile_k2, tile_n2] = (alpha * acc2).to(mat2.dtype)
187-
188-
return grad_input, grad_mat1, grad_mat2
189-
190-
19167
# %%
19268
class MatMulFunction(torch.autograd.Function):
19369
@staticmethod
@@ -206,10 +82,24 @@ def backward(
20682
ctx: Any, # noqa: ANN401
20783
*grad_outputs: Tensor,
20884
) -> tuple[Tensor | None, Tensor | None]:
209-
"""Backward pass for matrix multiplication."""
85+
"""
86+
Backward pass for matrix multiplication.
87+
88+
For C = A @ B, given grad_C:
89+
- grad_A = grad_C @ B.T
90+
- grad_B = A.T @ grad_C
91+
92+
We reuse the forward matmul kernel for both computations.
93+
"""
21094
grad_out = grad_outputs[0]
21195
mat1, mat2 = ctx.saved_tensors
212-
grad_mat1, grad_mat2 = matmul_bwd(grad_out, mat1, mat2)
96+
97+
# grad_mat1 = grad_out @ mat2.T
98+
grad_mat1 = matmul(grad_out, mat2.T)
99+
100+
# grad_mat2 = mat1.T @ grad_out
101+
grad_mat2 = matmul(mat1.T, grad_out)
102+
213103
return grad_mat1, grad_mat2
214104

215105

@@ -248,15 +138,33 @@ def backward(
248138
ctx: Any, # noqa: ANN401
249139
*grad_outputs: Tensor,
250140
) -> tuple[Tensor | None, Tensor | None, Tensor | None, None, None]:
251-
"""Backward pass for addmm operation."""
141+
"""
142+
Backward pass for addmm operation.
143+
144+
Forward: output = beta * bias + alpha * (mat1 @ mat2)
145+
146+
Given grad_out:
147+
- grad_bias = beta * grad_out
148+
- grad_mat1 = alpha * (grad_out @ mat2.T)
149+
- grad_mat2 = alpha * (mat1.T @ grad_out)
150+
151+
We reuse the forward matmul kernel for both matrix gradient computations.
152+
"""
252153
grad_out = grad_outputs[0]
253154
bias, mat1, mat2 = ctx.saved_tensors
254155
alpha = ctx.alpha
255156
beta = ctx.beta
256-
grad_input, grad_mat1, grad_mat2 = addmm_bwd(
257-
grad_out, bias, mat1, mat2, alpha, beta
258-
)
259-
return grad_input, grad_mat1, grad_mat2, None, None
157+
158+
# grad_bias = beta * grad_out
159+
grad_bias = beta * grad_out
160+
161+
# grad_mat1 = alpha * (grad_out @ mat2.T)
162+
grad_mat1 = alpha * matmul(grad_out, mat2.T)
163+
164+
# grad_mat2 = alpha * (mat1.T @ grad_out)
165+
grad_mat2 = alpha * matmul(mat1.T, grad_out)
166+
167+
return grad_bias, grad_mat1, grad_mat2, None, None
260168

261169

262170
def addmm_autograd(

0 commit comments

Comments
 (0)