Skip to content

Commit adec344

Browse files
committed
feat: lowering triton now works
1 parent 7e213c3 commit adec344

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

src/CompileOptions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ function CompileOptions(;
229229
:canonicalize,
230230
:just_batch,
231231
:none,
232+
:no_triton,
232233
]
233234
end
234235

src/Compiler.jl

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,12 +1846,14 @@ function compile_mlir!(
18461846
[
18471847
"mark-func-memory-effects",
18481848
opt_passes,
1849+
opt_passes_with_triton,
1850+
"lower-triton",
18491851
kern,
18501852
raise_passes,
18511853
"enzyme-batch",
18521854
opt_passes2,
18531855
enzyme_pass,
1854-
opt_passes_with_triton,
1856+
opt_passes2,
18551857
"canonicalize",
18561858
"remove-unnecessary-enzyme-ops",
18571859
"enzyme-simplify-math",
@@ -1873,6 +1875,7 @@ function compile_mlir!(
18731875
"enzyme-simplify-math",
18741876
legalize_chlo_to_stablehlo...,
18751877
opt_passes2,
1878+
"lower-triton",
18761879
kern,
18771880
raise_passes,
18781881
lower_enzymexla_linalg_pass,
@@ -1883,6 +1886,31 @@ function compile_mlir!(
18831886
),
18841887
"all",
18851888
)
1889+
elseif compile_options.optimization_passes === :no_triton
1890+
run_pass_pipeline!(
1891+
mod,
1892+
join(
1893+
if compile_options.raise_first
1894+
["mark-func-memory-effects", opt_passes]
1895+
else
1896+
[
1897+
"mark-func-memory-effects",
1898+
opt_passes,
1899+
"enzyme-batch",
1900+
opt_passes2,
1901+
enzyme_pass,
1902+
opt_passes2,
1903+
"canonicalize",
1904+
"remove-unnecessary-enzyme-ops",
1905+
"enzyme-simplify-math",
1906+
legalize_chlo_to_stablehlo...,
1907+
opt_passes2,
1908+
]
1909+
end,
1910+
',',
1911+
),
1912+
"before_kernel",
1913+
)
18861914
elseif compile_options.optimization_passes === :before_kernel
18871915
run_pass_pipeline!(
18881916
mod,
@@ -1915,13 +1943,14 @@ function compile_mlir!(
19151943
if compile_options.raise_first
19161944
[
19171945
"mark-func-memory-effects",
1918-
opt_passes,
1946+
opt_passes_with_triton,
1947+
"lower-triton",
19191948
kern,
19201949
raise_passes,
19211950
"enzyme-batch",
19221951
opt_passes2,
19231952
enzyme_pass,
1924-
opt_passes_with_triton,
1953+
opt_passes2,
19251954
"canonicalize",
19261955
"remove-unnecessary-enzyme-ops",
19271956
"enzyme-simplify-math",
@@ -1941,6 +1970,7 @@ function compile_mlir!(
19411970
"enzyme-simplify-math",
19421971
legalize_chlo_to_stablehlo...,
19431972
opt_passes2,
1973+
"lower-triton",
19441974
kern,
19451975
raise_passes,
19461976
]
@@ -1968,6 +1998,7 @@ function compile_mlir!(
19681998
"enzyme-simplify-math",
19691999
legalize_chlo_to_stablehlo...,
19702000
opt_passes2,
2001+
"lower-triton",
19712002
kern,
19722003
]
19732004
end,
@@ -2041,6 +2072,7 @@ function compile_mlir!(
20412072
"enzyme-simplify-math",
20422073
legalize_chlo_to_stablehlo...,
20432074
opt_passes_with_triton,
2075+
"lower-triton",
20442076
kern,
20452077
raise_passes,
20462078
lower_enzymexla_linalg_pass,
@@ -2058,7 +2090,8 @@ function compile_mlir!(
20582090
if compile_options.raise_first
20592091
[
20602092
"mark-func-memory-effects",
2061-
opt_passes,
2093+
opt_passes_with_triton,
2094+
"lower-triton",
20622095
kern,
20632096
raise_passes,
20642097
"enzyme-batch",
@@ -2073,9 +2106,10 @@ function compile_mlir!(
20732106
"mark-func-memory-effects",
20742107
opt_passes,
20752108
"enzyme-batch",
2076-
opt_passes2,
2109+
opt_passes_with_triton,
20772110
enzyme_pass,
20782111
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
2112+
"lower-triton",
20792113
kern,
20802114
raise_passes,
20812115
lower_enzymexla_linalg_pass,

src/Ops.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1864,7 +1864,7 @@ function triton_call(
18641864
output_operand_aliases,
18651865
MLIR.IR.Attribute(
18661866
MLIR.API.stablehloOutputOperandAliasGet(
1867-
MLIR.IR.context(), 0, C_NULL, Int64(i - 1), 0, C_NULL
1867+
MLIR.IR.context(), 1, Int64[i - 1], Int64(i - 1), 0, C_NULL
18681868
),
18691869
),
18701870
)

0 commit comments

Comments
 (0)