Skip to content

Commit 9afcb9a

Browse files
committed
fix: finite diff gradient accidental type promotion
1 parent 06778ec commit 9afcb9a

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

src/TestUtils.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ function construct_test_array(::Type{T}, dims::Int...) where {T}
2121
end
2222

2323
function finite_difference_gradient(
24-
f, x::AbstractArray{T}; epsilon=eps(T)^(3 / 4)
24+
f, x::AbstractArray{T}; epsilon=eps(T)^(T(3 / 4))
2525
) where {T}
2626
onehot_matrix = Reactant.promote_to(
27-
TracedRArray{Reactant.unwrapped_eltype(T),2}, LinearAlgebra.I(length(x))
27+
TracedRArray{Reactant.unwrapped_eltype(T),2},
28+
LinearAlgebra.Diagonal(fill(epsilon, length(x))),
2829
)
29-
perturbation = reshape(onehot_matrix .* epsilon, size(x)..., length(x))
30+
perturbation = reshape(onehot_matrix, size(x)..., length(x))
3031
f_input = cat(x .+ perturbation, x .- perturbation; dims=ndims(x) + 1)
3132

3233
f_evaluated = mapslices(f, f_input; dims=ntuple(identity, ndims(x)))

test/autodiff.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,9 @@ end
366366

367367
@test @jit(jvp_vjp_cubic(v_r, x_r, lambdas_r)) fill(6, (3, 2))
368368
end
369+
370+
@testset "Finite Difference Gradient" begin
371+
x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float16, 2, 2))
372+
res = @jit Reactant.TestUtils.finite_difference_gradient(sum, x)
373+
@test res isa Reactant.ConcreteRArray{Float16,2}
374+
end

0 commit comments

Comments
 (0)