Skip to content

Commit b50872a

Browse files
authored
[Gluon] Disable constant CSE before auto layout propagation (#8323)
Fixes #8229 Background is that we run the gluon inliner prior to auto layout propagation to enable returning auto layout from a function and having different calls of the function resolve to different layouts. However, the inliner calls gluon canonicalize and the `GreedyPatternRewriter` defaults to CSEing constants. This means that two distinct constants which could otherwise resolve to different layouts may be CSEd into a single constant and create a new conflict. I fix this by changing the inliner to do even less canonicalization, and only simplify control flow operations. I then add a canoncalization pass after auto layout resolution to make up for this.
1 parent f452d9d commit b50872a

File tree

8 files changed

+89
-3
lines changed

8 files changed

+89
-3
lines changed

include/triton/Dialect/Gluon/Transforms/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,14 @@ def GluonInline: Pass<"gluon-inline"> {
3535
let dependentDialects = [];
3636
}
3737

38+
def GluonSimplifyControlFlow: Pass<"gluon-slimplify-control-flow"> {
39+
let summary = "simplications for control flow ops";
40+
41+
let description = [{
42+
The `gluon-inline` pass applies a reduced set of simplification
43+
and canonicalization patterns to the module.
44+
}];
45+
let dependentDialects = [];
46+
}
47+
3848
#endif

lib/Dialect/Gluon/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_triton_library(GluonTransforms
22
Canonicalize.cpp
33
Inline.cpp
44
ResolveAutoEncodings.cpp
5+
SimplifyControlFlow.cpp
56

67
DEPENDS
78
GluonTransformsIncGen

lib/Dialect/Gluon/Transforms/Inline.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct Inline : public gluon::impl::GluonInlineBase<Inline> {
2222
void Inline::runOnOperation() {
2323
mlir::PassManager pm(&getContext());
2424
pm.addPass(createInlinerPass(/*opPipelines=*/{}, [](OpPassManager &pm) {
25-
pm.addPass(gluon::createGluonCanonicalize());
25+
pm.addPass(gluon::createGluonSimplifyControlFlow());
2626
}));
2727
if (failed(pm.run(getOperation())))
2828
return signalPassFailure();
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "mlir/IR/OperationSupport.h"
2+
#include "triton/Dialect/Gluon/Transforms/Passes.h"
3+
4+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
5+
6+
#include "mlir/Dialect/Arith/IR/Arith.h"
7+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
8+
#include "mlir/Dialect/SCF/IR/SCF.h"
9+
#include "mlir/Pass/Pass.h"
10+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
11+
12+
using namespace mlir;
13+
using namespace triton;
14+
15+
namespace mlir::triton::gluon {
16+
#define GEN_PASS_DEF_GLUONSIMPLIFYCONTROLFLOW
17+
#include "triton/Dialect/Gluon/Transforms/Passes.h.inc"
18+
} // namespace mlir::triton::gluon
19+
20+
namespace {
21+
struct SimplifyControlFlow
22+
: public gluon::impl::GluonSimplifyControlFlowBase<SimplifyControlFlow> {
23+
void runOnOperation() override;
24+
};
25+
} // namespace
26+
27+
void SimplifyControlFlow::runOnOperation() {
28+
MLIRContext *ctx = &getContext();
29+
RewritePatternSet patterns(&getContext());
30+
31+
// Populate `scf` and `cf` canonicalizers.
32+
ctx->getLoadedDialect<scf::SCFDialect>()->getCanonicalizationPatterns(
33+
patterns);
34+
ctx->getLoadedDialect<cf::ControlFlowDialect>()->getCanonicalizationPatterns(
35+
patterns);
36+
for (mlir::RegisteredOperationName op : ctx->getRegisteredOperationsByDialect(
37+
scf::SCFDialect::getDialectNamespace()))
38+
op.getCanonicalizationPatterns(patterns, ctx);
39+
for (mlir::RegisteredOperationName op : ctx->getRegisteredOperationsByDialect(
40+
cf::ControlFlowDialect::getDialectNamespace()))
41+
op.getCanonicalizationPatterns(patterns, ctx);
42+
populateForOpDeadArgumentElimination(patterns);
43+
44+
GreedyRewriteConfig config;
45+
// This is intended to run before AutoLayouts are resolved, in which case
46+
// CSEing constants can lead to additional layout conflicts.
47+
config.enableConstantCSE(false);
48+
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
49+
}

python/test/gluon/test_core.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,3 +1114,28 @@ def kernel(a_ptr, b_ptr, c_ptr, out_ptr):
11141114
out = torch.empty((B, B), dtype=torch.float32, device="cuda")
11151115
kernel[(1, )](a, b, c, out)
11161116
torch.testing.assert_close(out, torch.addmm(c, a, b), atol=1e-2, rtol=1e-2)
1117+
1118+
1119+
@gluon.jit
1120+
def kernel_auto_layout_constant(threads_per_warp: ttgl.constexpr):
1121+
BLOCK: ttgl.constexpr = 16
1122+
SIZE: ttgl.constexpr = 10
1123+
1124+
mask = ttgl.full(
1125+
(BLOCK, BLOCK),
1126+
True,
1127+
ttgl.int1,
1128+
ttgl.BlockedLayout(
1129+
size_per_thread=[1, 1],
1130+
threads_per_warp=[1, threads_per_warp],
1131+
warps_per_cta=[1, 4],
1132+
order=[1, 0],
1133+
),
1134+
)
1135+
1136+
mask &= (ttgl.arange(0, BLOCK, ttgl.AutoLayout()) < SIZE).expand_dims(0)
1137+
mask &= (ttgl.arange(0, BLOCK, ttgl.AutoLayout()) < SIZE).expand_dims(1)
1138+
1139+
1140+
def test_auto_layout_constant():
1141+
kernel_auto_layout_constant.warmup(THREADS_PER_WARP, grid=(1, ))

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
"fence_async_shared",
2424
"get_tmem_32x32b_reg_layout",
2525
"mbarrier",
26+
"mma_v2",
2627
"tensor_memory_descriptor",
2728
"TensorMemoryLayout",
2829
"tma",
29-
"mma_v2",
3030
]
3131

3232

python/triton/experimental/gluon/language/nvidia/hopper/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
if TYPE_CHECKING:
99
from triton._C.libtriton import ir
1010

11-
__all__ = ["async_copy", "fence_async_shared", "mbarrier", "tma", "warpgroup_mma", "warpgroup_mma_wait", "mma_v2"]
11+
__all__ = ["async_copy", "fence_async_shared", "mbarrier", "mma_v2", "tma", "warpgroup_mma", "warpgroup_mma_wait"]
1212

1313

1414
@_core.builtin

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def gluon_to_ttgir(self, src, metadata, options, capability):
329329

330330
passes.gluon.add_inliner(pm)
331331
passes.gluon.add_resolve_auto_encodings(pm)
332+
passes.gluon.add_canonicalizer(pm)
332333
passes.common.add_sccp(pm)
333334
passes.ttir.add_loop_aware_cse(pm)
334335
passes.gluon.add_canonicalizer(pm)

0 commit comments

Comments
 (0)