Skip to content

Commit e869448

Browse files
committed
fix: kind of working
1 parent 357f1c0 commit e869448

File tree

2 files changed

+45
-27
lines changed

2 files changed

+45
-27
lines changed

deps/ReactantExtra/make-bindings.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ for file in [
4242
"MPI.jl",
4343
"MemRef.jl",
4444
"SparseTensor.jl",
45-
"TritonExt.jl"
45+
"TritonExt.jl",
4646
]
4747
build_file(joinpath(src_dir, "mlir", "Dialects", file))
4848
end

src/Compiler.jl

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ function optimization_passes(
702702
lower_comms::Bool=true,
703703
max_constant_threshold::Int=1024,
704704
backend::String="gpu",
705+
enable_triton_passes::Bool=false,
705706
)
706707
transform_passes_list = [
707708
"patterns=compare_op_canon<16>",
@@ -1298,7 +1299,7 @@ function optimization_passes(
12981299
push!(passes, "remove-duplicate-func-def")
12991300
end
13001301
push!(passes, func_passes)
1301-
if backend == "cuda"
1302+
if enable_triton_passes && backend == "cuda"
13021303
push!(passes, triton_optimization_passes())
13031304
end
13041305
return join(passes, ',')
@@ -1373,12 +1374,11 @@ function triton_optimization_passes()
13731374
"allocate-shared-memory",
13741375
"triton-tensor-memory-allocation",
13751376
"tritongpu-global-scratch-memory-allocation",
1376-
# TODO: register the commented out passes
1377-
# "convert-triton-gpu-to-llvm",
1377+
"convert-triton-gpu-to-llvm",
13781378
"canonicalize",
13791379
"cse",
1380-
# "convert-nv-gpu-to-llvm",
1381-
# "convert-warp-specialize-to-llvm",
1380+
"convert-nv-gpu-to-llvm",
1381+
"convert-warp-specialize-to-llvm",
13821382
"reconcile-unrealized-casts",
13831383
"canonicalize",
13841384
"cse",
@@ -1774,10 +1774,28 @@ function compile_mlir!(
17741774
end
17751775

17761776
opt_passes = optimization_passes(
1777-
compile_options; sroa=true, recognize_comms, lower_comms, backend
1777+
compile_options;
1778+
sroa=true,
1779+
recognize_comms,
1780+
lower_comms,
1781+
backend,
1782+
enable_triton_passes=false,
17781783
)
17791784
opt_passes2 = optimization_passes(
1780-
compile_options; sroa=false, recognize_comms, lower_comms, backend
1785+
compile_options;
1786+
sroa=false,
1787+
recognize_comms,
1788+
lower_comms,
1789+
backend,
1790+
enable_triton_passes=false,
1791+
)
1792+
opt_passes3 = optimization_passes(
1793+
compile_options;
1794+
sroa=false,
1795+
recognize_comms,
1796+
lower_comms,
1797+
backend,
1798+
enable_triton_passes=true,
17811799
)
17821800

17831801
raise_passes = if raise isa String
@@ -1792,15 +1810,15 @@ function compile_mlir!(
17921810
opt_passes2
17931811

17941812
if DUS_TO_CONCAT[]
1795-
opt_passes3 = optimization_passes(
1813+
opt_passes_dus_to_concat = optimization_passes(
17961814
compile_options;
17971815
sroa=false,
17981816
dus_to_concat=true,
17991817
recognize_comms,
18001818
lower_comms,
18011819
backend,
18021820
)
1803-
result = result * "," * opt_passes3
1821+
result = result * "," * opt_passes_dus_to_concat
18041822
end
18051823
result
18061824
else
@@ -1831,12 +1849,12 @@ function compile_mlir!(
18311849
"enzyme-batch",
18321850
opt_passes2,
18331851
enzyme_pass,
1834-
opt_passes2,
1852+
opt_passes3,
18351853
"canonicalize",
18361854
"remove-unnecessary-enzyme-ops",
18371855
"enzyme-simplify-math",
18381856
legalize_chlo_to_stablehlo...,
1839-
opt_passes2,
1857+
opt_passes3,
18401858
lower_enzymexla_linalg_pass,
18411859
jit,
18421860
]
@@ -1847,12 +1865,12 @@ function compile_mlir!(
18471865
"enzyme-batch",
18481866
opt_passes2,
18491867
enzyme_pass,
1850-
opt_passes2,
1868+
opt_passes3,
18511869
"canonicalize",
18521870
"remove-unnecessary-enzyme-ops",
18531871
"enzyme-simplify-math",
18541872
legalize_chlo_to_stablehlo...,
1855-
opt_passes2,
1873+
opt_passes3,
18561874
kern,
18571875
raise_passes,
18581876
lower_enzymexla_linalg_pass,
@@ -1876,12 +1894,12 @@ function compile_mlir!(
18761894
"enzyme-batch",
18771895
opt_passes2,
18781896
enzyme_pass,
1879-
opt_passes2,
1897+
opt_passes3,
18801898
"canonicalize",
18811899
"remove-unnecessary-enzyme-ops",
18821900
"enzyme-simplify-math",
18831901
legalize_chlo_to_stablehlo...,
1884-
opt_passes2,
1902+
opt_passes3,
18851903
]
18861904
end,
18871905
',',
@@ -1901,12 +1919,12 @@ function compile_mlir!(
19011919
"enzyme-batch",
19021920
opt_passes2,
19031921
enzyme_pass,
1904-
opt_passes2,
1922+
opt_passes3,
19051923
"canonicalize",
19061924
"remove-unnecessary-enzyme-ops",
19071925
"enzyme-simplify-math",
19081926
legalize_chlo_to_stablehlo...,
1909-
opt_passes2,
1927+
opt_passes3,
19101928
]
19111929
else
19121930
[
@@ -1915,12 +1933,12 @@ function compile_mlir!(
19151933
"enzyme-batch",
19161934
opt_passes2,
19171935
enzyme_pass,
1918-
opt_passes2,
1936+
opt_passes3,
19191937
"canonicalize",
19201938
"remove-unnecessary-enzyme-ops",
19211939
"enzyme-simplify-math",
19221940
legalize_chlo_to_stablehlo...,
1923-
opt_passes2,
1941+
opt_passes3,
19241942
kern,
19251943
raise_passes,
19261944
]
@@ -1942,12 +1960,12 @@ function compile_mlir!(
19421960
"enzyme-batch",
19431961
opt_passes2,
19441962
enzyme_pass,
1945-
opt_passes2,
1963+
opt_passes3,
19461964
"canonicalize",
19471965
"remove-unnecessary-enzyme-ops",
19481966
"enzyme-simplify-math",
19491967
legalize_chlo_to_stablehlo...,
1950-
opt_passes2,
1968+
opt_passes3,
19511969
kern,
19521970
]
19531971
end,
@@ -1965,12 +1983,12 @@ function compile_mlir!(
19651983
"enzyme-batch",
19661984
opt_passes2,
19671985
enzyme_pass,
1968-
opt_passes2,
1986+
opt_passes3,
19691987
"canonicalize",
19701988
"remove-unnecessary-enzyme-ops",
19711989
"enzyme-simplify-math",
19721990
legalize_chlo_to_stablehlo...,
1973-
opt_passes2,
1991+
opt_passes3,
19741992
],
19751993
',',
19761994
),
@@ -2007,7 +2025,7 @@ function compile_mlir!(
20072025
"remove-unnecessary-enzyme-ops",
20082026
"enzyme-simplify-math",
20092027
legalize_chlo_to_stablehlo...,
2010-
opt_passes2,
2028+
opt_passes3,
20112029
lower_enzymexla_linalg_pass,
20122030
jit,
20132031
]
@@ -2020,7 +2038,7 @@ function compile_mlir!(
20202038
"remove-unnecessary-enzyme-ops",
20212039
"enzyme-simplify-math",
20222040
legalize_chlo_to_stablehlo...,
2023-
opt_passes2,
2041+
opt_passes3,
20242042
kern,
20252043
raise_passes,
20262044
lower_enzymexla_linalg_pass,
@@ -2231,7 +2249,7 @@ function compile_mlir!(
22312249
run_pass_pipeline!(
22322250
mod,
22332251
join(
2234-
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2],
2252+
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes3],
22352253
",",
22362254
),
22372255
"mid_pad_opts",

0 commit comments

Comments
 (0)