Skip to content

Commit 1ff0d8f

Browse files
committed
add tests
1 parent 68b7cd0 commit 1ff0d8f

File tree

4 files changed

+107
-39
lines changed

4 files changed

+107
-39
lines changed

src/common/common.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ const AA = (:action, :next_action)
88
const RT = (:reward, :terminal)
99
const SSART = (SS..., :action, RT...)
1010
const SSAART = (SS..., AA..., RT...)
11+
const SSLART = (SS..., :legal_actions_mask, :action, RT...)
1112
const SSLLAART = (SS..., LL..., AA..., RT...)
1213

1314
include("sum_tree.jl")

src/normalization.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ end
1515

1616
#Treats last dim as batch dim
1717
function OnlineStats.fit!(n::Normalizer, data::AbstractArray)
18-
for d in eachslice(data, dims = ndims(data))
18+
for d in eachslice(data, dims=ndims(data))
1919
fit!(n.os, vec(d))
2020
end
2121
n
@@ -72,17 +72,17 @@ function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractVector)
7272
return (x .- m) ./ s
7373
end
7474

75-
function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractArray)
75+
function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractArray)
7676
xn = similar(x)
77-
for (i, slice) in enumerate(eachslice(x, dims = ndims(x)))
78-
xn[repeat([:], ndims(x)-1)..., i] .= reshape(normalize(os, vec(slice)), size(x)[1:end-1]...)
77+
for (i, slice) in enumerate(eachslice(x, dims=ndims(x)))
78+
xn[repeat([:], ndims(x) - 1)..., i] .= reshape(normalize(os, vec(slice)), size(x)[1:end-1]...)
7979
end
8080
return xn
8181
end
8282

