Skip to content

Commit ac38357

Browse files
committed
fix: kind of working
1 parent 4876110 commit ac38357

File tree

3 files changed

+46
-28
lines changed

3 files changed

+46
-28
lines changed

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "0d94adbc3a182ea6dbdc9d4103022beb7f1d20b9"
7+
ENZYMEXLA_COMMIT = "e408511ec376befe19bd48a0e725732b322fce3b"
88

99
ENZYMEXLA_SHA256 = ""
1010

deps/ReactantExtra/make-bindings.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ for file in [
4242
"MPI.jl",
4343
"MemRef.jl",
4444
"SparseTensor.jl",
45-
"TritonExt.jl"
45+
"TritonExt.jl",
4646
]
4747
build_file(joinpath(src_dir, "mlir", "Dialects", file))
4848
end

src/Compiler.jl

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ function optimization_passes(
702702
lower_comms::Bool=true,
703703
max_constant_threshold::Int=1024,
704704
backend::String="gpu",
705+
enable_triton_passes::Bool=false,
705706
)
706707
transform_passes_list = [
707708
"patterns=compare_op_canon<16>",
@@ -1291,7 +1292,7 @@ function optimization_passes(
12911292
push!(passes, "remove-duplicate-func-def")
12921293
end
12931294
push!(passes, func_passes)
1294-
if backend == "cuda"
1295+
if enable_triton_passes && backend == "cuda"
12951296
push!(passes, triton_optimization_passes())
12961297
end
12971298
return join(passes, ',')
@@ -1366,12 +1367,11 @@ function triton_optimization_passes()
13661367
"allocate-shared-memory",
13671368
"triton-tensor-memory-allocation",
13681369
"tritongpu-global-scratch-memory-allocation",
1369-
# TODO: register the commented out passes
1370-
# "convert-triton-gpu-to-llvm",
1370+
"convert-triton-gpu-to-llvm",
13711371
"canonicalize",
13721372
"cse",
1373-
# "convert-nv-gpu-to-llvm",
1374-
# "convert-warp-specialize-to-llvm",
1373+
"convert-nv-gpu-to-llvm",
1374+
"convert-warp-specialize-to-llvm",
13751375
"reconcile-unrealized-casts",
13761376
"canonicalize",
13771377
"cse",
@@ -1766,10 +1766,28 @@ function compile_mlir!(
17661766
end
17671767

17681768
opt_passes = optimization_passes(
1769-
compile_options; sroa=true, recognize_comms, lower_comms, backend
1769+
compile_options;
1770+
sroa=true,
1771+
recognize_comms,
1772+
lower_comms,
1773+
backend,
1774+
enable_triton_passes=false,
17701775
)
17711776
opt_passes2 = optimization_passes(
1772-
compile_options; sroa=false, recognize_comms, lower_comms, backend
1777+
compile_options;
1778+
sroa=false,
1779+
recognize_comms,
1780+
lower_comms,
1781+
backend,
1782+
enable_triton_passes=false,
1783+
)
1784+
opt_passes3 = optimization_passes(
1785+
compile_options;
1786+
sroa=false,
1787+
recognize_comms,
1788+
lower_comms,
1789+
backend,
1790+
enable_triton_passes=true,
17731791
)
17741792

17751793
raise_passes = if raise isa String
@@ -1784,15 +1802,15 @@ function compile_mlir!(
17841802
opt_passes2
17851803

17861804
if DUS_TO_CONCAT[]
1787-
opt_passes3 = optimization_passes(
1805+
opt_passes_dus_to_concat = optimization_passes(
17881806
compile_options;
17891807
sroa=false,
17901808
dus_to_concat=true,
17911809
recognize_comms,
17921810
lower_comms,
17931811
backend,
17941812
)
1795-
result = result * "," * opt_passes3
1813+
result = result * "," * opt_passes_dus_to_concat
17961814
end
17971815
result
17981816
else
@@ -1823,12 +1841,12 @@ function compile_mlir!(
18231841
"enzyme-batch",
18241842
opt_passes2,
18251843
enzyme_pass,
1826-
opt_passes2,
1844+
opt_passes3,
18271845
"canonicalize",
18281846
"remove-unnecessary-enzyme-ops",
18291847
"enzyme-simplify-math",
18301848
legalize_chlo_to_stablehlo...,
1831-
opt_passes2,
1849+
opt_passes3,
18321850
lower_enzymexla_linalg_pass,
18331851
jit,
18341852
]
@@ -1839,12 +1857,12 @@ function compile_mlir!(
18391857
"enzyme-batch",
18401858
opt_passes2,
18411859
enzyme_pass,
1842-
opt_passes2,
1860+
opt_passes3,
18431861
"canonicalize",
18441862
"remove-unnecessary-enzyme-ops",
18451863
"enzyme-simplify-math",
18461864
legalize_chlo_to_stablehlo...,
1847-
opt_passes2,
1865+
opt_passes3,
18481866
kern,
18491867
raise_passes,
18501868
lower_enzymexla_linalg_pass,
@@ -1868,12 +1886,12 @@ function compile_mlir!(
18681886
"enzyme-batch",
18691887
opt_passes2,
18701888
enzyme_pass,
1871-
opt_passes2,
1889+
opt_passes3,
18721890
"canonicalize",
18731891
"remove-unnecessary-enzyme-ops",
18741892
"enzyme-simplify-math",
18751893
legalize_chlo_to_stablehlo...,
1876-
opt_passes2,
1894+
opt_passes3,
18771895
]
18781896
end,
18791897
',',
@@ -1893,12 +1911,12 @@ function compile_mlir!(
18931911
"enzyme-batch",
18941912
opt_passes2,
18951913
enzyme_pass,
1896-
opt_passes2,
1914+
opt_passes3,
18971915
"canonicalize",
18981916
"remove-unnecessary-enzyme-ops",
18991917
"enzyme-simplify-math",
19001918
legalize_chlo_to_stablehlo...,
1901-
opt_passes2,
1919+
opt_passes3,
19021920
]
19031921
else
19041922
[
@@ -1907,12 +1925,12 @@ function compile_mlir!(
19071925
"enzyme-batch",
19081926
opt_passes2,
19091927
enzyme_pass,
1910-
opt_passes2,
1928+
opt_passes3,
19111929
"canonicalize",
19121930
"remove-unnecessary-enzyme-ops",
19131931
"enzyme-simplify-math",
19141932
legalize_chlo_to_stablehlo...,
1915-
opt_passes2,
1933+
opt_passes3,
19161934
kern,
19171935
raise_passes,
19181936
]
@@ -1934,12 +1952,12 @@ function compile_mlir!(
19341952
"enzyme-batch",
19351953
opt_passes2,
19361954
enzyme_pass,
1937-
opt_passes2,
1955+
opt_passes3,
19381956
"canonicalize",
19391957
"remove-unnecessary-enzyme-ops",
19401958
"enzyme-simplify-math",
19411959
legalize_chlo_to_stablehlo...,
1942-
opt_passes2,
1960+
opt_passes3,
19431961
kern,
19441962
]
19451963
end,
@@ -1957,12 +1975,12 @@ function compile_mlir!(
19571975
"enzyme-batch",
19581976
opt_passes2,
19591977
enzyme_pass,
1960-
opt_passes2,
1978+
opt_passes3,
19611979
"canonicalize",
19621980
"remove-unnecessary-enzyme-ops",
19631981
"enzyme-simplify-math",
19641982
legalize_chlo_to_stablehlo...,
1965-
opt_passes2,
1983+
opt_passes3,
19661984
],
19671985
',',
19681986
),
@@ -1999,7 +2017,7 @@ function compile_mlir!(
19992017
"remove-unnecessary-enzyme-ops",
20002018
"enzyme-simplify-math",
20012019
legalize_chlo_to_stablehlo...,
2002-
opt_passes2,
2020+
opt_passes3,
20032021
lower_enzymexla_linalg_pass,
20042022
jit,
20052023
]
@@ -2012,7 +2030,7 @@ function compile_mlir!(
20122030
"remove-unnecessary-enzyme-ops",
20132031
"enzyme-simplify-math",
20142032
legalize_chlo_to_stablehlo...,
2015-
opt_passes2,
2033+
opt_passes3,
20162034
kern,
20172035
raise_passes,
20182036
lower_enzymexla_linalg_pass,
@@ -2223,7 +2241,7 @@ function compile_mlir!(
22232241
run_pass_pipeline!(
22242242
mod,
22252243
join(
2226-
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2],
2244+
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes3],
22272245
",",
22282246
),
22292247
"mid_pad_opts",

0 commit comments

Comments
 (0)