Skip to content

Commit d543875

Browse files
committed
feat: lowering triton now works
1 parent e8a3e1d commit d543875

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

src/CompileOptions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ function CompileOptions(;
229229
:canonicalize,
230230
:just_batch,
231231
:none,
232+
:no_triton,
232233
]
233234
end
234235

src/Compiler.jl

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,12 +1844,14 @@ function compile_mlir!(
18441844
[
18451845
"mark-func-memory-effects",
18461846
opt_passes,
1847+
opt_passes_with_triton,
1848+
"lower-triton",
18471849
kern,
18481850
raise_passes,
18491851
"enzyme-batch",
18501852
opt_passes2,
18511853
enzyme_pass,
1852-
opt_passes_with_triton,
1854+
opt_passes2,
18531855
"canonicalize",
18541856
"remove-unnecessary-enzyme-ops",
18551857
"enzyme-simplify-math",
@@ -1871,6 +1873,7 @@ function compile_mlir!(
18711873
"enzyme-simplify-math",
18721874
legalize_chlo_to_stablehlo...,
18731875
opt_passes2,
1876+
"lower-triton",
18741877
kern,
18751878
raise_passes,
18761879
lower_enzymexla_linalg_pass,
@@ -1881,6 +1884,31 @@ function compile_mlir!(
18811884
),
18821885
"all",
18831886
)
1887+
elseif compile_options.optimization_passes === :no_triton
1888+
run_pass_pipeline!(
1889+
mod,
1890+
join(
1891+
if compile_options.raise_first
1892+
["mark-func-memory-effects", opt_passes]
1893+
else
1894+
[
1895+
"mark-func-memory-effects",
1896+
opt_passes,
1897+
"enzyme-batch",
1898+
opt_passes2,
1899+
enzyme_pass,
1900+
opt_passes2,
1901+
"canonicalize",
1902+
"remove-unnecessary-enzyme-ops",
1903+
"enzyme-simplify-math",
1904+
legalize_chlo_to_stablehlo...,
1905+
opt_passes2,
1906+
]
1907+
end,
1908+
',',
1909+
),
1910+
"before_kernel",
1911+
)
18841912
elseif compile_options.optimization_passes === :before_kernel
18851913
run_pass_pipeline!(
18861914
mod,
@@ -1913,13 +1941,14 @@ function compile_mlir!(
19131941
if compile_options.raise_first
19141942
[
19151943
"mark-func-memory-effects",
1916-
opt_passes,
1944+
opt_passes_with_triton,
1945+
"lower-triton",
19171946
kern,
19181947
raise_passes,
19191948
"enzyme-batch",
19201949
opt_passes2,
19211950
enzyme_pass,
1922-
opt_passes_with_triton,
1951+
opt_passes2,
19231952
"canonicalize",
19241953
"remove-unnecessary-enzyme-ops",
19251954
"enzyme-simplify-math",
@@ -1939,6 +1968,7 @@ function compile_mlir!(
19391968
"enzyme-simplify-math",
19401969
legalize_chlo_to_stablehlo...,
19411970
opt_passes2,
1971+
"lower-triton",
19421972
kern,
19431973
raise_passes,
19441974
]
@@ -1966,6 +1996,7 @@ function compile_mlir!(
19661996
"enzyme-simplify-math",
19671997
legalize_chlo_to_stablehlo...,
19681998
opt_passes2,
1999+
"lower-triton",
19692000
kern,
19702001
]
19712002
end,
@@ -2039,6 +2070,7 @@ function compile_mlir!(
20392070
"enzyme-simplify-math",
20402071
legalize_chlo_to_stablehlo...,
20412072
opt_passes_with_triton,
2073+
"lower-triton",
20422074
kern,
20432075
raise_passes,
20442076
lower_enzymexla_linalg_pass,
@@ -2056,7 +2088,8 @@ function compile_mlir!(
20562088
if compile_options.raise_first
20572089
[
20582090
"mark-func-memory-effects",
2059-
opt_passes,
2091+
opt_passes_with_triton,
2092+
"lower-triton",
20602093
kern,
20612094
raise_passes,
20622095
"enzyme-batch",
@@ -2071,9 +2104,10 @@ function compile_mlir!(
20712104
"mark-func-memory-effects",
20722105
opt_passes,
20732106
"enzyme-batch",
2074-
opt_passes2,
2107+
opt_passes_with_triton,
20752108
enzyme_pass,
20762109
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
2110+
"lower-triton",
20772111
kern,
20782112
raise_passes,
20792113
lower_enzymexla_linalg_pass,

src/Ops.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1864,7 +1864,7 @@ function triton_call(
18641864
output_operand_aliases,
18651865
MLIR.IR.Attribute(
18661866
MLIR.API.stablehloOutputOperandAliasGet(
1867-
MLIR.IR.context(), 0, C_NULL, Int64(i - 1), 0, C_NULL
1867+
MLIR.IR.context(), 1, Int64[i - 1], Int64(i - 1), 0, C_NULL
18681868
),
18691869
),
18701870
)

0 commit comments

Comments
 (0)