Skip to content

Commit abe8339

Browse files
authored
[Benchmark] Add low mem dropout example (#641)
1 parent 9c64883 commit abe8339

File tree

4 files changed

+234
-0
lines changed

4 files changed

+234
-0
lines changed

benchmarks/run.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,11 @@ class RunResult:
280280
"examples.jagged_sum",
281281
"jagged_sum_tritonbench",
282282
),
283+
"low_mem_dropout": (
284+
"tritonbench.operators.low_mem_dropout.operator",
285+
"examples.low_mem_dropout",
286+
"low_mem_dropout_tritonbench",
287+
),
283288
}
284289

285290

@@ -538,6 +543,14 @@ class RunResult:
538543
"helion_fp8_gemm_tritonbench-speedup": "helion_speedup",
539544
"helion_fp8_gemm_tritonbench-accuracy": "helion_accuracy",
540545
},
546+
"low_mem_dropout": {
547+
"seeded_dropout-accuracy": "triton_accuracy",
548+
"seeded_dropout-speedup": "triton_speedup",
549+
"torch_compile_dropout-accuracy": "torch_compile_accuracy",
550+
"torch_compile_dropout-speedup": "torch_compile_speedup",
551+
"helion_low_mem_dropout_tritonbench-accuracy": "helion_accuracy",
552+
"helion_low_mem_dropout_tritonbench-speedup": "helion_speedup",
553+
},
541554
}
542555

543556

