Skip to content

Commit 29b72ea

Browse files
committed
small bug fix
1 parent 7fcc279 commit 29b72ea

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

src/qinfer/smc.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -556,17 +556,17 @@ def bayes_risk(self, expparams):
556556
has shape ``(expparams.size,)``
557557
"""
558558

559+
# number of outcomes for the first experiment
560+
n_out = self.model.n_outcomes(np.atleast_1d(expparams)[0])
561+
559562
# for models whose outcome number changes with experiment, we
560563
# take the easy way out and for-loop over experiments
561564
n_eps = expparams.size
562-
if n_eps > 1 and not self.model.is_n_outcomes_constant:
565+
if n_eps > 1 and not np.array_equal(self.model.n_outcomes(expparams), n_out):
563566
risk = np.empty(n_eps)
564567
for idx in range(n_eps):
565568
risk[idx] = self.bayes_risk(expparams[idx, np.newaxis])
566569
return risk
567-
568-
# but if we make it here, the following should be a single number
569-
n_out = self.model.n_outcomes(expparams)
570570

571571
# compute the hypothetical weights, likelihoods and normalizations for
572572
# every possible outcome and expparam
@@ -617,17 +617,17 @@ def expected_information_gain(self, expparams):
617617
# This is a special case of the KL divergence estimator (see below),
618618
# in which the other distribution is guaranteed to share support.
619619

620+
# number of outcomes for the first experiment
621+
n_out = self.model.n_outcomes(np.atleast_1d(expparams)[0])
622+
620623
# for models whose outcome number changes with experiment, we
621624
# take the easy way out and for-loop over experiments
622625
n_eps = expparams.size
623-
if n_eps > 1 and not self.model.is_n_outcomes_constant:
624-
ig = np.empty(n_eps)
626+
if n_eps > 1 and not np.array_equal(self.model.n_outcomes(expparams), n_out):
627+
risk = np.empty(n_eps)
625628
for idx in range(n_eps):
626-
ig[idx] = self.expected_information_gain(expparams[idx, np.newaxis])
627-
return ig
628-
629-
# but if we make it here, the following should be a single number
630-
n_out = self.model.n_outcomes(expparams)
629+
risk[idx] = self.bayes_risk(expparams[idx, np.newaxis])
630+
return risk
631631

632632
# compute the hypothetical weights, likelihoods and normalizations for
633633
# every possible outcome and expparam

0 commit comments

Comments
 (0)