@@ -11,21 +11,11 @@ function MatrixAlgebraKit.default_svd_algorithm(
1111 end
1212end
1313
14- function output_type (:: typeof (svd_compact!), A:: Type{<:AbstractMatrix{T}} ) where {T}
15- USVᴴ = Base. promote_op (svd_compact!, A)
16- ! isconcretetype (USVᴴ) &&
17- return Tuple{AbstractMatrix{T},AbstractMatrix{realtype (T)},AbstractMatrix{T}}
18- return USVᴴ
19- end
20-
21- function similar_output (
22- :: typeof (svd_compact!), A, S_axes, alg:: MatrixAlgebraKit.AbstractAlgorithm
23- )
24- BU, BS, BVᴴ = fieldtypes (output_type (svd_compact!, blocktype (A)))
25- U = similar (A, BlockType (BU), (axes (A, 1 ), S_axes[1 ]))
26- S = similar (A, BlockType (BS), S_axes)
27- Vᴴ = similar (A, BlockType (BVᴴ), (S_axes[2 ], axes (A, 2 )))
28- return U, S, Vᴴ
14+ function output_type (
15+ f:: Union{typeof(svd_compact!),typeof(svd_full!)} , A:: Type{<:AbstractMatrix{T}}
16+ ) where {T}
17+ USVᴴ = Base. promote_op (f, A)
18+ return isconcretetype (USVᴴ) ? USVᴴ : Tuple{AbstractMatrix{T},AbstractMatrix{realtype (T)},AbstractMatrix{T}}
2919end
3020
3121function MatrixAlgebraKit. initialize_output (
@@ -42,28 +32,13 @@ function MatrixAlgebraKit.initialize_output(
4232 s_axes = map (splat (infimum), zip (brows, bcols))
4333 s_axis = mortar_axis (s_axes)
4434 S_axes = (s_axis, s_axis)
45- U, S, Vᴴ = similar_output (svd_compact!, A, S_axes, alg)
46-
47- for bI in eachblockstoredindex (A)
48- block = @view! (A[bI])
49- block_alg = block_algorithm (alg, block)
50- I = first (Tuple (bI)) # == last(Tuple(bI))
51- U[I, I], S[I, I], Vᴴ[I, I] = MatrixAlgebraKit. initialize_output (
52- svd_compact!, block, block_alg
53- )
54- end
5535
56- return U, S, Vᴴ
57- end
36+ BU, BS, BVᴴ = fieldtypes (output_type (svd_compact!, blocktype (A)))
37+ U = similar (A, BlockType (BU), (axes (A, 1 ), S_axes[1 ]))
38+ S = similar (A, BlockType (BS), S_axes)
39+ Vᴴ = similar (A, BlockType (BVᴴ), (S_axes[2 ], axes (A, 2 )))
5840
59- function similar_output (
60- :: typeof (svd_full!), A, S_axes, alg:: MatrixAlgebraKit.AbstractAlgorithm
61- )
62- U = similar (A, axes (A, 1 ), S_axes[1 ])
63- T = real (eltype (A))
64- S = similar (A, T, S_axes)
65- Vt = similar (A, S_axes[2 ], axes (A, 2 ))
66- return U, S, Vt
41+ return U, S, Vᴴ
6742end
6843
6944function MatrixAlgebraKit. initialize_output (
7550function MatrixAlgebraKit. initialize_output (
7651 :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
7752)
78- U, S, Vᴴ = similar_output (svd_full!, A, axes (A), alg)
79-
80- for bI in eachblockstoredindex (A)
81- block = @view! (A[bI])
82- block_alg = block_algorithm (alg, block)
83- I = first (Tuple (bI)) # == last(Tuple(bI))
84- U[I, I], S[I, I], Vᴴ[I, I] = MatrixAlgebraKit. initialize_output (
85- svd_full!, block, block_alg
86- )
87- end
53+ BU, BS, BVᴴ = fieldtypes (output_type (svd_full!, blocktype (A)))
54+ U = similar (A, BlockType (BU), (axes (A, 1 ), axes (A, 1 )))
55+ S = similar (A, BlockType (BS), axes (A))
56+ Vᴴ = similar (A, BlockType (BVᴴ), (axes (A, 2 ), axes (A, 2 )))
8857
8958 return U, S, Vᴴ
9059end
@@ -154,11 +123,12 @@ function MatrixAlgebraKit.svd_compact!(
154123 for I in 1 : min (blocksize (A)... )
155124 bI = Block (I, I)
156125 if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
157- usvᴴ = (@view! (U[bI]), @view! (S[bI]), @view! (Vᴴ[bI]))
158126 block = @view! (A[bI])
159127 block_alg = block_algorithm (alg, block)
160- usvᴴ′ = svd_compact! (block, usvᴴ, block_alg)
161- @assert usvᴴ === usvᴴ′ " svd_compact! might not be in-place"
128+ bU, bS, bVᴴ = svd_compact! (block, block_alg)
129+ U[bI] = bU
130+ S[bI] = bS
131+ Vᴴ[bI] = bVᴴ
162132 else
163133 copyto! (@view! (U[bI]), LinearAlgebra. I)
164134 copyto! (@view! (Vᴴ[bI]), LinearAlgebra. I)
@@ -189,11 +159,12 @@ function MatrixAlgebraKit.svd_full!(
189159 for I in 1 : min (blocksize (A)... )
190160 bI = Block (I, I)
191161 if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
192- usvᴴ = (@view! (U[bI]), @view! (S[bI]), @view! (Vᴴ[bI]))
193162 block = @view! (A[bI])
194163 block_alg = block_algorithm (alg, block)
195- usvᴴ′ = svd_full! (block, usvᴴ, block_alg)
196- @assert usvᴴ === usvᴴ′ " svd_compact! might not be in-place"
164+ bU, bS, bVᴴ = svd_full! (block, block_alg)
165+ U[bI] = bU
166+ S[bI] = bS
167+ Vᴴ[bI] = bVᴴ
197168 else
198169 copyto! (@view! (U[bI]), LinearAlgebra. I)
199170 copyto! (@view! (Vᴴ[bI]), LinearAlgebra. I)
0 commit comments