Skip to content

Commit ff56092

Browse files
authored
Add support for StaticArrays >= 1.7 (#703)
1 parent 1907196 commit ff56092

File tree

5 files changed

+80
-18
lines changed

5 files changed

+80
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ LogExpFunctions = "0.3"
3232
NaNMath = "1"
3333
Preferences = "1"
3434
SpecialFunctions = "1, 2"
35-
StaticArrays = "1.5 - 1.6"
35+
StaticArrays = "1.5"
3636
julia = "1.6"
3737

3838
[extras]

ext/ForwardDiffStaticArraysExt.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using ForwardDiff: Dual, partials, GradientConfig, JacobianConfig, HessianConfig
77
gradient, hessian, jacobian, gradient!, hessian!, jacobian!,
88
extract_gradient!, extract_jacobian!, extract_value!,
99
vector_mode_gradient, vector_mode_gradient!,
10-
vector_mode_jacobian, vector_mode_jacobian!, valtype, value, _lyap_div!
10+
vector_mode_jacobian, vector_mode_jacobian!, valtype, value
1111
using DiffResults: DiffResult, ImmutableDiffResult, MutableDiffResult
1212

1313
@generated function dualize(::Type{T}, x::StaticArray) where T
@@ -23,19 +23,17 @@ end
2323

2424
@inline static_dual_eval(::Type{T}, f, x::StaticArray) where T = f(dualize(T, x))
2525

26+
# To fix method ambiguity issues:
2627
function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
27-
λ,Q = eigen(Symmetric(value.(parent(A))))
28-
parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N)
29-
Dual{Tg}.(λ, tuple.(parts...))
28+
return ForwardDiff._eigvals(A)
3029
end
31-
3230
function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
33-
λ = eigvals(A)
34-
_,Q = eigen(Symmetric(value.(parent(A))))
35-
parts = ntuple(j -> Q*_lyap_div!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
36-
Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
31+
return ForwardDiff._eigen(A)
3732
end
3833

34+
# For `MMatrix` we can use the in-place method
35+
ForwardDiff._lyap_div!!(A::StaticArrays.MMatrix, λ::AbstractVector) = ForwardDiff._lyap_div!(A, λ)
36+
3937
# Gradient
4038
@inline ForwardDiff.gradient(f, x::StaticArray) = vector_mode_gradient(f, x)
4139
@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig) = gradient(f, x)

src/dual.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,11 @@ end
719719
# Symmetric eigvals #
720720
#-------------------#
721721

