Skip to content

Commit 76d47c5

Browse files
authored
Merge pull request #232 from StochasticTree/predict-yhat-hotfix
Fixed yhat bug in adaptive coding BCF
2 parents 4f1c3e8 + 9eb1f3f commit 76d47c5

File tree

4 files changed

+166
-2
lines changed

4 files changed

+166
-2
lines changed

R/bcf.R

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2474,6 +2474,8 @@ bcf <- function(
24742474
)
24752475
tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples)) *
24762476
y_std_train
2477+
control_adj_train <- t(t(tau_hat_train_raw) * b_0_samples) * y_std_train
2478+
mu_hat_train <- mu_hat_train + control_adj_train
24772479
} else {
24782480
tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train) *
24792481
y_std_train
@@ -2508,6 +2510,8 @@ bcf <- function(
25082510
t(tau_hat_test_raw) * (b_1_samples - b_0_samples)
25092511
) *
25102512
y_std_train
2513+
control_adj_test <- t(t(tau_hat_test_raw) * b_0_samples) * y_std_train
2514+
mu_hat_test <- mu_hat_test + control_adj_test
25112515
} else {
25122516
tau_hat_test <- forest_samples_tau$predict_raw(
25132517
forest_dataset_test
@@ -2849,10 +2853,11 @@ predict.bcfmodel <- function(
28492853
"all"
28502854
))
28512855
) {
2852-
stop(paste0(
2856+
warning(paste0(
28532857
"Term '",
28542858
term,
2855-
"' was requested. Valid terms are 'y_hat', 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', and 'all'."
2859+
"' was requested. Valid terms are 'y_hat', 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', and 'all'.",
2860+
" This term will be ignored and prediction will only proceed if other requested terms are available in the model."
28562861
))
28572862
}
28582863
}
@@ -3056,6 +3061,8 @@ predict.bcfmodel <- function(
30563061
t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples)
30573062
) *
30583063
y_std
3064+
control_adj <- t(t(tau_hat_raw) * object$b_0_samples) * y_std
3065+
mu_hat_forest <- mu_hat_forest + control_adj
30593066
} else {
30603067
tau_hat_forest <- object$forests_tau$predict_raw(forest_dataset_pred) *
30613068
y_std

demo/debug/bcf_pred_rmse.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Load libraries
2+
from stochtree import BCFModel
3+
import numpy as np
4+
from sklearn.model_selection import train_test_split
5+
from scipy.stats import norm
6+
7+
# Simulation parameters
8+
n = 250
9+
p = 50
10+
n_sim = 100
11+
test_set_pct = 0.2
12+
rng = np.random.default_rng()
13+
14+
# Simulation containers
15+
rmses_cached = np.empty(n_sim)
16+
rmses_pred = np.empty(n_sim)
17+
18+
# Run the simulation
19+
for i in range(n_sim):
20+
# Generate data
21+
X = rng.normal(loc=0.0, scale=1.0, size=(n, p))
22+
mu_X = X[:, 0]
23+
tau_X = 0.25 * X[:, 1]
24+
pi_X = norm.cdf(0.5 * X[:, 1])
25+
Z = rng.binomial(n=1, p=pi_X, size=(n,))
26+
E_XZ = mu_X + tau_X * Z
27+
snr = 2.0
28+
noise_sd = np.std(E_XZ) / snr
29+
y = E_XZ + rng.normal(loc=0.0, scale=noise_sd, size=(n,))
30+
31+
# Train-test split
32+
sample_inds = np.arange(n)
33+
train_inds, test_inds = train_test_split(sample_inds, test_size=test_set_pct)
34+
X_train = X[train_inds, :]
35+
X_test = X[test_inds, :]
36+
Z_train = Z[train_inds]
37+
Z_test = Z[test_inds]
38+
pi_train = pi_X[train_inds]
39+
pi_test = pi_X[test_inds]
40+
tau_train = tau_X[train_inds]
41+
tau_test = tau_X[test_inds]
42+
mu_train = mu_X[train_inds]
43+
mu_test = mu_X[test_inds]
44+
y_train = y[train_inds]
45+
y_test = y[test_inds]
46+
E_XZ_train = E_XZ[train_inds]
47+
E_XZ_test = E_XZ[test_inds]
48+
49+
# Fit simple BCF model
50+
bcf_model = BCFModel()
51+
bcf_model.sample(
52+
X_train=X_train,
53+
Z_train=Z_train,
54+
pi_train=pi_train,
55+
y_train=y_train,
56+
X_test=X_test,
57+
Z_test=Z_test,
58+
pi_test=pi_test,
59+
)
60+
61+
# Predict out of sample
62+
y_hat_test = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test, type="mean", terms = "y_hat")
63+
64+
# Compute RMSE using both cached predictions and those returned by predict()
65+
rmses_cached[i] = np.sqrt(np.mean(np.power(np.mean(bcf_model.y_hat_test, axis = 1) - E_XZ_test, 2.0)))
66+
rmses_pred[i] = np.sqrt(np.mean(np.power(y_hat_test - E_XZ_test, 2.0)))
67+
68+
print(f"Average RMSE, cached: {np.mean(rmses_cached):.4f}, out-of-sample pred: {np.mean(rmses_pred):.4f}")

