From fbb91929b0076510100134ee0db86d20cd8405d6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 9 Jun 2025 16:58:44 -0400 Subject: [PATCH 01/11] Add ndims type parameter to ArrayInterface --- Project.toml | 2 +- src/abstractarrayinterface.jl | 11 ++++++---- src/concatenate.jl | 35 ++++++++++++++++++++---------- src/defaultarrayinterface.jl | 24 +++++++++++++++----- test/SparseArrayDOKs.jl | 12 +++++++--- test/test_defaultarrayinterface.jl | 3 ++- 6 files changed, 62 insertions(+), 25 deletions(-) diff --git a/Project.toml b/Project.toml index aa5072a..bb5dc63 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DerivableInterfaces" uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" authors = ["ITensor developers and contributors"] -version = "0.4.5" +version = "0.4.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/abstractarrayinterface.jl b/src/abstractarrayinterface.jl index 865787d..3f6f2eb 100644 --- a/src/abstractarrayinterface.jl +++ b/src/abstractarrayinterface.jl @@ -1,19 +1,22 @@ # TODO: Add `ndims` type parameter. -abstract type AbstractArrayInterface <: AbstractInterface end +abstract type AbstractArrayInterface{N} <: AbstractInterface end +function interface(::Type{<:Broadcast.AbstractArrayStyle{N}}) where {N} + return DefaultArrayInterface{N}() +end function interface(::Type{<:Broadcast.AbstractArrayStyle}) return DefaultArrayInterface() end -function interface(::Type{<:Broadcast.Broadcasted{Nothing}}) - return DefaultArrayInterface() +function interface(BC::Type{<:Broadcast.Broadcasted{Nothing}}) + return DefaultArrayInterface{ndims(BC)}() end function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style} return interface(Style) end -# TODO: Define as `Array{T}`. +# TODO: Define as `Array{T,N}`. arraytype(::AbstractArrayInterface, T::Type) = error("Not implemented.") using ArrayLayouts: ArrayLayouts diff --git a/src/concatenate.jl b/src/concatenate.jl index fcefee8..3020645 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -29,11 +29,16 @@ export concatenate using Base: promote_eltypeof using ..DerivableInterfaces: - DerivableInterfaces, AbstractInterface, interface, zero!, arraytype + DerivableInterfaces, AbstractArrayInterface, interface, zero!, arraytype unval(x) = x unval(::Val{x}) where {x} = x +set_interface_dims(::Type{Nothing}, ::Val{N}) where {N} = nothing +function set_interface_dims(Interface::Type{<:AbstractArrayInterface}, ::Val{N}) where {N} + return Interface(Val(N)) +end + function _Concatenated end """ @@ -42,25 +47,32 @@ function _Concatenated end Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide hooks to customize the implementation. """ -struct Concatenated{Interface,Dims,Args<:Tuple} +struct Concatenated{Interface<:Union{AbstractArrayInterface,Nothing},Dims,Args<:Tuple,N} interface::Interface dims::Val{Dims} args::Args global @inline function _Concatenated( - interface::Interface, dims::Val{Dims}, args::Args - ) where {Interface,Dims,Args<:Tuple} - return new{Interface,Dims,Args}(interface, dims, args) + interface::Interface, dims::Val{Dims}, args::Args, ndims::Val{N} + ) where {Interface<:Union{AbstractArrayInterface,Nothing},Dims,Args<:Tuple,N} + return new{Interface,Dims,Args,N}(interface, dims, args) end end -function Concatenated(interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple) - return _Concatenated(interface, dims, args) +function Concatenated(interface::Nothing, dims::Val, args::Tuple) + N = cat_ndims(dims, args...) + return _Concatenated(interface, dims, args, Val(N)) +end +function Concatenated(interface::AbstractArrayInterface, dims::Val, args::Tuple) + N = cat_ndims(dims, args...) + return _Concatenated(typeof(interface)(Val(N)), dims, args, Val(N)) end function Concatenated(dims::Val, args::Tuple) - return Concatenated(interface(args...), dims, args) + N = cat_ndims(dims, args...) + return _Concatenated(typeof(interface(args...))(Val(N)), dims, args, Val(N)) end function Concatenated{Interface}(dims::Val, args::Tuple) where {Interface} - return Concatenated(Interface(), dims, args) + N = cat_ndims(dims, args...) + return _Concatenated(set_interface_dims(Interface, Val(N)), dims, args, Val(N)) end dims(::Concatenated{<:Any,D}) where {D} = D @@ -82,7 +94,7 @@ end Base.similar(concat::Concatenated) = similar(concat, eltype(concat)) Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat)) function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T} - return similar(arraytype(interface(concat), T), ax) + return similar(arraytype(typeof(interface(concat))(Val(ndims(concat))), T), ax) end function cat_axis( @@ -111,7 +123,8 @@ end Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...) Base.size(concat::Concatenated) = length.(axes(concat)) -Base.ndims(concat::Concatenated) = length(axes(concat)) +Base.ndims(concat::Concatenated) = ndims(typeof(concat)) +Base.ndims(::Type{<:Concatenated{<:Any,<:Any,<:Any,N}}) where {N} = N # Main logic # ---------- diff --git a/src/defaultarrayinterface.jl b/src/defaultarrayinterface.jl index 538b8e7..f56dcd3 100644 --- a/src/defaultarrayinterface.jl +++ b/src/defaultarrayinterface.jl @@ -1,12 +1,26 @@ -# TODO: Add `ndims` type parameter. -struct DefaultArrayInterface <: AbstractArrayInterface end +struct DefaultArrayInterface{N} <: AbstractArrayInterface{N} end + +DefaultArrayInterface() = DefaultArrayInterface{Any}() +DefaultArrayInterface(::Val{N}) where {N} = DefaultArrayInterface{N}() +DefaultArrayInterface{M}(::Val{N}) where {M,N} = DefaultArrayInterface{N}() using TypeParameterAccessors: parenttype function interface(a::Type{<:AbstractArray}) - parenttype(a) === a && return DefaultArrayInterface() + parenttype(a) === a && return DefaultArrayInterface{ndims(a)}() return interface(parenttype(a)) end +function combine_interface_rule( + interface1::DefaultArrayInterface{N}, interface2::DefaultArrayInterface{N} +) where {N} + return DefaultArrayInterface{N}() +end +function combine_interface_rule( + interface1::DefaultArrayInterface, interface2::DefaultArrayInterface +) + return DefaultArrayInterface{Any}() +end + @interface ::DefaultArrayInterface function Base.getindex( a::AbstractArray{<:Any,N}, I::Vararg{Int,N} ) where {N} @@ -31,6 +45,6 @@ end return Base.mapreduce(f, op, as...; kwargs...) end -function arraytype(::DefaultArrayInterface, T::Type) - return Array{T} +function arraytype(::DefaultArrayInterface{N}, T::Type) where {N} + return Array{T,N} end diff --git a/test/SparseArrayDOKs.jl b/test/SparseArrayDOKs.jl index 0f5c68c..9b34929 100644 --- a/test/SparseArrayDOKs.jl +++ b/test/SparseArrayDOKs.jl @@ -40,7 +40,9 @@ using DerivableInterfaces: using LinearAlgebra: LinearAlgebra # Define an interface. -struct SparseArrayInterface <: AbstractArrayInterface end +struct SparseArrayInterface{N} <: AbstractArrayInterface{N} end +SparseArrayInterface(::Val{N}) where {N} = SparseArrayInterface{N}() +SparseArrayInterface{M}(::Val{N}) where {M,N} = SparseArrayInterface{N}() # Define interface functions. @interface ::SparseArrayInterface function Base.getindex( @@ -66,7 +68,9 @@ end struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}() -DerivableInterfaces.interface(::Type{<:SparseArrayStyle}) = SparseArrayInterface() +function DerivableInterfaces.interface(::Type{<:SparseArrayStyle{N}}) where {N} + return SparseArrayInterface{N}() +end @derive SparseArrayStyle AbstractArrayStyleOps @@ -260,7 +264,9 @@ function DerivableInterfaces.zero!(a::SparseArrayDOK) end # Specify the interface the type adheres to. -DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface() +function DerivableInterfaces.interface(arrayt::Type{<:SparseArrayDOK}) + SparseArrayInterface{ndims(arrayt)}() +end # Define aliases like `SparseMatrixDOK`, `AnySparseArrayDOK`, etc. @array_aliases SparseArrayDOK diff --git a/test/test_defaultarrayinterface.jl b/test/test_defaultarrayinterface.jl index d12bade..d16f152 100644 --- a/test/test_defaultarrayinterface.jl +++ b/test/test_defaultarrayinterface.jl @@ -33,6 +33,7 @@ end @testset "Broadcast.DefaultArrayStyle" begin @test interface(Broadcast.DefaultArrayStyle) == DefaultArrayInterface() + @test interface(Broadcast.DefaultArrayStyle{2}) == DefaultArrayInterface{2}() @test interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) == - DefaultArrayInterface() + DefaultArrayInterface{1}() end From 3011bfcac37d84018c0cf78745753d5bbd9db94a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 9 Jun 2025 17:13:43 -0400 Subject: [PATCH 02/11] Remove ndims type parameter from Concatenated --- src/concatenate.jl | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index 3020645..bf3539d 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -47,32 +47,31 @@ function _Concatenated end Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide hooks to customize the implementation. """ -struct Concatenated{Interface<:Union{AbstractArrayInterface,Nothing},Dims,Args<:Tuple,N} +struct Concatenated{Interface<:Union{AbstractArrayInterface,Nothing},Dims,Args<:Tuple} interface::Interface dims::Val{Dims} args::Args global @inline function _Concatenated( - interface::Interface, dims::Val{Dims}, args::Args, ndims::Val{N} - ) where {Interface<:Union{AbstractArrayInterface,Nothing},Dims,Args<:Tuple,N} - return new{Interface,Dims,Args,N}(interface, dims, args) + interface::Interface, dims::Val{Dims}, args::Args + ) where {Interface<:Union{AbstractArrayInterface,Nothing},Dims,Args<:Tuple} + return new{Interface,Dims,Args}(interface, dims, args) end end function Concatenated(interface::Nothing, dims::Val, args::Tuple) - N = cat_ndims(dims, args...) - return _Concatenated(interface, dims, args, Val(N)) + return _Concatenated(interface, dims, args) end function Concatenated(interface::AbstractArrayInterface, dims::Val, args::Tuple) N = cat_ndims(dims, args...) - return _Concatenated(typeof(interface)(Val(N)), dims, args, Val(N)) + return _Concatenated(typeof(interface)(Val(N)), dims, args) end function Concatenated(dims::Val, args::Tuple) N = cat_ndims(dims, args...) - return _Concatenated(typeof(interface(args...))(Val(N)), dims, args, Val(N)) + return _Concatenated(typeof(interface(args...))(Val(N)), dims, args) end function Concatenated{Interface}(dims::Val, args::Tuple) where {Interface} N = cat_ndims(dims, args...) - return _Concatenated(set_interface_dims(Interface, Val(N)), dims, args, Val(N)) + return _Concatenated(set_interface_dims(Interface, Val(N)), dims, args) end dims(::Concatenated{<:Any,D}) where {D} = D @@ -94,7 +93,7 @@ end Base.similar(concat::Concatenated) = similar(concat, eltype(concat)) Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat)) function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T} - return similar(arraytype(typeof(interface(concat))(Val(ndims(concat))), T), ax) + return similar(arraytype(interface(concat), T), ax) end function cat_axis( @@ -123,8 +122,7 @@ end Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...) Base.size(concat::Concatenated) = length.(axes(concat)) -Base.ndims(concat::Concatenated) = ndims(typeof(concat)) -Base.ndims(::Type{<:Concatenated{<:Any,<:Any,<:Any,N}}) where {N} = N +Base.ndims(concat::Concatenated) = length(axes(concat)) # Main logic # ---------- From 741c75cd7dc93eae402af3d7f37f3f6f00838603 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 9 Jun 2025 17:40:25 -0400 Subject: [PATCH 03/11] Mark as breaking --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index bb5dc63..ac8a237 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DerivableInterfaces" uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" authors = ["ITensor developers and contributors"] -version = "0.4.6" +version = "0.5.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 70f82e10908822c123973afddca24b19bc3d60ee Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 9 Jun 2025 17:44:23 -0400 Subject: [PATCH 04/11] Slight name improvement --- src/concatenate.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index bf3539d..7e1ed99 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -34,8 +34,8 @@ using ..DerivableInterfaces: unval(x) = x unval(::Val{x}) where {x} = x -set_interface_dims(::Type{Nothing}, ::Val{N}) where {N} = nothing -function set_interface_dims(Interface::Type{<:AbstractArrayInterface}, ::Val{N}) where {N} +set_interface_ndims(::Type{Nothing}, ::Val{N}) where {N} = nothing +function set_interface_ndims(Interface::Type{<:AbstractArrayInterface}, ::Val{N}) where {N} return Interface(Val(N)) end @@ -71,7 +71,7 @@ function Concatenated(dims::Val, args::Tuple) end function Concatenated{Interface}(dims::Val, args::Tuple) where {Interface} N = cat_ndims(dims, args...) - return _Concatenated(set_interface_dims(Interface, Val(N)), dims, args) + return _Concatenated(set_interface_ndims(Interface, Val(N)), dims, args) end dims(::Concatenated{<:Any,D}) where {D} = D From ecd8379879cf69d943cd78c0c2f97dca25700916 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 9 Jun 2025 17:45:45 -0400 Subject: [PATCH 05/11] Bump subdir packages --- docs/Project.toml | 2 +- examples/Project.toml | 2 +- test/Project.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 8b0bf74..6e4dd3d 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" [compat] -DerivableInterfaces = "0.4" +DerivableInterfaces = "0.5" Documenter = "1" Literate = "2" diff --git a/examples/Project.toml b/examples/Project.toml index cf1b8a6..ebb1aef 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -4,4 +4,4 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" [compat] ArrayLayouts = "1" -DerivableInterfaces = "0.4" +DerivableInterfaces = "0.5" diff --git a/test/Project.toml b/test/Project.toml index 87525d0..c7cacf1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,7 +10,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8" ArrayLayouts = "1" -DerivableInterfaces = "0.4" +DerivableInterfaces = "0.5" SafeTestsets = "0.1" Suppressor = "0.2" LinearAlgebra = "1" From d9faf8c861cf754948058cbe71eae4543b7790b7 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 9 Jun 2025 17:58:24 -0400 Subject: [PATCH 06/11] Test and generalize arraytype --- src/defaultarrayinterface.jl | 3 +++ test/test_defaultarrayinterface.jl | 14 +++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/defaultarrayinterface.jl b/src/defaultarrayinterface.jl index f56dcd3..ba9ab08 100644 --- a/src/defaultarrayinterface.jl +++ b/src/defaultarrayinterface.jl @@ -48,3 +48,6 @@ end function arraytype(::DefaultArrayInterface{N}, T::Type) where {N} return Array{T,N} end +function arraytype(::DefaultArrayInterface{Any}, T::Type) + return Array{T} +end diff --git a/test/test_defaultarrayinterface.jl b/test/test_defaultarrayinterface.jl index d16f152..4924159 100644 --- a/test/test_defaultarrayinterface.jl +++ b/test/test_defaultarrayinterface.jl @@ -1,5 +1,5 @@ using Test: @inferred, @testset, @test -using DerivableInterfaces: @interface, DefaultArrayInterface, interface +using DerivableInterfaces: @interface, DefaultArrayInterface, arraytype, interface # function wrappers to test type-stability _getindex(A, i...) = @interface DefaultArrayInterface() A[i...] @@ -31,6 +31,18 @@ end @test a == mapreduce(Returns(2), +, A) end +@testset "DefaultArrayInterface" begin + @test DefaultArrayInterface() === DefaultArrayInterface{Any}() + @test DefaultArrayInterface(Val(2)) === DefaultArrayInterface{2}() + @test DefaultArrayInterface{Any}(Val(2)) === DefaultArrayInterface{2}() + @test DefaultArrayInterface{3}(Val(2)) === DefaultArrayInterface{2}() +end + +@testset "arraytype" begin + @test arraytype(DefaultArrayInterface{2}(), Float32) == Matrix{Float32} + @test arraytype(DefaultArrayInterface(), Float32) == Array{Float32} +end + @testset "Broadcast.DefaultArrayStyle" begin @test interface(Broadcast.DefaultArrayStyle) == DefaultArrayInterface() @test interface(Broadcast.DefaultArrayStyle{2}) == DefaultArrayInterface{2}() From 1942bddaa7d11fe8aaea7fdc17053219abd37149 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 10 Jun 2025 13:28:15 -0400 Subject: [PATCH 07/11] Remove arraytype, use similar instead --- src/abstractarrayinterface.jl | 11 ++++++----- src/concatenate.jl | 14 ++++++++------ src/defaultarrayinterface.jl | 13 +++++++------ test/SparseArrayDOKs.jl | 4 +++- test/test_defaultarrayinterface.jl | 28 +++++++++++++++++++--------- 5 files changed, 43 insertions(+), 27 deletions(-) diff --git a/src/abstractarrayinterface.jl b/src/abstractarrayinterface.jl index 3f6f2eb..6eba847 100644 --- a/src/abstractarrayinterface.jl +++ b/src/abstractarrayinterface.jl @@ -16,8 +16,10 @@ function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style} return interface(Style) end -# TODO: Define as `Array{T,N}`. -arraytype(::AbstractArrayInterface, T::Type) = error("Not implemented.") +# TODO: Define as `similar(Array{T}, ax)`. +function Base.similar(interface::AbstractArrayInterface, T::Type, ax::Tuple) + return error("Not implemented.") +end using ArrayLayouts: ArrayLayouts @@ -88,7 +90,7 @@ end @interface interface::AbstractArrayInterface function Base.similar( a::AbstractArray, T::Type, size::Tuple{Vararg{Int}} ) - return similar(arraytype(interface, T), size) + return similar(interface, T, size) end @interface ::AbstractArrayInterface function Base.copy(a::AbstractArray) @@ -108,8 +110,7 @@ end @interface interface::AbstractArrayInterface function Base.similar( bc::Broadcast.Broadcasted, T::Type, axes::Tuple ) - # `arraytype(::AbstractInterface)` determines the default array type associated with the interface. - return similar(arraytype(interface, T), axes) + return similar(interface, T, axes) end using MapBroadcast: Mapped diff --git a/src/concatenate.jl b/src/concatenate.jl index 7e1ed99..8e2cfe9 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -28,8 +28,7 @@ export concatenate @compat public Concatenated, cat, cat!, concatenated using Base: promote_eltypeof -using ..DerivableInterfaces: - DerivableInterfaces, AbstractArrayInterface, interface, zero!, arraytype +using ..DerivableInterfaces: DerivableInterfaces, AbstractArrayInterface, interface, zero! unval(x) = x unval(::Val{x}) where {x} = x @@ -47,13 +46,13 @@ function _Concatenated end Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide hooks to customize the implementation. """ -struct Concatenated{Interface<:Union{AbstractArrayInterface,Nothing},Dims,Args<:Tuple} +struct Concatenated{Interface,Dims,Args<:Tuple} interface::Interface dims::Val{Dims} args::Args global @inline function _Concatenated( interface::Interface, dims::Val{Dims}, args::Args - ) where {Interface<:Union{AbstractArrayInterface,Nothing},Dims,Args<:Tuple} + ) where {Interface,Dims,Args<:Tuple} return new{Interface,Dims,Args}(interface, dims, args) end end @@ -92,8 +91,11 @@ end # ------------------------------------ Base.similar(concat::Concatenated) = similar(concat, eltype(concat)) Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat)) -function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T} - return similar(arraytype(interface(concat), T), ax) +function Base.similar(concat::Concatenated, ax::Tuple) + return similar(interface(concat), eltype(concat), ax) +end +function Base.similar(concat::Concatenated, ::Type{T}, ax::Tuple) where {T} + return similar(interface(concat), T, ax) end function cat_axis( diff --git a/src/defaultarrayinterface.jl b/src/defaultarrayinterface.jl index ba9ab08..3861a3e 100644 --- a/src/defaultarrayinterface.jl +++ b/src/defaultarrayinterface.jl @@ -6,7 +6,11 @@ DefaultArrayInterface{M}(::Val{N}) where {M,N} = DefaultArrayInterface{N}() using TypeParameterAccessors: parenttype function interface(a::Type{<:AbstractArray}) - parenttype(a) === a && return DefaultArrayInterface{ndims(a)}() + parenttype(a) === a && return DefaultArrayInterface() + return interface(parenttype(a)) +end +function interface(a::Type{<:AbstractArray{<:Any,N}}) where {N} + parenttype(a) === a && return DefaultArrayInterface{N}() return interface(parenttype(a)) end @@ -45,9 +49,6 @@ end return Base.mapreduce(f, op, as...; kwargs...) end -function arraytype(::DefaultArrayInterface{N}, T::Type) where {N} - return Array{T,N} -end -function arraytype(::DefaultArrayInterface{Any}, T::Type) - return Array{T} +function Base.similar(::DefaultArrayInterface, T::Type, ax::Tuple) + return similar(Array{T}, ax) end diff --git a/test/SparseArrayDOKs.jl b/test/SparseArrayDOKs.jl index 9b34929..9f80c3f 100644 --- a/test/SparseArrayDOKs.jl +++ b/test/SparseArrayDOKs.jl @@ -74,7 +74,9 @@ end @derive SparseArrayStyle AbstractArrayStyleOps -DerivableInterfaces.arraytype(::SparseArrayInterface, T::Type) = SparseArrayDOK{T} +function Base.similar(::SparseArrayInterface, T::Type, ax::Tuple) + return similar(SparseArrayDOK{T}, ax) +end # Interface functions. @interface ::SparseArrayInterface function Broadcast.BroadcastStyle(type::Type) diff --git a/test/test_defaultarrayinterface.jl b/test/test_defaultarrayinterface.jl index 4924159..6d2cb23 100644 --- a/test/test_defaultarrayinterface.jl +++ b/test/test_defaultarrayinterface.jl @@ -1,5 +1,6 @@ -using Test: @inferred, @testset, @test -using DerivableInterfaces: @interface, DefaultArrayInterface, arraytype, interface +using DerivableInterfaces: @interface, DefaultArrayInterface, interface +using Test: @testset, @test +using TestExtras: @constinferred # function wrappers to test type-stability _getindex(A, i...) = @interface DefaultArrayInterface() A[i...] @@ -11,36 +12,45 @@ end @testset "indexing" begin for (A, i) in ((zeros(2), 2), (zeros(2, 2), (2, 1)), (zeros(1, 2, 3), (1, 2, 3))) - a = @inferred _getindex(A, i...) + a = @constinferred _getindex(A, i...) @test a == A[i...] v = 1.1 - A′ = @inferred _setindex!(A, v, i...) + A′ = @constinferred _setindex!(A, v, i...) @test A′ == (A[i...] = v) end end @testset "map!" begin A = zeros(3) - a = @inferred _map!(Returns(2), copy(A), A) + a = @constinferred _map!(Returns(2), copy(A), A) @test a == map!(Returns(2), copy(A), A) end @testset "mapreduce" begin A = zeros(3) - a = @inferred _mapreduce(Returns(2), +, A) + a = @constinferred _mapreduce(Returns(2), +, A) @test a == mapreduce(Returns(2), +, A) end @testset "DefaultArrayInterface" begin + @test interface(Array) === DefaultArrayInterface{Any}() + @test interface(Array{Float32}) === DefaultArrayInterface{Any}() + @test interface(Matrix) === DefaultArrayInterface{2}() + @test interface(Matrix{Float32}) === DefaultArrayInterface{2}() @test DefaultArrayInterface() === DefaultArrayInterface{Any}() @test DefaultArrayInterface(Val(2)) === DefaultArrayInterface{2}() @test DefaultArrayInterface{Any}(Val(2)) === DefaultArrayInterface{2}() @test DefaultArrayInterface{3}(Val(2)) === DefaultArrayInterface{2}() end -@testset "arraytype" begin - @test arraytype(DefaultArrayInterface{2}(), Float32) == Matrix{Float32} - @test arraytype(DefaultArrayInterface(), Float32) == Array{Float32} +@testset "similar(::DefaultArrayInterface, ...)" begin + a = @constinferred similar(DefaultArrayInterface(), Float32, (2, 2)) + @test typeof(a) === Matrix{Float32} + @test size(a) == (2, 2) + + a = @constinferred similar(DefaultArrayInterface{1}(), Float32, (2, 2)) + @test typeof(a) === Matrix{Float32} + @test size(a) == (2, 2) end @testset "Broadcast.DefaultArrayStyle" begin From 51fbbb1db024d842ab3b001681a844c00eabbf61 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 10 Jun 2025 13:30:37 -0400 Subject: [PATCH 08/11] Add missing test dep --- test/Project.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index c7cacf1..2c88278 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,12 +6,14 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [compat] Aqua = "0.8" ArrayLayouts = "1" DerivableInterfaces = "0.5" +LinearAlgebra = "1" SafeTestsets = "0.1" Suppressor = "0.2" -LinearAlgebra = "1" Test = "1" +TestExtras = "0.3" From a2ab28d649f98b01dfa6cfe23f8ed2241cd31b8b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 10 Jun 2025 14:44:56 -0400 Subject: [PATCH 09/11] Improve Concatenated constructors --- src/concatenate.jl | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index 8e2cfe9..a542151 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -65,12 +65,18 @@ function Concatenated(interface::AbstractArrayInterface, dims::Val, args::Tuple) return _Concatenated(typeof(interface)(Val(N)), dims, args) end function Concatenated(dims::Val, args::Tuple) - N = cat_ndims(dims, args...) - return _Concatenated(typeof(interface(args...))(Val(N)), dims, args) + return Concatenated(interface(args...), dims, args) end -function Concatenated{Interface}(dims::Val, args::Tuple) where {Interface} - N = cat_ndims(dims, args...) - return _Concatenated(set_interface_ndims(Interface, Val(N)), dims, args) +function Concatenated{Nothing}(dims::Val, args::Tuple) + return _Concatenated(nothing, dims, args) +end +function Concatenated{Interface}( + dims::Val, args::Tuple +) where {N,Interface<:AbstractArrayInterface{N}} + (N === Any) || + (N == cat_ndims(dims, args...)) || + throw(ArgumentError("Input interface type has incorrect `ndims`.")) + return _Concatenated(Interface(), dims, args) end dims(::Concatenated{<:Any,D}) where {D} = D @@ -124,7 +130,7 @@ end Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...) Base.size(concat::Concatenated) = length.(axes(concat)) -Base.ndims(concat::Concatenated) = length(axes(concat)) +Base.ndims(concat::Concatenated) = cat_ndims(dims(concat), concat.args...) # Main logic # ---------- From 8c3705b734cb2b1b8de7398d1053d482a4917b9c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 10 Jun 2025 14:46:07 -0400 Subject: [PATCH 10/11] Cleanup --- src/concatenate.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index a542151..3f93015 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -33,11 +33,6 @@ using ..DerivableInterfaces: DerivableInterfaces, AbstractArrayInterface, interf unval(x) = x unval(::Val{x}) where {x} = x -set_interface_ndims(::Type{Nothing}, ::Val{N}) where {N} = nothing -function set_interface_ndims(Interface::Type{<:AbstractArrayInterface}, ::Val{N}) where {N} - return Interface(Val(N)) -end - function _Concatenated end """ From ce0ee300c5f9b3420ec6ebc5187ad72ca6a1b381 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 10 Jun 2025 14:58:28 -0400 Subject: [PATCH 11/11] Refactor constructors more --- src/concatenate.jl | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/concatenate.jl b/src/concatenate.jl index 3f93015..5f7b019 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -52,26 +52,18 @@ struct Concatenated{Interface,Dims,Args<:Tuple} end end -function Concatenated(interface::Nothing, dims::Val, args::Tuple) +function Concatenated( + interface::Union{AbstractArrayInterface,Nothing}, dims::Val, args::Tuple +) return _Concatenated(interface, dims, args) end -function Concatenated(interface::AbstractArrayInterface, dims::Val, args::Tuple) - N = cat_ndims(dims, args...) - return _Concatenated(typeof(interface)(Val(N)), dims, args) -end function Concatenated(dims::Val, args::Tuple) - return Concatenated(interface(args...), dims, args) -end -function Concatenated{Nothing}(dims::Val, args::Tuple) - return _Concatenated(nothing, dims, args) + return Concatenated(cat_interface(dims, args...), dims, args) end function Concatenated{Interface}( dims::Val, args::Tuple -) where {N,Interface<:AbstractArrayInterface{N}} - (N === Any) || - (N == cat_ndims(dims, args...)) || - throw(ArgumentError("Input interface type has incorrect `ndims`.")) - return _Concatenated(Interface(), dims, args) +) where {Interface<:Union{AbstractArrayInterface,Nothing}} + return Concatenated(Interface(), dims, args) end dims(::Concatenated{<:Any,D}) where {D} = D @@ -122,6 +114,11 @@ function cat_axes(dims::Val, as::AbstractArray...) return cat_axes(unval(dims), as...) end +function cat_interface(dims, as::AbstractArray...) + N = cat_ndims(dims, as...) + return typeof(interface(as...))(Val(N)) +end + Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...) Base.size(concat::Concatenated) = length.(axes(concat))