@@ -29,59 +29,34 @@ function similar_output(
2929end
3030
3131function MatrixAlgebraKit. initialize_output (
32- :: typeof (svd_compact!), A:: AbstractBlockSparseMatrix , alg:: BlockPermutedDiagonalAlgorithm
32+ :: typeof (svd_compact!), :: AbstractBlockSparseMatrix , :: BlockPermutedDiagonalAlgorithm
33+ )
34+ return nothing
35+ end
36+ function MatrixAlgebraKit. initialize_output (
37+ :: typeof (svd_compact!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
3338)
34- bm, bn = blocksize (A)
35- bmn = min (bm, bn)
36-
3739 brows = eachblockaxis (axes (A, 1 ))
3840 bcols = eachblockaxis (axes (A, 2 ))
39- u_axes = similar (brows, bmn)
40- v_axes = similar (brows, bmn)
41+ # using the property that zip stops as soon as one of the iterators is exhausted
42+ s_axes = map (splat (infimum), zip (brows, bcols))
43+ s_axis = mortar_axis (s_axes)
44+ S_axes = (s_axis, s_axis)
45+ U, S, Vᴴ = similar_output (svd_compact!, A, S_axes, alg)
4146
42- # fill in values for blocks that are present
43- bIs = collect (eachblockstoredindex (A))
44- browIs = Int .(first .(Tuple .(bIs)))
45- bcolIs = Int .(last .(Tuple .(bIs)))
4647 for bI in eachblockstoredindex (A)
47- row, col = Int .(Tuple (bI))
48- u_axes[col] = infimum (brows[row], bcols[col])
49- v_axes[col] = infimum (bcols[col], brows[row])
50- end
51-
52- # fill in values for blocks that aren't present, pairing them in order of occurence
53- # this is a convention, which at least gives the expected results for blockdiagonal
54- emptyrows = setdiff (1 : bm, browIs)
55- emptycols = setdiff (1 : bn, bcolIs)
56- for (row, col) in zip (emptyrows, emptycols)
57- u_axes[col] = infimum (brows[row], bcols[col])
58- v_axes[col] = infimum (bcols[col], brows[row])
59- end
60-
61- u_axis = mortar_axis (u_axes)
62- v_axis = mortar_axis (v_axes)
63- S_axes = (u_axis, v_axis)
64- U, S, Vt = similar_output (svd_compact!, A, S_axes, alg)
65-
66- # allocate output
67- for bI in eachblockstoredindex (A)
68- brow, bcol = Tuple (bI)
6948 block = @view! (A[bI])
7049 block_alg = block_algorithm (alg, block)
71- U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit. initialize_output (
50+ I = first (Tuple (bI)) # == last(Tuple(bI))
51+ U[I, I], S[I, I], Vᴴ[I, I] = MatrixAlgebraKit. initialize_output (
7252 svd_compact!, block, block_alg
7353 )
7454 end
7555
76- # allocate output for blocks that aren't present -- do we also fill identities here?
77- for (row, col) in zip (emptyrows, emptycols)
78- @view! (U[Block (row, col)])
79- @view! (Vt[Block (col, col)])
80- end
81-
82- return U, S, Vt
56+ return U, S, Vᴴ
8357end
8458
59+
8560function similar_output (
8661 :: typeof (svd_full!), A, S_axes, alg:: MatrixAlgebraKit.AbstractAlgorithm
8762)
@@ -93,65 +68,39 @@ function similar_output(
9368end
9469
9570function MatrixAlgebraKit. initialize_output (
96- :: typeof (svd_full!), A :: AbstractBlockSparseMatrix , alg :: BlockPermutedDiagonalAlgorithm
71+ :: typeof (svd_full!), :: AbstractBlockSparseMatrix , :: BlockPermutedDiagonalAlgorithm
9772)
98- bm, bn = blocksize (A)
99-
100- brows = eachblockaxis (axes (A, 1 ))
101- u_axes = similar (brows)
102-
103- # fill in values for blocks that are present
104- bIs = collect (eachblockstoredindex (A))
105- browIs = Int .(first .(Tuple .(bIs)))
106- bcolIs = Int .(last .(Tuple .(bIs)))
107- for bI in eachblockstoredindex (A)
108- row, col = Int .(Tuple (bI))
109- u_axes[col] = brows[row]
110- end
111-
112- # fill in values for blocks that aren't present, pairing them in order of occurence
113- # this is a convention, which at least gives the expected results for blockdiagonal
114- emptyrows = setdiff (1 : bm, browIs)
115- emptycols = setdiff (1 : bn, bcolIs)
116- for (row, col) in zip (emptyrows, emptycols)
117- u_axes[col] = brows[row]
118- end
119- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
120- u_axes[bn + i] = brows[emptyrows[k]]
121- end
73+ return nothing
74+ end
12275
123- u_axis = mortar_axis (u_axes)
124- S_axes = (u_axis, axes (A, 2 ))
125- U, S, Vt = similar_output (svd_full!, A, S_axes, alg)
76+ function MatrixAlgebraKit. initialize_output (
77+ :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
78+ )
79+ U, S, Vᴴ = similar_output (svd_full!, A, axes (A), alg)
12680
127- # allocate output
12881 for bI in eachblockstoredindex (A)
129- brow, bcol = Tuple (bI)
13082 block = @view! (A[bI])
13183 block_alg = block_algorithm (alg, block)
132- U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit. initialize_output (
84+ I = first (Tuple (bI)) # == last(Tuple(bI))
85+ U[I, I], S[I, I], Vᴴ[I, I] = MatrixAlgebraKit. initialize_output (
13386 svd_full!, block, block_alg
13487 )
13588 end
13689
137- # allocate output for blocks that aren't present -- do we also fill identities here?
138- for (row, col) in zip (emptyrows, emptycols)
139- @view! (U[Block (row, col)])
140- @view! (Vt[Block (col, col)])
141- end
142- # also handle extra rows/cols
143- for i in (length (emptyrows) + 1 ): length (emptycols)
144- @view! (Vt[Block (emptycols[i], emptycols[i])])
145- end
146- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
147- @view! (U[Block (emptyrows[k], bn + i)])
148- end
90+ return U, S, Vᴴ
91+ end
14992
150- return U, S, Vt
93+ function MatrixAlgebraKit. check_input (
94+ :: typeof (svd_compact!),
95+ A:: AbstractBlockSparseMatrix ,
96+ USVᴴ,
97+ :: BlockPermutedDiagonalAlgorithm ,
98+ )
99+ @assert isblockpermuteddiagonal (A)
151100end
152101
153102function MatrixAlgebraKit. check_input (
154- :: typeof (svd_compact!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ)
103+ :: typeof (svd_compact!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ), :: BlockDiagonalAlgorithm
155104)
156105 @assert isa (U, AbstractBlockSparseMatrix) &&
157106 isa (S, AbstractBlockSparseMatrix) &&
@@ -160,11 +109,19 @@ function MatrixAlgebraKit.check_input(
160109 @assert real (eltype (A)) == eltype (S)
161110 @assert axes (A, 1 ) == axes (U, 1 ) && axes (A, 2 ) == axes (Vᴴ, 2 )
162111 @assert axes (S, 1 ) == axes (S, 2 )
112+ @assert isblockdiagonal (A)
163113 return nothing
164114end
165115
166116function MatrixAlgebraKit. check_input (
167- :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ)
117+ :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , USVᴴ, :: BlockPermutedDiagonalAlgorithm
118+ )
119+ @assert isblockpermuteddiagonal (A)
120+ return nothing
121+ end
122+
123+ function MatrixAlgebraKit. check_input (
124+ :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ), :: BlockDiagonalAlgorithm
168125)
169126 @assert isa (U, AbstractBlockSparseMatrix) &&
170127 isa (S, AbstractBlockSparseMatrix) &&
@@ -173,78 +130,92 @@ function MatrixAlgebraKit.check_input(
173130 @assert real (eltype (A)) == eltype (S)
174131 @assert axes (A, 1 ) == axes (U, 1 ) && axes (A, 2 ) == axes (Vᴴ, 1 ) == axes (Vᴴ, 2 )
175132 @assert axes (S, 2 ) == axes (A, 2 )
133+ @assert isblockdiagonal (A)
176134 return nothing
177135end
178136
179137function MatrixAlgebraKit. svd_compact! (
180- A:: AbstractBlockSparseMatrix , (U, S, Vᴴ) , alg:: BlockPermutedDiagonalAlgorithm
138+ A:: AbstractBlockSparseMatrix , USVᴴ , alg:: BlockPermutedDiagonalAlgorithm
181139)
182- check_input (svd_compact!, A, (U, S, Vᴴ) )
140+ check_input (svd_compact!, A, USVᴴ, alg )
183141
184- # do decomposition on each block
185- for bI in eachblockstoredindex (A)
186- brow, bcol = Tuple (bI)
187- usvᴴ = (@view! (U[brow, bcol]), @view! (S[bcol, bcol]), @view! (Vᴴ[bcol, bcol]))
188- block = @view! (A[bI])
189- block_alg = block_algorithm (alg, block)
190- usvᴴ′ = svd_compact! (block, usvᴴ, block_alg)
191- @assert usvᴴ === usvᴴ′ " svd_compact! might not be in-place"
192- end
142+ Ad, rowperm, colperm = blockdiagonalize (A)
143+ Ud, S, Vᴴd = svd_compact! (Ad, BlockDiagonalAlgorithm (alg))
144+
145+ inv_rowperm = Block .(invperm (Int .(rowperm)))
146+ U = Ud[inv_rowperm, :]
147+
148+ inv_colperm = Block .(invperm (Int .(colperm)))
149+ Vᴴ = Vᴴd[:, inv_colperm]
150+
151+ return U, S, Vᴴ
152+ end
193153
194- # fill in identities for blocks that aren't present
195- bIs = collect (eachblockstoredindex (A))
196- browIs = Int .(first .(Tuple .(bIs)))
197- bcolIs = Int .(last .(Tuple .(bIs)))
198- emptyrows = setdiff (1 : blocksize (A, 1 ), browIs)
199- emptycols = setdiff (1 : blocksize (A, 2 ), bcolIs)
200- # needs copyto! instead because size(::LinearAlgebra.I) doesn't work
201- # U[Block(row, col)] = LinearAlgebra.I
202- # Vᴴ[Block(col, col)] = LinearAlgebra.I
203- for (row, col) in zip (emptyrows, emptycols)
204- copyto! (@view! (U[Block (row, col)]), LinearAlgebra. I)
205- copyto! (@view! (Vᴴ[Block (col, col)]), LinearAlgebra. I)
154+ function MatrixAlgebraKit. svd_compact! (
155+ A:: AbstractBlockSparseMatrix , (U, S, Vᴴ), alg:: BlockDiagonalAlgorithm
156+ )
157+ check_input (svd_compact!, A, (U, S, Vᴴ), alg)
158+
159+ for I in 1 : min (blocksize (A)... )
160+ bI = Block (I, I)
161+ if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
162+ usvᴴ = (@view! (U[bI]), @view! (S[bI]), @view! (Vᴴ[bI]))
163+ block = @view! (A[bI])
164+ block_alg = block_algorithm (alg, block)
165+ usvᴴ′ = svd_compact! (block, usvᴴ, block_alg)
166+ @assert usvᴴ === usvᴴ′ " svd_compact! might not be in-place"
167+ else
168+ copyto! (@view! (U[bI]), LinearAlgebra. I)
169+ copyto! (@view! (Vᴴ[bI]), LinearAlgebra. I)
170+ end
206171 end
207172
208- return ( U, S, Vᴴ)
173+ return U, S, Vᴴ
209174end
210175
211176function MatrixAlgebraKit. svd_full! (
212- A:: AbstractBlockSparseMatrix , (U, S, Vᴴ) , alg:: BlockPermutedDiagonalAlgorithm
177+ A:: AbstractBlockSparseMatrix , USVᴴ , alg:: BlockPermutedDiagonalAlgorithm
213178)
214- check_input (svd_full!, A, (U, S, Vᴴ) )
179+ check_input (svd_full!, A, USVᴴ, alg )
215180
216- # do decomposition on each block
217- for bI in eachblockstoredindex (A)
218- brow, bcol = Tuple (bI)
219- usvᴴ = (@view! (U[brow, bcol]), @view! (S[bcol, bcol]), @view! (Vᴴ[bcol, bcol]))
220- block = @view! (A[bI])
221- block_alg = block_algorithm (alg, block)
222- usvᴴ′ = svd_full! (block, usvᴴ, block_alg)
223- @assert usvᴴ === usvᴴ′ " svd_full! might not be in-place"
224- end
181+ Ad, rowperm, colperm = blockdiagonalize (A)
182+ Ud, S, Vᴴd = svd_full! (Ad, BlockDiagonalAlgorithm (alg))
183+
184+ inv_rowperm = Block .(invperm (Int .(rowperm)))
185+ U = Ud[inv_rowperm, :]
225186
226- # fill in identities for blocks that aren't present
227- bIs = collect (eachblockstoredindex (A))
228- browIs = Int .(first .(Tuple .(bIs)))
229- bcolIs = Int .(last .(Tuple .(bIs)))
230- emptyrows = setdiff (1 : blocksize (A, 1 ), browIs)
231- emptycols = setdiff (1 : blocksize (A, 2 ), bcolIs)
232- # needs copyto! instead because size(::LinearAlgebra.I) doesn't work
233- # U[Block(row, col)] = LinearAlgebra.I
234- # Vt[Block(col, col)] = LinearAlgebra.I
235- for (row, col) in zip (emptyrows, emptycols)
236- copyto! (@view! (U[Block (row, col)]), LinearAlgebra. I)
237- copyto! (@view! (Vᴴ[Block (col, col)]), LinearAlgebra. I)
187+ inv_colperm = Block .(invperm (Int .(colperm)))
188+ Vᴴ = Vᴴd[:, inv_colperm]
189+
190+ return U, S, Vᴴ
191+ end
192+
193+ function MatrixAlgebraKit. svd_full! (
194+ A:: AbstractBlockSparseMatrix , (U, S, Vᴴ), alg:: BlockDiagonalAlgorithm
195+ )
196+ check_input (svd_full!, A, (U, S, Vᴴ), alg)
197+
198+ for I in 1 : min (blocksize (A)... )
199+ bI = Block (I, I)
200+ if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
201+ usvᴴ = (@view! (U[bI]), @view! (S[bI]), @view! (Vᴴ[bI]))
202+ block = @view! (A[bI])
203+ block_alg = block_algorithm (alg, block)
204+ usvᴴ′ = svd_full! (block, usvᴴ, block_alg)
205+ @assert usvᴴ === usvᴴ′ " svd_compact! might not be in-place"
206+ else
207+ copyto! (@view! (U[bI]), LinearAlgebra. I)
208+ copyto! (@view! (Vᴴ[bI]), LinearAlgebra. I)
209+ end
238210 end
239211
240- # also handle extra rows/cols
241- for i in ( length (emptyrows) + 1 ) : length (emptycols )
242- copyto! (@view! (Vᴴ [Block (emptycols[i], emptycols[i] )]), LinearAlgebra. I)
212+ # Complete the unitaries for rectangular inputs
213+ for I in blocksize (A, 2 ) + 1 : blocksize (A, 1 )
214+ copyto! (@view! (U [Block (I, I )]), LinearAlgebra. I)
243215 end
244- bn = blocksize (A, 2 )
245- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
246- copyto! (@view! (U[Block (emptyrows[k], bn + i)]), LinearAlgebra. I)
216+ for I in blocksize (A, 1 )+ 1 : blocksize (A, 2 )
217+ copyto! (@view! (Vᴴ[Block (I, I)]), LinearAlgebra. I)
247218 end
248219
249- return ( U, S, Vᴴ)
220+ return U, S, Vᴴ
250221end
0 commit comments