Skip to content

Commit d296f3e

Browse files
committed
Updated remaining R and Python functions
1 parent f55e349 commit d296f3e

12 files changed

+156
-160
lines changed

R/posterior_transformation.R

Lines changed: 59 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,8 @@ compute_contrast_bart_model <- function(
465465
#' Sample from the posterior predictive distribution for outcomes modeled by BCF
466466
#'
467467
#' @param model_object A fitted BCF model object of class `bcfmodel`.
468-
#' @param covariates A matrix or data frame of covariates.
469-
#' @param treatment A vector or matrix of treatment assignments.
468+
#' @param X A matrix or data frame of covariates.
469+
#' @param Z A vector or matrix of treatment assignments.
470470
#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.
471471
#' @param rfx_group_ids (Optional) A vector of group IDs for random effects model. Required if the BCF model includes random effects.
472472
#' @param rfx_basis (Optional) A matrix of bases for random effects model. Required if the BCF model includes random effects.
@@ -484,13 +484,13 @@ compute_contrast_bart_model <- function(
484484
#' y <- 2 * X[,2] + 0.5 * X[,2] * Z + rnorm(n)
485485
#' bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X)
486486
#' ppd_samples <- sample_bcf_posterior_predictive(
487-
#' model_object = bcf_model, covariates = X,
488-
#' treatment = Z, propensity = pi_X
487+
#' model_object = bcf_model, X = X,
488+
#' Z = Z, propensity = pi_X
489489
#' )
490490
sample_bcf_posterior_predictive <- function(
491491
model_object,
492-
covariates = NULL,
493-
treatment = NULL,
492+
X = NULL,
493+
Z = NULL,
494494
propensity = NULL,
495495
rfx_group_ids = NULL,
496496
rfx_basis = NULL,
@@ -505,33 +505,33 @@ sample_bcf_posterior_predictive <- function(
505505
# Check that all the necessary inputs were provided for interval computation
506506
needs_covariates <- TRUE
507507
if (needs_covariates) {
508-
if (is.null(covariates)) {
508+
if (is.null(X)) {
509509
stop(
510-
"'covariates' must be provided in order to compute the requested intervals"
510+
"'X' must be provided in order to compute the requested intervals"
511511
)
512512
}
513-
if (!is.matrix(covariates) && !is.data.frame(covariates)) {
514-
stop("'covariates' must be a matrix or data frame")
513+
if (!is.matrix(X) && !is.data.frame(X)) {
514+
stop("'X' must be a matrix or data frame")
515515
}
516516
}
517517
needs_treatment <- needs_covariates
518518
if (needs_treatment) {
519-
if (is.null(treatment)) {
519+
if (is.null(Z)) {
520520
stop(
521-
"'treatment' must be provided in order to compute the requested intervals"
521+
"'Z' must be provided in order to compute the requested intervals"
522522
)
523523
}
524-
if (!is.matrix(treatment) && !is.numeric(treatment)) {
525-
stop("'treatment' must be a numeric vector or matrix")
524+
if (!is.matrix(Z) && !is.numeric(Z)) {
525+
stop("'Z' must be a numeric vector or matrix")
526526
}
527-
if (is.matrix(treatment)) {
528-
if (nrow(treatment) != nrow(covariates)) {
529-
stop("'treatment' must have the same number of rows as 'covariates'")
527+
if (is.matrix(Z)) {
528+
if (nrow(Z) != nrow(X)) {
529+
stop("'Z' must have the same number of rows as 'X'")
530530
}
531531
} else {
532-
if (length(treatment) != nrow(covariates)) {
532+
if (length(Z) != nrow(X)) {
533533
stop(
534-
"'treatment' must have the same number of elements as 'covariates'"
534+
"'Z' must have the same number of elements as 'X'"
535535
)
536536
}
537537
}
@@ -551,13 +551,13 @@ sample_bcf_posterior_predictive <- function(
551551
stop("'propensity' must be a numeric vector or matrix")
552552
}
553553
if (is.matrix(propensity)) {
554-
if (nrow(propensity) != nrow(covariates)) {
555-
stop("'propensity' must have the same number of rows as 'covariates'")
554+
if (nrow(propensity) != nrow(X)) {
555+
stop("'propensity' must have the same number of rows as 'X'")
556556
}
557557
} else {
558-
if (length(propensity) != nrow(covariates)) {
558+
if (length(propensity) != nrow(X)) {
559559
stop(
560-
"'propensity' must have the same number of elements as 'covariates'"
560+
"'propensity' must have the same number of elements as 'X'"
561561
)
562562
}
563563
}
@@ -569,9 +569,9 @@ sample_bcf_posterior_predictive <- function(
569569
"'rfx_group_ids' must be provided in order to compute the requested intervals"
570570
)
571571
}
572-
if (length(rfx_group_ids) != nrow(covariates)) {
572+
if (length(rfx_group_ids) != nrow(X)) {
573573
stop(
574-
"'rfx_group_ids' must have the same length as the number of rows in 'covariates'"
574+
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
575575
)
576576
}
577577
if (is.null(rfx_basis)) {
@@ -582,16 +582,16 @@ sample_bcf_posterior_predictive <- function(
582582
if (!is.matrix(rfx_basis)) {
583583
stop("'rfx_basis' must be a matrix")
584584
}
585-
if (nrow(rfx_basis) != nrow(covariates)) {
586-
stop("'rfx_basis' must have the same number of rows as 'covariates'")
585+
if (nrow(rfx_basis) != nrow(X)) {
586+
stop("'rfx_basis' must have the same number of rows as 'X'")
587587
}
588588
}
589589

590590
# Compute posterior samples
591591
bcf_preds <- predict(
592592
model_object,
593-
X = covariates,
594-
Z = treatment,
593+
X = X,
594+
Z = Z,
595595
propensity = propensity,
596596
rfx_group_ids = rfx_group_ids,
597597
rfx_basis = rfx_basis,
@@ -605,7 +605,7 @@ sample_bcf_posterior_predictive <- function(
605605
has_variance_forest <- model_object$model_params$include_variance_forest
606606
samples_global_variance <- model_object$model_params$sample_sigma2_global
607607
num_posterior_draws <- model_object$model_params$num_samples
608-
num_observations <- nrow(covariates)
608+
num_observations <- nrow(X)
609609
ppd_mean <- bcf_preds$y_hat
610610
if (has_variance_forest) {
611611
ppd_variance <- bcf_preds$variance_forest_predictions
@@ -840,8 +840,8 @@ posterior_predictive_heuristic_multiplier <- function(
840840
#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`.
841841
#' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval).
842842
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
843-
#' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).
844-
#' @param treatment (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions).
843+
#' @param X (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).
844+
#' @param Z (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions).
845845
#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.
846846
#' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects.
847847
#' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.
@@ -863,8 +863,8 @@ posterior_predictive_heuristic_multiplier <- function(
863863
#' intervals <- compute_bcf_posterior_interval(
864864
#' model_object = bcf_model,
865865
#' terms = c("prognostic_function", "cate"),
866-
#' covariates = X,
867-
#' treatment = Z,
866+
#' X = X,
867+
#' Z = Z,
868868
#' propensity = pi_X,
869869
#' level = 0.90
870870
#' )
@@ -873,8 +873,8 @@ compute_bcf_posterior_interval <- function(
873873
terms,
874874
level = 0.95,
875875
scale = "linear",
876-
covariates = NULL,
877-
treatment = NULL,
876+
X = NULL,
877+
Z = NULL,
878878
propensity = NULL,
879879
rfx_group_ids = NULL,
880880
rfx_basis = NULL
@@ -930,33 +930,33 @@ compute_bcf_posterior_interval <- function(
930930
("variance_forest" %in% terms) ||
931931
(needs_covariates_intermediate))
932932
if (needs_covariates) {
933-
if (is.null(covariates)) {
933+
if (is.null(X)) {
934934
stop(
935-
"'covariates' must be provided in order to compute the requested intervals"
935+
"'X' must be provided in order to compute the requested intervals"
936936
)
937937
}
938-
if (!is.matrix(covariates) && !is.data.frame(covariates)) {
939-
stop("'covariates' must be a matrix or data frame")
938+
if (!is.matrix(X) && !is.data.frame(X)) {
939+
stop("'X' must be a matrix or data frame")
940940
}
941941
}
942942
needs_treatment <- needs_covariates
943943
if (needs_treatment) {
944-
if (is.null(treatment)) {
944+
if (is.null(Z)) {
945945
stop(
946-
"'treatment' must be provided in order to compute the requested intervals"
946+
"'Z' must be provided in order to compute the requested intervals"
947947
)
948948
}
949-
if (!is.matrix(treatment) && !is.numeric(treatment)) {
950-
stop("'treatment' must be a numeric vector or matrix")
949+
if (!is.matrix(Z) && !is.numeric(Z)) {
950+
stop("'Z' must be a numeric vector or matrix")
951951
}
952-
if (is.matrix(treatment)) {
953-
if (nrow(treatment) != nrow(covariates)) {
954-
stop("'treatment' must have the same number of rows as 'covariates'")
952+
if (is.matrix(Z)) {
953+
if (nrow(Z) != nrow(X)) {
954+
stop("'Z' must have the same number of rows as 'X'")
955955
}
956956
} else {
957-
if (length(treatment) != nrow(covariates)) {
957+
if (length(Z) != nrow(X)) {
958958
stop(
959-
"'treatment' must have the same number of elements as 'covariates'"
959+
"'Z' must have the same number of elements as 'X'"
960960
)
961961
}
962962
}
@@ -976,13 +976,13 @@ compute_bcf_posterior_interval <- function(
976976
stop("'propensity' must be a numeric vector or matrix")
977977
}
978978
if (is.matrix(propensity)) {
979-
if (nrow(propensity) != nrow(covariates)) {
980-
stop("'propensity' must have the same number of rows as 'covariates'")
979+
if (nrow(propensity) != nrow(X)) {
980+
stop("'propensity' must have the same number of rows as 'X'")
981981
}
982982
} else {
983-
if (length(propensity) != nrow(covariates)) {
983+
if (length(propensity) != nrow(X)) {
984984
stop(
985-
"'propensity' must have the same number of elements as 'covariates'"
985+
"'propensity' must have the same number of elements as 'X'"
986986
)
987987
}
988988
}
@@ -998,9 +998,9 @@ compute_bcf_posterior_interval <- function(
998998
"'rfx_group_ids' must be provided in order to compute the requested intervals"
999999
)
10001000
}
1001-
if (length(rfx_group_ids) != nrow(covariates)) {
1001+
if (length(rfx_group_ids) != nrow(X)) {
10021002
stop(
1003-
"'rfx_group_ids' must have the same length as the number of rows in 'covariates'"
1003+
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
10041004
)
10051005
}
10061006

@@ -1016,17 +1016,17 @@ compute_bcf_posterior_interval <- function(
10161016
if (!is.matrix(rfx_basis)) {
10171017
stop("'rfx_basis' must be a matrix")
10181018
}
1019-
if (nrow(rfx_basis) != nrow(covariates)) {
1020-
stop("'rfx_basis' must have the same number of rows as 'covariates'")
1019+
if (nrow(rfx_basis) != nrow(X)) {
1020+
stop("'rfx_basis' must have the same number of rows as 'X'")
10211021
}
10221022
}
10231023
}
10241024

10251025
# Compute posterior matrices for the requested model terms
10261026
predictions <- predict(
10271027
model_object,
1028-
X = covariates,
1029-
Z = treatment,
1028+
X = X,
1029+
Z = Z,
10301030
propensity = propensity,
10311031
rfx_group_ids = rfx_group_ids,
10321032
rfx_basis = rfx_basis,

demo/debug/bart_predict_debug.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363

6464
# Compute posterior interval
6565
intervals = bart_model.compute_posterior_interval(
66-
terms="all", scale="linear", level=0.95, covariates=X_test
66+
terms="all", scale="linear", level=0.95, X=X_test
6767
)
6868

6969
# Check coverage
@@ -75,7 +75,7 @@
7575

7676
# Sample from the posterior predictive distribution
7777
bart_ppd_samples = bart_model.sample_posterior_predictive(
78-
covariates=X_test, num_draws_per_sample=10
78+
X=X_test, num_draws_per_sample=10
7979
)
8080

8181
# Plot PPD mean vs actual

demo/debug/bcf_predict_debug.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@
9090
terms="all",
9191
scale="linear",
9292
level=0.95,
93-
covariates=X_test,
94-
treatment=Z_test,
93+
X=X_test,
94+
Z=Z_test,
9595
propensity=pi_test,
9696
)
9797

@@ -118,7 +118,7 @@
118118

119119
# Sample from the posterior predictive distribution
120120
bcf_ppd_samples = bcf_model.sample_posterior_predictive(
121-
covariates=X_test, treatment=Z_test, propensity=pi_test, num_draws_per_sample=10
121+
X=X_test, Z=Z_test, propensity=pi_test, num_draws_per_sample=10
122122
)
123123

124124
# Plot PPD mean vs actual
@@ -229,8 +229,8 @@
229229
terms="all",
230230
scale="linear",
231231
level=0.95,
232-
covariates=X_test,
233-
treatment=Z_test,
232+
X=X_test,
233+
Z=Z_test,
234234
propensity=pi_test,
235235
rfx_group_ids=rfx_group_ids_test,
236236
)
@@ -240,8 +240,8 @@
240240
terms="prognostic_function",
241241
scale="linear",
242242
level=0.95,
243-
covariates=X_test,
244-
treatment=Z_test,
243+
X=X_test,
244+
Z=Z_test,
245245
propensity=pi_test,
246246
rfx_group_ids=rfx_group_ids_test
247247
)
@@ -251,8 +251,8 @@
251251
terms="cate",
252252
scale="linear",
253253
level=0.95,
254-
covariates=X_test,
255-
treatment=Z_test,
254+
X=X_test,
255+
Z=Z_test,
256256
propensity=pi_test,
257257
rfx_group_ids=rfx_group_ids_test
258258
)
@@ -284,17 +284,17 @@
284284
terms="mu",
285285
scale="linear",
286286
level=0.95,
287-
covariates=X_test,
288-
treatment=Z_test,
287+
X=X_test,
288+
Z=Z_test,
289289
propensity=pi_test,
290290
rfx_group_ids=rfx_group_ids_test
291291
)
292292
tau_intervals_test = bcf_model.compute_posterior_interval(
293293
terms="tau",
294294
scale="linear",
295295
level=0.95,
296-
covariates=X_test,
297-
treatment=Z_test,
296+
X=X_test,
297+
Z=Z_test,
298298
propensity=pi_test,
299299
rfx_group_ids=rfx_group_ids_test
300300
)

demo/debug/causal_inference_binary_outcome.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Load necessary libraries
22
import numpy as np
33
import pandas as pd
4-
import seaborn as sns
54
import matplotlib.pyplot as plt
65
from stochtree import BCFModel
76
from sklearn.model_selection import train_test_split

demo/debug/multi_chain.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# Load necessary libraries
44
import matplotlib.pyplot as plt
55
import numpy as np
6-
import pandas as pd
76
import arviz as az
87
from sklearn.model_selection import train_test_split
98

demo/debug/multiple_initializations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,14 @@ def outcome_mean(X, W):
118118
)
119119

120120
# Inspect the model outputs
121-
bart_preds_2 = bart_model_2.predict(X=X_test, basis_test)
121+
bart_preds_2 = bart_model_2.predict(X=X_test, leaf_basis=basis_test)
122122
y_hat_mcmc_2 = bart_preds_2['y_hat']
123123
y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True)
124124
y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True)
125-
bart_preds_3 = bart_model_3.predict(X=X_test, basis_test)
125+
bart_preds_3 = bart_model_3.predict(X=X_test, leaf_basis=basis_test)
126126
y_hat_mcmc_3 = bart_preds_3['y_hat']
127127
y_avg_mcmc_3 = np.squeeze(y_hat_mcmc_3).mean(axis=1, keepdims=True)
128-
bart_preds_4 = bart_model_4.predict(X=X_test, basis_test)
128+
bart_preds_4 = bart_model_4.predict(X=X_test, leaf_basis=basis_test)
129129
y_hat_mcmc_4 = bart_preds_4['y_hat']
130130
y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True)
131131
y_df = pd.DataFrame(

0 commit comments

Comments
 (0)