|
21 | 21 | # testing utilities |
22 | 22 | from triton_kernels.testing import assert_close, compute_actual_scale |
23 | 23 | # target-specific utilities |
24 | | -from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4 |
25 | | - |
| 24 | +from triton_kernels.target_info import is_hip, is_xpu, is_hip_cdna3, is_cuda, is_hip_cdna4 |
| 25 | +from icecream import ic |
26 | 26 | # --------------- |
27 | 27 | # initialize data |
28 | 28 | # --------------- |
@@ -168,100 +168,12 @@ class Case: |
168 | 168 | ", ".join(f.name for f in fields(Case)), |
169 | 169 | [ |
170 | 170 | tuple(getattr(case, f.name) for f in fields(Case)) for case in [ |
171 | | - # Zero-sized args: |
172 | | - Case(0, 5, 7, "ragged", "float16", "float16"), |
173 | | - Case(5, 0, 7, "ragged", "float16", "float16"), |
174 | | - Case(5, 7, 0, "ragged", "float16", "float16"), |
175 | | - Case(0, 5, 7, "batched", "float16", "float16"), |
176 | | - Case(5, 0, 7, "batched", "float16", "float16"), |
177 | | - Case(5, 7, 0, "batched", "float16", "float16"), |
178 | | - # Non-mx types: |
179 | | - Case(16, 256, 256, "ragged", "float16", "float16", 128, 4), |
180 | | - Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=2), |
181 | | - Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=4), |
182 | | - Case(16, 256, 256, "ragged", "float16", "float16", 4, 1, n_expt_shards=2), |
183 | | - Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3), |
184 | | - Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3), |
185 | | - Case(300, 400, 400, "batched", "float8_e5m2", "float8_e5m2", 5, 1), |
186 | | - Case(16, 256, 256, "batched", "float16", "float16", 5, 1), |
187 | | - Case(16, 256, 256, "ragged", "float16", "float16", 3, 1), |
188 | | - Case(256, 256, 256, "ragged", "float16", "float16", 4, 1), |
189 | | - Case(256, 256, 256, "ragged", "float16", "float16", 4, 1, split_k=3), |
190 | | - Case(300, 400, 400, "batched", "float16", "float16", 5, 1), |
191 | | - Case(300, 400, 400, "ragged", "float16", "float16"), |
192 | | - Case(300, 400, 400, "ragged", "float8_e5m2", "float8_e5m2"), |
193 | | - Case(1000, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 3, 1), |
194 | | - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=1), |
195 | | - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=2), |
196 | | - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=4), |
197 | | - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2), |
198 | | - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, n_expt_shards=2), |
199 | | - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 1, n_expt_shards=2), |
200 | | - Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, split_k=2), |
201 | | - Case(1000, 400, 400, "ragged", "float16", "float16", 3, 1), |
202 | | - Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2), |
203 | | - Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9), |
204 | | - # mx types: |
205 | | - Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1), |
206 | | - Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True), |
207 | | - Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1), |
208 | | - Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True), |
209 | | - Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), |
210 | | - Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), |
211 | | - Case(1000, 700, 700, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9), |
212 | | - Case(1000, 512, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), |
213 | | - Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4), |
214 | | - Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), |
215 | | - Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4), |
216 | | - Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), |
217 | | - Case(1, 2880, 2880, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4), |
218 | | - Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), |
219 | | - Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), |
220 | | - Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), |
221 | | - Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1), |
222 | | - Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9), |
223 | | - Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), |
224 | | - Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2), |
225 | | - Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), |
226 | | - Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4), |
227 | | - Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), |
228 | | - Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4), |
229 | | - Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True), |
230 | | - Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4), |
231 | | - Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True), |
232 | | - Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), |
233 | | - Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=False), |
234 | | - Case(16, 256, 256, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), |
235 | | - Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), |
236 | | - Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1), |
237 | 171 | Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9), |
238 | | - Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), |
239 | | - Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2), |
240 | | - Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), |
241 | | - Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4), |
242 | | - Case(300, 512, 512, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4), |
243 | | - Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), |
244 | | - Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4), |
245 | | - Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True), |
246 | | - Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4), |
247 | | - Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True), |
248 | | - # AMD |
249 | | - Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"), |
250 | | - Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1), |
251 | | - Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2), |
252 | | - Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, n_expt_shards=2), |
253 | | - Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, split_k=2), |
254 | | - Case(300, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn"), |
255 | | - Case(1000, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1), |
256 | | - Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2), |
257 | | - Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2, n_expt_shards=2), |
258 | | - ] + [ |
259 | | - Case(320, 400, 400, mode, dtype, dtype, x_transpose=x_transpose, w_transpose=w_transpose, y_transpose=y_transpose) |
260 | | - for mode in ("batched", "ragged") |
261 | | - for dtype in ("float16", "float8_e5m2") |
262 | | - for x_transpose in (False, True) |
263 | | - for w_transpose in (False, True) |
264 | | - for y_transpose in (False, True) |
| 172 | + #Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9), |
| 173 | + Case(1000, 704, 800, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9), |
| 174 | + #Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), |
| 175 | +
|
| 176 | + #Case(1111, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1), |
265 | 177 | ] |
266 | 178 | ], |
267 | 179 | ) |
@@ -355,6 +267,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas |
355 | 267 | "block_k": 256 |
356 | 268 | }) |
357 | 269 |
|
| 270 | + ic(constraints) |
358 | 271 | opt_flags.update_opt_flags_constraints(constraints) |
359 | 272 |
|
360 | 273 | weight_mxfp = weight_dtype_str.startswith("mx") |
@@ -555,16 +468,6 @@ def _make_tensor(shape, dtype, trans): |
555 | 468 | ) |
556 | 469 |
|
557 | 470 |
|
558 | | -def test_set_idle_sms(): |
559 | | - if not is_cuda(): |
560 | | - pytest.skip("Only supported on CUDA") |
561 | | - from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags |
562 | | - num_idle_sms = 24 |
563 | | - matmul_ogs_set_idle_sms(num_idle_sms) |
564 | | - flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \ |
565 | | - 1, 1024, 1024, 1024, None, True, False, 1, False) |
566 | | - assert flags.idle_sms == num_idle_sms |
567 | | - |
568 | 471 |
|
569 | 472 | @pytest.mark.parametrize("m, n, k, mode", [ |
570 | 473 | (1200, 704, 608, "ragged"), |
|
0 commit comments