@@ -49,7 +49,7 @@ StatsBase.sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = Stat
4949
5050function StatsBase. sample (s:: BatchSampler , t:: AbstractTraces , names, weights = StatsBase. UnitWeights {Int} (length (t)))
5151 inds = StatsBase. sample (s. rng, 1 : length (t), weights, s. batch_size)
52- NamedTuple {names} (map (x -> collect (t[x ][inds]), names))
52+ NamedTuple {names} (map (x -> collect (t[Val (x) ][inds]), names))
5353end
5454
5555function StatsBase. sample (s:: BatchSampler , t:: EpisodesBuffer , names)
@@ -74,12 +74,12 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir
7474 st = deepcopy (t. priorities)
7575 st .*= e. sampleable_inds[1 : end - 1 ] # temporary sumtree that puts 0 priority to non sampleable indices.
7676 inds, priorities = rand (s. rng, st, s. batch_size)
77- NamedTuple {(:key, :priority, names...)} ((t. keys[inds], priorities, map (x -> collect (t. traces[x ][inds]), names)... ))
77+ NamedTuple {(:key, :priority, names...)} ((t. keys[inds], priorities, map (x -> collect (t. traces[Val (x) ][inds]), names)... ))
7878end
7979
8080function StatsBase. sample (s:: BatchSampler , t:: CircularPrioritizedTraces , names)
8181 inds, priorities = rand (s. rng, t. priorities, s. batch_size)
82- NamedTuple {(:key, :priority, names...)} ((t. keys[inds], priorities, map (x -> collect (t. traces[x ][inds]), names)... ))
82+ NamedTuple {(:key, :priority, names...)} ((t. keys[inds], priorities, map (x -> collect (t. traces[Val (x) ][inds]), names)... ))
8383end
8484
8585# ####
0 commit comments