Skip to content

Commit f574b3d

Browse files
committed
fix tests
1 parent d63fd75 commit f574b3d

File tree

9 files changed

+107
-84
lines changed

9 files changed

+107
-84
lines changed

src/Trajectories.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
module Trajectories
22

33
include("samplers.jl")
4+
include("controlers.jl")
45
include("traces.jl")
56
include("episodes.jl")
67
include("trajectory.jl")
7-
include("async_trajectory.jl")
88
include("rendering.jl")
99
include("common/common.jl")
1010

src/common/CircularArraySARTTraces.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function Random.rand(s::BatchSampler, t::CircularArraySARTTraces)
4141
terminal=t[:terminal][inds],
4242
next_state=t[:state][inds′],
4343
next_action=t[:state][inds′]
44-
)
44+
) |> s.transformer
4545
end
4646

4747
function Base.push!(t::CircularArraySARTTraces, x::NamedTuple{SA})

src/common/CircularArraySLARTTraces.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function sample(s::BatchSampler, t::CircularArraySLARTTraces)
4747
next_state=t[:state][inds′],
4848
next_legal_actions_mask=t[:legal_actions_mask][inds′],
4949
next_action=t[:state][inds′]
50-
)
50+
) |> s.transformer
5151
end
5252

5353
function Base.push!(t::CircularArraySLARTTraces, x::NamedTuple{SLA})

src/controlers.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
export InsertSampleRatioControler, AsyncInsertSampleRatioControler
2+
3+
mutable struct InsertSampleRatioControler
4+
ratio::Float64
5+
threshold::Int
6+
n_inserted::Int
7+
n_sampled::Int
8+
end
9+
10+
"""
11+
InsertSampleRatioControler(ratio, threshold)
12+
13+
Used in [`Trajectory`](@ref). The `threshold` means the minimal number of
14+
insertings before sampling. The `ratio` balances the number of insertings and
15+
the number of samplings.
16+
"""
17+
InsertSampleRatioControler(ratio, threshold) = InsertSampleRatioControler(ratio, threshold, 0, 0)
18+
19+
function on_insert!(c::InsertSampleRatioControler, n::Int)
20+
if n > 0
21+
c.n_inserted += n
22+
end
23+
end
24+
25+
function on_sample!(c::InsertSampleRatioControler)
26+
if c.n_inserted >= c.threshold
27+
if c.n_sampled <= (c.n_inserted - c.threshold) * c.ratio
28+
c.n_sampled += 1
29+
true
30+
end
31+
end
32+
end
33+
34+
#####
35+
36+
mutable struct AsyncInsertSampleRatioControler
37+
ratio::Float64
38+
threshold::Int
39+
n_inserted::Int
40+
n_sampled::Int
41+
ch_in::Channel
42+
ch_out::Channel
43+
end
44+
45+
function AsyncInsertSampleRatioControler(
46+
ratio,
47+
threshold,
48+
; ch_in_sz=1,
49+
ch_out_sz=1,
50+
n_inserted=0,
51+
n_sampled=0
52+
)
53+
AsyncInsertSampleRatioControler(
54+
ratio,
55+
threshold,
56+
n_inserted,
57+
n_sampled,
58+
Channel(ch_in_sz),
59+
Channel(ch_out_sz)
60+
)
61+
end

src/episodes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,5 +103,5 @@ end
103103

104104
function sample(s::BatchSampler, e::Episodes)
105105
inds = rand(s.rng, 1:length(t), s.batch_size)
106-
batch([@view(s.episodes[e.inds[i][1]][e.inds[i][2]]) for i in inds])
106+
batch([@view(s.episodes[e.inds[i][1]][e.inds[i][2]]) for i in inds]) |> s.transformer
107107
end

src/samplers.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ using Random
55
struct BatchSampler
66
batch_size::Int
77
rng::Random.AbstractRNG
8+
transformer::Any
89
end
910

1011
"""
11-
BatchSampler(batch_size; rng=Random.GLOBAL_RNG)
12+
BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=identity)
1213
1314
Uniformly sample a batch of examples for each trace.
1415
1516
See also [`sample`](@ref).
1617
"""
17-
BatchSampler(batch_size; rng=Random.GLOBAL_RNG) = BatchSampler(batch_size, rng)
18+
BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=identity) = BatchSampler(batch_size, rng, identity)

src/traces.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Base.empty!(t::Trace) = empty!(t.x)
3636

3737
function sample(s::BatchSampler, t::Trace)
3838
inds = rand(s.rng, 1:length(t), s.batch_size)
39-
t[inds]
39+
t[inds] |> s.transformer
4040
end
4141

4242
#####
@@ -84,5 +84,5 @@ function sample(s::BatchSampler, t::Traces)
8484
inds = rand(s.rng, 1:length(t), s.batch_size)
8585
map(t.traces) do x
8686
x[inds]
87-
end
87+
end |> s.transformer
8888
end

src/trajectory.jl

Lines changed: 10 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,8 @@
1-
export Trajectory, InsertSampleRatioControler
1+
export Trajectory
22

33
using Base.Threads
44

55

