From 227b964af256cb5a3b414c200d3f696aae8627c6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 9 Nov 2025 19:42:18 -0500 Subject: [PATCH 01/10] fix: finite diff gradient accidental type promotion --- src/TestUtils.jl | 7 ++++--- test/autodiff.jl | 6 ++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/TestUtils.jl b/src/TestUtils.jl index 179aba69fd..95cd1263c6 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -21,12 +21,13 @@ function construct_test_array(::Type{T}, dims::Int...) where {T} end function finite_difference_gradient( - f, x::AbstractArray{T}; epsilon=eps(T)^(3 / 4) + f, x::AbstractArray{T}; epsilon=eps(T)^(T(3 / 4)) ) where {T} onehot_matrix = Reactant.promote_to( - TracedRArray{Reactant.unwrapped_eltype(T),2}, LinearAlgebra.I(length(x)) + TracedRArray{Reactant.unwrapped_eltype(T),2}, + LinearAlgebra.Diagonal(fill(epsilon, length(x))), ) - perturbation = reshape(onehot_matrix .* epsilon, size(x)..., length(x)) + perturbation = reshape(onehot_matrix, size(x)..., length(x)) f_input = cat(x .+ perturbation, x .- perturbation; dims=ndims(x) + 1) f_evaluated = mapslices(f, f_input; dims=ntuple(identity, ndims(x))) diff --git a/test/autodiff.jl b/test/autodiff.jl index 4629cb5d62..18571de97a 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -366,3 +366,9 @@ end @test @jit(jvp_vjp_cubic(v_r, x_r, lambdas_r)) ≈ fill(6, (3, 2)) end + +@testset "Finite Difference Gradient" begin + x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float16, 2, 2)) + res = @jit Reactant.TestUtils.finite_difference_gradient(sum, x) + @test res isa Reactant.ConcreteRArray{Float16,2} +end From 75e89898734f344985720c1851c1836a279a5d91 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 12 Nov 2025 13:38:02 -0500 Subject: [PATCH 02/10] fix: use better epsilon --- src/Compiler.jl | 3 ++- src/TestUtils.jl | 15 ++++++++++++++- src/TracedRNumber.jl | 1 + test/nn/nnlib.jl | 4 ++-- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 5c37842a03..1366e87a1c 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -702,9 +702,10 @@ function optimization_passes( dus_to_concat::Bool=false, recognize_comms::Bool=true, lower_comms::Bool=true, - max_constant_threshold::Int=1024, backend::String="gpu", ) + (; max_constant_threshold) = compile_options + transform_passes_list = [ "patterns=compare_op_canon<16>", "transpose_transpose<16>", diff --git a/src/TestUtils.jl b/src/TestUtils.jl index 95cd1263c6..7ce1ca7313 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -20,8 +20,21 @@ function construct_test_array(::Type{T}, dims::Int...) where {T} return reshape(collect(T, 1:prod(dims)), dims...) end +# https://github.com/JuliaDiff/FiniteDiff.jl/blob/3a8c3d8d87e59de78e2831787a3f54b12b7c2075/src/epsilons.jl#L133 +function default_epslion(::Val{fdtype}, ::Type{T}) where {fdtype,T} + if fdtype == :forward + return sqrt(eps(real(T))) + elseif fdtype == :central + return cbrt(eps(real(T))) + elseif fdtype == :hcentral + return eps(T)^(T(1 / 4)) + else + return one(real(T)) + end +end + function finite_difference_gradient( - f, x::AbstractArray{T}; epsilon=eps(T)^(T(3 / 4)) + f, x::AbstractArray{T}; epsilon=default_epslion(Val(:central), T) ) where {T} onehot_matrix = Reactant.promote_to( TracedRArray{Reactant.unwrapped_eltype(T),2}, diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 672c01a44b..8a5bb307ea 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -491,6 +491,7 @@ for (jlop, hloop) in ( (:(Base.log), :log), (:(Base.log1p), :log_plus_one), (:(Base.sqrt), :sqrt), + (:(Base.cbrt), :cbrt), (:(Base.acos), :acos), (:(Base.acosh), :acosh), (:(Base.asin), :asin), diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index ba1dcfd2f5..630c09ad55 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -3,7 +3,7 @@ using NNlib, Reactant, Enzyme, Statistics, Test @testset "Activation Functions" begin sumabs2(f, x) = sum(abs2, f.(x)) - x_act = Reactant.TestUtils.construct_test_array(Float32, 10, 10) + x_act = Reactant.TestUtils.construct_test_array(Float32, 10, 10) .- 0.5f0 x_act_ca = Reactant.to_rarray(x_act) @testset "Activation: $act" for act in ( @@ -18,7 +18,7 @@ using NNlib, Reactant, Enzyme, Statistics, Test ) @test y_simple ≈ y_compile - @test ∂x_compile ≈ ∂x_compile_fd + @test ∂x_compile ≈ ∂x_compile_fd atol=1e-3 rtol=1e-3 end end From 25e05ca13893c70be25367f970875ffc14aab061 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 14 Nov 2025 17:21:26 -0500 Subject: [PATCH 03/10] feat: support multiple args for finitediff --- src/TestUtils.jl | 125 ++++++++++++++++++++++++++++++++++++++++++----- test/autodiff.jl | 24 +++++++++ test/nn/nnlib.jl | 2 +- 3 files changed, 137 insertions(+), 14 deletions(-) diff --git a/src/TestUtils.jl b/src/TestUtils.jl index 7ce1ca7313..f85dd34004 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -1,6 +1,7 @@ module TestUtils -using ..Reactant: Reactant, TracedRArray +using ..Reactant: Reactant, TracedRArray, TracedRNumber, TracedUtils +using Reactant.Ops: @opcall using ReactantCore: ReactantCore using LinearAlgebra: LinearAlgebra @@ -33,23 +34,121 @@ function default_epslion(::Val{fdtype}, ::Type{T}) where {fdtype,T} end end -function finite_difference_gradient( - f, x::AbstractArray{T}; epsilon=default_epslion(Val(:central), T) -) where {T} +function generate_purturbed_array(x::AbstractArray{T}, epsilon) where {T} onehot_matrix = Reactant.promote_to( TracedRArray{Reactant.unwrapped_eltype(T),2}, LinearAlgebra.Diagonal(fill(epsilon, length(x))), ) - perturbation = reshape(onehot_matrix, size(x)..., length(x)) - f_input = cat(x .+ perturbation, x .- perturbation; dims=ndims(x) + 1) - - f_evaluated = mapslices(f, f_input; dims=ntuple(identity, ndims(x))) - return ReactantCore.materialize_traced_array( - reshape( - (f_evaluated[1:length(x)] - f_evaluated[(length(x) + 1):end]) ./ (2 * epsilon), - size(x), - ), + perturbation = permutedims( + reshape(onehot_matrix, size(x)..., length(x)), (ndims(x) + 1, 1:(ndims(x))...) ) + return cat( + reshape(x, 1, size(x)...) .+ perturbation, + reshape(x, 1, size(x)...) .- perturbation; + dims=1, + ) +end + +function finite_difference_gradient(f::F, args...) where {F} + argprefix = gensym("finitediffarg") + resprefix = gensym("finitediffresult") + resargprefix = gensym("finitediffresarg") + + # TODO: can we detect and prevent using functions that mutate their arguments? + mlir_fn_res = TracedUtils.make_mlir_fn( + f, + args, + (), + "finite_difference_gradient_fn", + false; + args_in_result=:none, + argprefix, + resprefix, + resargprefix, + ) + + seenargs = Reactant.OrderedIdDict() + Reactant.make_tracer(seenargs, f, (argprefix,), Reactant.TracedSetPath) + for (i, arg) in enumerate(args) + Reactant.make_tracer(seenargs, arg, (argprefix, i), Reactant.TracedSetPath) + end + + linear_args = Reactant.TracedType[] + for (k, v) in seenargs + v isa Reactant.TracedType || continue + push!(linear_args, v) + end + + if ( + length(mlir_fn_res.linear_results) != 1 || + !(mlir_fn_res.linear_results[1] isa TracedRNumber) + ) + error( + "`finite_difference_gradient` only supports functions with a single scalar output", + ) + end + + gradient_results = TracedRArray[] + for i in 1:length(linear_args) + arg = linear_args[i] + if arg isa TracedRArray && TracedUtils.has_idx(arg, argprefix) + path = TracedUtils.get_idx(arg, argprefix) + if mlir_fn_res.fnwrapped && length(path) > 1 && path[2] == 1 + continue + end + + # We need the gradient wrt this argument + # we will naively insert the args here, cse will take care of the rest + new_arguments = TracedRArray[] + epsilon = default_epslion(Val(:central), Reactant.unwrapped_eltype(arg)) + pertubed_arg = generate_purturbed_array(arg, epsilon) + bsize = size(pertubed_arg, 1) + for j in 1:length(linear_args) + if i == j + new_arg = pertubed_arg + elseif linear_args[j] isa TracedRNumber + new_arg = @opcall broadcast_in_dim( + linear_args[j], Int64[], Int64[bsize] + ) + else + new_arg = @opcall broadcast_in_dim( + linear_args[j], + collect(Int64, 2:(ndims(linear_args[j]) + 1)), + Int64[bsize, size(linear_args[j])...], + ) + end + new_arg = @opcall transpose(new_arg, Int64[1, ((ndims(new_arg)):-1:2)...];) + push!(new_arguments, new_arg) + end + + batched_res = @opcall batch( + new_arguments, + [ + Reactant.MLIR.IR.TensorType( + Int64[bsize], + Reactant.MLIR.IR.Type( + Reactant.unwrapped_eltype(mlir_fn_res.linear_results[1]) + ), + ), + ], + Int64[bsize]; + fn=mlir_fn_res.f, + ) + batched_res = only(batched_res) + push!( + gradient_results, + ReactantCore.materialize_traced_array( + reshape( + (batched_res[1:(bsize ÷ 2)] - batched_res[((bsize ÷ 2) + 1):end]) ./ + (2 * epsilon), + size(arg), + ), + ), + ) + end + end + + return Tuple(gradient_results) end end diff --git a/test/autodiff.jl b/test/autodiff.jl index 18571de97a..23ded4d639 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -372,3 +372,27 @@ end res = @jit Reactant.TestUtils.finite_difference_gradient(sum, x) @test res isa Reactant.ConcreteRArray{Float16,2} end + +using Reactant + +function fdiff_multiple_args(f, nt, x) + return sum(abs2, f(nt.y .+ x .- nt.x)) +end + +struct WrapperFunc{T} + x::T +end + +(f::WrapperFunc)(x) = x .^ 3 .+ f.x + +nt = (; x=rand(3, 4), y=rand(3, 4)) +fn = WrapperFunc(rand(3, 4)) +x = rand(3, 4) + +nt_ra = Reactant.to_rarray(nt) +fn_ra = Reactant.to_rarray(fn) +x_ra = Reactant.to_rarray(x) + +@code_hlo Reactant.TestUtils.finite_difference_gradient( + fdiff_multiple_args, fn_ra, nt_ra, x_ra +) diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 630c09ad55..754df1e142 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -18,7 +18,7 @@ using NNlib, Reactant, Enzyme, Statistics, Test ) @test y_simple ≈ y_compile - @test ∂x_compile ≈ ∂x_compile_fd atol=1e-3 rtol=1e-3 + @test ∂x_compile ≈ ∂x_compile_fd atol = 1e-3 rtol = 1e-3 end end From 6bf8c37ef0234667af3d5638f276cbc291894416 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 14 Nov 2025 21:03:42 -0500 Subject: [PATCH 04/10] feat: preserve the correct return type --- src/TestUtils.jl | 58 ++++++++++++++++++++++++------------ src/TracedRArray.jl | 11 ++++++- src/stdlibs/LinearAlgebra.jl | 4 +-- 3 files changed, 51 insertions(+), 22 deletions(-) diff --git a/src/TestUtils.jl b/src/TestUtils.jl index f85dd34004..5c6604c610 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -34,22 +34,31 @@ function default_epslion(::Val{fdtype}, ::Type{T}) where {fdtype,T} end end -function generate_purturbed_array(x::AbstractArray{T}, epsilon) where {T} +function get_perturbation(x::AbstractArray{T}, epsilon) where {T} onehot_matrix = Reactant.promote_to( TracedRArray{Reactant.unwrapped_eltype(T),2}, - LinearAlgebra.Diagonal(fill(epsilon, length(x))), + LinearAlgebra.Diagonal(fill(epsilon, length(x))); ) - perturbation = permutedims( + return permutedims( reshape(onehot_matrix, size(x)..., length(x)), (ndims(x) + 1, 1:(ndims(x))...) ) - return cat( - reshape(x, 1, size(x)...) .+ perturbation, - reshape(x, 1, size(x)...) .- perturbation; - dims=1, - ) end -function finite_difference_gradient(f::F, args...) where {F} +function generate_perturbed_array(::Val{:central}, x::AbstractArray{T}, epsilon) where {T} + perturbation = get_perturbation(x, epsilon) + x_ = reshape(x, 1, size(x)...) + return cat(x_ .+ perturbation, x_ .- perturbation; dims=1) +end + +function generate_perturbed_array(::Val{:forward}, x::AbstractArray{T}, epsilon) where {T} + perturbation = get_perturbation(x, epsilon) + x_ = reshape(x, 1, size(x)...) + return cat(x_ .+ perturbation, x_; dims=1) +end + +function finite_difference_gradient( + f::F, args...; method::Union{Val{:central},Val{:forward}}=Val(:central) +) where {F} argprefix = gensym("finitediffarg") resprefix = gensym("finitediffresult") resargprefix = gensym("finitediffresarg") @@ -89,6 +98,7 @@ function finite_difference_gradient(f::F, args...) where {F} end gradient_results = TracedRArray[] + gradient_result_map_path = [] for i in 1:length(linear_args) arg = linear_args[i] if arg isa TracedRArray && TracedUtils.has_idx(arg, argprefix) @@ -100,8 +110,10 @@ function finite_difference_gradient(f::F, args...) where {F} # We need the gradient wrt this argument # we will naively insert the args here, cse will take care of the rest new_arguments = TracedRArray[] - epsilon = default_epslion(Val(:central), Reactant.unwrapped_eltype(arg)) - pertubed_arg = generate_purturbed_array(arg, epsilon) + + epsilon = default_epslion(method, Reactant.unwrapped_eltype(arg)) + pertubed_arg = generate_perturbed_array(method, arg, epsilon) + bsize = size(pertubed_arg, 1) for j in 1:length(linear_args) if i == j @@ -135,20 +147,28 @@ function finite_difference_gradient(f::F, args...) where {F} fn=mlir_fn_res.f, ) batched_res = only(batched_res) + + if method isa Val{:central} + diff = batched_res[1:(bsize ÷ 2)] - batched_res[((bsize ÷ 2) + 1):end] + grad_res = diff ./ (2 * epsilon) + elseif method isa Val{:forward} + diff = batched_res[1:(end - 1)] .- batched_res[end:end] + grad_res = diff ./ epsilon + end + + push!(gradient_result_map_path, TracedUtils.get_idx(arg, argprefix)) push!( gradient_results, - ReactantCore.materialize_traced_array( - reshape( - (batched_res[1:(bsize ÷ 2)] - batched_res[((bsize ÷ 2) + 1):end]) ./ - (2 * epsilon), - size(arg), - ), - ), + ReactantCore.materialize_traced_array(reshape(grad_res, size(arg))), ) end end - return Tuple(gradient_results) + results = deepcopy(args) + for (path, grad_res) in zip(gradient_result_map_path, gradient_results) + TracedUtils.set!(results, path[2:end], grad_res.mlir_data) + end + return results end end diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 5ffc60c165..73dacadfb2 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -33,6 +33,15 @@ Base.convert(T::Type{<:TracedRArray}, x::AbstractArray) = Reactant.promote_to(T, Base.complex(x::TracedRArray{<:Real}) = complex.(x) Base.complex(x::TracedRArray{<:Complex}) = x +function Base.deepcopy_internal(x::TracedRArray, stackdict::IdDict) + if haskey(stackdict, x) + return stackdict[x]::typeof(x) + end + y = copy(x) + stackdict[x] = y + return y +end + TracedRArray{T,N}(x::AbstractArray) where {T,N} = convert(TracedRArray{T,N}, x) function maybe_assert_scalar_setindexing( @@ -1109,7 +1118,7 @@ function Base.accumulate_pairwise!(op, A::AnyTracedRVector, B::AnyTracedRVector) return accumulate!(op, A, B; dims=1) end -if isdefined(Base, :_accumulate_promote_op) +@static if isdefined(Base, :_accumulate_promote_op) function Base._accumulate_promote_op(op, A::AnyTracedRArray{T}; init=nothing) where {T} if init !== nothing init isa TracedRNumber && (init = zero(unwrapped_eltype(init))) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 23e8cb4390..fa8285e630 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -273,7 +273,7 @@ function overloaded_mul!( return C end -if isdefined(LinearAlgebra, :_triu) +@static if isdefined(LinearAlgebra, :_triu) function LinearAlgebra._triu(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T} return overloaded_triu(materialize_traced_array(A), k) end @@ -284,7 +284,7 @@ if isdefined(LinearAlgebra, :_triu) end end -if isdefined(LinearAlgebra, :_tril) +@static if isdefined(LinearAlgebra, :_tril) function LinearAlgebra._tril(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T} return overloaded_tril(materialize_traced_array(A), k) end From 963302fe26c96d4db1fec0d28c5f273d7c1a98a1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 14 Nov 2025 21:08:50 -0500 Subject: [PATCH 05/10] test: against enzyme --- test/autodiff.jl | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/test/autodiff.jl b/test/autodiff.jl index 23ded4d639..2c3f8a0fac 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -373,8 +373,6 @@ end @test res isa Reactant.ConcreteRArray{Float16,2} end -using Reactant - function fdiff_multiple_args(f, nt, x) return sum(abs2, f(nt.y .+ x .- nt.x)) end @@ -385,14 +383,27 @@ end (f::WrapperFunc)(x) = x .^ 3 .+ f.x -nt = (; x=rand(3, 4), y=rand(3, 4)) -fn = WrapperFunc(rand(3, 4)) -x = rand(3, 4) +@testset "Finite Difference Gradient (non vector inputs)" begin + nt = (; + x=Reactant.TestUtils.construct_test_array(Float64, 3, 4), + y=Reactant.TestUtils.construct_test_array(Float64, 3, 4), + ) + fn = WrapperFunc(Reactant.TestUtils.construct_test_array(Float64, 3, 4)) + x = Reactant.TestUtils.construct_test_array(Float64, 3, 4) -nt_ra = Reactant.to_rarray(nt) -fn_ra = Reactant.to_rarray(fn) -x_ra = Reactant.to_rarray(x) + nt_ra = Reactant.to_rarray(nt) + fn_ra = Reactant.to_rarray(fn) + x_ra = Reactant.to_rarray(x) -@code_hlo Reactant.TestUtils.finite_difference_gradient( - fdiff_multiple_args, fn_ra, nt_ra, x_ra -) + results_fd = @jit Reactant.TestUtils.finite_difference_gradient( + fdiff_multiple_args, fn_ra, nt_ra, x_ra + ) + @test results_fd isa typeof((fn_ra, nt_ra, x_ra)) + + results_enz = @jit Enzyme.gradient(Reverse, fdiff_multiple_args, fn_ra, nt_ra, x_ra) + + @test results_fd[1].x ≈ results_enz[1].x + @test results_fd[2].x ≈ results_enz[2].x + @test results_fd[2].y ≈ results_enz[2].y + @test results_fd[3] ≈ results_enz[3] +end From bf01c95d00764d234ff877ed0c26f8d3da01fef2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 14 Nov 2025 21:20:04 -0500 Subject: [PATCH 06/10] test: incorrect usage --- src/TestUtils.jl | 3 ++- test/nn/luxlib.jl | 41 +++++++++++------------------------------ 2 files changed, 13 insertions(+), 31 deletions(-) diff --git a/src/TestUtils.jl b/src/TestUtils.jl index 5c6604c610..00c3409b1b 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -93,7 +93,8 @@ function finite_difference_gradient( !(mlir_fn_res.linear_results[1] isa TracedRNumber) ) error( - "`finite_difference_gradient` only supports functions with a single scalar output", + "`finite_difference_gradient` only supports functions with a single scalar \ + output. Received : $(mlir_fn_res.linear_results)", ) end diff --git a/test/nn/luxlib.jl b/test/nn/luxlib.jl index 135a83f8f9..224c2d257d 100644 --- a/test/nn/luxlib.jl +++ b/test/nn/luxlib.jl @@ -9,20 +9,9 @@ using LuxLib, Reactant, Enzyme, NNlib, Test end function ∇fuseddense_fd(act, weight, x, bias) - dw = Reactant.TestUtils.finite_difference_gradient( - w -> sumabs2fuseddense(act, w, x, bias), weight + return Reactant.TestUtils.finite_difference_gradient( + (w, x, b) -> sumabs2fuseddense(act, w, x, b), weight, x, bias ) - dx = Reactant.TestUtils.finite_difference_gradient( - x -> sumabs2fuseddense(act, weight, x, bias), x - ) - db = if bias === nothing - nothing - else - Reactant.TestUtils.finite_difference_gradient( - b -> sumabs2fuseddense(act, weight, x, b), bias - ) - end - return dw, dx, db end @testset for act in (identity, relu, sigmoid, tanh, gelu), has_bias in (true, false) @@ -45,7 +34,7 @@ using LuxLib, Reactant, Enzyme, NNlib, Test act, weight_ra, x_ra, bias_ra ) - dw_fd, dx_fd, db_fd = @jit ∇fuseddense_fd(act, weight, x, bias) + dw_fd, dx_fd, db_fd = @jit ∇fuseddense_fd(act, weight_ra, x_ra, bias_ra) @test dw_fd ≈ dw_compile atol = 1e-5 rtol = 1e-2 @test dx_fd ≈ dx_compile atol = 1e-5 rtol = 1e-2 @@ -65,26 +54,18 @@ end end function ∇biasact_fd(act, x, b) - dx = Reactant.TestUtils.finite_difference_gradient( - x -> sumabs2biasact(act, x, b), x + return Reactant.TestUtils.finite_difference_gradient( + (x, b) -> sumabs2biasact(act, x, b), x, b ) - db = Reactant.TestUtils.finite_difference_gradient( - b -> sumabs2biasact(act, x, b), b - ) - return dx, db end function ∇biasact!!(act, x, b) return Enzyme.gradient(Reverse, sumabs2biasact!!, Const(act), x, b)[2:end] end function ∇biasact!!_fd(act, x, b) - dx = Reactant.TestUtils.finite_difference_gradient( - x -> sumabs2biasact!!(act, x, b), x - ) - db = Reactant.TestUtils.finite_difference_gradient( - b -> sumabs2biasact!!(act, x, b), b + return Reactant.TestUtils.finite_difference_gradient( + (x, b) -> sumabs2biasact!!(act, x, b), x, b ) - return dx, db end @testset for act in (identity, relu, sigmoid, tanh, gelu) @@ -104,7 +85,7 @@ end @test y_simple!! ≈ y_compile!! atol = 1e-5 rtol = 1e-2 @testset "Enzyme: bias_activation" begin - ∂x_enz, ∂b_enz = @jit ∇biasact_fd(act, x, b) + ∂x_enz, ∂b_enz = @jit ∇biasact_fd(act, x_ra, b_ra) ∂x_compile, ∂b_compile = @jit ∇biasact(act, x_ra, b_ra) @test ∂x_enz ≈ ∂x_compile atol = 1e-5 rtol = 1e-2 @@ -112,7 +93,7 @@ end end @testset "Enzyme: bias_activation!!" begin - ∂x_enz!!, ∂b_enz!! = @jit ∇biasact!!_fd(act, x, b) + ∂x_enz!!, ∂b_enz!! = @jit ∇biasact!!_fd(act, x_ra, b_ra) ∂x_compile!!, ∂b_compile!! = @jit ∇biasact!!(act, x_ra, b_ra) @test ∂x_enz!! ≈ ∂x_compile!! atol = 1e-5 rtol = 1e-2 @@ -146,9 +127,9 @@ end @test y_simple ≈ y_compile atol = 1e-5 rtol = 1e-2 @test y_simple!! ≈ y_compile!! atol = 1e-5 rtol = 1e-2 - ∂x_enz = @jit ∇sumabs2_fd(act, x_act) + ∂x_enz = @jit ∇sumabs2_fd(act, x_act_ca) ∂x_compile = @jit ∇sumabs2(act, x_act_ca) - ∂x_enz!! = @jit ∇sumabs2!!_fd(act, x_act) + ∂x_enz!! = @jit ∇sumabs2!!_fd(act, x_act_ca) ∂x_compile!! = @jit ∇sumabs2!!(act, x_act_ca) @test ∂x_enz ≈ ∂x_compile atol = 1e-5 rtol = 1e-2 From 816ad7992ac02708f4250d19f5b93613ee137fd2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 14 Nov 2025 21:21:13 -0500 Subject: [PATCH 07/10] Update src/TestUtils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/TestUtils.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/TestUtils.jl b/src/TestUtils.jl index 00c3409b1b..247eb0eaf9 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -92,10 +92,8 @@ function finite_difference_gradient( length(mlir_fn_res.linear_results) != 1 || !(mlir_fn_res.linear_results[1] isa TracedRNumber) ) - error( - "`finite_difference_gradient` only supports functions with a single scalar \ - output. Received : $(mlir_fn_res.linear_results)", - ) + error("`finite_difference_gradient` only supports functions with a single scalar \ + output. Received : $(mlir_fn_res.linear_results)") end gradient_results = TracedRArray[] From a12b60a46fd399881079265f0ec01e58d79b8080 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 14 Nov 2025 21:49:06 -0500 Subject: [PATCH 08/10] fix: single arg return type --- src/TestUtils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/TestUtils.jl b/src/TestUtils.jl index 247eb0eaf9..7da5b68b1c 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -167,6 +167,7 @@ function finite_difference_gradient( for (path, grad_res) in zip(gradient_result_map_path, gradient_results) TracedUtils.set!(results, path[2:end], grad_res.mlir_data) end + length(args) == 1 && return results[1] return results end From fa7804349a0cdd3d29a44a0614e8b6ab52319de8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Nov 2025 12:40:13 -0500 Subject: [PATCH 09/10] test: use analytic gradients --- test/nn/luxlib.jl | 113 +++++++++++++++++++++++++++++----------------- test/nn/nnlib.jl | 23 +++++++--- 2 files changed, 89 insertions(+), 47 deletions(-) diff --git a/test/nn/luxlib.jl b/test/nn/luxlib.jl index 224c2d257d..cf23b087ba 100644 --- a/test/nn/luxlib.jl +++ b/test/nn/luxlib.jl @@ -8,13 +8,34 @@ using LuxLib, Reactant, Enzyme, NNlib, Test return Enzyme.gradient(Reverse, sumabs2fuseddense, Const(act), weight, x, bias)[2:end] end - function ∇fuseddense_fd(act, weight, x, bias) - return Reactant.TestUtils.finite_difference_gradient( - (w, x, b) -> sumabs2fuseddense(act, w, x, b), weight, x, bias - ) + function ∇fuseddense(actgradfn, act, weight, x, bias) + z = weight * x + if bias !== nothing + z .+= bias + end + Ω = act.(z) + + δ = 2 .* actgradfn.(Ω, z) .* Ω + + ∂weight = δ * x' + if bias !== nothing + ∂bias = vec(sum(δ; dims=2)) + else + ∂bias = nothing + end + ∂x = weight' * δ + return ∂weight, ∂x, ∂bias end - @testset for act in (identity, relu, sigmoid, tanh, gelu), has_bias in (true, false) + @testset "Activation: $act | bias=$has_bias" for (act, gradfn) in ( + (identity, (Ω, x) -> one(Ω)), + (relu, (Ω, x) -> (Ω > 0)), + (sigmoid, (Ω, x) -> conj((1 - Ω) * Ω)), + (tanh, (Ω, x) -> conj(1 - Ω^2)), + (gelu, (Ω, x) -> NNlib.deriv_gelu_tanh(x)), + ), + has_bias in (true, false) + weight = Reactant.TestUtils.construct_test_array(Float32, 9, 10) x = Reactant.TestUtils.construct_test_array(Float32, 10, 12) bias = has_bias ? Reactant.TestUtils.construct_test_array(Float32, 9) : nothing @@ -24,9 +45,7 @@ using LuxLib, Reactant, Enzyme, NNlib, Test bias_ra = Reactant.to_rarray(bias) y_compile = @jit fused_dense_bias_activation(act, weight_ra, x_ra, bias_ra) - y_res = fused_dense_bias_activation(act, weight, x, bias) - @test y_res ≈ y_compile atol = 1e-5 rtol = 1e-2 @testset "Enzyme: fused_dense_bias_activation" begin @@ -34,11 +53,11 @@ using LuxLib, Reactant, Enzyme, NNlib, Test act, weight_ra, x_ra, bias_ra ) - dw_fd, dx_fd, db_fd = @jit ∇fuseddense_fd(act, weight_ra, x_ra, bias_ra) + dw_gt, dx_gt, db_gt = @jit ∇fuseddense(gradfn, act, weight_ra, x_ra, bias_ra) - @test dw_fd ≈ dw_compile atol = 1e-5 rtol = 1e-2 - @test dx_fd ≈ dx_compile atol = 1e-5 rtol = 1e-2 - has_bias && @test db_fd ≈ db_compile atol = 1e-5 rtol = 1e-2 + @test dw_gt ≈ dw_compile atol = 1e-5 rtol = 1e-2 + @test dx_gt ≈ dx_compile atol = 1e-5 rtol = 1e-2 + has_bias && @test db_gt ≈ db_compile atol = 1e-5 rtol = 1e-2 end end end @@ -53,22 +72,25 @@ end return Enzyme.gradient(Reverse, sumabs2biasact, Const(act), x, b)[2:end] end - function ∇biasact_fd(act, x, b) - return Reactant.TestUtils.finite_difference_gradient( - (x, b) -> sumabs2biasact(act, x, b), x, b - ) - end - function ∇biasact!!(act, x, b) return Enzyme.gradient(Reverse, sumabs2biasact!!, Const(act), x, b)[2:end] end - function ∇biasact!!_fd(act, x, b) - return Reactant.TestUtils.finite_difference_gradient( - (x, b) -> sumabs2biasact!!(act, x, b), x, b - ) + + function ∇biasact(gradfn, act, x, b) + xb = x .+ b + Ω = act.(xb) + ∂x = 2 .* gradfn.(Ω, xb) .* Ω + ∂b = vec(sum(∂x; dims=2)) + return ∂x, ∂b end - @testset for act in (identity, relu, sigmoid, tanh, gelu) + @testset "Activation: $act" for (act, gradfn) in ( + (identity, (Ω, x) -> one(Ω)), + (relu, (Ω, x) -> (Ω > 0)), + (sigmoid, (Ω, x) -> conj((1 - Ω) * Ω)), + (tanh, (Ω, x) -> conj(1 - Ω^2)), + (gelu, (Ω, x) -> NNlib.deriv_gelu_tanh(x)), + ) x = Reactant.TestUtils.construct_test_array(Float32, 10, 10) b = Reactant.TestUtils.construct_test_array(Float32, 10) @@ -84,20 +106,20 @@ end @test y_simple ≈ y_compile atol = 1e-5 rtol = 1e-2 @test y_simple!! ≈ y_compile!! atol = 1e-5 rtol = 1e-2 + ∂x_gt, ∂b_gt = @jit ∇biasact(gradfn, act, x_ra, b_ra) + @testset "Enzyme: bias_activation" begin - ∂x_enz, ∂b_enz = @jit ∇biasact_fd(act, x_ra, b_ra) - ∂x_compile, ∂b_compile = @jit ∇biasact(act, x_ra, b_ra) + ∂x_enz, ∂b_enz = @jit ∇biasact(act, x_ra, b_ra) - @test ∂x_enz ≈ ∂x_compile atol = 1e-5 rtol = 1e-2 - @test ∂b_enz ≈ ∂b_compile atol = 1e-5 rtol = 1e-2 + @test ∂x_enz ≈ ∂x_gt atol = 1e-5 rtol = 1e-2 + @test ∂b_enz ≈ ∂b_gt atol = 1e-5 rtol = 1e-2 end @testset "Enzyme: bias_activation!!" begin - ∂x_enz!!, ∂b_enz!! = @jit ∇biasact!!_fd(act, x_ra, b_ra) - ∂x_compile!!, ∂b_compile!! = @jit ∇biasact!!(act, x_ra, b_ra) + ∂x_enz!!, ∂b_enz!! = @jit ∇biasact!!(act, x_ra, b_ra) - @test ∂x_enz!! ≈ ∂x_compile!! atol = 1e-5 rtol = 1e-2 - @test ∂b_enz!! ≈ ∂b_compile!! atol = 1e-5 rtol = 1e-2 + @test ∂x_enz!! ≈ ∂x_gt atol = 1e-5 rtol = 1e-2 + @test ∂b_enz!! ≈ ∂b_gt atol = 1e-5 rtol = 1e-2 end end end @@ -108,16 +130,26 @@ end sumabs2!!(f, x) = sum(abs2, fast_activation!!(f, copy(x))) ∇sumabs2(f, x) = Enzyme.gradient(Reverse, sumabs2, Const(f), x)[2] - ∇sumabs2_fd(f, x) = Reactant.TestUtils.finite_difference_gradient(x -> sumabs2(f, x), x) ∇sumabs2!!(f, x) = Enzyme.gradient(Reverse, sumabs2!!, Const(f), x)[2] - ∇sumabs2!!_fd(f, x) = - Reactant.TestUtils.finite_difference_gradient(x -> sumabs2!!(f, x), x) + + function ∇sumabs2(gradfn, f, x) + Ω = f.(x) + return 2 .* gradfn.(Ω, x) .* Ω + end x_act = Reactant.TestUtils.construct_test_array(Float32, 10, 10) x_act_ca = Reactant.to_rarray(x_act) - @testset "Activation: $act" for act in ( - identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2 + @testset "Activation: $act" for (act, gradfn) in ( + (identity, (Ω, x) -> one(Ω)), + (relu, (Ω, x) -> (Ω > 0)), + (sigmoid, (Ω, x) -> conj((1 - Ω) * Ω)), + (tanh, (Ω, x) -> conj(1 - Ω^2)), + (tanh_fast, (Ω, x) -> conj(1 - Ω^2)), + (sigmoid_fast, (Ω, x) -> conj((1 - Ω) * Ω)), + (gelu, (Ω, x) -> NNlib.deriv_gelu_tanh(x)), + (abs2, (Ω, x) -> (2 * x)), + (relu6, (Ω, x) -> (Ω > 0) & (Ω < 6)), ) y_simple = sumabs2(act, x_act) y_simple!! = sumabs2!!(act, x_act) @@ -127,13 +159,12 @@ end @test y_simple ≈ y_compile atol = 1e-5 rtol = 1e-2 @test y_simple!! ≈ y_compile!! atol = 1e-5 rtol = 1e-2 - ∂x_enz = @jit ∇sumabs2_fd(act, x_act_ca) - ∂x_compile = @jit ∇sumabs2(act, x_act_ca) - ∂x_enz!! = @jit ∇sumabs2!!_fd(act, x_act_ca) - ∂x_compile!! = @jit ∇sumabs2!!(act, x_act_ca) + ∂x_enz = @jit ∇sumabs2(act, x_act_ca) + ∂x_enz!! = @jit ∇sumabs2!!(act, x_act_ca) + ∂x_gt = @jit ∇sumabs2(gradfn, act, x_act_ca) - @test ∂x_enz ≈ ∂x_compile atol = 1e-5 rtol = 1e-2 - @test ∂x_enz!! ≈ ∂x_compile!! atol = 1e-5 rtol = 1e-2 + @test ∂x_enz ≈ ∂x_gt atol = 1e-5 rtol = 1e-2 + @test ∂x_enz!! ≈ ∂x_gt atol = 1e-5 rtol = 1e-2 end end diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 754df1e142..17fb16a7e7 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -3,22 +3,33 @@ using NNlib, Reactant, Enzyme, Statistics, Test @testset "Activation Functions" begin sumabs2(f, x) = sum(abs2, f.(x)) + function ∇sumabs2(gradfn, f, x) + Ω = f.(x) + return 2 .* gradfn.(Ω, x) .* Ω + end + x_act = Reactant.TestUtils.construct_test_array(Float32, 10, 10) .- 0.5f0 x_act_ca = Reactant.to_rarray(x_act) - @testset "Activation: $act" for act in ( - identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2, relu6 + @testset "Activation: $act" for (act, gradfn) in ( + (identity, (Ω, x) -> one(Ω)), + (relu, (Ω, x) -> (Ω > 0)), + (sigmoid, (Ω, x) -> conj((1 - Ω) * Ω)), + (tanh, (Ω, x) -> conj(1 - Ω^2)), + (tanh_fast, (Ω, x) -> conj(1 - Ω^2)), + (sigmoid_fast, (Ω, x) -> conj((1 - Ω) * Ω)), + (gelu, (Ω, x) -> NNlib.deriv_gelu_tanh(x)), + (abs2, (Ω, x) -> (2 * x)), + (relu6, (Ω, x) -> (Ω > 0) & (Ω < 6)), ) y_simple = sumabs2(act, x_act) y_compile = @jit sumabs2(act, x_act_ca) ∂x_compile = @jit(Enzyme.gradient(Reverse, sumabs2, Const(act), x_act_ca))[2] - ∂x_compile_fd = @jit Reactant.TestUtils.finite_difference_gradient( - Base.Fix1(sumabs2, act), x_act_ca - ) + ∂x_compile_gt = @jit ∇sumabs2(gradfn, act, x_act_ca) @test y_simple ≈ y_compile - @test ∂x_compile ≈ ∂x_compile_fd atol = 1e-3 rtol = 1e-3 + @test ∂x_compile ≈ ∂x_compile_gt atol = 1e-3 rtol = 1e-3 end end From 871dd896afff0681eef574ddf7dcd970bd394356 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Nov 2025 12:41:34 -0500 Subject: [PATCH 10/10] Update test/nn/luxlib.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/nn/luxlib.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/nn/luxlib.jl b/test/nn/luxlib.jl index cf23b087ba..c862b6497e 100644 --- a/test/nn/luxlib.jl +++ b/test/nn/luxlib.jl @@ -85,12 +85,12 @@ end end @testset "Activation: $act" for (act, gradfn) in ( - (identity, (Ω, x) -> one(Ω)), - (relu, (Ω, x) -> (Ω > 0)), - (sigmoid, (Ω, x) -> conj((1 - Ω) * Ω)), - (tanh, (Ω, x) -> conj(1 - Ω^2)), - (gelu, (Ω, x) -> NNlib.deriv_gelu_tanh(x)), - ) + (identity, (Ω, x) -> one(Ω)), + (relu, (Ω, x) -> (Ω > 0)), + (sigmoid, (Ω, x) -> conj((1 - Ω) * Ω)), + (tanh, (Ω, x) -> conj(1 - Ω^2)), + (gelu, (Ω, x) -> NNlib.deriv_gelu_tanh(x)), + ) x = Reactant.TestUtils.construct_test_array(Float32, 10, 10) b = Reactant.TestUtils.construct_test_array(Float32, 10)