Skip to content

Commit d64dfbc

Browse files
committed
Corner case of first non-sampleable idx
1 parent 0c14476 commit d64dfbc

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

src/samplers.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,13 @@ function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)
262262
ranges = UnitRange{Int}[]
263263
idx = 1
264264
while idx < length(t)
265-
last_state_idx = idx + t.episodes_lengths[idx] - t.step_numbers[idx] + 1
266-
push!(ranges,idx:last_state_idx)
267-
idx = last_state_idx + 1
265+
if t.sampleable_inds[idx] == 1
266+
last_state_idx = idx + t.episodes_lengths[idx] - t.step_numbers[idx] + 1
267+
push!(ranges,idx:last_state_idx)
268+
idx = last_state_idx + 1
269+
else
270+
idx += 1
271+
end
268272
end
269273

270274
return [Episode(NamedTuple{names}(map(x -> collect(t[Val(x)][r]), names))) for r in ranges]

test/samplers.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,14 @@
218218
@test length(b) == 2
219219
@test length(b[1][:state]) == 5
220220
@test length(b[2][:state]) == 6
221+
222+
for (j,i) = enumerate(2:5)
223+
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
224+
end
225+
#only the last state of the first episode is still buffered. Should not be sampled.
226+
b = sample(s, eb)
227+
@test length(b) == 1
228+
221229

222230
#with specified traces
223231
s = EpisodesSampler{(:state,)}()

0 commit comments

Comments
 (0)