Skip to content

Commit 19f2864

Browse files
committed
added default n_particles to LuiWest resampler
1 parent 12b2612 commit 19f2864

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

src/qinfer/resamplers.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,12 @@ class LiuWestResampler(Resampler):
182182
has zero norm.
183183
:param callable kernel: Callable function ``kernel(*shape)`` that returns samples
184184
from a resampling distribution with mean 0 and variance 1.
185+
:param int n_particles: The default number of particles to draw during
186+
a resampling action. If ``None``, the number of redrawn particles
187+
redrawn will be equal to the number of particle given.
188+
This value of ``n_particles`` can be overridden by any integer
189+
value of ``n_particles`` given to ``__call__``.
190+
185191
186192
.. warning::
187193
@@ -192,9 +198,10 @@ class LiuWestResampler(Resampler):
192198
"""
193199
def __init__(self,
194200
a=0.98, h=None, maxiter=1000, debug=False, postselect=True,
195-
zero_cov_comp=1e-10,
201+
zero_cov_comp=1e-10, n_particles=None,
196202
kernel=np.random.randn
197203
):
204+
self._default_n_particles = n_particles
198205
self.a = a # Implicitly calls the property setter below to set _h.
199206
if h is not None:
200207
self._override_h = True
@@ -244,7 +251,10 @@ def __call__(self, model, particle_weights, particle_locations,
244251
cov = precomputed_cov
245252

246253
if n_particles is None:
247-
n_particles = l.shape[0]
254+
if self._default_n_particles is None:
255+
n_particles = l.shape[0]
256+
else:
257+
n_particles = self._default_n_particles
248258

249259
# parameters in the Liu and West algorithm
250260
a, h = self._a, self._h

src/qinfer/smc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(self,
145145
self.resampler = qinfer.resamplers.LiuWestResampler(a=resample_a)
146146
else:
147147
if resampler is None:
148-
self.resampler = qinfer.resamplers.LiuWestResampler()
148+
self.resampler = qinfer.resamplers.LiuWestResampler(n_particles=n_particles)
149149
else:
150150
self.resampler = resampler
151151

@@ -1613,4 +1613,4 @@ def update(self, outcome, expparams,check_for_resample=True):
16131613

16141614
# We now can update as normal.
16151615
SMCUpdater.update(self, outcome, expparams,check_for_resample=check_for_resample)
1616-
1616+

0 commit comments

Comments
 (0)