stochtree/bcf.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,7 +2267,12 @@ def sample(
22672267
adaptive_coding_weights = np.expand_dims(
22682268
self.b1_samples - self.b0_samples, axis=(0, 2)
22692269
)
2270+
b0_weights = np.expand_dims(
2271+
self.b0_samples, axis=(0, 2)
2272+
)
2273+
control_adj_train = self.tau_hat_train * b0_weights * self.y_std
22702274
self.tau_hat_train = self.tau_hat_train * adaptive_coding_weights
2275+
self.mu_hat_train = self.mu_hat_train + np.squeeze(control_adj_train)
22712276
self.tau_hat_train = np.squeeze(self.tau_hat_train * self.y_std)
22722277
if self.multivariate_treatment:
22732278
treatment_term_train = np.multiply(
@@ -2289,7 +2294,12 @@ def sample(
22892294
adaptive_coding_weights_test = np.expand_dims(
22902295
self.b1_samples - self.b0_samples, axis=(0, 2)
22912296
)
2297+
b0_weights = np.expand_dims(
2298+
self.b0_samples, axis=(0, 2)
2299+
)
2300+
control_adj_test = self.tau_hat_test * b0_weights * self.y_std
22922301
self.tau_hat_test = self.tau_hat_test * adaptive_coding_weights_test
2302+
self.mu_hat_test = self.mu_hat_test + np.squeeze(control_adj_test)
22932303
self.tau_hat_test = np.squeeze(self.tau_hat_test * self.y_std)
22942304
if self.multivariate_treatment:
22952305
treatment_term_test = np.multiply(
@@ -2594,7 +2604,12 @@ def predict(
25942604
adaptive_coding_weights = np.expand_dims(
25952605
self.b1_samples - self.b0_samples, axis=(0, 2)
25962606
)
2607+
b0_weights = np.expand_dims(
2608+
self.b0_samples, axis=(0, 2)
2609+
)
2610+
control_adj = tau_raw * b0_weights * self.y_std
25972611
tau_raw = tau_raw * adaptive_coding_weights
2612+
mu_x_forest = mu_x_forest + np.squeeze(control_adj)
25982613
tau_x_forest = np.squeeze(tau_raw * self.y_std)
25992614
if Z.shape[1] > 1:
26002615
treatment_term = np.multiply(

tools/simulations/bcf-pred-rmse.R

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Load library
2+
library(stochtree)
3+
4+
# Simulation parameters
5+
n <- 250
6+
p <- 50
7+
n_sim <- 100
8+
test_set_pct <- 0.2
9+
10+
# Simulation containers
11+
rmses_cached <- rep(NA_real_, n_sim)
12+
rmses_pred <- rep(NA_real_, n_sim)
13+
14+
# Run the simulation
15+
for (i in 1:n_sim) {
16+
# Generate data
17+
X <- matrix(rnorm(n * p), ncol = p)
18+
mu_x <- X[, 1]
19+
tau_x <- 0.25 * X[, 2]
20+
pi_x <- pnorm(0.5 * X[, 1])
21+
Z <- rbinom(n, 1, pi_x)
22+
E_XZ <- mu_x + Z * tau_x
23+
snr <- 2
24+
y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr)
25+
26+
# Train-test split
27+
n_test <- round(test_set_pct * n)
28+
n_train <- n - n_test
29+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
30+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
31+
X_test <- X[test_inds, ]
32+
X_train <- X[train_inds, ]
33+
pi_test <- pi_x[test_inds]
34+
pi_train <- pi_x[train_inds]
35+
Z_test <- Z[test_inds]
36+
Z_train <- Z[train_inds]
37+
y_test <- y[test_inds]
38+
y_train <- y[train_inds]
39+
mu_test <- mu_x[test_inds]
40+
mu_train <- mu_x[train_inds]
41+
tau_test <- tau_x[test_inds]
42+
tau_train <- tau_x[train_inds]
43+
E_XZ_test <- E_XZ[test_inds]
44+
E_XZ_train <- E_XZ[train_inds]
45+
46+
# Fit a simple BCF model
47+
bcf_model <- bcf(
48+
X_train = X_train,
49+
Z_train = Z_train,
50+
y_train = y_train,
51+
propensity_train = pi_train,
52+
X_test = X_test,
53+
Z_test = Z_test,
54+
propensity_test = pi_test
55+
)
56+
57+
# Predict out of sample
58+
y_hat_test <- predict(
59+
bcf_model,
60+
X = X_test,
61+
Z = Z_test,
62+
propensity = pi_test,
63+
type = "mean",
64+
terms = "y_hat"
65+
)
66+
67+
# Compute RMSE using both cached predictions and those returned by predict()
68+
rmses_cached[i] <- sqrt(mean((rowMeans(bcf_model$y_hat_test) - E_XZ_test)^2))
69+
rmses_pred[i] <- sqrt(mean((y_hat_test - E_XZ_test)^2))
70+
}
71+
72+
# Inspect results
73+
mean(rmses_cached)
74+
mean(rmses_pred)

0 commit comments

Comments
 (0)