722-
function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
722+
# To be able to reuse this default definition in the StaticArrays extension
723+
# (has to be re-defined to avoid method ambiguity issues)
724+
# we forward the call to an internal method that can be shared and reused
725+
LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N} = _eigvals(A)
726+
function _eigvals(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
723727
λ,Q = eigen(Symmetric(value.(parent(A))))
724728
parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N)
725729
Dual{Tg}.(λ, tuple.(parts...))
@@ -737,8 +741,19 @@ function LinearAlgebra.eigvals(A::SymTridiagonal{<:Dual{Tg,T,N}}) where {Tg,T<:R
737741
Dual{Tg}.(λ, tuple.(parts...))
738742
end
739743

740-
# A ./ (λ - λ') but with diag special cased
741-
function _lyap_div!(A, λ)
744+
# A ./ (λ' .- λ) but with diag special cased
745+
# Default out-of-place method
746+
function _lyap_div!!(A::AbstractMatrix, λ::AbstractVector)
747+
return map(
748+
(a, b, idx) -> a / (idx[1] == idx[2] ? oneunit(b) : b),
749+
A,
750+
λ' .- λ,
751+
CartesianIndices(A),
752+
)
753+
end
754+
# For `Matrix` (and e.g. `StaticArrays.MMatrix`) we can use an in-place method
755+
_lyap_div!!(A::Matrix, λ::AbstractVector) = _lyap_div!(A, λ)
756+
function _lyap_div!(A::AbstractMatrix, λ::AbstractVector)
742757
for (j,μ) in enumerate(λ), (k,λ) in enumerate(λ)
743758
if k j
744759
A[k,j] /= μ - λ
@@ -747,17 +762,21 @@ function _lyap_div!(A, λ)
747762
A
748763
end
749764

750-
function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
765+
# To be able to reuse this default definition in the StaticArrays extension
766+
# (has to be re-defined to avoid method ambiguity issues)
767+
# we forward the call to an internal method that can be shared and reused
768+
LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N} = _eigen(A)
769+
function _eigen(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
751770
λ = eigvals(A)
752771
_,Q = eigen(Symmetric(value.(parent(A))))
753-
parts = ntuple(j -> Q*_lyap_div!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
772+
parts = ntuple(j -> Q*_lyap_div!!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
754773
Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
755774
end
756775

757776
function LinearAlgebra.eigen(A::SymTridiagonal{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
758777
λ = eigvals(A)
759778
_,Q = eigen(SymTridiagonal(value.(parent(A))))
760-
parts = ntuple(j -> Q*_lyap_div!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
779+
parts = ntuple(j -> Q*_lyap_div!!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
761780
Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
762781
end
763782

test/JacobianTest.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,14 @@ end
243243
@test ForwardDiff.jacobian(ev1, x0) Calculus.finite_difference_jacobian(ev1, x0)
244244
ev2(x) = eigen(SymTridiagonal(x, x[1:end-1])).vectors[:,1]
245245
@test ForwardDiff.jacobian(ev2, x0) Calculus.finite_difference_jacobian(ev2, x0)
246-
x0_static = SVector{2}(x0)
247-
@test ForwardDiff.jacobian(ev1, x0_static) Calculus.finite_difference_jacobian(ev1, x0)
246+
247+
x0_svector = SVector{2}(x0)
248+
@test ForwardDiff.jacobian(ev1, x0_svector) isa SMatrix{2, 2}
249+
@test ForwardDiff.jacobian(ev1, x0_svector) Calculus.finite_difference_jacobian(ev1, x0)
250+
251+
x0_mvector = MVector{2}(x0)
252+
@test ForwardDiff.jacobian(ev1, x0_mvector) isa MMatrix{2, 2}
253+
@test ForwardDiff.jacobian(ev1, x0_mvector) Calculus.finite_difference_jacobian(ev1, x0)
248254
end
249255

250256
@testset "type stability" begin

test/MiscTest.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ using Test
66
using ForwardDiff
77
using DiffTests
88
using SparseArrays: sparse
9+
using StaticArrays
910
using IrrationalConstants
11+
using LinearAlgebra
1012

1113
include(joinpath(dirname(@__FILE__), "utils.jl"))
1214

@@ -180,4 +182,41 @@ end
180182
# example from https://github.com/JuliaDiff/DiffRules.jl/pull/98#issuecomment-1574420052
181183
@test only(ForwardDiff.hessian(t -> abs(t[1])^2, [0.0])) == 2
182184

185+
@testset "_lyap_div!!" begin
186+
# In-place version for `Matrix`
187+
A = rand(3, 3)
188+
Acopy = copy(A)
189+
λ = rand(3)
190+
B = @inferred(ForwardDiff._lyap_div!!(A, λ))
191+
@test B === A
192+
@test B[diagind(B)] == Acopy[diagind(Acopy)]
193+
no_diag(X) = [X[i] for i in eachindex(X) if !(i in diagind(X))]
194+
@test no_diag(B) == no_diag(Acopy ./' .- λ))
195+
196+
# Immutable static arrays
197+
A_smatrix = SMatrix{3,3}(Acopy)
198+
λ_svector = SVector{3}(λ)
199+
B_smatrix = @inferred(ForwardDiff._lyap_div!!(A_smatrix, λ_svector))
200+
@test B_smatrix !== A_smatrix
201+
@test B_smatrix isa SMatrix{3,3}
202+
@test B_smatrix == B
203+
λ_mvector = MVector{3}(λ)
204+
B_smatrix = @inferred(ForwardDiff._lyap_div!!(A_smatrix, λ_mvector))
205+
@test B_smatrix !== A_smatrix
206+
@test B_smatrix isa SMatrix{3,3}
207+
@test B_smatrix == B
208+
209+
# Mutable static arrays
210+
A_mmatrix = MMatrix{3,3}(Acopy)
211+
λ_svector = SVector{3}(λ)
212+
B_mmatrix = @inferred(ForwardDiff._lyap_div!!(A_mmatrix, λ_svector))
213+
@test B_mmatrix === A_mmatrix
214+
@test B_mmatrix == B
215+
A_mmatrix = MMatrix{3,3}(Acopy)
216+
λ_mvector = MVector{3}(λ)
217+
B_mmatrix = @inferred(ForwardDiff._lyap_div!!(A_mmatrix, λ_mvector))
218+
@test B_mmatrix === A_mmatrix
219+
@test B_mmatrix == B
220+
end
221+
183222
end # module

0 commit comments

Comments
 (0)