Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,9 +702,10 @@ function optimization_passes(
dus_to_concat::Bool=false,
recognize_comms::Bool=true,
lower_comms::Bool=true,
max_constant_threshold::Int=1024,
backend::String="gpu",
)
(; max_constant_threshold) = compile_options

transform_passes_list = [
"patterns=compare_op_canon<16>",
"transpose_transpose<16>",
Expand Down
161 changes: 147 additions & 14 deletions src/TestUtils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module TestUtils

using ..Reactant: Reactant, TracedRArray
using ..Reactant: Reactant, TracedRArray, TracedRNumber, TracedUtils
using Reactant.Ops: @opcall
using ReactantCore: ReactantCore
using LinearAlgebra: LinearAlgebra

Expand All @@ -20,22 +21,154 @@ function construct_test_array(::Type{T}, dims::Int...) where {T}
return reshape(collect(T, 1:prod(dims)), dims...)
end

function finite_difference_gradient(
f, x::AbstractArray{T}; epsilon=eps(T)^(3 / 4)
) where {T}
# https://github.com/JuliaDiff/FiniteDiff.jl/blob/3a8c3d8d87e59de78e2831787a3f54b12b7c2075/src/epsilons.jl#L133
function default_epslion(::Val{fdtype}, ::Type{T}) where {fdtype,T}
if fdtype == :forward
return sqrt(eps(real(T)))
elseif fdtype == :central
return cbrt(eps(real(T)))
elseif fdtype == :hcentral
return eps(T)^(T(1 / 4))
else
return one(real(T))
end
end

function get_perturbation(x::AbstractArray{T}, epsilon) where {T}
onehot_matrix = Reactant.promote_to(
TracedRArray{Reactant.unwrapped_eltype(T),2}, LinearAlgebra.I(length(x))
TracedRArray{Reactant.unwrapped_eltype(T),2},
LinearAlgebra.Diagonal(fill(epsilon, length(x)));
)
return permutedims(
reshape(onehot_matrix, size(x)..., length(x)), (ndims(x) + 1, 1:(ndims(x))...)
)
end

function generate_perturbed_array(::Val{:central}, x::AbstractArray{T}, epsilon) where {T}
perturbation = get_perturbation(x, epsilon)
x_ = reshape(x, 1, size(x)...)
return cat(x_ .+ perturbation, x_ .- perturbation; dims=1)
end

function generate_perturbed_array(::Val{:forward}, x::AbstractArray{T}, epsilon) where {T}
perturbation = get_perturbation(x, epsilon)
x_ = reshape(x, 1, size(x)...)
return cat(x_ .+ perturbation, x_; dims=1)
end

function finite_difference_gradient(
f::F, args...; method::Union{Val{:central},Val{:forward}}=Val(:central)
) where {F}
argprefix = gensym("finitediffarg")
resprefix = gensym("finitediffresult")
resargprefix = gensym("finitediffresarg")

# TODO: can we detect and prevent using functions that mutate their arguments?
mlir_fn_res = TracedUtils.make_mlir_fn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to do the full make_mlir_fn here, or can we reuse the equivalent of traced call or traced functions to achieve the same effect?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using make_mlir_fn mostly to linearlize the arguments. not sure how to use traced_call here though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah so it was linearization, not efficiency

f,
args,
(),
"finite_difference_gradient_fn",
false;
args_in_result=:none,
argprefix,
resprefix,
resargprefix,
)
perturbation = reshape(onehot_matrix .* epsilon, size(x)..., length(x))
f_input = cat(x .+ perturbation, x .- perturbation; dims=ndims(x) + 1)

