@@ -556,17 +556,17 @@ def bayes_risk(self, expparams):
556556 has shape ``(expparams.size,)``
557557 """
558558
559+ # number of outcomes for the first experiment
560+ n_out = self .model .n_outcomes (np .atleast_1d (expparams )[0 ])
561+
559562 # for models whose outcome number changes with experiment, we
560563 # take the easy way out and for-loop over experiments
561564 n_eps = expparams .size
562- if n_eps > 1 and not self .model .is_n_outcomes_constant :
565+ if n_eps > 1 and not np . array_equal ( self .model .n_outcomes ( expparams ), n_out ) :
563566 risk = np .empty (n_eps )
564567 for idx in range (n_eps ):
565568 risk [idx ] = self .bayes_risk (expparams [idx , np .newaxis ])
566569 return risk
567-
568- # but if we make it here, the following should be a single number
569- n_out = self .model .n_outcomes (expparams )
570570
571571 # compute the hypothetical weights, likelihoods and normalizations for
572572 # every possible outcome and expparam
@@ -617,17 +617,17 @@ def expected_information_gain(self, expparams):
617617 # This is a special case of the KL divergence estimator (see below),
618618 # in which the other distribution is guaranteed to share support.
619619
620+ # number of outcomes for the first experiment
621+ n_out = self .model .n_outcomes (np .atleast_1d (expparams )[0 ])
622+
620623 # for models whose outcome number changes with experiment, we
621624 # take the easy way out and for-loop over experiments
622625 n_eps = expparams .size
623- if n_eps > 1 and not self .model .is_n_outcomes_constant :
624- ig = np .empty (n_eps )
626+ if n_eps > 1 and not np . array_equal ( self .model .n_outcomes ( expparams ), n_out ) :
627+ risk = np .empty (n_eps )
625628 for idx in range (n_eps ):
626- ig [idx ] = self .expected_information_gain (expparams [idx , np .newaxis ])
627- return ig
628-
629- # but if we make it here, the following should be a single number
630- n_out = self .model .n_outcomes (expparams )
629+ risk [idx ] = self .bayes_risk (expparams [idx , np .newaxis ])
630+ return risk
631631
632632 # compute the hypothetical weights, likelihoods and normalizations for
633633 # every possible outcome and expparam
0 commit comments