Skip to content

Commit 46e8f03

Browse files
committed
Port reverse from CUDA
1 parent 9d9f432 commit 46e8f03

File tree

2 files changed

+200
-0
lines changed

2 files changed

+200
-0
lines changed

src/host/reverse.jl

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# reversing
2+
3+
# the kernel works by treating the array as 1d. after reversing by dimension x an element at
4+
# pos [i1, i2, i3, ... , i{x}, ..., i{n}] will be at
5+
# pos [i1, i2, i3, ... , d{x} - i{x} + 1, ..., i{n}] where d{x} is the size of dimension x
6+
7+
# out-of-place version, copying a single value per thread from input to output
8+
function _reverse(input::AnyGPUArray{T, N}, output::AnyGPUArray{T, N};
9+
dims=1:ndims(input)) where {T, N}
10+
@assert size(input) == size(output)
11+
rev_dims = ntuple((d)-> d in dims && size(input, d) > 1, N)
12+
ref = size(input) .+ 1
13+
# converts an ND-index in the data array to the linear index
14+
lin_idx = LinearIndices(input)
15+
# converts a linear index in a reduced array to an ND-index, but using the reduced size
16+
nd_idx = CartesianIndices(input)
17+
18+
## COV_EXCL_START
19+
@kernel cpu=false unsafe_indices=true function kernel(input::AbstractArray{T, N}, output::AbstractArray{T, N}) where {T, N}
20+
offset_in = @groupsize()[1] * (@index(Group, Linear) - 1i32)
21+
index_in = offset_in + @index(Local, Linear)
22+
23+
@inbounds if index_in <= length(input)
24+
idx = Tuple(nd_idx[index_in])
25+
idx = ifelse.(rev_dims, ref .- idx, idx)
26+
index_out = lin_idx[idx...]
27+
output[index_out] = input[index_in]
28+
end
29+
30+
return
31+
end
32+
## COV_EXCL_STOP
33+
34+
nthreads = 256
35+
nblocks = cld(length(input), nthreads)
36+
37+
kernel(get_backend(input), nblocks)(input, output; ndrange=length(nblocks))
38+
end
39+
40+
# in-place version, swapping elements on half the number of threads
41+
function _reverse!(data::AnyGPUArray{T, N}; dims=1:ndims(data)) where {T, N}
42+
rev_dims = ntuple((d)-> d in dims && size(data, d) > 1, N)
43+
half_dim = findlast(rev_dims)
44+
if isnothing(half_dim)
45+
# no reverse operation needed at all in this case.
46+
return
47+
end
48+
ref = size(data) .+ 1
49+
# converts an ND-index in the data array to the linear index
50+
lin_idx = LinearIndices(data)
51+
reduced_size = ntuple((d)->ifelse(d==half_dim, cld(size(data,d),2), size(data,d)), N)
52+
reduced_length = prod(reduced_size)
53+
# converts a linear index in a reduced array to an ND-index, but using the reduced size
54+
nd_idx = CartesianIndices(reduced_size)
55+
56+
## COV_EXCL_START
57+
@kernel cpu=false unsafe_indices=true function kernel(data::AbstractArray{T, N}) where {T, N}
58+
offset_in = @groupsize()[1] * (@index(Group, Linear) - 1i32)
59+
60+
index_in = offset_in + threadIdx().x
61+
62+
@inbounds if index_in <= reduced_length
63+
idx = Tuple(nd_idx[index_in])
64+
index_in = lin_idx[idx...]
65+
idx = ifelse.(rev_dims, ref .- idx, idx)
66+
index_out = lin_idx[idx...]
67+
68+
if index_in < index_out
69+
temp = data[index_out]
70+
data[index_out] = data[index_in]
71+
data[index_in] = temp
72+
end
73+
end
74+
75+
return
76+
end
77+
## COV_EXCL_STOP
78+
79+
# NOTE: we launch slightly more than half the number of elements in the array as threads.
80+
# The last non-singleton dimension along which to reverse is used to define how the array is split.
81+
# Only the middle row in case of an odd array dimension could cause trouble, but this is prevented by
82+
# ignoring the threads that cross the mid-point
83+
84+
nthreads = 256
85+
nblocks = cld(prod(reduced_size), nthreads)
86+
87+
kernel(get_backend(input), nblocks)(input, output; ndrange=length(nblocks))
88+
end
89+
90+
91+
# n-dimensional API
92+
93+
function Base.reverse!(data::AnyGPUArray{T, N}; dims=:) where {T, N}
94+
if isa(dims, Colon)
95+
dims = 1:ndims(data)
96+
end
97+
if !applicable(iterate, dims)
98+
throw(ArgumentError("dimension $dims is not an iterable"))
99+
end
100+
if !all(1 .≤ dims .≤ ndims(data))
101+
throw(ArgumentError("dimension $dims is not 1 ≤ $dims$(ndims(data))"))
102+
end
103+
104+
_reverse!(data; dims=dims)
105+
106+
return data
107+
end
108+
109+
# out-of-place
110+
function Base.reverse(input::AnyGPUArray{T, N}; dims=:) where {T, N}
111+
if isa(dims, Colon)
112+
dims = 1:ndims(input)
113+
end
114+
if !applicable(iterate, dims)
115+
throw(ArgumentError("dimension $dims is not an iterable"))
116+
end
117+
if !all(1 .≤ dims .≤ ndims(input))
118+
throw(ArgumentError("dimension $dims is not 1 ≤ $dims$(ndims(input))"))
119+
end
120+
121+
if all(size(input)[[dims...]].==1)
122+
# no reverse operation needed at all in this case.
123+
return copy(input)
124+
else
125+
output = similar(input)
126+
_reverse(input, output; dims=dims)
127+
return output
128+
end
129+
end
130+
131+
132+
# 1-dimensional API
133+
134+
# in-place
135+
Base.@propagate_inbounds function Base.reverse!(data::AnyGPUArray{T}, start::Integer,
136+
stop::Integer=length(data)) where {T}
137+
_reverse!(view(data, start:stop))
138+
return data
139+
end
140+
141+
Base.reverse!(data::AnyGPUArray{T}) where {T} = @inbounds reverse!(data, 1, length(data))
142+
143+
# out-of-place
144+
Base.@propagate_inbounds function Base.reverse(input::AnyGPUArray{T}, start::Integer,
145+
stop::Integer=length(input)) where {T}
146+
output = similar(input)
147+
148+
start > 1 && copyto!(output, 1, input, 1, start-1)
149+
_reverse(view(input, start:stop), view(output, start:stop))
150+
stop < length(input) && copyto!(output, stop+1, input, stop+1)
151+
152+
return output
153+
end
154+
155+
Base.reverse(data::AnyGPUArray{T}) where {T} = @inbounds reverse(data, 1, length(data))

