11module LinearSolveForwardDiffExt
22
33using LinearSolve
4+ using LinearSolve: SciMLLinearSolveAlgorithm
45using LinearAlgebra
56using ForwardDiff
67using ForwardDiff: Dual, Partials
@@ -36,8 +37,14 @@ const DualAbstractLinearProblem = Union{
3637LinearSolve. @concrete mutable struct DualLinearCache
3738 linear_cache
3839 dual_type
40+
3941 partials_A
4042 partials_b
43+ partials_u
44+
45+ dual_A
46+ dual_b
47+ dual_u
4148end
4249
4350function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
@@ -55,16 +62,15 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
5562
5663 rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
5764
58- partial_cache = cache. linear_cache
59- partial_cache. u = dual_u0
60-
65+ cache. linear_cache. u = dual_u0
66+ # We can reuse the linear cache, because the same factorization will work for the partials.
6167 for i in eachindex (rhs_list)
62- partial_cache . b = rhs_list[i]
63- rhs_list[i] = copy (solve! (partial_cache , alg, args... ; kwargs... ). u)
68+ cache . linear_cache . b = rhs_list[i]
69+ rhs_list[i] = copy (solve! (cache . linear_cache , alg, args... ; kwargs... ). u)
6470 end
6571
66- # Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to
67- partial_cache . b = primal_b
72+ # Reset to the original `b` and `u` , users will expect that `b` doesn't change if they don't tell it to
73+ cache . linear_cache . b = primal_b
6874
6975 partial_sols = rhs_list
7076
@@ -96,35 +102,25 @@ function xp_linsolve_rhs(
96102 b_list
97103end
98104
99- #=
100- function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
101- return solve(prob, nothing, args...; kwargs...)
102- end
103-
104- function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
105- assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
106- return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
107- end
108-
109- function SciMLBase.solve(prob::DualAbstractLinearProblem,
110- alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
111- solve!(init(prob, alg, args...; kwargs...))
112- end
113- =#
114-
115105function linearsolve_dual_solution (
116106 u:: Number , partials, dual_type)
117107 return dual_type (u, partials)
118108end
119109
120- function linearsolve_dual_solution (
121- u:: AbstractArray , partials, dual_type)
110+ function linearsolve_dual_solution (u:: Number , partials,
111+ dual_type:: Type{<:Dual{T, V, P}} ) where {T, V, P}
112+ # Handle single-level duals
113+ return dual_type (u, partials)
114+ end
115+
116+ function linearsolve_dual_solution (u:: AbstractArray , partials,
117+ dual_type:: Type{<:Dual{T, V, P}} ) where {T, V, P}
118+ # Handle single-level duals for arrays
122119 partials_list = RecursiveArrayTools. VectorOfArray (partials)
123120 return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))),
124- zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
121+ zip (u, partials_list[i, :] for i in 1 : length (partials_list. u [1 ])))
125122end
126123
127- #=
128124function SciMLBase. init (
129125 prob:: DualAbstractLinearProblem , alg:: LinearSolve.SciMLLinearSolveAlgorithm ,
130126 args... ;
@@ -138,7 +134,6 @@ function SciMLBase.init(
138134 assumptions = OperatorAssumptions (issquare (prob. A)),
139135 sensealg = LinearSolveAdjoint (),
140136 kwargs... )
141-
142137 (; A, b, u0, p) = prob
143138 new_A = nodual_value (A)
144139 new_b = nodual_value (b)
@@ -147,7 +142,6 @@ function SciMLBase.init(
147142 ∂_A = partial_vals (A)
148143 ∂_b = partial_vals (b)
149144
150- #primal_prob = LinearProblem(new_A, new_b, u0 = new_u0)
151145 primal_prob = remake (prob; A = new_A, b = new_b, u0 = new_u0)
152146
153147 if get_dual_type (prob. A) != = nothing
@@ -156,48 +150,71 @@ function SciMLBase.init(
156150 dual_type = get_dual_type (prob. b)
157151 end
158152
153+ alg isa LinearSolve. DefaultLinearSolver ? real_alg = LinearSolve. defaultalg (primal_prob. A, primal_prob. b) : real_alg = alg
154+
159155 non_partial_cache = init (
160- primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
156+ primal_prob, real_alg, assumptions, args... ;
157+ alias = alias, abstol = abstol, reltol = reltol,
161158 maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
162159 sensealg = sensealg, u0 = new_u0, kwargs... )
163- return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b)
160+ return DualLinearCache (non_partial_cache, dual_type, ∂_A, ∂_b, ! isnothing (∂_b) ? zero .(∂_b) : ∂_b, A, b, zeros (dual_type, length (b)) )
164161end
165162
166163function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
164+ solve! (cache, cache. alg, args... ; kwargs... )
165+ end
166+
167+ function SciMLBase. solve! (cache:: DualLinearCache , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... )
167168 sol,
168169 partials = linearsolve_forwarddiff_solve (
169170 cache:: DualLinearCache , cache. alg, args... ; kwargs... )
170-
171171 dual_sol = linearsolve_dual_solution (sol. u, partials, cache. dual_type)
172+
173+ if cache. dual_u isa AbstractArray
174+ cache. dual_u[:] = dual_sol
175+ else
176+ cache. dual_u = dual_sol
177+ end
178+
172179 return SciMLBase. build_linear_solution (
173180 cache. alg, dual_sol, sol. resid, cache; sol. retcode, sol. iters, sol. stats
174181 )
175182end
176- =#
177183
178184# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
179- # Also "forwards" setproperty so that
180185function Base. setproperty! (dc:: DualLinearCache , sym:: Symbol , val)
181186 # If the property is A or b, also update it in the LinearCache
182187 if sym === :A || sym === :b || sym === :u
183188 setproperty! (dc. linear_cache, sym, nodual_value (val))
189+ elseif hasfield (DualLinearCache, sym)
190+ setfield! (dc, sym, val)
184191 elseif hasfield (LinearSolve. LinearCache, sym)
185192 setproperty! (dc. linear_cache, sym, val)
186193 end
187194
195+
188196 # Update the partials if setting A or b
189197 if sym === :A
198+ setfield! (dc, :dual_A , val)
190199 setfield! (dc, :partials_A , partial_vals (val))
191- elseif sym === :b
200+ elseif sym === :b
201+ setfield! (dc, :dual_b , val)
192202 setfield! (dc, :partials_b , partial_vals (val))
193- else
194- setfield! (dc, sym, val)
203+ elseif sym === :u
204+ setfield! (dc, :dual_u , val)
205+ setfield! (dc, :partials_u , partial_vals (val))
195206 end
196207end
197208
198209# "Forwards" getproperty to LinearCache if necessary
199210function Base. getproperty (dc:: DualLinearCache , sym:: Symbol )
200- if hasfield (LinearSolve. LinearCache, sym)
211+ if sym === :A
212+ dc. dual_A
213+ elseif sym === :b
214+ dc. dual_b
215+ elseif sym === :u
216+ dc. dual_u
217+ elseif hasfield (LinearSolve. LinearCache, sym)
201218 return getproperty (dc. linear_cache, sym)
202219 else
203220 return getfield (dc, sym)
@@ -206,31 +223,36 @@ end
206223
207224
208225
209- # Helper functions for Dual numbers
210- get_dual_type (x:: Dual ) = typeof (x)
226+ # Enhanced helper functions for Dual numbers to handle recursion
227+ get_dual_type (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = typeof (x)
228+ get_dual_type (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = typeof (x)
211229get_dual_type (x:: AbstractArray{<:Dual} ) = eltype (x)
212230get_dual_type (x) = nothing
213231
214- partial_vals (x:: Dual ) = ForwardDiff. partials (x)
232+ # Add recursive handling for nested dual partials
233+ partial_vals (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = ForwardDiff. partials (x)
234+ partial_vals (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = ForwardDiff. partials (x)
215235partial_vals (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. partials, x)
216236partial_vals (x) = nothing
217237
238+ # Add recursive handling for nested dual values
218239nodual_value (x) = x
219- nodual_value (x:: Dual ) = ForwardDiff. value (x)
220- nodual_value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
240+ nodual_value (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = ForwardDiff. value (x)
241+ nodual_value (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = x. value # Keep the inner dual intact
242+ nodual_value (x:: AbstractArray{<:Dual} ) = map (nodual_value, x)
221243
222244
223- function partials_to_list (partial_matrix:: Vector )
245+ function partials_to_list (partial_matrix:: AbstractVector{T} ) where {T}
224246 p = eachindex (first (partial_matrix))
225247 [[partial[i] for partial in partial_matrix] for i in p]
226248end
227249
228250function partials_to_list (partial_matrix)
229251 p = length (first (partial_matrix))
230252 m, n = size (partial_matrix)
231- res_list = fill (zeros (m, n), p)
253+ res_list = fill (zeros (typeof (partial_matrix[ 1 , 1 ][ 1 ]), m, n), p)
232254 for k in 1 : p
233- res = zeros (m, n)
255+ res = zeros (typeof (partial_matrix[ 1 , 1 ][ 1 ]), m, n)
234256 for i in 1 : m
235257 for j in 1 : n
236258 res[i, j] = partial_matrix[i, j][k]
243265
244266
245267end
268+
0 commit comments