Skip to content

Commit c6dba24

Browse files
committed
add episodessampler
1 parent 937f1c6 commit c6dba24

File tree

2 files changed

+259
-185
lines changed

2 files changed

+259
-185
lines changed

src/samplers.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Random
2+
export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler
23

34
struct SampleGenerator{S,T}
45
sampler::S
@@ -233,3 +234,38 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any,
233234
StatsBase.sample(s, t.traces, Val(names), inds)
234235
)
235236
end
237+
238+
"""
239+
EpisodesSampler()
240+
241+
A sampler that samples all Episodes present in the Trajectory and divides them into
242+
Episode containers. Truncated Episodes (e.g. due to the buffer capacity) are sampled as well.
243+
There will be at most one truncated episode and it will always be the first one.
244+
"""
245+
struct EpisodesSampler{names}
246+
end
247+
248+
EpisodesSampler() = EpisodesSampler{nothing}()
249+
#EpisodesSampler{names}() = new{names}()
250+
251+
252+
struct Episode{names, N <: NamedTuple{names}}
253+
nt::N
254+
end
255+
256+
@forward Episode.nt Base.keys, Base.haskey, Base.getindex
257+
258+
StatsBase.sample(s::EpisodesSampler{nothing}, t::EpisodesBuffer) = StatsBase.sample(s,t,keys(t))
259+
StatsBase.sample(s::EpisodesSampler{names}, t::EpisodesBuffer) where names = StatsBase.sample(s,t,names)
260+
261+
function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)
262+
ranges = UnitRange{Int}[]
263+
idx = 1
264+
while idx < length(t)
265+
last_state_idx = idx + t.episodes_lengths[idx] - t.step_numbers[idx] + 1
266+
push!(ranges,idx:last_state_idx)
267+
idx = last_state_idx + 1
268+
end
269+
270+
return [Episode(NamedTuple{names}(map(x -> collect(t[Val(x)][r]), names))) for r in ranges]
271+
end

0 commit comments

Comments
 (0)