|
1 | | -using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted |
| 1 | +using Base.Broadcast: |
| 2 | + Broadcast, BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted |
2 | 3 | using GPUArraysCore: @allowscalar |
3 | 4 | using MapBroadcast: Mapped |
4 | 5 | using DerivableInterfaces: DerivableInterfaces, @interface |
5 | 6 |
|
6 | | -abstract type AbstractBlockSparseArrayStyle{N} <: AbstractArrayStyle{N} end |
| 7 | +abstract type AbstractBlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <: |
| 8 | + AbstractArrayStyle{N} end |
7 | 9 |
|
8 | | -function DerivableInterfaces.interface(::Type{<:AbstractBlockSparseArrayStyle}) |
9 | | - return BlockSparseArrayInterface() |
| 10 | +blockstyle(::AbstractBlockSparseArrayStyle{N,B}) where {N,B<:AbstractArrayStyle{N}} = B() |
| 11 | + |
| 12 | +function Broadcast.BroadcastStyle( |
| 13 | + style1::AbstractBlockSparseArrayStyle, style2::AbstractBlockSparseArrayStyle |
| 14 | +) |
| 15 | + style = Broadcast.result_style(blockstyle(style1), blockstyle(style2)) |
| 16 | + return BlockSparseArrayStyle(style) |
10 | 17 | end |
11 | 18 |
|
12 | | -struct BlockSparseArrayStyle{N} <: AbstractBlockSparseArrayStyle{N} end |
| 19 | +function DerivableInterfaces.interface( |
| 20 | + ::Type{<:AbstractBlockSparseArrayStyle{N,B}} |
| 21 | +) where {N,B<:AbstractArrayStyle{N}} |
| 22 | + return BlockSparseArrayInterface(interface(B)) |
| 23 | +end |
13 | 24 |
|
14 | | -# Define for new sparse array types. |
15 | | -# function Broadcast.BroadcastStyle(arraytype::Type{<:MyBlockSparseArray}) |
16 | | -# return BlockSparseArrayStyle{ndims(arraytype)}() |
17 | | -# end |
| 25 | +struct BlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <: |
| 26 | + AbstractBlockSparseArrayStyle{N,B} |
| 27 | + blockstyle::B |
| 28 | +end |
| 29 | +function BlockSparseArrayStyle{N}(blockstyle::AbstractArrayStyle{N}) where {N} |
| 30 | + return BlockSparseArrayStyle{N,typeof(blockstyle)}(blockstyle) |
| 31 | +end |
18 | 32 |
|
| 33 | +function BlockSparseArrayStyle{N,B}() where {N,B<:AbstractArrayStyle{N}} |
| 34 | + return BlockSparseArrayStyle{N,B}(B()) |
| 35 | +end |
| 36 | +BlockSparseArrayStyle{N}() where {N} = BlockSparseArrayStyle{N}(DefaultArrayStyle{N}()) |
19 | 37 | BlockSparseArrayStyle(::Val{N}) where {N} = BlockSparseArrayStyle{N}() |
20 | 38 | BlockSparseArrayStyle{M}(::Val{N}) where {M,N} = BlockSparseArrayStyle{N}() |
| 39 | +function BlockSparseArrayStyle{M,B}(::Val{N}) where {M,B<:AbstractArrayStyle{M},N} |
| 40 | + return BlockSparseArrayStyle{N}(B(Val(N))) |
| 41 | +end |
21 | 42 |
|
22 | 43 | Broadcast.BroadcastStyle(a::BlockSparseArrayStyle, ::DefaultArrayStyle{0}) = a |
23 | 44 | function Broadcast.BroadcastStyle( |
|
0 commit comments