Skip to content

Commit 24c21dc

Browse files
Tests pass!
1 parent 486b913 commit 24c21dc

File tree

2 files changed

+26
-17
lines changed

2 files changed

+26
-17
lines changed

src/episodes.jl

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -90,31 +90,40 @@ function pad!(trace::Trace)
9090
return nothing
9191
end
9292

93-
pad!(buf::CircularVectorBuffer) = pad!(buf.buffer)
93+
pad!(buf::CircularVectorBuffer{T}) where {T} = push!(buf, zero(T))
9494
pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))
9595

9696
#push a duplicate of last element as a dummy element for all 'trace' objects, ignores multiplex traces, should never be sampled.
9797
@generated function fill_multiplex(trace_tuple::Traces{names,Trs,N,E}) where {names,Trs,N,E}
98-
i = 1
98+
traces_signature = Trs
9999
ex = :()
100-
for tr in Trs.parameters
101-
if !(tr <: MultiplexTraces)
102-
#push a duplicate of last element as a dummy element, should never be sampled.
103-
ex = :($ex; pad!(trace_tuple.traces[$i]))
100+
i = 1
101+
102+
if traces_signature <: NamedTuple
103+
# Handle 'simple' (non-multiplexed) Traces
104+
for tr in traces_signature.parameters[1]
105+
ex = :($ex; pad!(trace_tuple.traces[$i])) # pad everything
106+
i += 1
104107
end
105-
i += 1
108+
elseif traces_signature <: Tuple
109+
traces_signature = traces_signature.parameters
110+
111+
112+
for tr in traces_signature
113+
if !(tr <: MultiplexTraces)
114+
#push a duplicate of last element as a dummy element, should never be sampled.
115+
ex = :($ex; pad!(trace_tuple.traces[$i]))
116+
end
117+
i += 1
118+
end
119+
else
120+
error("Traces store is neither a tuple nor a named tuple!")
106121
end
122+
107123
return :($ex)
108124
end
109125

110-
# This function is currently unoptimized, could be optimized by using a generated function.
111-
function fill_multiplex(es::EpisodesBuffer)
112-
for trace in es.traces.traces
113-
if !(trace isa MultiplexTraces)
114-
push!(trace, last(trace)) #push a duplicate of last element as a dummy element, should never be sampled.
115-
end
116-
end
117-
end
126+
fill_multiplex(es::EpisodesBuffer) = fill_multiplex(es.traces)
118127

119128
fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(es.traces.traces)
120129

test/episodes.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ using Test
2828
@test eb.episodes_lengths[end] == 0
2929
@test eb.step_numbers[end] == 1
3030
@test eb.sampleable_inds == [1,1,1,1,1,0,0]
31-
@test eb[6][:reward] == 5 #6 is not a valid index, the reward there is dummy duplicate of previous (5)
31+
@test eb[6][:reward] == 0 #6 is not a valid index, the reward there is filled as zero
3232
ep2_len = 0
3333
for (j,i) = enumerate(8:11)
3434
ep2_len += 1
@@ -113,7 +113,7 @@ using Test
113113
@test eb.episodes_lengths[end] == 0
114114
@test eb.step_numbers[end] == 1
115115
@test eb.sampleable_inds == [1,1,1,1,1,0,0]
116-
@test eb[6][:reward] == 5 #6 is not a valid index, the reward there is dummy duplicate of previous (5)
116+
@test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero
117117
ep2_len = 0
118118
for (j,i) = enumerate(8:11)
119119
ep2_len += 1

0 commit comments

Comments
 (0)