Skip to content

Commit a02c055

Browse files
committed
changed bayes_risk and expected_info_gain to use model domain
1 parent 297b3e5 commit a02c055

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/qinfer/smc.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -556,8 +556,9 @@ def bayes_risk(self, expparams):
556556
has shape ``(expparams.size,)``
557557
"""
558558

559-
# number of outcomes for the first experiment
560-
n_out = self.model.n_outcomes(np.atleast_1d(expparams)[0])
559+
# outcomes for the first experiment
560+
os = self.model.domain(None).values
561+
n_out = os.size
561562

562563
# for models whose outcome number changes with experiment, we
563564
# take the easy way out and for-loop over experiments
@@ -572,7 +573,7 @@ def bayes_risk(self, expparams):
572573
# every possible outcome and expparam
573574
# the likelihood over outcomes should sum to 1, so don't compute for last outcome
574575
w_hyp, L, N = self.hypothetical_update(
575-
np.arange(n_out-1),
576+
os[:-1],
576577
expparams,
577578
return_normalization=True,
578579
return_likelihood=True
@@ -618,22 +619,23 @@ def expected_information_gain(self, expparams):
618619
# in which the other distribution is guaranteed to share support.
619620

620621
# number of outcomes for the first experiment
621-
n_out = self.model.n_outcomes(np.atleast_1d(expparams)[0])
622+
os = self.model.domain(None).values
623+
n_out = os.size
622624

623625
# for models whose outcome number changes with experiment, we
624626
# take the easy way out and for-loop over experiments
625627
n_eps = expparams.size
626628
if n_eps > 1 and not np.array_equal(self.model.n_outcomes(expparams), n_out):
627629
risk = np.empty(n_eps)
628630
for idx in range(n_eps):
629-
risk[idx] = self.bayes_risk(expparams[idx, np.newaxis])
631+
risk[idx] = self.expected_information_gain(expparams[idx, np.newaxis])
630632
return risk
631633

632634
# compute the hypothetical weights, likelihoods and normalizations for
633635
# every possible outcome and expparam
634636
# the likelihood over outcomes should sum to 1, so don't compute for last outcome
635637
w_hyp, L, N = self.hypothetical_update(
636-
np.arange(n_out-1),
638+
os[:-1],
637639
expparams,
638640
return_normalization=True,
639641
return_likelihood=True

0 commit comments

Comments
 (0)