1+ const GPUComponentArray = ComponentArray{T,N,<: GPUArrays.AbstractGPUArray ,Ax} where {T,N,Ax}
2+
3+ GPUArrays. backend (x:: ComponentArray ) = GPUArrays. backend (getdata (x))
4+
5+ function GPUArrays. Adapt. adapt_structure (to, x:: ComponentArray )
6+ data = GPUArrays. Adapt. adapt_structure (to, getdata (x))
7+ return ComponentArray (data, getaxes (x))
8+ end
9+
10+ function Base. map (f, x:: GPUComponentArray , args... )
11+ data = map (f, getdata (x), getdata .(args)... )
12+ return ComponentArray (data, getaxes (x))
13+ end
14+ function Base. map (f, x:: GPUComponentArray , args:: Vararg{Union{Base.AbstractBroadcasted, AbstractArray}} )
15+ data = map (f, getdata (x), map (getdata, args)... )
16+ return ComponentArray (data, getaxes (x))
17+ end
18+
19+ # We need all of these to avoid method ambiguities
20+ function Base. mapreduce (f, op, x:: GPUComponentArray ; kwargs... )
21+ return mapreduce (f, op, getdata (x); kwargs... )
22+ end
23+ function Base. mapreduce (f, op, x:: GPUComponentArray , args... ; kwargs... )
24+ return mapreduce (f, op, getdata (x), map (getdata, args)... ; kwargs... )
25+ end
26+ function Base. mapreduce (f, op, x:: GPUComponentArray , args:: Vararg{Union{Base.AbstractBroadcasted, AbstractArray}} ; kwargs... )
27+ return mapreduce (f, op, getdata (x), map (getdata, args)... ; kwargs... )
28+ end
29+
30+ # These are all stolen from GPUArrays.j;
31+ Base. any (A:: GPUComponentArray{Bool} ) = mapreduce (identity, | , getdata (A))
32+ Base. all (A:: GPUComponentArray{Bool} ) = mapreduce (identity, & , getdata (A))
33+
34+ Base. any (f:: Function , A:: GPUComponentArray ) = mapreduce (f, | , getdata (A))
35+ Base. all (f:: Function , A:: GPUComponentArray ) = mapreduce (f, & , getdata (A))
36+
37+ Base. count (pred:: Function , A:: GPUComponentArray ; dims= :, init= 0 ) =
38+ mapreduce (pred, Base. add_sum, getdata (A); init= init, dims= dims)
39+
40+ # avoid calling into `initarray!`
41+ for (fname, op) in [(:sum , :(Base. add_sum)), (:prod , :(Base. mul_prod)),
42+ (:maximum , :(Base. max)), (:minimum , :(Base. min)),
43+ (:all , :& ), (:any , :| )]
44+ fname! = Symbol (fname, ' !' )
45+ @eval begin
46+ Base.$ (fname!)(f:: Function , r:: GPUComponentArray , A:: GPUComponentArray{T} ) where T =
47+ GPUArrays. mapreducedim! (f, $ (op), getdata (r), getdata (A); init= neutral_element ($ (op), T))
48+ end
49+ end
0 commit comments