Skip to content

Commit 97f32cc

Browse files
Merge OpenAI Triton commit 318fa9c (#5498)
This PR changes the Triton base from b3cf593 to 318fa9c (Nov 2). Pass rate: 94.95%->98.1%
2 parents a3c2bc0 + ea925e1 commit 97f32cc

File tree

21 files changed

+1044
-181
lines changed

21 files changed

+1044
-181
lines changed

.github/workflows/llvm-build.yml

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ on:
99
pull_request:
1010
paths:
1111
- .github/workflows/llvm-build.yml
12+
- .github/workflows/llvm-build/almalinux.Dockerfile
13+
- .github/workflows/llvm-build/centos.Dockerfile
1214
workflow_dispatch:
1315

1416
env:
@@ -135,6 +137,7 @@ jobs:
135137
-DLLVM_INSTALL_UTILS=ON
136138
-DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU"
137139
-DLLVM_ENABLE_TERMINFO=OFF
140+
-DLLVM_ENABLE_ZSTD=OFF
138141
llvm-project/llvm
139142
140143
ninja -C llvm-project/build check-mlir install
@@ -237,7 +240,11 @@ jobs:
237240
run: |
238241
# if this step crashes, it can leave behind a stale docker container
239242
docker container prune -f
240-
docker rmi -f $(docker images -q)
243+
244+
images=$(docker images -q)
245+
if [ -n "$images" ]; then
246+
docker rmi -f $images
247+
fi
241248
242249
docker build --tag llvm-build --build-arg llvm_dir=llvm-project \
243250
-f llvm-build/.github/workflows/llvm-build/almalinux.Dockerfile .
@@ -264,16 +271,16 @@ jobs:
264271
path: |
265272
${{ github.workspace }}/llvm-*-${{ matrix.config.target-os }}-${{ matrix.config.arch }}.tar.gz
266273
267-
- name: Azure Login
268-
if: ${{ (github.repository == 'triton-lang/triton') }}
274+
- name: Azure login
275+
if: ${{ (github.repository == 'triton-lang/triton') && github.ref_name == 'llvm-head' }}
269276
uses: azure/login@v2
270277
with:
271-
client-id: ${{ secrets.AZURE_CLIENT_ID }}
272-
tenant-id: ${{ secrets.AZURE_TENANT_ID }}
273-
subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }}
278+
client-id: ${{ secrets.AZURE_CLIENT_ID_LLVM }}
279+
tenant-id: ${{ secrets.AZURE_TENANT_ID_LLVM }}
280+
subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID_LLVM }}
274281

275282
- name: Upload LLVM Artifacts to Azure
276-
if: ${{ (github.repository == 'triton-lang/triton') }}
283+
if: ${{ (github.repository == 'triton-lang/triton') && github.ref_name == 'llvm-head' }}
277284
shell: bash -el {0}
278285
run: |
279286
az storage blob upload --account-name oaitriton --auth-mode login --container-name public --file "${{ env.llvm_install_dir }}.tar.gz" --name "llvm-builds/${{ env.llvm_install_dir }}.tar.gz" --overwrite
@@ -282,7 +289,7 @@ jobs:
282289
echo "Blob URL: ${URL}"
283290
284291
- name: Azure Logout
285-
if: ${{ (github.repository == 'triton-lang/triton') }}
292+
if: ${{ (github.repository == 'triton-lang/triton') && github.ref_name == 'llvm-head' }}
286293
run: |
287294
az logout
288295
az cache purge

.github/workflows/llvm-build/almalinux.Dockerfile

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
FROM almalinux:8
1+
# https://github.com/AlmaLinux/container-images/blob/9f9b3c8c8cf4a57fd42f362570ff47c75788031f/default/amd64/Dockerfile
2+
FROM almalinux:8.10-20250411
23
ARG llvm_dir=llvm-project
34
# Add the cache artifacts and the LLVM source tree to the container
45
ADD sccache /sccache
@@ -8,6 +9,7 @@ ENV SCCACHE_CACHE_SIZE="2G"
89

910
RUN dnf install --assumeyes llvm-toolset
1011
RUN dnf install --assumeyes python38-pip python38-devel git
12+
RUN alternatives --set python3 /usr/bin/python3.8
1113

