diff --git a/Project.toml b/Project.toml index c790519..6c99f47 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.14" +version = "0.1.15" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -10,6 +10,7 @@ DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" [weakdeps] @@ -28,5 +29,6 @@ DiagonalArrays = "0.3.5" FillArrays = "1.13.0" GPUArraysCore = "0.2.0" LinearAlgebra = "1.10" +MapBroadcast = "0.1.9" MatrixAlgebraKit = "0.2.0" julia = "1.10" diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index c734253..c77a2f5 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -75,3 +75,12 @@ for f in (:+, :-) end end end + +using Base.Broadcast: axistype +function Base.Broadcast.axistype( + r1::CartesianProductUnitRange, r2::CartesianProductUnitRange +) + prod = axistype(arg1(r1), arg1(r2)) × axistype(arg2(r1), arg2(r2)) + range = axistype(unproduct(r1), unproduct(r2)) + return cartesianrange(prod, range) +end diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index 689f1e4..0a39106 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -1,3 +1,13 @@ +using FillArrays: FillArrays, Zeros +function FillArrays.fillsimilar( + a::Zeros{T}, + ax::Tuple{ + CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} + }, +) where {T} + return Zeros{T}(arg1.(ax)) ⊗ Zeros{T}(arg2.(ax)) +end + using FillArrays: RectDiagonal, OnesVector const RectEye{T,V<:OnesVector{T},Axes} = RectDiagonal{T,V,Axes} @@ -208,3 +218,17 @@ end function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeEye, a::EyeEye) return error("Can't write in-place.") end + +using Base.Broadcast: + AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted + +struct EyeStyle <: AbstractArrayStyle{2} end +EyeStyle(::Val{2}) = EyeStyle() +function _BroadcastStyle(::Type{<:Eye}) + return EyeStyle() +end +Base.BroadcastStyle(style1::EyeStyle, style2::EyeStyle) = EyeStyle() + +function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type) + return Eye{elt}(axes(bc)) +end diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 5bbae8c..05f9b26 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -226,6 +226,8 @@ end for op in (:+, :-) @eval begin function Base.$op(a::KroneckerArray, b::KroneckerArray) + iszero(a) && return $op(b) + iszero(b) && return a if a.b == b.b return $op(a.a, b.a) ⊗ a.b elseif a.a == b.a @@ -241,8 +243,15 @@ for op in (:+, :-) end end -using Base.Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted +# Allows for customizations for FillArrays. +_BroadcastStyle(x) = BroadcastStyle(x) + +using Base.Broadcast: Broadcast, AbstractArrayStyle, BroadcastStyle, Broadcasted struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end +arg1(::Type{<:KroneckerStyle{<:Any,A}}) where {A} = A +arg1(style::KroneckerStyle) = arg1(typeof(style)) +arg2(::Type{<:KroneckerStyle{<:Any,B}}) where {B} = B +arg2(style::KroneckerStyle) = arg2(typeof(style)) function KroneckerStyle{N}(a::BroadcastStyle, b::BroadcastStyle) where {N} return KroneckerStyle{N,a,b}() end @@ -253,30 +262,69 @@ function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M} return KroneckerStyle{M,typeof(A)(v),typeof(B)(v)}() end function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any,N,A,B}}) where {N,A,B} - return KroneckerStyle{N}(BroadcastStyle(A), BroadcastStyle(B)) + return KroneckerStyle{N}(_BroadcastStyle(A), _BroadcastStyle(B)) end function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N} - return KroneckerStyle{N}( - BroadcastStyle(style1.a, style2.a), BroadcastStyle(style1.b, style2.b) - ) + style_a = BroadcastStyle(arg1(style1), arg1(style2)) + (style_a isa Broadcast.Unknown) && return Broadcast.Unknown() + style_b = BroadcastStyle(arg2(style1), arg2(style2)) + (style_b isa Broadcast.Unknown) && return Broadcast.Unknown() + return KroneckerStyle{N}(style_a, style_b) end function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type) where {N,A,B} - ax_a = map(ax -> ax.product.a, axes(bc)) - ax_b = map(ax -> ax.product.b, axes(bc)) + ax_a = arg1.(axes(bc)) + ax_b = arg2.(axes(bc)) bc_a = Broadcasted(A, nothing, (), ax_a) bc_b = Broadcasted(B, nothing, (), ax_b) a = similar(bc_a, elt) b = similar(bc_b, elt) return a ⊗ b end +# Fallback definition of broadcasting falls back to `map` but assumes +# inputs have been canonicalized to a map-compatible expression already, +# for example by absorbing scalar arguments into the function. function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:KroneckerStyle}) - return throw( - ArgumentError( - "Arbitrary broadcasting is not supported for KroneckerArrays since they might not preserve the Kronecker structure.", - ), - ) + allequal(axes, bc.args) || throw(ArgumentError("Broadcasted axes must be equal.")) + map!(bc.f, dest, bc.args...) + return dest end +# Broadcast rewrite rules. Canonicalize inputs to absorb scalar inputs into the +# function. +function Base.broadcasted(style::KroneckerStyle, ::typeof(*), a::Number, b::KroneckerArray) + return broadcasted(style, Base.Fix1(*, a), b) +end +function Base.broadcasted(style::KroneckerStyle, ::typeof(*), a::KroneckerArray, b::Number) + return broadcasted(style, Base.Fix2(*, b), a) +end +function Base.broadcasted(style::KroneckerStyle, ::typeof(/), a::KroneckerArray, b::Number) + return broadcasted(style, Base.Fix2(/, b), a) +end +using MapBroadcast: MapBroadcast, MapFunction +function Base.broadcasted( + style::KroneckerStyle, + f::MapFunction{typeof(*),<:Tuple{<:Number,MapBroadcast.Arg}}, + a::KroneckerArray, +) + return broadcasted(style, Base.Fix1(*, f.args[1]), a) +end +function Base.broadcasted( + style::KroneckerStyle, + f::MapFunction{typeof(*),<:Tuple{MapBroadcast.Arg,<:Number}}, + a::KroneckerArray, +) + return broadcasted(style, Base.Fix2(*, f.args[2]), a) +end +function Base.broadcasted( + style::KroneckerStyle, + f::MapFunction{typeof(/),<:Tuple{MapBroadcast.Arg,<:Number}}, + a::KroneckerArray, +) + return broadcasted(style, Base.Fix2(/, f.args[2]), a) +end + +# TODO: Define by converting to a broadcast expession (with MapBroadcast.jl) +# and then constructing the output with `similar`. function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...) return throw( ArgumentError( @@ -312,6 +360,8 @@ for f in [:+, :-] function Base.map!( ::typeof($f), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray ) + iszero(b) && return map!(identity, dest, a) + iszero(a) && return map!($f, dest, b) if a.b == b.b map!($f, dest.a, a.a, b.a) map!(identity, dest.b, a.b) @@ -350,6 +400,15 @@ for op in [:*, :/] end end end +for f in [:+, :-] + @eval begin + function Base.map!(::typeof($f), dest::KroneckerArray, src::KroneckerArray) + map!($f, dest.a, src.a) + map!(identity, dest.b, src.b) + return dest + end + end +end using DiagonalArrays: DiagonalArrays, diagonal function DiagonalArrays.diagonal(a::KroneckerArray) diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index 2218363..95bed45 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -15,6 +15,15 @@ using LinearAlgebra: svdvals, tr +using LinearAlgebra: LinearAlgebra +function KroneckerArray(J::LinearAlgebra.UniformScaling, ax::Tuple) + return Eye{eltype(J)}(arg1.(ax)) ⊗ Eye{eltype(J)}(arg2.(ax)) +end +function Base.copyto!(a::KroneckerArray, J::LinearAlgebra.UniformScaling) + copyto!(a, KroneckerArray(J, axes(a))) + return a +end + using LinearAlgebra: LinearAlgebra, pinv function LinearAlgebra.pinv(a::KroneckerArray; kwargs...) return pinv(a.a; kwargs...) ⊗ pinv(a.b; kwargs...) diff --git a/test/test_basics.jl b/test/test_basics.jl index 45f6403..47460d3 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -87,14 +87,15 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) a′ = similar(a) @test_throws "not supported" a′ .= sin.(a) a′ = similar(a) - @test_broken a′ .= 2 .* a + a′ .= 2 .* a + @test collect(a′) ≈ 2 * collect(a) bc = broadcasted(+, a, a) @test bc.style === style @test similar(bc, elt) isa KroneckerArray{elt,2,typeof(a.a),typeof(a.b)} - @test_broken copy(bc) + @test collect(copy(bc)) ≈ 2 * collect(a) bc = broadcasted(*, 2, a) @test bc.style === style - @test_broken copy(bc) + @test collect(copy(bc)) ≈ 2 * collect(a) # Mapping a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index 0f2ea3a..178e6bd 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -64,12 +64,11 @@ arrayts = (Array, JLArray) @test_broken inv(a) end - if (VERSION ≤ v"1.11-" && arrayt === Array && elt <: Complex) || - (arrayt === Array && elt <: Real) + if arrayt === Array u, s, v = svd_compact(a) @test Array(u * s * v) ≈ Array(a) else - # Broken on GPU and for complex, investigate. + # Broken on GPU. @test_broken svd_compact(a) end @@ -135,14 +134,17 @@ end @test_broken exp(a) end - if VERSION < v"1.11-" && elt <: Complex - # Broken because of type stability issue in Julia v1.10. - @test_broken svd_compact(a) - elseif arrayt === Array + ## if VERSION < v"1.11-" && elt <: Complex + ## # Broken because of type stability issue in Julia v1.10. + ## @test_broken svd_compact(a) + if arrayt === Array u, s, v = svd_compact(a) @test u * s * v ≈ a - @test blocktype(u) === blocktype(a) - @test blocktype(v) === blocktype(a) + @test blocktype(u) >: blocktype(u) + @test eltype(u) === eltype(a) + @test blocktype(v) >: blocktype(a) + @test eltype(v) === eltype(a) + @test eltype(s) === real(eltype(a)) else @test_broken svd_compact(a) end @@ -150,4 +152,16 @@ end # Broken operations @test_broken inv(a) @test_broken a[Block.(1:2), Block(2)] + + @testset "Block deficient" begin + d = Dict(Block(1, 1) => Eye{elt}(2, 2) ⊗ dev(randn(elt, 2, 2))) + a = @constinferred dev(blocksparse(d, r, r)) + + d = Dict(Block(2, 2) => Eye{elt}(3, 3) ⊗ dev(randn(elt, 3, 3))) + b = @constinferred dev(blocksparse(d, r, r)) + + @test_broken a + b + # @test Array(a + b) ≈ Array(a) + Array(b) + # @test Array(2a) ≈ 2Array(a) + end end