@@ -702,6 +702,7 @@ function optimization_passes(
702702 lower_comms:: Bool = true ,
703703 max_constant_threshold:: Int = 1024 ,
704704 backend:: String = " gpu" ,
705+ enable_triton_passes:: Bool = false ,
705706)
706707 transform_passes_list = [
707708 " patterns=compare_op_canon<16>" ,
@@ -1291,7 +1292,7 @@ function optimization_passes(
12911292 push! (passes, " remove-duplicate-func-def" )
12921293 end
12931294 push! (passes, func_passes)
1294- if backend == " cuda"
1295+ if enable_triton_passes && backend == " cuda"
12951296 push! (passes, triton_optimization_passes ())
12961297 end
12971298 return join (passes, ' ,' )
@@ -1366,12 +1367,11 @@ function triton_optimization_passes()
13661367 " allocate-shared-memory" ,
13671368 " triton-tensor-memory-allocation" ,
13681369 " tritongpu-global-scratch-memory-allocation" ,
1369- # TODO : register the commented out passes
1370- # "convert-triton-gpu-to-llvm",
1370+ " convert-triton-gpu-to-llvm" ,
13711371 " canonicalize" ,
13721372 " cse" ,
1373- # "convert-nv-gpu-to-llvm",
1374- # "convert-warp-specialize-to-llvm",
1373+ " convert-nv-gpu-to-llvm" ,
1374+ " convert-warp-specialize-to-llvm" ,
13751375 " reconcile-unrealized-casts" ,
13761376 " canonicalize" ,
13771377 " cse" ,
@@ -1766,10 +1766,28 @@ function compile_mlir!(
17661766 end
17671767
17681768 opt_passes = optimization_passes (
1769- compile_options; sroa= true , recognize_comms, lower_comms, backend
1769+ compile_options;
1770+ sroa= true ,
1771+ recognize_comms,
1772+ lower_comms,
1773+ backend,
1774+ enable_triton_passes= false ,
17701775 )
17711776 opt_passes2 = optimization_passes (
1772- compile_options; sroa= false , recognize_comms, lower_comms, backend
1777+ compile_options;
1778+ sroa= false ,
1779+ recognize_comms,
1780+ lower_comms,
1781+ backend,
1782+ enable_triton_passes= false ,
1783+ )
1784+ opt_passes3 = optimization_passes (
1785+ compile_options;
1786+ sroa= false ,
1787+ recognize_comms,
1788+ lower_comms,
1789+ backend,
1790+ enable_triton_passes= true ,
17731791 )
17741792
17751793 raise_passes = if raise isa String
@@ -1784,15 +1802,15 @@ function compile_mlir!(
17841802 opt_passes2
17851803
17861804 if DUS_TO_CONCAT[]
1787- opt_passes3 = optimization_passes (
1805+ opt_passes_dus_to_concat = optimization_passes (
17881806 compile_options;
17891807 sroa= false ,
17901808 dus_to_concat= true ,
17911809 recognize_comms,
17921810 lower_comms,
17931811 backend,
17941812 )
1795- result = result * " ," * opt_passes3
1813+ result = result * " ," * opt_passes_dus_to_concat
17961814 end
17971815 result
17981816 else
@@ -1823,12 +1841,12 @@ function compile_mlir!(
18231841 " enzyme-batch" ,
18241842 opt_passes2,
18251843 enzyme_pass,
1826- opt_passes2 ,
1844+ opt_passes3 ,
18271845 " canonicalize" ,
18281846 " remove-unnecessary-enzyme-ops" ,
18291847 " enzyme-simplify-math" ,
18301848 legalize_chlo_to_stablehlo... ,
1831- opt_passes2 ,
1849+ opt_passes3 ,
18321850 lower_enzymexla_linalg_pass,
18331851 jit,
18341852 ]
@@ -1839,12 +1857,12 @@ function compile_mlir!(
18391857 " enzyme-batch" ,
18401858 opt_passes2,
18411859 enzyme_pass,
1842- opt_passes2 ,
1860+ opt_passes3 ,
18431861 " canonicalize" ,
18441862 " remove-unnecessary-enzyme-ops" ,
18451863 " enzyme-simplify-math" ,
18461864 legalize_chlo_to_stablehlo... ,
1847- opt_passes2 ,
1865+ opt_passes3 ,
18481866 kern,
18491867 raise_passes,
18501868 lower_enzymexla_linalg_pass,
@@ -1868,12 +1886,12 @@ function compile_mlir!(
18681886 " enzyme-batch" ,
18691887 opt_passes2,
18701888 enzyme_pass,
1871- opt_passes2 ,
1889+ opt_passes3 ,
18721890 " canonicalize" ,
18731891 " remove-unnecessary-enzyme-ops" ,
18741892 " enzyme-simplify-math" ,
18751893 legalize_chlo_to_stablehlo... ,
1876- opt_passes2 ,
1894+ opt_passes3 ,
18771895 ]
18781896 end ,
18791897 ' ,' ,
@@ -1893,12 +1911,12 @@ function compile_mlir!(
18931911 " enzyme-batch" ,
18941912 opt_passes2,
18951913 enzyme_pass,
1896- opt_passes2 ,
1914+ opt_passes3 ,
18971915 " canonicalize" ,
18981916 " remove-unnecessary-enzyme-ops" ,
18991917 " enzyme-simplify-math" ,
19001918 legalize_chlo_to_stablehlo... ,
1901- opt_passes2 ,
1919+ opt_passes3 ,
19021920 ]
19031921 else
19041922 [
@@ -1907,12 +1925,12 @@ function compile_mlir!(
19071925 " enzyme-batch" ,
19081926 opt_passes2,
19091927 enzyme_pass,
1910- opt_passes2 ,
1928+ opt_passes3 ,
19111929 " canonicalize" ,
19121930 " remove-unnecessary-enzyme-ops" ,
19131931 " enzyme-simplify-math" ,
19141932 legalize_chlo_to_stablehlo... ,
1915- opt_passes2 ,
1933+ opt_passes3 ,
19161934 kern,
19171935 raise_passes,
19181936 ]
@@ -1934,12 +1952,12 @@ function compile_mlir!(
19341952 " enzyme-batch" ,
19351953 opt_passes2,
19361954 enzyme_pass,
1937- opt_passes2 ,
1955+ opt_passes3 ,
19381956 " canonicalize" ,
19391957 " remove-unnecessary-enzyme-ops" ,
19401958 " enzyme-simplify-math" ,
19411959 legalize_chlo_to_stablehlo... ,
1942- opt_passes2 ,
1960+ opt_passes3 ,
19431961 kern,
19441962 ]
19451963 end ,
@@ -1957,12 +1975,12 @@ function compile_mlir!(
19571975 " enzyme-batch" ,
19581976 opt_passes2,
19591977 enzyme_pass,
1960- opt_passes2 ,
1978+ opt_passes3 ,
19611979 " canonicalize" ,
19621980 " remove-unnecessary-enzyme-ops" ,
19631981 " enzyme-simplify-math" ,
19641982 legalize_chlo_to_stablehlo... ,
1965- opt_passes2 ,
1983+ opt_passes3 ,
19661984 ],
19671985 ' ,' ,
19681986 ),
@@ -1999,7 +2017,7 @@ function compile_mlir!(
19992017 " remove-unnecessary-enzyme-ops" ,
20002018 " enzyme-simplify-math" ,
20012019 legalize_chlo_to_stablehlo... ,
2002- opt_passes2 ,
2020+ opt_passes3 ,
20032021 lower_enzymexla_linalg_pass,
20042022 jit,
20052023 ]
@@ -2012,7 +2030,7 @@ function compile_mlir!(
20122030 " remove-unnecessary-enzyme-ops" ,
20132031 " enzyme-simplify-math" ,
20142032 legalize_chlo_to_stablehlo... ,
2015- opt_passes2 ,
2033+ opt_passes3 ,
20162034 kern,
20172035 raise_passes,
20182036 lower_enzymexla_linalg_pass,
@@ -2223,7 +2241,7 @@ function compile_mlir!(
22232241 run_pass_pipeline! (
22242242 mod,
22252243 join (
2226- [opt_passes, " canonicalize" , " cse" , " canonicalize" , opt_passes2 ],
2244+ [opt_passes, " canonicalize" , " cse" , " canonicalize" , opt_passes3 ],
22272245 " ," ,
22282246 ),
22292247 " mid_pad_opts" ,
0 commit comments