@@ -418,7 +418,11 @@ bart <- function(
418418 # Raise a warning if the data have ties and only GFR is being run
419419 if ((num_gfr > 0 ) && (num_mcmc == 0 ) && (num_burnin == 0 )) {
420420 num_values <- nrow(X_train )
421- max_grid_size <- floor(num_values / cutpoint_grid_size )
421+ max_grid_size <- ifelse(
422+ num_values > cutpoint_grid_size ,
423+ floor(num_values / cutpoint_grid_size ),
424+ 1
425+ )
422426 covs_warning_1 <- NULL
423427 covs_warning_2 <- NULL
424428 covs_warning_3 <- NULL
@@ -1924,7 +1928,7 @@ bart <- function(
19241928# ' Predict from a sampled BART model on new data
19251929# '
19261930# ' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs.
1927- # ' @param covariates Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.
1931+ # ' @param X Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.
19281932# ' @param leaf_basis (Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: `NULL`.
19291933# ' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model.
19301934# ' We do not currently support (but plan to in the near future), test set evaluation for group labels
@@ -1961,10 +1965,10 @@ bart <- function(
19611965# ' y_train <- y[train_inds]
19621966# ' bart_model <- bart(X_train = X_train, y_train = y_train,
19631967# ' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
1964- # ' y_hat_test <- predict(bart_model, X_test)$y_hat
1968+ # ' y_hat_test <- predict(bart_model, X= X_test)$y_hat
19651969predict.bartmodel <- function (
19661970 object ,
1967- covariates ,
1971+ X ,
19681972 leaf_basis = NULL ,
19691973 rfx_group_ids = NULL ,
19701974 rfx_basis = NULL ,
@@ -2047,8 +2051,8 @@ predict.bartmodel <- function(
20472051 }
20482052
20492053 # Check that covariates are matrix or data frame
2050- if ((! is.data.frame(covariates )) && (! is.matrix(covariates ))) {
2051- stop(" covariates must be a matrix or dataframe" )
2054+ if ((! is.data.frame(X )) && (! is.matrix(X ))) {
2055+ stop(" X must be a matrix or dataframe" )
20522056 }
20532057
20542058 # Convert all input data to matrices if not already converted
@@ -2063,12 +2067,12 @@ predict.bartmodel <- function(
20632067 if ((object $ model_params $ requires_basis ) && (is.null(leaf_basis ))) {
20642068 stop(" Basis (leaf_basis) must be provided for this model" )
20652069 }
2066- if ((! is.null(leaf_basis )) && (nrow(covariates ) != nrow(leaf_basis ))) {
2067- stop(" covariates and leaf_basis must have the same number of rows" )
2070+ if ((! is.null(leaf_basis )) && (nrow(X ) != nrow(leaf_basis ))) {
2071+ stop(" X and leaf_basis must have the same number of rows" )
20682072 }
2069- if (object $ model_params $ num_covariates != ncol(covariates )) {
2073+ if (object $ model_params $ num_covariates != ncol(X )) {
20702074 stop(
2071- " covariates must contain the same number of columns as the BART model's training dataset"
2075+ " X must contain the same number of columns as the BART model's training dataset"
20722076 )
20732077 }
20742078 if ((predict_rfx ) && (is.null(rfx_group_ids ))) {
@@ -2089,7 +2093,7 @@ predict.bartmodel <- function(
20892093
20902094 # Preprocess covariates
20912095 train_set_metadata <- object $ train_set_metadata
2092- covariates <- preprocessPredictionData(covariates , train_set_metadata )
2096+ X <- preprocessPredictionData(X , train_set_metadata )
20932097
20942098 # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
20952099 has_rfx <- FALSE
@@ -2119,8 +2123,8 @@ predict.bartmodel <- function(
21192123 # Only construct a basis if user-provided basis missing
21202124 if (is.null(rfx_basis )) {
21212125 rfx_basis <- matrix (
2122- rep(1 , nrow(covariates )),
2123- nrow = nrow(covariates ),
2126+ rep(1 , nrow(X )),
2127+ nrow = nrow(X ),
21242128 ncol = 1
21252129 )
21262130 }
@@ -2129,9 +2133,9 @@ predict.bartmodel <- function(
21292133
21302134 # Create prediction dataset
21312135 if (! is.null(leaf_basis )) {
2132- prediction_dataset <- createForestDataset(covariates , leaf_basis )
2136+ prediction_dataset <- createForestDataset(X , leaf_basis )
21332137 } else {
2134- prediction_dataset <- createForestDataset(covariates )
2138+ prediction_dataset <- createForestDataset(X )
21352139 }
21362140
21372141 # Compute variance forest predictions
@@ -2843,7 +2847,7 @@ createBARTModelFromJsonFile <- function(json_filename) {
28432847# ' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
28442848# ' bart_json <- saveBARTModelToJsonString(bart_model)
28452849# ' bart_model_roundtrip <- createBARTModelFromJsonString(bart_json)
2846- # ' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat)
2850+ # ' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X= X_train)$y_hat)
28472851createBARTModelFromJsonString <- function (json_string ) {
28482852 # Load a `CppJson` object from string
28492853 bart_json <- createCppJsonString(json_string )
0 commit comments