8383
function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractVector{<:AbstractArray})
8484
xn = similar(x)
85-
for (i,el) in enumerate(x)
85+
for (i, el) in enumerate(x)
8686
xn[i] = normalize(os, vec(el))
8787
end
8888
return xn
@@ -96,7 +96,7 @@ have equal weights in the computation of the moments.
9696
See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/)
9797
to use variants such as exponential weights to favor the most recent observations.
9898
"""
99-
scalar_normalizer(; weight::Weight = EqualWeight()) = Normalizer(Moments(weight = weight))
99+
scalar_normalizer(; weight::Weight=EqualWeight()) = Normalizer(Moments(weight=weight))
100100

101101
"""
102102
array_normalizer(size::Tuple{Int}; weights = OnlineStats.EqualWeight())
@@ -108,7 +108,7 @@ By default, all samples have equal weights in the computation of the moments.
108108
See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/)
109109
to use variants such as exponential weights to favor the most recent observations.
110110
"""
111-
array_normalizer(size::NTuple{N,Int}; weight::Weight = EqualWeight()) where N = Normalizer(Group([Moments(weight = weight) for _ in 1:prod(size)]))
111+
array_normalizer(size::NTuple{N,Int}; weight::Weight=EqualWeight()) where {N} = Normalizer(Group([Moments(weight=weight) for _ in 1:prod(size)]))
112112

113113
"""
114114
NormalizedTraces(traces::AbstractTraces, normalizers::NamedTuple)
@@ -142,12 +142,12 @@ traj = Trajectory(
142142
)
143143
```
144144
"""
145-
struct NormalizedTraces{names, TT, T <: AbstractTraces{names, TT}, normnames, N} <: AbstractTraces{names, TT}
145+
struct NormalizedTraces{names,TT,T<:AbstractTraces{names,TT},normnames,N} <: AbstractTraces{names,TT}
146146
traces::T
147-
normalizers::NamedTuple{normnames, N}
148-
end
147+
normalizers::NamedTuple{normnames,N}
148+
end
149149

150-
function NormalizedTraces(traces::AbstractTraces{names, TT}; trace_normalizer_pairs...) where names where TT
150+
function NormalizedTraces(traces::AbstractTraces{names,TT}; trace_normalizer_pairs...) where {names} where {TT}
151151
for key in keys(trace_normalizer_pairs)
152152
@assert key in keys(traces) "Traces do not have key $key, valid keys are $(keys(traces))."
153153
end
@@ -160,11 +160,11 @@ function NormalizedTraces(traces::AbstractTraces{names, TT}; trace_normalizer_pa
160160
else #if not then one is missing
161161
present_key = only(intersect(keys(trace), keys(trace_normalizer_pairs)))
162162
absent_key = only(setdiff(keys(trace), keys(trace_normalizer_pairs)))
163-
nt = merge(nt, (;(absent_key => nt[present_key],)...)) #assign the same normalizer
163+
nt = merge(nt, (; (absent_key => nt[present_key],)...)) #assign the same normalizer
164164
end
165165
end
166166
end
167-
NormalizedTraces{names, TT, typeof(traces), keys(nt), typeof(values(nt))}(traces, nt)
167+
NormalizedTraces{names,TT,typeof(traces),keys(nt),typeof(values(nt))}(traces, nt)
168168
end
169169

170170
function Base.show(io::IO, ::MIME"text/plain", t::NormalizedTraces{names,T}) where {names,T}
@@ -193,6 +193,6 @@ end
193193

194194
function sample(s::BatchSampler, nt::NormalizedTraces, names)
195195
inds = rand(s.rng, 1:length(nt), s.batch_size)
196-
maybe_normalize(data, key) = key in keys(nt.normalizers) ? normalize(nt.normalizers[key], data) : data
197-
NamedTuple{names}(s.transformer(maybe_normalize(nt[x][inds], x) for x in names))
196+
maybe_normalize(data, key) = key in keys(nt.normalizers) ? normalize(nt.normalizers[key], data) : data
197+
NamedTuple{names}(maybe_normalize(nt[x][inds], x) for x in names)
198198
end

src/samplers.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -91,49 +91,49 @@ sample(m::MultiBatchSampler, t) = [sample(m.sampler, t) for _ in 1:m.n]
9191

9292
export NStepBatchSampler
9393

94-
Base.@kwdef mutable struct NStepBatchSampler{traces}
94+
mutable struct NStepBatchSampler{traces}
9595
n::Int # !!! n starts from 1
9696
γ::Float32
97-
batch_size::Int = 32
98-
stack_size::Union{Nothing,Int} = nothing
99-
rng::Any = Random.GLOBAL_RNG
97+
batch_size::Int
98+
stack_size::Union{Nothing,Int}
99+
rng::Any
100100
end
101101

102-
select_last_dim(xs::AbstractArray{T,N}, inds) where {T,N} = @views xs[ntuple(_ -> (:), Val(N - 1))..., inds]
103-
select_last_frame(xs::AbstractArray{T,N}) where {T,N} = select_last_dim(xs, size(xs, N))
104-
105-
consecutive_view(cb, inds; n_stack=nothing, n_horizon=nothing) = consecutive_view(cb, inds, n_stack, n_horizon)
106-
consecutive_view(cb, inds, ::Nothing, ::Nothing) = select_last_dim(cb, inds)
107-
consecutive_view(cb, inds, n_stack::Int, ::Nothing) = select_last_dim(cb, [x + i for i in -n_stack+1:0, x in inds])
108-
consecutive_view(cb, inds, ::Nothing, n_horizon::Int) = select_last_dim(cb, [x + j for j in 0:n_horizon-1, x in inds])
109-
consecutive_view(cb, inds, n_stack::Int, n_horizon::Int) = select_last_dim(cb, [x + i + j for i in -n_stack+1:0, j in 0:n_horizon-1, x in inds])
102+
NStepBatchSampler(; kw...) = NStepBatchSampler{SSART}(; kw...)
103+
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names} = NStepBatchSampler{names}(n, γ, batch_size, stack_size, rng)
110104

111105
function sample(s::NStepBatchSampler{names}, ts) where {names}
112106
valid_range = isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1))# think about the exteme case where s.stack_size == 1 and s.n == 1
113107
inds = rand(s.rng, valid_range, s.batch_size)
114108
sample(s, ts, Val(names), inds)
115109
end
116110

117-
function sample(s::NStepBatchSampler, ts, ::Val{SSART}, inds)
118-
s = consecutive_view(ts[:state], inds; n_stack=s.stack_size)
119-
s′ = consecutive_view(ts[:next_state], inds .+ (s.n - 1); n_stack=s.stack_size)
120-
a = consecutive_view(ts[:action], inds)
121-
t_horizon = consecutive_view(ts[:terminal], inds; n_horizon=s.n)
122-
r_horizon = consecutive_view(ts[:reward], inds; n_horizon=s.n)
111+
function sample(nbs::NStepBatchSampler, ts, ::Val{SSART}, inds)
112+
if isnothing(nbs.stack_size)
113+
s = ts[:state][inds]
114+
s′ = ts[:next_state][inds.+(nbs.n-1)]
115+
else
116+
s = ts[:state][[x + i for i in -nbs.stack_size+1:0, x in inds]]
117+
s′ = ts[:next_state][[x + nbs.n - 1 + i for i in -nbs.stack_size+1:0, x in inds]]
118+
end
119+
120+
a = ts[:action][inds]
121+
t_horizon = ts[:terminal][[x + j for j in 0:nbs.n-1, x in inds]]
122+
r_horizon = ts[:reward][[x + j for j in 0:nbs.n-1, x in inds]]
123123

124124
@assert ndims(t_horizon) == 2
125-
t = any(t_horizon, dims=1)
125+
t = any(t_horizon, dims=1) |> vec
126126

127127
@assert ndims(r_horizon) == 2
128128
r = map(eachcol(r_horizon), eachcol(t_horizon)) do r⃗, t⃗
129-
foldr((init, (rr, tt)) -> rr + f.γ * init * (1 - tt), zip(r⃗, t⃗); init=0.0f0)
129+
foldr(((rr, tt), init) -> rr + nbs.γ * init * (1 - tt), zip(r⃗, t⃗); init=0.0f0)
130130
end
131131

132-
NamedTuple{names}(s, s′, a, r, t)
132+
NamedTuple{SSART}((s, s′, a, r, t))
133133
end
134134

135135
function sample(s::NStepBatchSampler, ts, ::Val{SSLART}, inds)
136-
s, s′, a, r, t = sample(s, ts, Val(SSART), inds),
136+
s, s′, a, r, t = sample(s, ts, Val(SSART), inds)
137137
l = consecutive_view(ts[:legal_actions_mask], inds)
138-
NamedTuple{SSLART}(s, s′, l, a, r, t)
138+
NamedTuple{SSLART}((s, s′, l, a, r, t))
139139
end

test/samplers.jl

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,71 @@ end
5858
@test length(batches[1][:policy][:a]) == 3
5959
@test length(batches[1][:critic]) == 2 # we sampled 2 batches for critic
6060
@test length(batches[1][:critic][1][:b]) == 5 #each batch is 5 samples
61-
end
61+
end
62+
63+
#! format: off
64+
@testset "NStepSampler" begin
65+
γ = 0.9
66+
n_stack = 2
67+
n_horizon = 3
68+
batch_size = 4
69+
70+
t1 = MultiplexTraces{(:state, :next_state)}(1:10) +
71+
MultiplexTraces{(:action, :next_action)}(iseven.(1:10)) +
72+
Traces(
73+
reward=1:9,
74+
terminal=Bool[0, 0, 0, 1, 0, 0, 0, 0, 1],
75+
)
76+
77+
s1 = NStepBatchSampler(n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size)
78+
79+
xs = RLTrajectories.sample(s1, t1)
80+
81+
@test size(xs.state) == (n_stack, batch_size)
82+
@test size(xs.next_state) == (n_stack, batch_size)
83+
@test size(xs.action) == (batch_size,)
84+
@test size(xs.reward) == (batch_size,)
85+
@test size(xs.terminal) == (batch_size,)
86+
87+
88+
state_size = (2,3)
89+
n_state = reduce(*, state_size)
90+
total_length = 10
91+
t2 = MultiplexTraces{(:state, :next_state)}(
92+
reshape(1:n_state * total_length, state_size..., total_length)
93+
) +
94+
MultiplexTraces{(:action, :next_action)}(iseven.(1:total_length)) +
95+
Traces(
96+
reward=1:total_length-1,
97+
terminal=Bool[0, 0, 0, 1, 0, 0, 0, 0, 1],
98+
)
99+
100+
xs2 = RLTrajectories.sample(s1, t2)
101+
102+
@test size(xs2.state) == (state_size..., n_stack, batch_size)
103+
@test size(xs2.next_state) == (state_size..., n_stack, batch_size)
104+
@test size(xs2.action) == (batch_size,)
105+
@test size(xs2.reward) == (batch_size,)
106+
@test size(xs2.terminal) == (batch_size,)
107+
108+
inds = [3, 5, 7]
109+
xs3 = RLTrajectories.sample(s1, t2, Val(SSART), inds)
110+
111+
@test xs3.state == cat(
112+
(
113+
reshape(n_state * (i-n_stack)+1: n_state * i, state_size..., n_stack)
114+
for i in inds
115+
)...
116+
;dims=length(state_size) + 2
117+
)
118+
119+
@test xs3.next_state == xs3.state .+ (n_state * n_horizon)
120+
@test xs3.action == iseven.(inds)
121+
@test xs3.terminal == [any(t2[:terminal][i: i+n_horizon-1]) for i in inds]
122+
123+
# manual calculation
124+
@test xs3.reward[1] 3 + γ * 4 # terminated at step 4
125+
@test xs3.reward[2] 5 + γ * (6 + γ * 7)
126+
@test xs3.reward[3] 7 + γ * (8 + γ * 9)
127+
end
128+
#! format: on

0 commit comments

Comments
 (0)