Skip to content

Commit 43fc9d6

Browse files
authored
Update to Derive and SparseArraysBase v0.2 (#3)
1 parent d44f51c commit 43fc9d6

File tree

7 files changed

+175
-65
lines changed

7 files changed

+175
-65
lines changed

Project.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
name = "DiagonalArrays"
22
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.1.0"
4+
version = "0.2.0"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
88
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
9-
NestedPermutedDimsArrays = "2c2a8ec4-3cfc-4276-aa3e-1307b4294e58"
9+
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
1010
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
1111
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1212

1313
[compat]
1414
ArrayLayouts = "1.10.4"
1515
BroadcastMapConversion = "0.1"
16-
NestedPermutedDimsArrays = "0.1"
17-
SparseArraysBase = "0.1"
18-
TypeParameterAccessors = "0.1"
16+
Derive = "0.3.6"
17+
SparseArraysBase = "0.2.1"
18+
TypeParameterAccessors = "0.2"
1919
julia = "1.10"

examples/Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
[deps]
2-
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
2+
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
33
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
4-
NestedPermutedDimsArrays = "2c2a8ec4-3cfc-4276-aa3e-1307b4294e58"
54
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
6-
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
Lines changed: 86 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,95 @@
1-
using SparseArraysBase: SparseArraysBase, StorageIndex, StorageIndices
1+
# TODO: Turn these into `@interface ::AbstractDiagonalArrayInterface` functions.
22

3-
SparseArraysBase.StorageIndex(i::DiagIndex) = StorageIndex(index(i))
3+
diagview(a::AbstractDiagonalArray) = throw(MethodError(diagview, Tuple{typeof(a)}))
44

5-
function Base.getindex(a::AbstractDiagonalArray, i::DiagIndex)
6-
return a[StorageIndex(i)]
7-
end
5+
using Derive: Derive, @interface
6+
using SparseArraysBase:
7+
SparseArraysBase, AbstractSparseArrayInterface, AbstractSparseArrayStyle ## , StorageIndex, StorageIndices
88

9-
function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndex)
10-
a[StorageIndex(i)] = value
11-
return a
12-
end
9+
abstract type AbstractDiagonalArrayInterface <: AbstractSparseArrayInterface end
10+
11+
struct DiagonalArrayInterface <: AbstractDiagonalArrayInterface end
12+
13+
Derive.arraytype(::AbstractDiagonalArrayInterface, elt::Type) = DiagonalArray{elt}
14+
Derive.interface(::Type{<:AbstractDiagonalArray}) = DiagonalArrayInterface()
15+
16+
abstract type AbstractDiagonalArrayStyle{N} <: AbstractSparseArrayStyle{N} end
1317

14-
SparseArraysBase.StorageIndices(i::DiagIndices) = StorageIndices(indices(i))
18+
Derive.interface(::Type{<:AbstractDiagonalArrayStyle}) = DiagonalArrayInterface()
1519

16-
function Base.getindex(a::AbstractDiagonalArray, i::DiagIndices)
17-
return a[StorageIndices(i)]
20+
struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end
21+
22+
DiagonalArrayStyle{M}(::Val{N}) where {M,N} = DiagonalArrayStyle{N}()
23+
24+
@interface ::AbstractDiagonalArrayInterface function Broadcast.BroadcastStyle(type::Type)
25+
return DiagonalArrayStyle{ndims(type)}()
1826
end
1927

