@@ -7,6 +7,23 @@ function GPUArrays.Adapt.adapt_structure(to, x::ComponentArray)
77 return ComponentArray (data, getaxes (x))
88end
99
10+ GPUArrays. Adapt. adapt_storage (:: Type{ComponentArray{T,N,A,Ax}} , xs:: AT ) where {T,N,A,Ax,AT<: AbstractArray } =
11+ GPUArrays. Adapt. adapt_storage (A, xs)
12+
13+ function Base. fill! (A:: GPUComponentArray{T} , x) where {T}
14+ length (A) == 0 && return A
15+ GPUArrays. gpu_call (A, convert (T, x)) do ctx, a, val
16+ idx = GPUArrays. @linearidx (a)
17+ @inbounds a[idx] = val
18+ return
19+ end
20+ A
21+ end
22+
23+ LinearAlgebra. dot (x:: GPUComponentArray , y:: GPUComponentArray ) = dot (getdata (x), getdata (y))
24+ LinearAlgebra. norm (ca:: GPUComponentArray , p:: Real ) = norm (getdata (ca), p)
25+ LinearAlgebra. rmul! (ca:: GPUComponentArray , b:: Number ) = GPUArrays. generic_rmul! (ca, b)
26+
1027function Base. map (f, x:: GPUComponentArray , args... )
1128 data = map (f, getdata (x), getdata .(args)... )
1229 return ComponentArray (data, getaxes (x))
@@ -46,4 +63,4 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
4663 Base.$ (fname!)(f:: Function , r:: GPUComponentArray , A:: GPUComponentArray{T} ) where T =
4764 GPUArrays. mapreducedim! (f, $ (op), getdata (r), getdata (A); init= neutral_element ($ (op), T))
4865 end
49- end
66+ end
0 commit comments