Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions .github/workflows/llvm-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ on:
pull_request:
paths:
- .github/workflows/llvm-build.yml
- .github/workflows/llvm-build/almalinux.Dockerfile
- .github/workflows/llvm-build/centos.Dockerfile
workflow_dispatch:

env:
Expand Down Expand Up @@ -135,6 +137,7 @@ jobs:
-DLLVM_INSTALL_UTILS=ON
-DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU"
-DLLVM_ENABLE_TERMINFO=OFF
-DLLVM_ENABLE_ZSTD=OFF
llvm-project/llvm

ninja -C llvm-project/build check-mlir install
Expand Down Expand Up @@ -237,7 +240,11 @@ jobs:
run: |
# if this step crashes, it can leave behind a stale docker container
docker container prune -f
docker rmi -f $(docker images -q)

images=$(docker images -q)
if [ -n "$images" ]; then
docker rmi -f $images
fi

docker build --tag llvm-build --build-arg llvm_dir=llvm-project \
-f llvm-build/.github/workflows/llvm-build/almalinux.Dockerfile .
Expand All @@ -264,16 +271,16 @@ jobs:
path: |
${{ github.workspace }}/llvm-*-${{ matrix.config.target-os }}-${{ matrix.config.arch }}.tar.gz

- name: Azure Login
if: ${{ (github.repository == 'triton-lang/triton') }}
- name: Azure login
if: ${{ (github.repository == 'triton-lang/triton') && github.ref_name == 'llvm-head' }}
uses: azure/login@v2
with:
client-id: ${{ secrets.AZURE_CLIENT_ID }}
tenant-id: ${{ secrets.AZURE_TENANT_ID }}
subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }}
client-id: ${{ secrets.AZURE_CLIENT_ID_LLVM }}
tenant-id: ${{ secrets.AZURE_TENANT_ID_LLVM }}
subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID_LLVM }}

- name: Upload LLVM Artifacts to Azure
if: ${{ (github.repository == 'triton-lang/triton') }}
if: ${{ (github.repository == 'triton-lang/triton') && github.ref_name == 'llvm-head' }}
shell: bash -el {0}
run: |
az storage blob upload --account-name oaitriton --auth-mode login --container-name public --file "${{ env.llvm_install_dir }}.tar.gz" --name "llvm-builds/${{ env.llvm_install_dir }}.tar.gz" --overwrite
Expand All @@ -282,7 +289,7 @@ jobs:
echo "Blob URL: ${URL}"

- name: Azure Logout
if: ${{ (github.repository == 'triton-lang/triton') }}
if: ${{ (github.repository == 'triton-lang/triton') && github.ref_name == 'llvm-head' }}
run: |
az logout
az cache purge
Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/llvm-build/almalinux.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
FROM almalinux:8
# https://github.com/AlmaLinux/container-images/blob/9f9b3c8c8cf4a57fd42f362570ff47c75788031f/default/amd64/Dockerfile
FROM almalinux:8.10-20250411
ARG llvm_dir=llvm-project
# Add the cache artifacts and the LLVM source tree to the container
ADD sccache /sccache
Expand All @@ -8,6 +9,7 @@ ENV SCCACHE_CACHE_SIZE="2G"

RUN dnf install --assumeyes llvm-toolset
RUN dnf install --assumeyes python38-pip python38-devel git
RUN alternatives --set python3 /usr/bin/python3.8

RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --upgrade cmake ninja sccache lit
Expand All @@ -26,6 +28,8 @@ RUN cmake -GNinja -Bbuild \
-DCMAKE_CXX_FLAGS="-Wno-everything" \
-DCMAKE_LINKER=lld \
-DCMAKE_INSTALL_PREFIX="/install" \
-DPython3_EXECUTABLE="/usr/bin/python3.8" \
-DPython_EXECUTABLE="/usr/bin/python3.8" \
-DLLVM_BUILD_UTILS=ON \
-DLLVM_BUILD_TOOLS=ON \
-DLLVM_ENABLE_ASSERTIONS=ON \
Expand All @@ -34,6 +38,7 @@ RUN cmake -GNinja -Bbuild \
-DLLVM_ENABLE_TERMINFO=OFF \
-DLLVM_INSTALL_UTILS=ON \
-DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" \
-DLLVM_ENABLE_ZSTD=OFF \
/source/llvm-project/llvm

RUN ninja -C build install
3 changes: 2 additions & 1 deletion python/test/gluon/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,8 +542,9 @@ def kernel(a_ptr, b_ptr, use_buffer_load: ttgl.constexpr):
cdna4_async_copy.buffer_load_to_shared(smem, a_ptr, offsets)
else:
cdna4_async_copy.global_load_to_shared(smem, a_ptr + offsets)
cdna4_async_copy.commit_group()

cdna4_async_copy.async_wait(0)
cdna4_async_copy.wait_group(0)
a = cdna4_async_copy.load_shared_relaxed(smem, blocked)

