|
| 1 | +using MatrixAlgebraKit: |
| 2 | + MatrixAlgebraKit, |
| 3 | + PolarViaSVD, |
| 4 | + check_input, |
| 5 | + default_algorithm, |
| 6 | + left_polar!, |
| 7 | + right_polar!, |
| 8 | + svd_compact! |
| 9 | + |
| 10 | +function MatrixAlgebraKit.check_input(::typeof(left_polar!), A::AbstractBlockSparseMatrix) |
| 11 | + @views for I in eachblockstoredindex(A) |
| 12 | + m, n = size(A[I]) |
| 13 | + m >= n || |
| 14 | + throw(ArgumentError("each input matrix block needs at least as many rows as columns")) |
| 15 | + end |
| 16 | + return nothing |
| 17 | +end |
| 18 | +function MatrixAlgebraKit.check_input(::typeof(right_polar!), A::AbstractBlockSparseMatrix) |
| 19 | + @views for I in eachblockstoredindex(A) |
| 20 | + m, n = size(A[I]) |
| 21 | + m <= n || |
| 22 | + throw(ArgumentError("each input matrix block needs at least as many columns as rows")) |
| 23 | + end |
| 24 | + return nothing |
| 25 | +end |
| 26 | + |
| 27 | +function MatrixAlgebraKit.left_polar!(A::AbstractBlockSparseMatrix, alg::PolarViaSVD) |
| 28 | + check_input(left_polar!, A) |
| 29 | + # TODO: Use more in-place operations here, avoid `copy`. |
| 30 | + U, S, Vᴴ = svd_compact!(A, alg.svdalg) |
| 31 | + W = U * Vᴴ |
| 32 | + # TODO: `copy` is required for now because of: |
| 33 | + # https://github.com/ITensor/BlockSparseArrays.jl/issues/24 |
| 34 | + # Remove when that is fixed. |
| 35 | + P = copy(Vᴴ') * S * Vᴴ |
| 36 | + return (W, P) |
| 37 | +end |
| 38 | +function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, alg::PolarViaSVD) |
| 39 | + check_input(right_polar!, A) |
| 40 | + # TODO: Use more in-place operations here, avoid `copy`. |
| 41 | + U, S, Vᴴ = svd_compact!(A, alg.svdalg) |
| 42 | + Wᴴ = U * Vᴴ |
| 43 | + # TODO: `copy` is required for now because of: |
| 44 | + # https://github.com/ITensor/BlockSparseArrays.jl/issues/24 |
| 45 | + # Remove when that is fixed. |
| 46 | + P = U * S * copy(U') |
| 47 | + return (P, Wᴴ) |
| 48 | +end |
| 49 | + |
| 50 | +function MatrixAlgebraKit.default_algorithm( |
| 51 | + ::typeof(left_polar!), a::AbstractBlockSparseMatrix; kwargs... |
| 52 | +) |
| 53 | + return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...)) |
| 54 | +end |
| 55 | +function MatrixAlgebraKit.default_algorithm( |
| 56 | + ::typeof(right_polar!), a::AbstractBlockSparseMatrix; kwargs... |
| 57 | +) |
| 58 | + return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...)) |
| 59 | +end |
0 commit comments