diff --git a/src/Compiler.jl b/src/Compiler.jl index 5c37842a03..1366e87a1c 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -702,9 +702,10 @@ function optimization_passes( dus_to_concat::Bool=false, recognize_comms::Bool=true, lower_comms::Bool=true, - max_constant_threshold::Int=1024, backend::String="gpu", ) + (; max_constant_threshold) = compile_options + transform_passes_list = [ "patterns=compare_op_canon<16>", "transpose_transpose<16>", diff --git a/src/TestUtils.jl b/src/TestUtils.jl index 179aba69fd..7da5b68b1c 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -1,6 +1,7 @@ module TestUtils -using ..Reactant: Reactant, TracedRArray +using ..Reactant: Reactant, TracedRArray, TracedRNumber, TracedUtils +using Reactant.Ops: @opcall using ReactantCore: ReactantCore using LinearAlgebra: LinearAlgebra @@ -20,22 +21,154 @@ function construct_test_array(::Type{T}, dims::Int...) where {T} return reshape(collect(T, 1:prod(dims)), dims...) end -function finite_difference_gradient( - f, x::AbstractArray{T}; epsilon=eps(T)^(3 / 4) -) where {T} +# https://github.com/JuliaDiff/FiniteDiff.jl/blob/3a8c3d8d87e59de78e2831787a3f54b12b7c2075/src/epsilons.jl#L133 +function default_epslion(::Val{fdtype}, ::Type{T}) where {fdtype,T} + if fdtype == :forward + return sqrt(eps(real(T))) + elseif fdtype == :central + return cbrt(eps(real(T))) + elseif fdtype == :hcentral + return eps(T)^(T(1 / 4)) + else + return one(real(T)) + end +end + +function get_perturbation(x::AbstractArray{T}, epsilon) where {T} onehot_matrix = Reactant.promote_to( - TracedRArray{Reactant.unwrapped_eltype(T),2}, LinearAlgebra.I(length(x)) + TracedRArray{Reactant.unwrapped_eltype(T),2}, + LinearAlgebra.Diagonal(fill(epsilon, length(x))); + ) + return permutedims( + reshape(onehot_matrix, size(x)..., length(x)), (ndims(x) + 1, 1:(ndims(x))...) + ) +end + +function generate_perturbed_array(::Val{:central}, x::AbstractArray{T}, epsilon) where {T} + perturbation = get_perturbation(x, epsilon) + x_ = reshape(x, 1, size(x)...) + return cat(x_ .+ perturbation, x_ .- perturbation; dims=1) +end + +function generate_perturbed_array(::Val{:forward}, x::AbstractArray{T}, epsilon) where {T} + perturbation = get_perturbation(x, epsilon) + x_ = reshape(x, 1, size(x)...) + return cat(x_ .+ perturbation, x_; dims=1) +end + +function finite_difference_gradient( + f::F, args...; method::Union{Val{:central},Val{:forward}}=Val(:central) +) where {F} + argprefix = gensym("finitediffarg") + resprefix = gensym("finitediffresult") + resargprefix = gensym("finitediffresarg") + + # TODO: can we detect and prevent using functions that mutate their arguments? + mlir_fn_res = TracedUtils.make_mlir_fn( + f, + args, + (), + "finite_difference_gradient_fn", + false; + args_in_result=:none, + argprefix, + resprefix, + resargprefix, ) - perturbation = reshape(onehot_matrix .* epsilon, size(x)..., length(x)) - f_input = cat(x .+ perturbation, x .- perturbation; dims=ndims(x) + 1) - - f_evaluated = mapslices(f, f_input; dims=ntuple(identity, ndims(x))) - return ReactantCore.materialize_traced_array( - reshape( - (f_evaluated[1:length(x)] - f_evaluated[(length(x) + 1):end]) ./ (2 * epsilon), - size(x), - ), + + seenargs = Reactant.OrderedIdDict() + Reactant.make_tracer(seenargs, f, (argprefix,), Reactant.TracedSetPath) + for (i, arg) in enumerate(args) + Reactant.make_tracer(seenargs, arg, (argprefix, i), Reactant.TracedSetPath) + end + + linear_args = Reactant.TracedType[] + for (k, v) in seenargs + v isa Reactant.TracedType || continue + push!(linear_args, v) + end + + if ( + length(mlir_fn_res.linear_results) != 1 || + !(mlir_fn_res.linear_results[1] isa TracedRNumber) ) + error("`finite_difference_gradient` only supports functions with a single scalar \ + output. Received : $(mlir_fn_res.linear_results)") + end + + gradient_results = TracedRArray[] + gradient_result_map_path = [] + for i in 1:length(linear_args) + arg = linear_args[i] + if arg isa TracedRArray && TracedUtils.has_idx(arg, argprefix) + path = TracedUtils.get_idx(arg, argprefix) + if mlir_fn_res.fnwrapped && length(path) > 1 && path[2] == 1 + continue + end + + # We need the gradient wrt this argument + # we will naively insert the args here, cse will take care of the rest + new_arguments = TracedRArray[] + + epsilon = default_epslion(method, Reactant.unwrapped_eltype(arg)) + pertubed_arg = generate_perturbed_array(method, arg, epsilon) + + bsize = size(pertubed_arg, 1) + for j in 1:length(linear_args) + if i == j + new_arg = pertubed_arg + elseif linear_args[j] isa TracedRNumber + new_arg = @opcall broadcast_in_dim( + linear_args[j], Int64[], Int64[bsize] + ) + else + new_arg = @opcall broadcast_in_dim( + linear_args[j], + collect(Int64, 2:(ndims(linear_args[j]) + 1)), + Int64[bsize, size(linear_args[j])...], + ) + end + new_arg = @opcall transpose(new_arg, Int64[1, ((ndims(new_arg)):-1:2)...];) + push!(new_arguments, new_arg) + end + + batched_res = @opcall batch( + new_arguments, + [ + Reactant.MLIR.IR.TensorType( + Int64[bsize], + Reactant.MLIR.IR.Type( + Reactant.unwrapped_eltype(mlir_fn_res.linear_results[1]) + ), + ), + ], + Int64[bsize]; + fn=mlir_fn_res.f, + ) + batched_res = only(batched_res) + + if method isa Val{:central} + diff = batched_res[1:(bsize ÷ 2)] - batched_res[((bsize ÷ 2) + 1):end] + grad_res = diff ./ (2 * epsilon) + elseif method isa Val{:forward} + diff = batched_res[1:(end - 1)] .- batched_res[end:end] + grad_res = diff ./ epsilon + end + + push!(gradient_result_map_path, TracedUtils.get_idx(arg, argprefix)) + push!( + gradient_results, + ReactantCore.materialize_traced_array(reshape(grad_res, size(arg))), + ) + end + end + + results = deepcopy(args) + for (path, grad_res) in zip(gradient_result_map_path, gradient_results) + TracedUtils.set!(results, path[2:end], grad_res.mlir_data) + end + length(args) == 1 && return results[1] + return results end end diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 5ffc60c165..73dacadfb2 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -33,6 +33,15 @@ Base.convert(T::Type{<:TracedRArray}, x::AbstractArray) = Reactant.promote_to(T, Base.complex(x::TracedRArray{<:Real}) = complex.(x) Base.complex(x::TracedRArray{<:Complex}) = x +function Base.deepcopy_internal(x::TracedRArray, stackdict::IdDict) + if haskey(stackdict, x) + return stackdict[x]::typeof(x) + end + y = copy(x) + stackdict[x] = y + return y +end + TracedRArray{T,N}(x::AbstractArray) where {T,N} = convert(TracedRArray{T,N}, x) function maybe_assert_scalar_setindexing( @@ -1109,7 +1118,7 @@ function Base.accumulate_pairwise!(op, A::AnyTracedRVector, B::AnyTracedRVector) return accumulate!(op, A, B; dims=1) end -if isdefined(Base, :_accumulate_promote_op) +@static if isdefined(Base, :_accumulate_promote_op) function Base._accumulate_promote_op(op, A::AnyTracedRArray{T}; init=nothing) where {T} if init !== nothing init isa TracedRNumber && (init = zero(unwrapped_eltype(init))) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 672c01a44b..8a5bb307ea 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -491,6 +491,7 @@ for (jlop, hloop) in ( (:(Base.log), :log), (:(Base.log1p), :log_plus_one), (:(Base.sqrt), :sqrt), + (:(Base.cbrt), :cbrt), (:(Base.acos), :acos), (:(Base.acosh), :acosh), (:(Base.asin), :asin), diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 23e8cb4390..fa8285e630 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -273,7 +273,7 @@ function overloaded_mul!( return C end -if isdefined(LinearAlgebra, :_triu) +@static if isdefined(LinearAlgebra, :_triu) function LinearAlgebra._triu(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T} return overloaded_triu(materialize_traced_array(A), k) end @@ -284,7 +284,7 @@ if isdefined(LinearAlgebra, :_triu) end end -if isdefined(LinearAlgebra, :_tril) +@static if isdefined(LinearAlgebra, :_tril) function LinearAlgebra._tril(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T} return overloaded_tril(materialize_traced_array(A), k) end diff --git a/test/autodiff.jl b/test/autodiff.jl index 4629cb5d62..2c3f8a0fac 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -366,3 +366,44 @@ end @test @jit(jvp_vjp_cubic(v_r, x_r, lambdas_r)) ≈ fill(6, (3, 2)) end + +@testset "Finite Difference Gradient" begin + x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float16, 2, 2)) + res = @jit Reactant.TestUtils.finite_difference_gradient(sum, x) + @test res isa Reactant.ConcreteRArray{Float16,2} +end + +function fdiff_multiple_args(f, nt, x) + return sum(abs2, f(nt.y .+ x .- nt.x)) +end + +struct WrapperFunc{T} + x::T +end + +(f::WrapperFunc)(x) = x .^ 3 .+ f.x + +@testset "Finite Difference Gradient (non vector inputs)" begin + nt = (; + x=Reactant.TestUtils.construct_test_array(Float64, 3, 4), + y=Reactant.TestUtils.construct_test_array(Float64, 3, 4), + ) + fn = WrapperFunc(Reactant.TestUtils.construct_test_array(Float64, 3, 4)) + x = Reactant.TestUtils.construct_test_array(Float64, 3, 4) + + nt_ra = Reactant.to_rarray(nt) + fn_ra = Reactant.to_rarray(fn) + x_ra = Reactant.to_rarray(x) + + results_fd = @jit Reactant.TestUtils.finite_difference_gradient( + fdiff_multiple_args, fn_ra, nt_ra, x_ra + ) + @test results_fd isa typeof((fn_ra, nt_ra, x_ra)) + + results_enz = @jit Enzyme.gradient(Reverse, fdiff_multiple_args, fn_ra, nt_ra, x_ra) + + @test results_fd[1].x ≈ results_enz[1].x + @test results_fd[2].x ≈ results_enz[2].x + @test results_fd[2].y ≈ results_enz[2].y + @test results_fd[3] ≈ results_enz[3] +end diff --git a/test/nn/luxlib.jl b/test/nn/luxlib.jl index 135a83f8f9..c862b6497e 100644 --- a/test/nn/luxlib.jl +++ b/test/nn/luxlib.jl @@ -8,24 +8,34 @@ using LuxLib, Reactant, Enzyme, NNlib, Test return Enzyme.gradient(Reverse, sumabs2fuseddense, Const(act), weight, x, bias)[2:end] end - function ∇fuseddense_fd(act, weight, x, bias) - dw = Reactant.TestUtils.finite_difference_gradient( - w -> sumabs2fuseddense(act, w, x, bias), weight - ) - dx = Reactant.TestUtils.finite_difference_gradient( - x -> sumabs2fuseddense(act, weight, x, bias), x - ) - db = if bias === nothing - nothing + function ∇fuseddense(actgradfn, act, weight, x, bias) + z = weight * x + if bias !== nothing + z .+= bias + end + Ω = act.(z) + + δ = 2 .* actgradfn.(Ω, z) .* Ω + + ∂weight = δ * x' + if bias !== nothing + ∂bias = vec(sum(δ; dims=2)) else - Reactant.TestUtils.finite_difference_gradient( - b -> sumabs2fuseddense(act, weight, x, b), bias - ) + ∂bias = nothing end - return dw, dx, db + ∂x = weight' * δ + return ∂weight, ∂x, ∂bias end - @testset for act in (identity, relu, sigmoid, tanh, gelu), has_bias in (true, false) + @testset "Activation: $act | bias=$has_bias" for (act, gradfn) in ( + (identity, (Ω, x) -> one(Ω)), + (relu, (Ω, x) -> (Ω > 0)), + (sigmoid, (Ω, x) -> conj((1 - Ω) * Ω)), + (tanh, (Ω, x) -> conj(1 - Ω^2)), + (gelu, (Ω, x) -> NNlib.deriv_gelu_tanh(x)), + ), + has_bias in (true, false) + weight = Reactant.TestUtils.construct_test_array(Float32, 9, 10) x = Reactant.TestUtils.construct_test_array(Float32, 10, 12) bias = has_bias ? Reactant.TestUtils.construct_test_array(Float32, 9) : nothing @@ -35,9 +45,7 @@ using LuxLib, Reactant, Enzyme, NNlib, Test bias_ra = Reactant.to_rarray(bias) y_compile = @jit fused_dense_bias_activation(act, weight_ra, x_ra, bias_ra) - y_res = fused_dense_bias_activation(act, weight, x, bias) - @test y_res ≈ y_compile atol = 1e-5 rtol = 1e-2 @testset "Enzyme: fused_dense_bias_activation" begin @@ -45,11 +53,11 @@ using LuxLib, Reactant, Enzyme, NNlib, Test act, weight_ra, x_ra, bias_ra ) - dw_fd, dx_fd, db_fd = @jit ∇fuseddense_fd(act, weight, x, bias) + dw_gt, dx_gt, db_gt = @jit ∇fuseddense(gradfn, act, weight_ra, x_ra, bias_ra) - @test dw_fd ≈ dw_compile atol = 1e-5 rtol = 1e-2 - @test dx_fd ≈ dx_compile atol = 1e-5 rtol = 1e-2 - has_bias && @test db_fd ≈ db_compile atol = 1e-5 rtol = 1e-2 + @test dw_gt ≈ dw_compile atol = 1e-5 rtol = 1e-2 + @test dx_gt ≈ dx_compile atol = 1e-5 rtol = 1e-2 + has_bias && @test db_gt ≈ db_compile atol = 1e-5 rtol = 1e-2 end end end @@ -64,30 +72,25 @@ end return Enzyme.gradient(Reverse, sumabs2biasact, Const(act), x, b)[2:end] end - function ∇biasact_fd(act, x, b) - dx = Reactant.TestUtils.finite_difference_gradient( - x -> sumabs2biasact(act, x, b), x - ) - db = Reactant.TestUtils.finite_difference_gradient( - b -> sumabs2biasact(act, x, b), b - ) - return dx, db - end - function ∇biasact!!(act, x, b) return Enzyme.gradient(Reverse, sumabs2biasact!!, Const(act), x, b)[2:end] end - function ∇biasact!!_fd(act, x, b) - dx = Reactant.TestUtils.finite_difference_gradient( - x -> sumabs2biasact!!(act, x, b), x - ) - db = Reactant.TestUtils.finite_difference_gradient( - b -> sumabs2biasact!!(act, x, b), b - ) - return dx, db + + function ∇biasact(gradfn, act, x, b) + xb = x .+ b + Ω = act.(xb) + ∂x = 2 .* gradfn.(Ω, xb) .* Ω + ∂b = vec(sum(∂x; dims=2)) + return ∂x, ∂b end - @testset for act in (identity, relu, sigmoid, tanh, gelu) + @testset "Activation: $act" for (act, gradfn) in ( + (identity, (Ω, x) -> one(Ω)), + (relu, (Ω, x) -> (Ω > 0)), + (sigmoid, (Ω, x) -> conj((1 - Ω) * Ω)), + (tanh, (Ω, x) -> conj(1 - Ω^2)), + (gelu, (Ω, x) -> NNlib.deriv_gelu_tanh(x)), + ) x = Reactant.TestUtils.construct_test_array(Float32, 10, 10) b = Reactant.TestUtils.construct_test_array(Float32, 10) @@ -103,20 +106,20 @@ end @test y_simple ≈ y_compile atol = 1e-5 rtol = 1e-2 @test y_simple!! ≈ y_compile!! atol = 1e-5 rtol = 1e-2 + ∂x_gt, ∂b_gt = @jit ∇biasact(gradfn, act, x_ra, b_ra) + @testset "Enzyme: bias_activation" begin - ∂x_enz, ∂b_enz = @jit ∇biasact_fd(act, x, b) - ∂x_compile, ∂b_compile = @jit ∇biasact(act, x_ra, b_ra) + ∂x_enz, ∂b_enz = @jit ∇biasact(act, x_ra, b_ra) - @test ∂x_enz ≈ ∂x_compile atol = 1e-5 rtol = 1e-2 - @test ∂b_enz ≈ ∂b_compile atol = 1e-5 rtol = 1e-2 + @test ∂x_enz ≈ ∂x_gt atol = 1e-5 rtol = 1e-2 + @test ∂b_enz ≈ ∂b_gt atol = 1e-5 rtol = 1e-2 end @testset "Enzyme: bias_activation!!" begin - ∂x_enz!!, ∂b_enz!! = @jit ∇biasact!!_fd(act, x, b) - ∂x_compile!!, ∂b_compile!! = @jit ∇biasact!!(act, x_ra, b_ra) + ∂x_enz!!, ∂b_enz!! = @jit ∇biasact!!(act, x_ra, b_ra) - @test ∂x_enz!! ≈ ∂x_compile!! atol = 1e-5 rtol = 1e-2 - @test ∂b_enz!! ≈ ∂b_compile!! atol = 1e-5 rtol = 1e-2 + @test ∂x_enz!! ≈ ∂x_gt atol = 1e-5 rtol = 1e-2 + @test ∂b_enz!! ≈ ∂b_gt atol = 1e-5 rtol = 1e-2 end end end @@ -127,16 +130,26 @@ end sumabs2!!(f, x) = sum(abs2, fast_activation!!(f, copy(x))) ∇sumabs2(f, x) = Enzyme.gradient(Reverse, sumabs2, Const(f), x)[2] - ∇sumabs2_fd(f, x) = Reactant.TestUtils.finite_difference_gradient(x -> sumabs2(f, x), x) ∇sumabs2!!(f, x) = Enzyme.gradient(Reverse, sumabs2!!, Const(f), x)[2] - ∇sumabs2!!_fd(f, x) = - Reactant.TestUtils.finite_difference_gradient(x -> sumabs2!!(f, x), x) + + function ∇sumabs2(gradfn, f, x) + Ω = f.(x) + return 2 .* gradfn.(Ω, x) .* Ω + end x_act = Reactant.TestUtils.construct_test_array(Float32, 10, 10) x_act_ca = Reactant.to_rarray(x_act) - @testset "Activation: $act" for act in ( - identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2 + @testset "Activation: $act" for (act, gradfn) in ( + (identity, (Ω, x) -> one(Ω)), + (relu, (Ω, x) -> (Ω > 0)), + (sigmoid, (Ω, x) -> conj((1 - Ω) * Ω)), + (tanh, (Ω, x) -> conj(1 - Ω^2)), + (tanh_fast, (Ω, x) -> conj(1 - Ω^2)), + (sigmoid_fast, (Ω, x) -> conj((1 - Ω) * Ω)), + (gelu, (Ω, x) -> NNlib.deriv_gelu_tanh(x)), + (abs2, (Ω, x) -> (2 * x)), + (relu6, (Ω, x) -> (Ω > 0) & (Ω < 6)), ) y_simple = sumabs2(act, x_act) y_simple!! = sumabs2!!(act, x_act) @@ -146,13 +159,12 @@ end @test y_simple ≈ y_compile atol = 1e-5 rtol = 1e-2 @test y_simple!! ≈ y_compile!! atol = 1e-5 rtol = 1e-2 - ∂x_enz = @jit ∇sumabs2_fd(act, x_act) - ∂x_compile = @jit ∇sumabs2(act, x_act_ca) - ∂x_enz!! = @jit ∇sumabs2!!_fd(act, x_act) - ∂x_compile!! = @jit ∇sumabs2!!(act, x_act_ca) + ∂x_enz = @jit ∇sumabs2(act, x_act_ca) + ∂x_enz!! = @jit ∇sumabs2!!(act, x_act_ca) + ∂x_gt = @jit ∇sumabs2(gradfn, act, x_act_ca) - @test ∂x_enz ≈ ∂x_compile atol = 1e-5 rtol = 1e-2 - @test ∂x_enz!! ≈ ∂x_compile!! atol = 1e-5 rtol = 1e-2 + @test ∂x_enz ≈ ∂x_gt atol = 1e-5 rtol = 1e-2 + @test ∂x_enz!! ≈ ∂x_gt atol = 1e-5 rtol = 1e-2 end end diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index ba1dcfd2f5..17fb16a7e7 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -3,22 +3,33 @@ using NNlib, Reactant, Enzyme, Statistics, Test @testset "Activation Functions" begin sumabs2(f, x) = sum(abs2, f.(x)) - x_act = Reactant.TestUtils.construct_test_array(Float32, 10, 10) + function ∇sumabs2(gradfn, f, x) + Ω = f.(x) + return 2 .* gradfn.(Ω, x) .* Ω + end + + x_act = Reactant.TestUtils.construct_test_array(Float32, 10, 10) .- 0.5f0 x_act_ca = Reactant.to_rarray(x_act) - @testset "Activation: $act" for act in ( - identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2, relu6 + @testset "Activation: $act" for (act, gradfn) in ( + (identity, (Ω, x) -> one(Ω)), + (relu, (Ω, x) -> (Ω > 0)), + (sigmoid, (Ω, x) -> conj((1 - Ω) * Ω)), + (tanh, (Ω, x) -> conj(1 - Ω^2)), + (tanh_fast, (Ω, x) -> conj(1 - Ω^2)), + (sigmoid_fast, (Ω, x) -> conj((1 - Ω) * Ω)), + (gelu, (Ω, x) -> NNlib.deriv_gelu_tanh(x)), + (abs2, (Ω, x) -> (2 * x)), + (relu6, (Ω, x) -> (Ω > 0) & (Ω < 6)), ) y_simple = sumabs2(act, x_act) y_compile = @jit sumabs2(act, x_act_ca) ∂x_compile = @jit(Enzyme.gradient(Reverse, sumabs2, Const(act), x_act_ca))[2] - ∂x_compile_fd = @jit Reactant.TestUtils.finite_difference_gradient( - Base.Fix1(sumabs2, act), x_act_ca - ) + ∂x_compile_gt = @jit ∇sumabs2(gradfn, act, x_act_ca) @test y_simple ≈ y_compile - @test ∂x_compile ≈ ∂x_compile_fd + @test ∂x_compile ≈ ∂x_compile_gt atol = 1e-3 rtol = 1e-3 end end