20-
function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndices)
21-
a[StorageIndices(i)] = value
28+
function SparseArraysBase.isstored(
29+
a::AbstractDiagonalArray{<:Any,N}, I::Vararg{Int,N}
30+
) where {N}
31+
return allequal(I)
32+
end
33+
function SparseArraysBase.getstoredindex(
34+
a::AbstractDiagonalArray{<:Any,N}, I::Vararg{Int,N}
35+
) where {N}
36+
# TODO: Make this check optional, define `checkstored` like `checkbounds`
37+
# in SparseArraysBase.jl.
38+
# allequal(I) || error("Not a diagonal index.")
39+
return getdiagindex(a, first(I))
40+
end
41+
function SparseArraysBase.setstoredindex!(
42+
a::AbstractDiagonalArray{<:Any,N}, value, I::Vararg{Int,N}
43+
) where {N}
44+
# TODO: Make this check optional, define `checkstored` like `checkbounds`
45+
# in SparseArraysBase.jl.
46+
# allequal(I) || error("Not a diagonal index.")
47+
setdiagindex!(a, value, first(I))
2248
return a
2349
end
50+
function SparseArraysBase.eachstoredindex(a::AbstractDiagonalArray)
51+
return diagindices(a)
52+
end
53+
54+
# Fix ambiguity error with SparseArraysBase.
55+
function Base.getindex(a::AbstractDiagonalArray, I::DiagIndices)
56+
# TODO: Use `@interface` rather than `invoke`.
57+
return invoke(getindex, Tuple{AbstractArray,DiagIndices}, a, I)
58+
end
59+
# Fix ambiguity error with SparseArraysBase.
60+
function Base.getindex(a::AbstractDiagonalArray, I::DiagIndex)
61+
# TODO: Use `@interface` rather than `invoke`.
62+
return invoke(getindex, Tuple{AbstractArray,DiagIndex}, a, I)
63+
end
64+
# Fix ambiguity error with SparseArraysBase.
65+
function Base.setindex!(a::AbstractDiagonalArray, value, I::DiagIndices)
66+
# TODO: Use `@interface` rather than `invoke`.
67+
return invoke(setindex!, Tuple{AbstractArray,Any,DiagIndices}, a, value, I)
68+
end
69+
# Fix ambiguity error with SparseArraysBase.
70+
function Base.setindex!(a::AbstractDiagonalArray, value, I::DiagIndex)
71+
# TODO: Use `@interface` rather than `invoke`.
72+
return invoke(setindex!, Tuple{AbstractArray,Any,DiagIndex}, a, value, I)
73+
end
74+
75+
## SparseArraysBase.StorageIndex(i::DiagIndex) = StorageIndex(index(i))
76+
77+
## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndex)
78+
## return a[StorageIndex(i)]
79+
## end
80+
81+
## function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndex)
82+
## a[StorageIndex(i)] = value
83+
## return a
84+
## end
85+
86+
## SparseArraysBase.StorageIndices(i::DiagIndices) = StorageIndices(indices(i))
87+
88+
## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndices)
89+
## return a[StorageIndices(i)]
90+
## end
91+
92+
## function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndices)
93+
## a[StorageIndices(i)] = value
94+
## return a
95+
## end

src/abstractdiagonalarray/sparsearrayinterface.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
1-
using SparseArraysBase: SparseArraysBase
2-
3-
# `SparseArraysBase` interface
4-
function SparseArraysBase.index_to_storage_index(
5-
a::AbstractDiagonalArray{<:Any,N}, I::CartesianIndex{N}
6-
) where {N}
7-
!allequal(Tuple(I)) && return nothing
8-
return first(Tuple(I))
9-
end
10-
11-
function SparseArraysBase.storage_index_to_index(a::AbstractDiagonalArray, I)
12-
return CartesianIndex(ntuple(Returns(I), ndims(a)))
13-
end
1+
## # `SparseArraysBase` interface
2+
## function SparseArraysBase.index_to_storage_index(
3+
## a::AbstractDiagonalArray{<:Any,N}, I::CartesianIndex{N}
4+
## ) where {N}
5+
## !allequal(Tuple(I)) && return nothing
6+
## return first(Tuple(I))
7+
## end
8+
##
9+
## function SparseArraysBase.storage_index_to_index(a::AbstractDiagonalArray, I)
10+
## return CartesianIndex(ntuple(Returns(I), ndims(a)))
11+
## end
1412

1513
## # 1-dimensional case can be `AbstractDiagonalArray`.
1614
## function SparseArraysBase.sparse_similar(

src/diaginterface/diaginterface.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# TODO: Turn these into `@interface ::AbstractDiagonalArrayInterface` functions.
2+
13
diaglength(a::AbstractArray{<:Any,0}) = 1
24

35
function diaglength(a::AbstractArray)
@@ -19,10 +21,43 @@ function diagstride(a::AbstractArray)
1921
return s
2022
end
2123

