Skip to content

Commit 97c952b

Browse files
committed
fix: kind of working
1 parent a3b8cb6 commit 97c952b

File tree

2 files changed

+45
-27
lines changed

2 files changed

+45
-27
lines changed

deps/ReactantExtra/make-bindings.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ for file in [
4242
"MPI.jl",
4343
"MemRef.jl",
4444
"SparseTensor.jl",
45-
"TritonExt.jl"
45+
"TritonExt.jl",
4646
]
4747
build_file(joinpath(src_dir, "mlir", "Dialects", file))
4848
end

src/Compiler.jl

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ function optimization_passes(
702702
lower_comms::Bool=true,
703703
max_constant_threshold::Int=1024,
704704
backend::String="gpu",
705+
enable_triton_passes::Bool=false,
705706
)
706707
transform_passes_list = [
707708
"patterns=compare_op_canon<16>",
@@ -1300,7 +1301,7 @@ function optimization_passes(
13001301
push!(passes, "remove-duplicate-func-def")
13011302
end
13021303
push!(passes, func_passes)
1303-
if backend == "cuda"
1304+
if enable_triton_passes && backend == "cuda"
13041305
push!(passes, triton_optimization_passes())
13051306
end
13061307
return join(passes, ',')
@@ -1375,12 +1376,11 @@ function triton_optimization_passes()
13751376
"allocate-shared-memory",
13761377
"triton-tensor-memory-allocation",
13771378
"tritongpu-global-scratch-memory-allocation",
1378-
# TODO: register the commented out passes
1379-
# "convert-triton-gpu-to-llvm",
1379+
"convert-triton-gpu-to-llvm",
13801380
"canonicalize",
13811381
"cse",
1382-
# "convert-nv-gpu-to-llvm",
1383-
# "convert-warp-specialize-to-llvm",
1382+
"convert-nv-gpu-to-llvm",
1383+
"convert-warp-specialize-to-llvm",
13841384
"reconcile-unrealized-casts",
13851385
"canonicalize",
13861386
"cse",
@@ -1781,10 +1781,28 @@ function compile_mlir!(
17811781
end
17821782

17831783
opt_passes = optimization_passes(
1784-
compile_options; sroa=true, recognize_comms, lower_comms, backend
1784+
compile_options;
1785+
sroa=true,
1786+
recognize_comms,
1787+
lower_comms,
1788+
backend,
1789+
enable_triton_passes=false,
17851790
)
17861791
opt_passes2 = optimization_passes(
1787-
compile_options; sroa=false, recognize_comms, lower_comms, backend
1792+
compile_options;
1793+
sroa=false,
1794+
recognize_comms,
1795+
lower_comms,
1796+
backend,
1797+
enable_triton_passes=false,
1798+
)
1799+
opt_passes3 = optimization_passes(
1800+
compile_options;
1801+
sroa=false,
1802+
recognize_comms,
1803+
lower_comms,
1804+
backend,
1805+
enable_triton_passes=true,
17881806
)
17891807

17901808
raise_passes = if raise isa String
@@ -1799,15 +1817,15 @@ function compile_mlir!(
17991817
opt_passes2
18001818

18011819
if DUS_TO_CONCAT[]
1802-
opt_passes3 = optimization_passes(
1820+
opt_passes_dus_to_concat = optimization_passes(
18031821
compile_options;
18041822
sroa=false,
18051823
dus_to_concat=true,
18061824
recognize_comms,
18071825
lower_comms,
18081826
backend,
18091827
)
1810-
result = result * "," * opt_passes3
1828+
result = result * "," * opt_passes_dus_to_concat
18111829
end
18121830
result
18131831
else
@@ -1838,12 +1856,12 @@ function compile_mlir!(
18381856
"enzyme-batch",
18391857
opt_passes2,
18401858
enzyme_pass,
1841-
opt_passes2,
1859+
opt_passes3,
18421860
"canonicalize",
18431861
"remove-unnecessary-enzyme-ops",
18441862
"enzyme-simplify-math",
18451863
legalize_chlo_to_stablehlo...,
1846-
opt_passes2,
1864+
opt_passes3,
18471865
lower_enzymexla_linalg_pass,
18481866
jit,
18491867
]
@@ -1854,12 +1872,12 @@ function compile_mlir!(
18541872
"enzyme-batch",
18551873
opt_passes2,
18561874
enzyme_pass,
1857-
opt_passes2,
1875+
opt_passes3,
18581876
"canonicalize",
18591877
"remove-unnecessary-enzyme-ops",
18601878
"enzyme-simplify-math",
18611879
legalize_chlo_to_stablehlo...,
1862-
opt_passes2,
1880+
opt_passes3,
18631881
kern,
18641882
raise_passes,
18651883
lower_enzymexla_linalg_pass,
@@ -1883,12 +1901,12 @@ function compile_mlir!(
18831901
"enzyme-batch",
18841902
opt_passes2,
18851903
enzyme_pass,
1886-
opt_passes2,
1904+
opt_passes3,
18871905
"canonicalize",
18881906
"remove-unnecessary-enzyme-ops",
18891907
"enzyme-simplify-math",
18901908
legalize_chlo_to_stablehlo...,
1891-
opt_passes2,
1909+
opt_passes3,
18921910
]
18931911
end,
18941912
',',
@@ -1908,12 +1926,12 @@ function compile_mlir!(
19081926
"enzyme-batch",
19091927
opt_passes2,
19101928
enzyme_pass,
1911-
opt_passes2,
1929+
opt_passes3,
19121930
"canonicalize",
19131931
"remove-unnecessary-enzyme-ops",
19141932
"enzyme-simplify-math",
19151933
legalize_chlo_to_stablehlo...,
1916-
opt_passes2,
1934+
opt_passes3,
19171935
]
19181936
else
19191937
[
@@ -1922,12 +1940,12 @@ function compile_mlir!(
19221940
"enzyme-batch",
19231941
opt_passes2,
19241942
enzyme_pass,
1925-
opt_passes2,
1943+
opt_passes3,
19261944
"canonicalize",
19271945
"remove-unnecessary-enzyme-ops",
19281946
"enzyme-simplify-math",
19291947
legalize_chlo_to_stablehlo...,
1930-
opt_passes2,
1948+
opt_passes3,
19311949
kern,
19321950
raise_passes,
19331951
]
@@ -1949,12 +1967,12 @@ function compile_mlir!(
19491967
"enzyme-batch",
19501968
opt_passes2,
19511969
enzyme_pass,
1952-
opt_passes2,
1970+
opt_passes3,
19531971
"canonicalize",
19541972
"remove-unnecessary-enzyme-ops",
19551973
"enzyme-simplify-math",
19561974
legalize_chlo_to_stablehlo...,
1957-
opt_passes2,
1975+
opt_passes3,
19581976
kern,
19591977
]
19601978
end,
@@ -1972,12 +1990,12 @@ function compile_mlir!(
19721990
"enzyme-batch",
19731991
opt_passes2,
19741992
enzyme_pass,
1975-
opt_passes2,
1993+
opt_passes3,
19761994
"canonicalize",
19771995
"remove-unnecessary-enzyme-ops",
19781996
"enzyme-simplify-math",
19791997
legalize_chlo_to_stablehlo...,
1980-
opt_passes2,
1998+
opt_passes3,
19811999
],
19822000
',',
19832001
),
@@ -2014,7 +2032,7 @@ function compile_mlir!(
20142032
"remove-unnecessary-enzyme-ops",
20152033
"enzyme-simplify-math",
20162034
legalize_chlo_to_stablehlo...,
2017-
opt_passes2,
2035+
opt_passes3,
20182036
lower_enzymexla_linalg_pass,
20192037
jit,
20202038
]
@@ -2027,7 +2045,7 @@ function compile_mlir!(
20272045
"remove-unnecessary-enzyme-ops",
20282046
"enzyme-simplify-math",
20292047
legalize_chlo_to_stablehlo...,
2030-
opt_passes2,
2048+
opt_passes3,
20312049
kern,
20322050
raise_passes,
20332051
lower_enzymexla_linalg_pass,
@@ -2238,7 +2256,7 @@ function compile_mlir!(
22382256
run_pass_pipeline!(
22392257
mod,
22402258
join(
2241-
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2],
2259+
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes3],
22422260
",",
22432261
),
22442262
"mid_pad_opts",

0 commit comments

Comments
 (0)