Skip to content

Commit dee9490

Browse files
committed
feat: lowering triton now works
1 parent 3104809 commit dee9490

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

src/CompileOptions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ function CompileOptions(;
229229
:canonicalize,
230230
:just_batch,
231231
:none,
232+
:no_triton,
232233
]
233234
end
234235

src/Compiler.jl

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1851,12 +1851,14 @@ function compile_mlir!(
18511851
[
18521852
"mark-func-memory-effects",
18531853
opt_passes,
1854+
opt_passes_with_triton,
1855+
"lower-triton",
18541856
kern,
18551857
raise_passes,
18561858
"enzyme-batch",
18571859
opt_passes2,
18581860
enzyme_pass,
1859-
opt_passes_with_triton,
1861+
opt_passes2,
18601862
"canonicalize",
18611863
"remove-unnecessary-enzyme-ops",
18621864
"enzyme-simplify-math",
@@ -1878,6 +1880,7 @@ function compile_mlir!(
18781880
"enzyme-simplify-math",
18791881
legalize_chlo_to_stablehlo...,
18801882
opt_passes2,
1883+
"lower-triton",
18811884
kern,
18821885
raise_passes,
18831886
lower_enzymexla_linalg_pass,
@@ -1888,6 +1891,31 @@ function compile_mlir!(
18881891
),
18891892
"all",
18901893
)
1894+
elseif compile_options.optimization_passes === :no_triton
1895+
run_pass_pipeline!(
1896+
mod,
1897+
join(
1898+
if compile_options.raise_first
1899+
["mark-func-memory-effects", opt_passes]
1900+
else
1901+
[
1902+
"mark-func-memory-effects",
1903+
opt_passes,
1904+
"enzyme-batch",
1905+
opt_passes2,
1906+
enzyme_pass,
1907+
opt_passes2,
1908+
"canonicalize",
1909+
"remove-unnecessary-enzyme-ops",
1910+
"enzyme-simplify-math",
1911+
legalize_chlo_to_stablehlo...,
1912+
opt_passes2,
1913+
]
1914+
end,
1915+
',',
1916+
),
1917+
"before_kernel",
1918+
)
18911919
elseif compile_options.optimization_passes === :before_kernel
18921920
run_pass_pipeline!(
18931921
mod,
@@ -1920,13 +1948,14 @@ function compile_mlir!(
19201948
if compile_options.raise_first
19211949
[
19221950
"mark-func-memory-effects",
1923-
opt_passes,
1951+
opt_passes_with_triton,
1952+
"lower-triton",
19241953
kern,
19251954
raise_passes,
19261955
"enzyme-batch",
19271956
opt_passes2,
19281957
enzyme_pass,
1929-
opt_passes_with_triton,
1958+
opt_passes2,
19301959
"canonicalize",
19311960
"remove-unnecessary-enzyme-ops",
19321961
"enzyme-simplify-math",
@@ -1946,6 +1975,7 @@ function compile_mlir!(
19461975
"enzyme-simplify-math",
19471976
legalize_chlo_to_stablehlo...,
19481977
opt_passes2,
1978+
"lower-triton",
19491979
kern,
19501980
raise_passes,
19511981
]
@@ -1973,6 +2003,7 @@ function compile_mlir!(
19732003
"enzyme-simplify-math",
19742004
legalize_chlo_to_stablehlo...,
19752005
opt_passes2,
2006+
"lower-triton",
19762007
kern,
19772008
]
19782009
end,
@@ -2046,6 +2077,7 @@ function compile_mlir!(
20462077
"enzyme-simplify-math",
20472078
legalize_chlo_to_stablehlo...,
20482079
opt_passes_with_triton,
2080+
"lower-triton",
20492081
kern,
20502082
raise_passes,
20512083
lower_enzymexla_linalg_pass,
@@ -2063,7 +2095,8 @@ function compile_mlir!(
20632095
if compile_options.raise_first
20642096
[
20652097
"mark-func-memory-effects",
2066-
opt_passes,
2098+
opt_passes_with_triton,
2099+
"lower-triton",
20672100
kern,
20682101
raise_passes,
20692102
"enzyme-batch",
@@ -2078,9 +2111,10 @@ function compile_mlir!(
20782111
"mark-func-memory-effects",
20792112
opt_passes,
20802113
"enzyme-batch",
2081-
opt_passes2,
2114+
opt_passes_with_triton,
20822115
enzyme_pass,
20832116
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
2117+
"lower-triton",
20842118
kern,
20852119
raise_passes,
20862120
lower_enzymexla_linalg_pass,

src/Ops.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1864,7 +1864,7 @@ function triton_call(
18641864
output_operand_aliases,
18651865
MLIR.IR.Attribute(
18661866
MLIR.API.stablehloOutputOperandAliasGet(
1867-
MLIR.IR.context(), 0, C_NULL, Int64(i - 1), 0, C_NULL
1867+
MLIR.IR.context(), 1, Int64[i - 1], Int64(i - 1), 0, C_NULL
18681868
),
18691869
),
18701870
)

0 commit comments

Comments
 (0)