Skip to content

Commit b4dc832

Browse files
committed
fix: kind of working
1 parent 8ae9ebe commit b4dc832

File tree

2 files changed

+45
-27
lines changed

2 files changed

+45
-27
lines changed

deps/ReactantExtra/make-bindings.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ for file in [
4242
"MPI.jl",
4343
"MemRef.jl",
4444
"SparseTensor.jl",
45-
"TritonExt.jl"
45+
"TritonExt.jl",
4646
]
4747
build_file(joinpath(src_dir, "mlir", "Dialects", file))
4848
end

src/Compiler.jl

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ function optimization_passes(
702702
lower_comms::Bool=true,
703703
max_constant_threshold::Int=1024,
704704
backend::String="gpu",
705+
enable_triton_passes::Bool=false,
705706
)
706707
transform_passes_list = [
707708
"patterns=compare_op_canon<16>",
@@ -1300,7 +1301,7 @@ function optimization_passes(
13001301
push!(passes, "remove-duplicate-func-def")
13011302
end
13021303
push!(passes, func_passes)
1303-
if backend == "cuda"
1304+
if enable_triton_passes && backend == "cuda"
13041305
push!(passes, triton_optimization_passes())
13051306
end
13061307
return join(passes, ',')
@@ -1375,12 +1376,11 @@ function triton_optimization_passes()
13751376
"allocate-shared-memory",
13761377
"triton-tensor-memory-allocation",
13771378
"tritongpu-global-scratch-memory-allocation",
1378-
# TODO: register the commented out passes
1379-
# "convert-triton-gpu-to-llvm",
1379+
"convert-triton-gpu-to-llvm",
13801380
"canonicalize",
13811381
"cse",
1382-
# "convert-nv-gpu-to-llvm",
1383-
# "convert-warp-specialize-to-llvm",
1382+
"convert-nv-gpu-to-llvm",
1383+
"convert-warp-specialize-to-llvm",
13841384
"reconcile-unrealized-casts",
13851385
"canonicalize",
13861386
"cse",
@@ -1776,10 +1776,28 @@ function compile_mlir!(
17761776
end
17771777

17781778
opt_passes = optimization_passes(
1779-
compile_options; sroa=true, recognize_comms, lower_comms, backend
1779+
compile_options;
1780+
sroa=true,
1781+
recognize_comms,
1782+
lower_comms,
1783+
backend,
1784+
enable_triton_passes=false,
17801785
)
17811786
opt_passes2 = optimization_passes(
1782-
compile_options; sroa=false, recognize_comms, lower_comms, backend
1787+
compile_options;
1788+
sroa=false,
1789+
recognize_comms,
1790+
lower_comms,
1791+
backend,
1792+
enable_triton_passes=false,
1793+
)
1794+
opt_passes3 = optimization_passes(
1795+
compile_options;
1796+
sroa=false,
1797+
recognize_comms,
1798+
lower_comms,
1799+
backend,
1800+
enable_triton_passes=true,
17831801
)
17841802

17851803
raise_passes = if raise isa String
@@ -1794,15 +1812,15 @@ function compile_mlir!(
17941812
opt_passes2
17951813

17961814
if DUS_TO_CONCAT[]
1797-
opt_passes3 = optimization_passes(
1815+
opt_passes_dus_to_concat = optimization_passes(
17981816
compile_options;
17991817
sroa=false,
18001818
dus_to_concat=true,
18011819
recognize_comms,
18021820
lower_comms,
18031821
backend,
18041822
)
1805-
result = result * "," * opt_passes3
1823+
result = result * "," * opt_passes_dus_to_concat
18061824
end
18071825
result
18081826
else
@@ -1833,12 +1851,12 @@ function compile_mlir!(
18331851
"enzyme-batch",
18341852
opt_passes2,
18351853
enzyme_pass,
1836-
opt_passes2,
1854+
opt_passes3,
18371855
"canonicalize",
18381856
"remove-unnecessary-enzyme-ops",
18391857
"enzyme-simplify-math",
18401858
legalize_chlo_to_stablehlo...,
1841-
opt_passes2,
1859+
opt_passes3,
18421860
lower_enzymexla_linalg_pass,
18431861
jit,
18441862
]
@@ -1849,12 +1867,12 @@ function compile_mlir!(
18491867
"enzyme-batch",
18501868
opt_passes2,
18511869
enzyme_pass,
1852-
opt_passes2,
1870+
opt_passes3,
18531871
"canonicalize",
18541872
"remove-unnecessary-enzyme-ops",
18551873
"enzyme-simplify-math",
18561874
legalize_chlo_to_stablehlo...,
1857-
opt_passes2,
1875+
opt_passes3,
18581876
kern,
18591877
raise_passes,
18601878
lower_enzymexla_linalg_pass,
@@ -1878,12 +1896,12 @@ function compile_mlir!(
18781896
"enzyme-batch",
18791897
opt_passes2,
18801898
enzyme_pass,
1881-
opt_passes2,
1899+
opt_passes3,
18821900
"canonicalize",
18831901
"remove-unnecessary-enzyme-ops",
18841902
"enzyme-simplify-math",
18851903
legalize_chlo_to_stablehlo...,
1886-
opt_passes2,
1904+
opt_passes3,
18871905
]
18881906
end,
18891907
',',
@@ -1903,12 +1921,12 @@ function compile_mlir!(
19031921
"enzyme-batch",
19041922
opt_passes2,
19051923
enzyme_pass,
1906-
opt_passes2,
1924+
opt_passes3,
19071925
"canonicalize",
19081926
"remove-unnecessary-enzyme-ops",
19091927
"enzyme-simplify-math",
19101928
legalize_chlo_to_stablehlo...,
1911-
opt_passes2,
1929+
opt_passes3,
19121930
]
19131931
else
19141932
[
@@ -1917,12 +1935,12 @@ function compile_mlir!(
19171935
"enzyme-batch",
19181936
opt_passes2,
19191937
enzyme_pass,
1920-
opt_passes2,
1938+
opt_passes3,
19211939
"canonicalize",
19221940
"remove-unnecessary-enzyme-ops",
19231941
"enzyme-simplify-math",
19241942
legalize_chlo_to_stablehlo...,
1925-
opt_passes2,
1943+
opt_passes3,
19261944
kern,
19271945
raise_passes,
19281946
]
@@ -1944,12 +1962,12 @@ function compile_mlir!(
19441962
"enzyme-batch",
19451963
opt_passes2,
19461964
enzyme_pass,
1947-
opt_passes2,
1965+
opt_passes3,
19481966
"canonicalize",
19491967
"remove-unnecessary-enzyme-ops",
19501968
"enzyme-simplify-math",
19511969
legalize_chlo_to_stablehlo...,
1952-
opt_passes2,
1970+
opt_passes3,
19531971
kern,
19541972
]
19551973
end,
@@ -1967,12 +1985,12 @@ function compile_mlir!(
19671985
"enzyme-batch",
19681986
opt_passes2,
19691987
enzyme_pass,
1970-
opt_passes2,
1988+
opt_passes3,
19711989
"canonicalize",
19721990
"remove-unnecessary-enzyme-ops",
19731991
"enzyme-simplify-math",
19741992
legalize_chlo_to_stablehlo...,
1975-
opt_passes2,
1993+
opt_passes3,
19761994
],
19771995
',',
19781996
),
@@ -2009,7 +2027,7 @@ function compile_mlir!(
20092027
"remove-unnecessary-enzyme-ops",
20102028
"enzyme-simplify-math",
20112029
legalize_chlo_to_stablehlo...,
2012-
opt_passes2,
2030+
opt_passes3,
20132031
lower_enzymexla_linalg_pass,
20142032
jit,
20152033
]
@@ -2022,7 +2040,7 @@ function compile_mlir!(
20222040
"remove-unnecessary-enzyme-ops",
20232041
"enzyme-simplify-math",
20242042
legalize_chlo_to_stablehlo...,
2025-
opt_passes2,
2043+
opt_passes3,
20262044
kern,
20272045
raise_passes,
20282046
lower_enzymexla_linalg_pass,
@@ -2233,7 +2251,7 @@ function compile_mlir!(
22332251
run_pass_pipeline!(
22342252
mod,
22352253
join(
2236-
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2],
2254+
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes3],
22372255
",",
22382256
),
22392257
"mid_pad_opts",

0 commit comments

Comments
 (0)