Skip to content

Commit 3102d07

Browse files
committed
[intel] improve pitch and width constexpr folding
1 parent 6f1525f commit 3102d07

File tree

2 files changed

+50
-14
lines changed

2 files changed

+50
-14
lines changed

python/test/unit/language/test_block_pointer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def test_block_copy(dtypes_str, n, padding_option, boundary_check, device):
6969
def matmul_no_scf_with_advance_kernel( #
7070
a_ptr, b_ptr, c_ptr, #
7171
M, N, K, #
72-
stride_am, stride_ak, #
73-
stride_bk, stride_bn, #
72+
stride_am: tl.constexpr, stride_ak: tl.constexpr, #
73+
stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
7474
stride_cm, stride_cn, #
7575
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr #
7676
):

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
1717
#include "intel/include/Utils/Utility.h"
1818
#include "triton/Tools/LinearLayout.h"
19+
1920
#include <optional>
2021
#include <triton/Tools/Sys/GetEnv.hpp>
2122

@@ -39,6 +40,43 @@ static int __builtin_ctz(unsigned x) {
3940

4041
namespace {
4142

43+
static Value skipCasts(Value v) {
44+
Operation *def = v.getDefiningOp();
45+
if (def &&
46+
isa<LLVM::TruncOp, LLVM::SExtOp, LLVM::ZExtOp, LLVM::BitcastOp>(def))
47+
return def->getOperand(0);
48+
return v;
49+
}
50+
51+
static Value tryFoldOp(Value v) {
52+
Operation *def = v.getDefiningOp();
53+
if (def) {
54+
SmallVector<OpFoldResult> results;
55+
if (succeeded(def->fold(results)) && results.size() == 1) {
56+
if (auto val = dyn_cast_or_null<Value>(results[0]))
57+
return val;
58+
}
59+
}
60+
return v;
61+
}
62+
63+
static std::optional<int64_t> tryConstEval(Value v, int depth = 16) {
64+
for (int i = 0; i < depth; ++i) {
65+
if (auto res = getConstantIntValue(v))
66+
return res;
67+
68+
Value newV = skipCasts(v);
69+
newV = tryFoldOp(newV);
70+
71+
if (newV == v)
72+
break;
73+
74+
v = newV;
75+
}
76+
77+
return std::nullopt;
78+
}
79+
4280
Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) {
4381
auto tb = TritonLLVMOpBuilder(loc, rewriter);
4482
if (a && b) {
@@ -1031,6 +1069,7 @@ struct LoadOpToBlockIOConversion
10311069
LogicalResult
10321070
rewriteTensorPointerLoad(triton::LoadOp op, OpAdaptor adaptor,
10331071
ConversionPatternRewriter &rewriter) const {
1072+
10341073
// FIXME: Remove once IGC can split large 2D block loads.
10351074
std::optional<bool> oneMatrixPerLoadForBT =
10361075
mlir::triton::tools::isEnvValueBool(mlir::triton::tools::getStrEnv(
@@ -1589,24 +1628,21 @@ struct LoadOpToBlockIOConversion
15891628
pitch = b.trunc(i32_ty, colStride);
15901629
std::swap(baseWidth, baseHeight);
15911630
}
1631+
15921632
// HW requires the pitch to be at least 64 bytes.
1593-
std::function<Value(Value)> skipTrunc = [&](Value v) {
1594-
if (dyn_cast_or_null<LLVM::TruncOp>(v.getDefiningOp()))
1595-
return skipTrunc(v.getDefiningOp()->getOperand(0));
1596-
return v;
1597-
};
1598-
if (Operation *op = skipTrunc(pitch).getDefiningOp()) {
1599-
std::optional<int64_t> pitchConst =
1600-
mlir::triton::intel::getFoldedConstantValue(op);
1601-
if (pitchConst.has_value()) {
1602-
if ((*pitchConst * elemSizeInBits / 8) < 64)
1603-
return failure();
1604-
}
1633+
if (auto pitchConst = tryConstEval(pitch)) {
1634+
if ((*pitchConst * elemSizeInBits / 8) < 64)
1635+
return failure();
16051636
}
16061637

16071638
baseWidth = b.trunc(i32_ty, baseWidth);
16081639
baseHeight = b.trunc(i32_ty, baseHeight);
16091640

1641+
if (auto widthConst = tryConstEval(baseWidth)) {
1642+
if ((*widthConst * elemSizeInBits / 8) < 64)
1643+
return failure();
1644+
}
1645+
16101646
const unsigned originalElemBits = elemSizeInBits;
16111647
if (isTransposeRequired) {
16121648
// adjust the block io parameter to align HW's limitations on

0 commit comments

Comments
 (0)