1616#include " intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
1717#include " intel/include/Utils/Utility.h"
1818#include " triton/Tools/LinearLayout.h"
19+
1920#include < optional>
2021#include < triton/Tools/Sys/GetEnv.hpp>
2122
@@ -39,6 +40,43 @@ static int __builtin_ctz(unsigned x) {
3940
4041namespace {
4142
43+ static Value skipCasts (Value v) {
44+ Operation *def = v.getDefiningOp ();
45+ if (def &&
46+ isa<LLVM::TruncOp, LLVM::SExtOp, LLVM::ZExtOp, LLVM::BitcastOp>(def))
47+ return def->getOperand (0 );
48+ return v;
49+ }
50+
51+ static Value tryFoldOp (Value v) {
52+ Operation *def = v.getDefiningOp ();
53+ if (def) {
54+ SmallVector<OpFoldResult> results;
55+ if (succeeded (def->fold (results)) && results.size () == 1 ) {
56+ if (auto val = dyn_cast_or_null<Value>(results[0 ]))
57+ return val;
58+ }
59+ }
60+ return v;
61+ }
62+
63+ static std::optional<int64_t > tryConstEval (Value v, int depth = 16 ) {
64+ for (int i = 0 ; i < depth; ++i) {
65+ if (auto res = getConstantIntValue (v))
66+ return res;
67+
68+ Value newV = skipCasts (v);
69+ newV = tryFoldOp (newV);
70+
71+ if (newV == v)
72+ break ;
73+
74+ v = newV;
75+ }
76+
77+ return std::nullopt ;
78+ }
79+
4280Value maybeAnd (RewriterBase &rewriter, Location loc, Value a, Value b) {
4381 auto tb = TritonLLVMOpBuilder (loc, rewriter);
4482 if (a && b) {
@@ -1031,6 +1069,7 @@ struct LoadOpToBlockIOConversion
10311069 LogicalResult
10321070 rewriteTensorPointerLoad (triton::LoadOp op, OpAdaptor adaptor,
10331071 ConversionPatternRewriter &rewriter) const {
1072+
10341073 // FIXME: Remove once IGC can split large 2D block loads.
10351074 std::optional<bool > oneMatrixPerLoadForBT =
10361075 mlir::triton::tools::isEnvValueBool (mlir::triton::tools::getStrEnv (
@@ -1589,24 +1628,21 @@ struct LoadOpToBlockIOConversion
15891628 pitch = b.trunc (i32_ty, colStride);
15901629 std::swap (baseWidth, baseHeight);
15911630 }
1631+
15921632 // HW requires the pitch to be at least 64 bytes.
1593- std::function<Value (Value)> skipTrunc = [&](Value v) {
1594- if (dyn_cast_or_null<LLVM::TruncOp>(v.getDefiningOp ()))
1595- return skipTrunc (v.getDefiningOp ()->getOperand (0 ));
1596- return v;
1597- };
1598- if (Operation *op = skipTrunc (pitch).getDefiningOp ()) {
1599- std::optional<int64_t > pitchConst =
1600- mlir::triton::intel::getFoldedConstantValue (op);
1601- if (pitchConst.has_value ()) {
1602- if ((*pitchConst * elemSizeInBits / 8 ) < 64 )
1603- return failure ();
1604- }
1633+ if (auto pitchConst = tryConstEval (pitch)) {
1634+ if ((*pitchConst * elemSizeInBits / 8 ) < 64 )
1635+ return failure ();
16051636 }
16061637
16071638 baseWidth = b.trunc (i32_ty, baseWidth);
16081639 baseHeight = b.trunc (i32_ty, baseHeight);
16091640
1641+ if (auto widthConst = tryConstEval (baseWidth)) {
1642+ if ((*widthConst * elemSizeInBits / 8 ) < 64 )
1643+ return failure ();
1644+ }
1645+
16101646 const unsigned originalElemBits = elemSizeInBits;
16111647 if (isTransposeRequired) {
16121648 // adjust the block io parameter to align HW's limitations on
0 commit comments