@@ -702,6 +702,7 @@ function optimization_passes(
702702 lower_comms:: Bool = true ,
703703 max_constant_threshold:: Int = 1024 ,
704704 backend:: String = " gpu" ,
705+ enable_triton_passes:: Bool = false ,
705706)
706707 transform_passes_list = [
707708 " patterns=compare_op_canon<16>" ,
@@ -1300,7 +1301,7 @@ function optimization_passes(
13001301 push! (passes, " remove-duplicate-func-def" )
13011302 end
13021303 push! (passes, func_passes)
1303- if backend == " cuda"
1304+ if enable_triton_passes && backend == " cuda"
13041305 push! (passes, triton_optimization_passes ())
13051306 end
13061307 return join (passes, ' ,' )
@@ -1375,12 +1376,11 @@ function triton_optimization_passes()
13751376 " allocate-shared-memory" ,
13761377 " triton-tensor-memory-allocation" ,
13771378 " tritongpu-global-scratch-memory-allocation" ,
1378- # TODO : register the commented out passes
1379- # "convert-triton-gpu-to-llvm",
1379+ " convert-triton-gpu-to-llvm" ,
13801380 " canonicalize" ,
13811381 " cse" ,
1382- # "convert-nv-gpu-to-llvm",
1383- # "convert-warp-specialize-to-llvm",
1382+ " convert-nv-gpu-to-llvm" ,
1383+ " convert-warp-specialize-to-llvm" ,
13841384 " reconcile-unrealized-casts" ,
13851385 " canonicalize" ,
13861386 " cse" ,
@@ -1781,10 +1781,28 @@ function compile_mlir!(
17811781 end
17821782
17831783 opt_passes = optimization_passes (
1784- compile_options; sroa= true , recognize_comms, lower_comms, backend
1784+ compile_options;
1785+ sroa= true ,
1786+ recognize_comms,
1787+ lower_comms,
1788+ backend,
1789+ enable_triton_passes= false ,
17851790 )
17861791 opt_passes2 = optimization_passes (
1787- compile_options; sroa= false , recognize_comms, lower_comms, backend
1792+ compile_options;
1793+ sroa= false ,
1794+ recognize_comms,
1795+ lower_comms,
1796+ backend,
1797+ enable_triton_passes= false ,
1798+ )
1799+ opt_passes3 = optimization_passes (
1800+ compile_options;
1801+ sroa= false ,
1802+ recognize_comms,
1803+ lower_comms,
1804+ backend,
1805+ enable_triton_passes= true ,
17881806 )
17891807
17901808 raise_passes = if raise isa String
@@ -1799,15 +1817,15 @@ function compile_mlir!(
17991817 opt_passes2
18001818
18011819 if DUS_TO_CONCAT[]
1802- opt_passes3 = optimization_passes (
1820+ opt_passes_dus_to_concat = optimization_passes (
18031821 compile_options;
18041822 sroa= false ,
18051823 dus_to_concat= true ,
18061824 recognize_comms,
18071825 lower_comms,
18081826 backend,
18091827 )
1810- result = result * " ," * opt_passes3
1828+ result = result * " ," * opt_passes_dus_to_concat
18111829 end
18121830 result
18131831 else
@@ -1838,12 +1856,12 @@ function compile_mlir!(
18381856 " enzyme-batch" ,
18391857 opt_passes2,
18401858 enzyme_pass,
1841- opt_passes2 ,
1859+ opt_passes3 ,
18421860 " canonicalize" ,
18431861 " remove-unnecessary-enzyme-ops" ,
18441862 " enzyme-simplify-math" ,
18451863 legalize_chlo_to_stablehlo... ,
1846- opt_passes2 ,
1864+ opt_passes3 ,
18471865 lower_enzymexla_linalg_pass,
18481866 jit,
18491867 ]
@@ -1854,12 +1872,12 @@ function compile_mlir!(
18541872 " enzyme-batch" ,
18551873 opt_passes2,
18561874 enzyme_pass,
1857- opt_passes2 ,
1875+ opt_passes3 ,
18581876 " canonicalize" ,
18591877 " remove-unnecessary-enzyme-ops" ,
18601878 " enzyme-simplify-math" ,
18611879 legalize_chlo_to_stablehlo... ,
1862- opt_passes2 ,
1880+ opt_passes3 ,
18631881 kern,
18641882 raise_passes,
18651883 lower_enzymexla_linalg_pass,
@@ -1883,12 +1901,12 @@ function compile_mlir!(
18831901 " enzyme-batch" ,
18841902 opt_passes2,
18851903 enzyme_pass,
1886- opt_passes2 ,
1904+ opt_passes3 ,
18871905 " canonicalize" ,
18881906 " remove-unnecessary-enzyme-ops" ,
18891907 " enzyme-simplify-math" ,
18901908 legalize_chlo_to_stablehlo... ,
1891- opt_passes2 ,
1909+ opt_passes3 ,
18921910 ]
18931911 end ,
18941912 ' ,' ,
@@ -1908,12 +1926,12 @@ function compile_mlir!(
19081926 " enzyme-batch" ,
19091927 opt_passes2,
19101928 enzyme_pass,
1911- opt_passes2 ,
1929+ opt_passes3 ,
19121930 " canonicalize" ,
19131931 " remove-unnecessary-enzyme-ops" ,
19141932 " enzyme-simplify-math" ,
19151933 legalize_chlo_to_stablehlo... ,
1916- opt_passes2 ,
1934+ opt_passes3 ,
19171935 ]
19181936 else
19191937 [
@@ -1922,12 +1940,12 @@ function compile_mlir!(
19221940 " enzyme-batch" ,
19231941 opt_passes2,
19241942 enzyme_pass,
1925- opt_passes2 ,
1943+ opt_passes3 ,
19261944 " canonicalize" ,
19271945 " remove-unnecessary-enzyme-ops" ,
19281946 " enzyme-simplify-math" ,
19291947 legalize_chlo_to_stablehlo... ,
1930- opt_passes2 ,
1948+ opt_passes3 ,
19311949 kern,
19321950 raise_passes,
19331951 ]
@@ -1949,12 +1967,12 @@ function compile_mlir!(
19491967 " enzyme-batch" ,
19501968 opt_passes2,
19511969 enzyme_pass,
1952- opt_passes2 ,
1970+ opt_passes3 ,
19531971 " canonicalize" ,
19541972 " remove-unnecessary-enzyme-ops" ,
19551973 " enzyme-simplify-math" ,
19561974 legalize_chlo_to_stablehlo... ,
1957- opt_passes2 ,
1975+ opt_passes3 ,
19581976 kern,
19591977 ]
19601978 end ,
@@ -1972,12 +1990,12 @@ function compile_mlir!(
19721990 " enzyme-batch" ,
19731991 opt_passes2,
19741992 enzyme_pass,
1975- opt_passes2 ,
1993+ opt_passes3 ,
19761994 " canonicalize" ,
19771995 " remove-unnecessary-enzyme-ops" ,
19781996 " enzyme-simplify-math" ,
19791997 legalize_chlo_to_stablehlo... ,
1980- opt_passes2 ,
1998+ opt_passes3 ,
19811999 ],
19822000 ' ,' ,
19832001 ),
@@ -2014,7 +2032,7 @@ function compile_mlir!(
20142032 " remove-unnecessary-enzyme-ops" ,
20152033 " enzyme-simplify-math" ,
20162034 legalize_chlo_to_stablehlo... ,
2017- opt_passes2 ,
2035+ opt_passes3 ,
20182036 lower_enzymexla_linalg_pass,
20192037 jit,
20202038 ]
@@ -2027,7 +2045,7 @@ function compile_mlir!(
20272045 " remove-unnecessary-enzyme-ops" ,
20282046 " enzyme-simplify-math" ,
20292047 legalize_chlo_to_stablehlo... ,
2030- opt_passes2 ,
2048+ opt_passes3 ,
20312049 kern,
20322050 raise_passes,
20332051 lower_enzymexla_linalg_pass,
@@ -2238,7 +2256,7 @@ function compile_mlir!(
22382256 run_pass_pipeline! (
22392257 mod,
22402258 join (
2241- [opt_passes, " canonicalize" , " cse" , " canonicalize" , opt_passes2 ],
2259+ [opt_passes, " canonicalize" , " cse" , " canonicalize" , opt_passes3 ],
22422260 " ," ,
22432261 ),
22442262 " mid_pad_opts" ,
0 commit comments