Skip to content

Commit e9fb65c

Browse files
authored
Merge pull request #2360 from FluxML/bc/amd-rng-fix
Use stable API for AMDGPU RNG conversion
2 parents f4b4761 + 2587e05 commit e9fb65c

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

ext/FluxAMDGPUExt/functor.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ end
4040
adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.FillArrays.AbstractFill) =
4141
ROCArray(collect(x))
4242
adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.OneElement) = ROCArray(collect(x))
43-
adapt_storage(::FluxAMDGPUAdaptor, x::Random.TaskLocalRNG) =
44-
AMDGPU.rocRAND.default_rng()
43+
adapt_storage(::FluxAMDGPUAdaptor, x::Random.TaskLocalRNG) = AMDGPU.rocrand_rng()
4544
adapt_storage(::FluxAMDGPUAdaptor, x::AMDGPU.rocRAND.RNG) = x
4645
adapt_storage(::FluxAMDGPUAdaptor, x::AbstractRNG) = error("""
4746
Cannot map RNG of type $(typeof(x)) to AMDGPU.

test/ext_amdgpu/basic.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ end
8686
@test parent(Flux.gpu(g3)) isa ROCMatrix{Float32}
8787
end
8888

89+
@testset "cpu and gpu on RNGs" begin
90+
crng = Random.default_rng()
91+
grng = gpu(crng)
92+
@test grng isa AMDGPU.rocRAND.RNG
93+
@test cpu(grng) === crng
94+
end
95+
8996
@testset "Flux.onecold gpu" begin
9097
y = Flux.onehotbatch(ones(3), 1:10) |> Flux.gpu
9198
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']

0 commit comments

Comments
 (0)