Skip to content

Commit 091be91

Browse files
authored
Merge pull request #119 from ihincks/resampler-default-nparticles
Feature: Lui-West default number of particles
2 parents 3cbf3b9 + 06a1018 commit 091be91

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/qinfer/resamplers.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,12 @@ class LiuWestResampler(Resampler):
182182
has zero norm.
183183
:param callable kernel: Callable function ``kernel(*shape)`` that returns samples
184184
from a resampling distribution with mean 0 and variance 1.
185+
:param int default_n_particles: The default number of particles to draw during
186+
a resampling action. If ``None``, the number of redrawn particles
187+
redrawn will be equal to the number of particles given.
188+
The value of ``default_n_particles`` can be overridden by any integer
189+
value of ``n_particles`` given to ``__call__``.
190+
185191
186192
.. warning::
187193
@@ -192,9 +198,11 @@ class LiuWestResampler(Resampler):
192198
"""
193199
def __init__(self,
194200
a=0.98, h=None, maxiter=1000, debug=False, postselect=True,
195-
zero_cov_comp=1e-10,
201+
zero_cov_comp=1e-10,
202+
default_n_particles=None,
196203
kernel=np.random.randn
197204
):
205+
self._default_n_particles = default_n_particles
198206
self.a = a # Implicitly calls the property setter below to set _h.
199207
if h is not None:
200208
self._override_h = True
@@ -244,7 +252,10 @@ def __call__(self, model, particle_weights, particle_locations,
244252
cov = precomputed_cov
245253

246254
if n_particles is None:
247-
n_particles = l.shape[0]
255+
if self._default_n_particles is None:
256+
n_particles = l.shape[0]
257+
else:
258+
n_particles = self._default_n_particles
248259

249260
# parameters in the Liu and West algorithm
250261
a, h = self._a, self._h

src/qinfer/smc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(self,
145145
self.resampler = qinfer.resamplers.LiuWestResampler(a=resample_a)
146146
else:
147147
if resampler is None:
148-
self.resampler = qinfer.resamplers.LiuWestResampler()
148+
self.resampler = qinfer.resamplers.LiuWestResampler(default_n_particles=n_particles)
149149
else:
150150
self.resampler = resampler
151151

@@ -1613,4 +1613,4 @@ def update(self, outcome, expparams,check_for_resample=True):
16131613

16141614
# We now can update as normal.
16151615
SMCUpdater.update(self, outcome, expparams,check_for_resample=check_for_resample)
1616-
1616+

0 commit comments

Comments
 (0)