Skip to content

Commit a3ca7c1

Browse files
committed
Make GPU ComponentArray boradcasting work. Fixes #155
1 parent 692e4f8 commit a3ca7c1

File tree

5 files changed

+79
-0
lines changed

5 files changed

+79
-0
lines changed

src/ComponentArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ function __init__()
5959
@require RecursiveArrayTools="731186ca-8d62-57ce-b412-fbd966d074cd" required("recursivearraytools.jl")
6060
@require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" required("staticarrays.jl")
6161
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" required("reversediff.jl")
62+
@require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" required("gpuarrays.jl")
6263
end
6364

6465
end

src/compat/gpuarrays.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
const GPUComponentArray = ComponentArray{T,N,<:GPUArrays.AbstractGPUArray,Ax} where {T,N,Ax}
2+
3+
GPUArrays.backend(x::ComponentArray) = GPUArrays.backend(getdata(x))
4+
5+
function GPUArrays.Adapt.adapt_structure(to, x::ComponentArray)
6+
data = GPUArrays.Adapt.adapt_structure(to, getdata(x))
7+
return ComponentArray(data, getaxes(x))
8+
end
9+
10+
function Base.map(f, x::GPUComponentArray, args...)
11+
data = map(f, getdata(x), getdata.(args)...)
12+
return ComponentArray(data, getaxes(x))
13+
end
14+
function Base.map(f, x::GPUComponentArray, args::Vararg{Union{Base.AbstractBroadcasted, AbstractArray}})
15+
data = map(f, getdata(x), map(getdata, args)...)
16+
return ComponentArray(data, getaxes(x))
17+
end
18+
19+
# We need all of these to avoid method ambiguities
20+
function Base.mapreduce(f, op, x::GPUComponentArray; kwargs...)
21+
return mapreduce(f, op, getdata(x); kwargs...)
22+
end
23+
function Base.mapreduce(f, op, x::GPUComponentArray, args...; kwargs...)
24+
return mapreduce(f, op, getdata(x), map(getdata, args)...; kwargs...)
25+
end
26+
function Base.mapreduce(f, op, x::GPUComponentArray, args::Vararg{Union{Base.AbstractBroadcasted, AbstractArray}}; kwargs...)
27+
return mapreduce(f, op, getdata(x), map(getdata, args)...; kwargs...)
28+
end
29+
30+
# These are all stolen from GPUArrays.j;
31+
Base.any(A::GPUComponentArray{Bool}) = mapreduce(identity, |, getdata(A))
32+
Base.all(A::GPUComponentArray{Bool}) = mapreduce(identity, &, getdata(A))
33+
34+
Base.any(f::Function, A::GPUComponentArray) = mapreduce(f, |, getdata(A))
35+
Base.all(f::Function, A::GPUComponentArray) = mapreduce(f, &, getdata(A))
36+
37+
Base.count(pred::Function, A::GPUComponentArray; dims=:, init=0) =
38+
mapreduce(pred, Base.add_sum, getdata(A); init=init, dims=dims)
39+
40+
# avoid calling into `initarray!`
41+
for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
42+
(:maximum, :(Base.max)), (:minimum, :(Base.min)),
43+
(:all, :&), (:any, :|)]
44+
fname! = Symbol(fname, '!')
45+
@eval begin
46+
Base.$(fname!)(f::Function, r::GPUComponentArray, A::GPUComponentArray{T}) where T =
47+
GPUArrays.mapreducedim!(f, $(op), getdata(r), getdata(A); init=neutral_element($(op), T))
48+
end
49+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
44
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
55
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
66
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
7+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
78
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"

test/gpu_tests.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using JLArrays
2+
3+
JLArrays.allowscalar(false)
4+
5+
jla = jl(collect(1:4))
6+
jlca = ComponentArray(jla, Axis(a=1:2, b=3:4))
7+
8+
@testset "Broadcasting" begin
9+
@test identity.(jlca + jla) ./ 2 == jlca
10+
11+
@test getdata(map(identity, jlca)) isa JLArray
12+
@test all(==(0), map(-, jlca, jla))
13+
@test all(map(-, jlca, jlca) .== 0)
14+
@test all(==(0), map(-, jla, jlca))
15+
16+
@test any(==(1), jlca)
17+
@test count(>(2), jlca) == 2
18+
19+
# Make sure mapreducing multiple arrays works
20+
@test mapreduce(==, +, jlca, jla) == 4
21+
@test mapreduce(abs2, +, jlca) == 30
22+
23+
@test all(map(sin, jlca) .== sin.(jlca) .== sin.(jla) .≈ sin.(1:4))
24+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,4 +608,8 @@ end
608608

609609
@testset "Autodiff" begin
610610
include("autodiff_tests.jl")
611+
end
612+
613+
@testset "GPU" begin
614+
include("gpu_tests.jl")
611615
end

0 commit comments

Comments
 (0)