test/testsuite/base.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,51 @@ end
381381
gA = reshape(AT(A),4)
382382
end
383383

384+
@testset "reverse" begin
385+
# 1-d out-of-place
386+
@test compare(x->reverse(x), AT, rand(Float32, 1000))
387+
@test compare(x->reverse(x, 10), AT, rand(Float32, 1000))
388+
@test compare(x->reverse(x, 10, 90), AT, rand(Float32, 1000))
389+
390+
# 1-d in-place
391+
@test compare(x->reverse!(x), AT, rand(Float32, 1000))
392+
@test compare(x->reverse!(x, 10), AT, rand(Float32, 1000))
393+
@test compare(x->reverse!(x, 10, 90), AT, rand(Float32, 1000))
394+
395+
# n-d out-of-place
396+
for shape in ([1, 2, 4, 3], [4, 2], [5], [2^5, 2^5, 2^5]),
397+
dim in 1:length(shape)
398+
@test compare(x->reverse(x; dims=dim), AT, rand(Float32, shape...))
399+
400+
cpu = rand(Float32, shape...)
401+
gpu = AT(cpu)
402+
reverse!(gpu; dims=dim)
403+
@test Array(gpu) == reverse(cpu; dims=dim)
404+
end
405+
406+
# supports multidimensional reverse
407+
for shape in ([1, 2, 4, 3], [2^5, 2^5, 2^5]),
408+
dim in ((1,2),(2,3),(1,3),:)
409+
@test compare(x->reverse(x; dims=dim), AT, rand(Float32, shape...))
410+
411+
cpu = rand(Float32, shape...)
412+
gpu = AT(cpu)
413+
reverse!(gpu; dims=dim)
414+
@test Array(gpu) == reverse(cpu; dims=dim)
415+
end
416+
417+
# wrapped array
418+
@test compare(x->reverse(x), AT, reshape(rand(Float32, 2,2), 4))
419+
420+
# error throwing
421+
cpu = rand(Float32, 1,2,3,4)
422+
gpu = AT(cpu)
423+
@test_throws ArgumentError reverse!(gpu, dims=5)
424+
@test_throws ArgumentError reverse!(gpu, dims=0)
425+
@test_throws ArgumentError reverse(gpu, dims=5)
426+
@test_throws ArgumentError reverse(gpu, dims=0)
427+
end
428+
384429
@testset "reinterpret" begin
385430
A = Int32[-1,-2,-3]
386431
dA = AT(A)

0 commit comments

Comments
 (0)