Skip to content

Commit 9528846

Browse files
authored
feat: use the generalized passes from upcoming enzymejax (#1850)
* feat: use the generalized passes from upcoming enzymejax [skip ci] * chore: bump jll version * test: update * feat: new passes
1 parent 481a91c commit 9528846

File tree

4 files changed

+15
-33
lines changed

4 files changed

+15
-33
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>", "Mosè Giordano <mose@gnu.org>"]
4-
version = "0.2.177"
4+
version = "0.2.178"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -105,7 +105,7 @@ PythonCall = "0.9.25"
105105
Random = "1.10"
106106
Random123 = "1.7"
107107
ReactantCore = "0.1.16"
108-
Reactant_jll = "0.0.259"
108+
Reactant_jll = "0.0.261"
109109
ScopedValues = "1.3.0"
110110
Scratch = "1.2"
111111
Sockets = "1.10"

src/CompileOptions.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ communication.
1212
wrap_comm::Int = 0
1313
extend_comm::Int = 0
1414
dus_to_pad_manual_comp_comm::Int = 0 # 2
15-
dus_to_pad_comm::Int = 1
15+
dus_to_pad_comm::Int = 0
1616
concat_two_operands_comm::Int = 0
1717
concat_to_pad_comm::Int = 1
18-
extend_to_pad_comm::Int = 1
18+
extend_to_pad_comm::Int = 0
19+
extend_to_pad_comm2::Int = 1
1920
wrap_to_pad_comm::Int = 1
2021
end
2122

src/Compiler.jl

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,9 @@ function optimization_passes(
759759
"noop_slice<16>",
760760
"noop_reverse<16>",
761761
"slice_slice<16>",
762+
"dynamic_slice_slice<16>",
763+
"slice_dynamic_slice<16>",
764+
"dynamic_slice_dynamic_slice<16>",
762765
"shift_right_logical_simplify<16>",
763766
"slice_simplify<16>",
764767
"convert_simplify<16>",
@@ -816,20 +819,7 @@ function optimization_passes(
816819
"dus_dus",
817820
"dus_dus_concat",
818821
"abs_positive_simplify",
819-
"transpose_unary_transpose_abs",
820-
"transpose_unary_transpose_neg",
821-
"transpose_unary_transpose_sqrt",
822-
"transpose_unary_transpose_rsqrt",
823-
"transpose_unary_transpose_ceil",
824-
"transpose_unary_transpose_convert",
825-
"transpose_unary_transpose_cosine",
826-
"transpose_unary_transpose_exp",
827-
"transpose_unary_transpose_expm1",
828-
"transpose_unary_transpose_log",
829-
"transpose_unary_transpose_log1p",
830-
"transpose_unary_transpose_sign",
831-
"transpose_unary_transpose_sine",
832-
"transpose_unary_transpose_tanh",
822+
"transpose_elementwise_transpose",
833823
"select_comp_iota_const_simplify<1>",
834824
"sign_abs_simplify<1>",
835825
"broadcastindim_is_reshape",
@@ -1147,6 +1137,9 @@ function optimization_passes(
11471137
"concat_appending_reshape",
11481138
"slice_reshape",
11491139
"slice_reshape_slice<1>",
1140+
"dynamic_slice_reshape_slice<1>",
1141+
"slice_reshape_dynamic_slice<1>",
1142+
"dynamic_slice_reshape_dynamic_slice<1>",
11501143
"slice_reshape_concat<1>",
11511144
"slice_reshape_elementwise<1>",
11521145
"slice_reshape_dot_general<1>",
@@ -1193,17 +1186,7 @@ function optimization_passes(
11931186
transform_passes_list,
11941187
[
11951188
"reorder_elementwise_and_shape_op<16>",
1196-
"binary_op_transpose_simplify_add",
1197-
"binary_op_transpose_simplify_sub",
1198-
"binary_op_transpose_simplify_mul",
1199-
"binary_op_transpose_simplify_div",
1200-
"binary_op_transpose_simplify_min",
1201-
"binary_op_transpose_simplify_max",
1202-
"binary_op_transpose_simplify_pow",
1203-
"binary_op_transpose_simplify_rem",
1204-
"binary_op_transpose_simplify_or",
1205-
"binary_op_transpose_simplify_and",
1206-
"binary_op_transpose_simplify_xor",
1189+
"elementwise_all_transpose_operands_simplify",
12071190
"slice_transpose",
12081191
"einsum_transpose<1>",
12091192
"slice_reshape_transpose<1>",

test/optimize_comm.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,8 @@ if length(addressable_devices) ≥ 8
100100

101101
hlo = repr(@code_xla shardy_passes = :to_mhlo_shardings dus2(rx, ry))
102102
@test !contains(hlo, "all-to-all")
103-
@test !contains(hlo, "all-gather") broken =
104-
Reactant.XLA.REACTANT_XLA_RUNTIME == "PJRT"
105-
@test contains(hlo, "collective-permute") broken =
106-
Reactant.XLA.REACTANT_XLA_RUNTIME == "PJRT"
103+
@test !contains(hlo, "all-gather")
104+
@test contains(hlo, "collective-permute")
107105

108106
dus2(x, y)
109107
@jit shardy_passes = :to_mhlo_shardings dus2(rx, ry)

0 commit comments

Comments
 (0)