24+
# Iterator over the diagonal cartesian indices.
25+
# For an AbstractArray `a`, `DiagCartesianIndices(a)` is equivalent
26+
# to `@view CartesianIndices(a)[diagindices(a)]` but should be
27+
# faster because it avoids conversions from linear to cartesian indices.
28+
struct DiagCartesianIndices{N} <: AbstractVector{CartesianIndex{N}}
29+
diaglength::Int
30+
end
31+
function DiagCartesianIndices(axes::Tuple{Vararg{AbstractUnitRange}})
32+
# Check the ranges are one-based.
33+
@assert all(isone, first.(axes))
34+
return DiagCartesianIndices{length(axes)}(minimum(length.(axes)))
35+
end
36+
function DiagCartesianIndices(dims::Tuple{Vararg{Int}})
37+
return DiagCartesianIndices(Base.OneTo.(dims))
38+
end
39+
function DiagCartesianIndices(a::AbstractArray)
40+
return DiagCartesianIndices(axes(a))
41+
end
42+
Base.size(I::DiagCartesianIndices) = (I.diaglength,)
43+
function Base.getindex(I::DiagCartesianIndices{N}, i::Int) where {N}
44+
return CartesianIndex(ntuple(Returns(i), N))
45+
end
46+
2247
function diagindices(a::AbstractArray)
48+
return diagindices(IndexStyle(a), a)
49+
end
50+
function diagindices(::IndexLinear, a::AbstractArray)
2351
maxdiag = LinearIndices(a)[CartesianIndex(ntuple(Returns(diaglength(a)), ndims(a)))]
2452
return 1:diagstride(a):maxdiag
2553
end
54+
function diagindices(::IndexCartesian, a::AbstractArray)
55+
return DiagCartesianIndices(a)
56+
# TODO: Define a special iterator for this, i.e. `DiagCartesianIndices`?
57+
return Iterators.map(
58+
i -> CartesianIndex(ntuple(Returns(i), ndims(a))), Base.OneTo(diaglength(a))
59+
)
60+
end
2661

2762
function diagindices(a::AbstractArray{<:Any,0})
2863
return Base.OneTo(1)

src/diagonalarray/diagonalarray.jl

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,33 @@
1-
using SparseArraysBase: SparseArraysBase, SparseArrayDOK, Zero, getindex_zero_function
1+
using SparseArraysBase: SparseArraysBase, SparseArrayDOK, default_getunstoredindex ## , Zero, getindex_zero_function
22

3-
struct DiagonalArray{T,N,Diag<:AbstractVector{T},Zero} <: AbstractDiagonalArray{T,N}
3+
struct DiagonalArray{T,N,Diag<:AbstractVector{T},F} <: AbstractDiagonalArray{T,N}
44
diag::Diag
55
dims::NTuple{N,Int}
6-
zero::Zero
6+
getunstoredindex::F
77
end
88

99
function DiagonalArray{T,N}(
10-
diag::AbstractVector{T}, d::Tuple{Vararg{Int,N}}, zero=Zero()
10+
diag::AbstractVector{T},
11+
d::Tuple{Vararg{Int,N}},
12+
getunstoredindex=default_getunstoredindex,
1113
) where {T,N}
12-
return DiagonalArray{T,N,typeof(diag),typeof(zero)}(diag, d, zero)
14+
return DiagonalArray{T,N,typeof(diag),typeof(getunstoredindex)}(diag, d, getunstoredindex)
1315
end
1416

1517
function DiagonalArray{T,N}(
16-
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, zero=Zero()
18+
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, getunstoredindex=default_getunstoredindex
1719
) where {T,N}
18-
return DiagonalArray{T,N}(T.(diag), d, zero)
20+
return DiagonalArray{T,N}(T.(diag), d, getunstoredindex)
1921
end
2022

2123
function DiagonalArray{T,N}(diag::AbstractVector, d::Vararg{Int,N}) where {T,N}
2224
return DiagonalArray{T,N}(diag, d)
2325
end
2426

2527
function DiagonalArray{T}(
26-
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, zero=Zero()
28+
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, getunstoredindex=default_getunstoredindex
2729
) where {T,N}
28-
return DiagonalArray{T,N}(diag, d, zero)
30+
return DiagonalArray{T,N}(diag, d, getunstoredindex)
2931
end
3032

3133
function DiagonalArray{T}(diag::AbstractVector, d::Vararg{Int,N}) where {T,N}
@@ -51,27 +53,29 @@ end
5153

5254
# undef
5355
function DiagonalArray{T,N}(
54-
::UndefInitializer, d::Tuple{Vararg{Int,N}}, zero=Zero()
56+
::UndefInitializer, d::Tuple{Vararg{Int,N}}, getunstoredindex=default_getunstoredindex
5557
) where {T,N}
56-
return DiagonalArray{T,N}(Vector{T}(undef, minimum(d)), d, zero)
58+
return DiagonalArray{T,N}(Vector{T}(undef, minimum(d)), d, getunstoredindex)
5759
end
5860

5961
function DiagonalArray{T,N}(::UndefInitializer, d::Vararg{Int,N}) where {T,N}
6062
return DiagonalArray{T,N}(undef, d)
6163
end
6264

6365
function DiagonalArray{T}(
64-
::UndefInitializer, d::Tuple{Vararg{Int,N}}, zero=Zero()
66+
::UndefInitializer, d::Tuple{Vararg{Int,N}}, getunstoredindex=default_getunstoredindex
6567
) where {T,N}
66-
return DiagonalArray{T,N}(undef, d, zero)
68+
return DiagonalArray{T,N}(undef, d, getunstoredindex)
6769
end
6870

