Skip to content

Commit b504f18

Browse files
authored
Extend eviction policy tests to all indexing types (#833)
1 parent 6581aac commit b504f18

File tree

2 files changed

+267
-6
lines changed

2 files changed

+267
-6
lines changed

test/test_eviction_policy.expected

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,79 @@ import triton
99
import triton.language as tl
1010
from helion.runtime import default_launcher as _default_launcher
1111

12+
@triton.jit
13+
def _helion_kernel_with_eviction(x, y, out, x_size_0, out_stride_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
14+
pid_0 = tl.program_id(0)
15+
offset_0 = pid_0 * _BLOCK_SIZE_0
16+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
17+
mask_0 = indices_0 < x_size_0
18+
val_x = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
19+
val_y = tl.load(y + indices_0 * y_stride_0, mask_0, other=0, eviction_policy='evict_last')
20+
v_0 = val_x + val_y
21+
tl.store(out + indices_0 * out_stride_0, v_0, mask_0)
22+
23+
def kernel_with_eviction(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
24+
out = torch.empty_like(x)
25+
_BLOCK_SIZE_0 = 16
26+
_launcher(_helion_kernel_with_eviction, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, out, x.size(0), out.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
27+
return out
28+
29+
--- assertExpectedJournal(TestEvictionPolicy.test_eviction_policy_in_generated_code_indexing_block_ptr)
30+
from __future__ import annotations
31+
32+
import torch
33+
import triton
34+
import triton.language as tl
35+
from helion.runtime import default_launcher as _default_launcher
36+
37+
@triton.jit
38+
def _helion_kernel_with_eviction(x, y, out, out_size_0, x_size_0, y_size_0, out_stride_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
39+
pid_0 = tl.program_id(0)
40+
offset_0 = pid_0 * _BLOCK_SIZE_0
41+
val_x = tl.load(tl.make_block_ptr(x, [x_size_0], [x_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero')
42+
val_y = tl.load(tl.make_block_ptr(y, [y_size_0], [y_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero', eviction_policy='evict_last')
43+
v_0 = val_x + val_y
44+
tl.store(tl.make_block_ptr(out, [out_size_0], [out_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), v_0, boundary_check=[0])
45+
46+
def kernel_with_eviction(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
47+
out = torch.empty_like(x)
48+
_BLOCK_SIZE_0 = 16
49+
_launcher(_helion_kernel_with_eviction, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, out, out.size(0), x.size(0), y.size(0), out.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
50+
return out
51+
52+
--- assertExpectedJournal(TestEvictionPolicy.test_eviction_policy_in_generated_code_indexing_pointer)
53+
from __future__ import annotations
54+
55+
import torch
56+
import triton
57+
import triton.language as tl
58+
from helion.runtime import default_launcher as _default_launcher
59+
60+
@triton.jit
61+
def _helion_kernel_with_eviction(x, y, out, x_size_0, out_stride_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
62+
pid_0 = tl.program_id(0)
63+
offset_0 = pid_0 * _BLOCK_SIZE_0
64+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
65+
mask_0 = indices_0 < x_size_0
66+
val_x = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
67+
val_y = tl.load(y + indices_0 * y_stride_0, mask_0, other=0, eviction_policy='evict_last')
68+
v_0 = val_x + val_y
69+
tl.store(out + indices_0 * out_stride_0, v_0, mask_0)
70+
71+
def kernel_with_eviction(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
72+
out = torch.empty_like(x)
73+
_BLOCK_SIZE_0 = 16
74+
_launcher(_helion_kernel_with_eviction, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, out, x.size(0), out.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
75+
return out
76+
77+
--- assertExpectedJournal(TestEvictionPolicy.test_eviction_policy_in_generated_code_indexing_tensor_descriptor)
78+
from __future__ import annotations
79+
80+
import torch
81+
import triton
82+
import triton.language as tl
83+
from helion.runtime import default_launcher as _default_launcher
84+
1285
@triton.jit
1386
def _helion_kernel_with_eviction(x, y, out, x_size_0, out_stride_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
1487
pid_0 = tl.program_id(0)
@@ -34,6 +107,79 @@ import triton
34107
import triton.language as tl
35108
from helion.runtime import default_launcher as _default_launcher
36109

110+
@triton.jit
111+
def _helion_kernel_with_override(x, y, out, x_size_0, out_stride_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
112+
pid_0 = tl.program_id(0)
113+
offset_0 = pid_0 * _BLOCK_SIZE_0
114+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
115+
mask_0 = indices_0 < x_size_0
116+
val_x = tl.load(x + indices_0 * x_stride_0, mask_0, other=0, eviction_policy='evict_last')
117+
val_y = tl.load(y + indices_0 * y_stride_0, mask_0, other=0, eviction_policy='evict_first')
118+
v_0 = val_x + val_y
119+
tl.store(out + indices_0 * out_stride_0, v_0, mask_0)
120+
121+
def kernel_with_override(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
122+
out = torch.empty_like(x)
123+
_BLOCK_SIZE_0 = 16
124+
_launcher(_helion_kernel_with_override, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, out, x.size(0), out.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
125+
return out
126+
127+
--- assertExpectedJournal(TestEvictionPolicy.test_explicit_eviction_policy_overrides_tunable_indexing_block_ptr)
128+
from __future__ import annotations
129+
130+
import torch
131+
import triton
132+
import triton.language as tl
133+
from helion.runtime import default_launcher as _default_launcher
134+
135+
@triton.jit
136+
def _helion_kernel_with_override(x, y, out, out_size_0, x_size_0, y_size_0, out_stride_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
137+
pid_0 = tl.program_id(0)
138+
offset_0 = pid_0 * _BLOCK_SIZE_0
139+
val_x = tl.load(tl.make_block_ptr(x, [x_size_0], [x_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero', eviction_policy='evict_last')
140+
val_y = tl.load(tl.make_block_ptr(y, [y_size_0], [y_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero', eviction_policy='evict_first')
141+
v_0 = val_x + val_y
142+
tl.store(tl.make_block_ptr(out, [out_size_0], [out_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), v_0, boundary_check=[0])
143+
144+
def kernel_with_override(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
145+
out = torch.empty_like(x)
146+
_BLOCK_SIZE_0 = 16
147+
_launcher(_helion_kernel_with_override, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, out, out.size(0), x.size(0), y.size(0), out.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
148+
return out
149+
150+
--- assertExpectedJournal(TestEvictionPolicy.test_explicit_eviction_policy_overrides_tunable_indexing_pointer)
151+
from __future__ import annotations
152+
153+
import torch
154+
import triton
155+
import triton.language as tl
156+
from helion.runtime import default_launcher as _default_launcher
157+
158+
@triton.jit
159+
def _helion_kernel_with_override(x, y, out, x_size_0, out_stride_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
160+
pid_0 = tl.program_id(0)
161+
offset_0 = pid_0 * _BLOCK_SIZE_0
162+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
163+
mask_0 = indices_0 < x_size_0
164+
val_x = tl.load(x + indices_0 * x_stride_0, mask_0, other=0, eviction_policy='evict_last')
165+
val_y = tl.load(y + indices_0 * y_stride_0, mask_0, other=0, eviction_policy='evict_first')
166+
v_0 = val_x + val_y
167+
tl.store(out + indices_0 * out_stride_0, v_0, mask_0)
168+
169+
def kernel_with_override(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
170+
out = torch.empty_like(x)
171+
_BLOCK_SIZE_0 = 16
172+
_launcher(_helion_kernel_with_override, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, out, x.size(0), out.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
173+
return out
174+
175+
--- assertExpectedJournal(TestEvictionPolicy.test_explicit_eviction_policy_overrides_tunable_indexing_tensor_descriptor)
176+
from __future__ import annotations
177+
178+
import torch
179+
import triton
180+
import triton.language as tl
181+
from helion.runtime import default_launcher as _default_launcher
182+
37183
@triton.jit
38184
def _helion_kernel_with_override(x, y, out, x_size_0, out_stride_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr):
39185
pid_0 = tl.program_id(0)
@@ -103,6 +249,29 @@ import triton
103249
import triton.language as tl
104250
from helion.runtime import default_launcher as _default_launcher
105251

252+
@triton.jit
253+
def _helion_copy_with_eviction(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
254+
pid_0 = tl.program_id(0)
255+
offset_0 = pid_0 * _BLOCK_SIZE_0
256+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
257+
mask_0 = indices_0 < x_size_0
258+
val = tl.load(x + indices_0 * x_stride_0, mask_0, other=0, eviction_policy='evict_last')
259+
tl.store(out + indices_0 * out_stride_0, val, mask_0)
260+
261+
def copy_with_eviction(x: torch.Tensor, *, _launcher=_default_launcher):
262+
out = torch.empty_like(x)
263+
_BLOCK_SIZE_0 = 16
264+
_launcher(_helion_copy_with_eviction, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
265+
return out
266+
267+
--- assertExpectedJournal(TestEvictionPolicy.test_hl_load_eviction_policy_emitted_indexing_tensor_descriptor)
268+
from __future__ import annotations
269+
270+
import torch
271+
import triton
272+
import triton.language as tl
273+
from helion.runtime import default_launcher as _default_launcher
274+
106275
@triton.jit
107276
def _helion_copy_with_eviction(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
108277
pid_0 = tl.program_id(0)
@@ -144,3 +313,82 @@ def kernel_multiple_loads(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *,
144313
_BLOCK_SIZE_0 = 16
145314
_launcher(_helion_kernel_multiple_loads, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, z, out, x.size(0), out.stride(0), x.stride(0), y.stride(0), z.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
146315
return out
316+
317+
--- assertExpectedJournal(TestEvictionPolicy.test_multiple_loads_different_policies_indexing_block_ptr)
318+
from __future__ import annotations
319+
320+
import torch
321+
import triton
322+
import triton.language as tl
323+
from helion.runtime import default_launcher as _default_launcher
324+
325+
@triton.jit
326+
def _helion_kernel_multiple_loads(x, y, z, out, out_size_0, x_size_0, y_size_0, z_size_0, out_stride_0, x_stride_0, y_stride_0, z_stride_0, _BLOCK_SIZE_0: tl.constexpr):
327+
pid_0 = tl.program_id(0)
328+
offset_0 = pid_0 * _BLOCK_SIZE_0
329+
val_x = tl.load(tl.make_block_ptr(x, [x_size_0], [x_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero', eviction_policy='evict_first')
330+
val_y = tl.load(tl.make_block_ptr(y, [y_size_0], [y_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero', eviction_policy='evict_last')
331+
val_z = tl.load(tl.make_block_ptr(z, [z_size_0], [z_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero')
332+
v_0 = val_x + val_y
333+
v_1 = v_0 + val_z
334+
tl.store(tl.make_block_ptr(out, [out_size_0], [out_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), v_1, boundary_check=[0])
335+
336+
def kernel_multiple_loads(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, _launcher=_default_launcher):
337+
out = torch.empty_like(x)
338+
_BLOCK_SIZE_0 = 16
339+
_launcher(_helion_kernel_multiple_loads, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, z, out, out.size(0), x.size(0), y.size(0), z.size(0), out.stride(0), x.stride(0), y.stride(0), z.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
340+
return out
341+
342+
--- assertExpectedJournal(TestEvictionPolicy.test_multiple_loads_different_policies_indexing_pointer)
343+
from __future__ import annotations
344+
345+
import torch
346+
import triton
347+
import triton.language as tl
348+
from helion.runtime import default_launcher as _default_launcher
349+
350+
@triton.jit
351+
def _helion_kernel_multiple_loads(x, y, z, out, x_size_0, out_stride_0, x_stride_0, y_stride_0, z_stride_0, _BLOCK_SIZE_0: tl.constexpr):
352+
pid_0 = tl.program_id(0)
353+
offset_0 = pid_0 * _BLOCK_SIZE_0
354+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
355+
mask_0 = indices_0 < x_size_0
356+
val_x = tl.load(x + indices_0 * x_stride_0, mask_0, other=0, eviction_policy='evict_first')
357+
val_y = tl.load(y + indices_0 * y_stride_0, mask_0, other=0, eviction_policy='evict_last')
358+
val_z = tl.load(z + indices_0 * z_stride_0, mask_0, other=0)
359+
v_0 = val_x + val_y
360+
v_1 = v_0 + val_z
361+
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
362+
363+
def kernel_multiple_loads(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, _launcher=_default_launcher):
364+
out = torch.empty_like(x)
365+
_BLOCK_SIZE_0 = 16
366+
_launcher(_helion_kernel_multiple_loads, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, z, out, x.size(0), out.stride(0), x.stride(0), y.stride(0), z.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
367+
return out
368+
369+
--- assertExpectedJournal(TestEvictionPolicy.test_multiple_loads_different_policies_indexing_tensor_descriptor)
370+
from __future__ import annotations
371+
372+
import torch
373+
import triton
374+
import triton.language as tl
375+
from helion.runtime import default_launcher as _default_launcher
376+
377+
@triton.jit
378+
def _helion_kernel_multiple_loads(x, y, z, out, x_size_0, out_stride_0, x_stride_0, y_stride_0, z_stride_0, _BLOCK_SIZE_0: tl.constexpr):
379+
pid_0 = tl.program_id(0)
380+
offset_0 = pid_0 * _BLOCK_SIZE_0
381+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
382+
mask_0 = indices_0 < x_size_0
383+
val_x = tl.load(x + indices_0 * x_stride_0, mask_0, other=0, eviction_policy='evict_first')
384+
val_y = tl.load(y + indices_0 * y_stride_0, mask_0, other=0, eviction_policy='evict_last')
385+
val_z = tl.load(z + indices_0 * z_stride_0, mask_0, other=0)
386+
v_0 = val_x + val_y
387+
v_1 = v_0 + val_z
388+
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
389+
390+
def kernel_multiple_loads(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, _launcher=_default_launcher):
391+
out = torch.empty_like(x)
392+
_BLOCK_SIZE_0 = 16
393+
_launcher(_helion_kernel_multiple_loads, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, z, out, x.size(0), out.stride(0), x.stride(0), y.stride(0), z.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
394+
return out

test/test_eviction_policy.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ def copy_with_eviction(x: torch.Tensor) -> torch.Tensor:
3333
x = torch.randn([128], device=DEVICE, dtype=torch.float32)
3434
code, result = code_and_output(copy_with_eviction, (x,))
3535
torch.testing.assert_close(result, x)
36-
if indexing != "tensor_descriptor":
37-
# TODO(oulgen): Update this on a machine that supports tensor_descriptor
38-
self.assertExpectedJournal(code)
36+
self.assertExpectedJournal(code)
3937
self.assertIn("eviction_policy", code)
4038
self.assertIn("evict_last", code)
4139

@@ -69,13 +67,18 @@ def kernel_with_loads(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
6967
self.assertIn("first", fragment.inner.choices)
7068
self.assertIn("last", fragment.inner.choices)
7169

72-
def test_eviction_policy_in_generated_code(self):
70+
@parametrize("indexing", ("pointer", "block_ptr", "tensor_descriptor"))
71+
def test_eviction_policy_in_generated_code(self, indexing: str):
7372
"""Test that eviction policies appear in generated code when configured."""
7473

74+
if indexing == "tensor_descriptor" and not supports_tensor_descriptor():
75+
self.skipTest("Tensor descriptor support is required")
76+
7577
@helion.kernel(
7678
config={
7779
"block_size": 16,
7880
"load_eviction_policies": ["", "last"],
81+
"indexing": indexing,
7982
}
8083
)
8184
def kernel_with_eviction(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@@ -96,11 +99,16 @@ def kernel_with_eviction(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
9699
self.assertIn("evict_last", code)
97100
self.assertExpectedJournal(code)
98101

99-
def test_explicit_eviction_policy_overrides_tunable(self):
102+
@parametrize("indexing", ("pointer", "block_ptr", "tensor_descriptor"))
103+
def test_explicit_eviction_policy_overrides_tunable(self, indexing: str):
104+
if indexing == "tensor_descriptor" and not supports_tensor_descriptor():
105+
self.skipTest("Tensor descriptor support is required")
106+
100107
@helion.kernel(
101108
config={
102109
"block_size": 16,
103110
"load_eviction_policies": ["first", "first"],
111+
"indexing": indexing,
104112
}
105113
)
106114
def kernel_with_override(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@@ -121,11 +129,16 @@ def kernel_with_override(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
121129
self.assertIn("evict_last", code)
122130
self.assertExpectedJournal(code)
123131

124-
def test_multiple_loads_different_policies(self):
132+
@parametrize("indexing", ("pointer", "block_ptr", "tensor_descriptor"))
133+
def test_multiple_loads_different_policies(self, indexing: str):
134+
if indexing == "tensor_descriptor" and not supports_tensor_descriptor():
135+
self.skipTest("Tensor descriptor support is required")
136+
125137
@helion.kernel(
126138
config={
127139
"block_size": 16,
128140
"load_eviction_policies": ["first", "last", ""],
141+
"indexing": indexing,
129142
}
130143
)
131144
def kernel_multiple_loads(

0 commit comments

Comments
 (0)