|
29 | 29 |
|
30 | 30 | @test @jit(f3(A_ra, 1)) ≈ A .+ 1 |
31 | 31 | end |
| 32 | + |
| 33 | +# Auto-Batching |
| 34 | +function run_auto_batching_tests(f::F, args...) where {F} |
| 35 | + @testset "$(nameof(F))" begin |
| 36 | + @testset "Correctness" begin |
| 37 | + res1 = @jit f(args...) |
| 38 | + res2 = @jit compile_options = CompileOptions(; |
| 39 | + disable_auto_batching_passes=true |
| 40 | + ) f(args...) |
| 41 | + @test res1 ≈ res2 |
| 42 | + end |
| 43 | + |
| 44 | + @testset "No while loops" begin |
| 45 | + hlo = repr( |
| 46 | + @code_hlo compile_options = CompileOptions(; |
| 47 | + disable_auto_batching_passes=true |
| 48 | + ) f(args...) |
| 49 | + ) |
| 50 | + @test occursin("stablehlo.while", hlo) |
| 51 | + |
| 52 | + hlo = repr(@code_hlo f(args...)) |
| 53 | + @test !occursin("stablehlo.while", hlo) |
| 54 | + end |
| 55 | + end |
| 56 | +end |
| 57 | + |
| 58 | +function looped_reduction(y, x) |
| 59 | + z = copy(y) |
| 60 | + @trace for i in 1:size(x, 2) |
| 61 | + z[:, i, :] = dropdims(sum(abs2, x[:, i, :, :]; dims=3); dims=3) |
| 62 | + end |
| 63 | + return z |
| 64 | +end |
| 65 | + |
| 66 | +@testset "Loop of Reduces => Single Reduction" begin |
| 67 | + x = Reactant.to_rarray(rand(Float32, 3, 256, 5, 7)) |
| 68 | + y = Reactant.to_rarray(rand(Float32, 3, 260, 5)) |
| 69 | + |
| 70 | + run_auto_batching_tests(looped_reduction, y, x) |
| 71 | +end |
| 72 | + |
| 73 | +function naive_batched_matmul(x, y) |
| 74 | + @assert size(x, 3) == size(y, 3) |
| 75 | + z = similar(x, size(x, 1), size(y, 2), size(x, 3)) |
| 76 | + @trace for i in 1:size(x, 3) |
| 77 | + z[:, :, i] = x[:, :, i] * y[:, :, i] |
| 78 | + end |
| 79 | + return z |
| 80 | +end |
| 81 | + |
| 82 | +@testset "Naive Batched Matmul => Single Dot General" begin |
| 83 | + x = Reactant.to_rarray(rand(Float32, 3, 256, 5)) |
| 84 | + y = Reactant.to_rarray(rand(Float32, 256, 7, 5)) |
| 85 | + |
| 86 | + run_auto_batching_tests(naive_batched_matmul, x, y) |
| 87 | +end |
0 commit comments