6161
6262# # vectorized indexing
6363
64- function vectorized_getindex (src:: AbstractGPUArray , Is... )
65- shape = Base. index_shape (Is... )
66- dest = similar (src, shape)
64+ function vectorized_getindex! (dest:: AbstractGPUArray , src:: AbstractArray , Is... )
6765 any (isempty, Is) && return dest # indexing with empty array
6866 idims = map (length, Is)
6967
7068 # NOTE: we are pretty liberal here supporting non-GPU indices...
71- Is = map (x -> adapt (ToGPU (src), x ), Is)
69+ Is = map (adapt (ToGPU (dest) ), Is)
7270 @boundscheck checkbounds (src, Is... )
7371
7472 gpu_call (getindex_kernel, dest, src, idims, Is... )
7573 return dest
7674end
7775
76+ function vectorized_getindex (src:: AbstractGPUArray , Is... )
77+ shape = Base. index_shape (Is... )
78+ dest = similar (src, shape)
79+ return vectorized_getindex! (dest, src, Is... )
80+ end
81+
7882@generated function getindex_kernel (ctx:: AbstractKernelContext , dest, src, idims,
7983 Is:: Vararg{Any,N} ) where {N}
8084 quote
8791 end
8892end
8993
90- function vectorized_setindex! (dest:: AbstractGPUArray , src, Is... )
94+ function vectorized_setindex! (dest:: AbstractArray , src, Is... )
9195 isempty (Is) && return dest
9296 idims = length .(Is)
9397 len = prod (idims)
@@ -101,7 +105,7 @@ function vectorized_setindex!(dest::AbstractGPUArray, src, Is...)
101105 end
102106
103107 # NOTE: we are pretty liberal here supporting non-GPU indices...
104- Is = map (x -> adapt (ToGPU (dest), x ), Is)
108+ Is = map (adapt (ToGPU (dest)), Is)
105109 @boundscheck checkbounds (dest, Is... )
106110
107111 gpu_call (setindex_kernel, dest, adapt (ToGPU (dest), src), idims, len, Is... ;
144148 end )
145149end
146150
151+ # # Vectorized index overloading for `WrappedGPUArray`
152+ # We'd better not to overload `getindex`/`setindex!` directly as otherwise
153+ # the ambiguities from the default scalar fallback become a mess.
154+ # The default `getindex` for `AbstractArray` follows a `similar`-`copyto!` style.
155+ # Thus we only dispatch the `copyto!` part (`Base._unsafe_getindex!`) to our implement.
156+ function Base. _unsafe_getindex! (dest:: AbstractGPUArray , src:: AbstractArray , Is:: Vararg{Union{Real, AbstractArray}, N} ) where {N}
157+ return vectorized_getindex! (dest, src, Base. ensure_indexable (Is)... )
158+ end
159+ # Similar for `setindex!`, its default fallback is equivalent to `copyto!`.
160+ # We only dispatch the `copyto!` part (`Base._unsafe_setindex!`) to our implement.
161+ function Base. _unsafe_setindex! (:: IndexStyle , A:: WrappedGPUArray , x, Is:: Vararg{Union{Real,AbstractArray}, N} ) where N
162+ return vectorized_setindex! (A, x, Base. ensure_indexable (Is)... )
163+ end
164+ # And allow one more `ReshapedArray` wrapper to handle the `_maybe_reshape` optimization.
165+ function Base. _unsafe_setindex! (:: IndexStyle , A:: Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray} , x, Is:: Vararg{Union{Real,AbstractArray}, N} ) where N
166+ return vectorized_setindex! (A, x, Base. ensure_indexable (Is)... )
167+ end
147168
148169# find*
149170
0 commit comments