Skip to content

Commit 533034b

Browse files
authored
[BlockSparseArrays] Sparse and block sparse matrix multiplication (#1278)
1 parent 4e75a3f commit 533034b

File tree

4 files changed

+67
-8
lines changed

4 files changed

+67
-8
lines changed

src/DiagonalArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ include("diaginterface/diagindices.jl")
66
include("abstractdiagonalarray/abstractdiagonalarray.jl")
77
include("abstractdiagonalarray/sparsearrayinterface.jl")
88
include("abstractdiagonalarray/diagonalarraydiaginterface.jl")
9+
include("abstractdiagonalarray/arraylayouts.jl")
910
include("diagonalarray/diagonalarray.jl")
1011
include("diagonalarray/diagonalmatrix.jl")
1112
include("diagonalarray/diagonalvector.jl")
13+
include("diagonalarray/arraylayouts.jl")
1214
end
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
using ArrayLayouts: ArrayLayouts
2+
using ..SparseArrayInterface: AbstractSparseLayout
3+
4+
abstract type AbstractDiagonalLayout <: AbstractSparseLayout end
5+
struct DiagonalLayout <: AbstractDiagonalLayout end
6+
7+
ArrayLayouts.MemoryLayout(::Type{<:AbstractDiagonalArray}) = DiagonalLayout()

src/diagonalarray/arraylayouts.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using ArrayLayouts: MulAdd
2+
3+
# Default sparse array type for `AbstractDiagonalLayout`.
4+
default_diagonalarraytype(elt::Type) = DiagonalArray{elt}
5+
6+
# TODO: Preserve GPU memory! Implement `CuSparseArrayLayout`, `MtlSparseLayout`?
7+
function Base.similar(
8+
::MulAdd{<:AbstractDiagonalLayout,<:AbstractDiagonalLayout}, elt::Type, axes
9+
)
10+
return similar(default_diagonalarraytype(elt), axes)
11+
end

test/runtests.jl

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
@eval module $(gensym())
2-
using Test: @test, @testset
3-
using NDTensors.DiagonalArrays: DiagonalArrays
2+
using Test: @test, @testset, @test_broken
3+
using NDTensors.DiagonalArrays: DiagonalArrays, DiagonalArray, DiagonalMatrix, diaglength
4+
using NDTensors.SparseArrayDOKs: SparseArrayDOK
5+
using NDTensors.SparseArrayInterface: nstored
46
@testset "Test NDTensors.DiagonalArrays" begin
57
@testset "README" begin
68
@test include(
@@ -9,12 +11,49 @@ using NDTensors.DiagonalArrays: DiagonalArrays
911
),
1012
) isa Any
1113
end
12-
@testset "Basics" begin
13-
using NDTensors.DiagonalArrays: diaglength
14-
a = fill(1.0, 2, 3)
15-
@test diaglength(a) == 2
16-
a = fill(1.0)
17-
@test diaglength(a) == 1
14+
@testset "DiagonalArray (eltype=$elt)" for elt in
15+
(Float32, Float64, ComplexF32, ComplexF64)
16+
@testset "Basics" begin
17+
a = fill(one(elt), 2, 3)
18+
@test diaglength(a) == 2
19+
a = fill(one(elt))
20+
@test diaglength(a) == 1
21+
end
22+
@testset "Matrix multiplication" begin
23+
a1 = DiagonalArray{elt}(undef, (2, 3))
24+
a1[1, 1] = 11
25+
a1[2, 2] = 22
26+
a2 = DiagonalArray{elt}(undef, (3, 4))
27+
a2[1, 1] = 11
28+
a2[2, 2] = 22
29+
a2[3, 3] = 33
30+
a_dest = a1 * a2
31+
# TODO: Use `densearray` to make generic to GPU.
32+
@test Array(a_dest) Array(a1) * Array(a2)
33+
# TODO: Make this work with `ArrayLayouts`.
34+
@test nstored(a_dest) == 2
35+
@test a_dest isa DiagonalMatrix{elt}
36+
37+
# TODO: Make generic to GPU, use `allocate_randn`?
38+
a2 = randn(elt, (3, 4))
39+
a_dest = a1 * a2
40+
# TODO: Use `densearray` to make generic to GPU.
41+
@test Array(a_dest) Array(a1) * Array(a2)
42+
@test nstored(a_dest) == 8
43+
@test a_dest isa Matrix{elt}
44+
45+
a2 = SparseArrayDOK{elt}(3, 4)
46+
a2[1, 1] = 11
47+
a2[2, 2] = 22
48+
a2[3, 3] = 33
49+
a_dest = a1 * a2
50+
# TODO: Use `densearray` to make generic to GPU.
51+
@test Array(a_dest) Array(a1) * Array(a2)
52+
# TODO: Define `SparseMatrixDOK`.
53+
# TODO: Make this work with `ArrayLayouts`.
54+
@test nstored(a_dest) == 2
55+
@test a_dest isa SparseArrayDOK{elt,2}
56+
end
1857
end
1958
end
2059
end

0 commit comments

Comments
 (0)