Skip to content

Commit d67078b

Browse files
committed
Added checks for variables already being treated as categorical and also explicitly flagging binary variables
1 parent ae21853 commit d67078b

File tree

2 files changed

+109
-46
lines changed

2 files changed

+109
-46
lines changed

R/bart.R

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -423,37 +423,56 @@ bart <- function(
423423
floor(num_values / cutpoint_grid_size),
424424
1
425425
)
426+
x_is_df <- is.data.frame(X_train)
426427
covs_warning_1 <- NULL
427428
covs_warning_2 <- NULL
428429
covs_warning_3 <- NULL
430+
covs_warning_4 <- NULL
429431
for (i in 1:num_cov_orig) {
430-
# Determine the number of unique values
431-
num_unique_values <- length(unique(X_train[, i]))
432-
433-
# Determine a "name" for the covariate
434-
cov_name <- ifelse(
435-
is.null(colnames(X_train)),
436-
paste0("X", i),
437-
colnames(X_train)[i]
438-
)
439-
440-
# Check for a small relative number of unique values
441-
unique_full_ratio <- num_unique_values / num_values
442-
if (unique_full_ratio < 0.2) {
443-
covs_warning_1 <- c(covs_warning_1, cov_name)
432+
# Skip check for variables that are treated as categorical
433+
x_numeric <- T
434+
if (x_is_df) {
435+
if (is.factor(X_train[, i])) {
436+
x_numeric <- F
437+
}
444438
}
439+
if (x_numeric) {
440+
# Determine the number of unique values
441+
num_unique_values <- length(unique(X_train[, i]))
442+
443+
# Determine a "name" for the covariate
444+
cov_name <- ifelse(
445+
is.null(colnames(X_train)),
446+
paste0("X", i),
447+
colnames(X_train)[i]
448+
)
445449

446-
# Check for a small absolute number of unique values
447-
if (num_values > 100) {
448-
if (num_unique_values < 20) {
449-
covs_warning_2 <- c(covs_warning_2, cov_name)
450+
# Check for a small relative number of unique values
451+
unique_full_ratio <- num_unique_values / num_values
452+
if (unique_full_ratio < 0.2) {
453+
covs_warning_1 <- c(covs_warning_1, cov_name)
454+
}
455+
456+
# Check for a small absolute number of unique values
457+
if (num_values > 100) {
458+
if (num_unique_values < 20) {
459+
covs_warning_2 <- c(covs_warning_2, cov_name)
460+
}
461+
}
462+
463+
# Check for a large number of duplicates of any individual value
464+
x_j_hist <- table(X_train[, i])
465+
if (any(x_j_hist > 2 * max_grid_size)) {
466+
covs_warning_3 <- c(covs_warning_3, cov_name)
450467
}
451-
}
452468

453-
# Check for a large number of duplicates of any individual value
454-
x_j_hist <- table(X_train[, i])
455-
if (any(x_j_hist > 2 * max_grid_size)) {
456-
covs_warning_3 <- c(covs_warning_3, cov_name)
469+
# Check for binary variables
470+
if (num_unique_values == 2) {
471+
already_flagged <- (num_values > 100) && (num_unique_values < 20)
472+
if (!already_flagged) {
473+
covs_warning_4 <- c(covs_warning_4, cov_name)
474+
}
475+
}
457476
}
458477
}
459478

@@ -494,6 +513,18 @@ bart <- function(
494513
)
495514
)
496515
}
516+
517+
if (!is.null(covs_warning_4)) {
518+
warning(
519+
paste0(
520+
"Covariates ",
521+
paste(covs_warning_4, collapse = ", "),
522+
" appear to be binary but are currently treated by stochtree as continuous. ",
523+
"This might present some issues with the grow-from-root (GFR) algorithm. ",
524+
"Consider converting binary variables to ordered factor (i.e. `factor(..., ordered = T)`."
525+
)
526+
)
527+
}
497528
}
498529

499530
# Standardize the keep variable lists to numeric indices

R/bcf.R

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -527,37 +527,57 @@ bcf <- function(
527527
floor(num_values / cutpoint_grid_size),
528528
1
529529
)
530+
x_is_df <- is.data.frame(X_train)
530531
covs_warning_1 <- NULL
531532
covs_warning_2 <- NULL
532533
covs_warning_3 <- NULL
534+
covs_warning_4 <- NULL
533535
for (i in 1:num_cov_orig) {
534-
# Determine the number of unique values
535-
num_unique_values <- length(unique(X_train[, i]))
536-
537-
# Determine a "name" for the covariate
538-
cov_name <- ifelse(
539-
is.null(colnames(X_train)),
540-
paste0("X", i),
541-
colnames(X_train)[i]
542-
)
543-
544-
# Check for a small relative number of unique values
545-
unique_full_ratio <- num_unique_values / num_values
546-
if (unique_full_ratio < 0.2) {
547-
covs_warning_1 <- c(covs_warning_1, cov_name)
536+
# Skip check for variables that are treated as categorical
537+
x_numeric <- T
538+
if (x_is_df) {
539+
if (is.factor(X_train[, i])) {
540+
x_numeric <- F
541+
}
548542
}
549543

550-
# Check for a small absolute number of unique values
551-
if (num_values > 100) {
552-
if (num_unique_values < 20) {
553-
covs_warning_2 <- c(covs_warning_2, cov_name)
544+
if (x_numeric) {
545+
# Determine the number of unique values
546+
num_unique_values <- length(unique(X_train[, i]))
547+
548+
# Determine a "name" for the covariate
549+
cov_name <- ifelse(
550+
is.null(colnames(X_train)),
551+
paste0("X", i),
552+
colnames(X_train)[i]
553+
)
554+
555+
# Check for a small relative number of unique values
556+
unique_full_ratio <- num_unique_values / num_values
557+
if (unique_full_ratio < 0.2) {
558+
covs_warning_1 <- c(covs_warning_1, cov_name)
559+
}
560+
561+
# Check for a small absolute number of unique values
562+
if (num_values > 100) {
563+
if (num_unique_values < 20) {
564+
covs_warning_2 <- c(covs_warning_2, cov_name)
565+
}
566+
}
567+
568+
# Check for a large number of duplicates of any individual value
569+
x_j_hist <- table(X_train[, i])
570+
if (any(x_j_hist > 2 * max_grid_size)) {
571+
covs_warning_3 <- c(covs_warning_3, cov_name)
554572
}
555-
}
556573

557-
# Check for a large number of duplicates of any individual value
558-
x_j_hist <- table(X_train[, i])
559-
if (any(x_j_hist > 2 * max_grid_size)) {
560-
covs_warning_3 <- c(covs_warning_3, cov_name)
574+
# Check for binary variables
575+
if (num_unique_values == 2) {
576+
already_flagged <- (num_values > 100) && (num_unique_values < 20)
577+
if (!already_flagged) {
578+
covs_warning_4 <- c(covs_warning_4, cov_name)
579+
}
580+
}
561581
}
562582
}
563583

@@ -598,6 +618,18 @@ bcf <- function(
598618
)
599619
)
600620
}
621+
622+
if (!is.null(covs_warning_4)) {
623+
warning(
624+
paste0(
625+
"Covariates ",
626+
paste(covs_warning_4, collapse = ", "),
627+
" appear to be binary but are currently treated by stochtree as continuous. ",
628+
"This might present some issues with the grow-from-root (GFR) algorithm. ",
629+
"Consider converting binary variables to ordered factor (i.e. `factor(..., ordered = T)`."
630+
)
631+
)
632+
}
601633
}
602634

603635
# Check delta_max is valid

0 commit comments

Comments
 (0)