@@ -39,6 +39,43 @@ static int __builtin_ctz(unsigned x) {
3939
4040namespace {
4141
42+ static Value skipCasts (Value v) {
43+ Operation *def = v.getDefiningOp ();
44+ if (def &&
45+ isa<LLVM::TruncOp, LLVM::SExtOp, LLVM::ZExtOp, LLVM::BitcastOp>(def))
46+ return def->getOperand (0 );
47+ return v;
48+ }
49+
50+ static Value tryFoldOp (Value v) {
51+ Operation *def = v.getDefiningOp ();
52+ if (def) {
53+ SmallVector<OpFoldResult> results;
54+ if (succeeded (def->fold (results)) && results.size () == 1 ) {
55+ if (auto val = dyn_cast_or_null<Value>(results[0 ]))
56+ return val;
57+ }
58+ }
59+ return v;
60+ }
61+
62+ static std::optional<int64_t > tryConstEval (Value v, int depth = 16 ) {
63+ for (int i = 0 ; i < depth; ++i) {
64+ if (auto res = getConstantIntValue (v))
65+ return res;
66+
67+ Value newV = skipCasts (v);
68+ newV = tryFoldOp (newV);
69+
70+ if (newV == v)
71+ break ;
72+
73+ v = newV;
74+ }
75+
76+ return std::nullopt ;
77+ }
78+
4279Value maybeAnd (RewriterBase &rewriter, Location loc, Value a, Value b) {
4380 auto tb = TritonLLVMOpBuilder (loc, rewriter);
4481 if (a && b) {
@@ -1590,23 +1627,19 @@ struct LoadOpToBlockIOConversion
15901627 std::swap (baseWidth, baseHeight);
15911628 }
15921629 // HW requires the pitch to be at least 64 bytes.
1593- std::function<Value (Value)> skipTrunc = [&](Value v) {
1594- if (dyn_cast_or_null<LLVM::TruncOp>(v.getDefiningOp ()))
1595- return skipTrunc (v.getDefiningOp ()->getOperand (0 ));
1596- return v;
1597- };
1598- if (Operation *op = skipTrunc (pitch).getDefiningOp ()) {
1599- std::optional<int64_t > pitchConst =
1600- mlir::triton::intel::getFoldedConstantValue (op);
1601- if (pitchConst.has_value ()) {
1602- if ((*pitchConst * elemSizeInBits / 8 ) < 64 )
1603- return failure ();
1604- }
1630+ if (auto pitchConst = tryConstEval (pitch)) {
1631+ if ((*pitchConst * elemSizeInBits / 8 ) < 64 )
1632+ return failure ();
16051633 }
16061634
16071635 baseWidth = b.trunc (i32_ty, baseWidth);
16081636 baseHeight = b.trunc (i32_ty, baseHeight);
16091637
1638+ if (auto widthConst = tryConstEval (baseWidth)) {
1639+ if ((*widthConst * elemSizeInBits / 8 ) < 64 )
1640+ return failure ();
1641+ }
1642+
16101643 const unsigned originalElemBits = elemSizeInBits;
16111644 if (isTransposeRequired) {
16121645 // adjust the block io parameter to align HW's limitations on
0 commit comments