@@ -932,16 +932,15 @@ def sample(
932932 if sample_sigma2_leaf_tau is not None :
933933 if not isinstance (sample_sigma2_leaf_tau , bool ):
934934 raise ValueError ("sample_sigma2_leaf_tau must be a bool" )
935- if propensity_covariate is not None :
936- if propensity_covariate not in [
937- "prognostic" ,
938- "treatment_effect" ,
939- "both" ,
940- "none" ,
941- ]:
942- raise ValueError (
943- "propensity_covariate must be one of 'prognostic', 'treatment_effect', 'both', or 'none'"
944- )
935+ if propensity_covariate not in [
936+ "prognostic" ,
937+ "treatment_effect" ,
938+ "both" ,
939+ "none" ,
940+ ]:
941+ raise ValueError (
942+ "propensity_covariate must be one of 'prognostic', 'treatment_effect', 'both', or 'none'"
943+ )
945944 if b_0 is not None :
946945 b_0 = check_scalar (
947946 x = b_0 ,
@@ -1663,15 +1662,6 @@ def sample(
16631662 ] = 0
16641663
16651664 # Update covariates to include propensities if requested
1666- if propensity_covariate not in [
1667- "none" ,
1668- "prognostic" ,
1669- "treatment_effect" ,
1670- "both" ,
1671- ]:
1672- raise ValueError (
1673- "propensity_covariate must equal one of 'none', 'prognostic', 'treatment_effect', or 'both'"
1674- )
16751665 if propensity_covariate != "none" :
16761666 feature_types = np .append (
16771667 feature_types , np .repeat (0 , propensity_train .shape [1 ])
@@ -1700,9 +1690,10 @@ def sample(
17001690 variable_weights_tau = np .append (
17011691 variable_weights_tau , np .repeat (1 / num_cov_orig , propensity_train .shape [1 ])
17021692 )
1703- variable_weights_variance = np .append (
1704- variable_weights_variance , np .repeat (0.0 , propensity_train .shape [1 ])
1705- )
1693+ # For now, propensities are not included in the variance forest
1694+ variable_weights_variance = np .append (
1695+ variable_weights_variance , np .repeat (0.0 , propensity_train .shape [1 ])
1696+ )
17061697
17071698 # Renormalize variable weights
17081699 variable_weights_mu = variable_weights_mu / np .sum (variable_weights_mu )
0 commit comments