@@ -11,13 +11,12 @@ function resample(
1111 rng:: AbstractRNG ,
1212 resampler:: AbstractResampler ,
1313 states:: ParticleState{PT,WT} ,
14- filter:: AbstractFilter ,
14+ filter:: AbstractFilter ;
15+ weights:: AbstractVector{WT} = StatsBase. weights (states)
1516) where {PT,WT}
16- weights = StatsBase. weights (states)
1717 idxs = sample_ancestors (rng, resampler, weights)
18-
19- new_state = ParticleState (deepcopy (states. particles[idxs]), zeros (WT, length (states)))
20-
18+ new_state = ParticleState (deepcopy (states. particles[idxs]), zeros (WT, length (states)))
19+ reset_weights! (new_state, idxs, filter)
2120 return new_state, idxs
2221end
2322
@@ -26,8 +25,9 @@ function resample(
2625 rng:: AbstractRNG ,
2726 resampler:: AbstractResampler ,
2827 states:: RaoBlackwellisedParticleState{T,M,ZT} ,
28+ :: AbstractFilter ;
29+ weights= StatsBase. weights (states)
2930) where {T,M,ZT}
30- weights = StatsBase. weights (states)
3131 idxs = sample_ancestors (rng, resampler, weights)
3232
3333 new_state = RaoBlackwellisedParticleState (
@@ -39,23 +39,6 @@ function resample(
3939 return new_state, idxs
4040end
4141
42- # TODO : combine this with above definition
43- function resample (
44- rng:: AbstractRNG ,
45- resampler:: AbstractResampler ,
46- states:: RaoBlackwellisedParticleState{T,M,ZT} ,
47- ) where {T,M,ZT}
48- weights = StatsBase. weights (states)
49- idxs = sample_ancestors (rng, resampler, weights)
50-
51- new_state = RaoBlackwellisedParticleState (
52- deepcopy (states. x_particles[:, idxs]),
53- deepcopy (states. z_particles[idxs]),
54- CUDA. zeros (T, length (states)),
55- )
56- return reset_weights! (state, idxs, filter)
57- end
58-
5942# # CONDITIONAL RESAMPLING ##################################################################
6043
6144abstract type AbstractConditionalResampler <: AbstractResampler end
@@ -69,7 +52,7 @@ struct ESSResampler <: AbstractConditionalResampler
6952end
7053
7154function resample (
72- rng:: AbstractRNG , cond_resampler:: ESSResampler , state:: ParticleState{PT,WT}
55+ rng:: AbstractRNG , cond_resampler:: ESSResampler , state:: ParticleState{PT,WT} , filter :: AbstractFilter
7356) where {PT,WT}
7457 n = length (state)
7558 # TODO : computing weights twice. Should create a wrapper to avoid this
@@ -78,7 +61,7 @@ function resample(
7861 @debug " ESS: $ess "
7962
8063 if cond_resampler. threshold * n ≥ ess
81- return resample (rng, cond_resampler. resampler, state)
64+ return resample (rng, cond_resampler. resampler, state, filter; weights = weights )
8265 else
8366 return deepcopy (state), collect (1 : n)
8467 end
0 commit comments