ttgl.store(b_ptr + offsets, a)
Expand Down
27 changes: 23 additions & 4 deletions python/test/gluon/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,17 +1950,36 @@ def test_infer_layout_for_amd_wmma(target):


@gluon.jit
def amd_async_wait():
cdna4_async_copy.async_wait(0)
def amd_commit_group():
cdna4_async_copy.commit_group()


@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
def test_amd_commit_group(target):
mod = run_parser(amd_wait_group, target=target)
expecttest.assert_expected_inline(
anonymize_ir(mod.str_nodebug()), """\
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @amd_wait_group() attributes {noinline = false} {
%0 = ttg.async_wait {num = 0 : i32}
tt.return
}
}
""")


@gluon.jit
def amd_wait_group():
cdna4_async_copy.wait_group(0)


@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
def test_amd_async_wait(target):
mod = run_parser(amd_async_wait, target=target)
mod = run_parser(amd_wait_group, target=target)
expecttest.assert_expected_inline(
anonymize_ir(mod.str_nodebug()), """\
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @amd_async_wait() attributes {noinline = false} {
tt.func public @amd_wait_group() attributes {noinline = false} {
%0 = ttg.async_wait {num = 0 : i32}
tt.return
}
Expand Down
35 changes: 25 additions & 10 deletions python/triton/experimental/gluon/language/amd/cdna4/async_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
__all__ = [
"global_load_to_shared",
"buffer_load_to_shared",
"async_wait",
"commit_group",
"wait_group",
"load_shared_relaxed",
]

Expand All @@ -17,7 +18,10 @@ def global_load_to_shared(dest, ptr, mask=None, other=None, cache_modifier="", _
AMD global load to shared operation. This operation loads data directly
from global memory to shared memory without going through registers. It
happens asynchronously and requires a subsequent `async_wait` to ensure the
data is available in shared memory.
data is available in shared memory. Note that this operation does still
complete in order with ttgl.loads/stores or buffer_loads/stores on CDNA4,
so interleaving with them will hurt performance.

Compared to `buffer_load_to_shared`, it requires a tensor pointer which
supports 64-bit indexing range for each thread in a block, which gives more
flexibility, but at the cost of higher register pressure and no hardware
Expand Down Expand Up @@ -72,7 +76,10 @@ def buffer_load_to_shared(dest, ptr, offsets, mask=None, other=None, cache_modif
32-bit offsets instead of a tensor of pointers. This operation loads data
directly from global memory to shared memory without going through
registers. It happens asynchronously and requires a subsequent `async_wait`
to ensure the data is available in shared memory.
to ensure thedata is available in shared memory. Note that this operation
does still complete in order with ttgl.loads/stores or buffer_loads/stores
on CDNA4, so interleaving with them will hurt performance.

Compared to `global_load_to_shared`, it has better performance and also
supports hardware out-of-bound masking. But it strictly requires a
32-bit offset instead of a 64-bit tensor pointer.
Expand Down Expand Up @@ -118,16 +125,24 @@ def buffer_load_to_shared(dest, ptr, offsets, mask=None, other=None, cache_modif


@builtin
def async_wait(num_outstanding=0, _semantic=None):
def commit_group(_semantic=None):
"""
Commit oustanding async operations.

This finalizes a set of async copy operations which can be waited upon via `wait_group`.
"""
_semantic.builder.create_async_commit_group()


@builtin
def wait_group(num_outstanding=0, _semantic=None):
"""
Wait for outstanding memory operations, this includes normal load like
`load` and `buffer_load`, as well as direct load to shared memory
like `global_load_to_shared` and `buffer_load_to_shared`.
It will block until the number of outstanding memory operations is less than
or equal to `num_outstanding`.
Wait for outstanding commit groups. It will block until the number of
outstanding commit groups is less than or equal to `num_outstanding`. Note that uncommited
async operations will be waited upon even if `num_outstanding` is 0.

Args:
num_outstanding (int): The number of outstanding operations to wait for. Defaults to 0.
num_outstanding (int): The number of outstanding commit groups to wait for. Defaults to 0.
"""
num_outstanding = _unwrap_if_constexpr(num_outstanding)
_semantic.builder.create_async_wait_group(num_outstanding)
Expand Down
44 changes: 18 additions & 26 deletions python/triton_kernels/tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,27 +303,27 @@ class Case:
],
)
@pytest.mark.parametrize("block_m", [16, 128])
@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter, inner_expt_opt", [
(False, False, False, None),
(True, False, False, None),
(False, True, False, None),
(False, True, True, None),
(True, True, False, None),
(True, True, True, None),
(False, False, False, "pad_w"),
(False, False, False, "pad_x"),
@pytest.mark.parametrize("do_gather, do_scatter, inner_expt_opt", [
(False, False, None),
(True, False, None),
(False, True, None),
(False, True, None),
(True, True, None),
(True, True, None),
(False, False, "pad_w"),
(False, False, "pad_x"),
])
@pytest.mark.parametrize("has_y_gammas", [False, True])
@pytest.mark.parametrize("is_persistent", [False, True])
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
x_transpose, w_transpose, y_transpose,
device, opt_flags_scope):
# We catch and re-invoke pytest.skip(), because otherwise pytest may hold a reference to
# the frame that called pytest.skip, including all the tensors, leading to OOM.
skip_message = None
try:
_test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
_test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
x_transpose, w_transpose, y_transpose,
device, opt_flags_scope)
Expand All @@ -333,7 +333,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
if skip_message is not None:
pytest.skip(skip_message)

def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
x_transpose, w_transpose, y_transpose,
device, opt_flags_scope):
Expand Down Expand Up @@ -362,9 +362,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_
if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3():
pytest.xfail("float8_e4m3fnuz only tested on AMD CDNA3 Platform")

if fused_scatter and split_k is not None and split_k > 1:
pytest.xfail("fused scatter scratchpad not supported with split_k")

if hbm_swizzling:
if is_hip():
if not is_hip_cdna4():
Expand Down Expand Up @@ -414,7 +411,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_
"block_m": block_m,
"block_k": block_k,
"split_k": split_k,
"fused_scatter": fused_scatter,
"is_persistent": is_persistent,
"epilogue_subtile": epilogue_subtile,
}
Expand Down Expand Up @@ -727,12 +723,11 @@ def test_set_idle_sms():
(800, 800, 400, "batched"),
])
@pytest.mark.parametrize("split_k", [1, 2])
@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [
(False, False, False),
(True, False, False),
(False, True, False),
(True, True, False),
(True, True, True),
@pytest.mark.parametrize("do_gather, do_scatter", [
(False, False),
(True, False),
(False, True),
(True, True),
])
@pytest.mark.parametrize("is_persistent, epilogue_subtile", [
(False, None),
Expand All @@ -744,16 +739,13 @@ def test_set_idle_sms():
(1.0, 1.2),
(0.7, 1.0),
])
def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter, is_persistent, epilogue_subtile,
def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, is_persistent, epilogue_subtile,
swiglu_alpha, swiglu_limit, device, opt_flags_scope):
if fused_scatter and split_k > 1:
pytest.xfail("fused scatter scratchpad not supported with split_k")
torch.manual_seed(0)
constraints = {
"is_persistent": is_persistent,
"epilogue_subtile": epilogue_subtile,
"split_k": split_k,
"fused_scatter": fused_scatter,
}
n_expts_tot, n_expts_act = 1, 1
opt_flags.update_opt_flags_constraints(constraints)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def get_flags(split_k, max_mn):
k,
None,
False,
False,
True,
False,
0,
False,
Expand Down
15 changes: 8 additions & 7 deletions python/triton_kernels/triton_kernels/matmul_ogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,10 @@ def matmul_ogs(x, w, bias,
has_gather_tma = has_gather and target_info.has_tma_gather()
# hopper w/ mxfp4 doesn't support TMA
can_use_tma = can_use_tma and is_cuda() and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
can_use_split_k = scatter_indx is None and not x_has_mx and not w_has_mx
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
batch_size, M, N, w.shape[-2], routing_data,
can_use_tma, scatter_indx is not None, epilogue.effective_itemsize,
can_use_tma, can_use_split_k, epilogue.effective_itemsize,
x_transpose, y_acc_in is not None,
inner_routing_data.block_k if inner_routing_data is not None else None,
)
Expand Down Expand Up @@ -618,21 +619,21 @@ def matmul_ogs(x, w, bias,
**fused_comm_kwargs,
**opt_flags.target_kernel_kwargs)

assert not (opt_flags.split_k > 1 and scatter_indx is not None)
out_final_mx_scale = None
if opt_flags.split_k > 1:
assert not out_matmul_has_mx
has_scatter = scatter_indx is not None
postprocess_fn1 = ReducePostprocessFn(specs=reduce_fused_activation.specs, fn_args=reduce_fused_activation.fn_args)
postprocess_fn2 = None if has_scatter else ReducePostprocessFn(specs=epilogue.specs, fn_args=epilogue.fn_arg_values_finalize)
y, y_mx_scale = reduce(
x = out_matmul.view(out_matmul.shape[0], -1, out_matmul.shape[-1]),
dim = 0,
# output data/metadata
y = None if has_scatter else memory["output"].view(-1, memory["output"].shape[-1]),
y_dtype = out_matmul.dtype if has_scatter else memory["output"].dtype,
y_flex = OutFlexData() if has_scatter else precision_config.flex_ctx.out_data,
y_flex_saturate_inf = None if has_scatter else precision_config.flexpoint_saturate_inf,
y_has_mx = scatter_indx is None and precision_config.out_scale is not None,
y = memory["output"].view(-1, memory["output"].shape[-1]),
y_dtype = memory["output"].dtype,
y_flex = precision_config.flex_ctx.out_data,
y_flex_saturate_inf = precision_config.flexpoint_saturate_inf,
y_has_mx = precision_config.out_scale is not None,
# fused functions
postprocess_fn1 = postprocess_fn1,
postprocess_fn2 = postprocess_fn2,
Expand Down
Loading
Loading