@@ -556,8 +556,9 @@ def bayes_risk(self, expparams):
556556 has shape ``(expparams.size,)``
557557 """
558558
559- # number of outcomes for the first experiment
560- n_out = self .model .n_outcomes (np .atleast_1d (expparams )[0 ])
559+ # outcomes for the first experiment
560+ os = self .model .domain (None ).values
561+ n_out = os .size
561562
562563 # for models whose outcome number changes with experiment, we
563564 # take the easy way out and for-loop over experiments
@@ -572,7 +573,7 @@ def bayes_risk(self, expparams):
572573 # every possible outcome and expparam
573574 # the likelihood over outcomes should sum to 1, so don't compute for last outcome
574575 w_hyp , L , N = self .hypothetical_update (
575- np . arange ( n_out - 1 ) ,
576+ os [: - 1 ] ,
576577 expparams ,
577578 return_normalization = True ,
578579 return_likelihood = True
@@ -618,22 +619,23 @@ def expected_information_gain(self, expparams):
618619 # in which the other distribution is guaranteed to share support.
619620
620621 # number of outcomes for the first experiment
621- n_out = self .model .n_outcomes (np .atleast_1d (expparams )[0 ])
622+ os = self .model .domain (None ).values
623+ n_out = os .size
622624
623625 # for models whose outcome number changes with experiment, we
624626 # take the easy way out and for-loop over experiments
625627 n_eps = expparams .size
626628 if n_eps > 1 and not np .array_equal (self .model .n_outcomes (expparams ), n_out ):
627629 risk = np .empty (n_eps )
628630 for idx in range (n_eps ):
629- risk [idx ] = self .bayes_risk (expparams [idx , np .newaxis ])
631+ risk [idx ] = self .expected_information_gain (expparams [idx , np .newaxis ])
630632 return risk
631633
632634 # compute the hypothetical weights, likelihoods and normalizations for
633635 # every possible outcome and expparam
634636 # the likelihood over outcomes should sum to 1, so don't compute for last outcome
635637 w_hyp , L , N = self .hypothetical_update (
636- np . arange ( n_out - 1 ) ,
638+ os [: - 1 ] ,
637639 expparams ,
638640 return_normalization = True ,
639641 return_likelihood = True
0 commit comments