@@ -1291,9 +1291,66 @@ function optimization_passes(
12911291 push! (passes, " remove-duplicate-func-def" )
12921292 end
12931293 push! (passes, func_passes)
1294+ if backend == " cuda"
1295+ push! (passes, triton_optimization_passes ())
1296+ end
12941297 return join (passes, ' ,' )
12951298end
12961299
1300+ # https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc
1301+ function triton_optimization_passes ()
1302+ # TODO : check that all triton passes are included here
1303+ return join (
1304+ [
1305+ # convert passes
1306+ " convert-scf-to-cf" ,
1307+ " convert-cf-to-llvm" ,
1308+ " convert-index-to-llvm" ,
1309+ " convert-arith-to-llvm" ,
1310+ " convert-nvvm-to-llvm" ,
1311+ # common passes
1312+ " canonicalize" ,
1313+ # # ttir passes
1314+ # "triton-combine",
1315+ # "triton-reorder-broadcast",
1316+ # "triton-rewrite-tensor-pointer",
1317+ # "triton-rewrite-tensor-descriptor-to-pointer",
1318+ # "triton-loop-unroll",
1319+ # "triton-licm",
1320+ # "triton-loop-aware-cse",
1321+ # # TODO : should num-warps and num-ctas be set for each kernel?
1322+ # "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1323+ # # ttgir passes
1324+ # "tritongpu-coalesce",
1325+ # "tritongpu-optimize-thread-locality",
1326+ # "tritongpu-hoist-tmem-alloc",
1327+ # "tritongpu-assign-latencies",
1328+ # "tritongpu-pipeline",
1329+ # "tritongpu-schedule-loops",
1330+ # "tritongpu-automatic-warp-specialization",
1331+ # "tritongpu-prefetch",
1332+ # "tritongpu-accelerate-matmul",
1333+ # "tritongpu-reorder-instructions",
1334+ # "tritongpu-F32DotTC",
1335+ # "tritongpu-optimize-dot-operands",
1336+ # "tritongpu-remove-layout-conversions",
1337+ # "tritongpu-reduce-data-duplication",
1338+ # "tritongpu-hoist-tmem-alloc",
1339+ # "tritongpu-fuse-nested-loops",
1340+ # "tritongpu-rewrite-partition-dependencies",
1341+ # "tritongpu-partition-loops",
1342+ # "tritongpu-combine-tensor-select-and-if",
1343+ # # ttgir to llvm passes
1344+ # "tritongpu-allocate-warp-groups",
1345+ # "allocate-shared-memory",
1346+ # "tritongpu-global-scratch-memory-allocation",
1347+ # "tritongpu-optimize-accumulator-init",
1348+ # "tritongpu-coalesce-async-copy",
1349+ ],
1350+ " ," ,
1351+ )
1352+ end
1353+
12971354# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
12981355# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
12991356const enzyme_pass:: String = " enzyme{postpasses=\" arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\" }"
@@ -1425,6 +1482,7 @@ const cubinChip = Ref{String}("sm_60")
14251482const cubinFormat = Ref {String} (" bin" )
14261483const cuindexBitWidth = Ref {Int} (32 )
14271484const cuOptLevel = Ref {Int} (2 )
1485+ const cuWarpSize = Ref {Int} (32 )
14281486# Wgatever the relevant highest version from our LLVM is within NVPTX.td
14291487# Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684
14301488# We see relevant ptx version is CUDA 12.6 -> 85
@@ -2245,7 +2303,8 @@ function compile_mlir!(
22452303 end
22462304 end
22472305
2248- run_pass_pipeline! (mod, " mark-func-memory-effects" , " mark-func-memory-effects" )
2306+ # XXX : re-enable this pass
2307+ # run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects")
22492308
22502309 func_op = MLIR. API. mlirSymbolTableLookup (
22512310 MLIR. IR. SymbolTable (MLIR. IR. Operation (mod)), fnname
0 commit comments