Skip to content

Commit 459dd78

Browse files
authored
Merge pull request #243 from StochasticTree/variable-check-update
Add checks for variables already being treated as categorical and explicitly flag binary variables
2 parents ae21853 + 3f01307 commit 459dd78

File tree

4 files changed

+290
-161
lines changed

4 files changed

+290
-161
lines changed

R/bart.R

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -423,37 +423,53 @@ bart <- function(
423423
floor(num_values / cutpoint_grid_size),
424424
1
425425
)
426+
x_is_df <- is.data.frame(X_train)
426427
covs_warning_1 <- NULL
427428
covs_warning_2 <- NULL
428429
covs_warning_3 <- NULL
430+
covs_warning_4 <- NULL
429431
for (i in 1:num_cov_orig) {
430-
# Determine the number of unique values
431-
num_unique_values <- length(unique(X_train[, i]))
432-
433-
# Determine a "name" for the covariate
434-
cov_name <- ifelse(
435-
is.null(colnames(X_train)),
436-
paste0("X", i),
437-
colnames(X_train)[i]
438-
)
439-
440-
# Check for a small relative number of unique values
441-
unique_full_ratio <- num_unique_values / num_values
442-
if (unique_full_ratio < 0.2) {
443-
covs_warning_1 <- c(covs_warning_1, cov_name)
432+
# Skip check for variables that are treated as categorical
433+
x_numeric <- T
434+
if (x_is_df) {
435+
if (is.factor(X_train[, i])) {
436+
x_numeric <- F
437+
}
444438
}
439+
if (x_numeric) {
440+
# Determine the number of unique values
441+
num_unique_values <- length(unique(X_train[, i]))
442+
443+
# Determine a "name" for the covariate
444+
cov_name <- ifelse(
445+
is.null(colnames(X_train)),
446+
paste0("X", i),
447+
colnames(X_train)[i]
448+
)
445449

446-
# Check for a small absolute number of unique values
447-
if (num_values > 100) {
448-
if (num_unique_values < 20) {
449-
covs_warning_2 <- c(covs_warning_2, cov_name)
450+
# Check for a small relative number of unique values
451+
unique_full_ratio <- num_unique_values / num_values
452+
if (unique_full_ratio < 0.2) {
453+
covs_warning_1 <- c(covs_warning_1, cov_name)
454+
}
455+
456+
# Check for a small absolute number of unique values
457+
if (num_values > 100) {
458+
if (num_unique_values < 20) {
459+
covs_warning_2 <- c(covs_warning_2, cov_name)
460+
}
461+
}
462+
463+
# Check for a large number of duplicates of any individual value
464+
x_j_hist <- table(X_train[, i])
465+
if (any(x_j_hist > 2 * max_grid_size)) {
466+
covs_warning_3 <- c(covs_warning_3, cov_name)
450467
}
451-
}
452468

453-
# Check for a large number of duplicates of any individual value
454-
x_j_hist <- table(X_train[, i])
455-
if (any(x_j_hist > 2 * max_grid_size)) {
456-
covs_warning_3 <- c(covs_warning_3, cov_name)
469+
# Check for binary variables
470+
if (num_unique_values == 2) {
471+
covs_warning_4 <- c(covs_warning_4, cov_name)
472+
}
457473
}
458474
}
459475

@@ -494,6 +510,18 @@ bart <- function(
494510
)
495511
)
496512
}
513+
514+
if (!is.null(covs_warning_4)) {
515+
warning(
516+
paste0(
517+
"Covariates ",
518+
paste(covs_warning_4, collapse = ", "),
519+
" appear to be binary but are currently treated by stochtree as continuous. ",
520+
"This might present some issues with the grow-from-root (GFR) algorithm. ",
521+
"Consider converting binary variables to ordered factor (i.e. `factor(..., ordered = T)`."
522+
)
523+
)
524+
}
497525
}
498526

499527
# Standardize the keep variable lists to numeric indices

