Skip to content

Commit a981564

Browse files
Merge pull request #45 from JuliaReinforcementLearning/jpsl/type-stability-metaprogramming
Use metaprogramming to improve type stability...
2 parents c9f2ebe + 05475c1 commit a981564

File tree

5 files changed

+217
-47
lines changed

5 files changed

+217
-47
lines changed

src/episodes.jl

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,48 @@ ispartial_insert(traces::Traces, xs) = length(xs) < length(traces.traces) #this
8585
ispartial_insert(es::EpisodesBuffer, xs) = ispartial_insert(es.traces, xs)
8686
ispartial_insert(traces::CircularPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs)
8787

88-
function fill_multiplex(es::EpisodesBuffer)
89-
for trace in es.traces.traces
90-
if !(trace isa MultiplexTraces)
91-
push!(trace, last(trace)) #push a duplicate of last element as a dummy element, should never be sampled.
92-
end
93-
end
88+
function pad!(trace::Trace)
89+
pad!(trace.parent)
90+
return nothing
9491
end
95-
function fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces})
96-
for trace in es.traces.traces.traces
97-
if !(trace isa MultiplexTraces)
98-
push!(trace, last(trace)) #push a duplicate of last element as a dummy element, should never be sampled.
92+
93+
pad!(buf::CircularVectorBuffer{T}) where {T} = push!(buf, zero(T))
94+
pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))
95+
96+
#push a duplicate of last element as a dummy element for all 'trace' objects, ignores multiplex traces, should never be sampled.
97+
@generated function fill_multiplex(trace_tuple::Traces{names,Trs,N,E}) where {names,Trs,N,E}
98+
traces_signature = Trs
99+
ex = :()
100+
i = 1
101+
102+
if traces_signature <: NamedTuple
103+
# Handle 'simple' (non-multiplexed) Traces
104+
for tr in traces_signature.parameters[1]
105+
ex = :($ex; pad!(trace_tuple.traces[$i])) # pad everything
106+
i += 1
99107
end
108+
elseif traces_signature <: Tuple
109+
traces_signature = traces_signature.parameters
110+
111+
112+
for tr in traces_signature
113+
if !(tr <: MultiplexTraces)
114+
#push a duplicate of last element as a dummy element, should never be sampled.
115+
ex = :($ex; pad!(trace_tuple.traces[$i]))
116+
end
117+
i += 1
118+
end
119+
else
120+
error("Traces store is neither a tuple nor a named tuple!")
100121
end
122+
123+
return :($ex)
101124
end
102125

126+
fill_multiplex(es::EpisodesBuffer) = fill_multiplex(es.traces)
127+
128+
fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(es.traces.traces)
129+
103130
function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
104131
push!(eb.traces, xs)
105132
partial = ispartial_insert(eb, xs)

src/samplers.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ StatsBase.sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = Stat
4949

5050
function StatsBase.sample(s::BatchSampler, t::AbstractTraces, names, weights = StatsBase.UnitWeights{Int}(length(t)))
5151
inds = StatsBase.sample(s.rng, 1:length(t), weights, s.batch_size)
52-
NamedTuple{names}(map(x -> collect(t[x][inds]), names))
52+
NamedTuple{names}(map(x -> collect(t[Val(x)][inds]), names))
5353
end
5454

5555
function StatsBase.sample(s::BatchSampler, t::EpisodesBuffer, names)
@@ -74,12 +74,12 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir
7474
st = deepcopy(t.priorities)
7575
st .*= e.sampleable_inds[1:end-1] #temporary sumtree that puts 0 priority to non sampleable indices.
7676
inds, priorities = rand(s.rng, st, s.batch_size)
77-
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[x][inds]), names)...))
77+
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[Val(x)][inds]), names)...))
7878
end
7979

8080
function StatsBase.sample(s::BatchSampler, t::CircularPrioritizedTraces, names)
8181
inds, priorities = rand(s.rng, t.priorities, s.batch_size)
82-
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[x][inds]), names)...))
82+
NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[Val(x)][inds]), names)...))
8383
end
8484

8585
#####

src/traces.jl

Lines changed: 131 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,18 @@ end
128128

129129
Adapt.adapt_structure(to, t::MultiplexTraces{names}) where {names} = MultiplexTraces{names}(Adapt.adapt_structure(to, t.trace))
130130

