From 09fec1ef85baac9894a6dd237bb27fcd4a6e8981 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Thu, 29 May 2025 17:28:33 +0000 Subject: [PATCH 01/20] Add pass to convert block load to subgroup 2d block encoding types --- .../optimize-block-io-encoding.mlir | 65 ++++ third_party/intel/backend/compiler.py | 1 + .../TritonIntelGPU/Transforms/Passes.td | 11 + .../TritonIntelGPUTransforms/CMakeLists.txt | 1 + .../OptimizeBlockIOEncoding.cpp | 319 ++++++++++++++++++ third_party/intel/triton_xpu.cc | 3 + 6 files changed, 400 insertions(+) create mode 100644 test/TritonIntelGPU/optimize-block-io-encoding.mlir create mode 100644 third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir new file mode 100644 index 0000000000..68174d5d90 --- /dev/null +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -0,0 +1,65 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --tritonintelgpu-optimize-block-io-encoding | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> +// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> +// CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) attributes {noinline = false} { + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c5120_i32 = arith.constant 5120 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c64_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %4 : i32 + %6 = arith.addi %2, %5 : i32 + %7 = arith.remsi %0, %c64_i32 : i32 + %8 = arith.divsi %7, %4 : i32 + %9 = arith.muli %6, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}} : > + %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array} : > + %11 = arith.muli %8, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}} : > + %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %11] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr>) : i32 { + %17 = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[A_LOAD:.*]] = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #mma> -> tensor<256x32xf16, #blocked1> + %18 = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma1> -> tensor<32x256xf16, #blocked2> + %19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> + %22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 2}>> -> tensor<256x256xf32, #mma2> + %24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + // CHECK: tt.advance {{.*}} : > + %26 = tt.advance %arg5, [%c0_i32, %c32_i32] : > + // CHECK: tt.advance {{.*}} : > + %27 = tt.advance %arg6, [%c32_i32, %c0_i32] : > + scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr> + } + %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array} : > + %15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> + %16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked2> + tt.store %14, %16 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index d225494f75..2c69390029 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -280,6 +280,7 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) + intel.passes.ttgpuir.add_optimize_block_load_encoding(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, XPUBackend.get_split_barrier_scope(opt)) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index c20224aaee..91625ba0be 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -409,4 +409,15 @@ def TritonIntelGPUReduceVariableLiveness "mlir::scf::SCFDialect", "mlir::arith::ArithDialect"]; } + +def TritonIntelGPUOptimizeBlockIOEncodingPass : Pass<"tritonintelgpu-optimize-block-io-encoding", "mlir::ModuleOp"> { + let summary = "Set encodings on candidates for Subgroup 2D Block IO ops"; + + let description = [{ + Set the Subgroup2DBlock encoding on tensor ptr types that are candidates for Subgroup 2D Block IO lowering. The goal is to change the tensor ptr type to use the new encoding so the LoadOp will use the new encoding, allowing the encoding to be an anchor layout during RemoveLayoutConversions. To avoid duplicating work in RemoveLayoutConversions, a ConvertLayout op to the existing encoding replaces the result of the LoadOp. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::gpu::intel::TritonIntelGPUDialect", "mlir::triton::TritonDialect"]; +} + #endif // TRITON_INTEL_GPU_PASSES diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index b8cb96cfa0..bb32041127 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -5,6 +5,7 @@ add_triton_library(TritonIntelGPUTransforms DistributeToWarps.cpp MatchTargetSize.cpp MaterializeBlockPointer.cpp + OptimizeBlockIOEncoding.cpp OptimizeDotOperands.cpp OptimizeReductionLocality.cpp Pipeliner/MatmulLoopPipeline.cpp diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp new file mode 100644 index 0000000000..14ca36f31b --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -0,0 +1,319 @@ +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "llvm/ADT/PriorityWorklist.h" + +namespace ttg = mlir::triton::gpu; +namespace ttgi = mlir::triton::gpu::intel; + +namespace mlir { +namespace triton { +namespace gpu::intel { + +#define DEBUG_TYPE "tritongpu-optimize-block-encoding" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { + +SmallVector getTiedArgs(Operation *op, int resultIdx) { + if (auto forOp = dyn_cast(op)) { + auto iterArg = forOp.getRegionIterArg(resultIdx); + auto result = forOp.getResult(resultIdx); + auto yieldVal = forOp.getBody()->getTerminator()->getOperand(resultIdx); + auto initVal = forOp.getInitArgs()[resultIdx]; + return {iterArg, result, yieldVal, initVal}; + } else if (auto whileOp = dyn_cast(op)) { + auto iterArg = whileOp.getBeforeArguments()[resultIdx]; + auto result = whileOp.getResults()[resultIdx]; + auto yieldVal = + whileOp.getBeforeBody()->getTerminator()->getOperand(resultIdx); + auto initVal = whileOp.getOperands()[resultIdx]; + return {iterArg, result, iterArg, initVal}; + } else if (auto ifOp = dyn_cast(op)) { + SmallVector values; + for (auto &block : ifOp.getThenRegion().getBlocks()) { + auto terminator = block.getTerminator(); + if (isa(terminator)) + values.push_back(terminator->getOperands()[resultIdx]); + } + for (auto &block : ifOp.getElseRegion().getBlocks()) { + auto terminator = block.getTerminator(); + if (isa(terminator)) + values.push_back(terminator->getOperands()[resultIdx]); + } + values.push_back(ifOp->getResults()[resultIdx]); + return values; + } + return {}; +} + +Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); +} + +Type getNewPointerType(Type type, Attribute encoding) { + assert(isa(type) && "expected a ptr type!"); + auto oldPointerType = cast(type); + return PointerType::get(getNewType(oldPointerType.getPointeeType(), encoding), + oldPointerType.getAddressSpace()); +} + +struct EncodingInfo { + Attribute desiredEncoding; + bool requiresConvert = false; + + bool operator==(const EncodingInfo &other) const { + return desiredEncoding == other.desiredEncoding && + requiresConvert == other.requiresConvert; + } +}; + +/** + * The algorithm here takes inspiration from + * TritonNVIDIAGPU::OptimizeDescriptorEncoding. The idea is to iterate the + * def-use chain in both directions starting from the Load Op. We store the + * values that need to be updated along with the new encoding in the + * `valueToEncodingInfo` MapVector. After all value/encoding pairs have been + * determined, we update the encoding for each value, adding aa conversion to + * the existing Load Op result layout for users of the load. + */ +void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) { + auto loadOp = cast(op); + auto loadPtrType = cast(loadOp->getOperand(0).getType()); + auto addressSpace = loadPtrType.getAddressSpace(); + + llvm::MapVector, EncodingInfo> valueToEncodingInfo; + llvm::PriorityWorklist> worklist; + + auto updateEncoding = [&](ArrayRef ptrValues, EncodingInfo info) { + for (auto value : ptrValues) { + bool requiresConvert = llvm::any_of( + value.getUsers(), [](auto user) { return isa(user); }); + info.requiresConvert = requiresConvert; + + auto typedVal = cast>(value); + auto itr = valueToEncodingInfo.find(typedVal); + if (itr == valueToEncodingInfo.end()) { + LLVM_DEBUG(DBGS() << "Add encoding " << info.desiredEncoding + << " for value " << typedVal << "\n"); + valueToEncodingInfo[typedVal] = info; + worklist.insert(typedVal); + } else { + LLVM_DEBUG(DBGS() << "Found existing encoding info " + << itr->second.desiredEncoding << " for value " + << typedVal << ". Ensure new encoding " + << info.desiredEncoding << " matches.\n"); + assert(itr->second == info && "already visited encoding info for " + "value, expected them to be equal!"); + continue; + } + } + }; + + worklist.insert(cast>(loadOp->getOperand(0))); + + // 1. Starting from the Load Op, propagate encoding info up and down the + // def-use chain. + while (!worklist.empty()) { + auto crtValue = worklist.pop_back_val(); + + // Propagate to users + for (OpOperand &use : crtValue.getUses()) { + auto op = use.getOwner(); + if (isa(op)) { + auto offset = 3 * isa(op); + auto vals = getTiedArgs(op, use.getOperandNumber() - offset); + updateEncoding(vals, EncodingInfo{encoding}); + } else if (isa(op)) { + auto vals = getTiedArgs(op->getParentOp(), use.getOperandNumber()); + updateEncoding(vals, EncodingInfo{encoding}); + } + } + + // Propagate to defining ops + if (auto opResult = dyn_cast(crtValue)) { + auto definingOp = opResult.getOwner(); + if (isa(definingOp)) { + auto vals = getTiedArgs(definingOp, opResult.getResultNumber()); + updateEncoding(vals, EncodingInfo{encoding}); + } + } else if (auto blockArg = dyn_cast(crtValue)) { + auto parentOp = blockArg.getOwner()->getParentOp(); + if (isa(parentOp)) { + auto offset = isa(parentOp); + auto vals = getTiedArgs(parentOp, blockArg.getArgNumber() - offset); + updateEncoding(vals, EncodingInfo{encoding}); + } + } + } + + // 2. Update the type for each value in-place. Add a ConvertLayout Op after + // any loads which require conversion to the existing layout for the loaded + // value. + for (auto &[val, einfo] : valueToEncodingInfo) { + Attribute newEncoding = einfo.desiredEncoding; + LLVM_DEBUG(DBGS() << "Rewrite encoding to " << newEncoding << " for value " + << val << "\n"); + + PointerType oldType = val.getType(); + auto oldTensorTy = cast(oldType.getPointeeType()); + auto newTensorTy = RankedTensorType::get( + oldTensorTy.getShape(), oldTensorTy.getElementType(), newEncoding); + + val.setType(PointerType::get(newTensorTy, oldType.getAddressSpace())); + if (einfo.requiresConvert) { + for (auto user : val.getUsers()) { + if (auto loadOp = dyn_cast(user)) { + + OpBuilder builder(loadOp); + auto oldLoadType = loadOp.getType(); + Value result = loadOp.getResult(); + + builder.setInsertionPointAfter(loadOp); + auto cvt = builder.create(loadOp.getLoc(), + result.getType(), result); + LLVM_DEBUG(DBGS() << "Added convert Op:\n" + << cvt << " after Load Op:\n" + << loadOp << "\n"); + result.setType(newTensorTy); + + result.replaceAllUsesExcept(cvt.getResult(), cvt.getOperation()); + } + } + } + } +} + +} // namespace + +#define GEN_PASS_DEF_TRITONINTELGPUOPTIMIZEBLOCKIOENCODINGPASS +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" + +class TritonIntelGPUOptimizeBlockIOEncodingPass + : public impl::TritonIntelGPUOptimizeBlockIOEncodingPassBase< + TritonIntelGPUOptimizeBlockIOEncodingPass> { + + void getSubgroup2DBlockLayoutForOperand( + Value operand, DpasEncodingAttr dpasLayout, + llvm::MapVector &layoutMap) { + auto isCandidateLoad = [](Value v) -> LoadOp { + // Peel out the original cvt dot_op<..., #blocked> + // and any other potential cvt/trans ops + while (true) { + if (auto cvtOp = v.getDefiningOp()) { + v = cvtOp.getSrc(); + continue; + } + if (auto transOp = v.getDefiningOp()) { + v = transOp.getSrc(); + continue; + } + break; + } + return isa(v.getDefiningOp()) ? cast(v.getDefiningOp()) + : nullptr; + }; + + LoadOp loadOp = isCandidateLoad(operand); + if (!loadOp) + return; + + auto dotOperandType = cast(operand.getType()); + auto dotOperandEncoding = + cast(dotOperandType.getEncoding()); + // layout width is determined by the DPAS operand encoding width + const int kWidth = dotOperandEncoding.getKWidth(); + + Attribute blockIOAttr = + loadOp->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); + if (!blockIOAttr) + return; + + // get the MakeTensorPtr Op for the load + Value ptr = loadOp.getPtr(); + if (!isTensorPointerType(ptr.getType())) { + // TODO: support tensor of pointer loads + LLVM_DEBUG(DBGS() << "Ptr\n" + << ptr << " for Load Op:\n" + << loadOp + << "\nincompatible with Subgroup 2D Block Layout.\n"); + return; + } + MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(ptr); + assert(makeTensorPtrOp && + "expecting a tensor pointer parent to block io load " + "with tensor pointer type"); + + auto oldTensorPtrType = cast(makeTensorPtrOp.getType()); + auto oldTensorType = + cast(oldTensorPtrType.getPointeeType()); + // Note: we need the old layout to get the order for the load, but it is not + // clear the layout will always be Blocked. Is there a better way to get + // this info? + auto oldLayout = cast(oldTensorType.getEncoding()); + + auto CTALayout = getCTALayout(dpasLayout); + const unsigned elemSizeInBits = + oldTensorType.getElementType().getIntOrFloatBitWidth(); + + auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( + cast(dotOperandEncoding), + oldTensorType.getShape(), + blockIOAttr == StringAttr::get(&getContext(), "row_major"), + elemSizeInBits / 8, &getContext()); + SmallVector instrShape{tileParams[0], tileParams[1]}; + const unsigned vBlocks = tileParams[2]; + + auto subgroup2DBlockEncoding = Subgroup2DBlockEncodingAttr::get( + &getContext(), dpasLayout.getWarpsPerCTA(), CTALayout, instrShape, + tileParams[2], + getOrderForDotOperand(dotOperandEncoding.getOpIdx(), /*rank*/ 2, + /*kContig*/ true), + kWidth, dpasLayout.getThreadsPerWarp()); + + LLVM_DEBUG(DBGS() << "Generated new encoding: " << subgroup2DBlockEncoding + << " for op : " << loadOp << "\n"); + + layoutMap[loadOp] = subgroup2DBlockEncoding; + } + +public: + void runOnOperation() override { + ModuleOp m = getOperation(); + + // Step 1. Find all loads which are candidates for conversion to Subgroup 2D + // Block Encoding. To be a candidate load, a load must be consumed by a Dot + // Op and the load operand must be a block ptr (produced by a MakeTensorPtr + // Op). Currently we look for loads with the "block_io" attribute but we + // could consider moving that logic to this pass later. We place the load + // and the candidate encoding into the layout map for propagation in step 2 + llvm::MapVector layoutMap; + m.walk([&](DotOp dotOp) { + auto dotOpType = cast(dotOp.getResult().getType()); + auto dpasLayout = dyn_cast(dotOpType.getEncoding()); + if (!dpasLayout) + return; + + getSubgroup2DBlockLayoutForOperand(dotOp.getA(), dpasLayout, layoutMap); + getSubgroup2DBlockLayoutForOperand(dotOp.getB(), dpasLayout, layoutMap); + }); + + // Step 2. Rewrite MakeTensorPtr to use the new layout and propagate the + // change through the def-use chain, terminating at the Load Op. We add a + // ConvertLayout Op after the Load Op to convert back to the original + // layout. Subgroup2DBlockEncoding layouts will be chosen as anchor layouts + // in RemoveLayoutConversions, and a subsequent run of + // RemoveLayoutConversions after this pass cleans up intermediate layout + // conversions and removes the original Load Op encoding. + for (auto &kv : layoutMap) { + rewriteTensorLayoutsForOp(kv.second, kv.first); + } + } +}; + +} // namespace gpu::intel +} // namespace triton +} // namespace mlir diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index ae485a2c7b..1abab23024 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -115,6 +115,9 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { gpu::intel::createTritonIntelGPUReduceDataDuplication); ADD_PASS_WRAPPER_0("add_materialize_block_pointer", gpu::intel::createTritonIntelGPUMaterializeBlockPointer); + ADD_PASS_WRAPPER_0( + "add_optimize_block_load_encoding", + gpu::intel::createTritonIntelGPUOptimizeBlockIOEncodingPass); ADD_PASS_WRAPPER_0("add_optimize_reduction_locality", gpu::intel::createTritonIntelGPUOptimizeReductionLocality); ADD_PASS_WRAPPER_0("add_reduce_variable_liveness", From f7e81ce0df2eb958fd6c180029a26477c9b03018 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Fri, 6 Jun 2025 19:41:33 +0000 Subject: [PATCH 02/20] Mark subgroup 2d block loads as expenisve loads --- .../lib/TritonIntelGPUTransforms/Utility.cpp | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index 785920770a..ed5c0e57f1 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -120,9 +120,22 @@ bool isExpensiveLoadOrStore(Operation *op) { if (isSingleValue(base)) return false; - // Loads that use a block pointer are expensive if they cannot be lowered to - // 2D block read operations. Temporarily leverage the - // "ttig.block_io" attribute to filter out inexpensive loads. + if (auto loadOp = dyn_cast(op)) { + // Subgroup2DBlockEncodingAttr loads are expensive, but loads without this + // encoding may still be expensive so we only return true if the encodng + // exists + if (auto tensorTy = dyn_cast(loadOp.getType())) + if (isa(tensorTy.getEncoding())) + return true; + } + + // The block ptr attribute identifies loads that are candidates for subgroup + // 2d block io operations. Loads with these attributes (and without the new + // subgroup 2d block encoding above) should have their layouts replaced with + // the layout from the expensive op (usually a dot op with DPAS encoding). The + // load result is convert to the expensive op layout during LLVM lowering. + // Note: the long term plan is to replace this path with the above subgroup 2d + // block encoding layout. Attribute blockIOAttr = op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); if (blockIOAttr) From e301b432e65ff2fc33ee19bfba7b2936fccda6bf Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Sun, 8 Jun 2025 17:00:43 +0000 Subject: [PATCH 03/20] WIP: use the subgroup 2d block layout in LoadStoreOpToLLVM add missing definition WIP: use new encoding in load store op to llvm --- .../TritonIntelGPU/Transforms/Utility.h | 3 ++ .../LoadStoreOpToLLVM.cpp | 50 +++++++++++++++---- .../ReduceDataDuplication.cpp | 2 + .../lib/TritonIntelGPUTransforms/Utility.cpp | 7 +++ 4 files changed, 53 insertions(+), 9 deletions(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h index 9ab7baaa71..580a9a6895 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h @@ -33,6 +33,9 @@ Attribute inferSrcEncoding(Operation *op, Attribute encoding); // Retuns true if the operation is an expensive load or store operation. bool isExpensiveLoadOrStore(Operation *op); +// Returns true if the tensor type has a subgroup 2d block io encoding +bool hasSubgroup2DBlockEncoding(RankedTensorType tensorType); + // Returns true if the tensor type has a dot dpas encoding. bool hasDotDpasEncoding(RankedTensorType tensorType); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 32be8bd222..5e2dd9bfc6 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -302,7 +302,8 @@ struct BlockIOConversionBase : public LoadStoreConversionBase { // Only lower loadOp with dpas layout encoding. auto tensorTy = cast(op.getType()); - return hasDpasEncoding(tensorTy) || hasDotDpasEncoding(tensorTy); + return hasDpasEncoding(tensorTy) || hasDotDpasEncoding(tensorTy) || + hasSubgroup2DBlockEncoding(tensorTy); } template < @@ -1416,12 +1417,31 @@ struct LoadOpConversion auto tensorType = cast(resultType); const bool memoryRowMajor = isMemoryRowMajor(op); - DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType); + + auto getDpasTypeFromCVTOp = [&](Value opResult) -> RankedTensorType { + for (OpOperand user : opResult.getUsers()) { + if (auto cvt = dyn_cast(user.getOwner())) { + return cast(cvt.getResult().getType()); + // return getDpasLayout(cvt.getResult().getType()); + } + } + llvm_unreachable("expected to find a cvt op with dpas layout"); + }; + + auto dpasTensorType = hasSubgroup2DBlockEncoding(tensorType) + ? getDpasTypeFromCVTOp(op.getResult()) + : tensorType; + llvm::errs() << "using dpas tensor type: " << dpasTensorType << "\n"; + DpasEncodingAttr dpasLayout = getDpasLayout(dpasTensorType); + + DpasEncodingAttr::OpIdx opIdx = getOpIdx(dpasTensorType); LLVM_DEBUG(llvm::dbgs() << "Tensor type for op " << int(opIdx) << ": " << tensorType << "\n"); Attribute encoding = tensorType.getEncoding(); + // TODO: this gives us the linear layour corresponding + // to the subgroup 2d block encoding, not the dpas encoding... std::optional llEncoding = cast(encoding).toLinearLayout( tensorType.getShape()); @@ -1440,14 +1460,21 @@ struct LoadOpConversion Type eltTy = tensorType.getElementType(); unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); - auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( - cast(encoding), tensorType.getShape(), - memoryRowMajor, elemSizeInBits / 8, rewriter.getContext()); - unsigned tileHeight = tileParams[0]; - const unsigned tileWidth = tileParams[1]; - const unsigned vBlocks = tileParams[2]; + auto getTileParams = [&]() -> std::tuple { + if (hasSubgroup2DBlockEncoding(tensorType)) { + auto encoding = + cast(tensorType.getEncoding()); + auto shape = encoding.getInstrShape(); + return std::make_tuple(shape[0], shape[1], encoding.getNumBlocks()); + } else { + auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( + cast(encoding), tensorType.getShape(), + memoryRowMajor, elemSizeInBits / 8, rewriter.getContext()); + return std::make_tuple(tileParams[0], tileParams[1], tileParams[2]); + } + }; + auto [tileHeight, tileWidth, vBlocks] = getTileParams(); - DpasEncodingAttr dpasLayout = getDpasLayout(tensorType); const ArrayRef tensorShape = tensorType.getShape(); unsigned numElems = getTotalElemsPerThread(resultType); SmallVector numReps = @@ -1617,6 +1644,7 @@ struct LoadOpConversion // input operands to DPAS. // TODO: add support for int4 and int2. unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); + llvm::errs() << "opsPerChannel = " << opsPerChannel << "\n"; if ((opsPerChannel == 4 && elemSizeInBits == 8) || (opsPerChannel == 2 && elemSizeInBits == 16) || (opsPerChannel == 1 && elemSizeInBits == 32)) { @@ -1840,6 +1868,8 @@ struct LoadOpConversion unsigned numValuesPerLoad = packedElemsPerLanePerDPASInst * numOperandsOuterDimPerLoad * numOperandsInnerDimPerLoad; + llvm::errs() << "num values per load = " << numValuesPerLoad << "\n"; + llvm::errs() << "loadResultElemType = " << loadResultElemType << "\n"; Type load2DGenXType = LLVM::getVectorType(loadResultElemType, numValuesPerLoad); @@ -2187,6 +2217,8 @@ struct LoadOpConversion } Type llvmResultStructTy = typeConverter->convertType(op.getType()); + llvm::errs() << "op.getType() " << op.getType() << "\n"; + llvm::errs() << "llvmResultStructTy: " << llvmResultStructTy << "\n"; Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy); rewriter.replaceOp(op, {resultStruct}); diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp index ee218d76b1..bcc8cbc0b2 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp @@ -30,6 +30,8 @@ class TritonIntelGPUReduceDataDuplicationPass auto srcEncoding = srcType.getEncoding(); if (isa(srcEncoding)) return; + if (isa(srcEncoding)) + return; auto dstDotOp = dyn_cast(dstType.getEncoding()); if (!dstDotOp) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index ed5c0e57f1..a96df14d14 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -153,6 +153,13 @@ bool isExpensiveLoadOrStore(Operation *op) { return false; } +bool hasSubgroup2DBlockEncoding(RankedTensorType tensorType) { + if (!tensorType.getEncoding()) + return false; + + return isa(tensorType.getEncoding()); +} + bool hasDotDpasEncoding(RankedTensorType tensorType) { if (!tensorType.getEncoding()) return false; From 3e149ae112c45d87e48a95a458cd9a3ab1ebb400 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Thu, 12 Jun 2025 21:03:27 +0000 Subject: [PATCH 04/20] remove convert layout op which converts subgroup2d block to dpas this is being handled by loadstoreoptollvm currently --- .../ConvertLayoutOpToLLVM.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 0feb202b49..95097398d2 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -3,6 +3,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "intel/include/Analysis/Utility.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" namespace mlir::triton::gpu { namespace { @@ -27,6 +28,21 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); + if (auto srcTensorTy = cast(srcTy)) { + if (auto dstTensorTy = cast(dstTy)) { + // TODO: replace this with proper conversion once conversion is removed + // from LoadStoreOpToLLVM. + if (intel::hasSubgroup2DBlockEncoding(srcTensorTy) && + intel::hasDotDpasEncoding(dstTensorTy)) { + // need to delete the op and do nothing + llvm::errs() << "need to delete op " << op << "\n"; + // what if we just delete it + rewriter.replaceOp(op, op.getSrc()); + return success(); + } + } + } + LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); LinearLayout srcLayout = toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); From 22b79fbbcfd1b04b60d5ed203b1cce0573db9a5d Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Fri, 13 Jun 2025 01:51:14 +0000 Subject: [PATCH 05/20] remove debug code --- .../lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 3 --- .../intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp | 6 ------ 2 files changed, 9 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 95097398d2..4c7b208a28 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -34,9 +34,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // from LoadStoreOpToLLVM. if (intel::hasSubgroup2DBlockEncoding(srcTensorTy) && intel::hasDotDpasEncoding(dstTensorTy)) { - // need to delete the op and do nothing - llvm::errs() << "need to delete op " << op << "\n"; - // what if we just delete it rewriter.replaceOp(op, op.getSrc()); return success(); } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 5e2dd9bfc6..536576789c 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1431,7 +1431,6 @@ struct LoadOpConversion auto dpasTensorType = hasSubgroup2DBlockEncoding(tensorType) ? getDpasTypeFromCVTOp(op.getResult()) : tensorType; - llvm::errs() << "using dpas tensor type: " << dpasTensorType << "\n"; DpasEncodingAttr dpasLayout = getDpasLayout(dpasTensorType); DpasEncodingAttr::OpIdx opIdx = getOpIdx(dpasTensorType); @@ -1644,7 +1643,6 @@ struct LoadOpConversion // input operands to DPAS. // TODO: add support for int4 and int2. unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); - llvm::errs() << "opsPerChannel = " << opsPerChannel << "\n"; if ((opsPerChannel == 4 && elemSizeInBits == 8) || (opsPerChannel == 2 && elemSizeInBits == 16) || (opsPerChannel == 1 && elemSizeInBits == 32)) { @@ -1868,8 +1866,6 @@ struct LoadOpConversion unsigned numValuesPerLoad = packedElemsPerLanePerDPASInst * numOperandsOuterDimPerLoad * numOperandsInnerDimPerLoad; - llvm::errs() << "num values per load = " << numValuesPerLoad << "\n"; - llvm::errs() << "loadResultElemType = " << loadResultElemType << "\n"; Type load2DGenXType = LLVM::getVectorType(loadResultElemType, numValuesPerLoad); @@ -2217,8 +2213,6 @@ struct LoadOpConversion } Type llvmResultStructTy = typeConverter->convertType(op.getType()); - llvm::errs() << "op.getType() " << op.getType() << "\n"; - llvm::errs() << "llvmResultStructTy: " << llvmResultStructTy << "\n"; Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy); rewriter.replaceOp(op, {resultStruct}); From b84cbf90e293d255d8907ef03d1f29781e72c034 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Fri, 13 Jun 2025 16:09:21 +0000 Subject: [PATCH 06/20] do not add barrier op for subgroup 2d block -> dpas conversion --- .../Dialect/TritonIntelGPU/Transforms/Utility.h | 6 ++++++ third_party/intel/lib/Analysis/Allocation.cpp | 4 ++++ .../TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 13 +++++-------- .../intel/lib/TritonIntelGPUTransforms/Utility.cpp | 5 +++++ 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h index 580a9a6895..c356f07f20 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h @@ -33,6 +33,12 @@ Attribute inferSrcEncoding(Operation *op, Attribute encoding); // Retuns true if the operation is an expensive load or store operation. bool isExpensiveLoadOrStore(Operation *op); +// Returns true if the conversion between tensor types should be a no-op. Will +// be removed once layout conversion for BlockIO types is lifted from +// LoadStoreOpToLLVM.cpp +bool isBlockIONoOpConversion(RankedTensorType srcType, + RankedTensorType dstType); + // Returns true if the tensor type has a subgroup 2d block io encoding bool hasSubgroup2DBlockEncoding(RankedTensorType tensorType); diff --git a/third_party/intel/lib/Analysis/Allocation.cpp b/third_party/intel/lib/Analysis/Allocation.cpp index 8c9cfe5147..e6cd3df5d2 100644 --- a/third_party/intel/lib/Analysis/Allocation.cpp +++ b/third_party/intel/lib/Analysis/Allocation.cpp @@ -1,5 +1,6 @@ #include "intel/include/Analysis/Allocation.h" #include "intel/include/Analysis/Utility.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" // isBlockIONoOpConversion #include "triton/Dialect/Triton/IR/Utility.h" #include "llvm/ADT/TypeSwitch.h" @@ -11,6 +12,9 @@ constexpr unsigned invalidSize = -1; unsigned allocationAnalysisScratchSizeFn(gpu::ConvertLayoutOp convertLayout) { RankedTensorType srcTy = convertLayout.getSrc().getType(); RankedTensorType dstTy = convertLayout.getResult().getType(); + + if (gpu::intel::isBlockIONoOpConversion(srcTy, dstTy)) + return 0; if (gpu::intel::cvtIsSubGroupShuffle(srcTy, dstTy)) return 0; if (gpu::intel::cvtIsSubGroupTranspose(srcTy, dstTy)) { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 4c7b208a28..173cbdb29a 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -25,18 +25,15 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion ConversionPatternRewriter &rewriter) const override { MLIRContext *ctx = op.getContext(); - auto srcTy = op.getSrc().getType(); + RankedTensorType srcTy = op.getSrc().getType(); auto dstTy = op.getType(); - if (auto srcTensorTy = cast(srcTy)) { - if (auto dstTensorTy = cast(dstTy)) { + if (auto dstTensorTy = cast(dstTy)) { + if (intel::isBlockIONoOpConversion(srcTy, dstTensorTy)) { // TODO: replace this with proper conversion once conversion is removed // from LoadStoreOpToLLVM. - if (intel::hasSubgroup2DBlockEncoding(srcTensorTy) && - intel::hasDotDpasEncoding(dstTensorTy)) { - rewriter.replaceOp(op, op.getSrc()); - return success(); - } + rewriter.replaceOp(op, op.getSrc()); + return success(); } } diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index a96df14d14..f2aa2e6d7a 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -153,6 +153,11 @@ bool isExpensiveLoadOrStore(Operation *op) { return false; } +bool isBlockIONoOpConversion(RankedTensorType srcType, + RankedTensorType dstType) { + return hasSubgroup2DBlockEncoding(srcType) && hasDotDpasEncoding(dstType); +} + bool hasSubgroup2DBlockEncoding(RankedTensorType tensorType) { if (!tensorType.getEncoding()) return false; From f83bed86abf054006c5a71ceeb1829a851f1801f Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Mon, 16 Jun 2025 19:09:02 +0000 Subject: [PATCH 07/20] fixup handling of tensor ptrs when lowering to gather load (with subgroup 2d block layout ) --- .../LoadStoreOpToLLVM.cpp | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 536576789c..d9fb23dfbf 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -343,6 +343,15 @@ struct BlockIOConversionBase : public LoadStoreConversionBase { : getDotEncoding(tensorTy).value().getParent()); } + static RankedTensorType getDpasTypeFromCVTOp(Value opResult) { + for (OpOperand user : opResult.getUsers()) { + if (auto cvt = dyn_cast(user.getOwner())) { + return cast(cvt.getResult().getType()); + } + } + llvm_unreachable("expected to find a cvt op with dpas layout"); + } + // Returns the pitch (stride in bytes) of \p ptr. Value getPitch(ConversionPatternRewriter &rewriter, Value ptr, const std::map, Value> &ptrs, @@ -1418,16 +1427,6 @@ struct LoadOpConversion const bool memoryRowMajor = isMemoryRowMajor(op); - auto getDpasTypeFromCVTOp = [&](Value opResult) -> RankedTensorType { - for (OpOperand user : opResult.getUsers()) { - if (auto cvt = dyn_cast(user.getOwner())) { - return cast(cvt.getResult().getType()); - // return getDpasLayout(cvt.getResult().getType()); - } - } - llvm_unreachable("expected to find a cvt op with dpas layout"); - }; - auto dpasTensorType = hasSubgroup2DBlockEncoding(tensorType) ? getDpasTypeFromCVTOp(op.getResult()) : tensorType; @@ -2213,6 +2212,8 @@ struct LoadOpConversion } Type llvmResultStructTy = typeConverter->convertType(op.getType()); + LLVM_DEBUG(llvm::dbgs() << "Packing load result in struct " + << llvmResultStructTy << "\n"); Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy); rewriter.replaceOp(op, {resultStruct}); @@ -2235,10 +2236,16 @@ struct LoadOpConversion Value mask = op.getMask(); Value llMask = adaptor.getMask(); + auto opType = op.getType(); + // TODO: Override the OpType since conversion is still happening during Load + // lowering. Once we materialize ConvertLayoutOp this can be removed. + if (auto tensorTy = dyn_cast(opType); + hasSubgroup2DBlockEncoding(tensorTy)) + opType = getDpasTypeFromCVTOp(op.getResult()); + // Determine the vectorization size - Type valueElemTy = - typeConverter->convertType(getElementTypeOrSelf(op.getType())); - unsigned numElems = getTotalElemsPerThread(op.getType()); + Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(opType)); + unsigned numElems = getTotalElemsPerThread(opType); unsigned vec = getVectorSize(ptr); if (llMask) vec = std::min(vec, getMaskAlignment(mask)); @@ -2249,7 +2256,7 @@ struct LoadOpConversion if (isTensorPointerType(ptr.getType())) { // fallback to gather load. - auto tensorType = cast(op.getType()); + auto tensorType = cast(opType); std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr( loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter, op.getBoundaryCheck(), op.getPadding()); @@ -2396,7 +2403,7 @@ struct LoadOpConversion } } // end vec - Type llvmResultStructTy = typeConverter->convertType(op.getType()); + Type llvmResultStructTy = typeConverter->convertType(opType); Value resultStruct = packLLElements(loc, typeConverter, loadedVals, rewriter, llvmResultStructTy); rewriter.replaceOp(op, {resultStruct}); From 2b354d23a378454935ec4f1a080bd9dd3c378f41 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Mon, 16 Jun 2025 19:24:37 +0000 Subject: [PATCH 08/20] use dpas tensor type for packed load type (essentially "post-conversion") --- .../intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index d9fb23dfbf..8b31616c67 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -2211,7 +2211,7 @@ struct LoadOpConversion } } - Type llvmResultStructTy = typeConverter->convertType(op.getType()); + Type llvmResultStructTy = typeConverter->convertType(dpasTensorType); LLVM_DEBUG(llvm::dbgs() << "Packing load result in struct " << llvmResultStructTy << "\n"); Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals, From f06f90ba4e0dfff4ed9bad13c86f934158a5b472 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Mon, 16 Jun 2025 20:00:22 +0000 Subject: [PATCH 09/20] fixup op type override --- .../lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 8b31616c67..cb0b5a1f29 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -2239,8 +2239,8 @@ struct LoadOpConversion auto opType = op.getType(); // TODO: Override the OpType since conversion is still happening during Load // lowering. Once we materialize ConvertLayoutOp this can be removed. - if (auto tensorTy = dyn_cast(opType); - hasSubgroup2DBlockEncoding(tensorTy)) + auto tensorTy = dyn_cast(opType); + if (tensorTy && hasSubgroup2DBlockEncoding(tensorTy)) opType = getDpasTypeFromCVTOp(op.getResult()); // Determine the vectorization size @@ -2256,9 +2256,11 @@ struct LoadOpConversion if (isTensorPointerType(ptr.getType())) { // fallback to gather load. - auto tensorType = cast(opType); + // make sure we use the modified opType from above, "seeing through" any + // post-subgroup 2d block encoding CVT. + auto blockPtrTensorType = cast(opType); std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr( - loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter, + loc, adaptor.getPtr(), blockPtrTensorType, valueElemTy, rewriter, op.getBoundaryCheck(), op.getPadding()); } else { Value other = op.getOther(); From 7e5b52b114e6407c67b888384bf3ea2e8d95f129 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Tue, 17 Jun 2025 18:15:46 +0000 Subject: [PATCH 10/20] do not apply subgroup 2d block encoding to A transpose --- .../optimize-block-io-encoding.mlir | 64 +++++++++++++++++++ .../LoadStoreOpToLLVM.cpp | 5 ++ .../OptimizeBlockIOEncoding.cpp | 33 ++++++++-- 3 files changed, 98 insertions(+), 4 deletions(-) diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir index 68174d5d90..8a9b6f184a 100644 --- a/test/TritonIntelGPU/optimize-block-io-encoding.mlir +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -63,3 +63,67 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.tar tt.return } } + +// ----- + +// COM: Dot operand A transpose currently not supported by subgroup 2d block io encoding +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 1], warpsPerCTA = [2, 16], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> +// CHECK: #mma1 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +// CHECK-NOT: #mma2 +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c5120_i32 = arith.constant 5120 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c64_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %4 : i32 + %6 = arith.addi %2, %5 : i32 + %7 = arith.remsi %0, %c64_i32 : i32 + %8 = arith.divsi %7, %4 : i32 + %9 = arith.muli %6, %c256_i32 : i32 + %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c1_i64, %c1024_i64], [%9, %c0_i32] {order = array} : > + %11 = arith.muli %8, %c256_i32 : i32 + %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %11] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr>) : i32 { + // CHECK: {{.*}} = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %17 = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + // CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma> -> tensor<32x256xf16, #blocked2> + %18 = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + %19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> + %22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>> -> tensor<256x256xf32, #mma1> + %24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + // CHECK: tt.advance {{.*}} : > + %26 = tt.advance %arg5, [%c0_i32, %c32_i32] : > + // CHECK: tt.advance {{.*}} : > + %27 = tt.advance %arg6, [%c32_i32, %c0_i32] : > + scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr> + } + %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array} : > + %15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> + %16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked2> + tt.store %14, %16 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index cb0b5a1f29..7fcc5eb253 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -2247,6 +2247,11 @@ struct LoadOpConversion Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(opType)); unsigned numElems = getTotalElemsPerThread(opType); unsigned vec = getVectorSize(ptr); + LLVM_DEBUG({ + llvm::dbgs() << "Vectorization for gather load:\n"; + llvm::dbgs() << "\t" << valueElemTy << " [" << numElems << "]\n"; + llvm::dbgs() << "\tvector size = " << vec << " for " << ptr << "\n"; + }); if (llMask) vec = std::min(vec, getMaskAlignment(mask)); diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp index 14ca36f31b..db912b34ba 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -222,6 +222,15 @@ class TritonIntelGPUOptimizeBlockIOEncodingPass return; auto dotOperandType = cast(operand.getType()); + auto layout = ttg::toLinearEncoding(dotOperandType); + auto order = layout.getThreadOrder(); + auto rank = order.size(); + if (rank != 2) { + loadOp.emitWarning( + "Subgroup 2D Block Encoding layouts only support rank 2 operands."); + return; + } + auto dotOperandEncoding = cast(dotOperandType.getEncoding()); // layout width is determined by the DPAS operand encoding width @@ -232,6 +241,23 @@ class TritonIntelGPUOptimizeBlockIOEncodingPass if (!blockIOAttr) return; + const bool valueRowMajor = + getOrderForDotOperand(0, rank, /*kContig=*/true) == order; + const bool memoryRowMajor = + blockIOAttr == StringAttr::get(&getContext(), "row_major"); + const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; + LLVM_DEBUG({ + DBGS() << "Original layout: " << dotOperandEncoding << "\n"; + DBGS() << "\tvalueRowMajor = " << valueRowMajor << "\n"; + DBGS() << "\tmemoryRowMajor = " << memoryRowMajor << "\n"; + DBGS() << "\tisTransposeRequired = " << isTransposeRequired << "\n"; + }); + if (dotOperandEncoding.getOpIdx() == 0 && isTransposeRequired) { + LLVM_DEBUG(DBGS() << "Transposed 'A' operand does not yet support " + "Subgroup 2D Block Encoding layout.\n"); + return; + } + // get the MakeTensorPtr Op for the load Value ptr = loadOp.getPtr(); if (!isTensorPointerType(ptr.getType())) { @@ -261,16 +287,15 @@ class TritonIntelGPUOptimizeBlockIOEncodingPass auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( cast(dotOperandEncoding), - oldTensorType.getShape(), - blockIOAttr == StringAttr::get(&getContext(), "row_major"), - elemSizeInBits / 8, &getContext()); + oldTensorType.getShape(), memoryRowMajor, elemSizeInBits / 8, + &getContext()); SmallVector instrShape{tileParams[0], tileParams[1]}; const unsigned vBlocks = tileParams[2]; auto subgroup2DBlockEncoding = Subgroup2DBlockEncodingAttr::get( &getContext(), dpasLayout.getWarpsPerCTA(), CTALayout, instrShape, tileParams[2], - getOrderForDotOperand(dotOperandEncoding.getOpIdx(), /*rank*/ 2, + getOrderForDotOperand(dotOperandEncoding.getOpIdx(), /*rank*/ rank, /*kContig*/ true), kWidth, dpasLayout.getThreadsPerWarp()); From 234897feaf459afe6080afa16d4bbf55d78fdecd Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Fri, 20 Jun 2025 15:05:48 +0000 Subject: [PATCH 11/20] Store transpose attribute in Subgroup2DBlockIO layouts --- .../optimize-block-io-encoding.mlir | 74 ++++++++++++++++++- .../IR/TritonIntelGPUAttrDefs.td | 4 +- .../lib/Dialect/TritonIntelGPU/IR/Dialect.cpp | 46 ++++++++---- .../IR/LinearLayoutConversions.cpp | 6 +- .../LoadStoreOpToLLVM.cpp | 3 +- .../OptimizeBlockIOEncoding.cpp | 6 +- .../LinearLayoutConversionsTest.cpp | 53 ++++++++++--- 7 files changed, 156 insertions(+), 36 deletions(-) diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir index 8a9b6f184a..0e135a956e 100644 --- a/test/TritonIntelGPU/optimize-block-io-encoding.mlir +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -3,8 +3,8 @@ #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> -// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> -// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks = 2, isTransposed = false, order = [1, 0], kWidth = 1, threadsPerWarp = 16} +// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks = 2, isTransposed = false, order = [0, 1], kWidth = 2, threadsPerWarp = 16} // CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { @@ -66,11 +66,79 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.tar // ----- +// COM: Dot Operand B transpose is supported +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 4], warpsPerCTA = [1, 32], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks = 2, isTransposed = false, order = [1, 0], kWidth = 1, threadsPerWarp = 16}> +// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 8], numBlocks = 1, isTransposed = true, order = [0, 1], kWidth = 2, threadsPerWarp = 16}> +// CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c5120_i32 = arith.constant 5120 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c64_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %4 : i32 + %6 = arith.addi %2, %5 : i32 + %7 = arith.remsi %0, %c64_i32 : i32 + %8 = arith.divsi %7, %4 : i32 + %9 = arith.muli %6, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}} : > + %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array} : > + %11 = arith.muli %8, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}} : > + %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c1_i64, %c5120_i64], [%c0_i32, %11] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr>) : i32 { + // CHECK: %[[A_LOAD:.*]] = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #mma> -> tensor<256x32xf16, #blocked1> + %17 = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma1> -> tensor<32x256xf16, #blocked2> + %18 = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> + %22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 2}>> -> tensor<256x256xf32, #mma2> + %24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + // CHECK: tt.advance {{.*}} : > + %26 = tt.advance %arg5, [%c0_i32, %c32_i32] : > + // CHECK: tt.advance {{.*}} : > + %27 = tt.advance %arg6, [%c32_i32, %c0_i32] : > + scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr> + } + %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array} : > + %15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> + %16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked3> + tt.store %14, %16 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} + +// ----- + // COM: Dot operand A transpose currently not supported by subgroup 2d block io encoding #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 1], warpsPerCTA = [2, 16], order = [0, 1]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> -// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks = 2, isTransposed = false, order = [0, 1], kWidth = 2, threadsPerWarp = 16}> // CHECK: #mma1 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> // CHECK-NOT: #mma2 #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index 254ea42b47..bb3198c3ce 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -297,6 +297,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding", For the layout, the following parameters are required: - `instrShape` : contains the (height, width) block parameters for the block io operation - `numBlocks` : the block count parameter allows a single load to load multiple blocks in row-major order (useful for increasing cache line utilization) + - `isTransposed` : indicates whether the data should be transposed post-load. The `instrShape` describes the shape of the data to load pre-transpose, i.e. if this is true then the output from the instruction (load + tranpose) will be the transposed `instrShape`. - `threadsPerWarp` : currently a scalar, this parameter allows us to support different subgroup / warp configurations. Because the 2d block io operation is a subgroup operation, the size of the subgroup is important in determining the ordering of the loaded tensor. - `warpsPerCTA` : the number of warps per block / subgroups per workgroup and their distribution - `order` : The order within the block, used to determine along which dimension to broadcast. @@ -310,6 +311,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding", "CTALayoutAttr":$CTALayout, ArrayRefParameter<"unsigned">:$instrShape, "unsigned":$numBlocks, + "bool":$isTransposed, ArrayRefParameter<"unsigned">:$order, "unsigned":$kWidth, "unsigned":$threadsPerWarp @@ -317,7 +319,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding", let extraClassDeclaration = extraDistributedDeclaration # [{ SmallVector getRepOrderForOperand(int opIdx) const; - static SmallVector getInstrShapeForLayout(DistributedEncodingTrait layout, ArrayRef shape, bool memoryRowMajor, unsigned kWidth, MLIRContext* context); + static SmallVector getInstrShapeForLayout(DistributedEncodingTrait layout, ArrayRef shape, bool memoryRowMajor, bool isTransposed, unsigned kWidth, MLIRContext* context); }]; let hasCustomAssemblyFormat = 1; diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 138fccf6c0..a0d5996398 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -59,6 +59,17 @@ static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr, return success(); } +static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr, + bool &value, StringRef desc) { + auto boolAttr = mlir::dyn_cast(attr); + if (!boolAttr) { + parser.emitError(parser.getNameLoc(), "expected a bool type in ") << desc; + return failure(); + } + value = boolAttr.getValue(); + return success(); +} + // parse an array of integers static LogicalResult parseIntArrayAttr(AsmParser &parser, const NamedAttribute &attr, @@ -83,6 +94,11 @@ static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, return parseIntAttrValue(parser, attr.getValue(), value, desc); }; +static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr, + bool &value, StringRef desc) { + return parseBoolAttrValue(parser, attr.getValue(), value, desc); +}; + //===----------------------------------------------------------------------===// // Attribute methods //===----------------------------------------------------------------------===// @@ -531,8 +547,8 @@ void maybePrintCTALayout(mlir::MLIRContext *context, mlir::AsmPrinter &printer, LogicalResult Subgroup2DBlockEncodingAttr::verify( function_ref emitError, ArrayRef warpsPerCTA, CTALayoutAttr CTALayout, - ArrayRef instrShape, unsigned numBlocks, ArrayRef order, - unsigned kWidth, unsigned threadsPerWarp) { + ArrayRef instrShape, unsigned numBlocks, bool isTransposed, + ArrayRef order, unsigned kWidth, unsigned threadsPerWarp) { if (instrShape.size() != 2) { return emitError() << "instrShape must be rank 2 but was: " << instrShape.size(); @@ -569,6 +585,7 @@ Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) { std::optional> CTAOrder; SmallVector instrShape; unsigned numBlocks = 0; + bool isTransposed = false; SmallVector order; unsigned kWidth = 0; unsigned threadsPerWarp = 0; @@ -601,6 +618,10 @@ Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) { if (parseUInt(parser, attr, numBlocks, "numBlocks").failed()) return {}; } + if (attr.getName() == "isTransposed") { + if (parseBool(parser, attr, isTransposed, "isTransposed").failed()) + return {}; + } if (attr.getName() == "order") { if (parseIntArrayAttr(parser, attr, order, "order").failed()) return {}; @@ -622,7 +643,7 @@ Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) { return parser.getChecked( parser.getContext(), warpsPerCTA, *CTALayout, instrShape, numBlocks, - order, kWidth, threadsPerWarp); + isTransposed, order, kWidth, threadsPerWarp); } SmallVector Subgroup2DBlockEncodingAttr::getRepOrder() const { @@ -652,9 +673,10 @@ void Subgroup2DBlockEncodingAttr::print(AsmPrinter &printer) const { maybePrintCTALayout(getContext(), printer, getCTALayout(), getRank()); printer << ", instrShape = [" << getInstrShape() - << "], numBlocks=" << getNumBlocks() << ", order=[" << getOrder() - << "], kWidth=" << getKWidth() - << ", threadsPerWarp=" << getThreadsPerWarp() << "}>"; + << "], numBlocks = " << getNumBlocks() + << ", isTransposed = " << getIsTransposed() << ", order = [" + << getOrder() << "], kWidth = " << getKWidth() + << ", threadsPerWarp = " << getThreadsPerWarp() << "}>"; } LinearLayout @@ -664,7 +686,8 @@ Subgroup2DBlockEncodingAttr::toLinearLayout(ArrayRef shape) const { SmallVector Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( DistributedEncodingTrait layout, ArrayRef tensorShape, - bool memoryRowMajor, unsigned kWidth, MLIRContext *context) { + bool memoryRowMajor, bool isTransposed, unsigned kWidth, + MLIRContext *context) { const auto rank = tensorShape.size(); std::optional llEncoding = layout.toLinearLayout(tensorShape); @@ -672,13 +695,6 @@ SmallVector Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( LinearEncodingAttr llAttr = LinearEncodingAttr::get(context, *llEncoding); SmallVector threadOrder = llAttr.getThreadOrder(); - const bool valueRowMajor = - (threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0); - assert((valueRowMajor || - (threadOrder[rank - 2] == 0 && threadOrder[rank - 1] == 1)) && - "Only row_major or column_major is allowed"); - const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; - auto dotEncodingAttr = dyn_cast(layout); const unsigned opIdx = dotEncodingAttr ? dotEncodingAttr.getOpIdx() : 2; @@ -725,7 +741,7 @@ SmallVector Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( unsigned dpasOperandsPerTileY = isOperandA ? numReps[2] : repCluster[dimOuter]; - if (isTransposeRequired) { + if (isTransposed) { std::swap(tileWidth, tileHeight); const unsigned threadsPerWarp = dpasLayout.getThreadsPerWarp(); diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp index 64cc423629..ddf46f5f0a 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp @@ -602,11 +602,15 @@ subgroup2DBlockToLinearLayout(ArrayRef blockShape, assert(rank == layout.getRank() && "unexpected block shape rank, layout rank " "and block shape rank must be equal"); auto dimNames = standardOutDimNames(ctx, rank); - auto loadTileSize = layout.getInstrShape(); + auto loadTileSize = SmallVector(layout.getInstrShape()); + assert(loadTileSize.size() == 2); StringAttr kRegister = S("register"); StringAttr kLane = S("lane"); StringAttr kWarp = S("warp"); + if (layout.getIsTransposed()) + std::swap(loadTileSize[0], loadTileSize[1]); + // Start by creating register/lane bases corresponding to the desired load // tile size auto [regBases, laneBases] = createRegisterLaneBases( diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 7fcc5eb253..7ca9a2822c 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1467,7 +1467,8 @@ struct LoadOpConversion } else { auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( cast(encoding), tensorType.getShape(), - memoryRowMajor, elemSizeInBits / 8, rewriter.getContext()); + memoryRowMajor, isTransposeRequired, elemSizeInBits / 8, + rewriter.getContext()); return std::make_tuple(tileParams[0], tileParams[1], tileParams[2]); } }; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp index db912b34ba..effd387df6 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -287,14 +287,14 @@ class TritonIntelGPUOptimizeBlockIOEncodingPass auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( cast(dotOperandEncoding), - oldTensorType.getShape(), memoryRowMajor, elemSizeInBits / 8, - &getContext()); + oldTensorType.getShape(), memoryRowMajor, isTransposeRequired, + elemSizeInBits / 8, &getContext()); SmallVector instrShape{tileParams[0], tileParams[1]}; const unsigned vBlocks = tileParams[2]; auto subgroup2DBlockEncoding = Subgroup2DBlockEncodingAttr::get( &getContext(), dpasLayout.getWarpsPerCTA(), CTALayout, instrShape, - tileParams[2], + tileParams[2], isTransposeRequired, getOrderForDotOperand(dotOperandEncoding.getOpIdx(), /*rank*/ rank, /*kContig*/ true), kWidth, dpasLayout.getThreadsPerWarp()); diff --git a/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp b/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp index 2e1eee2adf..70ecacf335 100644 --- a/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp +++ b/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp @@ -21,9 +21,10 @@ class LinearLayoutConversionsTest : public ::testing::Test { // Create a Subgroup2DBlockEncoding layout based on a DPAS layout Subgroup2DBlockEncodingAttr - sdb(ArrayRef instrShape, unsigned numBlocks, unsigned kWidth, - ArrayRef warpsPerCTA, ArrayRef repCluster, - ArrayRef blockShape, unsigned opsPerChannel, unsigned opIdx) { + sdb(ArrayRef instrShape, unsigned numBlocks, bool isTransposed, + unsigned kWidth, ArrayRef warpsPerCTA, + ArrayRef repCluster, ArrayRef blockShape, + unsigned opsPerChannel, unsigned opIdx) { auto dpasLayout = DpasEncodingAttr::get( &ctx, /*repeatCount=*/8, /*systolicDepth=*/8, /*executionSize=*/16, opsPerChannel, warpsPerCTA, repCluster, @@ -35,7 +36,7 @@ class LinearLayoutConversionsTest : public ::testing::Test { CTALayoutAttr::get( &ctx, dpasLayout.getCTAsPerCGA(), // TODO: add to DpasLayout? dpasLayout.getCTASplitNum(), dpasLayout.getCTAOrder()), - instrShape, numBlocks, + instrShape, numBlocks, isTransposed, getOrderForDotOperand(opIdx, /*rank*/ 2, /*kContig*/ true), kWidth, dpasLayout.getThreadsPerWarp()); return layout; @@ -51,7 +52,8 @@ TEST_F(LinearLayoutConversionsTest, FP32_32x8x2_M256_N128_K32_A) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*blockShape*/ {256, 32}, - sdb(/*instrShape*/ {32, 8}, /*numBlocks*/ 2, /*kWidth*/ 4, + sdb(/*instrShape*/ {32, 8}, /*numBlocks*/ 2, /*isTransposed*/ false, + /*kWidth*/ 4, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 1}, /*blockShape*/ {256, 32}, /*opsPerChannel*/ 1, /*opIdx*/ 0), /*kWidth*/ 4), @@ -67,7 +69,8 @@ TEST_F(LinearLayoutConversionsTest, FP32_32x16x1_M256_N128_K32_B) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*blockShape*/ {32, 128}, - sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*kWidth*/ 4, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*isTransposed*/ false, + /*kWidth*/ 4, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 1}, /*blockShape*/ {32, 128}, /*opsPerChannel*/ 1, /*opIdx*/ 1), /*kWidth*/ 4), @@ -83,7 +86,8 @@ TEST_F(LinearLayoutConversionsTest, FP16_32x32x1_M256_N32_K32_A) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*blockShape*/ {256, 32}, - sdb(/*instrShape*/ {32, 32}, /*numBlocks*/ 1, /*kWidth*/ 2, + sdb(/*instrShape*/ {32, 32}, /*numBlocks*/ 1, /*isTransposed*/ false, + /*kWidth*/ 2, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, /*blockShape*/ {256, 32}, /*opsPerChannel*/ 2, /*opIdx*/ 0), /*kWidth*/ 2), @@ -99,7 +103,8 @@ TEST_F(LinearLayoutConversionsTest, FP16_32x16x2_M256_N32_K32_A) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*blockShape*/ {256, 32}, - sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 2, /*kWidth*/ 2, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 2, /*isTransposed*/ false, + /*kWidth*/ 2, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, /*blockShape*/ {256, 32}, /*opsPerChannel*/ 2, /*opIdx*/ 0), /*kWidth*/ 2), @@ -114,7 +119,8 @@ TEST_F(LinearLayoutConversionsTest, FP16_32x16x2_M256_N32_K32_A) { TEST_F(LinearLayoutConversionsTest, FP16_32x16x2_M256_N32_K32_B) { EXPECT_EQ(subgroup2DBlockToLinearLayout( /*shape*/ {32, 256}, - sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 2, /*kWidth*/ 2, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 2, + /*isTransposed*/ false, /*kWidth*/ 2, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, /*blockShape*/ {32, 256}, /*opsPerChannel*/ 2, /*opIdx*/ 1), @@ -131,7 +137,8 @@ TEST_F(LinearLayoutConversionsTest, FP16_32x16x2_M256_N32_K32_B) { TEST_F(LinearLayoutConversionsTest, FP16_16x16x2_M256_N32_K32_B) { EXPECT_EQ(subgroup2DBlockToLinearLayout( /*shape*/ {32, 256}, - sdb(/*instrShape*/ {16, 16}, /*numBlocks*/ 2, /*kWidth*/ 2, + sdb(/*instrShape*/ {16, 16}, /*numBlocks*/ 2, + /*isTransposed*/ false, /*kWidth*/ 2, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, /*blockShape*/ {32, 256}, /*opsPerChannel*/ 2, /*opIdx*/ 1), @@ -145,11 +152,32 @@ TEST_F(LinearLayoutConversionsTest, FP16_16x16x2_M256_N32_K32_B) { {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, FP16_32x16x1_M256_N32_K32_TRANSPOSE_B) { + // Note that the instrShape is pre-transpose + EXPECT_EQ( + subgroup2DBlockToLinearLayout( + /*shape*/ {32, 256}, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*isTransposed*/ true, + /*kWidth*/ 2, + /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, + /*blockShape*/ {256, 32}, /*opsPerChannel*/ 2, + /*opIdx*/ 1), + /*kWidth*/ 2), + LinearLayout( + {{S("register"), + {{0, 1}, {1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 128}}}, + {S("lane"), {{0, 2}, {0, 4}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, I8_16x32x1_M64_N128_K32_A) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*shape*/ {64, 32}, - sdb(/*instrShape*/ {16, 32}, /*numBlocks*/ 1, /*kWidth*/ 1, + sdb(/*instrShape*/ {16, 32}, /*numBlocks*/ 1, /*isTransposed*/ false, + /*kWidth*/ 1, /*warpsPerCTA*/ {4, 8}, /*repCluster*/ {2, 1}, /*blockShape*/ {64, 32}, /*opsPerChannel*/ 4, /*opIdx*/ 0), @@ -165,7 +193,8 @@ TEST_F(LinearLayoutConversionsTest, I8_32x32x1_M64_N128_K32_B) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*shape*/ {32, 128}, - sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*kWidth*/ 1, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*isTransposed*/ false, + /*kWidth*/ 1, /*warpsPerCTA*/ {4, 8}, /*repCluster*/ {2, 1}, /*blockShape*/ {32, 128}, /*opsPerChannel*/ 4, /*opIdx*/ 1), From 465b99c7b303dc9f16d0ea0765765fec8088d1ea Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Wed, 18 Jun 2025 21:16:46 +0000 Subject: [PATCH 12/20] Compute final load shape in Subgroup2DBlockIO layout (except for oneMatrixPerBT) --- .../optimize-block-io-encoding.mlir | 6 +-- .../lib/Dialect/TritonIntelGPU/IR/Dialect.cpp | 5 +++ .../LoadStoreOpToLLVM.cpp | 39 ++++++++++++------- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir index 0e135a956e..01f88d4f57 100644 --- a/test/TritonIntelGPU/optimize-block-io-encoding.mlir +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -3,8 +3,8 @@ #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> -// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks = 2, isTransposed = false, order = [1, 0], kWidth = 1, threadsPerWarp = 16} -// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks = 2, isTransposed = false, order = [0, 1], kWidth = 2, threadsPerWarp = 16} +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks = 2, isTransposed = false, order = [1, 0], kWidth = 1, threadsPerWarp = 16} +// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks = 2, isTransposed = false, order = [0, 1], kWidth = 2, threadsPerWarp = 16} // CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { @@ -138,7 +138,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.tar #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 1], warpsPerCTA = [2, 16], order = [0, 1]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> -// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks = 2, isTransposed = false, order = [0, 1], kWidth = 2, threadsPerWarp = 16}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks = 2, isTransposed = false, order = [0, 1], kWidth = 2, threadsPerWarp = 16}> // CHECK: #mma1 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> // CHECK-NOT: #mma2 #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index a0d5996398..8fe6d88427 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -754,6 +754,11 @@ SmallVector Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( dpasOperandsPerTileY = 1; } + // PVC 2D load supports 32 rows at most. Load multiple dot operands in by + // enlarging the tileHeight. + dpasOperandsPerTileX = std::min(dpasOperandsPerTileX, 32 / tileHeight); + tileHeight = tileHeight * dpasOperandsPerTileX; + // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands // by enlarging the number of blocks. const unsigned totalBytesPerRowPerDPASOp = tileWidth * kWidth; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 7ca9a2822c..40d8872734 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1437,17 +1437,25 @@ struct LoadOpConversion LLVM_DEBUG(llvm::dbgs() << "Tensor type for op " << int(opIdx) << ": " << tensorType << "\n"); - Attribute encoding = tensorType.getEncoding(); - // TODO: this gives us the linear layour corresponding - // to the subgroup 2d block encoding, not the dpas encoding... - std::optional llEncoding = - cast(encoding).toLinearLayout( - tensorType.getShape()); - assert(llEncoding.has_value() && "invalid dot layout to linear layout"); + auto encoding = cast(tensorType.getEncoding()); + LinearLayout llEncoding = encoding.toLinearLayout(tensorType.getShape()); LinearEncodingAttr llAttr = - LinearEncodingAttr::get(rewriter.getContext(), *llEncoding); + LinearEncodingAttr::get(rewriter.getContext(), llEncoding); SmallVector threadOrder = llAttr.getThreadOrder(); size_t rank = threadOrder.size(); + + SmallVector sizePerThread = llAttr.getSizePerThread(); + llvm::errs() << "sizePerThread:\n"; + for (auto i : sizePerThread) { + llvm::errs() << i << "\n"; + } + + SmallVector shapePerCTATile = llAttr.getShapePerCTATile(); + llvm::errs() << "shapePerCTATile:\n"; + for (auto i : shapePerCTATile) { + llvm::errs() << i << "\n"; + } + const bool valueRowMajor = (threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0); assert((valueRowMajor || @@ -1473,6 +1481,11 @@ struct LoadOpConversion } }; auto [tileHeight, tileWidth, vBlocks] = getTileParams(); + LLVM_DEBUG({ + llvm::dbgs() << "tileHeight = " << tileHeight << "\n"; + llvm::dbgs() << "tileWidth = " << tileWidth << "\n"; + llvm::dbgs() << "vBlocks = " << vBlocks << "\n"; + }); const ArrayRef tensorShape = tensorType.getShape(); unsigned numElems = getTotalElemsPerThread(resultType); @@ -1625,6 +1638,8 @@ struct LoadOpConversion Type unpackedDPASOperandType = LLVM::getVectorType( typeConverter->convertType(eltTy), elemsPerLanePerDPASInst); + const unsigned origTileHeight = elemsPerDPASInst[threadOrder[rank - 1]]; + // By default, use the unpacked type for the 2D load result type. Type loadResultElemType = typeConverter->convertType(eltTy); bool usePackedType = false; @@ -1803,11 +1818,9 @@ struct LoadOpConversion numOperandsPer2DloadN = 1; } - // TODO: move this logic to the instr shape computation - // PVC 2D load supports 32 rows at most. Load multiple dot operands in by - // enlarging the tileHeight. - numOperandsPer2DLoadM = std::min(numOperandsPer2DLoadM, 32 / tileHeight); - tileHeight = tileHeight * numOperandsPer2DLoadM; + numOperandsPer2DLoadM = + std::min(numOperandsPer2DLoadM, 32 / origTileHeight); + // tileHeight = tileHeight * numOperandsPer2DLoadM; // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands // by enlarging the vBlocks. From e85d57781cd6df136de0870d934a23408c6b2e37 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Thu, 19 Jun 2025 00:23:06 +0000 Subject: [PATCH 13/20] remove debug code --- .../lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 40d8872734..64121fd572 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1444,18 +1444,6 @@ struct LoadOpConversion SmallVector threadOrder = llAttr.getThreadOrder(); size_t rank = threadOrder.size(); - SmallVector sizePerThread = llAttr.getSizePerThread(); - llvm::errs() << "sizePerThread:\n"; - for (auto i : sizePerThread) { - llvm::errs() << i << "\n"; - } - - SmallVector shapePerCTATile = llAttr.getShapePerCTATile(); - llvm::errs() << "shapePerCTATile:\n"; - for (auto i : shapePerCTATile) { - llvm::errs() << i << "\n"; - } - const bool valueRowMajor = (threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0); assert((valueRowMajor || From 9985b768e42f7cf8955792c75d9412a627344927 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Thu, 19 Jun 2025 13:25:45 +0000 Subject: [PATCH 14/20] Remove legacy tile layout code --- .../LoadStoreOpToLLVM.cpp | 208 +----------------- 1 file changed, 8 insertions(+), 200 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 64121fd572..d2ebcfa686 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1711,67 +1711,6 @@ struct LoadOpConversion LLVM_DEBUG(llvm::dbgs() << "dpasTileToPackedIndicesRatio = " << dpasTileToPackedIndicesRatio << "\n"); - // Create the linear layout for the load. - // First, we create a tile layout corresponding to a single invocation of - // the DPAS instruction across all threads/work-items in a sub-group. The - // layout will later be expanded to cover multiple DPAS invocations - // (iteration) and multiple loads (load). - StringAttr kOffset = S("offset"); - StringAttr kIteration = S("iteration"); - StringAttr kLoad = S("load"); - - auto createTileLayout = [&](const SmallVectorImpl &threadOrder, - SmallVector tileShape) { - auto outDimNames = standardOutDimNames(ctx, tensorShape.size()); - LinearLayout layout = LinearLayout::empty(); - SmallVector kOffsetDims; - unsigned totalOffsets = 1; - assert(tileShape.size() == 2); // only support 2D layouts for now - - if (isTransposeRequired && opIdx == DpasEncodingAttr::OpIdx::OperandB) { - const unsigned widthDim = threadOrder[rank - 2]; - const unsigned origTileWidth = tileShape[widthDim]; - tileShape[widthDim] = origTileWidth / (32 / elemSizeInBits); - } - - for (int i = 0; i < tileShape.size(); i++) { - int dim = threadOrder[i]; - StringAttr kOffset = S("offset" + std::to_string(dim)); - - kOffsetDims.push_back(kOffset); - - assert(llvm::isPowerOf2_32(tileShape[dim])); - // reduce the offset dimension size by the number of elements packed in - // a single slot for the row wise dimension - const unsigned offsetDimSize = - (!isTransposeRequired && dim == 0) - ? tileShape[dim] / dpasTileToPackedIndicesRatio - : tileShape[dim]; - layout *= - LinearLayout::identity1D(offsetDimSize, kOffset, outDimNames[dim]); - totalOffsets *= offsetDimSize; - } - SmallVector newDims; - newDims.append(kOffsetDims.begin(), kOffsetDims.end()); - auto ret = layout.transposeIns(newDims); - ret = ret.transposeOuts(outDimNames); - return ret.reshapeIns({{kOffset, totalOffsets}}); - }; - auto tileLayout = createTileLayout(threadOrder, elemsPerDPASInst); - - LLVM_DEBUG({ - llvm::dbgs() << "Block load tile layout: " << tileLayout << "\n"; - for (size_t i = 0; i < tileLayout.getOutDimSize(dimOuterStr) * - tileLayout.getOutDimSize(dimInnerStr); - i += tileLayout.getOutDimSize(S("dim1"))) { - auto tensorVals = tileLayout.apply({{kOffset, i}}); - assert(tensorVals.size() == 2); - llvm::dbgs() << i << " : " << tensorVals[0].second << ", " - << tensorVals[1].second << "\n"; - } - llvm::dbgs() << "tile layout done\n"; - }); - unsigned numOperandsOuterDimPerLoad = 1; unsigned numOperandsInnerDimPerLoad = 1; @@ -1829,33 +1768,6 @@ struct LoadOpConversion llvm::dbgs() << "vBlocks = " << vBlocks << "\n"; }); - tileLayout *= LinearLayout::identity1D(numOperandsOuterDimPerLoad, - kIteration, dimOuterStr); - tileLayout *= - LinearLayout::identity1D(isTransposeRequired && oneMatrixPerLoadForBT - ? 1 - : numOperandsInnerDimPerLoad, - kIteration, dimInnerStr); - - LLVM_DEBUG({ - llvm::dbgs() << "Block load tile layout after adding iterations: " - << tileLayout << "\n"; - - for (size_t itr = 0; itr < tileLayout.getInDimSize(kIteration); itr++) { - auto printTileLayoutVals = [&](const size_t offset) { - auto tensorVals = - tileLayout.apply({{kOffset, offset}, {kIteration, itr}}); - assert(tensorVals.size() == 2); - llvm::dbgs() << itr << ", " << offset << " : " << tensorVals[0].second - << ", " << tensorVals[1].second << "\n"; - }; - - printTileLayoutVals(0); - printTileLayoutVals(tileLayout.getInDimSize(kOffset) - 1); - } - llvm::dbgs() << "\n"; - }); - if (isTransposeRequired) std::swap(numOperandsOuterDimPerLoad, numOperandsInnerDimPerLoad); @@ -1892,87 +1804,6 @@ struct LoadOpConversion llvm::dbgs() << "numRepInner = " << numRepInner << "\n"; }); - // For the kLoad dimension we create the basis vector directly, which allows - // us to control the stride between loads and create a non-surjective - // layout. - auto bases = tileLayout.getBases(); - std::vector> newLoadBases; - - SmallVector> outDims; - for (auto [name, size] : - llvm::zip(tileLayout.getOutDimNames(), tileLayout.getOutDimSizes())) { - outDims.push_back(std::make_pair(name, size)); - } - assert(outDims[0].first == S("dim0")); - assert(outDims[1].first == S("dim1")); - - for (size_t i = 0; - i < llvm::Log2_32(numRepInner / numOperandsInnerDimPerLoad); i++) { - newLoadBases.push_back({0, static_cast((1 << i) * repKStride * - numOperandsInnerDimPerLoad)}); - outDims[1].second *= repKStride * numOperandsInnerDimPerLoad; - } - for (size_t i = 0; i < llvm::Log2_32(numLoadPerOutRepCluster); i++) { - newLoadBases.push_back({static_cast((1 << i) * repStride), 0}); - outDims[0].second *= repStride; - } - for (size_t i = 0; i < llvm::Log2_32(numRepOuter); i++) { - newLoadBases.push_back({static_cast((1 << i) * repOuterStride), 0}); - outDims[0].second *= repOuterStride; - } - - LLVM_DEBUG({ - llvm::dbgs() << "Created Load Bases:\n"; - for (auto &base : newLoadBases) { - assert(base.size() == 2); - llvm::dbgs() << base[0] << ", " << base[1] << "\n"; - } - }); - - LLVM_DEBUG({ - llvm::dbgs() << "New tile layout dimensions after adding load bases:\n"; - for (size_t i = 0; i < outDims.size(); i++) { - llvm::dbgs() << outDims[i].first << " = " << outDims[i].second << "\n"; - } - }); - - // Disable building the load layout if we are not going to use it. Building - // the layout manually can cause an error which would abort the pass - // pipeline and block us from getting debug info. - if (useTileLoadLinearLayout) { - // add the bases to the map and replace the tile layout with the new - // layout - bases[kLoad] = newLoadBases; - tileLayout = LinearLayout(bases, outDims, - /*requiredSurjective=*/false); - } else { - // when linear layouts are disabled generate a single load, so we can have - // some reference for linear layout output without generating a layout - // that could abort the pass pipeline - tileLayout *= LinearLayout::identity1D(1, kLoad, dimOuterStr); - } - - LLVM_DEBUG({ - llvm::dbgs() << "Block load tile layout after adding loads: " - << tileLayout << "\n"; - for (size_t load = 0; load < tileLayout.getInDimSize(kLoad); load++) { - for (size_t itr = 0; itr < tileLayout.getInDimSize(kIteration); itr++) { - auto printTileLayoutVals = [&](const size_t offset) { - auto tensorVals = tileLayout.apply( - {{kOffset, offset}, {kIteration, itr}, {kLoad, load}}); - assert(tensorVals.size() == 2); - llvm::dbgs() << load << ", " << itr << ", " << offset << " : " - << tensorVals[0].second << ", " << tensorVals[1].second - << "\n"; - }; - - printTileLayoutVals(0); - printTileLayoutVals(tileLayout.getInDimSize(kOffset) - 1); - } - llvm::dbgs() << "\n"; - } - }); - Value pitch; if (memoryRowMajor) { pitch = b.trunc(i32_ty, rowStride); @@ -2022,17 +1853,6 @@ struct LoadOpConversion k / numOperandsInnerDimPerLoad; LLVM_DEBUG(llvm::dbgs() << "loadIdx: " << loadIdx << "\n"); - const auto offset = tileLayout.apply( - {{kOffset, 0}, {kIteration, 0}, {kLoad, loadIdx}}); - assert(offset.size() == 2); - - const auto layoutOffsetX = offset[dimInner].second; - const auto layoutOffsetY = offset[dimOuter].second; - LLVM_DEBUG({ - llvm::dbgs() << "x offset ll: " << layoutOffsetX << "\n"; - llvm::dbgs() << "y offset ll: " << layoutOffsetY << "\n"; - }); - Value offsetX, offsetY; switch (opIdx) { case DpasEncodingAttr::OpIdx::OperandA: { @@ -2041,16 +1861,10 @@ struct LoadOpConversion llvm::dbgs() << "y offset: " << outer * repOuterStride + rep * repStride << "\n"; }); - if (useTileLoadLinearLayout) { - offsetY = b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), - b.i32_val(layoutOffsetY)); - offsetX = b.i32_val(layoutOffsetX); - } else { - offsetY = - b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), - b.i32_val(outer * repOuterStride + rep * repStride)); - offsetX = b.i32_val(k * repKStride); - } + offsetY = + b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), + b.i32_val(outer * repOuterStride + rep * repStride)); + offsetX = b.i32_val(k * repKStride); } break; case DpasEncodingAttr::OpIdx::OperandB: { LLVM_DEBUG({ @@ -2058,16 +1872,10 @@ struct LoadOpConversion << outer * repOuterStride + rep * repStride << "\n"; llvm::dbgs() << "y offset: " << k * repKStride << "\n"; }); - if (useTileLoadLinearLayout) { - offsetX = b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), - b.i32_val(layoutOffsetX)); - offsetY = b.i32_val(layoutOffsetY); - } else { - offsetX = - b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), - b.i32_val(outer * repOuterStride + rep * repStride)); - offsetY = b.i32_val(k * repKStride); - } + offsetX = + b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), + b.i32_val(outer * repOuterStride + rep * repStride)); + offsetY = b.i32_val(k * repKStride); } break; case DpasEncodingAttr::OpIdx::OperandC: { llvm_unreachable("unexpected OpIdx::OperandC"); From 6652cacac1c485d982857d1e4545e9bd08b1aefc Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Fri, 20 Jun 2025 15:49:00 +0000 Subject: [PATCH 15/20] Separate subgroup 2d block encoding lowering from dpas block io lowering 1/? --- .../LoadStoreOpToLLVM.cpp | 238 +++++++++++++++++- 1 file changed, 236 insertions(+), 2 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index d2ebcfa686..2c6d2796d7 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1408,6 +1408,233 @@ struct LoadOpConversion oneMatrixPerLoadForBT(oneMatrixPerLoadForBT), useTileLoadLinearLayout(useTileLoadLinearLayout) {} +LogicalResult rewriteSubgroup2DBlockEncodingLoad( + triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value ptr = op.getPtr(); + assert(isTensorPointerType(ptr.getType()) && + "Expecting tensor pointer type"); + + Type resultType = op.getType(); + auto tensorType = cast(resultType); + assert(hasSubgroup2DBlockEncoding(tensorType) && + "load op passed to subgroup 2d block encoding load codegen must " + "have subgroup 2d block encoding"); + + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value mask = op.getMask(); + Value other = op.getOther(); + + const bool memoryRowMajor = isMemoryRowMajor(op); + + auto encoding = cast(tensorType.getEncoding()); + LinearLayout loadLayout = encoding.toLinearLayout(tensorType.getShape()); + LinearEncodingAttr llAttr = + LinearEncodingAttr::get(rewriter.getContext(), loadLayout); + SmallVector threadOrder = llAttr.getThreadOrder(); + size_t rank = threadOrder.size(); + + if (rank != 2) { + op.emitWarning( + "Subgroup 2D Block Encoding only supports rank 2 tensors."); + } + + const bool valueRowMajor = + (threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0); + assert((valueRowMajor || + (threadOrder[rank - 2] == 0 && threadOrder[rank - 1] == 1)) && + "Only row_major or column_major is allowed"); + const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; + + if (isTransposeRequired) { + // hrmm... + loadLayout = loadLayout.transposeOuts(llvm::to_vector(llvm::reverse(loadLayout.getOutDimNames()))); + } + + auto instrShape = encoding.getInstrShape(); + const unsigned tileHeight = instrShape[0]; + const unsigned tileWidth = instrShape[1]; + const unsigned numBlocks = encoding.getNumBlocks(); + LLVM_DEBUG({ + llvm::dbgs() << "tileHeight = " << tileHeight << "\n"; + llvm::dbgs() << "tileWidth = " << tileWidth << "\n"; + llvm::dbgs() << "numBlocks = " << numBlocks << "\n"; + }); + + const ArrayRef tensorShape = tensorType.getShape(); + + auto warpsPerCTA = encoding.getWarpsPerCTA(); + SmallVector dpasWarpsOrder = + getMatrixOrder(warpsPerCTA.size(), /*rowMajor*/ true); + + Value warpId = rewriter.create( + loc, i32_ty, + rewriter.create(loc, /*upperBound=*/nullptr)); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder); + + Type eltTy = tensorType.getElementType(); + + LLVMTypeConverter *typeConverter = getTypeConverter(); + Type loadResultElemType = typeConverter->convertType(eltTy); + LLVM_DEBUG(llvm::dbgs() + << "loadResultElemType: " << loadResultElemType << "\n"); + + auto printVector = [](auto vector, auto name) { + llvm::errs() << name << "\n"; + for (auto i : vector) { + llvm::errs() << i << "\n"; + } + }; + + printVector(threadOrder, "threadOrder"); + + auto warpOrder = llAttr.getWarpOrder(); + printVector(warpOrder, "warpOrder"); + + auto shapePerCTA = getShapePerCTA(tensorType); + printVector(shapePerCTA, "shapePerCTA"); + + unsigned innerBlockSize = shapePerCTA.back(); + llvm::errs() << "innerBlockSize = " << innerBlockSize << "\n"; + unsigned contigDimSize = tileWidth * numBlocks; // true? + llvm::errs() << "contigDimSize = " << contigDimSize << "\n"; + + unsigned numMessagesPerRow = ceil(innerBlockSize, contigDimSize); + llvm::errs() << "numMessagesPerRow = " << numMessagesPerRow << "\n"; + + auto ctaSplitNum = getCTASplitNum(llAttr); + printVector(ctaSplitNum, "ctaSplitNum"); + + auto ctasPerCGA = getCTAsPerCGA(llAttr); + printVector(ctasPerCGA, "ctasPerCGA"); + + // legacy, don't use! + // auto ctaShape = llAttr.getShapePerCTATile(); + // printVector(ctaShape, "shape per CTA tile"); + + MLIRContext *ctx = rewriter.getContext(); + auto kRegister = StringAttr::get(ctx, "register"); + auto kLane = StringAttr::get(ctx, "lane"); + auto kWarp = StringAttr::get(ctx, "warp"); + auto kBlock = StringAttr::get(ctx, "block"); + + auto basesPerReg = llAttr.basesPerDim(kRegister, /*skipBroadcast*/ true); + printVector(basesPerReg, "bases per reg"); + + auto basesPerRegNoBroadcast = + llAttr.basesPerDim(kRegister, /*skipBroadcast*/ false); + printVector(basesPerRegNoBroadcast, "bases per reg no broadcast"); + + auto basesPerLane = llAttr.basesPerDim(kLane); // threads per warp + printVector(basesPerLane, "bases per lane"); + + auto basesPerWarp = + llAttr.basesPerDim(kWarp, /*skipBrodcast*/ false); // warps per cta + printVector(basesPerWarp, "bases per warp"); + + auto threadsPerWarp = llAttr.getThreadsPerWarp(); + printVector(threadsPerWarp, "threads per warp"); + + // auto warpsPerCTA = llAttr.getWarpsPerCTA(); + printVector(warpsPerCTA, "warps per cta"); + + auto contigPerThread = llAttr.getContigPerThread(); + printVector(contigPerThread, "contigPerThread"); + + auto contigPerWarp = llAttr.getContigPerWarp(); + printVector(contigPerWarp, "contigPerWarp"); + + // unsigned vec = getVectorSize(ptr); + // auto vecTensor = getVectorSize(tensorType); + // llvm::errs() << "vec tensor: " << vecTensor << "\n"; + unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType()); + llvm::errs() << "elemsPerThread: " << elemsPerThread << "\n"; + unsigned elemsPerThreadTensorTy = getTotalElemsPerThread(tensorType); + llvm::errs() << "elemsPerThread tensor type: " << elemsPerThreadTensorTy + << "\n"; + + unsigned vec = tileHeight * numBlocks; + llvm::errs() << "vec: " << vec << "\n"; + unsigned numValuesPerLoad = vec / encoding.getKWidth(); + if (encoding.getKWidth() == 2) + loadResultElemType = i32_ty; // HACK + + Type load2DGenXType = + LLVM::getVectorType(loadResultElemType, numValuesPerLoad); + // Note: these end up being a mix of float/int vector types... + llvm::errs() << "load2DGenXType = " << load2DGenXType << "\n"; + + auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX, + offsetBaseY] = + getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter); + + Value pitch; + if (memoryRowMajor) { + pitch = b.trunc(i32_ty, rowStride); + } else { + // Column major memory. We need to swap the width and height because HW + // only support row major memory layout. + pitch = b.trunc(i32_ty, colStride); + std::swap(baseWidth, baseHeight); + } + baseWidth = b.trunc(i32_ty, baseWidth); + baseHeight = b.trunc(i32_ty, baseHeight); + + unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); + Value elemSizeInBytes = b.i32_val(elemSizeInBits / 8); + + // Dispatch the load instructions from the perspective of a single lane. + unsigned numElems = elemsPerThreadTensorTy; + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + llvm::errs() << "dispatch load " << vecStart << "\n"; + + auto offset = loadLayout.apply({{kRegister, vecStart}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + assert(offset.size() == 2); + LLVM_DEBUG({ + llvm::dbgs() << "offset: " << offset[0].second << ", " << offset[1].second << "\n"; + }); + + // Lane ID doesn't matter. + Value zero = b.i32_val(0); + Value regIdVal = b.i32_val(vecStart); + auto offsetValues = applyLinearLayout(loc, rewriter, loadLayout, {{kRegister, regIdVal}, {kLane, zero}, {kWarp, warpId}, {kBlock, zero}}); + assert(offsetValues.size() == 2); + + Value offsetX = b.add(offsetValues[0].second, offsetBaseX); + Value offsetY = b.add(offsetValues[1].second, offsetBaseY); + +#if 0 + auto load2dOp = rewriter.create( + loc, load2DGenXType, + /*ptr*/ base, + /*base_width*/ b.mul(baseWidth, elemSizeInBytes), + /*base_height*/ baseHeight, + /*base_pitch*/ b.mul(pitch, elemSizeInBytes), + /*x*/ b.trunc(i32_ty, offsetX), + /*y*/ b.trunc(i32_ty, offsetY), + /*elem_size_in_bits*/ elemSizeInBits, + /*tile_width*/ tileWidth, + /*tile_height*/ tileHeight, + /*v_blocks*/ vBlocks, + /*transpose*/ isTransposeRequired, + /*vnni_transform*/ + (usePackedType && !isOperandA && !isTransposeRequired && + originalElemBits != 32)); + if (failed(load2dOp.verify())) { + // delete the op so that the verifier will not abort the pass + // pipeline later, as we can fail this path and try a different + // approach. + rewriter.eraseOp(load2dOp); + return failure(); + } +#endif + } + + return failure(); + } + LogicalResult rewriteTensorPointerLoad(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -1418,12 +1645,19 @@ struct LoadOpConversion if (!isLoadCandidate(op)) return failure(); + Type resultType = op.getType(); + auto tensorType = cast(resultType); + if (hasSubgroup2DBlockEncoding(tensorType)) { + auto ret = rewriteSubgroup2DBlockEncodingLoad(op, adaptor, rewriter); + // little hack to make incremental progress + if (!ret.failed()) + return ret; + } + Location loc = op.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); Value mask = op.getMask(); Value other = op.getOther(); - Type resultType = op.getType(); - auto tensorType = cast(resultType); const bool memoryRowMajor = isMemoryRowMajor(op); From 13789340b0c810fc024aee9d72e0ac62e11aded0 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Wed, 2 Jul 2025 20:12:30 +0000 Subject: [PATCH 16/20] checkpoint: a matrix working --- .../LoadStoreOpToLLVM.cpp | 176 ++++++++++++++---- 1 file changed, 144 insertions(+), 32 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 2c6d2796d7..7f9d41de58 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -23,6 +23,20 @@ using namespace mlir::triton::gpu::intel; namespace { +Value llPrintf(StringRef msg, ValueRange args, ArrayRef isSigned, + ConversionPatternRewriter &rewriter, const mlir::triton::intel::TargetInfo& targetInfo) { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = targetInfo.getGlobalStringStart( + rewriter.getUnknownLoc(), rewriter, "printfFormat_", msgNewline, + /*addressSpace=*/TritonGEN::kUniformConstant); + targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args, + isSigned); + return msgValue; +} + Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) { auto tb = TritonLLVMOpBuilder(loc, rewriter); if (a && b) { @@ -1421,6 +1435,8 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( "load op passed to subgroup 2d block encoding load codegen must " "have subgroup 2d block encoding"); + LLVM_DEBUG(llvm::dbgs() << "Lowering load op with Subgroup 2D Block Encoding: " << op << "\n"); + Location loc = op.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); Value mask = op.getMask(); @@ -1430,6 +1446,7 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( auto encoding = cast(tensorType.getEncoding()); LinearLayout loadLayout = encoding.toLinearLayout(tensorType.getShape()); + llvm::errs() << "loadLayout: " << loadLayout << "\n"; LinearEncodingAttr llAttr = LinearEncodingAttr::get(rewriter.getContext(), loadLayout); SmallVector threadOrder = llAttr.getThreadOrder(); @@ -1438,18 +1455,13 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( if (rank != 2) { op.emitWarning( "Subgroup 2D Block Encoding only supports rank 2 tensors."); + return failure(); } - const bool valueRowMajor = - (threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0); - assert((valueRowMajor || - (threadOrder[rank - 2] == 0 && threadOrder[rank - 1] == 1)) && - "Only row_major or column_major is allowed"); - const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; - + const bool isTransposeRequired = encoding.getIsTransposed(); if (isTransposeRequired) { // hrmm... - loadLayout = loadLayout.transposeOuts(llvm::to_vector(llvm::reverse(loadLayout.getOutDimNames()))); + // loadLayout = loadLayout.transposeOuts(llvm::to_vector(llvm::reverse(loadLayout.getOutDimNames()))); } auto instrShape = encoding.getInstrShape(); @@ -1465,21 +1477,52 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( const ArrayRef tensorShape = tensorType.getShape(); auto warpsPerCTA = encoding.getWarpsPerCTA(); + LLVM_DEBUG({ + llvm::dbgs() << "warpsPerCTA: " << warpsPerCTA[0] << ", " + << warpsPerCTA[1] << "\n"; + }); SmallVector dpasWarpsOrder = getMatrixOrder(warpsPerCTA.size(), /*rowMajor*/ true); + const unsigned threadsPerWarp = encoding.getThreadsPerWarp(); + LLVM_DEBUG(llvm::dbgs() << "threadsPerWarp = " << threadsPerWarp << "\n"); +#if 0 + Type offsetType = i32_ty; // getTypeConverter()->getIndexType(); + llvm::errs() << "offsetType = " << offsetType << "\n"; + Value subGroupId = getValueOrCreateCastToIndexLike( + rewriter, loc, offsetType, + rewriter.create( + loc, /*upper_bound=*/IntegerAttr{})); + Value sgStride = rewriter.create( + loc, offsetType, threadsPerWarp); + Value outerDimWarpId = subGroupId; // b.mul(sgStride, subGroupId); +#else + unsigned dimOuter = 0; // TODO + Value warpId = rewriter.create( loc, i32_ty, rewriter.create(loc, /*upperBound=*/nullptr)); SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder); + // unsigned outerDimRequiredWarpNum = mlir::ceil( + // tensorShape[dimOuter], warpShape[dimOuter]); // ceil of ratio + unsigned outerDimWarpNum = warpsPerCTA[dimOuter]; //, outerDimRequiredWarpNum); + LLVM_DEBUG(llvm::dbgs() + << "outerDimWarpNum = " << outerDimWarpNum << "\n"); +#if 1 + Value outerDimWarpId = + b.urem(multiDimWarpId[dimOuter], b.i32_val(outerDimWarpNum)); +#else + Value outerDimWarpId = multiDimWarpId[dimOuter]; +#endif +#endif - Type eltTy = tensorType.getElementType(); + Type eltTy = tensorType.getElementType(); LLVMTypeConverter *typeConverter = getTypeConverter(); - Type loadResultElemType = typeConverter->convertType(eltTy); + Type valueElemTy = typeConverter->convertType(eltTy); LLVM_DEBUG(llvm::dbgs() - << "loadResultElemType: " << loadResultElemType << "\n"); + << "valueElemTy: " << valueElemTy << "\n"); auto printVector = [](auto vector, auto name) { llvm::errs() << name << "\n"; @@ -1534,8 +1577,8 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( llAttr.basesPerDim(kWarp, /*skipBrodcast*/ false); // warps per cta printVector(basesPerWarp, "bases per warp"); - auto threadsPerWarp = llAttr.getThreadsPerWarp(); - printVector(threadsPerWarp, "threads per warp"); + auto threadsPerWarp2 = llAttr.getThreadsPerWarp(); + printVector(threadsPerWarp2, "threads per warp"); // auto warpsPerCTA = llAttr.getWarpsPerCTA(); printVector(warpsPerCTA, "warps per cta"); @@ -1555,14 +1598,25 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( llvm::errs() << "elemsPerThread tensor type: " << elemsPerThreadTensorTy << "\n"; + auto width = encoding.getKWidth(); + unsigned vec = tileHeight * numBlocks; llvm::errs() << "vec: " << vec << "\n"; - unsigned numValuesPerLoad = vec / encoding.getKWidth(); - if (encoding.getKWidth() == 2) - loadResultElemType = i32_ty; // HACK + unsigned numValuesPerLoad = vec / width; + + unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); + Value elemSizeInBytes = b.i32_val(elemSizeInBits / 8); + llvm::errs() << "elemSizeInBits: " << elemSizeInBits << "\n"; + llvm::errs() << "elemSizeInBytes: " << elemSizeInBytes << "\n"; + + Type packedElemTy = valueElemTy; + if (width == 2) { + packedElemTy = i32_ty; // HACK + // elemSizeInBits = 32; + } Type load2DGenXType = - LLVM::getVectorType(loadResultElemType, numValuesPerLoad); + LLVM::getVectorType(packedElemTy, numValuesPerLoad); // Note: these end up being a mix of float/int vector types... llvm::errs() << "load2DGenXType = " << load2DGenXType << "\n"; @@ -1582,11 +1636,22 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( baseWidth = b.trunc(i32_ty, baseWidth); baseHeight = b.trunc(i32_ty, baseHeight); - unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); - Value elemSizeInBytes = b.i32_val(elemSizeInBits / 8); - // Dispatch the load instructions from the perspective of a single lane. unsigned numElems = elemsPerThreadTensorTy; + llvm::errs() << "numElems = " << numElems << "\n"; + + Value zero = b.i32_val(0); + auto baseOffsetForWarp = applyLinearLayout(loc, rewriter, loadLayout, {{kRegister, zero}, {kLane, zero}, {kWarp, warpId}, {kBlock, zero}}); + assert(baseOffsetForWarp.size() == 2); + + // probably always 0? just for fun look at warp 1 + auto baseOffset = loadLayout.apply({{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + assert(baseOffset.size() == 2); + LLVM_DEBUG({ + llvm::dbgs() << "base offset: " << baseOffset[0].second << ", " << baseOffset[1].second << "\n"; + }); + + SmallVector loadedVals; for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { llvm::errs() << "dispatch load " << vecStart << "\n"; @@ -1594,18 +1659,35 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( assert(offset.size() == 2); LLVM_DEBUG({ llvm::dbgs() << "offset: " << offset[0].second << ", " << offset[1].second << "\n"; + llvm::dbgs() << "offset - base offset: " << offset[0].second - baseOffset[0].second << ", " << offset[1].second - baseOffset[1].second << "\n"; }); // Lane ID doesn't matter. - Value zero = b.i32_val(0); Value regIdVal = b.i32_val(vecStart); auto offsetValues = applyLinearLayout(loc, rewriter, loadLayout, {{kRegister, regIdVal}, {kLane, zero}, {kWarp, warpId}, {kBlock, zero}}); assert(offsetValues.size() == 2); - Value offsetX = b.add(offsetValues[0].second, offsetBaseX); - Value offsetY = b.add(offsetValues[1].second, offsetBaseY); - #if 0 + Value offsetX = b.add(b.mul(outerDimWarpId, b.i32_val(/*warpOuterStride*/4)), offsetValues[0].second); + Value offsetY = offsetValues[1].second; +#else + Value offsetX = b.add(b.mul(outerDimWarpId, b.i32_val(/*warpOuterStride*/32)), b.sub(offsetValues[0].second, baseOffsetForWarp[0].second)); + Value offsetY = b.sub(offsetValues[1].second, baseOffsetForWarp[1].second); +#endif + + // llPrintf("warp: %d, x: %ld (base %ld), y: %ld (base %ld)\n", {outerDimWarpId, offsetX, offsetBaseX, offsetY, offsetBaseY}, {true, true, true, true, true}, rewriter, targetInfo); +#if 1 + std::swap(offsetX, offsetY); +#else + if (warpOrder[0]) { + llvm::errs() << "swap x, y\n"; + std::swap(offsetX, offsetY); + } +#endif + + offsetX = b.add(offsetX, offsetBaseX); + offsetY = b.add(offsetY, offsetBaseY); + auto load2dOp = rewriter.create( loc, load2DGenXType, /*ptr*/ base, @@ -1617,22 +1699,45 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( /*elem_size_in_bits*/ elemSizeInBits, /*tile_width*/ tileWidth, /*tile_height*/ tileHeight, - /*v_blocks*/ vBlocks, + /*v_blocks*/ numBlocks, /*transpose*/ isTransposeRequired, - /*vnni_transform*/ - (usePackedType && !isOperandA && !isTransposeRequired && - originalElemBits != 32)); + /*vnni_transform*/width > 1 && !isTransposeRequired); + llvm::errs() << "Generated load2dOp: " << load2dOp << "\n"; if (failed(load2dOp.verify())) { // delete the op so that the verifier will not abort the pass // pipeline later, as we can fail this path and try a different // approach. + assert(false); rewriter.eraseOp(load2dOp); return failure(); } -#endif - } + + + llvm::errs() << "llvm type: " << load2DGenXType << "\n"; + Type llvmResultStructTy = typeConverter->convertType(op.getType()); + + // Extract and store return values + Value load2dVec = b.bitcast(load2dOp, LLVM::getVectorType(valueElemTy, vec)); + llvm::errs() << "bitcasted load vec: " << load2dVec << "\n"; + llvm::errs() << "vec size: " << vec << "\n"; + for (size_t i = 0; i < vec; i++) { + Value vecIdx = createIndexAttrConstant( + rewriter, loc, typeConverter->getIndexType(), i); + Value loaded = b.extract_element(valueElemTy, load2dVec, vecIdx); + loadedVals.push_back(loaded); + } + - return failure(); + } // end vec + + llvm::errs() << "opType: " << op.getType() << "\n"; + Type llvmResultStructTy = typeConverter->convertType(op.getType()); + llvm::errs() << "result struct type: " << llvmResultStructTy << "\n"; + llvm::errs() << "number of load vals: " << loadedVals.size() << "\n"; + Value resultStruct = packLLElements(loc, typeConverter, loadedVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); } LogicalResult @@ -1647,12 +1752,19 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( Type resultType = op.getType(); auto tensorType = cast(resultType); +#if 1 if (hasSubgroup2DBlockEncoding(tensorType)) { auto ret = rewriteSubgroup2DBlockEncodingLoad(op, adaptor, rewriter); +#if 1 + assert(!ret.failed()); + return ret; +#else // little hack to make incremental progress if (!ret.failed()) return ret; +#endif } +#endif Location loc = op.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -1921,7 +2033,7 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( unsigned outerDimWarpNum = std::min(warpsPerCTA[dimOuter], outerDimRequiredWarpNum); LLVM_DEBUG(llvm::dbgs() - << "outerDimWarpNum = " << outerDimRequiredWarpNum << "\n"); + << "outerDimWarpNum = " << outerDimWarpNum << "\n"); Value outerDimWarpId = b.urem(multiDimWarpId[dimOuter], b.i32_val(outerDimWarpNum)); From 6a3ec802e6feaa4a470d50dd9a4caeb18ff37639 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Wed, 2 Jul 2025 20:41:09 +0000 Subject: [PATCH 17/20] b working --- .../LoadStoreOpToLLVM.cpp | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 7f9d41de58..2b6d236904 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1486,6 +1486,8 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( const unsigned threadsPerWarp = encoding.getThreadsPerWarp(); LLVM_DEBUG(llvm::dbgs() << "threadsPerWarp = " << threadsPerWarp << "\n"); + auto warpOrder = llAttr.getWarpOrder(); + #if 0 Type offsetType = i32_ty; // getTypeConverter()->getIndexType(); llvm::errs() << "offsetType = " << offsetType << "\n"; @@ -1497,7 +1499,8 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( loc, offsetType, threadsPerWarp); Value outerDimWarpId = subGroupId; // b.mul(sgStride, subGroupId); #else - unsigned dimOuter = 0; // TODO + + unsigned dimOuter = warpOrder[0]; // TODO Value warpId = rewriter.create( loc, i32_ty, @@ -1533,7 +1536,6 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( printVector(threadOrder, "threadOrder"); - auto warpOrder = llAttr.getWarpOrder(); printVector(warpOrder, "warpOrder"); auto shapePerCTA = getShapePerCTA(tensorType); @@ -1671,20 +1673,21 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( Value offsetX = b.add(b.mul(outerDimWarpId, b.i32_val(/*warpOuterStride*/4)), offsetValues[0].second); Value offsetY = offsetValues[1].second; #else - Value offsetX = b.add(b.mul(outerDimWarpId, b.i32_val(/*warpOuterStride*/32)), b.sub(offsetValues[0].second, baseOffsetForWarp[0].second)); + Value offsetX = b.sub(offsetValues[0].second, baseOffsetForWarp[0].second); Value offsetY = b.sub(offsetValues[1].second, baseOffsetForWarp[1].second); -#endif - - // llPrintf("warp: %d, x: %ld (base %ld), y: %ld (base %ld)\n", {outerDimWarpId, offsetX, offsetBaseX, offsetY, offsetBaseY}, {true, true, true, true, true}, rewriter, targetInfo); -#if 1 - std::swap(offsetX, offsetY); -#else if (warpOrder[0]) { - llvm::errs() << "swap x, y\n"; - std::swap(offsetX, offsetY); + llvm::errs() << "adding to Y offset\n"; + // b matrix + offsetY = b.add(b.mul(outerDimWarpId, b.i32_val(/*warpOuterStride*/32)), offsetY); + } else { + llvm::errs() << "adding to X offset\n"; + // a matrix + offsetX = b.add(b.mul(outerDimWarpId, b.i32_val(/*warpOuterStride*/32)), offsetX); } #endif + std::swap(offsetX, offsetY); // TODO: remove? + offsetX = b.add(offsetX, offsetBaseX); offsetY = b.add(offsetY, offsetBaseY); From b129f14cfaeec83244b45c9c6348cfdfc99f6f33 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Wed, 2 Jul 2025 20:41:41 +0000 Subject: [PATCH 18/20] remove the convert operator delete --- .../TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 173cbdb29a..fe4a04f172 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -28,14 +28,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion RankedTensorType srcTy = op.getSrc().getType(); auto dstTy = op.getType(); - if (auto dstTensorTy = cast(dstTy)) { - if (intel::isBlockIONoOpConversion(srcTy, dstTensorTy)) { - // TODO: replace this with proper conversion once conversion is removed - // from LoadStoreOpToLLVM. - rewriter.replaceOp(op, op.getSrc()); - return success(); - } - } LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); LinearLayout srcLayout = @@ -48,6 +40,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion assert(to_vector(conversion.getInDimNames()) == to_vector(conversion.getOutDimNames())); auto dims = conversion.getInDimNames(); + llvm::errs() << "dims for conversion: \n"; + for (auto dim : dims) + llvm::errs() << dim << "\n"; if (llvm::is_contained(dims, kLane)) { // If the operation is a supported sub-group shuffle, perform via shuffle // operations. From e436f30a68bd8976dcc0e3bc96c6ce7a57053f6a Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Wed, 2 Jul 2025 20:42:25 +0000 Subject: [PATCH 19/20] add a basic unit test --- python/test/unit/intel/test_block_load.py | 65 +++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/python/test/unit/intel/test_block_load.py b/python/test/unit/intel/test_block_load.py index 45ce9e5d8c..40e53f7a9b 100644 --- a/python/test/unit/intel/test_block_load.py +++ b/python/test/unit/intel/test_block_load.py @@ -7,6 +7,71 @@ import triton.language as tl from triton._internal_testing import is_xpu +def test_block_load_subgroup_layout(device, tmp_path: pathlib.Path): + M = 256 + N = 32 + A_width = 1 + B_width = 2 + transpose = False + ty = "f16" + block_io = "row_major" + dtype_str = "float16" + + layouts = f""" + #dpas = #ttig.dpas<{{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2]}}> + #mma = #ttig.subgroup_2d_block<{{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks = 2, isTransposed = false, order = [1, 0], kWidth = 1, threadsPerWarp = 16}}> + #mma1 = #ttig.subgroup_2d_block<{{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks = 2, isTransposed = false, order = [0, 1], kWidth = 2, threadsPerWarp = 16}}> + """ + + ir = layouts + f""" + module attributes {{ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32}} {{ + tt.func public @block_load_dpas_layout(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}, %arg3: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}) attributes {{noinline = false}} {{ + %0 = tt.get_program_id x : i32 + %M_i64 = arith.constant {M} : i64 + %N_i64 = arith.constant {N} : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + + // A matrix + %1 = tt.make_tensor_ptr %arg0, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%0, %c0_i32] {{order = array}} : > + %2 = tt.load %1 {{boundaryCheck = array, ttig.block_io = "row_major"}} : !tt.ptr> + %20 = ttg.convert_layout %2 : tensor<{M}x{N}x{ty}, #mma> -> tensor<{M}x{N}x{ty}, #ttg.dot_op<{{opIdx = 0, parent = #dpas, kWidth = {A_width}}}>> + %3 = tt.make_tensor_ptr %arg1, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%0, %c0_i32] {{order = array}} : >> + tt.store %3, %20 {{boundaryCheck = array}} : !tt.ptr>> + + // B matrix + %4 = tt.make_tensor_ptr %arg2, [%N_i64, %M_i64], {"[%c1_i64, %N_i64]" if transpose else "[%M_i64, %c1_i64]"}, [%c0_i32, %0] {{order = array}} : > + %5 = tt.load %4 {{boundaryCheck = array, ttig.block_io = "{block_io}" }} : !tt.ptr> + %50 = ttg.convert_layout %5 : tensor<{N}x{M}x{ty}, #mma1> -> tensor<{N}x{M}x{ty}, #ttg.dot_op<{{opIdx = 1, parent = #dpas, kWidth = {B_width}}}>> + %6 = tt.make_tensor_ptr %arg3, [%N_i64, %M_i64], {"[%c1_i64, %N_i64]" if transpose else "[%M_i64, %c1_i64]"}, [%c0_i32, %0] {{order = array}} : >> + tt.store %6, %50 {{boundaryCheck = array}} : !tt.ptr>> + + tt.return + }} + }} + """ + + torch_dtype = getattr(torch, dtype_str) + if torch_dtype.is_floating_point: + a = torch.arange(0, M * N, dtype=torch_dtype, device=device).reshape((M, N)) + b = torch.arange(0, M * N, dtype=torch_dtype, device=device).reshape((N, M)) + else: + a = torch.randint(low=-127, high=128, size=(M, N), dtype=torch_dtype, device=device) + b = torch.randint(low=-127, high=128, size=(N, M), dtype=torch_dtype, device=device) + + x = torch.empty_like(a) + y = torch.empty_like(b.T if transpose else b) + + temp_file = tmp_path / "test_block_load_dpas_layout.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](a, x, b, y) + + print(a.int()) + print(x.int()) + assert torch.equal(a, x) + assert torch.equal(b.T if transpose else b, y) @pytest.mark.parametrize("M, N", [[256, 64], [256, 32], [128, 32], [128, 16], [128, 8], [64, 64], [64, 32], [32, 32], [16, 64]]) From 41198c960fdda358003e09509dd157882dc1ab48 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Wed, 2 Jul 2025 20:51:43 +0000 Subject: [PATCH 20/20] remove debug code and format --- python/test/unit/intel/test_block_load.py | 14 +- .../ConvertLayoutOpToLLVM.cpp | 1 - .../LoadStoreOpToLLVM.cpp | 176 ++++++++---------- 3 files changed, 88 insertions(+), 103 deletions(-) diff --git a/python/test/unit/intel/test_block_load.py b/python/test/unit/intel/test_block_load.py index 40e53f7a9b..59e7ecf7ff 100644 --- a/python/test/unit/intel/test_block_load.py +++ b/python/test/unit/intel/test_block_load.py @@ -7,6 +7,7 @@ import triton.language as tl from triton._internal_testing import is_xpu + def test_block_load_subgroup_layout(device, tmp_path: pathlib.Path): M = 256 N = 32 @@ -17,10 +18,10 @@ def test_block_load_subgroup_layout(device, tmp_path: pathlib.Path): block_io = "row_major" dtype_str = "float16" - layouts = f""" - #dpas = #ttig.dpas<{{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2]}}> - #mma = #ttig.subgroup_2d_block<{{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks = 2, isTransposed = false, order = [1, 0], kWidth = 1, threadsPerWarp = 16}}> - #mma1 = #ttig.subgroup_2d_block<{{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks = 2, isTransposed = false, order = [0, 1], kWidth = 2, threadsPerWarp = 16}}> + layouts = """ + #dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2]}> + #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks = 2, isTransposed = false, order = [1, 0], kWidth = 1, threadsPerWarp = 16}> + #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks = 2, isTransposed = false, order = [0, 1], kWidth = 2, threadsPerWarp = 16}> """ ir = layouts + f""" @@ -67,12 +68,13 @@ def test_block_load_subgroup_layout(device, tmp_path: pathlib.Path): kernel = triton.compile(str(temp_file)) kernel[(1, 1, 1)](a, x, b, y) - + print(a.int()) print(x.int()) - assert torch.equal(a, x) + assert torch.equal(a, x) assert torch.equal(b.T if transpose else b, y) + @pytest.mark.parametrize("M, N", [[256, 64], [256, 32], [128, 32], [128, 16], [128, 8], [64, 64], [64, 32], [32, 32], [16, 64]]) @pytest.mark.parametrize("dtype_str", ["float32", "float16", "int8"]) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index fe4a04f172..5e37c47c07 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -28,7 +28,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion RankedTensorType srcTy = op.getSrc().getType(); auto dstTy = op.getType(); - LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); LinearLayout srcLayout = toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 2b6d236904..b055530f18 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -24,7 +24,8 @@ using namespace mlir::triton::gpu::intel; namespace { Value llPrintf(StringRef msg, ValueRange args, ArrayRef isSigned, - ConversionPatternRewriter &rewriter, const mlir::triton::intel::TargetInfo& targetInfo) { + ConversionPatternRewriter &rewriter, + const mlir::triton::intel::TargetInfo &targetInfo) { assert(!msg.empty() && "printf with empty string not supported"); llvm::SmallString<64> msgNewline(msg); msgNewline.push_back('\n'); @@ -1422,7 +1423,7 @@ struct LoadOpConversion oneMatrixPerLoadForBT(oneMatrixPerLoadForBT), useTileLoadLinearLayout(useTileLoadLinearLayout) {} -LogicalResult rewriteSubgroup2DBlockEncodingLoad( + LogicalResult rewriteSubgroup2DBlockEncodingLoad( triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value ptr = op.getPtr(); @@ -1435,7 +1436,9 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( "load op passed to subgroup 2d block encoding load codegen must " "have subgroup 2d block encoding"); - LLVM_DEBUG(llvm::dbgs() << "Lowering load op with Subgroup 2D Block Encoding: " << op << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "Lowering load op with Subgroup 2D Block Encoding: " << op + << "\n"); Location loc = op.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -1461,7 +1464,8 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( const bool isTransposeRequired = encoding.getIsTransposed(); if (isTransposeRequired) { // hrmm... - // loadLayout = loadLayout.transposeOuts(llvm::to_vector(llvm::reverse(loadLayout.getOutDimNames()))); + // loadLayout = + // loadLayout.transposeOuts(llvm::to_vector(llvm::reverse(loadLayout.getOutDimNames()))); } auto instrShape = encoding.getInstrShape(); @@ -1477,10 +1481,10 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( const ArrayRef tensorShape = tensorType.getShape(); auto warpsPerCTA = encoding.getWarpsPerCTA(); - LLVM_DEBUG({ + LLVM_DEBUG({ llvm::dbgs() << "warpsPerCTA: " << warpsPerCTA[0] << ", " << warpsPerCTA[1] << "\n"; - }); + }); SmallVector dpasWarpsOrder = getMatrixOrder(warpsPerCTA.size(), /*rowMajor*/ true); @@ -1488,44 +1492,24 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( LLVM_DEBUG(llvm::dbgs() << "threadsPerWarp = " << threadsPerWarp << "\n"); auto warpOrder = llAttr.getWarpOrder(); -#if 0 - Type offsetType = i32_ty; // getTypeConverter()->getIndexType(); - llvm::errs() << "offsetType = " << offsetType << "\n"; - Value subGroupId = getValueOrCreateCastToIndexLike( - rewriter, loc, offsetType, - rewriter.create( - loc, /*upper_bound=*/IntegerAttr{})); - Value sgStride = rewriter.create( - loc, offsetType, threadsPerWarp); - Value outerDimWarpId = subGroupId; // b.mul(sgStride, subGroupId); -#else - - unsigned dimOuter = warpOrder[0]; // TODO + unsigned dimOuter = warpOrder[0]; // TODO Value warpId = rewriter.create( loc, i32_ty, rewriter.create(loc, /*upperBound=*/nullptr)); SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder); - // unsigned outerDimRequiredWarpNum = mlir::ceil( - // tensorShape[dimOuter], warpShape[dimOuter]); // ceil of ratio - unsigned outerDimWarpNum = warpsPerCTA[dimOuter]; //, outerDimRequiredWarpNum); - LLVM_DEBUG(llvm::dbgs() - << "outerDimWarpNum = " << outerDimWarpNum << "\n"); -#if 1 + unsigned outerDimWarpNum = + warpsPerCTA[dimOuter]; //, outerDimRequiredWarpNum); + LLVM_DEBUG(llvm::dbgs() << "outerDimWarpNum = " << outerDimWarpNum << "\n"); Value outerDimWarpId = b.urem(multiDimWarpId[dimOuter], b.i32_val(outerDimWarpNum)); -#else - Value outerDimWarpId = multiDimWarpId[dimOuter]; -#endif -#endif Type eltTy = tensorType.getElementType(); LLVMTypeConverter *typeConverter = getTypeConverter(); Type valueElemTy = typeConverter->convertType(eltTy); - LLVM_DEBUG(llvm::dbgs() - << "valueElemTy: " << valueElemTy << "\n"); + LLVM_DEBUG(llvm::dbgs() << "valueElemTy: " << valueElemTy << "\n"); auto printVector = [](auto vector, auto name) { llvm::errs() << name << "\n"; @@ -1617,9 +1601,8 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( // elemSizeInBits = 32; } - Type load2DGenXType = - LLVM::getVectorType(packedElemTy, numValuesPerLoad); - // Note: these end up being a mix of float/int vector types... + Type load2DGenXType = LLVM::getVectorType(packedElemTy, numValuesPerLoad); + // Note: these end up being a mix of float/int vector types... llvm::errs() << "load2DGenXType = " << load2DGenXType << "\n"; auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX, @@ -1641,104 +1624,114 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( // Dispatch the load instructions from the perspective of a single lane. unsigned numElems = elemsPerThreadTensorTy; llvm::errs() << "numElems = " << numElems << "\n"; - + Value zero = b.i32_val(0); - auto baseOffsetForWarp = applyLinearLayout(loc, rewriter, loadLayout, {{kRegister, zero}, {kLane, zero}, {kWarp, warpId}, {kBlock, zero}}); + auto baseOffsetForWarp = applyLinearLayout( + loc, rewriter, loadLayout, + {{kRegister, zero}, {kLane, zero}, {kWarp, warpId}, {kBlock, zero}}); assert(baseOffsetForWarp.size() == 2); // probably always 0? just for fun look at warp 1 - auto baseOffset = loadLayout.apply({{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + auto baseOffset = + loadLayout.apply({{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); assert(baseOffset.size() == 2); LLVM_DEBUG({ - llvm::dbgs() << "base offset: " << baseOffset[0].second << ", " << baseOffset[1].second << "\n"; + llvm::dbgs() << "base offset: " << baseOffset[0].second << ", " + << baseOffset[1].second << "\n"; }); SmallVector loadedVals; for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { llvm::errs() << "dispatch load " << vecStart << "\n"; - auto offset = loadLayout.apply({{kRegister, vecStart}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + auto offset = loadLayout.apply( + {{kRegister, vecStart}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); assert(offset.size() == 2); LLVM_DEBUG({ - llvm::dbgs() << "offset: " << offset[0].second << ", " << offset[1].second << "\n"; - llvm::dbgs() << "offset - base offset: " << offset[0].second - baseOffset[0].second << ", " << offset[1].second - baseOffset[1].second << "\n"; + llvm::dbgs() << "offset: " << offset[0].second << ", " + << offset[1].second << "\n"; + llvm::dbgs() << "offset - base offset: " + << offset[0].second - baseOffset[0].second << ", " + << offset[1].second - baseOffset[1].second << "\n"; }); // Lane ID doesn't matter. Value regIdVal = b.i32_val(vecStart); - auto offsetValues = applyLinearLayout(loc, rewriter, loadLayout, {{kRegister, regIdVal}, {kLane, zero}, {kWarp, warpId}, {kBlock, zero}}); + auto offsetValues = applyLinearLayout(loc, rewriter, loadLayout, + {{kRegister, regIdVal}, + {kLane, zero}, + {kWarp, warpId}, + {kBlock, zero}}); assert(offsetValues.size() == 2); -#if 0 - Value offsetX = b.add(b.mul(outerDimWarpId, b.i32_val(/*warpOuterStride*/4)), offsetValues[0].second); - Value offsetY = offsetValues[1].second; -#else - Value offsetX = b.sub(offsetValues[0].second, baseOffsetForWarp[0].second); - Value offsetY = b.sub(offsetValues[1].second, baseOffsetForWarp[1].second); + Value offsetX = + b.sub(offsetValues[0].second, baseOffsetForWarp[0].second); + Value offsetY = + b.sub(offsetValues[1].second, baseOffsetForWarp[1].second); if (warpOrder[0]) { llvm::errs() << "adding to Y offset\n"; - // b matrix - offsetY = b.add(b.mul(outerDimWarpId, b.i32_val(/*warpOuterStride*/32)), offsetY); + // b matrix + offsetY = b.add( + b.mul(outerDimWarpId, b.i32_val(/*warpOuterStride*/ 32)), offsetY); } else { llvm::errs() << "adding to X offset\n"; - // a matrix - offsetX = b.add(b.mul(outerDimWarpId, b.i32_val(/*warpOuterStride*/32)), offsetX); + // a matrix + offsetX = b.add( + b.mul(outerDimWarpId, b.i32_val(/*warpOuterStride*/ 32)), offsetX); } -#endif - std::swap(offsetX, offsetY); // TODO: remove? + std::swap(offsetX, offsetY); // TODO: remove? - offsetX = b.add(offsetX, offsetBaseX); + offsetX = b.add(offsetX, offsetBaseX); offsetY = b.add(offsetY, offsetBaseY); auto load2dOp = rewriter.create( - loc, load2DGenXType, - /*ptr*/ base, - /*base_width*/ b.mul(baseWidth, elemSizeInBytes), - /*base_height*/ baseHeight, - /*base_pitch*/ b.mul(pitch, elemSizeInBytes), - /*x*/ b.trunc(i32_ty, offsetX), - /*y*/ b.trunc(i32_ty, offsetY), - /*elem_size_in_bits*/ elemSizeInBits, - /*tile_width*/ tileWidth, - /*tile_height*/ tileHeight, - /*v_blocks*/ numBlocks, - /*transpose*/ isTransposeRequired, - /*vnni_transform*/width > 1 && !isTransposeRequired); - llvm::errs() << "Generated load2dOp: " << load2dOp << "\n"; - if (failed(load2dOp.verify())) { - // delete the op so that the verifier will not abort the pass - // pipeline later, as we can fail this path and try a different - // approach. - assert(false); - rewriter.eraseOp(load2dOp); - return failure(); - } - - + loc, load2DGenXType, + /*ptr*/ base, + /*base_width*/ b.mul(baseWidth, elemSizeInBytes), + /*base_height*/ baseHeight, + /*base_pitch*/ b.mul(pitch, elemSizeInBytes), + /*x*/ b.trunc(i32_ty, offsetX), + /*y*/ b.trunc(i32_ty, offsetY), + /*elem_size_in_bits*/ elemSizeInBits, + /*tile_width*/ tileWidth, + /*tile_height*/ tileHeight, + /*v_blocks*/ numBlocks, + /*transpose*/ isTransposeRequired, + /*vnni_transform*/ width > 1 && !isTransposeRequired); + llvm::errs() << "Generated load2dOp: " << load2dOp << "\n"; + if (failed(load2dOp.verify())) { + // delete the op so that the verifier will not abort the pass + // pipeline later, as we can fail this path and try a different + // approach. + assert(false); + rewriter.eraseOp(load2dOp); + return failure(); + } + llvm::errs() << "llvm type: " << load2DGenXType << "\n"; - Type llvmResultStructTy = typeConverter->convertType(op.getType()); - + Type llvmResultStructTy = typeConverter->convertType(op.getType()); + // Extract and store return values - Value load2dVec = b.bitcast(load2dOp, LLVM::getVectorType(valueElemTy, vec)); + Value load2dVec = + b.bitcast(load2dOp, LLVM::getVectorType(valueElemTy, vec)); llvm::errs() << "bitcasted load vec: " << load2dVec << "\n"; llvm::errs() << "vec size: " << vec << "\n"; for (size_t i = 0; i < vec; i++) { - Value vecIdx = createIndexAttrConstant( + Value vecIdx = createIndexAttrConstant( rewriter, loc, typeConverter->getIndexType(), i); Value loaded = b.extract_element(valueElemTy, load2dVec, vecIdx); loadedVals.push_back(loaded); } - } // end vec - + llvm::errs() << "opType: " << op.getType() << "\n"; Type llvmResultStructTy = typeConverter->convertType(op.getType()); llvm::errs() << "result struct type: " << llvmResultStructTy << "\n"; llvm::errs() << "number of load vals: " << loadedVals.size() << "\n"; Value resultStruct = packLLElements(loc, typeConverter, loadedVals, - rewriter, llvmResultStructTy); + rewriter, llvmResultStructTy); rewriter.replaceOp(op, {resultStruct}); return success(); } @@ -1755,19 +1748,11 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( Type resultType = op.getType(); auto tensorType = cast(resultType); -#if 1 if (hasSubgroup2DBlockEncoding(tensorType)) { auto ret = rewriteSubgroup2DBlockEncodingLoad(op, adaptor, rewriter); -#if 1 assert(!ret.failed()); return ret; -#else - // little hack to make incremental progress - if (!ret.failed()) - return ret; -#endif } -#endif Location loc = op.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -2035,8 +2020,7 @@ LogicalResult rewriteSubgroup2DBlockEncodingLoad( << outerDimRequiredWarpNum << "\n"); unsigned outerDimWarpNum = std::min(warpsPerCTA[dimOuter], outerDimRequiredWarpNum); - LLVM_DEBUG(llvm::dbgs() - << "outerDimWarpNum = " << outerDimWarpNum << "\n"); + LLVM_DEBUG(llvm::dbgs() << "outerDimWarpNum = " << outerDimWarpNum << "\n"); Value outerDimWarpId = b.urem(multiDimWarpId[dimOuter], b.i32_val(outerDimWarpNum));