diff --git a/test/TritonIntelGPU/blockptr_load.mlir b/test/TritonIntelGPU/blockptr_load.mlir index b625c66211..e2e4ac14f9 100644 --- a/test/TritonIntelGPU/blockptr_load.mlir +++ b/test/TritonIntelGPU/blockptr_load.mlir @@ -136,7 +136,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32, // CHECK: %[[OFFSET_0:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> // CHECK: %[[OFFSET_1:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> // CHECK: %[[WIDTH_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> + // CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[VAL_11]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> // CHECK: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue %[[VAL_12]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> // CHECK: %[[COL_STRIDE_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> @@ -199,7 +199,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32, // CHECK: %[[OFFSET_0:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> // CHECK: %[[OFFSET_1:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> // CHECK: %[[WIDTH_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> + // CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[VAL_10]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> // CHECK: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue %[[VAL_11]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> // CHECK: %[[COL_STRIDE_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index feab50f3fa..b094460599 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -39,6 +39,42 @@ static int __builtin_ctz(unsigned x) { namespace { +static Value skipCasts(Value v) { + Operation *def = v.getDefiningOp(); + if (def && + isa(def)) + return def->getOperand(0); + return v; +} + +static Value tryFoldOp(Value v) { + if (Operation *def = v.getDefiningOp()) { + SmallVector results; + if (succeeded(def->fold(results)) && results.size() == 1) { + if (auto val = dyn_cast_or_null(results[0])) + return val; + } + } + return v; +} + +static std::optional tryConstEval(Value v, int depth = 16) { + for (int i = 0; i < depth; ++i) { + if (auto res = getConstantIntValue(v)) + return res; + + Value newV = skipCasts(v); + newV = tryFoldOp(newV); + + if (newV == v) + break; + + v = newV; + } + + return std::nullopt; +} + Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) { auto tb = TritonLLVMOpBuilder(loc, rewriter); if (a && b) { @@ -1590,23 +1626,19 @@ struct LoadOpToBlockIOConversion std::swap(baseWidth, baseHeight); } // HW requires the pitch to be at least 64 bytes. - std::function skipTrunc = [&](Value v) { - if (dyn_cast_or_null(v.getDefiningOp())) - return skipTrunc(v.getDefiningOp()->getOperand(0)); - return v; - }; - if (Operation *op = skipTrunc(pitch).getDefiningOp()) { - std::optional pitchConst = - mlir::triton::intel::getFoldedConstantValue(op); - if (pitchConst.has_value()) { - if ((*pitchConst * elemSizeInBits / 8) < 64) - return failure(); - } + if (auto pitchConst = tryConstEval(pitch)) { + if ((*pitchConst * elemSizeInBits / 8) < 64) + return failure(); } baseWidth = b.trunc(i32_ty, baseWidth); baseHeight = b.trunc(i32_ty, baseHeight); + if (auto widthConst = tryConstEval(baseWidth)) { + if ((*widthConst * elemSizeInBits / 8) < 64) + return failure(); + } + const unsigned originalElemBits = elemSizeInBits; if (isTransposeRequired) { // adjust the block io parameter to align HW's limitations on