Skip to content

Commit 74f08fe

Browse files
committed
fix: use better epsilon
1 parent e1ae169 commit 74f08fe

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

src/Compiler.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,10 @@ function optimization_passes(
702702
dus_to_concat::Bool=false,
703703
recognize_comms::Bool=true,
704704
lower_comms::Bool=true,
705-
max_constant_threshold::Int=1024,
706705
backend::String="gpu",
707706
)
707+
(; max_constant_threshold) = compile_options
708+
708709
transform_passes_list = [
709710
"patterns=compare_op_canon<16>",
710711
"transpose_transpose<16>",

src/TestUtils.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,21 @@ function construct_test_array(::Type{T}, dims::Int...) where {T}
2020
return reshape(collect(T, 1:prod(dims)), dims...)
2121
end
2222

23+
# https://github.com/JuliaDiff/FiniteDiff.jl/blob/3a8c3d8d87e59de78e2831787a3f54b12b7c2075/src/epsilons.jl#L133
24+
function default_epslion(::Val{fdtype}, ::Type{T}) where {fdtype,T}
25+
if fdtype == :forward
26+
return sqrt(eps(real(T)))
27+
elseif fdtype == :central
28+
return cbrt(eps(real(T)))
29+
elseif fdtype == :hcentral
30+
return eps(T)^(T(1 / 4))
31+
else
32+
return one(real(T))
33+
end
34+
end
35+
2336
function finite_difference_gradient(
24-
f, x::AbstractArray{T}; epsilon=eps(T)^(T(3 / 4))
37+
f, x::AbstractArray{T}; epsilon=default_epslion(Val(:central), T)
2538
) where {T}
2639
onehot_matrix = Reactant.promote_to(
2740
TracedRArray{Reactant.unwrapped_eltype(T),2},

src/TracedRNumber.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ for (jlop, hloop) in (
491491
(:(Base.log), :log),
492492
(:(Base.log1p), :log_plus_one),
493493
(:(Base.sqrt), :sqrt),
494+
(:(Base.cbrt), :cbrt),
494495
(:(Base.acos), :acos),
495496
(:(Base.acosh), :acosh),
496497
(:(Base.asin), :asin),

0 commit comments

Comments
 (0)