@@ -1300,9 +1300,66 @@ function optimization_passes(
13001300 push! (passes, " remove-duplicate-func-def" )
13011301 end
13021302 push! (passes, func_passes)
1303+ if backend == " cuda"
1304+ push! (passes, triton_optimization_passes ())
1305+ end
13031306 return join (passes, ' ,' )
13041307end
13051308
1309+ # https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc
1310+ function triton_optimization_passes ()
1311+ # TODO : check that all triton passes are included here
1312+ return join (
1313+ [
1314+ # convert passes
1315+ " convert-scf-to-cf" ,
1316+ " convert-cf-to-llvm" ,
1317+ " convert-index-to-llvm" ,
1318+ " convert-arith-to-llvm" ,
1319+ " convert-nvvm-to-llvm" ,
1320+ # common passes
1321+ " canonicalize" ,
1322+ # # ttir passes
1323+ # "triton-combine",
1324+ # "triton-reorder-broadcast",
1325+ # "triton-rewrite-tensor-pointer",
1326+ # "triton-rewrite-tensor-descriptor-to-pointer",
1327+ # "triton-loop-unroll",
1328+ # "triton-licm",
1329+ # "triton-loop-aware-cse",
1330+ # # TODO : should num-warps and num-ctas be set for each kernel?
1331+ # "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1332+ # # ttgir passes
1333+ # "tritongpu-coalesce",
1334+ # "tritongpu-optimize-thread-locality",
1335+ # "tritongpu-hoist-tmem-alloc",
1336+ # "tritongpu-assign-latencies",
1337+ # "tritongpu-pipeline",
1338+ # "tritongpu-schedule-loops",
1339+ # "tritongpu-automatic-warp-specialization",
1340+ # "tritongpu-prefetch",
1341+ # "tritongpu-accelerate-matmul",
1342+ # "tritongpu-reorder-instructions",
1343+ # "tritongpu-F32DotTC",
1344+ # "tritongpu-optimize-dot-operands",
1345+ # "tritongpu-remove-layout-conversions",
1346+ # "tritongpu-reduce-data-duplication",
1347+ # "tritongpu-hoist-tmem-alloc",
1348+ # "tritongpu-fuse-nested-loops",
1349+ # "tritongpu-rewrite-partition-dependencies",
1350+ # "tritongpu-partition-loops",
1351+ # "tritongpu-combine-tensor-select-and-if",
1352+ # # ttgir to llvm passes
1353+ # "tritongpu-allocate-warp-groups",
1354+ # "allocate-shared-memory",
1355+ # "tritongpu-global-scratch-memory-allocation",
1356+ # "tritongpu-optimize-accumulator-init",
1357+ # "tritongpu-coalesce-async-copy",
1358+ ],
1359+ " ," ,
1360+ )
1361+ end
1362+
13061363# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
13071364# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
13081365const enzyme_pass:: String = " enzyme{postpasses=\" arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\" }"
@@ -2256,7 +2313,8 @@ function compile_mlir!(
22562313 end
22572314 end
22582315
2259- run_pass_pipeline! (mod, " mark-func-memory-effects" , " mark-func-memory-effects" )
2316+ # XXX : re-enable this pass
2317+ # run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects")
22602318
22612319 func_op = MLIR. API. mlirSymbolTableLookup (
22622320 MLIR. IR. SymbolTable (MLIR. IR. Operation (mod)), fnname
0 commit comments