@@ -180,26 +180,35 @@ LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B)
180180
181181
182182# # permutedims
183+ LinearAlgebra. permutedims! (dest:: AbstractGPUArray , src:: AbstractGPUArray , perm) =
184+ permutedims! (dest, src, Tuple (perm))
183185
184186function LinearAlgebra. permutedims! (dest:: AbstractGPUArray , src:: AbstractGPUArray ,
185- perm:: NTuple )
187+ perm:: NTuple{N} ) where N
186188 Base. checkdims_perm (dest, src, perm)
187- function permutedims_kernel (ctx, dest, src, :: Val{perm} ) where {perm}
188- I = @cartesianidx src
189- @inbounds begin
190- J = CartesianIndex (map (i-> I[i], perm))
191- dest[J] = src[I]
192- end
189+
190+ # get the new strides of destination tensor
191+ dest_strides = ntuple (k-> k== 1 ? 1 : prod (i-> size (dest, i), 1 : k- 1 ), N)
192+ dest_strides_perm = ntuple (i-> dest_strides[findfirst (== (i), perm)], N)
193+
194+ function permutedims_kernel (ctx, dest, src, dest_strides_perm)
195+ # find the cartesian index in source tensor
196+ LI = @linearidx src
197+ I = @inbounds CartesianIndices (src)[LI]
198+
199+ # the corresponding linear index in the destination tensor
200+ dest_index = map_index (I. I, dest_strides_perm)
201+ @inbounds dest[dest_index] = src[LI]
193202 return
194203 end
195- gpu_call (permutedims_kernel, dest, src, Val (perm) )
204+ gpu_call (permutedims_kernel, dest, src, dest_strides_perm )
196205 return dest
197206end
198207
199- # TODO : implementation without the memory copy
200- LinearAlgebra . permutedims! (dest :: AbstractGPUArray , src :: AbstractGPUArray , perm) =
201- permutedims! (dest, src, Tuple (perm) )
202-
208+ # get linear index from cartesian indices and strides.
209+ @inline @generated function map_index (I :: NTuple{N} , dest_strides :: NTuple{N,T} ) where {N,T}
210+ Expr ( :call , : + , one (T), [:( @inbounds (I[ $ i] - 1 ) * dest_strides[ $ i]) for i in 1 : N] . .. )
211+ end
203212
204213# # norm
205214
0 commit comments