Skip to content

Commit 17ff1ed

Browse files
committed
minor updates when adapting it in RL.jl
1 parent f2239bd commit 17ff1ed

File tree

4 files changed

+18
-21
lines changed

4 files changed

+18
-21
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,13 @@ version = "0.1.0"
44

55
[deps]
66
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
7-
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
87
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
98
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
109
StackViews = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
1110

1211
[compat]
1312
CircularArrayBuffers = "0.1"
1413
MacroTools = "0.5"
15-
MLUtils = "0.2"
1614
StackViews = "0.1"
1715
julia = "1.6"
1816

src/patch.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
import MLUtils
2-
3-
MLUtils.batch(x::AbstractArray{<:Number}) = x
4-
51
#####
62

73
import StackViews: StackView

src/samplers.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
11
export BatchSampler, MetaSampler, MultiBatchSampler
22

3-
using MLUtils: batch
4-
53
using Random
64

75
abstract type AbstractSampler end
86

9-
struct BatchSampler <: AbstractSampler
7+
struct BatchSampler{names} <: AbstractSampler
108
batch_size::Int
119
rng::Random.AbstractRNG
1210
transformer::Any
1311
end
1412

1513
"""
16-
BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=identity)
14+
BatchSampler{names}(;batch_size, rng=Random.GLOBAL_RNG, transformer=identity)
1715
18-
Uniformly sample a batch of examples for each trace.
16+
Uniformly sample a batch of examples for each trace specified in `names`. By default, all the traces will be sampled.
1917
2018
See also [`sample`](@ref).
2119
"""
22-
BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=batch) = BatchSampler(batch_size, rng, transformer)
20+
BatchSampler(; kw...) = BatchSampler{nothing}(; kw...)
21+
BatchSampler{names}(; batch_size, rng=Random.GLOBAL_RNG, transformer=identity) where {names} = BatchSampler{names}(batch_size, rng, transformer)
22+
23+
sample(s::BatchSampler{nothing}, t::AbstractTraces) = sample(s, t, keys(t))
24+
sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = sample(s, t, names)
2325

24-
function sample(s::BatchSampler, t::AbstractTraces)
26+
function sample(s::BatchSampler, t::AbstractTraces, names)
2527
inds = rand(s.rng, 1:length(t), s.batch_size)
26-
map(s.transformer, t[inds])
28+
NamedTuple{names}(s.transformer(t[x][inds]) for x in names)
2729
end
2830

2931
"""

src/trajectory.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ function Base.bind(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}
6565
bind(t.controler.ch_out, task)
6666
end
6767

68+
# !!! by default we assume `x` is a complete example which contains all the traces
69+
# When doing partial inserting, the result of undefined
6870
function Base.push!(t::Trajectory, x)
69-
n_pre = length(t.container)
7071
push!(t.container, x)
71-
n_post = length(t.container)
72-
on_insert!(t.controller, n_post - n_pre)
72+
on_insert!(t.controller, 1)
7373
end
7474

7575
struct CallMsg
@@ -81,13 +81,14 @@ end
8181
Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args...; kw...) = put!(t.controller.ch_in, CallMsg(Base.push!, args, kw))
8282
Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args...; kw...) = put!(t.controller.ch_in, CallMsg(Base.append!, args, kw))
8383

84-
function Base.append!(t::Trajectory, x)
85-
n_pre = length(t.container)
84+
function Base.append!(t::Trajectory, x::AbstractVector)
8685
append!(t.container, x)
87-
n_post = length(t.container)
88-
on_insert!(t.controller, n_post - n_pre)
86+
on_insert!(t.controller, length(x))
8987
end
9088

89+
# !!! bypass the controller
90+
sample(t::Trajectory) = sample(t.sampler, t.container)
91+
9192
function Base.take!(t::Trajectory)
9293
res = on_sample!(t.controller)
9394
if isnothing(res)

0 commit comments

Comments
 (0)