131-
function Base.getindex(t::MultiplexTraces{names}, k::Symbol) where {names}
132-
a, b = names
133-
if k == a
134-
RelativeTrace{0,-1}(convert(AbstractTrace, t.trace))
135-
elseif k == b
136-
RelativeTrace{1,0}(convert(AbstractTrace, t.trace))
131+
Base.getindex(t::MultiplexTraces{names}, k::Symbol) where {names} = _getindex(t, Val(k))
132+
133+
@generated function _getindex(t::MultiplexTraces{names}, ::Val{k}) where {names,k}
134+
ex = :()
135+
if QuoteNode(names[1]) == QuoteNode(k)
136+
ex = :(RelativeTrace{0,-1}(t.trace))
137+
elseif QuoteNode(names[2]) == QuoteNode(k)
138+
ex = :(RelativeTrace{1,0}(t.trace))
137139
else
138-
throw(ArgumentError("unknown trace name: $k"))
140+
ex = :(throw(ArgumentError("unknown trace name: $k")))
139141
end
142+
return :($ex)
140143
end
141144

142145
Base.getindex(t::MultiplexTraces{names}, I::Int) where {names} = NamedTuple{names}((t.trace[I], t.trace[I+1]))
@@ -163,83 +166,144 @@ end
163166

164167
struct Traces{names,T,N,E} <: AbstractTraces{names,E}
165168
traces::T
166-
inds::NamedTuple{names,NTuple{N,Int}}
167169
end
168170

169171
function Adapt.adapt_structure(to, t::Traces{names,T,N,E}) where {names,T,N,E}
170172
data = Adapt.adapt_structure(to, t.traces)
171173
# FIXME: `E` is not adapted here
172-
Traces{names,typeof(data),length(names),E}(data, t.inds)
174+
Traces{names,typeof(data),length(names),E}(data)
173175
end
174176

175177
function Traces(; kw...)
176178
data = map(x -> convert(AbstractTrace, x), values(kw))
177179
names = keys(data)
178-
inds = NamedTuple(k => i for (i, k) in enumerate(names))
179-
Traces{names,typeof(data),length(names),typeof(values(data))}(data, inds)
180+
Traces{names,typeof(data),length(names),typeof(values(data))}(data)
180181
end
181182

182183

183-
function Base.getindex(ts::Traces, s::Symbol)
184-
t = ts.traces[ts.inds[s]]
184+
Base.getindex(ts::Traces, s::Symbol) = Base.getindex(ts::Traces, Val(s))
185+
186+
function Base.getindex(ts::Traces, ::Val{s}) where {s}
187+
t = _gettrace(ts, Val(s))
185188
if t isa AbstractTrace
186189
t
190+
elseif t isa MultiplexTraces
191+
_getindex(t, Val(s))
187192
else
188-
t[s]
193+
throw(ArgumentError("unknown trace name: $s"))
189194
end
190195
end
191196

192-
Base.getindex(t::Traces{names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names))
197+
@generated function _gettrace(ts::Traces{names,Trs,N,E}, ::Val{k}) where {names,Trs,N,E,k}
198+
index_ = build_trace_index(names, Trs)
199+
# Generate code, i.e. find the correct index for a given key
200+
ex = :()
201+
202+
for name in names
203+
if QuoteNode(name) == QuoteNode(k)
204+
index_element = index_[k]
205+
ex = :(ts.traces[$index_element])
206+
break
207+
end
208+
end
209+
210+
return :($ex)
211+
end
212+
213+
@generated function Base.getindex(t::Traces{names}, i) where {names}
214+
ex = :(NamedTuple{$(names)}($(Expr(:tuple))))
215+
for k in names
216+
push!(ex.args[2].args, :(t[Val($(QuoteNode(k)))][i]))
217+
end
218+
return ex
219+
end
193220

194221
function Base.:(+)(t1::AbstractTraces{k1,T1}, t2::AbstractTraces{k2,T2}) where {k1,k2,T1,T2}
195222
ks = (k1..., k2...)
196223
ts = (t1, t2)
197-
inds = (; (k => 1 for k in k1)..., (k => 2 for k in k2)...)
198-
Traces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds)
224+
Traces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts)
199225
end
200226

201227
function Base.:(+)(t1::AbstractTraces{k1,T1}, t2::Traces{k2,T,N,T2}) where {k1,T1,k2,T,N,T2}
202228
ks = (k1..., k2...)
203229
ts = (t1, t2.traces...)
204-
inds = merge(NamedTuple(k => 1 for k in k1), map(v -> v + 1, t2.inds))
205-
Traces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds)
230+
Traces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts)
206231
end
207232

