Skip to content

Commit 1efd123

Browse files
authored
feat: enable auto-batching passes (#1799)
* feat: enable auto-batching passes * chore: bump reactant version
1 parent b2461ef commit 1efd123

File tree

3 files changed

+58
-1
lines changed

3 files changed

+58
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>", "Mosè Giordano <mose@gnu.org>"]
4-
version = "0.2.172"
4+
version = "0.2.173"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/Compiler.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,7 @@ function optimization_passes(
937937
"broadcastindim_slice_to_batch",
938938
"reducewindow_slice_to_batch",
939939
"elementwise_slice_to_batch",
940+
"greedy_while_loop_batch_fission",
940941
],
941942
)
942943
end

test/batching.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,59 @@ end
2929

3030
@test @jit(f3(A_ra, 1)) A .+ 1
3131
end
32+
33+
# Auto-Batching
34+
function run_auto_batching_tests(f::F, args...) where {F}
35+
@testset "$(nameof(F))" begin
36+
@testset "Correctness" begin
37+
res1 = @jit f(args...)
38+
res2 = @jit compile_options = CompileOptions(;
39+
disable_auto_batching_passes=true
40+
) f(args...)
41+
@test res1 res2
42+
end
43+
44+
@testset "No while loops" begin
45+
hlo = repr(
46+
@code_hlo compile_options = CompileOptions(;
47+
disable_auto_batching_passes=true
48+
) f(args...)
49+
)
50+
@test occursin("stablehlo.while", hlo)
51+
52+
hlo = repr(@code_hlo f(args...))
53+
@test !occursin("stablehlo.while", hlo)
54+
end
55+
end
56+
end
57+
58+
function looped_reduction(y, x)
59+
z = copy(y)
60+
@trace for i in 1:size(x, 2)
61+
z[:, i, :] = dropdims(sum(abs2, x[:, i, :, :]; dims=3); dims=3)
62+
end
63+
return z
64+
end
65+
66+
@testset "Loop of Reduces => Single Reduction" begin
67+
x = Reactant.to_rarray(rand(Float32, 3, 256, 5, 7))
68+
y = Reactant.to_rarray(rand(Float32, 3, 260, 5))
69+
70+
run_auto_batching_tests(looped_reduction, y, x)
71+
end
72+
73+
function naive_batched_matmul(x, y)
74+
@assert size(x, 3) == size(y, 3)
75+
z = similar(x, size(x, 1), size(y, 2), size(x, 3))
76+
@trace for i in 1:size(x, 3)
77+
z[:, :, i] = x[:, :, i] * y[:, :, i]
78+
end
79+
return z
80+
end
81+
82+
@testset "Naive Batched Matmul => Single Dot General" begin
83+
x = Reactant.to_rarray(rand(Float32, 3, 256, 5))
84+
y = Reactant.to_rarray(rand(Float32, 256, 7, 5))
85+
86+
run_auto_batching_tests(naive_batched_matmul, x, y)
87+
end

0 commit comments

Comments
 (0)