Skip to content

Commit fc79ea7

Browse files
mengluy0125meta-codesync[bot]
authored andcommitted
Add squeeze_and_excitation_net kernel (#870)
Summary: Pull Request resolved: #870 We add a forward kernel calculating torch.sigmoid(torch.relu((x @ a)) @ b) in Helion Differential Revision: D84310853
1 parent 2572047 commit fc79ea7

File tree

3 files changed

+683
-0
lines changed

3 files changed

+683
-0
lines changed
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
"""
2+
Helion squeeze and excitation net Example
3+
============================
4+
This example demonstrates a Helion kernel implementation of squeeze and excitation
5+
net as those used in https://arxiv.org/abs/1709.01507.
6+
"""
7+
8+
# %%
9+
from __future__ import annotations
10+
11+
import torch
12+
from torch import Tensor
13+
14+
import helion
15+
from helion._testing import run_example
16+
import helion.language as hl
17+
18+
19+
# %%
20+
@helion.kernel(
21+
# static_shapes=True gives a performance boost for matmuls
22+
static_shapes=True,
23+
)
24+
def squeeze_and_excitation_net_fwd(
25+
x: Tensor, a: Tensor, b: Tensor
26+
) -> tuple[Tensor, Tensor, Tensor]:
27+
"""
28+
Performs torch.mul(x, torch.sigmoid(torch.relu((x @ a)) @ b))
29+
Args:
30+
x: 2D tensor of shape [m, n].
31+
a: 2D tensor of shape [n, k].
32+
b: 2D tensor of shape [k, n].
33+
Returns:
34+
out: Resulting matrix of shape [m, n].
35+
c = torch.relu(x @ a) of shape [m, k].
36+
d = torch.sigmoid(c @ b) of shape [m, n].
37+
"""
38+
m, n = x.size()
39+
k = a.size(1)
40+
41+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
42+
c = torch.empty([m, k], dtype=x.dtype, device=x.device)
43+
d = torch.empty([m, n], dtype=x.dtype, device=x.device)
44+
45+
for tile_m in hl.tile(m):
46+
# Compute c = relu(x @ a) for this tile_m
47+
for tile_k in hl.tile(k):
48+
partial_xa = x[tile_m, :] @ a[:, tile_k]
49+
c[tile_m, tile_k] = torch.relu(partial_xa)
50+
51+
# Compute d = sigmoid(c @ b) and out = x * d for this tile_m
52+
for tile_n in hl.tile(n):
53+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
54+
for tile_k in hl.tile(k):
55+
acc += c[tile_m, tile_k] @ b[tile_k, tile_n]
56+
d[tile_m, tile_n] = torch.sigmoid(acc)
57+
out[tile_m, tile_n] = x[tile_m, tile_n] * d[tile_m, tile_n]
58+
59+
return out, c, d
60+
61+
62+
# %%
63+
@helion.kernel(static_shapes=True)
64+
def squeeze_and_excitation_net_bwd_dx(
65+
grad_out: Tensor, x: Tensor, a: Tensor, b: Tensor, c: Tensor, d: Tensor
66+
) -> Tensor:
67+
"""
68+
Compute grad_x for the squeeze and excitation network.
69+
grad_x = grad_out * d + (grad_out * x * d * (1-d) @ b.T * (c>0)) @ a.T
70+
71+
The computation is structured to properly accumulate over the k dimension:
72+
1. First term: grad_out * d (element-wise, no reduction)
73+
2. Second term: chain rule through d->c->x path
74+
- For each output position (m, n), accumulate over k dimension
75+
- grad_c[m,k] = (grad_out * x * d * (1-d))[m,:] @ b[k,:].T * (c[m,k] > 0)
76+
- grad_x[m,n] += grad_c[m,k] @ a[n,k].T
77+
"""
78+
m, n = x.size()
79+
k = a.size(1)
80+
81+
grad_x = torch.empty([m, n], dtype=x.dtype, device=x.device)
82+
83+
# Compute grad_x: grad_out * d + second_term where second_term accumulates over k
84+
for tile_m, tile_n in hl.tile([m, n]):
85+
# First term: grad_out * d (element-wise)
86+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
87+
acc += grad_out[tile_m, tile_n] * d[tile_m, tile_n]
88+
89+
# Second term: accumulate gradient chain over k dimension
90+
for tile_k in hl.tile(k):
91+
# Compute grad_to_d for the full row: shape [tile_m, n]
92+
grad_to_d = (
93+
grad_out[tile_m, :] * x[tile_m, :] * d[tile_m, :] * (1.0 - d[tile_m, :])
94+
)
95+
96+
# Backprop through (c @ b): grad_c = grad_to_d @ b.T
97+
# [tile_m, n] @ [n, tile_k] = [tile_m, tile_k]
98+
grad_to_c = grad_to_d @ b[tile_k, :].T
99+
100+
# Apply ReLU mask: shape [tile_m, tile_k]
101+
grad_c_masked = grad_to_c * (c[tile_m, tile_k] > 0)
102+
103+
# Backprop through (x @ a): grad_x_contribution = grad_c_masked @ a.T
104+
# [tile_m, tile_k] @ [tile_k, tile_n] = [tile_m, tile_n]
105+
acc += grad_c_masked @ a[tile_n, tile_k].T
106+
107+
grad_x[tile_m, tile_n] = acc
108+
109+
return grad_x
110+
111+
112+
# %%
113+
@helion.kernel(static_shapes=True)
114+
def squeeze_and_excitation_net_bwd_da(
115+
grad_out: Tensor, x: Tensor, b: Tensor, c: Tensor, d: Tensor
116+
) -> Tensor:
117+
"""
118+
Compute grad_a for the squeeze and excitation network.
119+
grad_a = x.T @ (grad_out * x * d * (1-d) @ b.T * (c>0))
120+
"""
121+
m, n = x.size()
122+
k = c.size(1)
123+
124+
grad_a = torch.empty([n, k], dtype=x.dtype, device=x.device)
125+
126+
# Compute grad_a: x.T @ grad_c
127+
for tile_n, tile_k in hl.tile([n, k]):
128+
acc_a = hl.zeros([tile_n, tile_k], dtype=torch.float32)
129+
for tile_m in hl.tile(m):
130+
# Backprop through sigmoid: need full row for matmul with b.T
131+
grad_to_d = grad_out[tile_m, :] * x[tile_m, :]
132+
grad_to_cb = grad_to_d * d[tile_m, :] * (1.0 - d[tile_m, :])
133+
# Backprop through c @ b: [tile_m, n] @ [n, tile_k] = [tile_m, tile_k]
134+
grad_to_c = grad_to_cb @ b[tile_k, :].T
135+
# Backprop through relu
136+
grad_through_relu = grad_to_c * (c[tile_m, tile_k] > 0)
137+
# Accumulate x.T @ grad_c: [tile_n, tile_m] @ [tile_m, tile_k] = [tile_n, tile_k]
138+
acc_a += x[tile_m, tile_n].T @ grad_through_relu
139+
grad_a[tile_n, tile_k] = acc_a
140+
141+
return grad_a
142+
143+
144+
# %%
145+
@helion.kernel(static_shapes=True)
146+
def squeeze_and_excitation_net_bwd_db(
147+
grad_out: Tensor, x: Tensor, d: Tensor, c: Tensor
148+
) -> Tensor:
149+
"""
150+
Compute grad_b by fusing grad_d computation inline.
151+
grad_b = c.T @ (grad_out * x * d * (1 - d))
152+
"""
153+
m, n = grad_out.size()
154+
k = c.size(1)
155+
grad_b = torch.empty([k, n], dtype=grad_out.dtype, device=grad_out.device)
156+
157+
for tile_k, tile_n in hl.tile([k, n]):
158+
acc = hl.zeros([tile_k, tile_n], dtype=torch.float32)
159+
for tile_m in hl.tile(m):
160+
grad_d = (
161+
grad_out[tile_m, tile_n]
162+
* x[tile_m, tile_n]
163+
* d[tile_m, tile_n]
164+
* (1.0 - d[tile_m, tile_n])
165+
)
166+
acc += c[tile_m, tile_k].T @ grad_d
167+
grad_b[tile_k, tile_n] = acc
168+
169+
return grad_b
170+
171+
172+
# %%
173+
# Reference Implementation
174+
# --------------------
175+
def squeeze_and_excitation_net_pytorch(
176+
x: torch.Tensor, a: torch.Tensor, b: torch.Tensor
177+
) -> torch.Tensor:
178+
"""
179+
PyTorch reference implementation of squeeze_and_excitation_net.
180+
181+
Args:
182+
x, a, b: Input tensors
183+
184+
Returns:
185+
tensor of torch.mul(x, torch.sigmoid(torch.relu((x @ a)) @ b))
186+
"""
187+
return torch.mul(x, torch.sigmoid(torch.relu(x @ a) @ b))
188+
189+
190+
# %%
191+
# Autograd Function
192+
# ------------------
193+
class SqueezeAndExcitationNetFunction(torch.autograd.Function):
194+
@staticmethod
195+
def forward( # type: ignore[override]
196+
ctx: object,
197+
x: torch.Tensor,
198+
a: torch.Tensor,
199+
b: torch.Tensor,
200+
) -> torch.Tensor:
201+
"""Forward pass for squeeze and excitation network."""
202+
out, c, d = squeeze_and_excitation_net_fwd(x, a, b)
203+
ctx.save_for_backward(x, a, b, c, d) # type: ignore[attr-defined]
204+
return out
205+
206+
@staticmethod
207+
def backward( # type: ignore[override]
208+
ctx: object,
209+
grad_out: torch.Tensor,
210+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
211+
"""Backward pass for squeeze and excitation network."""
212+
x, a, b, c, d = ctx.saved_tensors # type: ignore[attr-defined]
213+
214+
grad_x = squeeze_and_excitation_net_bwd_dx(grad_out, x, a, b, c, d)
215+
grad_a = squeeze_and_excitation_net_bwd_da(grad_out, x, b, c, d)
216+
grad_b = squeeze_and_excitation_net_bwd_db(grad_out, x, d, c)
217+
return grad_x, grad_a, grad_b
218+
219+
220+
def squeeze_and_excitation_net(
221+
x: torch.Tensor, a: torch.Tensor, b: torch.Tensor
222+
) -> torch.Tensor:
223+
"""
224+
Squeeze and excitation network with autograd support.
225+
226+
Args:
227+
x: Input tensor [m, n]
228+
a: Weight matrix [n, k]
229+
b: Weight matrix [k, n]
230+
231+
Returns:
232+
Output tensor [m, n]
233+
"""
234+
return SqueezeAndExcitationNetFunction.apply(x, a, b) # type: ignore[no-any-return]
235+
236+
237+
def check(m: int, k: int, n: int) -> None:
238+
"""
239+
Checks the correctness against PyTorch.
240+
Args:
241+
m (int): Number of rows in matrix x.
242+
k (int): Number of columns in matrix x and rows in matrix y.
243+
n (int): Number of columns in matrix y.
244+
"""
245+
x = torch.randn([m, n], device="cuda", dtype=torch.float16, requires_grad=True)
246+
a = torch.randn([n, k], device="cuda", dtype=torch.float16, requires_grad=True)
247+
b = torch.randn([k, n], device="cuda", dtype=torch.float16, requires_grad=True)
248+
for bwd in [True, False]:
249+
run_example(
250+
squeeze_and_excitation_net,
251+
squeeze_and_excitation_net_pytorch,
252+
(x, a, b),
253+
bwd=bwd,
254+
)
255+
256+
257+
# %%
258+
def main() -> None:
259+
"""
260+
Main function to run autotuning (commented out) and correctness checks.
261+
"""
262+
# autotune(1024, 1024, 1024)
263+
check(1024, 1024, 1024)
264+
265+
266+
# %%
267+
if __name__ == "__main__":
268+
main()

0 commit comments

Comments
 (0)