208233

209234
function Base.:(+)(t1::Traces{k1,T,N,T1}, t2::AbstractTraces{k2,T2}) where {k1,T,N,T1,k2,T2}
210235
ks = (k1..., k2...)
211236
ts = (t1.traces..., t2)
212-
inds = merge(t1.inds, (; (k => length(ts) for k in k2)...))
213-
Traces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts, inds)
237+
Traces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts)
214238
end
215239

216240
function Base.:(+)(t1::Traces{k1,T1,N1,E1}, t2::Traces{k2,T2,N2,E2}) where {k1,T1,N1,E1,k2,T2,N2,E2}
217241
ks = (k1..., k2...)
218242
ts = (t1.traces..., t2.traces...)
219-
inds = merge(t1.inds, map(x -> x + length(t1.traces), t2.inds))
220-
Traces{ks,typeof(ts),length(ks),Tuple{E1.types...,E2.types...}}(ts, inds)
243+
Traces{ks,typeof(ts),length(ks),Tuple{E1.types...,E2.types...}}(ts)
221244
end
222245

223246
Base.size(t::Traces) = (mapreduce(length, min, t.traces),)
224-
capacity(t::Traces) = minimum(map(idx->capacity(t.traces[idx]),t.inds))
225247

226-
for f in (:push!, :pushfirst!)
227-
@eval function Base.$f(ts::Traces, xs::NamedTuple)
228-
for (k, v) in pairs(xs)
229-
$f(ts, Val(k), v)
248+
function capacity(t::Traces{names,Trs,N,E}) where {names,Trs,N,E}
249+
minimum(map(idx->capacity(t[idx]), names))
250+
end
251+
252+
@generated function Base.push!(ts::Traces, xs::NamedTuple{N,T}) where {N,T}
253+
ex = :()
254+
for n in N
255+
ex = :($ex; push!(ts, Val($(QuoteNode(n))), xs.$n))
256+
end
257+
return :($ex)
258+
end
259+
260+
@generated function Base.pushfirst!(ts::Traces, xs::NamedTuple{N,T}) where {N,T}
261+
ex = :()
262+
for n in N
263+
ex = :($ex; pushfirst!(ts, Val($(QuoteNode(n))), xs.$n))
264+
end
265+
return :($ex)
266+
end
267+
268+
@generated function Base.pushfirst!(ts::Traces{names,Trs,N,E}, ::Val{k}, v) where {names,Trs,N,E,k}
269+
index_ = build_trace_index(names, Trs)
270+
# Generate code, i.e. find the correct index for a given key
271+
ex = :()
272+
273+
for name in names
274+
if QuoteNode(name) == QuoteNode(k)
275+
index_element = index_[k]
276+
ex = :(pushfirst!(ts.traces[$index_element], Val($(QuoteNode(k))), v))
277+
break
230278
end
231279
end
232280

233-
@eval function Base.$f(ts::Traces, ::Val{k}, v) where {k}
234-
$f(ts.traces[ts.inds[k]], Val(k), v)
281+
return :($ex)
282+
end
283+
284+
@generated function Base.push!(ts::Traces{names,Trs,N,E}, ::Val{k}, v) where {names,Trs,N,E,k}
285+
index_ = build_trace_index(names, Trs)
286+
# Generate code, i.e. find the correct index for a given key
287+
ex = :()
288+
289+
for name in names
290+
if QuoteNode(name) == QuoteNode(k)
291+
index_element = index_[k]
292+
ex = :(push!(ts.traces[$index_element], Val($(QuoteNode(k))), v))
293+
break
294+
end
235295
end
236296

297+
return :($ex)
298+
end
299+
300+
for f in (:push!, :pushfirst!)
237301
@eval function Base.$f(t::AbstractTrace, ::Val{k}, v) where {k}
238302
$f(t, v)
239303
end
240304

241305
@eval function Base.$f(t::Trace, ::Val{k}, v) where {k}
242-
$f(t, v)
306+
$f(t.parent, v)
243307
end
244308