f_evaluated = mapslices(f, f_input; dims=ntuple(identity, ndims(x)))
return ReactantCore.materialize_traced_array(
reshape(
(f_evaluated[1:length(x)] - f_evaluated[(length(x) + 1):end]) ./ (2 * epsilon),
size(x),
),

seenargs = Reactant.OrderedIdDict()
Reactant.make_tracer(seenargs, f, (argprefix,), Reactant.TracedSetPath)
for (i, arg) in enumerate(args)
Reactant.make_tracer(seenargs, arg, (argprefix, i), Reactant.TracedSetPath)
end

linear_args = Reactant.TracedType[]
for (k, v) in seenargs
v isa Reactant.TracedType || continue
push!(linear_args, v)
end

if (
length(mlir_fn_res.linear_results) != 1 ||
!(mlir_fn_res.linear_results[1] isa TracedRNumber)
)
error("`finite_difference_gradient` only supports functions with a single scalar \
output. Received : $(mlir_fn_res.linear_results)")
end

gradient_results = TracedRArray[]
gradient_result_map_path = []
for i in 1:length(linear_args)
arg = linear_args[i]
if arg isa TracedRArray && TracedUtils.has_idx(arg, argprefix)
path = TracedUtils.get_idx(arg, argprefix)
if mlir_fn_res.fnwrapped && length(path) > 1 && path[2] == 1
continue
end

# We need the gradient wrt this argument
# we will naively insert the args here, cse will take care of the rest
new_arguments = TracedRArray[]

epsilon = default_epslion(method, Reactant.unwrapped_eltype(arg))
pertubed_arg = generate_perturbed_array(method, arg, epsilon)

bsize = size(pertubed_arg, 1)
for j in 1:length(linear_args)
if i == j
new_arg = pertubed_arg
elseif linear_args[j] isa TracedRNumber
new_arg = @opcall broadcast_in_dim(
linear_args[j], Int64[], Int64[bsize]
)
else
new_arg = @opcall broadcast_in_dim(
linear_args[j],
collect(Int64, 2:(ndims(linear_args[j]) + 1)),
Int64[bsize, size(linear_args[j])...],
)
end
new_arg = @opcall transpose(new_arg, Int64[1, ((ndims(new_arg)):-1:2)...];)
push!(new_arguments, new_arg)
end

batched_res = @opcall batch(
new_arguments,
[
Reactant.MLIR.IR.TensorType(
Int64[bsize],
Reactant.MLIR.IR.Type(
Reactant.unwrapped_eltype(mlir_fn_res.linear_results[1])
),
),
],
Int64[bsize];
fn=mlir_fn_res.f,
)
batched_res = only(batched_res)

if method isa Val{:central}
diff = batched_res[1:(bsize ÷ 2)] - batched_res[((bsize ÷ 2) + 1):end]
grad_res = diff ./ (2 * epsilon)
elseif method isa Val{:forward}
diff = batched_res[1:(end - 1)] .- batched_res[end:end]
grad_res = diff ./ epsilon
end

push!(gradient_result_map_path, TracedUtils.get_idx(arg, argprefix))
push!(
gradient_results,
ReactantCore.materialize_traced_array(reshape(grad_res, size(arg))),
)
end
end

results = deepcopy(args)
for (path, grad_res) in zip(gradient_result_map_path, gradient_results)
TracedUtils.set!(results, path[2:end], grad_res.mlir_data)
end
length(args) == 1 && return results[1]
return results
end

end
11 changes: 10 additions & 1 deletion src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ Base.convert(T::Type{<:TracedRArray}, x::AbstractArray) = Reactant.promote_to(T,
Base.complex(x::TracedRArray{<:Real}) = complex.(x)
Base.complex(x::TracedRArray{<:Complex}) = x

function Base.deepcopy_internal(x::TracedRArray, stackdict::IdDict)
if haskey(stackdict, x)
return stackdict[x]::typeof(x)
end
y = copy(x)
stackdict[x] = y
return y
end

TracedRArray{T,N}(x::AbstractArray) where {T,N} = convert(TracedRArray{T,N}, x)

function maybe_assert_scalar_setindexing(
Expand Down Expand Up @@ -1109,7 +1118,7 @@ function Base.accumulate_pairwise!(op, A::AnyTracedRVector, B::AnyTracedRVector)
return accumulate!(op, A, B; dims=1)
end

if isdefined(Base, :_accumulate_promote_op)
@static if isdefined(Base, :_accumulate_promote_op)
function Base._accumulate_promote_op(op, A::AnyTracedRArray{T}; init=nothing) where {T}
if init !== nothing
init isa TracedRNumber && (init = zero(unwrapped_eltype(init)))
Expand Down
1 change: 1 addition & 0 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ for (jlop, hloop) in (
(:(Base.log), :log),
(:(Base.log1p), :log_plus_one),
(:(Base.sqrt), :sqrt),
(:(Base.cbrt), :cbrt),
(:(Base.acos), :acos),
(:(Base.acosh), :acosh),
(:(Base.asin), :asin),
Expand Down
4 changes: 2 additions & 2 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ function overloaded_mul!(
return C
end

if isdefined(LinearAlgebra, :_triu)
@static if isdefined(LinearAlgebra, :_triu)
function LinearAlgebra._triu(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T}
return overloaded_triu(materialize_traced_array(A), k)
end
Expand All @@ -284,7 +284,7 @@ if isdefined(LinearAlgebra, :_triu)
end
end

if isdefined(LinearAlgebra, :_tril)
@static if isdefined(LinearAlgebra, :_tril)
function LinearAlgebra._tril(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T}
return overloaded_tril(materialize_traced_array(A), k)
end
Expand Down
41 changes: 41 additions & 0 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,44 @@ end

@test @jit(jvp_vjp_cubic(v_r, x_r, lambdas_r)) ≈ fill(6, (3, 2))
end

@testset "Finite Difference Gradient" begin
x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float16, 2, 2))
res = @jit Reactant.TestUtils.finite_difference_gradient(sum, x)
@test res isa Reactant.ConcreteRArray{Float16,2}
end

function fdiff_multiple_args(f, nt, x)
return sum(abs2, f(nt.y .+ x .- nt.x))
end

struct WrapperFunc{T}
x::T
end

(f::WrapperFunc)(x) = x .^ 3 .+ f.x

@testset "Finite Difference Gradient (non vector inputs)" begin
nt = (;
x=Reactant.TestUtils.construct_test_array(Float64, 3, 4),
y=Reactant.TestUtils.construct_test_array(Float64, 3, 4),
)
fn = WrapperFunc(Reactant.TestUtils.construct_test_array(Float64, 3, 4))
x = Reactant.TestUtils.construct_test_array(Float64, 3, 4)

nt_ra = Reactant.to_rarray(nt)
fn_ra = Reactant.to_rarray(fn)
x_ra = Reactant.to_rarray(x)

results_fd = @jit Reactant.TestUtils.finite_difference_gradient(
fdiff_multiple_args, fn_ra, nt_ra, x_ra
)
@test results_fd isa typeof((fn_ra, nt_ra, x_ra))

results_enz = @jit Enzyme.gradient(Reverse, fdiff_multiple_args, fn_ra, nt_ra, x_ra)

@test results_fd[1].x ≈ results_enz[1].x
@test results_fd[2].x ≈ results_enz[2].x
@test results_fd[2].y ≈ results_enz[2].y
@test results_fd[3] ≈ results_enz[3]
end
41 changes: 11 additions & 30 deletions test/nn/luxlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,9 @@ using LuxLib, Reactant, Enzyme, NNlib, Test
end

function ∇fuseddense_fd(act, weight, x, bias)
dw = Reactant.TestUtils.finite_difference_gradient(
w -> sumabs2fuseddense(act, w, x, bias), weight
return Reactant.TestUtils.finite_difference_gradient(
(w, x, b) -> sumabs2fuseddense(act, w, x, b), weight, x, bias
)
dx = Reactant.TestUtils.finite_difference_gradient(
x -> sumabs2fuseddense(act, weight, x, bias), x
)
db = if bias === nothing
nothing
else
Reactant.TestUtils.finite_difference_gradient(
b -> sumabs2fuseddense(act, weight, x, b), bias
)
end
return dw, dx, db
end

@testset for act in (identity, relu, sigmoid, tanh, gelu), has_bias in (true, false)
Expand All @@ -45,7 +34,7 @@ using LuxLib, Reactant, Enzyme, NNlib, Test
act, weight_ra, x_ra, bias_ra
)

dw_fd, dx_fd, db_fd = @jit ∇fuseddense_fd(act, weight, x, bias)
dw_fd, dx_fd, db_fd = @jit ∇fuseddense_fd(act, weight_ra, x_ra, bias_ra)

@test dw_fd ≈ dw_compile atol = 1e-5 rtol = 1e-2
@test dx_fd ≈ dx_compile atol = 1e-5 rtol = 1e-2
Expand All @@ -65,26 +54,18 @@ end
end

function ∇biasact_fd(act, x, b)
dx = Reactant.TestUtils.finite_difference_gradient(
x -> sumabs2biasact(act, x, b), x
return Reactant.TestUtils.finite_difference_gradient(
(x, b) -> sumabs2biasact(act, x, b), x, b
)
db = Reactant.TestUtils.finite_difference_gradient(
b -> sumabs2biasact(act, x, b), b
)
return dx, db
end

function ∇biasact!!(act, x, b)
return Enzyme.gradient(Reverse, sumabs2biasact!!, Const(act), x, b)[2:end]
end
function ∇biasact!!_fd(act, x, b)
dx = Reactant.TestUtils.finite_difference_gradient(
x -> sumabs2biasact!!(act, x, b), x
)
db = Reactant.TestUtils.finite_difference_gradient(
b -> sumabs2biasact!!(act, x, b), b
return Reactant.TestUtils.finite_difference_gradient(
(x, b) -> sumabs2biasact!!(act, x, b), x, b
)
return dx, db
end

@testset for act in (identity, relu, sigmoid, tanh, gelu)
Expand All @@ -104,15 +85,15 @@ end
@test y_simple!! ≈ y_compile!! atol = 1e-5 rtol = 1e-2

@testset "Enzyme: bias_activation" begin
∂x_enz, ∂b_enz = @jit ∇biasact_fd(act, x, b)
∂x_enz, ∂b_enz = @jit ∇biasact_fd(act, x_ra, b_ra)
∂x_compile, ∂b_compile = @jit ∇biasact(act, x_ra, b_ra)

@test ∂x_enz ≈ ∂x_compile atol = 1e-5 rtol = 1e-2
@test ∂b_enz ≈ ∂b_compile atol = 1e-5 rtol = 1e-2
end

@testset "Enzyme: bias_activation!!" begin
∂x_enz!!, ∂b_enz!! = @jit ∇biasact!!_fd(act, x, b)
∂x_enz!!, ∂b_enz!! = @jit ∇biasact!!_fd(act, x_ra, b_ra)
∂x_compile!!, ∂b_compile!! = @jit ∇biasact!!(act, x_ra, b_ra)

@test ∂x_enz!! ≈ ∂x_compile!! atol = 1e-5 rtol = 1e-2
Expand Down Expand Up @@ -146,9 +127,9 @@ end
@test y_simple ≈ y_compile atol = 1e-5 rtol = 1e-2
@test y_simple!! ≈ y_compile!! atol = 1e-5 rtol = 1e-2

∂x_enz = @jit ∇sumabs2_fd(act, x_act)
∂x_enz = @jit ∇sumabs2_fd(act, x_act_ca)
∂x_compile = @jit ∇sumabs2(act, x_act_ca)
∂x_enz!! = @jit ∇sumabs2!!_fd(act, x_act)
∂x_enz!! = @jit ∇sumabs2!!_fd(act, x_act_ca)
∂x_compile!! = @jit ∇sumabs2!!(act, x_act_ca)

@test ∂x_enz ≈ ∂x_compile atol = 1e-5 rtol = 1e-2
Expand Down
4 changes: 2 additions & 2 deletions test/nn/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using NNlib, Reactant, Enzyme, Statistics, Test
@testset "Activation Functions" begin
sumabs2(f, x) = sum(abs2, f.(x))

x_act = Reactant.TestUtils.construct_test_array(Float32, 10, 10)
x_act = Reactant.TestUtils.construct_test_array(Float32, 10, 10) .- 0.5f0
x_act_ca = Reactant.to_rarray(x_act)

@testset "Activation: $act" for act in (
Expand All @@ -18,7 +18,7 @@ using NNlib, Reactant, Enzyme, Statistics, Test
)

@test y_simple ≈ y_compile
@test ∂x_compile ≈ ∂x_compile_fd
@test ∂x_compile ≈ ∂x_compile_fd atol = 1e-3 rtol = 1e-3
end
end

Expand Down
Loading