|
1 | | -#= # This script has been copy-pasted from https://github.com/mhauru/TensorFactorizations.jl |
| 1 | +# This script has been copy-pasted from https://github.com/mhauru/TensorFactorizations.jl |
2 | 2 |
|
3 | | -""" |
4 | | - tensorsplit(A, a, b; kwargs...) |
5 | | -
|
6 | | -Calls tensorsvd with the arguments given to it to decompose the given tensor |
7 | | -A with indices a on one side and indices b on the other. It then splits |
8 | | -the diagonal matrix of singular values into two with a square root and |
9 | | -multiplies these weights into the isometric tensors. Thus tensorsplit ends |
10 | | -up splitting A into two parts, which are then returned, possibly together |
11 | | -with auxiliary data such as a truncation error. If the keyword argument |
12 | | -hermitian=true, an eigenvalue decomposition is used in stead of an SVD. All |
13 | | -the keyword arguments are passed to either tensorsvd or tensoreig. |
14 | | -
|
15 | | -See tensorsvd and tensoreig for further documentation. |
16 | | -""" |
17 | | -function tensorsplit(args...; kwargs...) |
18 | | - # Find the keyword argument hermitian. |
19 | | - # TODO This is awful, why do I have to do this? |
20 | | - hermitian = false |
21 | | - for (key, value) in kwargs |
22 | | - key == :hermitian && (hermitian = value) |
23 | | - end |
24 | 3 |
|
25 | | - if hermitian |
26 | | - res = tensoreig(args...; kwargs...) |
27 | | - S, U = res[1:2] |
28 | | - Vt_perm = [ndims(U), (1:(ndims(U) - 1))...] |
29 | | - Vt = conj!(tensorcopy(U, collect(1:ndims(U)), Vt_perm)) |
30 | | - S = Diagonal(S) |
31 | | - if !isposdef(S) |
32 | | - S = complex.(S) |
33 | | - end |
34 | | - auxdata = res[3:end] |
35 | | - else |
36 | | - res = tensorsvd(args...; kwargs...) |
37 | | - U, S, Vt = res[1:3] |
38 | | - S = Diagonal(S) |
39 | | - auxdata = res[4:end] |
40 | | - end |
41 | | - S_sqrt = sqrt.(S) |
42 | | - A1 = tensorcontract(U, (1:(ndims(U) - 1)..., :a), S_sqrt, (:a, :b)) |
43 | | - A2 = tensorcontract(S_sqrt, (:b, :a), Vt, (:a, 1:(ndims(Vt) - 1)...)) |
44 | | - return A1, A2, auxdata... |
45 | | -end =# |
46 | 4 |
|
47 | 5 | """ |
48 | 6 | tensoreig(A, a, b; chis=nothing, eps=0, |
@@ -120,90 +78,7 @@ function tensoreig( |
120 | 78 | return retval |
121 | 79 | end |
122 | 80 |
|
123 | | -#= """ |
124 | | - tensorsvd(A, a, b; |
125 | | - chis=nothing, eps=0, |
126 | | - return_error=false, print_error=false, |
127 | | - break_degenerate=false, degeneracy_eps=1e-6, |
128 | | - norm_type=:frobenius) |
129 | | -
|
130 | | -Singular valued decomposes a tensor A. The indices of A are |
131 | | -permuted so that the indices listed in the Array/Tuple a are on the "left" |
132 | | -side and indices listed in b are on the "right". The resulting tensor is |
133 | | -then reshaped to a matrix, and this matrix is SVDed into U*diagm(S)*Vt. |
134 | | -Finally, the unitary matrices U and Vt are reshaped to tensors so that |
135 | | -they have a new index coming from the SVD, for U as the last index and for |
136 | | -Vt as the first, and U has indices a as its first indices and V has |
137 | | -indices b as its last indices. |
138 | | -
|
139 | | -If eps>0 then the SVD may be truncated if the relative error can be kept |
140 | | -below eps. For this purpose different dimensions to truncate to can be tried, |
141 | | -and these dimensions should be listed in chis. If chis is nothing (the |
142 | | -default) then the full range of possible dimensions is tried. If |
143 | | -break_degenerate=false (the default) then the truncation never cuts between |
144 | | -degenerate singular values. degeneracy_eps controls how close the values need |
145 | | -to be to be considered degenerate. |
146 | | -
|
147 | | -norm_type specifies the norm used to measure the error. This defaults to |
148 | | -:frobenius, which means that the error measured is the Frobenius norm of the |
149 | | -difference between A and the decomposition, divided by the Frobenius norm of |
150 | | -A. This is the same thing as the 2-norm of the singular values that are |
151 | | -truncated out, divided by the 2-norm of all the singular values. The other |
152 | | -option is :trace, in which case a 1-norm is used instead. |
153 | | -
|
154 | | -If print_error=true the truncation error is printed. The default is false. |
155 | | -
|
156 | | -If return_error=true then the truncation error is also returned. |
157 | | -The default is false. |
158 | 81 |
|
159 | | -Note that no iterative techniques are used, which means choosing to truncate |
160 | | -provides no performance benefits: The full SVD is computed in any case. |
161 | | -
|
162 | | -Output is U, S, Vt, and possibly error. Here S is a vector of |
163 | | -singular values and U and Vt are isometric tensors (unitary if the matrix |
164 | | -that is SVDed is square and there is no truncation) such that U*diag(S)*Vt = |
165 | | -A, up to truncation errors. |
166 | | -""" |
167 | | -function tensorsvd( |
168 | | - A, |
169 | | - a, |
170 | | - b; |
171 | | - chis=nothing, |
172 | | - eps=0, |
173 | | - return_error=false, |
174 | | - print_error=false, |
175 | | - break_degenerate=false, |
176 | | - degeneracy_eps=1e-6, |
177 | | - norm_type=:frobenius, |
178 | | -) |
179 | | - # Create the matrix and SVD it. |
180 | | - A, shp_a, shp_b = to_matrix(A, a, b; return_tensor_shape=true) |
181 | | - fact = svd(A) |
182 | | - U, S, Vt = fact.U, fact.S, fact.Vt |
183 | | -
|
184 | | - # Find the dimensions to truncate to and the error caused in doing so. |
185 | | - chi, error = find_trunc_dim(S, chis, eps, break_degenerate, degeneracy_eps, norm_type) |
186 | | - # Truncate |
187 | | - S = S[1:chi] |
188 | | - U = U[:, 1:chi] |
189 | | - Vt = Vt[1:chi, :] |
190 | | -
|
191 | | - if print_error |
192 | | - println("Relative truncation error ($norm_type norm) in SVD: $error") |
193 | | - end |
194 | | -
|
195 | | - # Reshape U and V to tensors with shapes matching the shape of A and |
196 | | - # return. |
197 | | - dim = size(S)[1] |
198 | | - U_tens = reshape(U, shp_a..., dim) |
199 | | - Vt_tens = reshape(Vt, dim, shp_b...) |
200 | | - retval = (U_tens, S, Vt_tens) |
201 | | - if return_error |
202 | | - retval = (retval..., error) |
203 | | - end |
204 | | - return retval |
205 | | -end |
206 | | - =# |
207 | 82 | """ |
208 | 83 | Format the bond dimensions listed in chis to a standard format. |
209 | 84 | """ |
|
0 commit comments