From 340ca0da18b4cfa90b58c6666e674e323809fd4c Mon Sep 17 00:00:00 2001 From: Jiaxing Ding <61589029+Paran0idy@users.noreply.github.com> Date: Sat, 1 Nov 2025 07:01:11 +0800 Subject: [PATCH 1/7] [Gluon] fix gluon tutorial example (#8593) --- python/tutorials/gluon/05-wgmma.py | 2 +- setup.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/tutorials/gluon/05-wgmma.py b/python/tutorials/gluon/05-wgmma.py index 3a4dcfea97..bc0f976818 100644 --- a/python/tutorials/gluon/05-wgmma.py +++ b/python/tutorials/gluon/05-wgmma.py @@ -459,7 +459,7 @@ def find_configs(occupancy, dtype, num_buffers=1): if acc_regs > regs: continue - instr_shape_n = get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps).value + instr_shape_n = get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps) configs.append((BLOCK_M, BLOCK_N, BLOCK_K, num_warps, instr_shape_n, occupancy)) def filter_configs(configs, instr_shape_n): diff --git a/setup.py b/setup.py index 78b3e71dce..795eae685b 100644 --- a/setup.py +++ b/setup.py @@ -307,7 +307,11 @@ def get_thirdparty_packages(packages: list): file.extractall(path=package_root_dir) else: with tarfile.open(fileobj=response, mode="r|*") as file: - file.extractall(path=package_root_dir, filter="data") + # Use extractall without filter for Python version < 3.12 compatibility + if hasattr(tarfile, 'data_filter'): + file.extractall(path=package_root_dir, filter="data") + else: + file.extractall(path=package_root_dir) # write version url to package_dir with open(os.path.join(package_dir, "version.txt"), "w") as f: f.write(p.url) @@ -350,7 +354,11 @@ def download_and_copy(name, src_func, dst_path, variable, version, url_func): if download: print(f'downloading and extracting {url} ...') with open_url(url) as url_file, tarfile.open(fileobj=url_file, mode="r|*") as tar_file: - tar_file.extractall(path=tmp_path, filter="data") + # Use extractall without filter for Python version < 3.12 compatibility + if hasattr(tarfile, 'data_filter'): + tar_file.extractall(path=tmp_path, filter="data") + else: + tar_file.extractall(path=tmp_path) os.makedirs(os.path.split(dst_path)[0], exist_ok=True) print(f'copy {src_path} to {dst_path} ...') if os.path.isdir(src_path): From 4c2175f310a180ad4724292201fcf269e8e52b44 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Fri, 31 Oct 2025 20:15:31 -0700 Subject: [PATCH 2/7] [LLVM build] Merge back changes from llvm-head (#8612) Just to keep the branches in sync --- .github/workflows/llvm-build.yml | 23 ++++++++++++------- .../workflows/llvm-build/almalinux.Dockerfile | 7 +++++- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/.github/workflows/llvm-build.yml b/.github/workflows/llvm-build.yml index 8158acebde..3b55eb4424 100644 --- a/.github/workflows/llvm-build.yml +++ b/.github/workflows/llvm-build.yml @@ -9,6 +9,8 @@ on: pull_request: paths: - .github/workflows/llvm-build.yml + - .github/workflows/llvm-build/almalinux.Dockerfile + - .github/workflows/llvm-build/centos.Dockerfile workflow_dispatch: env: @@ -135,6 +137,7 @@ jobs: -DLLVM_INSTALL_UTILS=ON -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" -DLLVM_ENABLE_TERMINFO=OFF + -DLLVM_ENABLE_ZSTD=OFF llvm-project/llvm ninja -C llvm-project/build check-mlir install @@ -237,7 +240,11 @@ jobs: run: | # if this step crashes, it can leave behind a stale docker container docker container prune -f - docker rmi -f $(docker images -q) + + images=$(docker images -q) + if [ -n "$images" ]; then + docker rmi -f $images + fi docker build --tag llvm-build --build-arg llvm_dir=llvm-project \ -f llvm-build/.github/workflows/llvm-build/almalinux.Dockerfile . @@ -264,16 +271,16 @@ jobs: path: | ${{ github.workspace }}/llvm-*-${{ matrix.config.target-os }}-${{ matrix.config.arch }}.tar.gz - - name: Azure Login - if: ${{ (github.repository == 'triton-lang/triton') }} + - name: Azure login + if: ${{ (github.repository == 'triton-lang/triton') && github.ref_name == 'llvm-head' }} uses: azure/login@v2 with: - client-id: ${{ secrets.AZURE_CLIENT_ID }} - tenant-id: ${{ secrets.AZURE_TENANT_ID }} - subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + client-id: ${{ secrets.AZURE_CLIENT_ID_LLVM }} + tenant-id: ${{ secrets.AZURE_TENANT_ID_LLVM }} + subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID_LLVM }} - name: Upload LLVM Artifacts to Azure - if: ${{ (github.repository == 'triton-lang/triton') }} + if: ${{ (github.repository == 'triton-lang/triton') && github.ref_name == 'llvm-head' }} shell: bash -el {0} run: | az storage blob upload --account-name oaitriton --auth-mode login --container-name public --file "${{ env.llvm_install_dir }}.tar.gz" --name "llvm-builds/${{ env.llvm_install_dir }}.tar.gz" --overwrite @@ -282,7 +289,7 @@ jobs: echo "Blob URL: ${URL}" - name: Azure Logout - if: ${{ (github.repository == 'triton-lang/triton') }} + if: ${{ (github.repository == 'triton-lang/triton') && github.ref_name == 'llvm-head' }} run: | az logout az cache purge diff --git a/.github/workflows/llvm-build/almalinux.Dockerfile b/.github/workflows/llvm-build/almalinux.Dockerfile index 317470c00d..ad8e2f4438 100644 --- a/.github/workflows/llvm-build/almalinux.Dockerfile +++ b/.github/workflows/llvm-build/almalinux.Dockerfile @@ -1,4 +1,5 @@ -FROM almalinux:8 +# https://github.com/AlmaLinux/container-images/blob/9f9b3c8c8cf4a57fd42f362570ff47c75788031f/default/amd64/Dockerfile +FROM almalinux:8.10-20250411 ARG llvm_dir=llvm-project # Add the cache artifacts and the LLVM source tree to the container ADD sccache /sccache @@ -8,6 +9,7 @@ ENV SCCACHE_CACHE_SIZE="2G" RUN dnf install --assumeyes llvm-toolset RUN dnf install --assumeyes python38-pip python38-devel git +RUN alternatives --set python3 /usr/bin/python3.8 RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --upgrade cmake ninja sccache lit @@ -26,6 +28,8 @@ RUN cmake -GNinja -Bbuild \ -DCMAKE_CXX_FLAGS="-Wno-everything" \ -DCMAKE_LINKER=lld \ -DCMAKE_INSTALL_PREFIX="/install" \ + -DPython3_EXECUTABLE="/usr/bin/python3.8" \ + -DPython_EXECUTABLE="/usr/bin/python3.8" \ -DLLVM_BUILD_UTILS=ON \ -DLLVM_BUILD_TOOLS=ON \ -DLLVM_ENABLE_ASSERTIONS=ON \ @@ -34,6 +38,7 @@ RUN cmake -GNinja -Bbuild \ -DLLVM_ENABLE_TERMINFO=OFF \ -DLLVM_INSTALL_UTILS=ON \ -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" \ + -DLLVM_ENABLE_ZSTD=OFF \ /source/llvm-project/llvm RUN ninja -C build install From 7025305bd9b6532c6ab710bffb03aff333ae2380 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Sat, 1 Nov 2025 04:46:53 +0000 Subject: [PATCH 3/7] [AMD][GLUON] Wait outstanding async commit groups instead of instructions (#8605) Currently `async_wait` in Gluon on `CDNA4` requires the kernel writer to pass the number of outstanding hardware instructions/llvm intrinsic to `async_wait`. This count is very difficult to compute as it relies on layouts, sizes, contiguity... This PR changes the semantics of `async_wait` to represent the number of outstanding commit groups. This follows the semantics used for nvidia in Gluon. Therefore, Gluon kernels need to commit outstanding async operations via `commit_group` and then wait on them via `wait_group`. I also adapted the names so existing Gluon kernels using the old semantics error out. `UpdateAsyncWaitCount` is extended to compute the number of outstanding hardware instructions based on the number of oustanding commits groups. Previously, it only worked on `async_waits` carrying tokens of the commit groups which are not available when compiling a Gluon kernel. This is done by walking the IR backwards following *all* possible control flow paths and finding the smallest number of emitted instructions for N outstanding commit groups. --- python/test/gluon/test_core.py | 3 +- python/test/gluon/test_frontend.py | 27 +- .../gluon/language/amd/cdna4/async_copy.py | 35 +- ...update-async-wait-count-without-token.mlir | 519 ++++++++++++++++++ .../amd/amd-update-async-wait-count.mlir | 22 +- .../UpdateAsyncWaitCount.cpp | 326 ++++++++--- 6 files changed, 836 insertions(+), 96 deletions(-) create mode 100644 test/TritonGPU/amd/amd-update-async-wait-count-without-token.mlir diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index 6fea0f8a46..2d379b1504 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -541,8 +541,9 @@ def kernel(a_ptr, b_ptr, use_buffer_load: ttgl.constexpr): cdna4_async_copy.buffer_load_to_shared(smem, a_ptr, offsets) else: cdna4_async_copy.global_load_to_shared(smem, a_ptr + offsets) + cdna4_async_copy.commit_group() - cdna4_async_copy.async_wait(0) + cdna4_async_copy.wait_group(0) a = cdna4_async_copy.load_shared_relaxed(smem, blocked) ttgl.store(b_ptr + offsets, a) diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 3751841dc9..9e9ab46902 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -1950,17 +1950,36 @@ def test_infer_layout_for_amd_wmma(target): @gluon.jit -def amd_async_wait(): - cdna4_async_copy.async_wait(0) +def amd_commit_group(): + cdna4_async_copy.commit_group() + + +@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4]) +def test_amd_commit_group(target): + mod = run_parser(amd_wait_group, target=target) + expecttest.assert_expected_inline( + anonymize_ir(mod.str_nodebug()), """\ +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @amd_wait_group() attributes {noinline = false} { + %0 = ttg.async_wait {num = 0 : i32} + tt.return + } +} +""") + + +@gluon.jit +def amd_wait_group(): + cdna4_async_copy.wait_group(0) @pytest.mark.parametrize("target", [HIP_TARGET_CDNA4]) def test_amd_async_wait(target): - mod = run_parser(amd_async_wait, target=target) + mod = run_parser(amd_wait_group, target=target) expecttest.assert_expected_inline( anonymize_ir(mod.str_nodebug()), """\ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} { - tt.func public @amd_async_wait() attributes {noinline = false} { + tt.func public @amd_wait_group() attributes {noinline = false} { %0 = ttg.async_wait {num = 0 : i32} tt.return } diff --git a/python/triton/experimental/gluon/language/amd/cdna4/async_copy.py b/python/triton/experimental/gluon/language/amd/cdna4/async_copy.py index 422364794c..cc83ca2e09 100644 --- a/python/triton/experimental/gluon/language/amd/cdna4/async_copy.py +++ b/python/triton/experimental/gluon/language/amd/cdna4/async_copy.py @@ -6,7 +6,8 @@ __all__ = [ "global_load_to_shared", "buffer_load_to_shared", - "async_wait", + "commit_group", + "wait_group", "load_shared_relaxed", ] @@ -17,7 +18,10 @@ def global_load_to_shared(dest, ptr, mask=None, other=None, cache_modifier="", _ AMD global load to shared operation. This operation loads data directly from global memory to shared memory without going through registers. It happens asynchronously and requires a subsequent `async_wait` to ensure the - data is available in shared memory. + data is available in shared memory. Note that this operation does still + complete in order with ttgl.loads/stores or buffer_loads/stores on CDNA4, + so interleaving with them will hurt performance. + Compared to `buffer_load_to_shared`, it requires a tensor pointer which supports 64-bit indexing range for each thread in a block, which gives more flexibility, but at the cost of higher register pressure and no hardware @@ -72,7 +76,10 @@ def buffer_load_to_shared(dest, ptr, offsets, mask=None, other=None, cache_modif 32-bit offsets instead of a tensor of pointers. This operation loads data directly from global memory to shared memory without going through registers. It happens asynchronously and requires a subsequent `async_wait` - to ensure the data is available in shared memory. + to ensure thedata is available in shared memory. Note that this operation + does still complete in order with ttgl.loads/stores or buffer_loads/stores + on CDNA4, so interleaving with them will hurt performance. + Compared to `global_load_to_shared`, it has better performance and also supports hardware out-of-bound masking. But it strictly requires a 32-bit offset instead of a 64-bit tensor pointer. @@ -118,16 +125,24 @@ def buffer_load_to_shared(dest, ptr, offsets, mask=None, other=None, cache_modif @builtin -def async_wait(num_outstanding=0, _semantic=None): +def commit_group(_semantic=None): + """ + Commit oustanding async operations. + + This finalizes a set of async copy operations which can be waited upon via `wait_group`. + """ + _semantic.builder.create_async_commit_group() + + +@builtin +def wait_group(num_outstanding=0, _semantic=None): """ - Wait for outstanding memory operations, this includes normal load like - `load` and `buffer_load`, as well as direct load to shared memory - like `global_load_to_shared` and `buffer_load_to_shared`. - It will block until the number of outstanding memory operations is less than - or equal to `num_outstanding`. + Wait for outstanding commit groups. It will block until the number of + outstanding commit groups is less than or equal to `num_outstanding`. Note that uncommited + async operations will be waited upon even if `num_outstanding` is 0. Args: - num_outstanding (int): The number of outstanding operations to wait for. Defaults to 0. + num_outstanding (int): The number of outstanding commit groups to wait for. Defaults to 0. """ num_outstanding = _unwrap_if_constexpr(num_outstanding) _semantic.builder.create_async_wait_group(num_outstanding) diff --git a/test/TritonGPU/amd/amd-update-async-wait-count-without-token.mlir b/test/TritonGPU/amd/amd-update-async-wait-count-without-token.mlir new file mode 100644 index 0000000000..0fc738ba43 --- /dev/null +++ b/test/TritonGPU/amd/amd-update-async-wait-count-without-token.mlir @@ -0,0 +1,519 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-update-async-wait-count=arch-generation-name=gfx950 | FileCheck %s + +// The number in SSA symbolic names represents the number of generated async load operation at assembly level a ttg.async_copy_global_to_local will generate, which is counted by this pass. +// For example `ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst ..` will generate two global_load_async_to_lds_b128 assembly instruction + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [0, 1]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + + // CHECK-LABEL: simple_waitcnt + tt.func public @simple_waitcnt( + %cond: i1, + %arg0: i32, + %memDesc2Inst: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, + %ptr2Inst: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, + %memDesc1Inst: !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, + %ptr1Inst: tensor<64x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { + // Emit 1 instruction + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + // Emits 2 instructions + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + + // CHECK: amdgpu.async_wait {num_inst = 0 + ttg.async_wait {num = 0 : i32} + // CHECK: amdgpu.async_wait {num_inst = 2 + ttg.async_wait {num = 1 : i32} + // Check we stop at function boundary + // CHECK: amdgpu.async_wait {num_inst = 3 + ttg.async_wait {num = 2 : i32} + // CHECK: amdgpu.async_wait {num_inst = 3 + ttg.async_wait {num = 3 : i32} + + tt.return + } + + // CHECK-LABEL: simple_waitcnt_non_committed_async_ops + tt.func public @simple_waitcnt_non_committed_async_ops( + %cond: i1, + %arg0: i32, + %memDesc2Inst: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, + %ptr2Inst: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, + %memDesc1Inst: !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, + %ptr1Inst: tensor<64x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { + // Emit 1 instruction + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + + // We expect 1 because the async copy above has not been committed yet + // CHECK: amdgpu.async_wait {num_inst = 1 + ttg.async_wait {num = 0 : i32} + // -1 can be used to wait on all, even non committed async ops + // CHECK: amdgpu.async_wait {num_inst = 0 + ttg.async_wait {num = -1 : i32} + + tt.return + } + + // CHECK-LABEL: wait_if_without_else + tt.func public @wait_if_without_else( + %cond: i1, + %arg0: i32, + %memDesc2Inst: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, + %ptr2Inst: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, + %memDesc1Inst: !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, + %ptr1Inst: tensor<64x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { + // Ensure we look into then but also skip the if if no else is present + + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + scf.if %cond { + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + } + // CHECK: amdgpu.async_wait {num_inst = 1 + ttg.async_wait {num = 1: i32} + + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + scf.if %cond { + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + scf.yield + } + // CHECK: amdgpu.async_wait {num_inst = 1 + ttg.async_wait {num = 1: i32} + + // CHECK: amdgpu.async_wait {num_inst = 3 + ttg.async_wait {num = 2: i32} + + + tt.return + } + + // CHECK-LABEL wait_if_with_else + tt.func public @wait_if_with_else( + %cond: i1, + %arg0: i32, + %memDesc2Inst: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, + %ptr2Inst: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, + %memDesc1Inst: !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, + %ptr1Inst: tensor<64x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { + scf.if %cond { + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + scf.yield + } else { + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + scf.yield + } + ttg.async_commit_group + // Ensure we use the branch with less instructions (then) + // CHECK: amdgpu.async_wait {num_inst = 1 + ttg.async_wait {num = 1: i32} + // Check we do not loop in an if but instead continue upwards + // CHECK: amdgpu.async_wait {num_inst = 1 + ttg.async_wait {num = 2: i32} + + scf.if %cond { + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + scf.yield + } else { + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + scf.yield + } + ttg.async_commit_group + // Ensure we use the branch with less instructions (else) + // CHECK: amdgpu.async_wait {num_inst = 1 + ttg.async_wait {num = 1: i32} + + tt.return + } + + // CHECK-LABEL: check_wait_nested_ifs + tt.func public @check_wait_nested_ifs( + %cond: i1, + %arg0: i32, + %memDesc2Inst: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, + %ptr2Inst: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, + %memDesc1Inst: !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, + %ptr1Inst: tensor<64x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { + scf.if %cond { + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + scf.if %cond { + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + scf.yield + } else { + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + scf.yield + } + ttg.async_commit_group + scf.yield + } else { + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + scf.if %cond { + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + scf.yield + } else { + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + scf.yield + } + ttg.async_commit_group + scf.yield + } + // The shortest path (else->then) contains 2 async ops -> instruction count 2 + // CHECK: amdgpu.async_wait {num_inst = 2 + ttg.async_wait {num = 1: i32} + + tt.return + } + + //CHECK-LABEL: for_without_async_ops + tt.func public @for_without_async_ops( + %cond: i1, + %arg0: i32, + %memDesc2Inst: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, + %ptr2Inst: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, + %memDesc1Inst: !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, + %ptr1Inst: tensor<64x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { + + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + + scf.for %arg14 = %c0_i32 to %arg0 step %c1_i32 iter_args() -> () : i32 { + // CHECK: amdgpu.async_wait {num_inst = 1 + ttg.async_wait {num = 1: i32} + scf.yield + } + // CHECK: amdgpu.async_wait {num_inst = 1 + ttg.async_wait {num = 1: i32} + + tt.return + } + + //CHECK-LABEL: for_with_async_ops + tt.func public @for_with_async_ops( + %cond: i1, + %arg0: i32, + %memDesc2Inst: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, + %ptr2Inst: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, + %memDesc1Inst: !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, + %ptr1Inst: tensor<64x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { + + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + // CHECK: amdgpu.async_wait {num_inst = 6 + ttg.async_wait {num = 3: i32} + + scf.for %arg14 = %c0_i32 to %arg0 step %c1_i32 : i32 { + // The minimum it waits are 3 loop iteration with 1 instructions per iteration. Note the prologue would lead to 6 + // CHECK: amdgpu.async_wait {num_inst = 3 + ttg.async_wait {num = 3: i32} + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + scf.yield + } + // The minimum it waits are 3 loop iteration with 1 instructions per iteration. Note the prologue would lead to 6 + // CHECK: amdgpu.async_wait {num_inst = 3 + ttg.async_wait {num = 3: i32} + + tt.return + } + + //CHECK-LABEL: for_nested_control_flow + tt.func public @for_nested_control_flow( + %cond: i1, + %arg0: i32, + %memDesc2Inst: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, + %ptr2Inst: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, + %memDesc1Inst: !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, + %ptr1Inst: tensor<64x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { + + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + // Prologue: 2 instructions per commit group + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + + // The loop has 3 commits group which produce 2,1,1 (in program order) async instructions + scf.for %arg14 = %c0_i32 to %arg0 step %c1_i32 : i32 { + // 2 full loop iterations => 8 + // CHECK: amdgpu.async_wait {num_inst = 8 + ttg.async_wait {num = 6: i32} + + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + + // Wait on 1 full loop iteration (4) + the commit group above (2) + // CHECK: amdgpu.async_wait {num_inst = 6 + ttg.async_wait {num = 4: i32} + + scf.if %cond { + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + } else { + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + } + ttg.async_commit_group + + scf.if %cond { + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + } else { + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + } + ttg.async_commit_group + + // Wait on 1 full loop iteration (4) + the commit group above (1) + // CHECK: amdgpu.async_wait {num_inst = 5 + ttg.async_wait {num = 4: i32} + + scf.yield + } + // 2 Full loop iterations (2 * 4) + // CHECK: amdgpu.async_wait {num_inst = 8 + ttg.async_wait {num = 6: i32} + + tt.return + } + + // CHECK-LABEL: while_without_async_ops + tt.func public @while_without_async_ops( + %cond: i1, + %arg0: i32, + %memDesc2Inst: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, + %ptr2Inst: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, + %memDesc1Inst: !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, + %ptr1Inst: tensor<64x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { + + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + // Check we are not getting stuck in loops with no async ops + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + %69 = scf.while (%arg10 = %cond) : (i1) -> (i1) { + // CHECK: amdgpu.async_wait {num_inst = 2 + ttg.async_wait {num = 1: i32} + scf.condition(%arg10) %arg10 : i1 + } do { + ^bb0(%arg12: i1): + // CHECK: amdgpu.async_wait {num_inst = 2 + ttg.async_wait {num = 1: i32} + scf.yield %arg12 : i1 + } + // CHECK: amdgpu.async_wait {num_inst = 2 + ttg.async_wait {num = 1: i32} + + tt.return + } + + // CHECK-LABEL: while_async_op_in_before_block + tt.func public @while_async_op_in_before_block( + %cond: i1, + %arg0: i32, + %memDesc2Inst: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, + %ptr2Inst: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, + %memDesc1Inst: !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, + %ptr1Inst: tensor<64x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { + + // Check we are following control flow and count inside the before block + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + // CHECK: amdgpu.async_wait {num_inst = 6 + ttg.async_wait {num = 3: i32} + + %70 = scf.while (%arg10 = %cond) : (i1) -> (i1) { + // Count before block 3 times + // CHECK: amdgpu.async_wait {num_inst = 3 + ttg.async_wait {num = 3: i32} + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + scf.condition(%arg10) %arg10 : i1 + } do { + ^bb0(%arg12: i1): + // Count before block 3 times + // CHECK: amdgpu.async_wait {num_inst = 3 + ttg.async_wait {num = 3: i32} + scf.yield %arg12 : i1 + } + // Count before block 3 times + // CHECK: amdgpu.async_wait {num_inst = 3 + ttg.async_wait {num = 3: i32} + + tt.return + } + + // CHECK-LABEL: while_async_op_in_after_block + tt.func public @while_async_op_in_after_block( + %cond: i1, + %arg0: i32, + %memDesc2Inst: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, + %ptr2Inst: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, + %memDesc1Inst: !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, + %ptr1Inst: tensor<64x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { + + // Check we are following control flow and count inside the after block + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + // CHECK: amdgpu.async_wait {num_inst = 6 + ttg.async_wait {num = 3: i32} + + %71 = scf.while (%arg10 = %cond) : (i1) -> (i1) { + // Count after block 3 times + // CHECK: amdgpu.async_wait {num_inst = 3 + ttg.async_wait {num = 3: i32} + scf.condition(%arg10) %arg10 : i1 + } do { + ^bb0(%arg12: i1): + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + // Count after block 4 times + // CHECK: amdgpu.async_wait {num_inst = 4 + ttg.async_wait {num = 4: i32} // 4 because we moved the wait after the next prefetch + scf.yield %arg12 : i1 + } + // Count after block 3 times + // CHECK: amdgpu.async_wait {num_inst = 3 + ttg.async_wait {num = 3: i32} + + tt.return + } + + //CHECK-LABEL: nested_loops_and_if + tt.func public @nested_loops_and_if( + %cond: i1, + %arg0: i32, + %memDesc2Inst: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, + %ptr2Inst: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, + %memDesc1Inst: !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, + %ptr1Inst: tensor<64x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { + + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + // CHECK: amdgpu.async_wait {num_inst = 6 + ttg.async_wait {num = 6: i32} + + %70 = scf.while (%arg10 = %cond) : (i1) -> (i1) { + // Escape while and count prologue = 6 + // CHECK: amdgpu.async_wait {num_inst = 6 + ttg.async_wait {num = 6: i32} + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + // 2 Instructions + scf.condition(%arg10) %arg10 : i1 + } do { + ^bb0(%arg12: i1): + // 1 commit group in Before-block + 5 commits groups in prologue = 7 + // CHECK: amdgpu.async_wait {num_inst = 7 + ttg.async_wait {num = 6: i32} + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + ttg.async_commit_group + // 2 Instructions + + scf.for %arg14 = %c0_i32 to %arg0 step %c1_i32 : i32 { + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + // 2 Instructions + ttg.async_commit_group + // 1 commit group(2) to escape for, 1 commits group(2) in rest of while after block, 1 commit group (2) in while before block and 3 commits group in prologue = 9 + // CHECK: amdgpu.async_wait {num_inst = 9 + ttg.async_wait {num = 6: i32} + + scf.if %cond { + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + + // Same as above but we also have to count the 2 async_copies above = 9+3 + // CHECK: amdgpu.async_wait {num_inst = 12 + ttg.async_wait {num = 6: i32} + } else { + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + } + // 2 Instructions (else) + ttg.async_commit_group + + scf.if %cond { + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + // 3 Instructions + ttg.async_commit_group + // 1 commit group (3) in this block, 2 commits group in the rest of the for body (2+2), 1 commits group(2) in rest of while after block, 1 commit group (2) in while before block, 1 commit group (1) in epilogue = 12 + // CHECK: amdgpu.async_wait {num_inst = 12 + ttg.async_wait {num = 6: i32} + } + // Same as above but skips the if (first commit group(3)) and instead counts one more in the prologue (1) = 10 + // CHECK: amdgpu.async_wait {num_inst = 10 + ttg.async_wait {num = 6: i32} + scf.for %arg15 = %c0_i32 to %arg0 step %c1_i32 : i32 { + ttg.async_copy_global_to_local %ptr1Inst, %memDesc1Inst : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> + // 1 Instruction + ttg.async_commit_group + ttg.async_copy_global_to_local %ptr2Inst, %memDesc2Inst : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + // 2 Instructions + ttg.async_commit_group + // Just staying in the loop is the lowest path (3 per iteration and we do 3 iterations) + // CHECK: amdgpu.async_wait {num_inst = 9 + ttg.async_wait {num = 6: i32} + scf.yield + } + // Just stay in the inner loop for the lowest path + // CHECK: amdgpu.async_wait {num_inst = 9 + ttg.async_wait {num = 6: i32} + scf.yield + } + scf.yield %arg12 : i1 + } + // While before-body (2) + 5 prologue groups = 7 + // CHECK: amdgpu.async_wait {num_inst = 7 + ttg.async_wait {num = 6: i32} + + tt.return + } + +} diff --git a/test/TritonGPU/amd/amd-update-async-wait-count.mlir b/test/TritonGPU/amd/amd-update-async-wait-count.mlir index c032ca03eb..40771d388b 100644 --- a/test/TritonGPU/amd/amd-update-async-wait-count.mlir +++ b/test/TritonGPU/amd/amd-update-async-wait-count.mlir @@ -9,7 +9,7 @@ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: simple_waitcnt - tt.func public @simple_waitcnt(%arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked>, %arg4: tensor<16x256x!tt.ptr, #blocked1>) { + tt.func public @simple_waitcnt(%arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, %arg4: tensor<16x256x!tt.ptr, #blocked1> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { // Emits 1 direct to lds instruction %0 = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> %1 = ttg.async_commit_group tokens %0 @@ -38,7 +38,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: simple_waitcnt_reversed - tt.func public @simple_waitcnt_reversed(%arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked>, %arg4: tensor<16x256x!tt.ptr, #blocked1>) { + tt.func public @simple_waitcnt_reversed(%arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, %arg4: tensor<16x256x!tt.ptr, #blocked1> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { // Emits 1 direct to lds instruction %0 = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> %1 = ttg.async_commit_group tokens %0 @@ -67,7 +67,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: simple_waitcnt_with_tt_load - tt.func public @simple_waitcnt_with_tt_load(%arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked>, %arg4: tensor<16x256x!tt.ptr, #blocked1>) { + tt.func public @simple_waitcnt_with_tt_load(%arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, %arg4: tensor<16x256x!tt.ptr, #blocked1> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { // Emits 1 direct to lds instruction %0 = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> %1 = ttg.async_commit_group tokens %0 @@ -96,7 +96,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL wait_in_for_loop - tt.func public @wait_in_for_loop(%arg0: i32, %arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked>, %arg4: tensor<16x256x!tt.ptr, #blocked1>) { + tt.func public @wait_in_for_loop(%arg0: i32, %arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, %arg4: tensor<16x256x!tt.ptr, #blocked1> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 // Emits 1 direct to lds instruction @@ -131,7 +131,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL double_buffering - tt.func public @double_buffering(%arg0: i32, %arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked>, %arg4: tensor<16x256x!tt.ptr, #blocked1>) { + tt.func public @double_buffering(%arg0: i32, %arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, %arg4: tensor<16x256x!tt.ptr, #blocked1> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 // Emits 1 direct to lds instruction @@ -169,7 +169,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: double_buffering_wait_in_if - tt.func public @double_buffering_wait_in_if(%cond: i1, %arg0: i32, %arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked>, %arg4: tensor<16x256x!tt.ptr, #blocked1>) { + tt.func public @double_buffering_wait_in_if(%cond: i1, %arg0: i32, %arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, %arg4: tensor<16x256x!tt.ptr, #blocked1> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 // Emits 1 direct to lds instruction @@ -217,7 +217,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: doube_buffering_wait_loads_in_if - tt.func public @doube_buffering_wait_loads_in_if(%cond: i1, %arg0: i32, %arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked>, %arg4: tensor<16x256x!tt.ptr, #blocked1>) { + tt.func public @doube_buffering_wait_loads_in_if(%cond: i1, %arg0: i32, %arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, %arg4: tensor<16x256x!tt.ptr, #blocked1> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 // Emits 1 direct to lds instruction @@ -264,7 +264,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: double_buffering_uneven_then_else - tt.func public @double_buffering_uneven_then_else(%cond: i1, %arg0: i32, %arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked>, %arg4: tensor<16x256x!tt.ptr, #blocked1>) { + tt.func public @double_buffering_uneven_then_else(%cond: i1, %arg0: i32, %arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg2: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, %arg4: tensor<16x256x!tt.ptr, #blocked1> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 // Emits 1 direct to lds instruction @@ -311,7 +311,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: dynamic_loop_in_def_chain - tt.func public @dynamic_loop_in_def_chain(%arg0: i32, %arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked>) { + tt.func public @dynamic_loop_in_def_chain(%arg0: i32, %arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %c4_i32 = arith.constant 4 : i32 @@ -345,7 +345,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: constant_loop_in_def_chain - tt.func public @constant_loop_in_def_chain(%arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked>) { + tt.func public @constant_loop_in_def_chain(%arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %c4_i32 = arith.constant 4 : i32 @@ -379,7 +379,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: mix_async_copy_and_async_tdm_copy - tt.func public @mix_async_copy_and_async_tdm_copy(%memDesc: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %tensorDesc: !tt.tensordesc>, %mask: i1, %ptr: tensor<128x16x!tt.ptr, #blocked> + tt.func public @mix_async_copy_and_async_tdm_copy(%memDesc: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %tensorDesc: !tt.tensordesc>, %mask: i1, %ptr: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>} ) { %c0_i32 = arith.constant 0 : i32 diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/UpdateAsyncWaitCount.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/UpdateAsyncWaitCount.cpp index fa39cb4dbd..e44b6c7168 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/UpdateAsyncWaitCount.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/UpdateAsyncWaitCount.cpp @@ -1,22 +1,44 @@ +#include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "TritonAMDGPUTransforms/Passes.h" #include "amd/lib/TritonAMDGPUToLLVM/Utility.h" #include "amd/lib/TritonAMDGPUTransforms/Utility.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include -// This pass updates the waitCount of `AsyncWait` Ops to represent the number of -// inflight async load operation between the async_wait and the definition of -// the AsyncToken, thus allowing to wait only on the dependent async loads -// allowing loads issued after to complete in the future. -// This also means we should never overestimate the value to ensure -// correctness; being conservative and underestimating is fine given that only -// affects performance -// For each async_wait we need to compute the minimum across all AsyncToken -// operands. -// For each token the minimum number of async transaction along it's -// def chain is deduced. A token can be copied when passing in as loop initial -// argument and yielded from a loop body in which case we need to take the -// minimum along both paths. -// We do not exit early if we encounter another async_wait along the def chain -// because the pipeliner will merge redundant waits for us already +// This pass computes, for each AsyncWait, the number of outstanding async +// intrinsics that must be waited on. An AsyncWait can specify its wait target +// either via AsyncToken operands or via an explicit count (num) of outstanding +// async operations, with tokens taking precedence. To preserve correctness, the +// pass must never overestimate the wait count; underestimation only impacts +// performance by waiting more conservatively. The wait count represents the +// number of hardware instructions/intrinsics corresponding to the outstanding +// async operations. For waits that carry async tokens, the pass walks the +// def-use chains of each token and sums the number of async intrinsics +// oustanding excluding the producer of the async token. Tokens may be copied +// across loop boundaries (e.g., passed as loop initial arguments and yielded +// from the loop body); in such cases, the pass takes the minimum count across +// the possible paths. The final wait count is the minimum over all tokens and +// their paths. For waits without tokens the count represent the number of +// outstanding ttg.async_commit_groups (inclusive). The pass scans the IR +// backward to find the specified num async commit groups and computes the +// number of outstanding async intrinsics from async operations. Note that we +// walk until we find n+1 commit groups to include all async ops of the n'th +// commit group. Again, when multiple paths are possible, the pass takes the +// minimum count across all paths needed to reach num async operations. For +// ttg.async_wait we count: +// - On GFX9 the number of direct-to-lds instructions. We ignore loads to +// registers since we do not control the vectorization (llvm can change it). +// Therefore interleaving direct-to-lds and loads to registers will produce +// conservative waits. +// - On GFX1250 the number of (multicast) async_load and async_stores. On +// GFX1250 those are out of order with register loads so we will not get +// conservative waits. +// For amdgpu.tdm_async_wait we only count TDM ops. Each tdm_load/store will +// produce exactly one instruction so it directly correlates with OP at TGGIR +// level. namespace tt = triton; namespace ttg = triton::gpu; @@ -30,10 +52,12 @@ namespace { // Returns the number of individual async load memory transactions when copy // data from the given |srcTy| in global memory to the given |dstTy| in shared -// memory. -int getNumberOfLoadInstructions(RankedTensorType srcTy, - ttg::MemDescType dstTy) { - LinearLayout srcLayout = tt::gpu::toLinearLayout(srcTy); +// memory. This takes into account the mask and ptrs alignment and contiguoutiy +// as well as the layouts mapping from global to shared memory addresses +int getNumberOfLoadInstructions(TypedValue ptrs, + ttg::MemDescType dstTy, Value mask, + ModuleAxisInfoAnalysis &axisInfo) { + LinearLayout srcLayout = tt::gpu::toLinearLayout(ptrs.getType()); LinearLayout sharedLayout; if (auto paddedEnc = dyn_cast( dstTy.getEncoding())) { @@ -47,71 +71,223 @@ int getNumberOfLoadInstructions(RankedTensorType srcTy, // need coalesced writes. So we can divide the number of registers by the // contiguity to get the number of load instructions. int contig = srcToSharedLayout.getNumConsecutiveInOut(); + + // Further restrict by contiguity information for ptr and mask + auto order = tt::gpu::getOrder(ptrs.getType()); + auto *ptrInfo = axisInfo.getAxisInfo(ptrs); + contig = std::min(contig, LLVM::AMD::getVectorSize(ptrs, axisInfo)); + if (mask) + contig = std::min(contig, axisInfo.getMaskAlignment(mask)); + int numberOfRegisters = srcToSharedLayout.getInDimSize( - StringAttr::get(srcTy.getContext(), "register")); + StringAttr::get(ptrs.getContext(), "register")); int loadInstructionCount = std::max(1, numberOfRegisters / contig); return loadInstructionCount; } -// The pipeliner always insert ops following an order of ttg.async_load -> -// [token] -> ttg.async_commit_group -> [token] -> ttg.async_wait. So here we -// scan the operands of ttg.async_commit_group to count the number of issued -// async load intrinsics. -int getNumOfAsyncLoadInstructionsForOp(Operation *op, +// Return the number of generated intrinsics for async ops; 0 otherwise +// If emitRemarkOnNonAsyncOp is set for any non async op having a side effect on +// GlobalMemory an performance remark will be emitted +int getOpNumberOfAsyncLoadInstructions(Operation *op, + AMD::TargetInfo targetInfo, + ModuleAxisInfoAnalysis &axisInfo, bool emitRemarkOnNonAsyncOp) { - if (isa(op)) { - int count = 0; - for (auto token : op->getOperands()) { - auto defOp = token.getDefiningOp(); - if (!defOp) - continue; - if (auto copyOp = llvm::dyn_cast(defOp)) { - count += getNumberOfLoadInstructions(copyOp.getSrc().getType(), - copyOp.getResult().getType()); - } else if (auto copyOp = - llvm::dyn_cast(defOp)) { - auto srcTy = cast(LLVM::AMD::getPointerTypeWithShape( - copyOp.getPtr(), copyOp.getOffsets())); - count += getNumberOfLoadInstructions(srcTy, copyOp.getDest().getType()); - } + if (auto copyOp = dyn_cast(op)) { + return getNumberOfLoadInstructions(copyOp.getSrc(), + copyOp.getResult().getType(), + copyOp.getMask(), axisInfo); + } else if (emitRemarkOnNonAsyncOp) { + SmallVector effects; + if (auto memEffectIface = dyn_cast(op)) + memEffectIface.getEffectsOnResource(triton::GlobalMemory::get(), effects); + if (!effects.empty()) { + op->emitRemark("Global memory operation between async wait and " + "async_loads. This will hinder the interleaving of memory " + "operations and might impact performance."); } - return count; - } - if (emitRemarkOnNonAsyncOp && - isa(op)) { - op->emitRemark("Global memory operation between async wait and " - "async_loads. This will hinder the interleaving of memory " - "operations and might impact performance."); } return 0; } -// LLVM cannot infer the dependency between direct to lds (async) loads and -// the local reads between warps in a workgroup. As a workaround we update the -// waitcnt to represent the number of hardware instructions we are -// interleaving with. This allows us to manually emit the waitcnt during -// lowering. +// Walks the IR backwards and accumulates countFunc(op) until we find +// numOustanding ops returning a non zero value. For control flow all possible +// paths are walked in a recursive DFS way and the minimum number found along +// all paths is returned. For unsupported ops with subregions it will return a +// conservative wait count to avoid incorrect waits. Parameters: +// - `cursor`: the operation we walk backwards from +// - `cameFrom`: tracks the operation we most recently stepped from as we +// walk backwards, so we can disambiguate how to traverse multi-block ops +// - `numOutstanding`: remaining countFunc(op) > 0 to visit before acc stops +// - `pathSum`: accumulated result along the current path +// - `bestPath`: current found minimum when reaching numOutstanding or start of +// the kernel +// - `branchStateCache`: memoization cache to stop walking multi blocks +// ops already visited with the same number of outstanding ops. This +// prevents infinite recursion depths for loops without ops contributing +// - `countFunc`: called on ops to determine if they contribute to the pathSum +// TODO: walk static loops correctly to avoid conservative loops. (static loops +// from Gluon are unrolled right now) +using MemoCache = llvm::DenseSet>; +int computeMinCountBackward(Operation *cursor, Operation *cameFrom, + int numOutstanding, int pathSum, int bestPath, + MemoCache &branchStateCache, + llvm::function_ref countFunc) { + assert(cameFrom != nullptr); + // Step to the previous op within the current block; if none, step to + // the parent op. Stop at the module since it asserts on ->getPrevNode(). + auto getPredecessor = [&cameFrom](Operation *op) { + auto prevOp = op->getPrevNode(); + if (!prevOp) { + prevOp = op->getParentOp(); + if (isa(prevOp)) { + prevOp = nullptr; + } + } + + return prevOp; + }; + + // Continues the walk and updates bestPath to stop exploration early for paths + // leading to a higher sum; repeated calls will return monotonically + // decreasing values + auto continueWalkFrom = [&](Operation *newCursor) { + auto pathResult = + computeMinCountBackward(newCursor, cursor, numOutstanding, pathSum, + bestPath, branchStateCache, countFunc); + bestPath = std::min(bestPath, pathResult); + return pathResult; + }; + + // Walk backwards through the IR + while (cursor) { + // numOutstanding is inclusive so we have to walk until < 0 to include the + // async ops from the last outstanding commit group. Also prune path if the + // current path cannot beat the known minimum. + if (numOutstanding < 0 || pathSum >= bestPath) { + return std::min(bestPath, pathSum); + } + + // Handle operations with subregions. + if (auto ifOp = dyn_cast(cursor)) { + // Traversal depends on where we came from: + // If cameFrom is the successor of the ifOp, we walk the then and else + // blocks. If there is no else block we continue upwards instead since we + // could skip the if in case the condition is false. + // If cameFrom is from then/else regions continue upwards + bool cameFromThenOrElse = cameFrom->getParentOp() == ifOp; + if (cameFromThenOrElse) { + continueWalkFrom(getPredecessor(ifOp)); + } else { + continueWalkFrom(ifOp.getThenRegion().front().getTerminator()); + if (!ifOp.getElseRegion().empty()) { + continueWalkFrom(ifOp.getElseRegion().front().getTerminator()); + } else { + continueWalkFrom(getPredecessor(ifOp)); + } + } + return bestPath; + } else if (auto forOp = dyn_cast(cursor)) { + // We walk upwards (skip/escape for body) and walk the body + continueWalkFrom(getPredecessor(forOp)); + + // If we came from the body only walk it again if it's not in the cache + auto cameFromBody = cameFrom->getBlock() == forOp.getBody(); + auto cacheKey = std::make_tuple(cursor, numOutstanding, pathSum); + if (!cameFromBody || branchStateCache.insert(cacheKey).second) { + continueWalkFrom(forOp.getBody()->getTerminator()); + } + return bestPath; + } else if (auto whileOp = dyn_cast(cursor)) { + // Traversal depends on which region we came from: + // - Came from successor -> before-body + // - Came from before-body -> after-body and upwards + // - Came from after-body -> before-body. + Block *lastBlock = cameFrom->getBlock(); + bool cameFromBefore = lastBlock == whileOp.getBeforeBody(); + bool cameFromAfter = lastBlock == whileOp.getAfterBody(); + bool cameFromSuccessor = !cameFromAfter && !cameFromBefore; + + if (cameFromAfter || cameFromSuccessor) { + // Walk before body + continueWalkFrom(whileOp.getBeforeBody()->getTerminator()); + } else if (cameFromBefore) { + // Walk upwards + continueWalkFrom(getPredecessor(whileOp)); + // Do not walk the after-block if we already visited it with a lower + // num outstanding because we already walked an identical path + auto cacheKey = std::make_tuple(cursor, numOutstanding, pathSum); + if (branchStateCache.insert(cacheKey).second) + continueWalkFrom(whileOp.getAfterBody()->getTerminator()); + } + return bestPath; + } else if (isa(cursor)) { + // Reached function boundary; return current sum (conservative) + return std::min(bestPath, pathSum); + } else if (cursor->getNumRegions() > 0 && !isa(cursor)) { + // For unhandled ops with subregions we conservatively bail out. + // We ignore triton.reduce because it cannot contain async ops + cursor->emitRemark( + "has subregions but is not analyzed when determining async " + "wait count; this yields conservative waits"); + return 0; + } + + // Non-control-flow ops: keep walking and accumulate via countFunc + pathSum += countFunc(cursor); + if (isa(cursor)) { + numOutstanding--; + } + + cameFrom = cursor; + cursor = getPredecessor(cursor); + } + // No more ops or parents to traverse; return the accumulated count. + return std::min(pathSum, bestPath); +} + +// Overload for ease of use with AsyncWait, see documentation above +int computeMinCountBackward(ttg::AsyncWaitOp waitOp, + llvm::function_ref countFunc) { + MemoCache memoCache; + return computeMinCountBackward(waitOp, waitOp, waitOp.getNum(), 0, + std::numeric_limits::max(), memoCache, + countFunc); +} + +// Follows the tokens of waitOp or walks the IR backwards from waitOp and +// modifies the waitCnt in place based on the accumulated result of +// computeCountForOp on interleaved instructions. See the file header for more +// details. template void updateWaitCount(WaitType waitOp, llvm::function_ref computeCountForOp, RewriterBase &rewriter) { int waitCnt = std::numeric_limits::max(); - // AsyncWait can await multiple tokens so we get the minimum from all - // tokens - for (auto token : waitOp.getOperands()) { - // Traverse def chain from waitOp to the producer of the token and count - // the minumum number of vmcnt instructions - auto tokenWaitCnt = - deduceMinCountOnDefChain(token, waitOp, computeCountForOp); - waitCnt = std::min(waitCnt, tokenWaitCnt); + if (waitOp.getNumOperands() > 0) { + // AsyncWait can await multiple tokens so we get the minimum from all + // tokens + for (auto token : waitOp.getOperands()) { + // Traverse def chain from waitOp to the producer of the token and count + // the minumum number of vmcnt instructions + auto tokenWaitCnt = + deduceMinCountOnDefChain(token, waitOp, computeCountForOp); + waitCnt = std::min(waitCnt, tokenWaitCnt); + } + } else { + // For AsyncWait we have to count the actual intrinsics instead of + // ttgir ops. For TDM wait this is not required as each tdm load will emit + // exactly one tensor load so we can keep the count. + if constexpr (std::is_same_v) { + waitCnt = computeMinCountBackward(waitOp, computeCountForOp); + } else { + waitCnt = waitOp.getNum(); + } } if (waitCnt == std::numeric_limits::max()) { - // TODO(alex): set to conservative waitcnt=0 after gluon refactoring - waitCnt = waitOp.getNum(); + // Could not determine wait count, emit conservative waitCnt=0 + waitCnt = 0; } if (std::is_same_v) { @@ -158,19 +334,29 @@ struct TritonAMDGPUUpdateAsyncWaitCountPass ModuleOp m = getOperation(); + // ttg.async_wait should only count async **non** tdm load: SmallVector waitOps; getOperation()->walk( [&](ttg::AsyncWaitOp waitOp) { waitOps.push_back(waitOp); }); + ModuleAxisInfoAnalysis axisInfo(m); + // Cache #intrinsic per asyc op to avoid expensive recomputations + DenseMap intrinsicCountCache; + auto countAsyncLoadInstructions = [&](Operation *op) { + auto found = intrinsicCountCache.find(op); + if (found != intrinsicCountCache.end()) { + return found->second; + } + auto v = getOpNumberOfAsyncLoadInstructions(op, targetInfo, axisInfo, + !supportsAsyncLoads); + intrinsicCountCache[op] = v; + return v; + }; + // Note: AsyncWaits should ignore TDM ops; different HW counter for (auto waitOp : waitOps) { IRRewriter builder(waitOp->getContext()); - updateWaitCount( - waitOp, - [&](Operation *op) { - return getNumOfAsyncLoadInstructionsForOp(op, !supportsAsyncLoads); - }, - builder); + updateWaitCount(waitOp, countAsyncLoadInstructions, builder); } } }; From de8e71503fea971dfb65308147798657e18f8568 Mon Sep 17 00:00:00 2001 From: Zeng Wu Date: Fri, 31 Oct 2025 23:32:39 -0700 Subject: [PATCH 4/7] [AMD] Optimize gfx9 wave id code generation (#8601) On GFX9, this PR lifts computations of `wave_id` to the entry of the function and additionally emit `lvm.amdgcn.readfirstlane`. This gives us optimized code generation inside the loop. --- .../amd/buffer_load_to_local_to_llvm.mlir | 45 ++++++++++++++++--- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 43 ++++++++++++++++-- 2 files changed, 79 insertions(+), 9 deletions(-) diff --git a/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir b/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir index 3696bf6801..760501017e 100644 --- a/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir +++ b/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir @@ -187,11 +187,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr %arg2: !ttg.memdesc<64xf32, #shared, #smem, mutable>) { %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked> // The first constant 0 skips the LDS offset which is also 0 - // COMMON: llvm.getelementptr - // COMMON: llvm.mlir.constant(0 : i32) : i32 - // COMMON: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32 - // COMMON: llvm.mlir.constant(0 : i32) : i32 - // COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]] + // COMMON: %[[VOFFSET:.*]] = llvm.select + // COMMON-NEXT: %[[IMM0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // COMMON-NEXT: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32 + // COMMON-NEXT: %[[IMM1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // COMMON-NEXT: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, %[[VOFFSET]], %[[IMM1]], %[[IMM0]], %[[aux_ca]] %1 = amdgpu.buffer_load_to_local %arg0[%0] cacheModifier = ca into %arg2: [tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable> // COMMON: llvm.getelementptr // COMMON: %[[aux_cg:.*]] = llvm.mlir.constant(3 : i32) : i32 @@ -328,3 +328,38 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { + // COMMON-LABEL: buffer_load_to_local_wave_id + tt.func public @buffer_load_to_local_wave_id(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg2: !ttg.memdesc<64xf32, #shared, #smem, mutable>, %arg3: i32) { + // COMMON: %0 = rocdl.workitem.id.x : i32 + // COMMON-NEXT: %1 = llvm.mlir.constant(63 : i32) : i32 + // COMMON-NEXT: %2 = llvm.and %0, %1 : i32 + // COMMON-NEXT: %3 = llvm.mlir.constant(64 : i32) : i32 + // COMMON-NEXT: %4 = llvm.mlir.constant(0 : i32) : i32 + // COMMON-NEXT: %5 = llvm.call_intrinsic "llvm.amdgcn.readfirstlane"(%4) : (i32) -> i32 + // COMMON-NEXT: %6 = rocdl.workitem.id.x : i32 + // COMMON-NEXT: %7 = llvm.mlir.constant(63 : i32) : i32 + // COMMON-NEXT: %8 = llvm.and %6, %7 : i32 + // COMMON-NEXT: %9 = llvm.mlir.constant(64 : i32) : i32 + // COMMON-NEXT: %10 = llvm.mlir.constant(0 : i32) : i32 + // COMMON-NEXT: %11 = llvm.call_intrinsic "llvm.amdgcn.readfirstlane"(%10) : (i32) -> i32 + + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked> + %1 = amdgpu.buffer_load_to_local %arg0[%0] into %arg2: [tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable> + %c0_i32 = arith.constant 0 : i32 + %cond = llvm.icmp "eq" %arg3, %c0_i32 : i32 + cf.cond_br %cond, ^bb1, ^bb2 + ^bb1: + amdgpu.buffer_load_to_local %arg0[%0] into %arg2: [tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable> + cf.br ^bb1 + ^bb2: + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index b953188fe9..c6d3bbc4a0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -482,7 +482,7 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase { void lowerDirectToLDSLoad( RewriterBase &rewriter, Location loc, RankedTensorType srcTy, MemDescType dstTy, SmallVector loadVals, Value llDst, - Type resElemTy, unsigned vec, + Type resElemTy, unsigned vec, triton::AMD::ISAFamily isaFamily, std::function(RewriterBase &, Location, ArrayRef, Value, int, VectorType)> lowerInst) const { @@ -511,7 +511,40 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase { LLVM::getSharedMemoryObjectFromStruct(loc, llDst, resElemTy, rewriter); auto affineOffset = smemObj.getShmemOffset(loc, rewriter, dstTy); auto maskSpanAffineOffset = SharedMemoryObject::getMaskSpanOffsets(dstTy); - auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + + Value laneId, warpId; + if (ISAFamily::CDNA3 == isaFamily || ISAFamily::CDNA4 == isaFamily) { + // On GFX9, there is no dedicated hardware instruction to read `wave_id`. + // The value is instead computed from `workitem.id.x`. Per the GFX9 ABI, + // `workitem.id.x` is initialized in a vector register, and vector + // instructions are generated for IR operations that depend on `wave_id`. + // + // A `v_readfirstlane` instruction is inserted at the end of these vector + // sequences to transfer the value from a vector register to a scalar + // register, initializing `$m0`. + + // When this sequence occurs inside a loop, the MachineLICM pass does not + // hoist it because `v_readfirstlane` is convergent. Since both + // `workitem.id.x` and `wave_id` are constant at runtime, their + // computation can be safely hoisted to the function entry block. + auto insertPt = rewriter.saveInsertionPoint(); + Operation *parentOp = insertPt.getBlock()->getParentOp(); + while (!isa(parentOp)) { + parentOp = parentOp->getParentOp(); + } + + auto funcOp = cast(parentOp); + rewriter.setInsertionPointToStart(&funcOp.getBody().front()); + + std::tie(laneId, warpId) = getLaneAndWarpId(rewriter, loc); + auto call = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, "llvm.amdgcn.readfirstlane", {i32_ty}, {warpId}); + warpId = call.getResult(0); + rewriter.restoreInsertionPoint(insertPt); + } else { + std::tie(laneId, warpId) = getLaneAndWarpId(rewriter, loc); + } + auto calcPaddedOffset = [&](Value smemOffset) { TritonLLVMOpBuilder b(loc, rewriter); auto bitwidth = dstTy.getElementTypeBitWidth(); @@ -873,7 +906,8 @@ struct BufferLoadToLocalOpConversion }; lowerDirectToLDSLoad(rewriter, loc, ptrType, flatDstTy, loadVals, llDst, - resElemTy, vec, emitBufferLoadLds); + resElemTy, vec, targetInfo.getISAFamily(), + emitBufferLoadLds); // Drop the result token. Value zero = LLVM::ConstantOp::create(rewriter, op.getLoc(), @@ -999,7 +1033,8 @@ struct AsyncCopyGlobalToLocalOpConversion }; lowerDirectToLDSLoad(rewriter, loc, srcTy, flatDstTy, loadVals, llDst, - resElemTy, vec, emitGlobalLoadLds); + resElemTy, vec, targetInfo.getISAFamily(), + emitGlobalLoadLds); // Drop the result token. Value zero = LLVM::ConstantOp::create(rewriter, op.getLoc(), From 14fd9cbe332bdcc5a74d23eba8b400e4806bde6b Mon Sep 17 00:00:00 2001 From: Aleksandr Efimov <130555951+alefimov-amd@users.noreply.github.com> Date: Sun, 2 Nov 2025 22:02:11 +0100 Subject: [PATCH 5/7] [AMD] Add BufferOp interface (#8600) This PR introduces a common interface for buffer ops. --------- Co-authored-by: Alexander Efimov --- .../Dialect/TritonAMDGPU/IR/CMakeLists.txt | 5 +++ .../include/Dialect/TritonAMDGPU/IR/Dialect.h | 1 + .../IR/TritonAMDGPUOpInterfaces.td | 42 +++++++++++++++++++ .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 6 +++ .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 1 + 5 files changed, 55 insertions(+) create mode 100644 third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOpInterfaces.td diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt b/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt index 094ecfc7d4..63e9f4a086 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt @@ -15,4 +15,9 @@ mlir_tablegen(TritonAMDGPUEnums.h.inc -gen-enum-decls) mlir_tablegen(TritonAMDGPUEnums.cpp.inc -gen-enum-defs) mlir_tablegen(TritonAMDGPUAttrDefs.h.inc -gen-attrdef-decls) mlir_tablegen(TritonAMDGPUAttrDefs.cpp.inc -gen-attrdef-defs) + +set(LLVM_TARGET_DEFINITIONS TritonAMDGPUOpInterfaces.td) +mlir_tablegen(TritonAMDGPUOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(TritonAMDGPUOpInterfaces.cpp.inc -gen-op-interface-defs) + add_public_tablegen_target(TritonAMDGPUAttrDefsIncGen) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index 9d91da924a..ca4798f88d 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -41,6 +41,7 @@ #define GET_ATTRDEF_CLASSES #include "amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.h.inc" +#include "amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOpInterfaces.h.inc" #define GET_OP_CLASSES #include "amd/include/Dialect/TritonAMDGPU/IR/Ops.h.inc" diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOpInterfaces.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOpInterfaces.td new file mode 100644 index 0000000000..b67cf14fc4 --- /dev/null +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOpInterfaces.td @@ -0,0 +1,42 @@ +#ifndef TRITON_AMDGPU_OP_INTERFACES +#define TRITON_AMDGPU_OP_INTERFACES + +include "mlir/IR/OpBase.td" + +def BufferOpInterface : OpInterface<"BufferOpInterface"> { + let description = [{ + This interface is implemented by buffer load/store operations. + It provides methods to access common properties such base pointer, offset, mask and others. + }]; + + let cppNamespace = "::mlir::triton::amdgpu"; + + let methods = [ + InterfaceMethod< + /*desc=*/"Get operation base ptr.", + /*retType=*/"::mlir::TypedValue<::mlir::triton::PointerType>", + /*methodName=*/"getPtr">, + InterfaceMethod< + /*desc=*/"Get mutable operation base ptr.", + /*retType=*/"::mlir::OpOperand &", + /*methodName=*/"getPtrMutable">, + InterfaceMethod< + /*desc=*/"Get operation offset tensor.", + /*retType=*/"::mlir::TypedValue<::mlir::TensorType>", + /*methodName=*/"getOffsets">, + InterfaceMethod< + /*desc=*/"Get mutable operation offset tensor.", + /*retType=*/"::mlir::OpOperand &", + /*methodName=*/"getOffsetsMutable">, + InterfaceMethod< + /*desc=*/"Get operation stride.", + /*retType=*/"::mlir::TypedValue<::mlir::IntegerType>", + /*methodName=*/"getStride">, + InterfaceMethod< + /*desc=*/"Get mutable operation stride.", + /*retType=*/"::mlir::MutableOperandRange ", + /*methodName=*/"getStrideMutable"> + ]; +} + +#endif // TRITON_AMDGPU_OP_INTERFACES diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 9c904fb803..a69b2ffc3c 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -41,6 +41,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "TritonAMDGPUDialect.td" include "TritonAMDGPUAttrDefs.td" +include "TritonAMDGPUOpInterfaces.td" class TT_AMDGPU_Op traits = []> : @@ -283,6 +284,7 @@ def CondBarrierOp : TT_AMDGPU_Op<"cond_barrier"> { def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [ SameLoadStoreOperandsAndResultEncoding, AttrSizedOperandSegments, + BufferOpInterface, TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">, TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">, TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)", @@ -328,6 +330,7 @@ def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [ def BufferLoadToLocalOp : TT_AMDGPU_Op<"buffer_load_to_local", [ AttrSizedOperandSegments, + BufferOpInterface, TypesMatchWith<"dest element type matches pointee type of ptr", "dest", "ptr", "getPointerTypeToElement($_self)">, TypesMatchWith<"infer mask shape from offsets", "offsets", "mask", "getI1SameShape($_self)", @@ -364,6 +367,7 @@ def BufferLoadToLocalOp : TT_AMDGPU_Op<"buffer_load_to_local", [ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [ AttrSizedOperandSegments, SameLoadStoreOperandsAndResultEncoding, + BufferOpInterface, TypesMatchWith<"result element type matches the value type", "result", "value", "$_self">, TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">, TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">, @@ -410,6 +414,7 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [ //===----------------------------------------------------------------------===// def BufferAtomicCASOp : TT_AMDGPU_Op<"buffer_atomic_cas", [ SameLoadStoreOperandsAndResultEncoding, + BufferOpInterface, TypesMatchWith<"result element type matches the val type", "result", "val", "$_self">, TypesMatchWith<"result element type matches the cmp type", "result", "cmp", "$_self">, TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">, @@ -452,6 +457,7 @@ def BufferAtomicCASOp : TT_AMDGPU_Op<"buffer_atomic_cas", [ def BufferStoreOp : TT_AMDGPU_Op<"buffer_store", [ AttrSizedOperandSegments, SameLoadStoreOperandsEncoding, + BufferOpInterface, TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">, TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">, TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)", diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index ef38511505..98ef35aed1 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -62,6 +62,7 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { #define GET_OP_CLASSES #include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc" +#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUOpInterfaces.cpp.inc" namespace mlir::triton::amdgpu { From 318fa9c42fd9f0f7807d57b79640d3abb44f58bd Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 2 Nov 2025 13:39:18 -0800 Subject: [PATCH 6/7] [triton_kernels] forbid use of `split_k > 1` with fused scatter (#8618) API doesn't accept scale for the intermediate tensor produced between split_k and fused_scatter; this mode should therefore be disabled for now. Will be re-enabled after expert aggregation is moved out of the matmul_ogs API --- python/triton_kernels/tests/test_matmul.py | 44 ++++++++----------- .../test_opt_flags_split_k.py | 2 +- .../triton_kernels/matmul_ogs.py | 15 ++++--- .../matmul_ogs_details/opt_flags.py | 40 ++++++----------- 4 files changed, 40 insertions(+), 61 deletions(-) diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index 4fa4169741..3b065926f6 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -303,19 +303,19 @@ class Case: ], ) @pytest.mark.parametrize("block_m", [16, 128]) -@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter, inner_expt_opt", [ - (False, False, False, None), - (True, False, False, None), - (False, True, False, None), - (False, True, True, None), - (True, True, False, None), - (True, True, True, None), - (False, False, False, "pad_w"), - (False, False, False, "pad_x"), +@pytest.mark.parametrize("do_gather, do_scatter, inner_expt_opt", [ + (False, False, None), + (True, False, None), + (False, True, None), + (False, True, None), + (True, True, None), + (True, True, None), + (False, False, "pad_w"), + (False, False, "pad_x"), ]) @pytest.mark.parametrize("has_y_gammas", [False, True]) @pytest.mark.parametrize("is_persistent", [False, True]) -def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, +def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, x_transpose, w_transpose, y_transpose, device, opt_flags_scope): @@ -323,7 +323,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o # the frame that called pytest.skip, including all the tensors, leading to OOM. skip_message = None try: - _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, + _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, x_transpose, w_transpose, y_transpose, device, opt_flags_scope) @@ -333,7 +333,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o if skip_message is not None: pytest.skip(skip_message) -def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, +def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, x_transpose, w_transpose, y_transpose, device, opt_flags_scope): @@ -362,9 +362,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_ if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3(): pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform") - if fused_scatter and split_k is not None and split_k > 1: - pytest.skip("fused scatter scratchpad not supported with split_k") - if hbm_swizzling: if is_hip(): if not is_hip_cdna4(): @@ -413,7 +410,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_ "block_m": block_m, "block_k": block_k, "split_k": split_k, - "fused_scatter": fused_scatter, "is_persistent": is_persistent, "epilogue_subtile": epilogue_subtile, } @@ -726,12 +722,11 @@ def test_set_idle_sms(): (800, 800, 400, "batched"), ]) @pytest.mark.parametrize("split_k", [1, 2]) -@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [ - (False, False, False), - (True, False, False), - (False, True, False), - (True, True, False), - (True, True, True), +@pytest.mark.parametrize("do_gather, do_scatter", [ + (False, False), + (True, False), + (False, True), + (True, True), ]) @pytest.mark.parametrize("is_persistent, epilogue_subtile", [ (False, None), @@ -743,16 +738,13 @@ def test_set_idle_sms(): (1.0, 1.2), (0.7, 1.0), ]) -def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter, is_persistent, epilogue_subtile, +def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, is_persistent, epilogue_subtile, swiglu_alpha, swiglu_limit, device, opt_flags_scope): - if fused_scatter and split_k > 1: - pytest.skip("fused scatter scratchpad not supported with split_k") torch.manual_seed(0) constraints = { "is_persistent": is_persistent, "epilogue_subtile": epilogue_subtile, "split_k": split_k, - "fused_scatter": fused_scatter, } n_expts_tot, n_expts_act = 1, 1 opt_flags.update_opt_flags_constraints(constraints) diff --git a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py index d26a81ab36..3ce12f8e7d 100644 --- a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py +++ b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py @@ -185,7 +185,7 @@ def get_flags(split_k, max_mn): k, None, False, - False, + True, False, 0, False, diff --git a/python/triton_kernels/triton_kernels/matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs.py index 85291c667a..7ca9e9788e 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs.py @@ -419,9 +419,10 @@ def matmul_ogs(x, w, bias, has_gather_tma = has_gather and target_info.has_tma_gather() # hopper w/ mxfp4 doesn't support TMA can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4) + can_use_split_k = scatter_indx is None and not x_has_mx and not w_has_mx opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config, batch_size, M, N, w.shape[-2], routing_data, - can_use_tma, scatter_indx is not None, epilogue.effective_itemsize, + can_use_tma, can_use_split_k, epilogue.effective_itemsize, x_transpose, y_acc_in is not None, inner_routing_data.block_k if inner_routing_data is not None else None, ) @@ -618,21 +619,21 @@ def matmul_ogs(x, w, bias, **fused_comm_kwargs, **opt_flags.target_kernel_kwargs) + assert not (opt_flags.split_k > 1 and scatter_indx is not None) out_final_mx_scale = None if opt_flags.split_k > 1: assert not out_matmul_has_mx - has_scatter = scatter_indx is not None postprocess_fn1 = ReducePostprocessFn(specs=reduce_fused_activation.specs, fn_args=reduce_fused_activation.fn_args) postprocess_fn2 = None if has_scatter else ReducePostprocessFn(specs=epilogue.specs, fn_args=epilogue.fn_arg_values_finalize) y, y_mx_scale = reduce( x = out_matmul.view(out_matmul.shape[0], -1, out_matmul.shape[-1]), dim = 0, # output data/metadata - y = None if has_scatter else memory["output"].view(-1, memory["output"].shape[-1]), - y_dtype = out_matmul.dtype if has_scatter else memory["output"].dtype, - y_flex = OutFlexData() if has_scatter else precision_config.flex_ctx.out_data, - y_flex_saturate_inf = None if has_scatter else precision_config.flexpoint_saturate_inf, - y_has_mx = scatter_indx is None and precision_config.out_scale is not None, + y = memory["output"].view(-1, memory["output"].shape[-1]), + y_dtype = memory["output"].dtype, + y_flex = precision_config.flex_ctx.out_data, + y_flex_saturate_inf = precision_config.flexpoint_saturate_inf, + y_has_mx = precision_config.out_scale is not None, # fused functions postprocess_fn1 = postprocess_fn1, postprocess_fn2 = postprocess_fn2, diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 31f37ba444..53a79c3f5b 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -22,7 +22,6 @@ class OptFlags: w_cache_modifier: str split_k: int is_persistent: bool - fused_scatter: bool idle_sms: int epilogue_subtile: int | None arch: str @@ -56,14 +55,14 @@ def make_default_opt_flags_amd( k, routing_data, can_use_persistent_tma, - can_use_fused_scatter, + can_use_split_k, enforce_bitwise_invariance, epilogue_effective_itemsize, x_transpose, has_y_acc_in, constraints, ): - constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "max_allowable_mn"] + constraints_supported = ["block_m", "block_n", "block_k", "split_k", "is_persistent", "epilogue_subtile", "max_allowable_mn"] assert not any([c not in constraints_supported for c in constraints]), constraints.keys() # tokens per expert if routing_data is None: @@ -102,13 +101,12 @@ def make_default_opt_flags_amd( ) is_persistent = constraints.get("is_persistent", False) # split_k: + split_k = 1 if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None: split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k")) elif constraints.get("split_k", None) is not None: split_k = constraints["split_k"] - elif is_persistent or enforce_bitwise_invariance: - split_k = 1 - else: + elif can_use_split_k and not enforce_bitwise_invariance: grid_size = grid_m * ((n + block_n - 1) // block_n) n_cu = torch.cuda.get_device_properties(0).multi_processor_count split_k = max(1, n_cu // grid_size) @@ -156,7 +154,6 @@ def replace_with_valid_constraint(k: str, v): w_cache_modifier=w_cache_modifier, split_k=split_k, is_persistent=is_persistent, - fused_scatter=constraints.get('fused_scatter', False), idle_sms=0, epilogue_subtile=epilogue_subtile, arch=None, @@ -177,14 +174,14 @@ def make_default_opt_flags_nvidia( k, routing_data, can_use_persistent_tma, - can_use_fused_scatter, + can_use_split_k, enforce_bitwise_invariance, epilogue_effective_itemsize, x_transpose, has_y_acc_in, constraints, ): - constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms", "max_allowable_mn"] + constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "epilogue_subtile", "num_stages", "idle_sms", "max_allowable_mn"] assert not any([c not in constraints_supported for c in constraints]), constraints.keys() # tokens per expert if routing_data is None or batch_size > 1: @@ -236,26 +233,21 @@ def make_default_opt_flags_nvidia( if constraints.get("block_k", None) is not None: block_k = constraints["block_k"] # split_k + split_k = 1 if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None: split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k")) elif constraints.get("split_k", None) is not None: split_k = constraints["split_k"] - elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None: - split_k = 1 - else: + elif can_use_split_k and not enforce_bitwise_invariance: estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, batch_size, m, n, block_m, block_n) split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size) - if split_k > 1: - # With split_k, results are written in f32. Use that for the following computations. - out_dtype = torch.float32 compute_num_stages_args = ( precision_config, is_persistent, - block_m, block_n, block_k, - out_dtype, + torch.float32 if split_k > 1 else out_dtype, lhs_dtype, rhs_dtype, x_transpose, @@ -276,11 +268,6 @@ def make_default_opt_flags_nvidia( if constraints.get("num_stages", None): num_stages = constraints["num_stages"] assert num_stages >= 1 - # fused scatter scratchpad - if constraints.get("fused_scatter", None) is not None: - fused_scatter = constraints["fused_scatter"] - else: - fused_scatter = can_use_fused_scatter and split_k == 1 # Handshake with the HBM swizzling num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, is_persistent, precision_config) ret = OptFlags( @@ -289,7 +276,6 @@ def make_default_opt_flags_nvidia( block_k=block_k, num_warps=num_warps, num_stages=num_stages, - fused_scatter=fused_scatter, group_m=group_m, xcd_swizzle=xcd_swizzle, w_cache_modifier=None, @@ -343,7 +329,7 @@ def make_opt_flags( k, routing_data, can_use_persistent_tma, - can_use_fused_scatter, + can_use_split_k, epilogue_effective_itemsize, x_transpose, has_y_acc_in, @@ -351,8 +337,8 @@ def make_opt_flags( ): if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma: raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint") - if _opt_flags_constraints.get("fused_scatter", False) and not can_use_fused_scatter: - raise InapplicableConstraint("cannot enforce `fused_scatter=True` constraint") + if _opt_flags_constraints.get("split_k") is not None and _opt_flags_constraints.get("split_k") > 1 and not can_use_split_k: + raise InapplicableConstraint("cannot enforce `split_k=True` constraint") if _opt_flags_constraints.get("max_allowable_mn"): if not _opt_flags_constraints.get("split_k"): raise InapplicableConstraint("split_k also needs to be provided with max_allowable_mn") @@ -366,7 +352,7 @@ def make_opt_flags( opt_flags_constraints = opt_flags_constraints.copy() opt_flags_constraints.update(block_k=block_k, split_k=1) args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, batch_size, m, n, k, - routing_data, can_use_persistent_tma, can_use_fused_scatter, + routing_data, can_use_persistent_tma, can_use_split_k, enforce_bitwise_invariance, epilogue_effective_itemsize, x_transpose, has_y_acc_in, opt_flags_constraints] backend = triton.runtime.driver.active.get_current_target().backend From ea925e13391bbe2e63489dfabe83c45d5c3d5043 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Tue, 18 Nov 2025 03:26:45 +0000 Subject: [PATCH 7/7] [TEST] Fix `triton_kernels` failures after `318fa9c` Signed-off-by: Whitney Tsang --- .../triton_kernels/matmul_ogs_details/opt_flags.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 7838b28839..15adbf6e44 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -55,14 +55,14 @@ def make_default_opt_flags_intel( k, routing_data, can_use_persistent_tma, - can_use_fused_scatter, + can_use_split_k, enforce_bitwise_invariance, epilogue_effective_itemsize, x_transpose, has_y_acc_in, constraints, ): - constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "max_allowable_mn"] + constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "epilogue_subtile", "num_stages", "max_allowable_mn"] assert not any([c not in constraints_supported for c in constraints]), constraints.keys() # tokens per expert if routing_data is None: @@ -111,7 +111,6 @@ def make_default_opt_flags_intel( block_k=block_k, num_warps=opt_flags_intel.compute_num_warps(block_m, block_n), num_stages=constraints.get("num_stages", 2), - fused_scatter=constraints.get('fused_scatter', False), group_m=group_m, xcd_swizzle=xcd_swizzle, w_cache_modifier=None,