@@ -193,16 +193,21 @@ LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B)
193193
194194# # permutedims
195195
196- function genperm (I:: CartesianIndex{N} , perm:: NTuple{N} ) where N
197- CartesianIndex (ntuple (d-> (@inbounds return I[perm[d]]), Val (N)))
198- end
199-
200- function LinearAlgebra. permutedims! (dest:: AbstractGPUArray , src:: AbstractGPUArray , perm) where N
201- perm isa Tuple || (perm = Tuple (perm))
202- gpu_call (dest, src, perm; name= " permutedims!" ) do ctx, dest, src, perm
196+ function LinearAlgebra. permutedims! (dest:: AbstractGPUArray , src:: AbstractGPUArray ,
197+ perm:: NTuple )
198+ Base. checkdims_perm (dest, src, perm)
199+ function permutedims_kernel (ctx, dest, src, :: Val{perm} ) where {perm}
203200 I = @cartesianidx src
204- @inbounds dest[genperm (I, perm)] = src[I]
201+ @inbounds begin
202+ J = CartesianIndex (map (i-> I[i], perm))
203+ dest[J] = src[I]
204+ end
205205 return
206206 end
207+ gpu_call (permutedims_kernel, dest, src, Val (perm))
207208 return dest
208209end
210+
211+ # TODO : implementation without the memory copy
212+ LinearAlgebra. permutedims! (dest:: AbstractGPUArray , src:: AbstractGPUArray , perm) =
213+ permutedims! (dest, src, Tuple (perm))
0 commit comments