Skip to content

Commit a5b0bbd

Browse files
Fix indexing for gpu rand (#641)
* fix: indexing for `gpu_rand` * chore: fix runic formatting * test: see which testset fails * test: try using approx * Fix workgroup size determinaiton * Mark broken tests broken --------- Co-authored-by: Avik Pal <avikpal@mit.edu> Co-authored-by: Avik Pal <avik.pal.2017@gmail.com>
1 parent 3d9bb2a commit a5b0bbd

File tree

4 files changed

+18
-8
lines changed

4 files changed

+18
-8
lines changed

lib/JLArrays/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "JLArrays"
22
uuid = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
33
authors = ["Tim Besard <tim.besard@gmail.com>"]
4-
version = "0.3.0"
4+
version = "0.3.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

lib/JLArrays/src/JLArrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ KernelAbstractions.allocate(::JLBackend, ::Type{T}, dims::Tuple) where T = JLArr
600600
end
601601

602602
if KernelAbstractions.workgroupsize(kernel) <: DynamicSize && workgroupsize === nothing
603-
workgroupsize = (1024,) # Vectorization, 4x unrolling, minimal grain size
603+
workgroupsize = (MAXTHREADS,) # Vectorization, 4x unrolling, minimal grain size
604604
end
605605
iterspace, dynamic = partition(kernel, ndrange, workgroupsize)
606606
# partition checked that the ndrange's agreed
@@ -626,6 +626,7 @@ else
626626
end
627627

628628
function (obj::Kernel{JLBackend})(args...; ndrange=nothing, workgroupsize=nothing)
629+
ndrange, workgroupsize, _, _ = launch_config(obj, ndrange, workgroupsize)
629630
device_args = jlconvert.(args)
630631
new_obj = convert_to_cpu(obj)
631632
new_obj(device_args...; ndrange, workgroupsize)

src/host/random.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,10 @@ function Random.randn!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
9696
threads = (length(A) - 1) ÷ 2 + 1
9797
@kernel function randn!(a, randstates)
9898
i = @index(Global, Linear)
99+
threadidx = @index(Local, Linear)
99100
idx = 2*(i - 1) + 1
100-
U1 = gpu_rand(T, i, randstates)
101-
U2 = gpu_rand(T, i, randstates)
101+
U1 = gpu_rand(T, threadidx, randstates)
102+
U2 = gpu_rand(T, threadidx, randstates)
102103
Z0 = sqrt(T(-2.0)*log(U1))*cos(T(2pi)*U2)
103104
Z1 = sqrt(T(-2.0)*log(U1))*sin(T(2pi)*U2)
104105
@inbounds a[idx] = Z0

test/testsuite/random.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,25 @@
66
end
77
cpu_rng = Random.default_rng()
88

9+
SEEDING_BROKEN = (rng != cpu_rng) && !contains(string(AT), "JLArray")
10+
911
@testset "rand" begin # uniform
10-
for T in eltypes, d in (10, (10,10))
12+
@testset "$d $T" for T in eltypes, d in (10, (10, 10), (1024, 1024))
1113
A = AT{T}(undef, d)
1214
B = copy(A)
1315
rand!(rng, A)
1416
rand!(rng, B)
1517
@test Array(A) != Array(B)
1618

19+
A = AT(rand(T, d))
20+
B = AT(rand(T, d))
21+
1722
Random.seed!(rng)
1823
Random.seed!(rng, 1)
1924
rand!(rng, A)
2025
Random.seed!(rng, 1)
2126
rand!(rng, B)
22-
@test all(Array(A) .== Array(B))
27+
@test Array(A) == Array(B) broken=SEEDING_BROKEN && (prod(d) > length(rng.state))
2328

2429
if rng != cpu_rng
2530
rand!(cpu_rng, A)
@@ -44,19 +49,22 @@
4449
@testset "randn" begin # normally-distributed
4550
# XXX: randn calls sqrt, and Base's sqrt(::Complex) performs
4651
# checked type conversions that throw boxed numbers.
47-
for T in filter(isrealfloattype, eltypes), d in (2, (2,2))
52+
@testset "$d $T" for T in filter(isrealfloattype, eltypes), d in (2, (2, 2), (1024, 1024))
4853
A = AT{T}(undef, d)
4954
B = copy(A)
5055
randn!(rng, A)
5156
randn!(rng, B)
5257
@test Array(A) != Array(B)
5358

59+
A = AT(rand(T, d))
60+
B = AT(rand(T, d))
61+
5462
Random.seed!(rng)
5563
Random.seed!(rng, 1)
5664
randn!(rng, A)
5765
Random.seed!(rng, 1)
5866
randn!(rng, B)
59-
@test Array(A) == Array(B)
67+
@test Array(A) == Array(B) broken=SEEDING_BROKEN && (prod(d) > (2 * length(rng.state)))
6068

6169
if rng != cpu_rng
6270
randn!(cpu_rng, A)

0 commit comments

Comments
 (0)