6-
#####
7-
8-
mutable struct InsertSampleRatioControler
9-
ratio::Float64
10-
threshold::Int
11-
n_inserted::Int
12-
n_sampled::Int
13-
end
14-
15-
"""
16-
InsertSampleRatioControler(ratio, threshold)
17-
18-
Used in [`Trajectory`](@ref). The `threshold` means the minimal number of
19-
insertings before sampling. The `ratio` balances the number of insertings and
20-
the number of samplings.
21-
"""
22-
InsertSampleRatioControler(ratio, threshold) = InsertSampleRatioControler(ratio, threshold, 0, 0)
23-
24-
function on_insert!(c::InsertSampleRatioControler, n::Int)
25-
if n > 0
26-
c.n_inserted += n
27-
end
28-
end
29-
30-
function on_sample!(c::InsertSampleRatioControler)
31-
if c.n_inserted >= c.threshold
32-
if c.n_sampled < (c.n_inserted - n.threshold) * c.ratio
33-
c.n_sampled += 1
34-
true
35-
end
36-
end
37-
end
38-
39-
#####
40-
41-
mutable struct AsyncInsertSampleRatioControler
42-
ratio::Float64
43-
threshold::Int
44-
n_inserted::Int
45-
n_sampled::Int
46-
ch_in::Channel
47-
ch_out::Channel
48-
end
49-
50-
function AsyncInsertSampleRatioControler(
51-
ratio,
52-
threshold,
53-
; ch_in_sz=1,
54-
ch_out_sz=1,
55-
n_inserted=0,
56-
n_sampled=0
57-
)
58-
AsyncInsertSampleRatioControler(
59-
ratio,
60-
threshold,
61-
n_inserted,
62-
n_sampled,
63-
Channel(ch_in_sz),
64-
Channel(ch_out_sz)
65-
)
66-
end
67-
68-
#####
69-
706
"""
717
Trajectory(container, sampler, controler)
728
@@ -91,18 +27,18 @@ Base.@kwdef struct Trajectory{C,S,T}
9127

9228
function Trajectory(container::C, sampler::S, controler::T) where {C,S,T<:AsyncInsertSampleRatioControler}
9329
t = Threads.@spawn while true
94-
for msg in controler.in
30+
for msg in controler.ch_in
9531
if msg.f === Base.push! || msg.f === Base.append!
96-
n_pre = length(trajectory)
97-
msg.f(trajectory, msg.args...; msg.kw...)
98-
n_post = length(trajectory)
32+
n_pre = length(container)
33+
msg.f(container, msg.args...; msg.kw...)
34+
n_post = length(container)
9935
controler.n_inserted += n_post - n_pre
10036
else
101-
msg.f(trajectory, msg.args...; msg.kw...)
37+
msg.f(container, msg.args...; msg.kw...)
10238
end
10339

10440
if controler.n_inserted >= controler.threshold
105-
if controler.n_sampled < (controler.n_inserted - controler.threshold) * controler.ratio
41+
if controler.n_sampled <= (controler.n_inserted - controler.threshold) * controler.ratio
10642
batch = sample(sampler, container)
10743
put!(controler.ch_out, batch)
10844
controler.n_sampled += 1
@@ -111,8 +47,8 @@ Base.@kwdef struct Trajectory{C,S,T}
11147
end
11248
end
11349

114-
bind(controler.in, t)
115-
bind(controler.out, t)
50+
bind(controler.ch_in, t)
51+
bind(controler.ch_out, t)
11652
new{C,S,T}(container, sampler, controler)
11753
end
11854
end
@@ -134,7 +70,7 @@ struct CallMsg
13470
end
13571

13672
Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, args...; kw...) = put!(t.controler.ch_in, CallMsg(Base.push!, args, kw))
137-
Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, args...; kw...) = append!(t.controler.ch_in, CallMsg(Base.push!, args, kw))
73+
Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, args...; kw...) = put!(t.controler.ch_in, CallMsg(Base.append!, args, kw))
13874

13975
Base.append!(t::Trajectory; kw...) = append!(t, values(kw))
14076

test/trajectories.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@
3030
push!(batches, batch)
3131
end
3232

33-
@test length(batches) == 1 # 4 inserted, ratio is 0.25
33+
@test length(batches) == 1 # 4 inserted, threshold is 4, ratio is 0.25
3434

3535
append!(t; a=[5, 6, 7], b=[true, true, true])
3636

3737
for batch in t
3838
push!(batches, batch)
3939
end
4040

41-
@test length(batches) == 2 # 7 inserted, ratio is 0.25
41+
@test length(batches) == 1 # 7 inserted, threshold is 4, ratio is 0.25
4242

4343
push!(t; a=8, b=true)
4444

@@ -58,4 +58,29 @@
5858
s += 1
5959
end
6060
@test s == n
61+
end
62+
63+
@testset "async trajectories" begin
64+
threshould = 100
65+
ratio = 1 / 4
66+
t = Trajectory(
67+
container=Traces(
68+
a=Int[],
69+
b=Bool[]
70+
),
71+
sampler=BatchSampler(3),
72+
controler=AsyncInsertSampleRatioControler(ratio, threshould)
73+
)
74+
75+
n = 100
76+
insert_task = @async for i in 1:n
77+
append!(t; a=[i, i, i, i], b=[false, true, false, true])
78+
end
79+
80+
s = 0
81+
sample_task = @async for _ in t
82+
s += 1
83+
end
84+
sleep(1)
85+
@test s == (n - threshould * ratio) + 1
6186
end

0 commit comments

Comments
 (0)