Skip to content

Commit ae21853

Browse files
authored
Merge pull request #241 from StochasticTree/python-bcf-no-propensity-hotfix
Update Python BCF to work without propensities
2 parents 2878064 + 0a4d8c8 commit ae21853

File tree

2 files changed

+26
-22
lines changed

2 files changed

+26
-22
lines changed

stochtree/bcf.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -932,16 +932,15 @@ def sample(
932932
if sample_sigma2_leaf_tau is not None:
933933
if not isinstance(sample_sigma2_leaf_tau, bool):
934934
raise ValueError("sample_sigma2_leaf_tau must be a bool")
935-
if propensity_covariate is not None:
936-
if propensity_covariate not in [
937-
"prognostic",
938-
"treatment_effect",
939-
"both",
940-
"none",
941-
]:
942-
raise ValueError(
943-
"propensity_covariate must be one of 'prognostic', 'treatment_effect', 'both', or 'none'"
944-
)
935+
if propensity_covariate not in [
936+
"prognostic",
937+
"treatment_effect",
938+
"both",
939+
"none",
940+
]:
941+
raise ValueError(
942+
"propensity_covariate must be one of 'prognostic', 'treatment_effect', 'both', or 'none'"
943+
)
945944
if b_0 is not None:
946945
b_0 = check_scalar(
947946
x=b_0,
@@ -1663,15 +1662,6 @@ def sample(
16631662
] = 0
16641663

16651664
# Update covariates to include propensities if requested
1666-
if propensity_covariate not in [
1667-
"none",
1668-
"prognostic",
1669-
"treatment_effect",
1670-
"both",
1671-
]:
1672-
raise ValueError(
1673-
"propensity_covariate must equal one of 'none', 'prognostic', 'treatment_effect', or 'both'"
1674-
)
16751665
if propensity_covariate != "none":
16761666
feature_types = np.append(
16771667
feature_types, np.repeat(0, propensity_train.shape[1])
@@ -1700,9 +1690,10 @@ def sample(
17001690
variable_weights_tau = np.append(
17011691
variable_weights_tau, np.repeat(1 / num_cov_orig, propensity_train.shape[1])
17021692
)
1703-
variable_weights_variance = np.append(
1704-
variable_weights_variance, np.repeat(0.0, propensity_train.shape[1])
1705-
)
1693+
# For now, propensities are not included in the variance forest
1694+
variable_weights_variance = np.append(
1695+
variable_weights_variance, np.repeat(0.0, propensity_train.shape[1])
1696+
)
17061697

17071698
# Renormalize variable weights
17081699
variable_weights_mu = variable_weights_mu / np.sum(variable_weights_mu)

test/python/test_bcf.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,19 @@ def test_binary_bcf(self):
194194
# Check treatment effect prediction method
195195
tau_hat = bcf_model.predict(X=X_test, Z=Z_test, terms="cate")
196196

197+
# Check that we can run BCF without propensities
198+
bcf_model = BCFModel()
199+
general_params = {"propensity_covariate": "none"}
200+
bcf_model.sample(
201+
X_train=X_train,
202+
Z_train=Z_train,
203+
y_train=y_train,
204+
num_gfr=num_gfr,
205+
num_burnin=num_burnin,
206+
num_mcmc=num_mcmc,
207+
general_params=general_params,
208+
)
209+
197210
def test_continuous_univariate_bcf(self):
198211
# RNG
199212
random_seed = 101

0 commit comments

Comments
 (0)