Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
version = "0.3.1"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.3.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -35,7 +35,7 @@ FillArrays = "1.13"
GPUArraysCore = "0.2"
LinearAlgebra = "1.10"
MapBroadcast = "0.1.10"
MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5"
MatrixAlgebraKit = "0.6"
TensorAlgebra = "0.3.10, 0.4"
TensorProducts = "0.1.7"
TypeParameterAccessors = "0.4.2"
Expand Down
2 changes: 1 addition & 1 deletion src/KroneckerArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ kroneckerfactortypes(T::Type) = throw(MethodError(kroneckerfactortypes, (T,)))

Construct an object that represents the Kronecker product of the provided `args`.
""" (⊗)
function ⊗(a, b) end
function ⊗ end
const otimes = ⊗ # non-unicode alternative

# Includes
Expand Down
23 changes: 23 additions & 0 deletions src/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,29 @@ function Base.reshape(
return reshape(a, kroneckerfactors.(ax, 1)) ⊗ reshape(b, kroneckerfactors.(ax, 2))
end

function Base.fill!(ab::AbstractKroneckerArray, v)
a, b = kroneckerfactors(ab)
fill!(a, √v)
fill!(b, √v)
return ab
end
function Base.fill!(ab::AbstractKroneckerMatrix, v)
a, b = kroneckerfactors(ab)
(!isactive(a) && isone(a)) && (fill!(b, v); return ab)
(!isactive(b) && isone(b)) && (fill!(a, v); return ab)
fill!(a, √v)
fill!(b, √v)
return ab
end
function Base.fill!(ab::AbstractKroneckerVector, v)
a, b = kroneckerfactors(ab)
(!isactive(a) && all(isone, a)) && (fill!(b, v); return ab)
(!isactive(b) && all(isone, b)) && (fill!(a, v); return ab)
fill!(a, √v)
fill!(b, √v)
return ab
end

using Base.Broadcast: Broadcast, AbstractArrayStyle, BroadcastStyle, Broadcasted

struct KroneckerStyle{N, A, B} <: BC.AbstractArrayStyle{N} end
Expand Down
111 changes: 93 additions & 18 deletions src/matrixalgebrakit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,34 +80,109 @@ end

for f in (:eig_vals, :eigh_vals, :svd_vals)
f! = Symbol(f, :!)
@eval MAK.initialize_output(::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm) = nothing
@eval function MAK.initialize_output(
::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm
)
return nothing
end
@eval function MAK.$f!(ab::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm)
a, b = kroneckerfactors(ab)
algA, algB = kroneckerfactors(alg)
return MAK.$f(a, algA) ⊗ MAK.$f(b, algB)
end
end

for f in (:left_orth, :right_orth)
f! = Symbol(f, :!)
@eval MAK.initialize_output(::typeof($f!), a::AbstractKroneckerMatrix) = nothing
@eval function MAK.$f!(ab::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs...)
a, b = kroneckerfactors(ab)
Fa = MAK.$f(a; kwargs..., kwargs1...)
Fb = MAK.$f(b; kwargs..., kwargs2...)
return Fa .⊗ Fb
# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
# is merged.
for kind in ("polar", "qr", "svd")
@eval begin
function MAK.initialize_output(
::typeof(left_orth!), a::AbstractKroneckerMatrix,
alg::MAK.LeftOrthAlgorithm{Symbol($kind)},
)
return nothing
end
function MAK.left_orth!(
ab::AbstractKroneckerMatrix, F, alg::MAK.LeftOrthAlgorithm{Symbol($kind)};
kwargs1 = (;), kwargs2 = (;), kwargs...,
)
a, b = kroneckerfactors(ab)
Fa = MAK.left_orth!(a; kwargs..., kwargs1...)
Fb = MAK.left_orth!(b; kwargs..., kwargs2...)
return Fa .⊗ Fb
end
end
end

for f in [:left_null, :right_null]
f! = Symbol(f, :!)
@eval MAK.initialize_output(::typeof($f!), a::AbstractKroneckerMatrix) =
nothing
@eval function MAK.$f!(ab::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs...)
a, b = kroneckerfactors(ab)
Na = MAK.$f(a; kwargs..., kwargs1...)
Nb = MAK.$f(b; kwargs..., kwargs2...)
return Na ⊗ Nb
# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
# is merged.
for kind in ("lq", "polar", "svd")
@eval begin
function MAK.initialize_output(
::typeof(right_orth!), a::AbstractKroneckerMatrix,
alg::MAK.RightOrthAlgorithm{Symbol($kind)},
)
return nothing
end
function MAK.right_orth!(
ab::AbstractKroneckerMatrix, F, alg::MAK.RightOrthAlgorithm{Symbol($kind)};
kwargs1 = (;), kwargs2 = (;), kwargs...,
)
a, b = kroneckerfactors(ab)
Fa = MAK.right_orth!(a; kwargs..., kwargs1...)
Fb = MAK.right_orth!(b; kwargs..., kwargs2...)
return Fa .⊗ Fb
end
end
end

# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
# is merged.
for Alg in (
:(MAK.LeftNullViaQR),
:(MAK.LeftNullViaSVD{<:MAK.TruncatedAlgorithm}),
:(MAK.LeftNullViaSVD{<:MAK.TruncatedAlgorithm{<:MAK.GPU_Randomized}}),
)
@eval begin
function MAK.initialize_output(
::typeof(left_null!), a::AbstractKroneckerMatrix, alg::$Alg
)
return nothing
end
function MAK.left_null!(
ab::AbstractKroneckerMatrix, F, alg::$Alg;
kwargs1 = (;), kwargs2 = (;), kwargs...,
)
a, b = kroneckerfactors(ab)
Na = MAK.left_null!(a; kwargs..., kwargs1...)
Nb = MAK.left_null!(b; kwargs..., kwargs2...)
return Na ⊗ Nb
end
end
end

# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
# is merged.
for Alg in (
:(MAK.RightNullViaLQ),
:(MAK.RightNullViaSVD{<:MAK.TruncatedAlgorithm}),
:(MAK.RightNullViaSVD{<:MAK.TruncatedAlgorithm{<:MAK.GPU_Randomized}}),
)
@eval begin
function MAK.initialize_output(
::typeof(right_null!), a::AbstractKroneckerMatrix, alg::$Alg
)
return nothing
end
function MAK.right_null!(
ab::AbstractKroneckerMatrix, F, alg::$Alg;
kwargs1 = (;), kwargs2 = (;), kwargs...,
)
a, b = kroneckerfactors(ab)
Na = MAK.right_null!(a; kwargs..., kwargs1...)
Nb = MAK.right_null!(b; kwargs..., kwargs2...)
return Na ⊗ Nb
end
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ GPUArraysCore = "0.2"
JLArrays = "0.2, 0.3"
KroneckerArrays = "0.3"
LinearAlgebra = "1.10"
MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5"
MatrixAlgebraKit = "0.6"
SafeTestsets = "0.1"
StableRNGs = "1.0"
Suppressor = "0.2"
Expand Down
Loading