|
1 | 1 | import Base: +, -, *, /, \ |
2 | 2 |
|
3 | | -# TODO: more operators, like AbstractArray |
| 3 | +#-------------------------------------------------- |
| 4 | +# Vector space algebra |
4 | 5 |
|
5 | 6 | # Unary ops |
6 | 7 | @inline -(a::StaticArray) = map(-, a) |
@@ -30,10 +31,7 @@ import Base: +, -, *, /, \ |
30 | 31 | @inline -(a::UniformScaling, b::StaticMatrix) = _plus_uniform(Size(b), -b, a.λ) |
31 | 32 |
|
32 | 33 | @generated function _plus_uniform(::Size{S}, a::StaticMatrix, λ) where {S} |
33 | | - if S[1] != S[2] |
34 | | - throw(DimensionMismatch("matrix is not square: dimensions are $S")) |
35 | | - end |
36 | | - n = S[1] |
| 34 | + n = checksquare(a) |
37 | 35 | exprs = [i == j ? :(a[$(LinearIndices(S)[i, j])] + λ) : :(a[$(LinearIndices(S)[i, j])]) for i = 1:n, j = 1:n] |
38 | 36 | return quote |
39 | 37 | $(Expr(:meta, :inline)) |
|
46 | 44 | @inline \(a::UniformScaling, b::Union{StaticMatrix,StaticVector}) = a.λ \ b |
47 | 45 | @inline /(a::StaticMatrix, b::UniformScaling) = a / b.λ |
48 | 46 |
|
| 47 | +#-------------------------------------------------- |
| 48 | +# Matrix algebra |
49 | 49 |
|
50 | 50 | # Transpose, conjugate, etc |
51 | 51 | @inline conj(a::StaticArray) = map(conj, a) |
|
85 | 85 | @inline Base.zero(a::SA) where {SA <: StaticArray} = zeros(SA) |
86 | 86 | @inline Base.zero(a::Type{SA}) where {SA <: StaticArray} = zeros(SA) |
87 | 87 |
|
88 | | -@inline one(::SM) where {SM <: StaticMatrix} = _one(Size(SM), SM) |
89 | | -@inline one(::Type{SM}) where {SM <: StaticMatrix} = _one(Size(SM), SM) |
90 | | -@generated function _one(::Size{S}, ::Type{SM}) where {S, SM <: StaticArray} |
91 | | - if (length(S) != 2) || (S[1] != S[2]) |
92 | | - error("multiplicative identity defined only for square matrices") |
93 | | - end |
94 | | - T = eltype(SM) # should be "hyperpure" |
95 | | - if T == Any |
96 | | - T = Float64 |
97 | | - end |
98 | | - exprs = [i == j ? :(one($T)) : :(zero($T)) for i ∈ 1:S[1], j ∈ 1:S[2]] |
99 | | - return quote |
100 | | - $(Expr(:meta, :inline)) |
101 | | - SM(tuple($(exprs...))) |
| 88 | +@inline one(m::StaticMatrixLike) = _one(Size(m), m) |
| 89 | +@inline one(::Type{SM}) where {SM<:StaticMatrixLike}= _one(Size(SM), SM) |
| 90 | +function _one(s::Size, m_or_SM) |
| 91 | + if (length(s) != 2) || (s[1] != s[2]) |
| 92 | + throw(DimensionMismatch("multiplicative identity defined only for square matrices")) |
102 | 93 | end |
| 94 | + _scalar_matrix(s, m_or_SM, one(_eltype_or(m_or_SM, Float64))) |
103 | 95 | end |
104 | 96 |
|
105 | | -# StaticMatrix(I::UniformScaling) methods to replace eye |
106 | | -(::Type{SM})(I::UniformScaling) where {N,M,SM<:StaticMatrix{N,M}} = _eye(Size(SM), SM, I) |
107 | | - |
108 | | -@generated function _eye(::Size{S}, ::Type{SM}, I::UniformScaling{T}) where {S, SM <: StaticArray, T} |
109 | | - exprs = [i == j ? :(I.λ) : :(zero($T)) for i ∈ 1:S[1], j ∈ 1:S[2]] |
| 97 | +# StaticMatrix(I::UniformScaling) |
| 98 | +(::Type{SM})(I::UniformScaling) where {SM<:StaticMatrix} = _scalar_matrix(Size(SM), SM, I.λ) |
| 99 | +# The following oddity is needed if we want `SArray{Tuple{2,3}}(I)` to work |
| 100 | +# because we do not have `SArray{Tuple{2,3}} <: StaticMatrix`. |
| 101 | +(::Type{SM})(I::UniformScaling) where {SM<:(StaticArray{Tuple{N,M}} where {N,M})} = |
| 102 | + _scalar_matrix(Size(SM), SM, I.λ) |
| 103 | + |
| 104 | +# Construct a matrix with the scalar λ on the diagonal and zeros off the |
| 105 | +# diagonal. The matrix can be non-square. |
| 106 | +@generated function _scalar_matrix(s::Size{S}, m_or_SM, λ) where {S} |
| 107 | + elements = Symbol[i == j ? :λ : :λzero for i in 1:S[1], j in 1:S[2]] |
110 | 108 | return quote |
111 | 109 | $(Expr(:meta, :inline)) |
112 | | - SM(tuple($(exprs...))) |
| 110 | + λzero = zero(λ) |
| 111 | + _construct_similar(m_or_SM, s, tuple($(elements...))) |
113 | 112 | end |
114 | 113 | end |
115 | 114 |
|
|
145 | 144 | end |
146 | 145 | end |
147 | 146 |
|
| 147 | +#-------------------------------------------------- |
| 148 | +# Vector products |
148 | 149 | @inline cross(a::StaticVector, b::StaticVector) = _cross(same_size(a, b), a, b) |
149 | 150 | _cross(::Size{S}, a::StaticVector, b::StaticVector) where {S} = error("Cross product not defined for $(S[1])-vectors") |
150 | 151 | @inline function _cross(::Size{(2,)}, a::StaticVector, b::StaticVector) |
|
179 | 180 | return ret |
180 | 181 | end |
181 | 182 |
|
| 183 | +#-------------------------------------------------- |
| 184 | +# Norms |
182 | 185 | @inline LinearAlgebra.norm_sqr(v::StaticVector) = mapreduce(abs2, +, v; init=zero(real(eltype(v)))) |
183 | 186 |
|
184 | 187 | @inline norm(a::StaticArray) = _norm(Size(a), a) |
|
240 | 243 |
|
241 | 244 | @inline tr(a::StaticMatrix) = _tr(Size(a), a) |
242 | 245 | @generated function _tr(::Size{S}, a::StaticMatrix) where {S} |
243 | | - if S[1] != S[2] |
244 | | - throw(DimensionMismatch("matrix is not square")) |
245 | | - end |
| 246 | + checksquare(a) |
246 | 247 |
|
247 | 248 | if S[1] == 0 |
248 | 249 | return :(zero(eltype(a))) |
|
257 | 258 | end |
258 | 259 | end |
259 | 260 |
|
| 261 | + |
| 262 | +#-------------------------------------------------- |
| 263 | +# Outer products |
| 264 | + |
260 | 265 | const _length_limit = Length(200) |
261 | 266 |
|
262 | 267 | @inline kron(a::StaticMatrix, b::StaticMatrix) = _kron(_length_limit, Size(a), Size(b), a, b) |
|
414 | 419 | end |
415 | 420 | end |
416 | 421 |
|
417 | | -# some micro-optimizations (TODO check these make sense for v0.6+) |
418 | | -@inline LinearAlgebra.checksquare(::SM) where {SM<:StaticMatrix} = _checksquare(Size(SM)) |
419 | | -@inline LinearAlgebra.checksquare(::Type{SM}) where {SM<:StaticMatrix} = _checksquare(Size(SM)) |
420 | 422 |
|
421 | | -@pure _checksquare(::Size{S}) where {S} = (S[1] == S[2] || throw(DimensionMismatch("matrix is not square: dimensions are $S")); S[1]) |
| 423 | +#-------------------------------------------------- |
| 424 | +# Some shimming for special linear algebra matrix types |
| 425 | +@inline LinearAlgebra.Symmetric(A::StaticMatrix, uplo::Char='U') = (checksquare(A); Symmetric{eltype(A),typeof(A)}(A, uplo)) |
| 426 | +@inline LinearAlgebra.Hermitian(A::StaticMatrix, uplo::Char='U') = (checksquare(A); Hermitian{eltype(A),typeof(A)}(A, uplo)) |
422 | 427 |
|
423 | | -@inline LinearAlgebra.Symmetric(A::StaticMatrix, uplo::Char='U') = (LinearAlgebra.checksquare(A);Symmetric{eltype(A),typeof(A)}(A, uplo)) |
424 | | -@inline LinearAlgebra.Hermitian(A::StaticMatrix, uplo::Char='U') = (LinearAlgebra.checksquare(A);Hermitian{eltype(A),typeof(A)}(A, uplo)) |
|
0 commit comments