6971
# Axes version
7072
function DiagonalArray{T}(
71-
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange,N}}, zero=Zero()
73+
::UndefInitializer,
74+
axes::Tuple{Vararg{AbstractUnitRange,N}},
75+
getunstoredindex=default_getunstoredindex,
7276
) where {T,N}
7377
@assert all(isone, first.(axes))
74-
return DiagonalArray{T,N}(undef, length.(axes), zero)
78+
return DiagonalArray{T,N}(undef, length.(axes), getunstoredindex)
7579
end
7680

7781
function DiagonalArray{T}(::UndefInitializer, d::Vararg{Int,N}) where {T,N}
@@ -83,23 +87,26 @@ Base.size(a::DiagonalArray) = a.dims
8387

8488
function Base.similar(a::DiagonalArray, elt::Type, dims::Tuple{Vararg{Int}})
8589
# TODO: Preserve zero element function.
86-
return DiagonalArray{elt}(undef, dims, getindex_zero_function(a))
90+
return DiagonalArray{elt}(undef, dims, a.getunstoredindex)
8791
end
8892

93+
# DiagonalArrays interface.
94+
diagview(a::DiagonalArray) = a.diag
95+
8996
# Minimal `SparseArraysBase` interface
90-
SparseArraysBase.sparse_storage(a::DiagonalArray) = a.diag
97+
## SparseArraysBase.sparse_storage(a::DiagonalArray) = a.diag
9198

9299
# `SparseArraysBase`
93100
# Defines similar when the output can't be `DiagonalArray`,
94101
# such as in `reshape`.
95102
# TODO: Put into `DiagonalArraysSparseArraysBaseExt`?
96103
# TODO: Special case 2D to output `SparseMatrixCSC`?
97-
function SparseArraysBase.sparse_similar(
98-
a::DiagonalArray, elt::Type, dims::Tuple{Vararg{Int}}
99-
)
100-
return SparseArrayDOK{elt}(undef, dims, getindex_zero_function(a))
101-
end
102-
103-
function SparseArraysBase.getindex_zero_function(a::DiagonalArray)
104-
return a.zero
105-
end
104+
## function SparseArraysBase.sparse_similar(
105+
## a::DiagonalArray, elt::Type, dims::Tuple{Vararg{Int}}
106+
## )
107+
## return SparseArrayDOK{elt}(undef, dims, getindex_zero_function(a))
108+
## end
109+
110+
## function SparseArraysBase.getindex_zero_function(a::DiagonalArray)
111+
## return a.zero
112+
## end

test/test_basics.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Test: @test, @testset, @test_broken
22
using DiagonalArrays: DiagonalArrays, DiagonalArray, DiagonalMatrix, diaglength
3-
using SparseArraysBase: SparseArrayDOK, stored_length
3+
using SparseArraysBase: SparseArrayDOK, storedlength
44
@testset "Test DiagonalArrays" begin
55
@testset "DiagonalArray (eltype=$elt)" for elt in (
66
Float32, Float64, Complex{Float32}, Complex{Float64}
@@ -23,15 +23,15 @@ using SparseArraysBase: SparseArrayDOK, stored_length
2323
# TODO: Use `densearray` to make generic to GPU.
2424
@test Array(a_dest) Array(a1) * Array(a2)
2525
# TODO: Make this work with `ArrayLayouts`.
26-
@test stored_length(a_dest) == 2
26+
@test storedlength(a_dest) == 2
2727
@test a_dest isa DiagonalMatrix{elt}
2828

2929
# TODO: Make generic to GPU, use `allocate_randn`?
3030
a2 = randn(elt, (3, 4))
3131
a_dest = a1 * a2
3232
# TODO: Use `densearray` to make generic to GPU.
3333
@test Array(a_dest) Array(a1) * Array(a2)
34-
@test stored_length(a_dest) == 8
34+
@test storedlength(a_dest) == 8
3535
@test a_dest isa Matrix{elt}
3636

3737
a2 = SparseArrayDOK{elt}(3, 4)
@@ -43,7 +43,7 @@ using SparseArraysBase: SparseArrayDOK, stored_length
4343
@test Array(a_dest) Array(a1) * Array(a2)
4444
# TODO: Define `SparseMatrixDOK`.
4545
# TODO: Make this work with `ArrayLayouts`.
46-
@test stored_length(a_dest) == 2
46+
@test storedlength(a_dest) == 2
4747
@test a_dest isa SparseArrayDOK{elt,2}
4848
end
4949
end

0 commit comments

Comments
 (0)