11using DiagonalArrays: diagonaltype
22using MatrixAlgebraKit:
3- MatrixAlgebraKit, check_input, default_svd_algorithm, svd_compact!, svd_full!
3+ MatrixAlgebraKit, check_input, default_svd_algorithm, svd_compact!, svd_full!, svd_vals!
44using TypeParameterAccessors: realtype
55
66function MatrixAlgebraKit. default_svd_algorithm (
@@ -15,7 +15,15 @@ function output_type(
1515 f:: Union{typeof(svd_compact!),typeof(svd_full!)} , A:: Type{<:AbstractMatrix{T}}
1616) where {T}
1717 USVᴴ = Base. promote_op (f, A)
18- return isconcretetype (USVᴴ) ? USVᴴ : Tuple{AbstractMatrix{T},AbstractMatrix{realtype (T)},AbstractMatrix{T}}
18+ return if isconcretetype (USVᴴ)
19+ USVᴴ
20+ else
21+ Tuple{AbstractMatrix{T},AbstractMatrix{realtype (T)},AbstractMatrix{T}}
22+ end
23+ end
24+ function output_type (:: typeof (svd_vals!), A:: Type{<:AbstractMatrix{T}} ) where {T}
25+ S = Base. promote_op (svd_vals!, A)
26+ return isconcretetype (S) ? S : AbstractVector{real (T)}
1927end
2028
2129function MatrixAlgebraKit. initialize_output (
@@ -46,7 +54,6 @@ function MatrixAlgebraKit.initialize_output(
4654)
4755 return nothing
4856end
49-
5057function MatrixAlgebraKit. initialize_output (
5158 :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
5259)
@@ -58,6 +65,24 @@ function MatrixAlgebraKit.initialize_output(
5865 return U, S, Vᴴ
5966end
6067
68+ function MatrixAlgebraKit. initialize_output (
69+ :: typeof (svd_vals!), :: AbstractBlockSparseMatrix , :: BlockDiagonalAlgorithm
70+ )
71+ return nothing
72+ end
73+ function MatrixAlgebraKit. initialize_output (
74+ :: typeof (svd_vals!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
75+ )
76+ brows = eachblockaxis (axes (A, 1 ))
77+ bcols = eachblockaxis (axes (A, 2 ))
78+ # using the property that zip stops as soon as one of the iterators is exhausted
79+ s_axes = map (splat (infimum), zip (brows, bcols))
80+ s_axis = mortar_axis (s_axes)
81+
82+ BS = output_type (svd_vals!, blocktype (A))
83+ return similar (A, BlockType (BS), S_axes)
84+ end
85+
6186function MatrixAlgebraKit. check_input (
6287 :: typeof (svd_compact!),
6388 A:: AbstractBlockSparseMatrix ,
@@ -66,7 +91,6 @@ function MatrixAlgebraKit.check_input(
6691)
6792 @assert isblockpermuteddiagonal (A)
6893end
69-
7094function MatrixAlgebraKit. check_input (
7195 :: typeof (svd_compact!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ), :: BlockDiagonalAlgorithm
7296)
@@ -87,7 +111,6 @@ function MatrixAlgebraKit.check_input(
87111 @assert isblockpermuteddiagonal (A)
88112 return nothing
89113end
90-
91114function MatrixAlgebraKit. check_input (
92115 :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ), :: BlockDiagonalAlgorithm
93116)
@@ -102,15 +125,30 @@ function MatrixAlgebraKit.check_input(
102125 return nothing
103126end
104127
128+ function MatrixAlgebraKit. check_input (
129+ :: typeof (svd_vals!), A:: AbstractBlockSparseMatrix , S, :: BlockPermutedDiagonalAlgorithm
130+ )
131+ @assert isblockpermuteddiagonal (A)
132+ return nothing
133+ end
134+ function MatrixAlgebraKit. check_input (
135+ :: typeof (svd_vals!), A:: AbstractBlockSparseMatrix , S, :: BlockDiagonalAlgorithm
136+ )
137+ @assert isa (S, AbstractBlockSparseVector)
138+ @assert real (eltype (A)) == eltype (S)
139+ @assert isblockdiagonal (A)
140+ return nothing
141+ end
142+
105143function MatrixAlgebraKit. svd_compact! (
106144 A:: AbstractBlockSparseMatrix , USVᴴ, alg:: BlockPermutedDiagonalAlgorithm
107145)
108146 check_input (svd_compact!, A, USVᴴ, alg)
109147
110- Ad, transform_rows, transform_cols = blockdiagonalize (A)
148+ Ad, (invrowperm, invcolperm) = blockdiagonalize (A)
111149 Ud, S, Vᴴd = svd_compact! (Ad, BlockDiagonalAlgorithm (alg))
112- U = transform_rows (Ud)
113- Vᴴ = transform_cols (Vᴴd)
150+ U = transform_rows (Ud, invrowperm )
151+ Vᴴ = transform_cols (Vᴴd, invcolperm )
114152
115153 return U, S, Vᴴ
116154end
@@ -143,10 +181,10 @@ function MatrixAlgebraKit.svd_full!(
143181)
144182 check_input (svd_full!, A, USVᴴ, alg)
145183
146- Ad, transform_rows, transform_cols = blockdiagonalize (A)
184+ Ad, (invrowperm, invcolperm) = blockdiagonalize (A)
147185 Ud, S, Vᴴd = svd_full! (Ad, BlockDiagonalAlgorithm (alg))
148- U = transform_rows (Ud)
149- Vᴴ = transform_cols (Vᴴd)
186+ U = transform_rows (Ud, invrowperm )
187+ Vᴴ = transform_cols (Vᴴd, invcolperm )
150188
151189 return U, S, Vᴴ
152190end
@@ -181,3 +219,21 @@ function MatrixAlgebraKit.svd_full!(
181219
182220 return U, S, Vᴴ
183221end
222+
223+ function MatrixAlgebraKit. svd_vals! (
224+ A:: AbstractBlockSparseMatrix , S, alg:: BlockPermutedDiagonalAlgorithm
225+ )
226+ MatrixAlgebraKit. check_input (svd_vals!, A, S, alg)
227+ Ad, _ = blockdiagonalize (A)
228+ return svd_vals! (Ad, BlockDiagonalAlgorithm (alg))
229+ end
230+ function MatrixAlgebraKit. svd_vals! (
231+ A:: AbstractBlockSparseMatrix , S, alg:: BlockDiagonalAlgorithm
232+ )
233+ MatrixAlgebraKit. check_input (svd_vals!, A, S, alg)
234+ for I in eachblockstoredindex (A)
235+ block = @view! (A[I])
236+ S[Tuple (I)[1 ]] = $ f (block, block_algorithm (alg, block))
237+ end
238+ return S
239+ end
0 commit comments