1- using MatrixAlgebraKit:
2- MatrixAlgebraKit, default_qr_algorithm, lq_compact!, lq_full!, qr_compact!, qr_full!
1+ using MatrixAlgebraKit: MatrixAlgebraKit, default_qr_algorithm, qr_compact!, qr_full!
32
43function MatrixAlgebraKit. default_qr_algorithm (
54 :: Type{<:AbstractBlockSparseMatrix} ; kwargs...
@@ -9,211 +8,143 @@ function MatrixAlgebraKit.default_qr_algorithm(
98 end
109end
1110
12- function similar_output (
13- :: typeof (qr_compact!), A, R_axis, alg:: MatrixAlgebraKit.AbstractAlgorithm
14- )
15- Q = similar (A, axes (A, 1 ), R_axis)
16- R = similar (A, R_axis, axes (A, 2 ))
17- return Q, R
11+ function output_type (
12+ f:: Union{typeof(qr_compact!),typeof(qr_full!)} , A:: Type{<:AbstractMatrix{T}}
13+ ) where {T}
14+ QR = Base. promote_op (f, A)
15+ return isconcretetype (QR) ? QR : Tuple{AbstractMatrix{T},AbstractMatrix{T}}
1816end
1917
20- function similar_output (
21- :: typeof (qr_full !), A, R_axis, alg :: MatrixAlgebraKit.AbstractAlgorithm
18+ function MatrixAlgebraKit . initialize_output (
19+ :: typeof (qr_compact !), :: AbstractBlockSparseMatrix , :: BlockPermutedDiagonalAlgorithm
2220)
23- Q = similar (A, axes (A, 1 ), R_axis)
24- R = similar (A, R_axis, axes (A, 2 ))
25- return Q, R
21+ return nothing
2622end
27-
2823function MatrixAlgebraKit. initialize_output (
29- :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , alg:: BlockPermutedDiagonalAlgorithm
24+ :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
3025)
31- bm, bn = blocksize (A)
32- bmn = min (bm, bn)
33-
3426 brows = eachblockaxis (axes (A, 1 ))
3527 bcols = eachblockaxis (axes (A, 2 ))
36- r_axes = similar (brows, bmn)
37-
38- # fill in values for blocks that are present
39- bIs = collect (eachblockstoredindex (A))
40- browIs = Int .(first .(Tuple .(bIs)))
41- bcolIs = Int .(last .(Tuple .(bIs)))
42- for bI in eachblockstoredindex (A)
43- row, col = Int .(Tuple (bI))
44- len = minimum (length, (brows[row], bcols[col]))
45- r_axes[col] = brows[row][Base. OneTo (len)]
46- end
47-
48- # fill in values for blocks that aren't present, pairing them in order of occurence
49- # this is a convention, which at least gives the expected results for blockdiagonal
50- emptyrows = setdiff (1 : bm, browIs)
51- emptycols = setdiff (1 : bn, bcolIs)
52- for (row, col) in zip (emptyrows, emptycols)
53- len = minimum (length, (brows[row], bcols[col]))
54- r_axes[col] = brows[row][Base. OneTo (len)]
55- end
56-
28+ # using the property that zip stops as soon as one of the iterators is exhausted
29+ r_axes = map (splat (infimum), zip (brows, bcols))
5730 r_axis = mortar_axis (r_axes)
58- Q, R = similar_output (qr_compact!, A, r_axis, alg)
59-
60- # allocate output
61- for bI in eachblockstoredindex (A)
62- brow, bcol = Tuple (bI)
63- block = @view! (A[bI])
64- block_alg = block_algorithm (alg, block)
65- Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit. initialize_output (
66- qr_compact!, block, block_alg
67- )
68- end
6931
70- # allocate output for blocks that aren't present -- do we also fill identities here?
71- for (row, col) in zip (emptyrows, emptycols)
72- @view! (Q[Block (row, col)])
73- end
32+ BQ, BR = fieldtypes (output_type (qr_compact!, blocktype (A)))
33+ Q = similar (A, BlockType (BQ), (axes (A, 1 ), r_axis))
34+ R = similar (A, BlockType (BR), (r_axis, axes (A, 2 )))
7435
7536 return Q, R
7637end
7738
7839function MatrixAlgebraKit. initialize_output (
79- :: typeof (qr_full!), A :: AbstractBlockSparseMatrix , alg :: BlockPermutedDiagonalAlgorithm
40+ :: typeof (qr_full!), :: AbstractBlockSparseMatrix , :: BlockPermutedDiagonalAlgorithm
8041)
81- bm, bn = blocksize (A)
82-
83- brows = eachblockaxis (axes (A, 1 ))
84- r_axes = copy (brows)
85-
86- # fill in values for blocks that are present
87- bIs = collect (eachblockstoredindex (A))
88- browIs = Int .(first .(Tuple .(bIs)))
89- bcolIs = Int .(last .(Tuple .(bIs)))
90- for bI in eachblockstoredindex (A)
91- row, col = Int .(Tuple (bI))
92- r_axes[col] = brows[row]
93- end
94-
95- # fill in values for blocks that aren't present, pairing them in order of occurence
96- # this is a convention, which at least gives the expected results for blockdiagonal
97- emptyrows = setdiff (1 : bm, browIs)
98- emptycols = setdiff (1 : bn, bcolIs)
99- for (row, col) in zip (emptyrows, emptycols)
100- r_axes[col] = brows[row]
101- end
102- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
103- r_axes[bn + i] = brows[emptyrows[k]]
104- end
105-
106- r_axis = mortar_axis (r_axes)
107- Q, R = similar_output (qr_full!, A, r_axis, alg)
108-
109- # allocate output
110- for bI in eachblockstoredindex (A)
111- brow, bcol = Tuple (bI)
112- block = @view! (A[bI])
113- block_alg = block_algorithm (alg, block)
114- Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit. initialize_output (
115- qr_full!, block, block_alg
116- )
117- end
118-
119- # allocate output for blocks that aren't present -- do we also fill identities here?
120- for (row, col) in zip (emptyrows, emptycols)
121- @view! (Q[Block (row, col)])
122- end
123- # also handle extra rows/cols
124- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
125- @view! (Q[Block (emptyrows[k], bn + i)])
126- end
127-
42+ return nothing
43+ end
44+ function MatrixAlgebraKit. initialize_output (
45+ :: typeof (qr_full!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
46+ )
47+ BQ, BR = fieldtypes (output_type (qr_full!, blocktype (A)))
48+ Q = similar (A, BlockType (BQ), (axes (A, 1 ), axes (A, 1 )))
49+ R = similar (A, BlockType (BR), (axes (A, 1 ), axes (A, 2 )))
12850 return Q, R
12951end
13052
13153function MatrixAlgebraKit. check_input (
132- :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , QR
54+ :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , QR, :: BlockPermutedDiagonalAlgorithm
55+ )
56+ @assert isblockpermuteddiagonal (A)
57+ return nothing
58+ end
59+ function MatrixAlgebraKit. check_input (
60+ :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , (Q, R), :: BlockDiagonalAlgorithm
13361)
134- Q, R = QR
13562 @assert isa (Q, AbstractBlockSparseMatrix) && isa (R, AbstractBlockSparseMatrix)
13663 @assert eltype (A) == eltype (Q) == eltype (R)
13764 @assert axes (A, 1 ) == axes (Q, 1 ) && axes (A, 2 ) == axes (R, 2 )
13865 @assert axes (Q, 2 ) == axes (R, 1 )
139-
66+ @assert isblockdiagonal (A)
14067 return nothing
14168end
14269
143- function MatrixAlgebraKit. check_input (:: typeof (qr_full!), A:: AbstractBlockSparseMatrix , QR)
144- Q, R = QR
70+ function MatrixAlgebraKit. check_input (
71+ :: typeof (qr_full!), A:: AbstractBlockSparseMatrix , QR, :: BlockPermutedDiagonalAlgorithm
72+ )
73+ @assert isblockpermuteddiagonal (A)
74+ return nothing
75+ end
76+ function MatrixAlgebraKit. check_input (
77+ :: typeof (qr_full!), A:: AbstractBlockSparseMatrix , (Q, R), :: BlockDiagonalAlgorithm
78+ )
14579 @assert isa (Q, AbstractBlockSparseMatrix) && isa (R, AbstractBlockSparseMatrix)
14680 @assert eltype (A) == eltype (Q) == eltype (R)
14781 @assert axes (A, 1 ) == axes (Q, 1 ) && axes (A, 2 ) == axes (R, 2 )
14882 @assert axes (Q, 2 ) == axes (R, 1 )
149-
83+ @assert isblockdiagonal (A)
15084 return nothing
15185end
15286
15387function MatrixAlgebraKit. qr_compact! (
15488 A:: AbstractBlockSparseMatrix , QR, alg:: BlockPermutedDiagonalAlgorithm
15589)
156- MatrixAlgebraKit. check_input (qr_compact!, A, QR)
157- Q, R = QR
90+ check_input (qr_compact!, A, QR, alg)
91+ Ad, transform_rows, transform_cols = blockdiagonalize (A)
92+ Qd, Rd = qr_compact! (Ad, BlockDiagonalAlgorithm (alg))
93+ Q = transform_rows (Qd)
94+ R = transform_cols (Rd)
95+ return Q, R
96+ end
15897
159- # do decomposition on each block
160- for bI in eachblockstoredindex (A)
161- brow, bcol = Tuple (bI)
162- qr = (@view! (Q[brow, bcol]), @view! (R[bcol, bcol]))
163- block = @view! (A[bI])
164- block_alg = block_algorithm (alg, block)
165- qr′ = qr_compact! (block, qr, block_alg)
166- @assert qr === qr′ " qr_compact! might not be in-place"
167- end
98+ function MatrixAlgebraKit. qr_compact! (
99+ A:: AbstractBlockSparseMatrix , (Q, R), alg:: BlockDiagonalAlgorithm
100+ )
101+ MatrixAlgebraKit. check_input (qr_compact!, A, (Q, R), alg)
168102
169- # fill in identities for blocks that aren't present
170- bIs = collect (eachblockstoredindex (A))
171- browIs = Int .(first .(Tuple .(bIs)))
172- bcolIs = Int .(last .(Tuple .(bIs)))
173- emptyrows = setdiff (1 : blocksize (A, 1 ), browIs)
174- emptycols = setdiff (1 : blocksize (A, 2 ), bcolIs)
175- # needs copyto! instead because size(::LinearAlgebra.I) doesn't work
176- # Q[Block(row, col)] = LinearAlgebra.I
177- for (row, col) in zip (emptyrows, emptycols)
178- copyto! (@view! (Q[Block (row, col)]), LinearAlgebra. I)
103+ # do decomposition on each block
104+ for I in 1 : min (blocksize (A)... )
105+ bI = Block (I, I)
106+ if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
107+ block = @view! (A[bI])
108+ block_alg = block_algorithm (alg, block)
109+ bQ, bR = qr_compact! (block, block_alg)
110+ Q[bI] = bQ
111+ R[bI] = bR
112+ else
113+ copyto! (@view! (Q[bI]), LinearAlgebra. I)
114+ end
179115 end
180116
181- return QR
117+ return Q, R
182118end
183119
184120function MatrixAlgebraKit. qr_full! (
185121 A:: AbstractBlockSparseMatrix , QR, alg:: BlockPermutedDiagonalAlgorithm
186122)
187- MatrixAlgebraKit. check_input (qr_full!, A, QR)
188- Q, R = QR
189-
190- # do decomposition on each block
191- for bI in eachblockstoredindex (A)
192- brow, bcol = Tuple (bI)
193- qr = (@view! (Q[brow, bcol]), @view! (R[bcol, bcol]))
194- block = @view! (A[bI])
195- block_alg = block_algorithm (alg, block)
196- qr′ = qr_full! (block, qr, block_alg)
197- @assert qr === qr′ " qr_full! might not be in-place"
198- end
123+ check_input (qr_full!, A, QR, alg)
124+ Ad, transform_rows, transform_cols = blockdiagonalize (A)
125+ Qd, Rd = qr_full! (Ad, BlockDiagonalAlgorithm (alg))
126+ Q = transform_rows (Qd)
127+ R = transform_cols (Rd)
128+ return Q, R
129+ end
199130
200- # fill in identities for blocks that aren't present
201- bIs = collect (eachblockstoredindex (A))
202- browIs = Int .(first .(Tuple .(bIs)))
203- bcolIs = Int .(last .(Tuple .(bIs)))
204- emptyrows = setdiff (1 : blocksize (A, 1 ), browIs)
205- emptycols = setdiff (1 : blocksize (A, 2 ), bcolIs)
206- # needs copyto! instead because size(::LinearAlgebra.I) doesn't work
207- # Q[Block(row, col)] = LinearAlgebra.I
208- for (row, col) in zip (emptyrows, emptycols)
209- copyto! (@view! (Q[Block (row, col)]), LinearAlgebra. I)
210- end
131+ function MatrixAlgebraKit. qr_full! (
132+ A:: AbstractBlockSparseMatrix , (Q, R), alg:: BlockDiagonalAlgorithm
133+ )
134+ MatrixAlgebraKit. check_input (qr_full!, A, (Q, R), alg)
211135
212- # also handle extra rows/cols
213- bn = blocksize (A, 2 )
214- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
215- copyto! (@view! (Q[Block (emptyrows[k], bn + i)]), LinearAlgebra. I)
136+ for I in 1 : min (blocksize (A)... )
137+ bI = Block (I, I)
138+ if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
139+ block = @view! (A[bI])
140+ block_alg = block_algorithm (alg, block)
141+ bQ, bR = qr_full! (block, block_alg)
142+ Q[bI] = bQ
143+ R[bI] = bR
144+ else
145+ copyto! (@view! (Q[bI]), LinearAlgebra. I)
146+ end
216147 end
217148
218- return QR
149+ return Q, R
219150end
0 commit comments