@@ -26,15 +26,6 @@ Return the maximum of `a` and `b` if `x1 > x0`, otherwise return the minimum.
2626"""
2727__max_tdir (a, b, x0, x1) = ifelse (x1 > x0, max (a, b), min (a, b))
2828
29- __cvt_real (:: Type{T} , :: Nothing ) where {T} = nothing
30- __cvt_real (:: Type{T} , x) where {T} = real (T (x))
31-
32- _get_tolerance (η, :: Type{T} ) where {T} = __cvt_real (T, η)
33- function _get_tolerance (:: Nothing , :: Type{T} ) where {T}
34- η = real (oneunit (T)) * (eps (real (one (T))))^ (4 // 5 )
35- return _get_tolerance (η, T)
36- end
37-
3829__standard_tag (:: Nothing , x) = ForwardDiff. Tag (SimpleNonlinearSolveTag (), eltype (x))
3930__standard_tag (tag:: ForwardDiff.Tag , _) = tag
4031__standard_tag (tag, x) = ForwardDiff. Tag (tag, eltype (x))
@@ -60,6 +51,12 @@ function __get_jacobian_config(ad::AutoForwardDiff{CS}, f!, y, x) where {CS}
6051 return ForwardDiff. JacobianConfig (f!, y, x, ck, tag)
6152end
6253
54+ function __get_jacobian_config (ad:: AutoPolyesterForwardDiff{CS} , args... ) where {CS}
55+ x = last (args)
56+ return (CS === nothing || CS ≤ 0 ) ? __pick_forwarddiff_chunk (x) :
57+ ForwardDiff. Chunk {CS} ()
58+ end
59+
6360"""
6461 value_and_jacobian(ad, f, y, x, p, cache; J = nothing)
6562
@@ -81,6 +78,9 @@ function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F,
8178 FiniteDiff. finite_difference_jacobian! (J, _f, x, cache)
8279 _f (y, x)
8380 return y, J
81+ elseif ad isa AutoPolyesterForwardDiff
82+ __polyester_forwarddiff_jacobian! (_f, y, J, x, cache)
83+ return y, J
8484 else
8585 throw (ArgumentError (" Unsupported AD method: $(ad) " ))
8686 end
@@ -100,12 +100,18 @@ function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F,
100100 elseif ad isa AutoFiniteDiff
101101 J_fd = FiniteDiff. finite_difference_jacobian (_f, x, cache)
102102 return _f (x), J_fd
103+ elseif ad isa AutoPolyesterForwardDiff
104+ __polyester_forwarddiff_jacobian! (_f, J, x, cache)
105+ return _f (x), J
103106 else
104107 throw (ArgumentError (" Unsupported AD method: $(ad) " ))
105108 end
106109 end
107110end
108111
112+ # Declare functions
113+ function __polyester_forwarddiff_jacobian! end
114+
109115function value_and_jacobian (ad, f:: F , y, x:: Number , p, cache; J = nothing ) where {F}
110116 if DiffEqBase. has_jac (f)
111117 return f (x, p), f. jac (x, p)
@@ -132,7 +138,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray}
132138 J = similar (y, length (y), length (x))
133139 if DiffEqBase. has_jac (f)
134140 return J, nothing
135- elseif ad isa AutoForwardDiff
141+ elseif ad isa AutoForwardDiff || ad isa AutoPolyesterForwardDiff
136142 return J, __get_jacobian_config (ad, _f, y, x)
137143 elseif ad isa AutoFiniteDiff
138144 return J, FiniteDiff. JacobianCache (copy (x), copy (y), copy (y), ad. fdtype)
@@ -146,6 +152,10 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray}
146152 elseif ad isa AutoForwardDiff
147153 J = ArrayInterface. can_setindex (x) ? similar (y, length (y), length (x)) : nothing
148154 return J, __get_jacobian_config (ad, _f, x)
155+ elseif ad isa AutoPolyesterForwardDiff
156+ @assert ArrayInterface. can_setindex (x) " PolyesterForwardDiff requires mutable inputs."
157+ J = similar (y, length (y), length (x))
158+ return J, __get_jacobian_config (ad, _f, x)
149159 elseif ad isa AutoFiniteDiff
150160 return nothing , FiniteDiff. JacobianCache (copy (x), copy (y), copy (y), ad. fdtype)
151161 else
350360 (alias || ! ArrayInterface. can_setindex (typeof (x))) && return x
351361 return deepcopy (x)
352362end
363+
364+ # Decide which AD backend to use
365+ @inline __get_concrete_autodiff (prob, ad:: ADTypes.AbstractADType ) = ad
366+ @inline function __get_concrete_autodiff (prob, :: Nothing )
367+ if ForwardDiff. can_dual (eltype (prob. u0))
368+ if __is_extension_loaded (Val (:PolyesterForwardDiff )) && ! (prob. u0 isa Number) &&
369+ ArrayInterface. can_setindex (prob. u0)
370+ return AutoPolyesterForwardDiff ()
371+ else
372+ return AutoForwardDiff ()
373+ end
374+ else
375+ return AutoFiniteDiff ()
376+ end
377+ end
0 commit comments