@@ -702,6 +702,7 @@ function optimization_passes(
702702 lower_comms:: Bool = true ,
703703 max_constant_threshold:: Int = 1024 ,
704704 backend:: String = " gpu" ,
705+ enable_triton_passes:: Bool = false ,
705706)
706707 transform_passes_list = [
707708 " patterns=compare_op_canon<16>" ,
@@ -1298,7 +1299,7 @@ function optimization_passes(
12981299 push! (passes, " remove-duplicate-func-def" )
12991300 end
13001301 push! (passes, func_passes)
1301- if backend == " cuda"
1302+ if enable_triton_passes && backend == " cuda"
13021303 push! (passes, triton_optimization_passes ())
13031304 end
13041305 return join (passes, ' ,' )
@@ -1373,12 +1374,11 @@ function triton_optimization_passes()
13731374 " allocate-shared-memory" ,
13741375 " triton-tensor-memory-allocation" ,
13751376 " tritongpu-global-scratch-memory-allocation" ,
1376- # TODO : register the commented out passes
1377- # "convert-triton-gpu-to-llvm",
1377+ " convert-triton-gpu-to-llvm" ,
13781378 " canonicalize" ,
13791379 " cse" ,
1380- # "convert-nv-gpu-to-llvm",
1381- # "convert-warp-specialize-to-llvm",
1380+ " convert-nv-gpu-to-llvm" ,
1381+ " convert-warp-specialize-to-llvm" ,
13821382 " reconcile-unrealized-casts" ,
13831383 " canonicalize" ,
13841384 " cse" ,
@@ -1774,10 +1774,28 @@ function compile_mlir!(
17741774 end
17751775
17761776 opt_passes = optimization_passes (
1777- compile_options; sroa= true , recognize_comms, lower_comms, backend
1777+ compile_options;
1778+ sroa= true ,
1779+ recognize_comms,
1780+ lower_comms,
1781+ backend,
1782+ enable_triton_passes= false ,
17781783 )
17791784 opt_passes2 = optimization_passes (
1780- compile_options; sroa= false , recognize_comms, lower_comms, backend
1785+ compile_options;
1786+ sroa= false ,
1787+ recognize_comms,
1788+ lower_comms,
1789+ backend,
1790+ enable_triton_passes= false ,
1791+ )
1792+ opt_passes3 = optimization_passes (
1793+ compile_options;
1794+ sroa= false ,
1795+ recognize_comms,
1796+ lower_comms,
1797+ backend,
1798+ enable_triton_passes= true ,
17811799 )
17821800
17831801 raise_passes = if raise isa String
@@ -1792,15 +1810,15 @@ function compile_mlir!(
17921810 opt_passes2
17931811
17941812 if DUS_TO_CONCAT[]
1795- opt_passes3 = optimization_passes (
1813+ opt_passes_dus_to_concat = optimization_passes (
17961814 compile_options;
17971815 sroa= false ,
17981816 dus_to_concat= true ,
17991817 recognize_comms,
18001818 lower_comms,
18011819 backend,
18021820 )
1803- result = result * " ," * opt_passes3
1821+ result = result * " ," * opt_passes_dus_to_concat
18041822 end
18051823 result
18061824 else
@@ -1831,12 +1849,12 @@ function compile_mlir!(
18311849 " enzyme-batch" ,
18321850 opt_passes2,
18331851 enzyme_pass,
1834- opt_passes2 ,
1852+ opt_passes3 ,
18351853 " canonicalize" ,
18361854 " remove-unnecessary-enzyme-ops" ,
18371855 " enzyme-simplify-math" ,
18381856 legalize_chlo_to_stablehlo... ,
1839- opt_passes2 ,
1857+ opt_passes3 ,
18401858 lower_enzymexla_linalg_pass,
18411859 jit,
18421860 ]
@@ -1847,12 +1865,12 @@ function compile_mlir!(
18471865 " enzyme-batch" ,
18481866 opt_passes2,
18491867 enzyme_pass,
1850- opt_passes2 ,
1868+ opt_passes3 ,
18511869 " canonicalize" ,
18521870 " remove-unnecessary-enzyme-ops" ,
18531871 " enzyme-simplify-math" ,
18541872 legalize_chlo_to_stablehlo... ,
1855- opt_passes2 ,
1873+ opt_passes3 ,
18561874 kern,
18571875 raise_passes,
18581876 lower_enzymexla_linalg_pass,
@@ -1876,12 +1894,12 @@ function compile_mlir!(
18761894 " enzyme-batch" ,
18771895 opt_passes2,
18781896 enzyme_pass,
1879- opt_passes2 ,
1897+ opt_passes3 ,
18801898 " canonicalize" ,
18811899 " remove-unnecessary-enzyme-ops" ,
18821900 " enzyme-simplify-math" ,
18831901 legalize_chlo_to_stablehlo... ,
1884- opt_passes2 ,
1902+ opt_passes3 ,
18851903 ]
18861904 end ,
18871905 ' ,' ,
@@ -1901,12 +1919,12 @@ function compile_mlir!(
19011919 " enzyme-batch" ,
19021920 opt_passes2,
19031921 enzyme_pass,
1904- opt_passes2 ,
1922+ opt_passes3 ,
19051923 " canonicalize" ,
19061924 " remove-unnecessary-enzyme-ops" ,
19071925 " enzyme-simplify-math" ,
19081926 legalize_chlo_to_stablehlo... ,
1909- opt_passes2 ,
1927+ opt_passes3 ,
19101928 ]
19111929 else
19121930 [
@@ -1915,12 +1933,12 @@ function compile_mlir!(
19151933 " enzyme-batch" ,
19161934 opt_passes2,
19171935 enzyme_pass,
1918- opt_passes2 ,
1936+ opt_passes3 ,
19191937 " canonicalize" ,
19201938 " remove-unnecessary-enzyme-ops" ,
19211939 " enzyme-simplify-math" ,
19221940 legalize_chlo_to_stablehlo... ,
1923- opt_passes2 ,
1941+ opt_passes3 ,
19241942 kern,
19251943 raise_passes,
19261944 ]
@@ -1942,12 +1960,12 @@ function compile_mlir!(
19421960 " enzyme-batch" ,
19431961 opt_passes2,
19441962 enzyme_pass,
1945- opt_passes2 ,
1963+ opt_passes3 ,
19461964 " canonicalize" ,
19471965 " remove-unnecessary-enzyme-ops" ,
19481966 " enzyme-simplify-math" ,
19491967 legalize_chlo_to_stablehlo... ,
1950- opt_passes2 ,
1968+ opt_passes3 ,
19511969 kern,
19521970 ]
19531971 end ,
@@ -1965,12 +1983,12 @@ function compile_mlir!(
19651983 " enzyme-batch" ,
19661984 opt_passes2,
19671985 enzyme_pass,
1968- opt_passes2 ,
1986+ opt_passes3 ,
19691987 " canonicalize" ,
19701988 " remove-unnecessary-enzyme-ops" ,
19711989 " enzyme-simplify-math" ,
19721990 legalize_chlo_to_stablehlo... ,
1973- opt_passes2 ,
1991+ opt_passes3 ,
19741992 ],
19751993 ' ,' ,
19761994 ),
@@ -2007,7 +2025,7 @@ function compile_mlir!(
20072025 " remove-unnecessary-enzyme-ops" ,
20082026 " enzyme-simplify-math" ,
20092027 legalize_chlo_to_stablehlo... ,
2010- opt_passes2 ,
2028+ opt_passes3 ,
20112029 lower_enzymexla_linalg_pass,
20122030 jit,
20132031 ]
@@ -2020,7 +2038,7 @@ function compile_mlir!(
20202038 " remove-unnecessary-enzyme-ops" ,
20212039 " enzyme-simplify-math" ,
20222040 legalize_chlo_to_stablehlo... ,
2023- opt_passes2 ,
2041+ opt_passes3 ,
20242042 kern,
20252043 raise_passes,
20262044 lower_enzymexla_linalg_pass,
@@ -2231,7 +2249,7 @@ function compile_mlir!(
22312249 run_pass_pipeline! (
22322250 mod,
22332251 join (
2234- [opt_passes, " canonicalize" , " cse" , " canonicalize" , opt_passes2 ],
2252+ [opt_passes, " canonicalize" , " cse" , " canonicalize" , opt_passes3 ],
22352253 " ," ,
22362254 ),
22372255 " mid_pad_opts" ,
0 commit comments