Skip to content

Commit 6aa98ad

Browse files
authored
fix: TracedRNG path for correct create_result (#1803)
1 parent 021a658 commit 6aa98ad

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

src/Tracing.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,7 +1901,7 @@ Base.@nospecializeinfer function make_tracer(
19011901
return make_tracer(seen, prev.seed, path, mode; kwargs...)
19021902
end
19031903
return ReactantRNG(
1904-
make_tracer(seen, prev.seed, (path..., :seed), mode; kwargs...), prev.algorithm
1904+
make_tracer(seen, prev.seed, (path..., 1), mode; kwargs...), prev.algorithm
19051905
)
19061906
end
19071907

@@ -1911,9 +1911,7 @@ Base.@nospecializeinfer function make_tracer(
19111911
if mode == ArrayToConcrete
19121912
TracedRandom.should_warn_if_not_natively_supported(prev)
19131913
return ReactantRNG(
1914-
make_tracer(
1915-
seen, TracedRandom.make_seed(prev), (path..., :seed), mode; kwargs...
1916-
),
1914+
make_tracer(seen, TracedRandom.make_seed(prev), (path..., 1), mode; kwargs...),
19171915
TracedRandom.rng_algorithm(prev),
19181916
)
19191917
end

test/nn/lux.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,13 @@ end
102102
res, ∂ps = @jit gradient_loss_function(model, x, ps, st)
103103
@test res isa Reactant.ConcreteRNumber
104104
end
105+
106+
@testset "RNG stored in state" begin
107+
model = Dropout(0.5f0)
108+
ps, st = Reactant.to_rarray(Lux.setup(Random.default_rng(), model))
109+
110+
x = Reactant.to_rarray(randn(Float32, 10, 10))
111+
112+
res, st_new = @jit model(x, ps, st)
113+
@test st_new.rng isa Reactant.ReactantRNG
114+
end

0 commit comments

Comments
 (0)