@@ -9,211 +9,143 @@ function MatrixAlgebraKit.default_qr_algorithm(
99 end
1010end
1111
12- function similar_output (
13- :: typeof (qr_compact!), A, R_axis, alg:: MatrixAlgebraKit.AbstractAlgorithm
14- )
15- Q = similar (A, axes (A, 1 ), R_axis)
16- R = similar (A, R_axis, axes (A, 2 ))
17- return Q, R
12+ function output_type (
13+ f:: Union{typeof(qr_compact!),typeof(qr_full!)} , A:: Type{<:AbstractMatrix{T}}
14+ ) where {T}
15+ QR = Base. promote_op (f, A)
16+ return isconcretetype (QR) ? QR : Tuple{AbstractMatrix{T},AbstractMatrix{T}}
1817end
1918
20- function similar_output (
21- :: typeof (qr_full !), A, R_axis, alg :: MatrixAlgebraKit.AbstractAlgorithm
19+ function MatrixAlgebraKit . initialize_output (
20+ :: typeof (qr_compact !), :: AbstractBlockSparseMatrix , :: BlockPermutedDiagonalAlgorithm
2221)
23- Q = similar (A, axes (A, 1 ), R_axis)
24- R = similar (A, R_axis, axes (A, 2 ))
25- return Q, R
22+ return nothing
2623end
27-
2824function MatrixAlgebraKit. initialize_output (
29- :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , alg:: BlockPermutedDiagonalAlgorithm
25+ :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
3026)
31- bm, bn = blocksize (A)
32- bmn = min (bm, bn)
33-
3427 brows = eachblockaxis (axes (A, 1 ))
3528 bcols = eachblockaxis (axes (A, 2 ))
36- r_axes = similar (brows, bmn)
37-
38- # fill in values for blocks that are present
39- bIs = collect (eachblockstoredindex (A))
40- browIs = Int .(first .(Tuple .(bIs)))
41- bcolIs = Int .(last .(Tuple .(bIs)))
42- for bI in eachblockstoredindex (A)
43- row, col = Int .(Tuple (bI))
44- len = minimum (length, (brows[row], bcols[col]))
45- r_axes[col] = brows[row][Base. OneTo (len)]
46- end
47-
48- # fill in values for blocks that aren't present, pairing them in order of occurence
49- # this is a convention, which at least gives the expected results for blockdiagonal
50- emptyrows = setdiff (1 : bm, browIs)
51- emptycols = setdiff (1 : bn, bcolIs)
52- for (row, col) in zip (emptyrows, emptycols)
53- len = minimum (length, (brows[row], bcols[col]))
54- r_axes[col] = brows[row][Base. OneTo (len)]
55- end
56-
29+ # using the property that zip stops as soon as one of the iterators is exhausted
30+ r_axes = map (splat (infimum), zip (brows, bcols))
5731 r_axis = mortar_axis (r_axes)
58- Q, R = similar_output (qr_compact!, A, r_axis, alg)
59-
60- # allocate output
61- for bI in eachblockstoredindex (A)
62- brow, bcol = Tuple (bI)
63- block = @view! (A[bI])
64- block_alg = block_algorithm (alg, block)
65- Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit. initialize_output (
66- qr_compact!, block, block_alg
67- )
68- end
6932
70- # allocate output for blocks that aren't present -- do we also fill identities here?
71- for (row, col) in zip (emptyrows, emptycols)
72- @view! (Q[Block (row, col)])
73- end
33+ BQ, BR = fieldtypes (output_type (qr_compact!, blocktype (A)))
34+ Q = similar (A, BlockType (BQ), (axes (A, 1 ), r_axis))
35+ R = similar (A, BlockType (BR), (r_axis, axes (A, 2 )))
7436
7537 return Q, R
7638end
7739
7840function MatrixAlgebraKit. initialize_output (
79- :: typeof (qr_full!), A :: AbstractBlockSparseMatrix , alg :: BlockPermutedDiagonalAlgorithm
41+ :: typeof (qr_full!), :: AbstractBlockSparseMatrix , :: BlockPermutedDiagonalAlgorithm
8042)
81- bm, bn = blocksize (A)
82-
83- brows = eachblockaxis (axes (A, 1 ))
84- r_axes = copy (brows)
85-
86- # fill in values for blocks that are present
87- bIs = collect (eachblockstoredindex (A))
88- browIs = Int .(first .(Tuple .(bIs)))
89- bcolIs = Int .(last .(Tuple .(bIs)))
90- for bI in eachblockstoredindex (A)
91- row, col = Int .(Tuple (bI))
92- r_axes[col] = brows[row]
93- end
94-
95- # fill in values for blocks that aren't present, pairing them in order of occurence
96- # this is a convention, which at least gives the expected results for blockdiagonal
97- emptyrows = setdiff (1 : bm, browIs)
98- emptycols = setdiff (1 : bn, bcolIs)
99- for (row, col) in zip (emptyrows, emptycols)
100- r_axes[col] = brows[row]
101- end
102- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
103- r_axes[bn + i] = brows[emptyrows[k]]
104- end
105-
106- r_axis = mortar_axis (r_axes)
107- Q, R = similar_output (qr_full!, A, r_axis, alg)
108-
109- # allocate output
110- for bI in eachblockstoredindex (A)
111- brow, bcol = Tuple (bI)
112- block = @view! (A[bI])
113- block_alg = block_algorithm (alg, block)
114- Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit. initialize_output (
115- qr_full!, block, block_alg
116- )
117- end
118-
119- # allocate output for blocks that aren't present -- do we also fill identities here?
120- for (row, col) in zip (emptyrows, emptycols)
121- @view! (Q[Block (row, col)])
122- end
123- # also handle extra rows/cols
124- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
125- @view! (Q[Block (emptyrows[k], bn + i)])
126- end
127-
43+ return nothing
44+ end
45+ function MatrixAlgebraKit. initialize_output (
46+ :: typeof (qr_full!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
47+ )
48+ BQ, BR = fieldtypes (output_type (qr_compact!, blocktype (A)))
49+ Q = similar (A, BlockType (BQ), (axes (A, 1 ), axes (A, 1 )))
50+ R = similar (A, BlockType (BR), (axes (A, 1 ), axes (A, 2 )))
12851 return Q, R
12952end
13053
13154function MatrixAlgebraKit. check_input (
132- :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , QR
55+ :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , QR, :: BlockPermutedDiagonalAlgorithm
56+ )
57+ @assert isblockpermuteddiagonal (A)
58+ return nothing
59+ end
60+ function MatrixAlgebraKit. check_input (
61+ :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , (Q, R), :: BlockDiagonalAlgorithm
13362)
134- Q, R = QR
13563 @assert isa (Q, AbstractBlockSparseMatrix) && isa (R, AbstractBlockSparseMatrix)
13664 @assert eltype (A) == eltype (Q) == eltype (R)
13765 @assert axes (A, 1 ) == axes (Q, 1 ) && axes (A, 2 ) == axes (R, 2 )
13866 @assert axes (Q, 2 ) == axes (R, 1 )
139-
67+ @assert isblockdiagonal (A)
14068 return nothing
14169end
14270
143- function MatrixAlgebraKit. check_input (:: typeof (qr_full!), A:: AbstractBlockSparseMatrix , QR)
144- Q, R = QR
71+ function MatrixAlgebraKit. check_input (
72+ :: typeof (qr_full!), A:: AbstractBlockSparseMatrix , QR, :: BlockPermutedDiagonalAlgorithm
73+ )
74+ @assert isblockpermuteddiagonal (A)
75+ return nothing
76+ end
77+ function MatrixAlgebraKit. check_input (
78+ :: typeof (qr_full!), A:: AbstractBlockSparseMatrix , (Q, R), :: BlockDiagonalAlgorithm
79+ )
14580 @assert isa (Q, AbstractBlockSparseMatrix) && isa (R, AbstractBlockSparseMatrix)
14681 @assert eltype (A) == eltype (Q) == eltype (R)
14782 @assert axes (A, 1 ) == axes (Q, 1 ) && axes (A, 2 ) == axes (R, 2 )
14883 @assert axes (Q, 2 ) == axes (R, 1 )
149-
84+ @assert isblockdiagonal (A)
15085 return nothing
15186end
15287
15388function MatrixAlgebraKit. qr_compact! (
15489 A:: AbstractBlockSparseMatrix , QR, alg:: BlockPermutedDiagonalAlgorithm
15590)
156- MatrixAlgebraKit. check_input (qr_compact!, A, QR)
157- Q, R = QR
91+ check_input (qr_compact!, A, QR, alg)
92+ Ad, transform_rows, transform_cols = blockdiagonalize (A)
93+ Qd, Rd = qr_compact! (Ad, BlockDiagonalAlgorithm (alg))
94+ Q = transform_rows (Qd)
95+ R = transform_cols (Rd)
96+ return Q, R
97+ end
15898
159- # do decomposition on each block
160- for bI in eachblockstoredindex (A)
161- brow, bcol = Tuple (bI)
162- qr = (@view! (Q[brow, bcol]), @view! (R[bcol, bcol]))
163- block = @view! (A[bI])
164- block_alg = block_algorithm (alg, block)
165- qr′ = qr_compact! (block, qr, block_alg)
166- @assert qr === qr′ " qr_compact! might not be in-place"
167- end
99+ function MatrixAlgebraKit. qr_compact! (
100+ A:: AbstractBlockSparseMatrix , (Q, R), alg:: BlockDiagonalAlgorithm
101+ )
102+ MatrixAlgebraKit. check_input (qr_compact!, A, (Q, R), alg)
168103
169- # fill in identities for blocks that aren't present
170- bIs = collect (eachblockstoredindex (A))
171- browIs = Int .(first .(Tuple .(bIs)))
172- bcolIs = Int .(last .(Tuple .(bIs)))
173- emptyrows = setdiff (1 : blocksize (A, 1 ), browIs)
174- emptycols = setdiff (1 : blocksize (A, 2 ), bcolIs)
175- # needs copyto! instead because size(::LinearAlgebra.I) doesn't work
176- # Q[Block(row, col)] = LinearAlgebra.I
177- for (row, col) in zip (emptyrows, emptycols)
178- copyto! (@view! (Q[Block (row, col)]), LinearAlgebra. I)
104+ # do decomposition on each block
105+ for I in 1 : min (blocksize (A)... )
106+ bI = Block (I, I)
107+ if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
108+ block = @view! (A[bI])
109+ block_alg = block_algorithm (alg, block)
110+ bQ, bR = qr_compact! (block, block_alg)
111+ Q[bI] = bQ
112+ R[bI] = bR
113+ else
114+ copyto! (@view! (Q[bI]), LinearAlgebra. I)
115+ end
179116 end
180117
181- return QR
118+ return Q, R
182119end
183120
184121function MatrixAlgebraKit. qr_full! (
185122 A:: AbstractBlockSparseMatrix , QR, alg:: BlockPermutedDiagonalAlgorithm
186123)
187- MatrixAlgebraKit. check_input (qr_full!, A, QR)
188- Q, R = QR
189-
190- # do decomposition on each block
191- for bI in eachblockstoredindex (A)
192- brow, bcol = Tuple (bI)
193- qr = (@view! (Q[brow, bcol]), @view! (R[bcol, bcol]))
194- block = @view! (A[bI])
195- block_alg = block_algorithm (alg, block)
196- qr′ = qr_full! (block, qr, block_alg)
197- @assert qr === qr′ " qr_full! might not be in-place"
198- end
124+ check_input (qr_full!, A, QR, alg)
125+ Ad, transform_rows, transform_cols = blockdiagonalize (A)
126+ Qd, Rd = qr_full! (Ad, BlockDiagonalAlgorithm (alg))
127+ Q = transform_rows (Qd)
128+ R = transform_cols (Rd)
129+ return Q, R
130+ end
199131
200- # fill in identities for blocks that aren't present
201- bIs = collect (eachblockstoredindex (A))
202- browIs = Int .(first .(Tuple .(bIs)))
203- bcolIs = Int .(last .(Tuple .(bIs)))
204- emptyrows = setdiff (1 : blocksize (A, 1 ), browIs)
205- emptycols = setdiff (1 : blocksize (A, 2 ), bcolIs)
206- # needs copyto! instead because size(::LinearAlgebra.I) doesn't work
207- # Q[Block(row, col)] = LinearAlgebra.I
208- for (row, col) in zip (emptyrows, emptycols)
209- copyto! (@view! (Q[Block (row, col)]), LinearAlgebra. I)
210- end
132+ function MatrixAlgebraKit. qr_full! (
133+ A:: AbstractBlockSparseMatrix , (Q, R), alg:: BlockDiagonalAlgorithm
134+ )
135+ MatrixAlgebraKit. check_input (qr_full!, A, (Q, R), alg)
211136
212- # also handle extra rows/cols
213- bn = blocksize (A, 2 )
214- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
215- copyto! (@view! (Q[Block (emptyrows[k], bn + i)]), LinearAlgebra. I)
137+ for I in 1 : min (blocksize (A)... )
138+ bI = Block (I, I)
139+ if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
140+ block = @view! (A[bI])
141+ block_alg = block_algorithm (alg, block)
142+ bQ, bR = qr_full! (block, block_alg)
143+ Q[bI] = bQ
144+ R[bI] = bR
145+ else
146+ copyto! (@view! (Q[bI]), LinearAlgebra. I)
147+ end
216148 end
217149
218- return QR
150+ return Q, R
219151end
0 commit comments