245309
@eval function Base.$f(ts::MultiplexTraces, ::Val{k}, v) where {k}
@@ -251,7 +315,7 @@ end
251315
for f in (:append!, :prepend!)
252316
@eval function Base.$f(ts::Traces, xs::Traces)
253317
for k in keys(xs)
254-
t = ts.traces[ts.inds[k]]
318+
t = _gettrace(ts, Val(k))
255319
$f(t, xs[k])
256320
end
257321
end
@@ -264,3 +328,38 @@ for f in (:pop!, :popfirst!, :empty!)
264328
end
265329
end
266330
end
331+
332+
333+
"""
334+
build_trace_index(names::NTuple, traces_signature::DataType)
335+
336+
Take type signature from `Traces` and build a mapping from trace name to trace index
337+
"""
338+
function build_trace_index(names::NTuple, traces_signature::DataType)
339+
# Build index
340+
index_ = Dict()
341+
342+
if traces_signature <: NamedTuple
343+
# Handle simple Traces
344+
index_ = Dict(name => i for (name, i) zip(names, 1:length(names)))
345+
elseif traces_signature <: Tuple
346+
# Handle MultiplexTracesup
347+
i = 1
348+
j = 1
349+
trace_list = traces_signature.parameters
350+
for tr in trace_list
351+
if tr <: MultiplexTraces
352+
index_[names[i]] = j
353+
i += 1
354+
index_[names[i]] = j
355+
else
356+
index_[names[i]] = j
357+
end
358+
i += 1
359+
j += 1
360+
end
361+
else
362+
error("Traces store is neither a tuple nor a named tuple!")
363+
end
364+
return index_
365+
end

test/episodes.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ using Test
2828
@test eb.episodes_lengths[end] == 0
2929
@test eb.step_numbers[end] == 1
3030
@test eb.sampleable_inds == [1,1,1,1,1,0,0]
31-
@test eb[6][:reward] == 5 #6 is not a valid index, the reward there is dummy duplicate of previous (5)
31+
@test eb[6][:reward] == 0 #6 is not a valid index, the reward there is filled as zero
3232
ep2_len = 0
3333
for (j,i) = enumerate(8:11)
3434
ep2_len += 1
@@ -113,7 +113,7 @@ using Test
113113
@test eb.episodes_lengths[end] == 0
114114
@test eb.step_numbers[end] == 1
115115
@test eb.sampleable_inds == [1,1,1,1,1,0,0]
116-
@test eb[6][:reward] == 5 #6 is not a valid index, the reward there is dummy duplicate of previous (5)
116+
@test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero
117117
ep2_len = 0
118118
for (j,i) = enumerate(8:11)
119119
ep2_len += 1

test/traces.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,47 @@ end
115115
t8_view.a[1] = 0
116116
@test t8[:a][2] == 0
117117
end
118+
119+
using ReinforcementLearningTrajectories: build_trace_index
120+
121+
@testset "build_trace_index" begin
122+
t1 = CircularArraySARTSATraces(;
123+
capacity=3,
124+
state=Float32 => (2, 3),
125+
action=Float32 => (2,),
126+
reward=Float32 => (),
127+
terminal=Bool => ()
128+
)
129+
@test build_trace_index(typeof(t1).parameters[1], typeof(t1).parameters[2]) == Dict(:reward => 3,
130+
:next_state => 1,
131+
:state => 1,
132+
:action => 2,
133+
:next_action => 2,
134+
:terminal => 4)
135+
136+
t2 = Traces(; a=[2, 3], b=[false, true])
137+
build_trace_index(typeof(t2).parameters[1], typeof(t2).parameters[2])
138+
end
139+
140+
@testset "push!(ts::Traces{names,Trs,N,E}, ::Val{k}, v)" begin
141+
t1 = CircularArraySARTSATraces(;
142+
capacity=3,
143+
state=Float32 => (2, 3),
144+
action=Float32 => (2,),
145+
reward=Float32 => (),
146+
terminal=Bool => ()
147+
)
148+
push!(t1, Val(:reward), 5)
149+
@test t1[:reward][1] == 5
150+
151+
@test size(Base.getindex(t1, :reward)) == (1,)
152+
@test size(Base.getindex(t1, 1).state) == (2,3)
153+
154+
155+
t2 = Traces(; a=[2, 3], b=[false, true])
156+
push!(t2, Val(:a), 5)
157+
@test t2[:a][3] == 5
158+
159+
@test size(Base.getindex(t2, :a)) == (3,)
160+
@test Base.getindex(t2, 1) == (; a = 2, b= false)
161+
end

0 commit comments

Comments
 (0)