@@ -36,12 +36,13 @@ const DualAbstractLinearProblem = Union{
3636LinearSolve. @concrete mutable struct DualLinearCache
3737 linear_cache
3838 dual_type
39- dual_u0
4039 partials_A
4140 partials_b
4241end
4342
4443function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
44+ # Solve the primal problem
45+ dual_u0 = copy (cache. linear_cache. u)
4546 sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
4647 primal_b = copy (cache. linear_cache. b)
4748 uu = sol. u
@@ -51,7 +52,6 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
5152 # Solves Dual partials separately
5253 ∂_A = cache. partials_A
5354 ∂_b = cache. partials_b
54- dual_u0 = ! isnothing (cache. dual_u0) ? only (partials_to_list (cache. dual_u0)) : cache. linear_cache. u
5555
5656 rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
5757
@@ -137,14 +137,12 @@ function SciMLBase.init(
137137 kwargs... )
138138
139139 (; A, b, u0, p) = prob
140-
141140 new_A = nodual_value (A)
142141 new_b = nodual_value (b)
143142 new_u0 = nodual_value (u0)
144143
145144 ∂_A = partial_vals (A)
146145 ∂_b = partial_vals (b)
147- dual_u0 = partial_vals (u0)
148146
149147 primal_prob = LinearProblem (new_A, new_b, u0 = new_u0)
150148 # remake(prob; A = new_A, b = new_b, u0 = new_u0)
@@ -159,7 +157,7 @@ function SciMLBase.init(
159157 primal_prob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
160158 maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
161159 sensealg = sensealg, u0 = new_u0, kwargs... )
162- return DualLinearCache (non_partial_cache, dual_type, dual_u0, ∂_A, ∂_b)
160+ return DualLinearCache (non_partial_cache, dual_type, ∂_A, ∂_b)
163161end
164162
165163function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
@@ -168,9 +166,8 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
168166 cache:: DualLinearCache , cache. alg, args... ; kwargs... )
169167
170168 dual_sol = linearsolve_dual_solution (sol. u, partials, cache. dual_type)
171-
172169 return SciMLBase. build_linear_solution (
173- cache. alg, dual_sol, sol. resid, sol . cache; sol. retcode, sol. iters, sol. stats
170+ cache. alg, dual_sol, sol. resid, cache; sol. retcode, sol. iters, sol. stats
174171 )
175172end
176173
0 commit comments