Skip to content

Commit 77fc0ec

Browse files
committed
Adjust code so that push! etc. don't trigger scalar indexing. Add tests for CuArrays.
1 parent 76be2e8 commit 77fc0ec

File tree

3 files changed

+196
-7
lines changed

3 files changed

+196
-7
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ julia = "1"
1111

1212
[extras]
1313
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
14+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1415

1516
[targets]
16-
test = ["Test"]
17+
test = ["CUDA", "Test"]

src/CircularArrayBuffers.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Adapt
55
export CircularArrayBuffer, CircularVectorBuffer, capacity, isfull
66

77
"""
8-
CircularArrayBuffer{T}(sz::Integer...) -> CircularArrayBuffer{T, N}
8+
CircularArrayBuffer{T}(sz::Integer...) -> CircularArrayBuffer{T, N, Array{T, N}}
99
1010
`CircularArrayBuffer` uses a `N`-dimension `Array` of size `sz` to serve as a buffer for
1111
`N-1`-dimension `Array`s of the same size.
@@ -33,12 +33,22 @@ end
3333
Adapt.adapt_structure(to, cb::CircularArrayBuffer) =
3434
CircularArrayBuffer(adapt(to, cb.buffer), cb.first, cb.nframes, cb.step_size)
3535

36+
function Base.show(io::IO, ::MIME"text/plain", cb::CircularArrayBuffer{T}) where T
37+
print(io, ndims(cb) == 1 ? "CircularVectorBuffer(" : "CircularArrayBuffer(")
38+
Base.showarg(io, cb.buffer, false)
39+
print(io, ") with eltype $T:\n")
40+
Base.print_array(io, adapt(Array, cb))
41+
return nothing
42+
end
43+
3644
Base.IndexStyle(::CircularArrayBuffer) = IndexLinear()
3745

3846
Base.size(cb::CircularArrayBuffer{T,N}, i::Integer) where {T,N} = i == N ? cb.nframes : size(cb.buffer, i)
3947
Base.size(cb::CircularArrayBuffer{T,N}) where {T,N} = ntuple(i -> size(cb, i), N)
4048
Base.getindex(cb::CircularArrayBuffer{T,N}, i::Int) where {T,N} = getindex(cb.buffer, _buffer_index(cb, i))
49+
Base.getindex(cb::CircularArrayBuffer{T,N}, I...) where {T,N} = getindex(cb.buffer, Base.front(I)..., _buffer_frame(cb, Base.last(I)))
4150
Base.setindex!(cb::CircularArrayBuffer{T,N}, v, i::Int) where {T,N} = setindex!(cb.buffer, v, _buffer_index(cb, i))
51+
Base.setindex!(cb::CircularArrayBuffer{T,N}, v, I...) where {T,N} = setindex!(cb.buffer, v, Base.front(I)..., _buffer_frame(cb, Base.last(I)))
4252

4353
capacity(cb::CircularArrayBuffer{T,N}) where {T,N} = size(cb.buffer, N)
4454
isfull(cb::CircularArrayBuffer) = cb.nframes == capacity(cb)
@@ -52,6 +62,7 @@ Base.isempty(cb::CircularArrayBuffer) = cb.nframes == 0
5262
ind
5363
end
5464
end
65+
@inline _buffer_index(cb::CircularArrayBuffer, I::AbstractVector{<:Integer}) = map(Base.Fix1(_buffer_index, cb), I)
5566

5667
@inline function _buffer_frame(cb::CircularArrayBuffer, i::Int)
5768
n = capacity(cb)
@@ -63,7 +74,7 @@ end
6374
end
6475
end
6576

66-
_buffer_frame(cb::CircularArrayBuffer, I::Vector{Int}) = map(i -> _buffer_frame(cb, i), I)
77+
_buffer_frame(cb::CircularArrayBuffer, I::AbstractVector{<:Integer}) = map(i -> _buffer_frame(cb, i), I)
6778

6879
function Base.empty!(cb::CircularArrayBuffer)
6980
cb.nframes = 0
@@ -77,13 +88,14 @@ function Base.push!(cb::CircularArrayBuffer{T,N}, data) where {T,N}
7788
cb.nframes += 1
7889
end
7990
if N == 1
91+
i = _buffer_frame(cb, cb.nframes)
8092
if ndims(data) == 0
81-
cb[cb.nframes] = data[]
93+
cb.buffer[i:i] .= data[]
8294
else
83-
cb[cb.nframes] = data
95+
cb.buffer[i:i] .= data
8496
end
8597
else
86-
cb[ntuple(_ -> (:), N - 1)..., cb.nframes] .= data
98+
cb.buffer[ntuple(_ -> (:), N - 1)..., _buffer_frame(cb, cb.nframes)] .= data
8799
end
88100
cb
89101
end

test/runtests.jl

Lines changed: 177 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
using CircularArrayBuffers
22
using Test
3+
using Adapt
4+
using CUDA
5+
CUDA.allowscalar(false)
36

4-
@testset "CircularArrayBuffers.jl" begin
7+
@testset "CircularArrayBuffers (Array)" begin
58
A = ones(2, 2)
69
C = ones(Float32, 2, 2)
710

@@ -165,3 +168,176 @@ using Test
165168
]
166169
end
167170
end
171+
172+
@testset "CircularArrayBuffers (CuArray)" begin
173+
A = CUDA.ones(2, 2)
174+
Ac = adapt(Array, A)
175+
C = CUDA.ones(Float32, 2, 2)
176+
177+
@testset "Adapt" begin
178+
X = CircularArrayBuffer(rand(2, 3))
179+
Xc = adapt(CuArray, X)
180+
@test Xc isa CircularArrayBuffer{Float64, 2, <:CuArray}
181+
@test adapt(Array, Xc) == X
182+
end
183+
184+
# https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/issues/551
185+
@testset "1D with 0d data" begin
186+
b = adapt(CuArray, CircularArrayBuffer{Int}(3))
187+
CUDA.@allowscalar push!(b, CUDA.zeros(Int, ()))
188+
@test length(b) == 1
189+
@test CUDA.@allowscalar b[1] == 0
190+
end
191+
192+
@testset "1D Int" begin
193+
b = adapt(CuArray, CircularArrayBuffer{Int}(3))
194+
195+
@test eltype(b) == Int
196+
@test capacity(b) == 3
197+
@test isfull(b) == false
198+
@test isempty(b) == true
199+
@test length(b) == 0
200+
@test size(b) == (0,)
201+
# element must has the exact same length with the element of buffer
202+
@test_throws Exception push!(b, [1, 2])
203+
204+
for x in 1:3
205+
push!(b, x)
206+
end
207+
208+
@test capacity(b) == 3
209+
@test isfull(b) == true
210+
@test length(b) == 3
211+
@test size(b) == (3,)
212+
# scalar indexing is not allowed
213+
@test_throws ErrorException b[1]
214+
@test_throws ErrorException b[end]
215+
@test CUDA.@allowscalar b[1:end] == cu([1, 2, 3])
216+
217+
for x in 4:5
218+
push!(b, x)
219+
end
220+
221+
@test capacity(b) == 3
222+
@test length(b) == 3
223+
@test size(b) == (3,)
224+
@test CUDA.@allowscalar b[1:end] == [3, 4, 5]
225+
226+
empty!(b)
227+
@test isfull(b) == false
228+
@test isempty(b) == true
229+
@test length(b) == 0
230+
@test size(b) == (0,)
231+
232+
push!(b, 6)
233+
@test isfull(b) == false
234+
@test isempty(b) == false
235+
@test length(b) == 1
236+
@test size(b) == (1,)
237+
@test CUDA.@allowscalar b[1] == 6
238+
239+
push!(b, 7)
240+
push!(b, 8)
241+
@test isfull(b) == true
242+
@test isempty(b) == false
243+
@test length(b) == 3
244+
@test size(b) == (3,)
245+
@test CUDA.@allowscalar b[1:3] == cu([6, 7, 8])
246+
247+
push!(b, 9)
248+
@test isfull(b) == true
249+
@test isempty(b) == false
250+
@test length(b) == 3
251+
@test size(b) == (3,)
252+
@test CUDA.@allowscalar b[1:3] == cu([7, 8, 9])
253+
254+
x = CUDA.@allowscalar pop!(b)
255+
@test x == 9
256+
@test length(b) == 2
257+
@test CUDA.@allowscalar b[1:2] == cu([7, 8])
258+
259+
x = CUDA.@allowscalar popfirst!(b)
260+
@test x == 7
261+
@test length(b) == 1
262+
@test CUDA.@allowscalar b[1] == 8
263+
264+
x = CUDA.@allowscalar pop!(b)
265+
@test x == 8
266+
@test length(b) == 0
267+
268+
@test_throws ArgumentError pop!(b)
269+
@test_throws ArgumentError popfirst!(b)
270+
end
271+
272+
@testset "2D Float64" begin
273+
b = adapt(CuArray, CircularArrayBuffer{Float64}(2, 2, 3))
274+
275+
@test eltype(b) == Float64
276+
@test capacity(b) == 3
277+
@test isfull(b) == false
278+
@test length(b) == 0
279+
@test size(b) == (2, 2, 0)
280+
281+
for x in 1:3
282+
push!(b, x * A)
283+
end
284+
285+
@test capacity(b) == 3
286+
@test isfull(b) == true
287+
@test length(b) == 2 * 2 * 3
288+
@test size(b) == (2, 2, 3)
289+
for i in 1:3
290+
@test b[:, :, i] == i * A
291+
end
292+
@test b[:, :, end] == 3 * A
293+
294+
for x in 4:5
295+
push!(b, x * CUDA.ones(Float64, 2, 2))
296+
end
297+
298+
@test capacity(b) == 3
299+
@test length(b) == 2 * 2 * 3
300+
@test size(b) == (2, 2, 3)
301+
@test b[:, :, 1] == 3 * A
302+
@test b[:, :, end] == 5 * A
303+
304+
# doing b == ... triggers scalar indexing
305+
@test CUDA.@allowscalar b == cu(reshape([c for x in 3:5 for c in x * Ac], 2, 2, 3))
306+
307+
push!(b, 6 * CUDA.ones(Float32, 2, 2))
308+
push!(b, 7 * CUDA.ones(Int, 2, 2))
309+
@test CUDA.@allowscalar b == cu(reshape([c for x in 5:7 for c in x * Ac], 2, 2, 3))
310+
311+
x = pop!(b)
312+
@test x == 7 * CUDA.ones(Float64, 2, 2)
313+
@test CUDA.@allowscalar b == cu(reshape([c for x in 5:6 for c in x * Ac], 2, 2, 2))
314+
end
315+
316+
@testset "append!" begin
317+
b = adapt(CuArray, CircularArrayBuffer{Int}(2, 3))
318+
append!(b, CUDA.zeros(2))
319+
append!(b, 1:4)
320+
@test CUDA.@allowscalar b == cu([
321+
0 1 3
322+
0 2 4
323+
])
324+
325+
326+
b = adapt(CuArray, CircularArrayBuffer{Int}(2, 3))
327+
for i in 1:5
328+
push!(b, CUDA.fill(i, 2))
329+
end
330+
empty!(b)
331+
append!(b, 1:4)
332+
@test CUDA.@allowscalar b == cu([
333+
1 3
334+
2 4
335+
])
336+
337+
append!(b, 5:8)
338+
@test CUDA.@allowscalar b == cu([
339+
3 5 7
340+
4 6 8
341+
])
342+
end
343+
end

0 commit comments

Comments
 (0)