Skip to content

Commit 3939aea

Browse files
mengluy0125facebook-github-bot
authored andcommitted
Add simplified se_block kernel (#989)
Summary: We add a helion kernel to compute 2 * x * sigmoid(x @ w) Differential Revision: D84968671
1 parent d182808 commit 3939aea

File tree

3 files changed

+501
-0
lines changed

3 files changed

+501
-0
lines changed

examples/se_block.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
"""
2+
Helion SE Block Example
3+
============================
4+
This example demonstrates a Helion kernel implementation of SE Block.
5+
"""
6+
7+
# %%
8+
from __future__ import annotations
9+
10+
import torch
11+
from torch import Tensor
12+
13+
import helion
14+
from helion._testing import DEVICE
15+
from helion._testing import run_example
16+
import helion.language as hl
17+
18+
19+
# %%
20+
@helion.kernel(
21+
# static_shapes=True gives a performance boost for matmuls
22+
static_shapes=True,
23+
)
24+
def se_block_fwd(x: Tensor, w: Tensor) -> tuple[Tensor, Tensor]:
25+
"""
26+
Performs 2 * x * sigmoid(x @ w)
27+
Args:
28+
x: 2D tensor of shape [m, n].
29+
w: 2D tensor of shape [n, n].
30+
Returns:
31+
out: Resulting matrix of shape [m, n].
32+
s: sigmoid(x @ w) of shape [m, n].
33+
"""
34+
m, n = x.size()
35+
36+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
37+
s = torch.empty([m, n], dtype=x.dtype, device=x.device)
38+
39+
for tile_m in hl.tile(m):
40+
for tile_n in hl.tile(n):
41+
# Compute sigmoid in float32
42+
sigmoid_result = torch.sigmoid(x[tile_m, :] @ w[:, tile_n])
43+
s[tile_m, tile_n] = sigmoid_result
44+
# Compute output: 2 * x * sigmoid, cast to input dtype
45+
acc = 2.0 * x[tile_m, tile_n].to(torch.float32) * sigmoid_result
46+
out[tile_m, tile_n] = acc.to(x.dtype)
47+
48+
return out, s
49+
50+
51+
# %%
52+
@helion.kernel(static_shapes=True)
53+
def se_block_bwd_dx(grad_out: Tensor, x: Tensor, w: Tensor, s: Tensor) -> Tensor:
54+
"""
55+
Compute gradient for x.
56+
grad_x = 2 * grad_out * s + (2 * grad_out * x * s * (1 - s)) @ w.T
57+
58+
Args:
59+
grad_out: Gradient w.r.t output [m, n]
60+
x: Input tensor [m, n]
61+
w: Weight matrix [n, n]
62+
s: sigmoid(x @ w) from forward pass [m, n]
63+
64+
Returns:
65+
grad_x: Gradient w.r.t x [m, n]
66+
"""
67+
m, n = x.size()
68+
69+
grad_x = torch.empty([m, n], dtype=torch.float32, device=x.device)
70+
71+
for tile_m, tile_n in hl.tile([m, n]):
72+
# 2 * grad_out * s
73+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
74+
acc += 2.0 * grad_out[tile_m, tile_n] * s[tile_m, tile_n]
75+
76+
for tile_k in hl.tile(n):
77+
# 2 * grad_out * x * s * (1-s) for tile_k
78+
grad_to_w = (
79+
2.0
80+
* grad_out[tile_m, tile_k].to(torch.float32)
81+
* x[tile_m, tile_k].to(torch.float32)
82+
* s[tile_m, tile_k].to(torch.float32)
83+
* (1.0 - s[tile_m, tile_k].to(torch.float32))
84+
)
85+
# grad_to_w @ w.T[tile_k, tile_n] = grad_to_w @ w[tile_n, tile_k].T
86+
acc += grad_to_w @ w[tile_n, tile_k].to(torch.float32).T
87+
88+
grad_x[tile_m, tile_n] = acc.to(x.dtype)
89+
90+
return grad_x
91+
92+
93+
# %%
94+
@helion.kernel(static_shapes=True)
95+
def se_block_bwd_dw(grad_out: Tensor, x: Tensor, s: Tensor) -> Tensor:
96+
"""
97+
Compute gradient for w.
98+
grad_w = x.T @ (2 * grad_out * x * s * (1 - s))
99+
100+
Args:
101+
grad_out: Gradient w.r.t output [m, n]
102+
x: Input tensor [m, n]
103+
s: sigmoid(x @ w) from forward pass [m, n]
104+
105+
Returns:
106+
grad_w: Gradient w.r.t w [n, n]
107+
"""
108+
m, n = x.size()
109+
110+
grad_w = torch.zeros([n, n], dtype=torch.float32, device=x.device)
111+
112+
for tile_n1, tile_n2 in hl.tile([n, n]):
113+
acc_w = hl.zeros([tile_n1, tile_n2], dtype=torch.float32)
114+
for tile_m in hl.tile(m):
115+
# 2 * grad_out * x * s * (1-s)
116+
grad_to_w = (
117+
2.0
118+
* grad_out[tile_m, tile_n2].to(torch.float32)
119+
* x[tile_m, tile_n2].to(torch.float32)
120+
* s[tile_m, tile_n2].to(torch.float32)
121+
* (1.0 - s[tile_m, tile_n2].to(torch.float32))
122+
)
123+
# x[tile_m, tile_n1].T @ grad_to_w[tile_m, tile_n2]
124+
acc_w += x[tile_m, tile_n1].to(torch.float32).T @ grad_to_w
125+
126+
grad_w[tile_n1, tile_n2] = acc_w.to(x.dtype)
127+
128+
return grad_w
129+
130+
131+
# %%
132+
# Reference Implementation
133+
# --------------------
134+
def se_block_pytorch(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
135+
"""
136+
PyTorch reference implementation se_block.
137+
138+
Args:
139+
x, w: Input tensors
140+
141+
Returns:
142+
tensor of 2 * x * sigmoid(x @ w)
143+
"""
144+
return 2 * x * torch.sigmoid(x @ w)
145+
146+
147+
# %%
148+
# Autograd Function
149+
# ------------------
150+
class SEBlockFunction(torch.autograd.Function):
151+
@staticmethod
152+
def forward( # type: ignore[override]
153+
ctx: object,
154+
x: torch.Tensor,
155+
w: torch.Tensor,
156+
) -> torch.Tensor:
157+
"""Forward pass for se block."""
158+
out, s = se_block_fwd(x, w)
159+
ctx.save_for_backward(x, w, s) # type: ignore[attr-defined]
160+
return out
161+
162+
@staticmethod
163+
def backward( # type: ignore[override]
164+
ctx: object,
165+
grad_out: torch.Tensor,
166+
) -> tuple[torch.Tensor, torch.Tensor]:
167+
"""Backward pass for se block."""
168+
x, w, s = ctx.saved_tensors # type: ignore[attr-defined]
169+
170+
grad_x = se_block_bwd_dx(grad_out, x, w, s)
171+
grad_w = se_block_bwd_dw(grad_out, x, s)
172+
173+
return grad_x, grad_w
174+
175+
176+
def se_block(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
177+
"""
178+
SE Block with autograd support.
179+
180+
Args:
181+
x: Input tensor [m, n]
182+
w: Weight matrix [n, n]
183+
184+
Returns:
185+
Output tensor [m, n]
186+
"""
187+
return SEBlockFunction.apply(x, w) # type: ignore[no-any-return]
188+
189+
190+
def check(m: int, n: int) -> None:
191+
"""
192+
Checks the correctness against PyTorch.
193+
Args:
194+
m (int): Number of rows in matrix x.
195+
n (int): Number of columns in matrix x.
196+
"""
197+
x = torch.randn([m, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
198+
w = torch.randn([n, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
199+
for bwd in [True, False]:
200+
run_example(se_block, se_block_pytorch, (x, w), bwd=bwd)
201+
202+
203+
# %%
204+
def main() -> None:
205+
"""
206+
Main function to run correctness checks.
207+
"""
208+
check(1024, 1024)
209+
210+
211+
# %%
212+
if __name__ == "__main__":
213+
main()

0 commit comments

Comments
 (0)