Skip to content

Commit d303fcb

Browse files
committed
Updated posterior predictive method / function in R and Python
1 parent 8db1d8d commit d303fcb

File tree

4 files changed

+54
-54
lines changed

4 files changed

+54
-54
lines changed

R/posterior_transformation.R

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -659,8 +659,8 @@ sample_bcf_posterior_predictive <- function(
659659
#' Sample from the posterior predictive distribution for outcomes modeled by BART
660660
#'
661661
#' @param model_object A fitted BART model object of class `bartmodel`.
662-
#' @param covariates A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).
663-
#' @param basis A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models.
662+
#' @param X A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).
663+
#' @param leaf_basis A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models.
664664
#' @param rfx_group_ids A vector of group IDs for random effects model. Required if the BART model includes random effects.
665665
#' @param rfx_basis A matrix of bases for random effects model. Required if the BART model includes random effects.
666666
#' @param num_draws_per_sample The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws).
@@ -675,12 +675,12 @@ sample_bcf_posterior_predictive <- function(
675675
#' y <- 2 * X[,1] + rnorm(n)
676676
#' bart_model <- bart(y_train = y, X_train = X)
677677
#' ppd_samples <- sample_bart_posterior_predictive(
678-
#' model_object = bart_model, covariates = X
678+
#' model_object = bart_model, X = X
679679
#' )
680680
sample_bart_posterior_predictive <- function(
681681
model_object,
682-
covariates = NULL,
683-
basis = NULL,
682+
X = NULL,
683+
leaf_basis = NULL,
684684
rfx_group_ids = NULL,
685685
rfx_basis = NULL,
686686
num_draws_per_sample = NULL
@@ -694,32 +694,32 @@ sample_bart_posterior_predictive <- function(
694694
# Check that all the necessary inputs were provided for interval computation
695695
needs_covariates <- model_object$model_params$include_mean_forest
696696
if (needs_covariates) {
697-
if (is.null(covariates)) {
697+
if (is.null(X)) {
698698
stop(
699-
"'covariates' must be provided in order to compute the requested intervals"
699+
"'X' must be provided in order to compute the requested intervals"
700700
)
701701
}
702-
if (!is.matrix(covariates) && !is.data.frame(covariates)) {
703-
stop("'covariates' must be a matrix or data frame")
702+
if (!is.matrix(X) && !is.data.frame(X)) {
703+
stop("'X' must be a matrix or data frame")
704704
}
705705
}
706706
needs_basis <- needs_covariates && model_object$model_params$has_basis
707707
if (needs_basis) {
708-
if (is.null(basis)) {
708+
if (is.null(leaf_basis)) {
709709
stop(
710-
"'basis' must be provided in order to compute the requested intervals"
710+
"'leaf_basis' must be provided in order to compute the requested intervals"
711711
)
712712
}
713-
if (!is.matrix(basis)) {
714-
stop("'basis' must be a matrix")
713+
if (!is.matrix(leaf_basis)) {
714+
stop("'leaf_basis' must be a matrix")
715715
}
716-
if (is.matrix(basis)) {
717-
if (nrow(basis) != nrow(covariates)) {
718-
stop("'basis' must have the same number of rows as 'covariates'")
716+
if (is.matrix(leaf_basis)) {
717+
if (nrow(leaf_basis) != nrow(X)) {
718+
stop("'leaf_basis' must have the same number of rows as 'X'")
719719
}
720720
} else {
721-
if (length(basis) != nrow(covariates)) {
722-
stop("'basis' must have the same number of elements as 'covariates'")
721+
if (length(leaf_basis) != nrow(X)) {
722+
stop("'leaf_basis' must have the same number of elements as 'X'")
723723
}
724724
}
725725
}
@@ -730,9 +730,9 @@ sample_bart_posterior_predictive <- function(
730730
"'rfx_group_ids' must be provided in order to compute the requested intervals"
731731
)
732732
}
733-
if (length(rfx_group_ids) != nrow(covariates)) {
733+
if (length(rfx_group_ids) != nrow(X)) {
734734
stop(
735-
"'rfx_group_ids' must have the same length as the number of rows in 'covariates'"
735+
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
736736
)
737737
}
738738
if (is.null(rfx_basis)) {
@@ -743,16 +743,16 @@ sample_bart_posterior_predictive <- function(
743743
if (!is.matrix(rfx_basis)) {
744744
stop("'rfx_basis' must be a matrix")
745745
}
746-
if (nrow(rfx_basis) != nrow(covariates)) {
747-
stop("'rfx_basis' must have the same number of rows as 'covariates'")
746+
if (nrow(rfx_basis) != nrow(X)) {
747+
stop("'rfx_basis' must have the same number of rows as 'X'")
748748
}
749749
}
750750

751751
# Compute posterior samples
752752
bart_preds <- predict(
753753
model_object,
754-
X = covariates,
755-
leaf_basis = basis,
754+
X = X,
755+
leaf_basis = leaf_basis,
756756
rfx_group_ids = rfx_group_ids,
757757
rfx_basis = rfx_basis,
758758
type = "posterior",
@@ -766,7 +766,7 @@ sample_bart_posterior_predictive <- function(
766766
has_variance_forest <- model_object$model_params$include_variance_forest
767767
samples_global_variance <- model_object$model_params$sample_sigma2_global
768768
num_posterior_draws <- model_object$model_params$num_samples
769-
num_observations <- nrow(covariates)
769+
num_observations <- nrow(X)
770770
if (has_mean_term) {
771771
ppd_mean <- bart_preds$y_hat
772772
} else {

man/sample_bart_posterior_predictive.Rd

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

stochtree/bart.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,8 +2338,8 @@ def compute_posterior_interval(
23382338

23392339
def sample_posterior_predictive(
23402340
self,
2341-
covariates: np.array = None,
2342-
basis: np.array = None,
2341+
X: np.array = None,
2342+
leaf_basis: np.array = None,
23432343
rfx_group_ids: np.array = None,
23442344
rfx_basis: np.array = None,
23452345
num_draws_per_sample: int = None,
@@ -2349,9 +2349,9 @@ def sample_posterior_predictive(
23492349
23502350
Parameters
23512351
----------
2352-
covariates : np.array, optional
2352+
X : np.array, optional
23532353
An array or data frame of covariates at which to compute the intervals. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).
2354-
basis : np.array, optional
2354+
leaf_basis : np.array, optional
23552355
An array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.
23562356
rfx_group_ids : np.array, optional
23572357
An array of group IDs for random effects. Required if the BART model includes random effects.
@@ -2375,25 +2375,25 @@ def sample_posterior_predictive(
23752375
# Check that all the necessary inputs were provided for interval computation
23762376
needs_covariates = self.include_mean_forest
23772377
if needs_covariates:
2378-
if covariates is None:
2378+
if X is None:
23792379
raise ValueError(
2380-
"'covariates' must be provided in order to compute the requested intervals"
2380+
"'X' must be provided in order to compute the requested intervals"
23812381
)
2382-
if not isinstance(covariates, np.ndarray) and not isinstance(
2383-
covariates, pd.DataFrame
2382+
if not isinstance(X, np.ndarray) and not isinstance(
2383+
X, pd.DataFrame
23842384
):
2385-
raise ValueError("'covariates' must be a matrix or data frame")
2385+
raise ValueError("'X' must be a matrix or data frame")
23862386
needs_basis = needs_covariates and self.has_basis
23872387
if needs_basis:
2388-
if basis is None:
2388+
if leaf_basis is None:
23892389
raise ValueError(
2390-
"'basis' must be provided in order to compute the requested intervals"
2390+
"'leaf_basis' must be provided in order to compute the requested intervals"
23912391
)
2392-
if not isinstance(basis, np.ndarray):
2393-
raise ValueError("'basis' must be a numpy array")
2394-
if basis.shape[0] != covariates.shape[0]:
2392+
if not isinstance(leaf_basis, np.ndarray):
2393+
raise ValueError("'leaf_basis' must be a numpy array")
2394+
if leaf_basis.shape[0] != X.shape[0]:
23952395
raise ValueError(
2396-
"'basis' must have the same number of rows as 'covariates'"
2396+
"'leaf_basis' must have the same number of rows as 'X'"
23972397
)
23982398
needs_rfx_data = self.has_rfx
23992399
if needs_rfx_data:
@@ -2403,25 +2403,25 @@ def sample_posterior_predictive(
24032403
)
24042404
if not isinstance(rfx_group_ids, np.ndarray):
24052405
raise ValueError("'rfx_group_ids' must be a numpy array")
2406-
if rfx_group_ids.shape[0] != covariates.shape[0]:
2406+
if rfx_group_ids.shape[0] != X.shape[0]:
24072407
raise ValueError(
2408-
"'rfx_group_ids' must have the same length as the number of rows in 'covariates'"
2408+
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
24092409
)
24102410
if rfx_basis is None:
24112411
raise ValueError(
24122412
"'rfx_basis' must be provided in order to compute the requested intervals"
24132413
)
24142414
if not isinstance(rfx_basis, np.ndarray):
24152415
raise ValueError("'rfx_basis' must be a numpy array")
2416-
if rfx_basis.shape[0] != covariates.shape[0]:
2416+
if rfx_basis.shape[0] != X.shape[0]:
24172417
raise ValueError(
2418-
"'rfx_basis' must have the same number of rows as 'covariates'"
2418+
"'rfx_basis' must have the same number of rows as 'X'"
24192419
)
24202420

24212421
# Compute posterior predictive samples
24222422
bart_preds = self.predict(
2423-
covariates=covariates,
2424-
basis=basis,
2423+
X=X,
2424+
leaf_basis=leaf_basis,
24252425
rfx_group_ids=rfx_group_ids,
24262426
rfx_basis=rfx_basis,
24272427
type="posterior",
@@ -2433,7 +2433,7 @@ def sample_posterior_predictive(
24332433
has_variance_forest = self.include_variance_forest
24342434
samples_global_variance = self.sample_sigma2_global
24352435
num_posterior_draws = self.num_samples
2436-
num_observations = covariates.shape[0]
2436+
num_observations = X.shape[0]
24372437
if has_mean_term:
24382438
ppd_mean = bart_preds["y_hat"]
24392439
else:

tools/debug/bart_predict_debug.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ y_hat_intervals <- compute_bart_posterior_interval(
6767

6868
pred_intervals <- sample_bart_posterior_predictive(
6969
model_object = bart_model,
70-
covariates = X_test,
70+
X = X_test,
7171
level = 0.95
7272
)
7373

@@ -169,7 +169,7 @@ lines(y_hat_prob_intervals$upper[sort_inds])
169169
# Draw from posterior predictive for covariates in the test set
170170
ppd_samples <- sample_bart_posterior_predictive(
171171
model_object = bart_model,
172-
covariates = X_test,
172+
X = X_test,
173173
num_draws = 10
174174
)
175175

0 commit comments

Comments
 (0)