1- export Trajectory, InsertSampleRatioControler
1+ export Trajectory
22
33using Base. Threads
44
55
6- # ####
7-
8- mutable struct InsertSampleRatioControler
9- ratio:: Float64
10- threshold:: Int
11- n_inserted:: Int
12- n_sampled:: Int
13- end
14-
15- """
16- InsertSampleRatioControler(ratio, threshold)
17-
18- Used in [`Trajectory`](@ref). The `threshold` means the minimal number of
19- insertings before sampling. The `ratio` balances the number of insertings and
20- the number of samplings.
21- """
22- InsertSampleRatioControler (ratio, threshold) = InsertSampleRatioControler (ratio, threshold, 0 , 0 )
23-
24- function on_insert! (c:: InsertSampleRatioControler , n:: Int )
25- if n > 0
26- c. n_inserted += n
27- end
28- end
29-
30- function on_sample! (c:: InsertSampleRatioControler )
31- if c. n_inserted >= c. threshold
32- if c. n_sampled < (c. n_inserted - n. threshold) * c. ratio
33- c. n_sampled += 1
34- true
35- end
36- end
37- end
38-
39- # ####
40-
41- mutable struct AsyncInsertSampleRatioControler
42- ratio:: Float64
43- threshold:: Int
44- n_inserted:: Int
45- n_sampled:: Int
46- ch_in:: Channel
47- ch_out:: Channel
48- end
49-
50- function AsyncInsertSampleRatioControler (
51- ratio,
52- threshold,
53- ; ch_in_sz= 1 ,
54- ch_out_sz= 1 ,
55- n_inserted= 0 ,
56- n_sampled= 0
57- )
58- AsyncInsertSampleRatioControler (
59- ratio,
60- threshold,
61- n_inserted,
62- n_sampled,
63- Channel (ch_in_sz),
64- Channel (ch_out_sz)
65- )
66- end
67-
68- # ####
69-
706"""
717 Trajectory(container, sampler, controler)
728
@@ -91,18 +27,18 @@ Base.@kwdef struct Trajectory{C,S,T}
9127
9228 function Trajectory (container:: C , sampler:: S , controler:: T ) where {C,S,T<: AsyncInsertSampleRatioControler }
9329 t = Threads. @spawn while true
94- for msg in controler. in
30+ for msg in controler. ch_in
9531 if msg. f === Base. push! || msg. f === Base. append!
96- n_pre = length (trajectory )
97- msg. f (trajectory , msg. args... ; msg. kw... )
98- n_post = length (trajectory )
32+ n_pre = length (container )
33+ msg. f (container , msg. args... ; msg. kw... )
34+ n_post = length (container )
9935 controler. n_inserted += n_post - n_pre
10036 else
101- msg. f (trajectory , msg. args... ; msg. kw... )
37+ msg. f (container , msg. args... ; msg. kw... )
10238 end
10339
10440 if controler. n_inserted >= controler. threshold
105- if controler. n_sampled < (controler. n_inserted - controler. threshold) * controler. ratio
41+ if controler. n_sampled <= (controler. n_inserted - controler. threshold) * controler. ratio
10642 batch = sample (sampler, container)
10743 put! (controler. ch_out, batch)
10844 controler. n_sampled += 1
@@ -111,8 +47,8 @@ Base.@kwdef struct Trajectory{C,S,T}
11147 end
11248 end
11349
114- bind (controler. in , t)
115- bind (controler. out , t)
50+ bind (controler. ch_in , t)
51+ bind (controler. ch_out , t)
11652 new {C,S,T} (container, sampler, controler)
11753 end
11854end
@@ -134,7 +70,7 @@ struct CallMsg
13470end
13571
13672Base. push! (t:: Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler} , args... ; kw... ) = put! (t. controler. ch_in, CallMsg (Base. push!, args, kw))
137- Base. append! (t:: Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler} , args... ; kw... ) = append ! (t. controler. ch_in, CallMsg (Base. push !, args, kw))
73+ Base. append! (t:: Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler} , args... ; kw... ) = put ! (t. controler. ch_in, CallMsg (Base. append !, args, kw))
13874
13975Base. append! (t:: Trajectory ; kw... ) = append! (t, values (kw))
14076
0 commit comments