diff --git a/Project.toml b/Project.toml index 233dc22db8..16a71a180c 100644 --- a/Project.toml +++ b/Project.toml @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.256" +Reactant_jll = "0.0.257" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" diff --git a/src/Compiler.jl b/src/Compiler.jl index bdbbf230f2..8e3b90bb5b 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -693,6 +693,7 @@ const AGGRESSIVE_SUM_TO_CONV = Ref(false) const AGGRESSIVE_PROPAGATION = Ref(false) const DUS_SLICE_SIMPLIFY = Ref(true) const CONCATS_TO_DUS = Ref(false) +const WHILE_UNROLL_THRESHOLD = Ref(5) # Optimization passes via transform dialect function optimization_passes( @@ -912,6 +913,7 @@ function optimization_passes( "while_is_copy_simplify", "split_variadic_scatter_op", "dynamic_slice_simplify", + "enzyme_hlo_unroll($(WHILE_UNROLL_THRESHOLD[]))", ] if !compile_options.disable_auto_batching_passes @@ -955,6 +957,9 @@ function optimization_passes( "transpose_licm(0)", "broadcastindim_licm(0)", "reshape_licm(0)", + "dot_general_licm(0)", + "reduce_licm(0)", + "reduce_window_licm(0)", ], ) end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 7cf978ff9a..c992cc9bcd 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -295,6 +295,17 @@ function act_attr(val) return MLIR.IR.Attribute(val) end +function overload_autodiff( + mode::CMode, f::FA, args::Vararg{Annotation,Nargs} +) where {CMode<:Mode,FA<:Annotation,Nargs} + # need to guess the correct activity here. Execute the function, we will DCE it + # XXX: DONT MERGE + res = call_with_reactant(deepcopy(f.val), [deepcopy(x.val) for x in args]...) + return overload_autodiff( + mode, f, Enzyme.guess_activity(Core.Typeof(res), mode), args... + ) +end + function overload_autodiff( ::CMode, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs} ) where {CMode<:Mode,FA<:Annotation,A<:Annotation,Nargs} diff --git a/src/Overlay.jl b/src/Overlay.jl index 143e8a4a45..1abbed7b7f 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -9,6 +9,18 @@ end # Enzyme.jl overlays +@reactant_overlay @noinline function Enzyme.autodiff_deferred( + rmode::Enzyme.Mode, f::FA, args::Vararg{Annotation,Nargs} +) where {FA<:Annotation,Nargs} + return overload_autodiff(rmode, f, args...) +end + +@reactant_overlay @noinline function Enzyme.autodiff( + rmode::Enzyme.Mode, f::FA, args::Vararg{Annotation,Nargs} +) where {FA<:Annotation,Nargs} + return overload_autodiff(rmode, f, args...) +end + @reactant_overlay @noinline function Enzyme.autodiff_deferred( rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} ) where {FA<:Annotation,A<:Annotation,Nargs} diff --git a/test/autodiff.jl b/test/autodiff.jl index 4629cb5d62..ea2725ceef 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -2,6 +2,8 @@ using Enzyme, Reactant, Test, Random square(x) = x * 2 +sum_without_activity(x) = sum(abs2, x; dims=(2, 3)) + fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) @testset "Activity" begin @@ -39,6 +41,16 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) Enzyme.Duplicated end +@testset "Correct Activity Guess" begin + x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 3, 4, 5, 6)) + bx = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 3, 4, 5, 6)) + + res = only(@jit(Enzyme.autodiff(Forward, sum_without_activity, Duplicated(x, bx)))) + @test res isa Reactant.ConcreteRArray{Float32,4} + @test size(res) == (3, 1, 1, 6) + @test 2 .* sum_without_activity(bx) ≈ res +end + @testset "Basic Forward Mode" begin res1 = @jit( fwd( diff --git a/test/batching.jl b/test/batching.jl index 895bb02eb8..13298a5394 100644 --- a/test/batching.jl +++ b/test/batching.jl @@ -80,8 +80,8 @@ function naive_batched_matmul(x, y) end @testset "Naive Batched Matmul => Single Dot General" begin - x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 3, 256, 5)) - y = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 256, 7, 5)) + x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 3, 256, 8)) + y = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 256, 7, 8)) run_auto_batching_tests(naive_batched_matmul, x, y) end