examples/low_mem_dropout.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""
2+
Low mem dropout Example
3+
================
4+
5+
This example demonstrates how to implement a Low mem dropout using Helion.
6+
"""
7+
8+
# %%
9+
# Imports
10+
# -------
11+
from __future__ import annotations
12+
13+
from typing import Callable
14+
15+
import torch
16+
17+
import helion
18+
import helion.language as hl
19+
20+
21+
# %%
22+
# Low mem dropout forward implementations
23+
# -------------------
24+
@helion.kernel()
25+
def low_mem_dropout(p: float, x: torch.Tensor, seed: int) -> torch.Tensor:
26+
"""
27+
Applies dropout on x using p
28+
Args:
29+
p (float): dropout probability
30+
x (torch.Tensor): input tensor
31+
Returns:
32+
Output tensor
33+
"""
34+
scale = 1.0 / (1.0 - p)
35+
# flatten to 1D so we can use tile
36+
n = x.numel()
37+
x_flat = x.view(-1)
38+
out_flat = torch.empty_like(x_flat)
39+
for tidx in hl.tile(n):
40+
xi = x_flat[tidx].to(torch.float32)
41+
r = hl.rand([tidx], seed=seed)
42+
keep = r > p
43+
yscaled = xi * scale
44+
yi = torch.where(keep, yscaled, 0.0)
45+
out_flat[tidx] = yi.to(x.dtype)
46+
return out_flat.view_as(x)
47+
48+
49+
# %%
50+
# Low mem dropout backward implementation
51+
# -------------------
52+
@helion.kernel()
53+
def low_mem_dropout_bwd(p: float, grad_y: torch.Tensor, seed: int) -> torch.Tensor:
54+
"""
55+
For low mem dropout we are applying randomness inside both fwd and bwd
56+
technically dropout bwd is same as fwd
57+
Args:
58+
p (float): Dropout probability
59+
grad_y (torch.Tensor): Gradient tensor
60+
Returns:
61+
Output tensor
62+
"""
63+
scale = 1.0 / (1.0 - p)
64+
n = grad_y.numel()
65+
grad_y_flat = grad_y.view(-1)
66+
out_flat = torch.empty_like(grad_y_flat)
67+
for tidx in hl.tile(n):
68+
gi = grad_y_flat[tidx].to(torch.float32)
69+
r = hl.rand([tidx], seed=seed)
70+
keep = r > p
71+
g_scaled = gi * scale
72+
gxi = torch.where(keep, g_scaled, 0.0)
73+
out_flat[tidx] = gxi.to(grad_y.dtype)
74+
return out_flat.view_as(grad_y)
75+
76+
77+
# %%
78+
# TritonBench Wrapper
79+
# -------------------
80+
def low_mem_dropout_tritonbench(tb_op: object, p: float, x: torch.Tensor) -> Callable:
81+
"""
82+
Wrapper for TritonBench compatibility.
83+
84+
Args:
85+
tb_op: TritonBench operator instance
86+
p (float): dropout probability
87+
x (torch.Tensor): Input tensor
88+
89+
Returns:
90+
Callable: A function that performs the low_mem_dropout.
91+
"""
92+
93+
def _inner() -> torch.Tensor:
94+
return low_mem_dropout(p, x, seed=123)
95+
96+
return _inner
97+
98+
99+
# %%
100+
# Verification Function
101+
# -------------------
102+
def check(p: float, size: int) -> None:
103+
"""
104+
Verify the low mem dropout kernel implementation against PyTorch's native dropout implementation.
105+
106+
Args:
107+
p (float): dropout probability
108+
size (int): input tensor size
109+
"""
110+
x = torch.randn(size=(size,)).cuda()
111+
seed = 123
112+
113+
out = low_mem_dropout(p, x, seed)
114+
grad_y = torch.ones_like(x)
115+
grad_x = low_mem_dropout_bwd(p, grad_y, seed)
116+
mask_fwd = out != 0
117+
mask_bwd = grad_x != 0
118+
assert torch.equal(mask_fwd, mask_bwd)
119+
120+
121+
# %%
122+
# Main Function
123+
# -----------
124+
def main() -> None:
125+
"""
126+
Main entry point that runs the low mem dropout kernel verification with different tensor sizes.
127+
Tests with two configurations:
128+
- p=0.25, s=8192
129+
- p=0.25, s=32768
130+
"""
131+
check(0.25, 8192)
132+
check(0.25, 32768)
133+
134+
135+
if __name__ == "__main__":
136+
main()

test/test_examples.expected

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2766,6 +2766,46 @@ def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.T
27662766
_launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, mean, rstd, mean.stride(0), out.stride(0), out.stride(1), rstd.stride(0), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
27672767
return (out, mean, rstd)
27682768

2769+
--- assertExpectedJournal(TestExamples.test_low_mem_dropout)
2770+
from __future__ import annotations
2771+
2772+
import torch
2773+
import triton
2774+
import triton.language as tl
2775+
from helion.runtime import default_launcher as _default_launcher
2776+
2777+
@triton.jit
2778+
def _helion_low_mem_dropout(x_flat, out_flat, out_flat_stride_0, x_flat_stride_0, n, seed, p, scale, _BLOCK_SIZE_0: tl.constexpr):
2779+
pid_0 = tl.program_id(0)
2780+
offset_0 = pid_0 * _BLOCK_SIZE_0
2781+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
2782+
mask_0 = indices_0 < n
2783+
xi = tl.load(x_flat + indices_0 * x_flat_stride_0, mask_0, other=0)
2784+
rand = tl.rand(seed, indices_0)
2785+
v_0 = rand > p
2786+
v_1 = xi * scale
2787+
v_2 = 0.0
2788+
v_3 = v_2[None]
2789+
v_4 = tl.where(v_0, v_1, v_3)
2790+
tl.store(out_flat + indices_0 * out_flat_stride_0, v_4, mask_0)
2791+
2792+
def low_mem_dropout(p: float, x: torch.Tensor, seed: int, *, _launcher=_default_launcher):
2793+
"""
2794+
Applies dropout on x using p
2795+
Args:
2796+
p (float): dropout probability
2797+
x (torch.Tensor): input tensor
2798+
Returns:
2799+
Output tensor
2800+
"""
2801+
scale = 1.0 / (1.0 - p)
2802+
n = x.numel()
2803+
x_flat = x.view(-1)
2804+
out_flat = torch.empty_like(x_flat)
2805+
_BLOCK_SIZE_0 = 1024
2806+
_launcher(_helion_low_mem_dropout, (triton.cdiv(n, _BLOCK_SIZE_0),), x_flat, out_flat, out_flat.stride(0), x_flat.stride(0), n, seed, p, scale, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
2807+
return out_flat.view_as(x)
2808+
27692809
--- assertExpectedJournal(TestExamples.test_matmul)
27702810
from __future__ import annotations
27712811

test/test_examples.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,51 @@ def test_welford(self):
310310
)
311311
)
312312

313+
def test_low_mem_dropout(self):
314+
from examples.low_mem_dropout import low_mem_dropout
315+
from examples.low_mem_dropout import low_mem_dropout_bwd
316+
317+
from helion._testing import code_and_output
318+
319+
p = 0.25
320+
size = 8192
321+
seed = 123
322+
seed2 = 456
323+
x = torch.randn(size=(size,)).cuda()
324+
325+
_, out_fwd = code_and_output(
326+
low_mem_dropout,
327+
(p, x, seed),
328+
)
329+
330+
grad_y = torch.ones_like(x)
331+
_, grad_x = code_and_output(
332+
low_mem_dropout_bwd,
333+
(p, grad_y, seed),
334+
)
335+
336+
_, grad_x2 = code_and_output(
337+
low_mem_dropout_bwd,
338+
(p, grad_y, seed2),
339+
)
340+
341+
mask_fwd = out_fwd != 0
342+
mask_bwd = grad_x != 0
343+
self.assertTrue(
344+
torch.equal(mask_fwd, mask_bwd),
345+
"Same elements should be dropped in fwd and bwd with the same seed",
346+
)
347+
348+
mask_bwd2 = grad_x2 != 0
349+
self.assertFalse(
350+
torch.equal(mask_bwd, mask_bwd2),
351+
"Different elements should be dropped when using a different seed",
352+
)
353+
354+
self.assertExpectedJournal(
355+
check_example("low_mem_dropout", (p, grad_y, seed), grad_x),
356+
)
357+
313358
def test_rms_norm_fwd(self):
314359
args = (
315360
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),

0 commit comments

Comments
 (0)