1214
RUN python3 -m pip install --upgrade pip
1315
RUN python3 -m pip install --upgrade cmake ninja sccache lit
@@ -26,6 +28,8 @@ RUN cmake -GNinja -Bbuild \
2628
-DCMAKE_CXX_FLAGS="-Wno-everything" \
2729
-DCMAKE_LINKER=lld \
2830
-DCMAKE_INSTALL_PREFIX="/install" \
31+
-DPython3_EXECUTABLE="/usr/bin/python3.8" \
32+
-DPython_EXECUTABLE="/usr/bin/python3.8" \
2933
-DLLVM_BUILD_UTILS=ON \
3034
-DLLVM_BUILD_TOOLS=ON \
3135
-DLLVM_ENABLE_ASSERTIONS=ON \
@@ -34,6 +38,7 @@ RUN cmake -GNinja -Bbuild \
3438
-DLLVM_ENABLE_TERMINFO=OFF \
3539
-DLLVM_INSTALL_UTILS=ON \
3640
-DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" \
41+
-DLLVM_ENABLE_ZSTD=OFF \
3742
/source/llvm-project/llvm
3843

3944
RUN ninja -C build install

python/test/gluon/test_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,8 +542,9 @@ def kernel(a_ptr, b_ptr, use_buffer_load: ttgl.constexpr):
542542
cdna4_async_copy.buffer_load_to_shared(smem, a_ptr, offsets)
543543
else:
544544
cdna4_async_copy.global_load_to_shared(smem, a_ptr + offsets)
545+
cdna4_async_copy.commit_group()
545546

546-
cdna4_async_copy.async_wait(0)
547+
cdna4_async_copy.wait_group(0)
547548
a = cdna4_async_copy.load_shared_relaxed(smem, blocked)
548549

549550
ttgl.store(b_ptr + offsets, a)

python/test/gluon/test_frontend.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,17 +1950,36 @@ def test_infer_layout_for_amd_wmma(target):
19501950

19511951

19521952
@gluon.jit
1953-
def amd_async_wait():
1954-
cdna4_async_copy.async_wait(0)
1953+
def amd_commit_group():
1954+
cdna4_async_copy.commit_group()
1955+
1956+
1957+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
1958+
def test_amd_commit_group(target):
1959+
mod = run_parser(amd_wait_group, target=target)
1960+
expecttest.assert_expected_inline(
1961+
anonymize_ir(mod.str_nodebug()), """\
1962+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1963+
tt.func public @amd_wait_group() attributes {noinline = false} {
1964+
%0 = ttg.async_wait {num = 0 : i32}
1965+
tt.return
1966+
}
1967+
}
1968+
""")
1969+
1970+
1971+
@gluon.jit
1972+
def amd_wait_group():
1973+
cdna4_async_copy.wait_group(0)
19551974

19561975

