Skip to content

Commit 486b913

Browse files
More type tweaks
1 parent 34fe7c1 commit 486b913

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/samplers.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ StatsBase.sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = Stat
4949

5050
function StatsBase.sample(s::BatchSampler, t::AbstractTraces, names, weights = StatsBase.UnitWeights{Int}(length(t)))
5151
inds = StatsBase.sample(s.rng, 1:length(t), weights, s.batch_size)
52-
NamedTuple{names}(map(x -> collect(t[x][inds]), names))
52+
NamedTuple{names}(map(x -> collect(t[Val(x)][inds]), names))
5353
end
5454

5555
function StatsBase.sample(s::BatchSampler, t::EpisodesBuffer, names)
@@ -74,12 +74,12 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir
7474
st = deepcopy(t.priorities)
7575
st .*= e.sampleable_inds[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
7676
inds, priorities = rand(s.rng, st, s.batch_size)
77-
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[x][inds]), names)...))
77+
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[Val(x)][inds]), names)...))
7878
end
7979

8080
function StatsBase.sample(s::BatchSampler, t::CircularPrioritizedTraces, names)
8181
inds, priorities = rand(s.rng, t.priorities, s.batch_size)
82-
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[x][inds]), names)...))
82+
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[Val(x)][inds]), names)...))
8383
end
8484

8585
#####

0 commit comments

Comments
 (0)