R/bcf.R

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -527,37 +527,54 @@ bcf <- function(
527527
floor(num_values / cutpoint_grid_size),
528528
1
529529
)
530+
x_is_df <- is.data.frame(X_train)
530531
covs_warning_1 <- NULL
531532
covs_warning_2 <- NULL
532533
covs_warning_3 <- NULL
534+
covs_warning_4 <- NULL
533535
for (i in 1:num_cov_orig) {
534-
# Determine the number of unique values
535-
num_unique_values <- length(unique(X_train[, i]))
536-
537-
# Determine a "name" for the covariate
538-
cov_name <- ifelse(
539-
is.null(colnames(X_train)),
540-
paste0("X", i),
541-
colnames(X_train)[i]
542-
)
543-
544-
# Check for a small relative number of unique values
545-
unique_full_ratio <- num_unique_values / num_values
546-
if (unique_full_ratio < 0.2) {
547-
covs_warning_1 <- c(covs_warning_1, cov_name)
536+
# Skip check for variables that are treated as categorical
537+
x_numeric <- T
538+
if (x_is_df) {
539+
if (is.factor(X_train[, i])) {
540+
x_numeric <- F
541+
}
548542
}
549543

550-
# Check for a small absolute number of unique values
551-
if (num_values > 100) {
552-
if (num_unique_values < 20) {
553-
covs_warning_2 <- c(covs_warning_2, cov_name)
544+
if (x_numeric) {
545+
# Determine the number of unique values
546+
num_unique_values <- length(unique(X_train[, i]))
547+
548+
# Determine a "name" for the covariate
549+
cov_name <- ifelse(
550+
is.null(colnames(X_train)),
551+
paste0("X", i),
552+
colnames(X_train)[i]
553+
)
554+
555+
# Check for a small relative number of unique values
556+
unique_full_ratio <- num_unique_values / num_values
557+
if (unique_full_ratio < 0.2) {
558+
covs_warning_1 <- c(covs_warning_1, cov_name)
554559
}
555-
}
556560

557-
# Check for a large number of duplicates of any individual value
558-
x_j_hist <- table(X_train[, i])
559-
if (any(x_j_hist > 2 * max_grid_size)) {
560-
covs_warning_3 <- c(covs_warning_3, cov_name)
561+
# Check for a small absolute number of unique values
562+
if (num_values > 100) {
563+
if (num_unique_values < 20) {
564+
covs_warning_2 <- c(covs_warning_2, cov_name)
565+
}
566+
}
567+
568+
# Check for a large number of duplicates of any individual value
569+
x_j_hist <- table(X_train[, i])
570+
if (any(x_j_hist > 2 * max_grid_size)) {
571+
covs_warning_3 <- c(covs_warning_3, cov_name)
572+
}
573+
574+
# Check for binary variables
575+
if (num_unique_values == 2) {
576+
covs_warning_4 <- c(covs_warning_4, cov_name)
577+
}
561578
}
562579
}
563580

@@ -598,6 +615,18 @@ bcf <- function(
598615
)
599616
)
600617
}
618+
619+
if (!is.null(covs_warning_4)) {
620+
warning(
621+
paste0(
622+
"Covariates ",
623+
paste(covs_warning_4, collapse = ", "),
624+
" appear to be binary but are currently treated by stochtree as continuous. ",
625+
"This might present some issues with the grow-from-root (GFR) algorithm. ",
626+
"Consider converting binary variables to ordered factor (i.e. `factor(..., ordered = T)`."
627+
)
628+
)
629+
}
601630
}
602631

603632
# Check delta_max is valid

stochtree/bart.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -456,32 +456,45 @@ def sample(
456456
if (num_gfr > 0) and (num_burnin == 0) and (num_mcmc == 0):
457457
num_values, num_cov_orig = X_train.shape
458458
max_grid_size = floor(num_values / cutpoint_grid_size)
459+
x_is_df = isinstance(X_train, pd.DataFrame)
459460
covs_warning_1 = []
460461
covs_warning_2 = []
461462
covs_warning_3 = []
463+
covs_warning_4 = []
462464
for i in range(num_cov_orig):
463-
# Determine the number of unique covariate values and a name for the covariate
464-
if isinstance(X_train, np.ndarray):
465-
x_j_hist = np.unique_counts(X_train[:, i]).counts
466-
cov_name = f"X{i + 1}"
467-
else:
468-
x_j_hist = (X_train.iloc[:, i]).value_counts()
469-
cov_name = X_train.columns[i]
465+
# Skip check for variables that are treated as categorical
466+
x_numeric = True
467+
if x_is_df:
468+
if isinstance(X_train.iloc[:,i].dtype, pd.CategoricalDtype):
469+
x_numeric = False
470+
471+
if x_numeric:
472+
# Determine the number of unique covariate values and a name for the covariate
473+
if isinstance(X_train, np.ndarray):
474+
x_j_hist = np.unique_counts(X_train[:, i]).counts
475+
cov_name = f"X{i + 1}"
476+
else:
477+
x_j_hist = (X_train.iloc[:, i]).value_counts()
478+
cov_name = X_train.columns[i]
470479

471-
# Check for a small relative number of unique values
472-
num_unique_values = len(x_j_hist)
473-
unique_full_ratio = num_unique_values / num_values
474-
if unique_full_ratio < 0.2:
475-
covs_warning_1.append(cov_name)
480+
# Check for a small relative number of unique values
481+
num_unique_values = len(x_j_hist)
482+
unique_full_ratio = num_unique_values / num_values
483+
if unique_full_ratio < 0.2:
484+
covs_warning_1.append(cov_name)
476485

477-
# Check for a small absolute number of unique values
478-
if num_values > 100:
479-
if num_unique_values < 20:
480-
covs_warning_2.append(cov_name)
486+
# Check for a small absolute number of unique values
487+
if num_values > 100:
488+
if num_unique_values < 20:
489+
covs_warning_2.append(cov_name)
481490

482-
# Check for a large number of duplicates of any individual value
483-
if np.any(x_j_hist > 2 * max_grid_size):
484-
covs_warning_3.append(cov_name)
491+
# Check for a large number of duplicates of any individual value
492+
if np.any(x_j_hist > 2 * max_grid_size):
493+
covs_warning_3.append(cov_name)
494+
495+
# Check for binary variables
496+
if num_unique_values == 2:
497+
covs_warning_4.append(cov_name)
485498

486499
if covs_warning_1:
487500
warnings.warn(
@@ -505,6 +518,13 @@ def sample(
505518
"Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance."
506519
)
507520

521+
if covs_warning_4:
522+
warnings.warn(
523+
f"Covariates {', '.join(covs_warning_4)} appear to be binary but are currently treated by stochtree as continuous. "
524+
"This might present some issues with the grow-from-root (GFR) algorithm. "
525+
"Consider converting binary variables to ordered categorical (i.e. `pd.Categorical(..., ordered = True)`."
526+
)
527+
508528
# Variable weight preprocessing (and initialization if necessary)
509529
p = X_train.shape[1]
510530
if variable_weights is None:

0 commit comments

Comments
 (0)