Skip to content

Commit 1728e0f

Browse files
Match the other PR in form
1 parent 7d0b12c commit 1728e0f

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

ext/ForwardDiffStaticArraysExt.jl

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
module ForwardDiffStaticArraysExt
22

3-
using ForwardDiff
4-
using ForwardDiff.DiffResults: DiffResults, DiffResult, ImmutableDiffResult, MutableDiffResult
5-
using ForwardDiff.LinearAlgebra
6-
import ForwardDiff: Dual, Chunk, value, extract_jacobian!, extract_value!, extract_gradient!, extract_jacobian!,
7-
GradientConfig, JacobianConfig, HessianConfig, vector_mode_gradient, vector_mode_gradient!,
8-
Tag, valtype, partials, gradient, gradient!, jacobian, jacobian!, hessian, hessian!, vector_mode_jacobian,
9-
vector_mode_jacobian!
3+
using ForwardDiff, StaticArrays, LinearAlgebra, DiffResults
4+
using ForwardDiff: Dual, partials, GradientConfig, JacobianConfig, HessianConfig, Tag, Chunk,
5+
gradient, hessian, jacobian, gradient!, hessian!, jacobian!,
6+
extract_gradient!, extract_jacobian!, extract_value!,
7+
vector_mode_gradient, vector_mode_gradient!,
8+
vector_mode_jacobian, vector_mode_jacobian!, valtype, value, _lyap_div!
9+
using DiffResults: DiffResult, ImmutableDiffResult, MutableDiffResult
1010
using StaticArrays
1111

1212
@generated function dualize(::Type{T}, x::StaticArray) where T
@@ -22,6 +22,20 @@ end
2222

2323
@inline static_dual_eval(::Type{T}, f, x::StaticArray) where T = f(dualize(T, x))
2424

25+
function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
26+
λ,Q = eigen(Symmetric(value.(parent(A))))
27+
parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N)
28+
Dual{Tg}.(λ, tuple.(parts...))
29+
end
30+
31+
function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
32+
λ = eigvals(A)
33+
_,Q = eigen(Symmetric(value.(parent(A))))
34+
parts = ntuple(j -> Q*ForwardDiff._lyap_div!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
35+
Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
36+
end
37+
38+
# Gradient
2539
@inline ForwardDiff.gradient(f, x::StaticArray) = vector_mode_gradient(f, x)
2640
@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig) = gradient(f, x)
2741
@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient(f, x)
@@ -49,6 +63,7 @@ end
4963
return extract_gradient!(T, result, static_dual_eval(T, f, x))
5064
end
5165

66+
# Jacobian
5267
@inline ForwardDiff.jacobian(f, x::StaticArray) = vector_mode_jacobian(f, x)
5368
@inline ForwardDiff.jacobian(f, x::StaticArray, cfg::JacobianConfig) = jacobian(f, x)
5469
@inline ForwardDiff.jacobian(f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian(f, x)
@@ -93,6 +108,7 @@ end
93108
return result
94109
end
95110

111+
# Hessian
96112
ForwardDiff.hessian(f, x::StaticArray) = jacobian(y -> gradient(f, y), x)
97113
ForwardDiff.hessian(f, x::StaticArray, cfg::HessianConfig) = hessian(f, x)
98114
ForwardDiff.hessian(f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian(f, x)
@@ -118,17 +134,4 @@ function ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray)
118134
return result
119135
end
120136

121-
function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
122-
λ,Q = eigen(Symmetric(value.(parent(A))))
123-
parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N)
124-
Dual{Tg}.(λ, tuple.(parts...))
125-
end
126-
127-
function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
128-
λ = eigvals(A)
129-
_,Q = eigen(Symmetric(value.(parent(A))))
130-
parts = ntuple(j -> Q*ForwardDiff._lyap_div!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
131-
Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
132-
end
133-
134137
end

0 commit comments

Comments
 (0)