@@ -97,10 +97,20 @@ function JacobianCache(
9797 J, f, uf, fu, u, p, jac_cache, alg, 0 , autodiff, vjp_autodiff, jvp_autodiff)
9898end
9999
100- function JacobianCache (prob, alg, f:: F , :: Number , u:: Number , p; kwargs... ) where {F}
100+ function JacobianCache (
101+ prob, alg, f:: F , :: Number , u:: Number , p; autodiff = nothing , kwargs... ) where {F}
101102 uf = JacobianWrapper {false} (f, p)
103+ autodiff = get_concrete_forward_ad (autodiff, prob; check_reverse_mode = false )
104+ if ! (autodiff isa AutoForwardDiff ||
105+ autodiff isa AutoPolyesterForwardDiff ||
106+ autodiff isa AutoFiniteDiff)
107+ autodiff = AutoFiniteDiff ()
108+ # Other cases are not properly supported so we fallback to finite differencing
109+ @warn " Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \
110+ Detected $(autodiff) . Falling back to AutoFiniteDiff."
111+ end
102112 return JacobianCache {false} (
103- u, f, uf, u, u, p, nothing , alg, 0 , nothing , nothing , nothing )
113+ u, f, uf, u, u, p, nothing , alg, 0 , autodiff , nothing , nothing )
104114end
105115
106116@inline (cache:: JacobianCache )(u = cache. u) = cache (cache. J, u, cache. p)
@@ -115,7 +125,7 @@ function (cache::JacobianCache)(J::JacobianOperator, u, p = cache.p)
115125end
116126function (cache:: JacobianCache )(:: Number , u, p = cache. p) # Scalar
117127 cache. njacs += 1
118- J = last (__value_derivative (cache. uf, u))
128+ J = last (__value_derivative (cache. autodiff, cache . uf, u))
119129 return J
120130end
121131# Compute the Jacobian
@@ -181,12 +191,17 @@ end
181191 end
182192end
183193
184- @inline function __value_derivative (f:: F , x:: R ) where {F, R}
194+ @inline function __value_derivative (
195+ :: Union{AutoForwardDiff, AutoPolyesterForwardDiff} , f:: F , x:: R ) where {F, R}
185196 T = typeof (ForwardDiff. Tag (f, R))
186197 out = f (ForwardDiff. Dual {T} (x, one (x)))
187198 return ForwardDiff. value (out), ForwardDiff. extract_derivative (T, out)
188199end
189200
201+ @inline function __value_derivative (ad:: AutoFiniteDiff , f:: F , x:: R ) where {F, R}
202+ return f (x), FiniteDiff. finite_difference_derivative (f, x, ad. fdtype)
203+ end
204+
190205@inline function __scalar_jacvec (f:: F , x:: R , v:: V ) where {F, R, V}
191206 T = typeof (ForwardDiff. Tag (f, R))
192207 out = f (ForwardDiff. Dual {T} (x, v))
0 commit comments