@@ -8,18 +8,33 @@ jacobian_f(::Number, p::Number) = 1 / (2 * √p)
88jacobian_f (u, p:: Number ) = one .(u) .* (1 / (2 * √ p))
99jacobian_f (u, p:: AbstractArray ) = diagm (vec (@. 1 / (2 * √ p)))
1010
11- function solve_with (:: Val{iip } , u, alg) where {iip }
12- f = if iip
11+ function solve_with (:: Val{mode } , u, alg) where {mode }
12+ f = if mode === : iip
1313 solve_iip (p) = solve (NonlinearProblem (test_f!, u, p), alg). u
14- else
14+ elseif mode === :iip_cache
15+ function solve_iip_init (p)
16+ cache = SciMLBase. init (NonlinearProblem (test_f!, u, p), alg)
17+ return SciMLBase. solve! (cache). u
18+ end
19+ elseif mode === :oop
1520 solve_oop (p) = solve (NonlinearProblem (test_f, u, p), alg). u
21+ elseif mode === :oop_cache
22+ function solve_oop_init (p)
23+ cache = SciMLBase. init (NonlinearProblem (test_f, u, p), alg)
24+ return SciMLBase. solve! (cache). u
25+ end
1626 end
1727 return f
1828end
1929
20- __can_inplace (:: Number ) = false
21- __can_inplace (:: AbstractArray ) = true
22- __can_inplace (:: StaticArray ) = false
30+ __compatible (:: Any , :: Val{:oop} ) = true
31+ __compatible (:: Any , :: Val{:oop_cache} ) = true
32+ __compatible (:: Number , :: Val{:iip} ) = false
33+ __compatible (:: AbstractArray , :: Val{:iip} ) = true
34+ __compatible (:: StaticArray , :: Val{:iip} ) = false
35+ __compatible (:: Number , :: Val{:iip_cache} ) = false
36+ __compatible (:: AbstractArray , :: Val{:iip_cache} ) = true
37+ __compatible (:: StaticArray , :: Val{:iip_cache} ) = false
2338
2439__compatible (:: Any , :: Number ) = true
2540__compatible (:: Number , :: AbstractArray ) = false
@@ -32,37 +47,49 @@ __compatible(u::StaticArray, ::SciMLBase.AbstractNonlinearAlgorithm) = true
3247__compatible (u:: StaticArray , :: Union{CMINPACK, NLsolveJL} ) = false
3348__compatible (u, :: Nothing ) = true
3449
50+ __compatible (:: Any , :: Any ) = true
51+ __compatible (:: CMINPACK , :: Val{:iip_cache} ) = false
52+ __compatible (:: CMINPACK , :: Val{:oop_cache} ) = false
53+ __compatible (:: NLsolveJL , :: Val{:iip_cache} ) = false
54+ __compatible (:: NLsolveJL , :: Val{:oop_cache} ) = false
55+
3556@testset " ForwardDiff.jl Integration: $(alg) " for alg in (NewtonRaphson (), TrustRegion (),
3657 LevenbergMarquardt (), PseudoTransient (; alpha_initial = 10.0 ), Broyden (), Klement (),
3758 DFSane (), nothing , NLsolveJL (), CMINPACK ())
3859 us = (2.0 , @SVector [1.0 , 1.0 ], [1.0 , 1.0 ], ones (2 , 2 ), @SArray ones (2 , 2 ))
3960
4061 @testset " Scalar AD" begin
41- for p in 1.0 : 0.1 : 100.0
42- for u0 in us
43- __compatible (u0, alg) || continue
44- sol = solve (NonlinearProblem (test_f, u0, p), alg)
45- if SciMLBase. successful_retcode (sol)
46- gs = abs .(ForwardDiff. derivative (solve_with (Val {false} (), u0, alg), p))
47- gs_true = abs .(jacobian_f (u0, p))
48- if ! (isapprox (gs, gs_true, atol = 1e-5 ))
49- @show sol. retcode, sol. u
50- @error " ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg) " forwardiff_gradient= gs true_gradient= gs_true
51- else
52- @test abs .(gs)≈ abs .(gs_true) atol= 1e-5
53- end
62+ for p in 1.0 : 0.1 : 100.0 , u0 in us, mode in (:iip , :oop , :iip_cache , :oop_cache )
63+ __compatible (u0, alg) || continue
64+ __compatible (u0, Val (mode)) || continue
65+ __compatible (alg, Val (mode)) || continue
66+
67+ sol = solve (NonlinearProblem (test_f, u0, p), alg)
68+ if SciMLBase. successful_retcode (sol)
69+ gs = abs .(ForwardDiff. derivative (solve_with (Val {mode} (), u0, alg), p))
70+ gs_true = abs .(jacobian_f (u0, p))
71+ if ! (isapprox (gs, gs_true, atol = 1e-5 ))
72+ @show sol. retcode, sol. u
73+ @error " ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg) " forwardiff_gradient= gs true_gradient= gs_true
74+ else
75+ @test abs .(gs)≈ abs .(gs_true) atol= 1e-5
5476 end
5577 end
5678 end
5779 end
5880
5981 @testset " Jacobian" begin
60- for u0 in us, p in ([2.0 , 1.0 ], [2.0 1.0 ; 3.0 4.0 ])
82+ for u0 in us, p in ([2.0 , 1.0 ], [2.0 1.0 ; 3.0 4.0 ]),
83+ mode in (:iip , :oop , :iip_cache , :oop_cache )
84+
6185 __compatible (u0, p) || continue
6286 __compatible (u0, alg) || continue
87+ __compatible (u0, Val (mode)) || continue
88+ __compatible (alg, Val (mode)) || continue
89+
6390 sol = solve (NonlinearProblem (test_f, u0, p), alg)
6491 if SciMLBase. successful_retcode (sol)
65- gs = abs .(ForwardDiff. jacobian (solve_with (Val {false } (), u0, alg), p))
92+ gs = abs .(ForwardDiff. jacobian (solve_with (Val {mode } (), u0, alg), p))
6693 gs_true = abs .(jacobian_f (u0, p))
6794 if ! (isapprox (gs, gs_true, atol = 1e-5 ))
6895 @show sol. retcode, sol. u
0 commit comments