@@ -45,69 +45,11 @@ function MatrixAlgebraKit.findtruncated(
4545 return indexmask
4646end
4747
48- function similar_truncate (
49- :: typeof (svd_trunc!),
50- (U, S, Vᴴ):: TBlockUSV ᴴ,
51- strategy:: BlockPermutedDiagonalTruncationStrategy ,
52- indexmask= MatrixAlgebraKit. findtruncated (diagview (S), strategy),
53- )
54- ax = axes (S, 1 )
55- counter = Base. Fix1 (count, Base. Fix1 (getindex, indexmask))
56- s_lengths = filter! (> (0 ), map (counter, blocks (ax)))
57- s_axis = blockedrange (s_lengths)
58- Ũ = similar (U, axes (U, 1 ), s_axis)
59- S̃ = similar (S, s_axis, s_axis)
60- Ṽᴴ = similar (Vᴴ, s_axis, axes (Vᴴ, 2 ))
61- return Ũ, S̃, Ṽᴴ
62- end
63-
6448function MatrixAlgebraKit. truncate! (
6549 :: typeof (svd_trunc!),
6650 (U, S, Vᴴ):: TBlockUSV ᴴ,
6751 strategy:: BlockPermutedDiagonalTruncationStrategy ,
6852)
69- indexmask = MatrixAlgebraKit. findtruncated (diagview (S), strategy)
70-
71- # first determine the block structure of the output to avoid having assumptions on the
72- # data structures
73- Ũ, S̃, Ṽᴴ = similar_truncate (svd_trunc!, (U, S, Vᴴ), strategy, indexmask)
74-
75- # then loop over the blocks and assign the data
76- # TODO : figure out if we can presort and loop over the blocks -
77- # for now this has issues with missing blocks
78- bI_Us = collect (eachblockstoredindex (U))
79- bI_Ss = collect (eachblockstoredindex (S))
80- bI_Vᴴs = collect (eachblockstoredindex (Vᴴ))
81-
82- I′ = 0 # number of skipped blocks that got fully truncated
83- ax = axes (S, 1 )
84- for I in 1 : blocksize (ax, 1 )
85- b = ax[Block (I)]
86- mask = indexmask[b]
87-
88- if ! any (mask)
89- I′ += 1
90- continue
91- end
92-
93- bU_id = @something findfirst (x -> last (Tuple (x)) == Block (I), bI_Us) error (
94- " No U-block found for $I "
95- )
96- bU = Tuple (bI_Us[bU_id])
97- Ũ[bU[1 ], bU[2 ] - Block (I′)] = view (U, bU... )[:, mask]
98-
99- bVᴴ_id = @something findfirst (x -> first (Tuple (x)) == Block (I), bI_Vᴴs) error (
100- " No Vᴴ-block found for $I "
101- )
102- bVᴴ = Tuple (bI_Vᴴs[bVᴴ_id])
103- Ṽᴴ[bVᴴ[1 ] - Block (I′), bVᴴ[2 ]] = view (Vᴴ, bVᴴ... )[mask, :]
104-
105- bS_id = findfirst (x -> last (Tuple (x)) == Block (I), bI_Ss)
106- if ! isnothing (bS_id)
107- bS = Tuple (bI_Ss[bS_id])
108- S̃[(bS .- Block (I′)). .. ] = Diagonal (diagview (view (S, bS... ))[mask])
109- end
110- end
111-
112- return Ũ, S̃, Ṽᴴ
53+ I = MatrixAlgebraKit. findtruncated (diagview (S), strategy)
54+ return (U[:, I], S[I, I], Vᴴ[I, :])
11355end
0 commit comments