@@ -1298,9 +1298,66 @@ function optimization_passes(
12981298 push! (passes, " remove-duplicate-func-def" )
12991299 end
13001300 push! (passes, func_passes)
1301+ if backend == " cuda"
1302+ push! (passes, triton_optimization_passes ())
1303+ end
13011304 return join (passes, ' ,' )
13021305end
13031306
1307+ # https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc
1308+ function triton_optimization_passes ()
1309+ # TODO : check that all triton passes are included here
1310+ return join (
1311+ [
1312+ # convert passes
1313+ " convert-scf-to-cf" ,
1314+ " convert-cf-to-llvm" ,
1315+ " convert-index-to-llvm" ,
1316+ " convert-arith-to-llvm" ,
1317+ " convert-nvvm-to-llvm" ,
1318+ # common passes
1319+ " canonicalize" ,
1320+ # # ttir passes
1321+ # "triton-combine",
1322+ # "triton-reorder-broadcast",
1323+ # "triton-rewrite-tensor-pointer",
1324+ # "triton-rewrite-tensor-descriptor-to-pointer",
1325+ # "triton-loop-unroll",
1326+ # "triton-licm",
1327+ # "triton-loop-aware-cse",
1328+ # # TODO : should num-warps and num-ctas be set for each kernel?
1329+ # "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1330+ # # ttgir passes
1331+ # "tritongpu-coalesce",
1332+ # "tritongpu-optimize-thread-locality",
1333+ # "tritongpu-hoist-tmem-alloc",
1334+ # "tritongpu-assign-latencies",
1335+ # "tritongpu-pipeline",
1336+ # "tritongpu-schedule-loops",
1337+ # "tritongpu-automatic-warp-specialization",
1338+ # "tritongpu-prefetch",
1339+ # "tritongpu-accelerate-matmul",
1340+ # "tritongpu-reorder-instructions",
1341+ # "tritongpu-F32DotTC",
1342+ # "tritongpu-optimize-dot-operands",
1343+ # "tritongpu-remove-layout-conversions",
1344+ # "tritongpu-reduce-data-duplication",
1345+ # "tritongpu-hoist-tmem-alloc",
1346+ # "tritongpu-fuse-nested-loops",
1347+ # "tritongpu-rewrite-partition-dependencies",
1348+ # "tritongpu-partition-loops",
1349+ # "tritongpu-combine-tensor-select-and-if",
1350+ # # ttgir to llvm passes
1351+ # "tritongpu-allocate-warp-groups",
1352+ # "allocate-shared-memory",
1353+ # "tritongpu-global-scratch-memory-allocation",
1354+ # "tritongpu-optimize-accumulator-init",
1355+ # "tritongpu-coalesce-async-copy",
1356+ ],
1357+ " ," ,
1358+ )
1359+ end
1360+
13041361# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
13051362# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
13061363const enzyme_pass:: String = " enzyme{postpasses=\" arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\" }"
@@ -2254,7 +2311,8 @@ function compile_mlir!(
22542311 end
22552312 end
22562313
2257- run_pass_pipeline! (mod, " mark-func-memory-effects" , " mark-func-memory-effects" )
2314+ # XXX : re-enable this pass
2315+ # run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects")
22582316
22592317 func_op = MLIR. API. mlirSymbolTableLookup (
22602318 MLIR. IR. SymbolTable (MLIR. IR. Operation (mod)), fnname
0 commit comments