Skip to content

Commit 5f90ce2

Browse files
authored
Merge pull request #28 from findmyway/add_controller_and_sampler
Add ElasticArraySARTTraces
2 parents bcd02d7 + 418c4e3 commit 5f90ce2

File tree

7 files changed

+84
-10
lines changed

7 files changed

+84
-10
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "ReinforcementLearningTrajectories"
22
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
3-
version = "0.1.4"
3+
version = "0.1.5"
44

55
[deps]
66
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
7+
ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
78
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
89
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
export ElasticArraySARTTraces
2+
3+
using ElasticArrays: ElasticArray, resize_lastdim!
4+
5+
const ElasticArraySARTTraces = Traces{
6+
SS′AA′RT,
7+
<:Tuple{
8+
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
9+
<:MultiplexTraces{AA′,<:Trace{<:ElasticArray}},
10+
<:Trace{<:ElasticArray},
11+
<:Trace{<:ElasticArray},
12+
}
13+
}
14+
15+
function ElasticArraySARTTraces(;
16+
state=Int => (),
17+
action=Int => (),
18+
reward=Float32 => (),
19+
terminal=Bool => ()
20+
)
21+
state_eltype, state_size = state
22+
action_eltype, action_size = action
23+
reward_eltype, reward_size = reward
24+
terminal_eltype, terminal_size = terminal
25+
26+
MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
27+
MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) +
28+
Traces(
29+
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
30+
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
31+
)
32+
end
33+
34+
#####
35+
# extensions for ElasticArrays
36+
#####
37+
38+
Base.push!(a::ElasticArray, x) = append!(a, x)
39+
Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x])
40+
Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0)

src/common/common.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ include("sum_tree.jl")
1515
include("CircularArraySARTTraces.jl")
1616
include("CircularArraySLARTTraces.jl")
1717
include("CircularPrioritizedTraces.jl")
18+
include("ElasticArraySARTTraces.jl")

src/controllers.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ function on_sample!(c::InsertSampleRatioController)
2424
if c.n_inserted >= c.threshold
2525
if c.n_sampled <= (c.n_inserted - c.threshold) * c.ratio
2626
c.n_sampled += 1
27-
true
27+
return true
2828
end
2929
end
30+
return false
3031
end
3132

3233
#####
@@ -56,4 +57,4 @@ function AsyncInsertSampleRatioController(
5657
Channel(ch_in_sz),
5758
Channel(ch_out_sz)
5859
)
59-
end
60+
end

src/samplers.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@ using Random
22

33
abstract type AbstractSampler end
44

5+
#####
6+
# DummySampler
7+
#####
8+
9+
export DummySampler
10+
11+
struct DummySampler end
12+
13+
sample(s::DummySampler, t::AbstractTraces) = t
14+
515
#####
616
# BatchSampler
717
#####
@@ -160,4 +170,4 @@ function sample(s::NStepBatchSampler{names}, t::CircularPrioritizedTraces) where
160170
(key=t.keys[inds], priority=priorities),
161171
sample(s, t.traces, Val(names), inds)
162172
)
163-
end
173+
end

src/trajectory.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Supported methoes are:
2222
"""
2323
Base.@kwdef struct Trajectory{C,S,T}
2424
container::C
25-
sampler::S
25+
sampler::S = DummySampler()
2626
controller::T = InsertSampleRatioController()
2727
transformer::Any = identity
2828

@@ -96,12 +96,13 @@ end
9696
# !!! bypass the controller
9797
sample(t::Trajectory) = sample(t.sampler, t.container)
9898

99+
on_sample!(t::Trajectory) = on_sample!(t.controller)
100+
99101
function Base.take!(t::Trajectory)
100-
res = on_sample!(t.controller)
101-
if isnothing(res)
102-
nothing
103-
else
102+
if on_sample!(t)
104103
sample(t.sampler, t.container) |> t.transformer
104+
else
105+
nothing
105106
end
106107
end
107108

@@ -120,4 +121,4 @@ Base.iterate(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args
120121
Base.take!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}) = take!(t.controller.ch_out)
121122

122123
Base.IteratorSize(::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}) = Base.IsInfinite()
123-
Base.IteratorSize(::Trajectory) = Base.SizeUnknown()
124+
Base.IteratorSize(::Trajectory) = Base.SizeUnknown()

test/common.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,26 @@ end
9191
@test batch.terminal == Bool[0, 0, 0]
9292
end
9393

94+
@testset "ElasticArraySARTTraces" begin
95+
t = ElasticArraySARTTraces(;
96+
state=Float32 => (2, 3),
97+
action=Int => (),
98+
reward=Float32 => (),
99+
terminal=Bool => ()
100+
)
101+
102+
@test t isa ElasticArraySARTTraces
103+
104+
push!(t, (state=ones(Float32, 2, 3), action=1))
105+
push!(t, (reward=1.0f0, terminal=false, state=ones(Float32, 2, 3) * 2, action=2))
106+
107+
@test length(t) == 1
108+
109+
empty!(t)
110+
111+
@test length(t) == 0
112+
end
113+
94114
@testset "CircularArraySLARTTraces" begin
95115
t = CircularArraySLARTTraces(;
96116
capacity=3,

0 commit comments

Comments
 (0)