19571976
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
19581977
def test_amd_async_wait(target):
1959-
mod = run_parser(amd_async_wait, target=target)
1978+
mod = run_parser(amd_wait_group, target=target)
19601979
expecttest.assert_expected_inline(
19611980
anonymize_ir(mod.str_nodebug()), """\
19621981
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1963-
tt.func public @amd_async_wait() attributes {noinline = false} {
1982+
tt.func public @amd_wait_group() attributes {noinline = false} {
19641983
%0 = ttg.async_wait {num = 0 : i32}
19651984
tt.return
19661985
}

python/triton/experimental/gluon/language/amd/cdna4/async_copy.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
__all__ = [
77
"global_load_to_shared",
88
"buffer_load_to_shared",
9-
"async_wait",
9+
"commit_group",
10+
"wait_group",
1011
"load_shared_relaxed",
1112
]
1213

@@ -17,7 +18,10 @@ def global_load_to_shared(dest, ptr, mask=None, other=None, cache_modifier="", _
1718
AMD global load to shared operation. This operation loads data directly
1819
from global memory to shared memory without going through registers. It
1920
happens asynchronously and requires a subsequent `async_wait` to ensure the
20-
data is available in shared memory.
21+
data is available in shared memory. Note that this operation does still
22+
complete in order with ttgl.loads/stores or buffer_loads/stores on CDNA4,
23+
so interleaving with them will hurt performance.
24+
2125
Compared to `buffer_load_to_shared`, it requires a tensor pointer which
2226
supports 64-bit indexing range for each thread in a block, which gives more
2327
flexibility, but at the cost of higher register pressure and no hardware
@@ -72,7 +76,10 @@ def buffer_load_to_shared(dest, ptr, offsets, mask=None, other=None, cache_modif
7276
32-bit offsets instead of a tensor of pointers. This operation loads data
7377
directly from global memory to shared memory without going through
7478
registers. It happens asynchronously and requires a subsequent `async_wait`
75-
to ensure the data is available in shared memory.
79+
to ensure thedata is available in shared memory. Note that this operation
80+
does still complete in order with ttgl.loads/stores or buffer_loads/stores
81+
on CDNA4, so interleaving with them will hurt performance.
82+
7683
Compared to `global_load_to_shared`, it has better performance and also
7784
supports hardware out-of-bound masking. But it strictly requires a
7885
32-bit offset instead of a 64-bit tensor pointer.
@@ -118,16 +125,24 @@ def buffer_load_to_shared(dest, ptr, offsets, mask=None, other=None, cache_modif
118125

119126

120127
@builtin
121-
def async_wait(num_outstanding=0, _semantic=None):
128+
def commit_group(_semantic=None):
129+
"""
130+
Commit oustanding async operations.
131+
132+
This finalizes a set of async copy operations which can be waited upon via `wait_group`.
133+
"""
134+
_semantic.builder.create_async_commit_group()
135+
136+
137+
@builtin
138+
def wait_group(num_outstanding=0, _semantic=None):
122139
"""
123-
Wait for outstanding memory operations, this includes normal load like
124-
`load` and `buffer_load`, as well as direct load to shared memory
125-
like `global_load_to_shared` and `buffer_load_to_shared`.
126-
It will block until the number of outstanding memory operations is less than
127-
or equal to `num_outstanding`.
140+
Wait for outstanding commit groups. It will block until the number of
141+
outstanding commit groups is less than or equal to `num_outstanding`. Note that uncommited
142+
async operations will be waited upon even if `num_outstanding` is 0.
128143
129144
Args:
130-
num_outstanding (int): The number of outstanding operations to wait for. Defaults to 0.
145+
num_outstanding (int): The number of outstanding commit groups to wait for. Defaults to 0.
131146
"""
132147
num_outstanding = _unwrap_if_constexpr(num_outstanding)
133148
_semantic.builder.create_async_wait_group(num_outstanding)

python/triton_kernels/tests/test_matmul.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -303,27 +303,27 @@ class Case:
303303
],
304304
)
305305
@pytest.mark.parametrize("block_m", [16, 128])
306-
@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter, inner_expt_opt", [
307-
(False, False, False, None),
308-
(True, False, False, None),
309-
(False, True, False, None),
310-
(False, True, True, None),
311-
(True, True, False, None),
312-
(True, True, True, None),
313-
(False, False, False, "pad_w"),
314-
(False, False, False, "pad_x"),
306+
@pytest.mark.parametrize("do_gather, do_scatter, inner_expt_opt", [
307+
(False, False, None),
308+
(True, False, None),
309+
(False, True, None),
310+
(False, True, None),
311+
(True, True, None),
312+
(True, True, None),
313+
(False, False, "pad_w"),
314+
(False, False, "pad_x"),
315315
])
316316
@pytest.mark.parametrize("has_y_gammas", [False, True])
317317
@pytest.mark.parametrize("is_persistent", [False, True])
318-
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
318+
def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
319319
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
320320
x_transpose, w_transpose, y_transpose,
321321
device, opt_flags_scope):
322322
# We catch and re-invoke pytest.skip(), because otherwise pytest may hold a reference to
323323
# the frame that called pytest.skip, including all the tensors, leading to OOM.
324324
skip_message = None
325325
try:
326-
_test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
326+
_test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
327327
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
328328
x_transpose, w_transpose, y_transpose,
329329
device, opt_flags_scope)
@@ -333,7 +333,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
333333
if skip_message is not None:
334334
pytest.skip(skip_message)
335335

336-
def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
336+
def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
337337
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
338338
x_transpose, w_transpose, y_transpose,
339339
device, opt_flags_scope):
@@ -362,9 +362,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_
362362
if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3():
363363
pytest.xfail("float8_e4m3fnuz only tested on AMD CDNA3 Platform")
364364

365-
if fused_scatter and split_k is not None and split_k > 1:
366-
pytest.xfail("fused scatter scratchpad not supported with split_k")
367-
368365
if hbm_swizzling:
369366
if is_hip():
370367
if not is_hip_cdna4():
@@ -414,7 +411,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_
414411
"block_m": block_m,
415412
"block_k": block_k,
416413
"split_k": split_k,
417-
"fused_scatter": fused_scatter,
418414
"is_persistent": is_persistent,
419415
"epilogue_subtile": epilogue_subtile,
420416
}
@@ -727,12 +723,11 @@ def test_set_idle_sms():
727723
(800, 800, 400, "batched"),
728724
])
729725
@pytest.mark.parametrize("split_k", [1, 2])
730-
@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [
731-
(False, False, False),
732-
(True, False, False),
733-
(False, True, False),
734-
(True, True, False),
735-
(True, True, True),
726+
@pytest.mark.parametrize("do_gather, do_scatter", [
727+
(False, False),
728+
(True, False),
729+
(False, True),
730+
(True, True),
736731
])
737732
@pytest.mark.parametrize("is_persistent, epilogue_subtile", [
738733
(False, None),
@@ -744,16 +739,13 @@ def test_set_idle_sms():
744739
(1.0, 1.2),
745740
(0.7, 1.0),
746741
])
747-
def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter, is_persistent, epilogue_subtile,
742+
def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, is_persistent, epilogue_subtile,
748743
swiglu_alpha, swiglu_limit, device, opt_flags_scope):
749-
if fused_scatter and split_k > 1:
750-
pytest.xfail("fused scatter scratchpad not supported with split_k")
751744
torch.manual_seed(0)
752745
constraints = {
753746
"is_persistent": is_persistent,
754747
"epilogue_subtile": epilogue_subtile,
755748
"split_k": split_k,
756-
"fused_scatter": fused_scatter,
757749
}
758750
n_expts_tot, n_expts_act = 1, 1
759751
opt_flags.update_opt_flags_constraints(constraints)

python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def get_flags(split_k, max_mn):
185185
k,
186186
None,
187187
False,
188-
False,
188+
True,
189189
False,
190190
0,
191191
False,

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,10 @@ def matmul_ogs(x, w, bias,
419419
has_gather_tma = has_gather and target_info.has_tma_gather()
420420
# hopper w/ mxfp4 doesn't support TMA
421421
can_use_tma = can_use_tma and is_cuda() and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
422+
can_use_split_k = scatter_indx is None and not x_has_mx and not w_has_mx
422423
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
423424
batch_size, M, N, w.shape[-2], routing_data,
424-
can_use_tma, scatter_indx is not None, epilogue.effective_itemsize,
425+
can_use_tma, can_use_split_k, epilogue.effective_itemsize,
425426
x_transpose, y_acc_in is not None,
426427
inner_routing_data.block_k if inner_routing_data is not None else None,
427428
)
@@ -618,21 +619,21 @@ def matmul_ogs(x, w, bias,
618619
**fused_comm_kwargs,
619620
**opt_flags.target_kernel_kwargs)
620621

622+
assert not (opt_flags.split_k > 1 and scatter_indx is not None)
621623
out_final_mx_scale = None
622624
if opt_flags.split_k > 1:
623625
assert not out_matmul_has_mx
624-
has_scatter = scatter_indx is not None
625626
postprocess_fn1 = ReducePostprocessFn(specs=reduce_fused_activation.specs, fn_args=reduce_fused_activation.fn_args)
626627
postprocess_fn2 = None if has_scatter else ReducePostprocessFn(specs=epilogue.specs, fn_args=epilogue.fn_arg_values_finalize)
627628
y, y_mx_scale = reduce(
628629
x = out_matmul.view(out_matmul.shape[0], -1, out_matmul.shape[-1]),
629630
dim = 0,
630631
# output data/metadata
631-
y = None if has_scatter else memory["output"].view(-1, memory["output"].shape[-1]),
632-
y_dtype = out_matmul.dtype if has_scatter else memory["output"].dtype,
633-
y_flex = OutFlexData() if has_scatter else precision_config.flex_ctx.out_data,
634-
y_flex_saturate_inf = None if has_scatter else precision_config.flexpoint_saturate_inf,
635-
y_has_mx = scatter_indx is None and precision_config.out_scale is not None,
632+
y = memory["output"].view(-1, memory["output"].shape[-1]),
633+
y_dtype = memory["output"].dtype,
634+
y_flex = precision_config.flex_ctx.out_data,
635+
y_flex_saturate_inf = precision_config.flexpoint_saturate_inf,
636+
y_has_mx = precision_config.out_scale is not None,
636637
# fused functions
637638
postprocess_fn1 = postprocess_fn1,
638639
postprocess_fn2 = postprocess_fn2,

0 commit comments

Comments
 (0)