Skip to content

Commit ee2cfa1

Browse files
committed
Finally?
1 parent 3d98fea commit ee2cfa1

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "11.2.3"
66
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
9+
GPUToolbox = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc"
910
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1011
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -26,6 +27,7 @@ JLD2Ext = "JLD2"
2627
AcceleratedKernels = "0.4"
2728
Adapt = "4.0"
2829
GPUArraysCore = "= 0.2.0"
30+
GPUToolbox = "0.2, 0.3"
2931
JLD2 = "0.4, 0.5"
3032
KernelAbstractions = "0.9.28"
3133
LLVM = "3.9, 4, 5, 6, 7, 8, 9"

src/GPUArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module GPUArrays
22

3+
using GPUToolbox
34
using KernelAbstractions
45
using Serialization
56
using Random

src/host/reverse.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function _reverse(input::AnyGPUArray{T, N}, output::AnyGPUArray{T, N};
1616
nd_idx = CartesianIndices(input)
1717

1818
## COV_EXCL_START
19-
@kernel unsafe_indices=true function kernel(input::AbstractArray{T, N}, output::AbstractArray{T, N}) where {T, N}
19+
@kernel unsafe_indices=true function kernel(input, output)
2020
offset_in = @groupsize()[1] * (@index(Group, Linear) - 1i32)
2121
index_in = offset_in + @index(Local, Linear)
2222

@@ -32,7 +32,7 @@ function _reverse(input::AnyGPUArray{T, N}, output::AnyGPUArray{T, N};
3232
nthreads = 256
3333
nblocks = cld(length(input), nthreads)
3434

35-
kernel(get_backend(input), nblocks)(input, output; ndrange=length(nblocks))
35+
kernel(get_backend(input), nblocks)(input, output; ndrange=length(input))
3636
end
3737

3838
# in-place version, swapping elements on half the number of threads
@@ -52,10 +52,9 @@ function _reverse!(data::AnyGPUArray{T, N}; dims=1:ndims(data)) where {T, N}
5252
nd_idx = CartesianIndices(reduced_size)
5353

5454
## COV_EXCL_START
55-
@kernel unsafe_indices=true function kernel(data::AbstractArray{T, N}) where {T, N}
55+
@kernel unsafe_indices=true function kernel(data)
5656
offset_in = @groupsize()[1] * (@index(Group, Linear) - 1i32)
57-
58-
index_in = offset_in + threadIdx().x
57+
index_in = offset_in + @index(Local, Linear)
5958

6059
@inbounds if index_in <= reduced_length
6160
idx = Tuple(nd_idx[index_in])
@@ -80,7 +79,7 @@ function _reverse!(data::AnyGPUArray{T, N}; dims=1:ndims(data)) where {T, N}
8079
nthreads = 256
8180
nblocks = cld(prod(reduced_size), nthreads)
8281

83-
kernel(get_backend(input), nblocks)(input, output; ndrange=length(nblocks))
82+
kernel(get_backend(data), nblocks)(data; ndrange=length(data))
8483
end
8584

8685

0 commit comments

Comments
 (0)