From 9ff1b0f0ec2064289b903157e4c7d9436c36962a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 4 Nov 2025 07:30:55 -0600 Subject: [PATCH 1/4] feat: enable new licm + loop unroll passes + inlining --- Project.toml | 2 +- src/Compiler.jl | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 233dc22db8..16a71a180c 100644 --- a/Project.toml +++ b/Project.toml @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.256" +Reactant_jll = "0.0.257" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" diff --git a/src/Compiler.jl b/src/Compiler.jl index bdbbf230f2..8e3b90bb5b 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -693,6 +693,7 @@ const AGGRESSIVE_SUM_TO_CONV = Ref(false) const AGGRESSIVE_PROPAGATION = Ref(false) const DUS_SLICE_SIMPLIFY = Ref(true) const CONCATS_TO_DUS = Ref(false) +const WHILE_UNROLL_THRESHOLD = Ref(5) # Optimization passes via transform dialect function optimization_passes( @@ -912,6 +913,7 @@ function optimization_passes( "while_is_copy_simplify", "split_variadic_scatter_op", "dynamic_slice_simplify", + "enzyme_hlo_unroll($(WHILE_UNROLL_THRESHOLD[]))", ] if !compile_options.disable_auto_batching_passes @@ -955,6 +957,9 @@ function optimization_passes( "transpose_licm(0)", "broadcastindim_licm(0)", "reshape_licm(0)", + "dot_general_licm(0)", + "reduce_licm(0)", + "reduce_window_licm(0)", ], ) end From 585a37b5048d4420428e9cfc318289ebb8e3ad7c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 4 Nov 2025 09:16:26 -0600 Subject: [PATCH 2/4] test: increase batch size to > unroll limit --- test/batching.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/batching.jl b/test/batching.jl index 895bb02eb8..13298a5394 100644 --- a/test/batching.jl +++ b/test/batching.jl @@ -80,8 +80,8 @@ function naive_batched_matmul(x, y) end @testset "Naive Batched Matmul => Single Dot General" begin - x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 3, 256, 5)) - y = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 256, 7, 5)) + x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 3, 256, 8)) + y = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 256, 7, 8)) run_auto_batching_tests(naive_batched_matmul, x, y) end From 260e1cb6c8f1f203d0ebae172b6b0f4b3588bef1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 3 Nov 2025 15:52:49 -0600 Subject: [PATCH 3/4] fix: correct type for ad annotation --- src/Enzyme.jl | 10 ++++++++++ src/Overlay.jl | 12 ++++++++++++ test/autodiff.jl | 12 ++++++++++++ 3 files changed, 34 insertions(+) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 7cf978ff9a..79c32361c7 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -295,6 +295,16 @@ function act_attr(val) return MLIR.IR.Attribute(val) end +function overload_autodiff( + mode::CMode, f::FA, args::Vararg{Annotation,Nargs} +) where {CMode<:Mode,FA<:Annotation,Nargs} + # need to guess the correct activity here. Execute the function, we will DCE it + res = call_with_reactant(f.val, [x.val for x in args]...) + return overload_autodiff( + mode, f, Enzyme.guess_activity(Core.Typeof(res), mode), args... + ) +end + function overload_autodiff( ::CMode, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs} ) where {CMode<:Mode,FA<:Annotation,A<:Annotation,Nargs} diff --git a/src/Overlay.jl b/src/Overlay.jl index 143e8a4a45..1abbed7b7f 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -9,6 +9,18 @@ end # Enzyme.jl overlays +@reactant_overlay @noinline function Enzyme.autodiff_deferred( + rmode::Enzyme.Mode, f::FA, args::Vararg{Annotation,Nargs} +) where {FA<:Annotation,Nargs} + return overload_autodiff(rmode, f, args...) +end + +@reactant_overlay @noinline function Enzyme.autodiff( + rmode::Enzyme.Mode, f::FA, args::Vararg{Annotation,Nargs} +) where {FA<:Annotation,Nargs} + return overload_autodiff(rmode, f, args...) +end + @reactant_overlay @noinline function Enzyme.autodiff_deferred( rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} ) where {FA<:Annotation,A<:Annotation,Nargs} diff --git a/test/autodiff.jl b/test/autodiff.jl index 4629cb5d62..ea2725ceef 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -2,6 +2,8 @@ using Enzyme, Reactant, Test, Random square(x) = x * 2 +sum_without_activity(x) = sum(abs2, x; dims=(2, 3)) + fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) @testset "Activity" begin @@ -39,6 +41,16 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) Enzyme.Duplicated end +@testset "Correct Activity Guess" begin + x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 3, 4, 5, 6)) + bx = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 3, 4, 5, 6)) + + res = only(@jit(Enzyme.autodiff(Forward, sum_without_activity, Duplicated(x, bx)))) + @test res isa Reactant.ConcreteRArray{Float32,4} + @test size(res) == (3, 1, 1, 6) + @test 2 .* sum_without_activity(bx) ≈ res +end + @testset "Basic Forward Mode" begin res1 = @jit( fwd( From 5ab6391dd024c1c88110d82907b1680eef68db8e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 3 Nov 2025 15:56:05 -0600 Subject: [PATCH 4/4] fix: deepcopy :P --- src/Enzyme.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 79c32361c7..c992cc9bcd 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -299,7 +299,8 @@ function overload_autodiff( mode::CMode, f::FA, args::Vararg{Annotation,Nargs} ) where {CMode<:Mode,FA<:Annotation,Nargs} # need to guess the correct activity here. Execute the function, we will DCE it - res = call_with_reactant(f.val, [x.val for x in args]...) + # XXX: DONT MERGE + res = call_with_reactant(deepcopy(f.val), [deepcopy(x.val) for x in args]...) return overload_autodiff( mode, f, Enzyme.guess_activity(Core.Typeof(res), mode), args... )