9090 ϕdϕ
9191 method
9292 alpha
93- grad_op
93+ deriv_op
9494 u_cache
9595 fu_cache
9696 stats:: NLStats
@@ -110,25 +110,59 @@ function __internal_init(
110110 @warn " Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \
111111 Detected $(autodiff) . Falling back to AutoFiniteDiff."
112112 end
113- grad_op = @closure (u, fu, p) -> last (__value_derivative (
114- autodiff, Base. Fix2 (f, p), u)) * fu
113+ deriv_op = @closure (du, u, fu, p) -> last (__value_derivative (
114+ autodiff, Base. Fix2 (f, p), u)) *
115+ fu *
116+ du
115117 else
116- if SciMLBase. has_jvp (f)
118+ # Both forward and reverse AD can be used for line-search.
119+ # We prefer forward AD for better performance, however, reverse AD is also supported if user explicitly requests it.
120+ # 1. If jvp is available, we use forward AD;
121+ # 2. If vjp is available, we use reverse AD;
122+ # 3. If reverse type is requested, we use reverse AD;
123+ # 4. Finally, we use forward AD.
124+ if alg. autodiff isa AutoFiniteDiff
125+ deriv_op = nothing
126+ elseif SciMLBase. has_jvp (f)
117127 if isinplace (prob)
118- g_cache = __similar (u)
119- grad_op = @closure (u, fu, p) -> f. vjp (g_cache, fu, u, p)
128+ jvp_cache = __similar (fu)
129+ deriv_op = @closure (du, u, fu, p) -> begin
130+ f. jvp (jvp_cache, du, u, p)
131+ dot (fu, jvp_cache)
132+ end
120133 else
121- grad_op = @closure (u, fu, p) -> f . vjp (fu , u, p)
134+ deriv_op = @closure (du, u, fu, p) -> dot (fu, f . jvp (du , u, p) )
122135 end
123- else
136+ elseif SciMLBase. has_vjp (f)
137+ if isinplace (prob)
138+ vjp_cache = __similar (u)
139+ deriv_op = @closure (du, u, fu, p) -> begin
140+ f. vjp (vjp_cache, fu, u, p)
141+ dot (du, vjp_cache)
142+ end
143+ else
144+ deriv_op = @closure (du, u, fu, p) -> dot (du, f. vjp (fu, u, p))
145+ end
146+ elseif alg. autodiff != = nothing &&
147+ ADTypes. mode (alg. autodiff) isa ADTypes. ReverseMode
124148 autodiff = get_concrete_reverse_ad (
125149 alg. autodiff, prob; check_reverse_mode = true )
126150 vjp_op = VecJacOperator (prob, fu, u; autodiff)
127151 if isinplace (prob)
128- g_cache = __similar (u)
129- grad_op = @closure (u, fu, p) -> vjp_op (g_cache, fu, u, p)
152+ vjp_cache = __similar (u)
153+ deriv_op = @closure (du, u, fu, p) -> dot (du, vjp_op (vjp_cache, fu, u, p))
154+ else
155+ deriv_op = @closure (du, u, fu, p) -> dot (du, vjp_op (fu, u, p))
156+ end
157+ else
158+ autodiff = get_concrete_forward_ad (
159+ alg. autodiff, prob; check_forward_mode = true )
160+ jvp_op = JacVecOperator (prob, fu, u; autodiff)
161+ if isinplace (prob)
162+ jvp_cache = __similar (fu)
163+ deriv_op = @closure (du, u, fu, p) -> dot (fu, jvp_op (jvp_cache, du, u, p))
130164 else
131- grad_op = @closure (u, fu, p) -> vjp_op (fu, u, p)
165+ deriv_op = @closure (du, u, fu, p) -> dot (fu, jvp_op (du, u, p) )
132166 end
133167 end
134168 end
@@ -143,33 +177,37 @@ function __internal_init(
143177 return @fastmath internalnorm (fu_cache)^ 2 / 2
144178 end
145179
146- dϕ = @closure (f, p, u, du, α, u_cache, fu_cache, grad_op ) -> begin
180+ dϕ = @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op ) -> begin
147181 @bb @. u_cache = u + α * du
148182 fu_cache = evaluate_f!! (f, fu_cache, u_cache, p)
149183 stats. nf += 1
150- g₀ = grad_op (u_cache, fu_cache, p)
151- return dot (g₀, du)
184+ return deriv_op (du, u_cache, fu_cache, p)
152185 end
153186
154- ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, grad_op ) -> begin
187+ ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op ) -> begin
155188 @bb @. u_cache = u + α * du
156189 fu_cache = evaluate_f!! (f, fu_cache, u_cache, p)
157190 stats. nf += 1
158- g₀ = grad_op ( u_cache, fu_cache, p)
191+ deriv = deriv_op (du, u_cache, fu_cache, p)
159192 obj = @fastmath internalnorm (fu_cache)^ 2 / 2
160- return obj, dot (g₀, du)
193+ return obj, deriv
161194 end
162195
163196 return LineSearchesJLCache (f, p, ϕ, dϕ, ϕdϕ, alg. method, T (alg. initial_alpha),
164- grad_op , u_cache, fu_cache, stats)
197+ deriv_op , u_cache, fu_cache, stats)
165198end
166199
167200function __internal_solve! (cache:: LineSearchesJLCache , u, du; kwargs... )
168201 ϕ = @closure α -> cache. ϕ (cache. f, cache. p, u, du, α, cache. u_cache, cache. fu_cache)
169- dϕ = @closure α -> cache. dϕ (
170- cache. f, cache. p, u, du, α, cache. u_cache, cache. fu_cache, cache. grad_op)
171- ϕdϕ = @closure α -> cache. ϕdϕ (
172- cache. f, cache. p, u, du, α, cache. u_cache, cache. fu_cache, cache. grad_op)
202+ if cache. deriv_op != = nothing
203+ dϕ = @closure α -> cache. dϕ (
204+ cache. f, cache. p, u, du, α, cache. u_cache, cache. fu_cache, cache. deriv_op)
205+ ϕdϕ = @closure α -> cache. ϕdϕ (
206+ cache. f, cache. p, u, du, α, cache. u_cache, cache. fu_cache, cache. deriv_op)
207+ else
208+ dϕ = @closure α -> FiniteDiff. finite_difference_derivative (ϕ, α)
209+ ϕdϕ = @closure α -> (ϕ (α), FiniteDiff. finite_difference_derivative (ϕ, α))
210+ end
173211
174212 ϕ₀, dϕ₀ = ϕdϕ (zero (eltype (u)))
175213
0 commit comments