@@ -465,8 +465,8 @@ compute_contrast_bart_model <- function(
465465# ' Sample from the posterior predictive distribution for outcomes modeled by BCF
466466# '
467467# ' @param model_object A fitted BCF model object of class `bcfmodel`.
468- # ' @param covariates A matrix or data frame of covariates.
469- # ' @param treatment A vector or matrix of treatment assignments.
468+ # ' @param X A matrix or data frame of covariates.
469+ # ' @param Z A vector or matrix of treatment assignments.
470470# ' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.
471471# ' @param rfx_group_ids (Optional) A vector of group IDs for random effects model. Required if the BCF model includes random effects.
472472# ' @param rfx_basis (Optional) A matrix of bases for random effects model. Required if the BCF model includes random effects.
@@ -484,13 +484,13 @@ compute_contrast_bart_model <- function(
484484# ' y <- 2 * X[,2] + 0.5 * X[,2] * Z + rnorm(n)
485485# ' bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X)
486486# ' ppd_samples <- sample_bcf_posterior_predictive(
487- # ' model_object = bcf_model, covariates = X,
488- # ' treatment = Z, propensity = pi_X
487+ # ' model_object = bcf_model, X = X,
488+ # ' Z = Z, propensity = pi_X
489489# ' )
490490sample_bcf_posterior_predictive <- function (
491491 model_object ,
492- covariates = NULL ,
493- treatment = NULL ,
492+ X = NULL ,
493+ Z = NULL ,
494494 propensity = NULL ,
495495 rfx_group_ids = NULL ,
496496 rfx_basis = NULL ,
@@ -505,33 +505,33 @@ sample_bcf_posterior_predictive <- function(
505505 # Check that all the necessary inputs were provided for interval computation
506506 needs_covariates <- TRUE
507507 if (needs_covariates ) {
508- if (is.null(covariates )) {
508+ if (is.null(X )) {
509509 stop(
510- " 'covariates ' must be provided in order to compute the requested intervals"
510+ " 'X ' must be provided in order to compute the requested intervals"
511511 )
512512 }
513- if (! is.matrix(covariates ) && ! is.data.frame(covariates )) {
514- stop(" 'covariates ' must be a matrix or data frame" )
513+ if (! is.matrix(X ) && ! is.data.frame(X )) {
514+ stop(" 'X ' must be a matrix or data frame" )
515515 }
516516 }
517517 needs_treatment <- needs_covariates
518518 if (needs_treatment ) {
519- if (is.null(treatment )) {
519+ if (is.null(Z )) {
520520 stop(
521- " 'treatment ' must be provided in order to compute the requested intervals"
521+ " 'Z ' must be provided in order to compute the requested intervals"
522522 )
523523 }
524- if (! is.matrix(treatment ) && ! is.numeric(treatment )) {
525- stop(" 'treatment ' must be a numeric vector or matrix" )
524+ if (! is.matrix(Z ) && ! is.numeric(Z )) {
525+ stop(" 'Z ' must be a numeric vector or matrix" )
526526 }
527- if (is.matrix(treatment )) {
528- if (nrow(treatment ) != nrow(covariates )) {
529- stop(" 'treatment ' must have the same number of rows as 'covariates '" )
527+ if (is.matrix(Z )) {
528+ if (nrow(Z ) != nrow(X )) {
529+ stop(" 'Z ' must have the same number of rows as 'X '" )
530530 }
531531 } else {
532- if (length(treatment ) != nrow(covariates )) {
532+ if (length(Z ) != nrow(X )) {
533533 stop(
534- " 'treatment ' must have the same number of elements as 'covariates '"
534+ " 'Z ' must have the same number of elements as 'X '"
535535 )
536536 }
537537 }
@@ -551,13 +551,13 @@ sample_bcf_posterior_predictive <- function(
551551 stop(" 'propensity' must be a numeric vector or matrix" )
552552 }
553553 if (is.matrix(propensity )) {
554- if (nrow(propensity ) != nrow(covariates )) {
555- stop(" 'propensity' must have the same number of rows as 'covariates '" )
554+ if (nrow(propensity ) != nrow(X )) {
555+ stop(" 'propensity' must have the same number of rows as 'X '" )
556556 }
557557 } else {
558- if (length(propensity ) != nrow(covariates )) {
558+ if (length(propensity ) != nrow(X )) {
559559 stop(
560- " 'propensity' must have the same number of elements as 'covariates '"
560+ " 'propensity' must have the same number of elements as 'X '"
561561 )
562562 }
563563 }
@@ -569,9 +569,9 @@ sample_bcf_posterior_predictive <- function(
569569 " 'rfx_group_ids' must be provided in order to compute the requested intervals"
570570 )
571571 }
572- if (length(rfx_group_ids ) != nrow(covariates )) {
572+ if (length(rfx_group_ids ) != nrow(X )) {
573573 stop(
574- " 'rfx_group_ids' must have the same length as the number of rows in 'covariates '"
574+ " 'rfx_group_ids' must have the same length as the number of rows in 'X '"
575575 )
576576 }
577577 if (is.null(rfx_basis )) {
@@ -582,16 +582,16 @@ sample_bcf_posterior_predictive <- function(
582582 if (! is.matrix(rfx_basis )) {
583583 stop(" 'rfx_basis' must be a matrix" )
584584 }
585- if (nrow(rfx_basis ) != nrow(covariates )) {
586- stop(" 'rfx_basis' must have the same number of rows as 'covariates '" )
585+ if (nrow(rfx_basis ) != nrow(X )) {
586+ stop(" 'rfx_basis' must have the same number of rows as 'X '" )
587587 }
588588 }
589589
590590 # Compute posterior samples
591591 bcf_preds <- predict(
592592 model_object ,
593- X = covariates ,
594- Z = treatment ,
593+ X = X ,
594+ Z = Z ,
595595 propensity = propensity ,
596596 rfx_group_ids = rfx_group_ids ,
597597 rfx_basis = rfx_basis ,
@@ -605,7 +605,7 @@ sample_bcf_posterior_predictive <- function(
605605 has_variance_forest <- model_object $ model_params $ include_variance_forest
606606 samples_global_variance <- model_object $ model_params $ sample_sigma2_global
607607 num_posterior_draws <- model_object $ model_params $ num_samples
608- num_observations <- nrow(covariates )
608+ num_observations <- nrow(X )
609609 ppd_mean <- bcf_preds $ y_hat
610610 if (has_variance_forest ) {
611611 ppd_variance <- bcf_preds $ variance_forest_predictions
@@ -840,8 +840,8 @@ posterior_predictive_heuristic_multiplier <- function(
840840# ' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`.
841841# ' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval).
842842# ' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
843- # ' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).
844- # ' @param treatment (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions).
843+ # ' @param X (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).
844+ # ' @param Z (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions).
845845# ' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.
846846# ' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects.
847847# ' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.
@@ -863,8 +863,8 @@ posterior_predictive_heuristic_multiplier <- function(
863863# ' intervals <- compute_bcf_posterior_interval(
864864# ' model_object = bcf_model,
865865# ' terms = c("prognostic_function", "cate"),
866- # ' covariates = X,
867- # ' treatment = Z,
866+ # ' X = X,
867+ # ' Z = Z,
868868# ' propensity = pi_X,
869869# ' level = 0.90
870870# ' )
@@ -873,8 +873,8 @@ compute_bcf_posterior_interval <- function(
873873 terms ,
874874 level = 0.95 ,
875875 scale = " linear" ,
876- covariates = NULL ,
877- treatment = NULL ,
876+ X = NULL ,
877+ Z = NULL ,
878878 propensity = NULL ,
879879 rfx_group_ids = NULL ,
880880 rfx_basis = NULL
@@ -930,33 +930,33 @@ compute_bcf_posterior_interval <- function(
930930 (" variance_forest" %in% terms ) ||
931931 (needs_covariates_intermediate ))
932932 if (needs_covariates ) {
933- if (is.null(covariates )) {
933+ if (is.null(X )) {
934934 stop(
935- " 'covariates ' must be provided in order to compute the requested intervals"
935+ " 'X ' must be provided in order to compute the requested intervals"
936936 )
937937 }
938- if (! is.matrix(covariates ) && ! is.data.frame(covariates )) {
939- stop(" 'covariates ' must be a matrix or data frame" )
938+ if (! is.matrix(X ) && ! is.data.frame(X )) {
939+ stop(" 'X ' must be a matrix or data frame" )
940940 }
941941 }
942942 needs_treatment <- needs_covariates
943943 if (needs_treatment ) {
944- if (is.null(treatment )) {
944+ if (is.null(Z )) {
945945 stop(
946- " 'treatment ' must be provided in order to compute the requested intervals"
946+ " 'Z ' must be provided in order to compute the requested intervals"
947947 )
948948 }
949- if (! is.matrix(treatment ) && ! is.numeric(treatment )) {
950- stop(" 'treatment ' must be a numeric vector or matrix" )
949+ if (! is.matrix(Z ) && ! is.numeric(Z )) {
950+ stop(" 'Z ' must be a numeric vector or matrix" )
951951 }
952- if (is.matrix(treatment )) {
953- if (nrow(treatment ) != nrow(covariates )) {
954- stop(" 'treatment ' must have the same number of rows as 'covariates '" )
952+ if (is.matrix(Z )) {
953+ if (nrow(Z ) != nrow(X )) {
954+ stop(" 'Z ' must have the same number of rows as 'X '" )
955955 }
956956 } else {
957- if (length(treatment ) != nrow(covariates )) {
957+ if (length(Z ) != nrow(X )) {
958958 stop(
959- " 'treatment ' must have the same number of elements as 'covariates '"
959+ " 'Z ' must have the same number of elements as 'X '"
960960 )
961961 }
962962 }
@@ -976,13 +976,13 @@ compute_bcf_posterior_interval <- function(
976976 stop(" 'propensity' must be a numeric vector or matrix" )
977977 }
978978 if (is.matrix(propensity )) {
979- if (nrow(propensity ) != nrow(covariates )) {
980- stop(" 'propensity' must have the same number of rows as 'covariates '" )
979+ if (nrow(propensity ) != nrow(X )) {
980+ stop(" 'propensity' must have the same number of rows as 'X '" )
981981 }
982982 } else {
983- if (length(propensity ) != nrow(covariates )) {
983+ if (length(propensity ) != nrow(X )) {
984984 stop(
985- " 'propensity' must have the same number of elements as 'covariates '"
985+ " 'propensity' must have the same number of elements as 'X '"
986986 )
987987 }
988988 }
@@ -998,9 +998,9 @@ compute_bcf_posterior_interval <- function(
998998 " 'rfx_group_ids' must be provided in order to compute the requested intervals"
999999 )
10001000 }
1001- if (length(rfx_group_ids ) != nrow(covariates )) {
1001+ if (length(rfx_group_ids ) != nrow(X )) {
10021002 stop(
1003- " 'rfx_group_ids' must have the same length as the number of rows in 'covariates '"
1003+ " 'rfx_group_ids' must have the same length as the number of rows in 'X '"
10041004 )
10051005 }
10061006
@@ -1016,17 +1016,17 @@ compute_bcf_posterior_interval <- function(
10161016 if (! is.matrix(rfx_basis )) {
10171017 stop(" 'rfx_basis' must be a matrix" )
10181018 }
1019- if (nrow(rfx_basis ) != nrow(covariates )) {
1020- stop(" 'rfx_basis' must have the same number of rows as 'covariates '" )
1019+ if (nrow(rfx_basis ) != nrow(X )) {
1020+ stop(" 'rfx_basis' must have the same number of rows as 'X '" )
10211021 }
10221022 }
10231023 }
10241024
10251025 # Compute posterior matrices for the requested model terms
10261026 predictions <- predict(
10271027 model_object ,
1028- X = covariates ,
1029- Z = treatment ,
1028+ X = X ,
1029+ Z = Z ,
10301030 propensity = propensity ,
10311031 rfx_group_ids = rfx_group_ids ,
10321032 rfx_basis = rfx_basis ,
0 commit comments