Skip to content

Commit 0806403

Browse files
committed
[intel] improve pitch and width constexpr folding
1 parent 6f1525f commit 0806403

File tree

2 files changed

+47
-14
lines changed

2 files changed

+47
-14
lines changed

python/test/unit/language/test_block_pointer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def test_block_copy(dtypes_str, n, padding_option, boundary_check, device):
6969
def matmul_no_scf_with_advance_kernel( #
7070
a_ptr, b_ptr, c_ptr, #
7171
M, N, K, #
72-
stride_am, stride_ak, #
73-
stride_bk, stride_bn, #
72+
stride_am: tl.constexpr, stride_ak: tl.constexpr, #
73+
stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
7474
stride_cm, stride_cn, #
7575
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr #
7676
):

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,43 @@ static int __builtin_ctz(unsigned x) {
3939

4040
namespace {
4141

42+
static Value skipCasts(Value v) {
43+
Operation *def = v.getDefiningOp();
44+
if (def &&
45+
isa<LLVM::TruncOp, LLVM::SExtOp, LLVM::ZExtOp, LLVM::BitcastOp>(def))
46+
return def->getOperand(0);
47+
return v;
48+
}
49+
50+
static Value tryFoldOp(Value v) {
51+
Operation *def = v.getDefiningOp();
52+
if (def) {
53+
SmallVector<OpFoldResult> results;
54+
if (succeeded(def->fold(results)) && results.size() == 1) {
55+
if (auto val = dyn_cast_or_null<Value>(results[0]))
56+
return val;
57+
}
58+
}
59+
return v;
60+
}
61+
62+
static std::optional<int64_t> tryConstEval(Value v, int depth = 16) {
63+
for (int i = 0; i < depth; ++i) {
64+
if (auto res = getConstantIntValue(v))
65+
return res;
66+
67+
Value newV = skipCasts(v);
68+
newV = tryFoldOp(newV);
69+
70+
if (newV == v)
71+
break;
72+
73+
v = newV;
74+
}
75+
76+
return std::nullopt;
77+
}
78+
4279
Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) {
4380
auto tb = TritonLLVMOpBuilder(loc, rewriter);
4481
if (a && b) {
@@ -1590,23 +1627,19 @@ struct LoadOpToBlockIOConversion
15901627
std::swap(baseWidth, baseHeight);
15911628
}
15921629
// HW requires the pitch to be at least 64 bytes.
1593-
std::function<Value(Value)> skipTrunc = [&](Value v) {
1594-
if (dyn_cast_or_null<LLVM::TruncOp>(v.getDefiningOp()))
1595-
return skipTrunc(v.getDefiningOp()->getOperand(0));
1596-
return v;
1597-
};
1598-
if (Operation *op = skipTrunc(pitch).getDefiningOp()) {
1599-
std::optional<int64_t> pitchConst =
1600-
mlir::triton::intel::getFoldedConstantValue(op);
1601-
if (pitchConst.has_value()) {
1602-
if ((*pitchConst * elemSizeInBits / 8) < 64)
1603-
return failure();
1604-
}
1630+
if (auto pitchConst = tryConstEval(pitch)) {
1631+
if ((*pitchConst * elemSizeInBits / 8) < 64)
1632+
return failure();
16051633
}
16061634

16071635
baseWidth = b.trunc(i32_ty, baseWidth);
16081636
baseHeight = b.trunc(i32_ty, baseHeight);
16091637

1638+
if (auto widthConst = tryConstEval(baseWidth)) {
1639+
if ((*widthConst * elemSizeInBits / 8) < 64)
1640+
return failure();
1641+
}
1642+
16101643
const unsigned originalElemBits = elemSizeInBits;
16111644
if (isTransposeRequired) {
16121645
// adjust the block io parameter to align HW's limitations on

0 commit comments

Comments
 (0)