Skip to content

Commit 3f01307

Browse files
committed
Reflected this change through the Python interface as well
1 parent d67078b commit 3f01307

File tree

4 files changed

+189
-123
lines changed

4 files changed

+189
-123
lines changed

R/bart.R

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -468,10 +468,7 @@ bart <- function(
468468

469469
# Check for binary variables
470470
if (num_unique_values == 2) {
471-
already_flagged <- (num_values > 100) && (num_unique_values < 20)
472-
if (!already_flagged) {
473-
covs_warning_4 <- c(covs_warning_4, cov_name)
474-
}
471+
covs_warning_4 <- c(covs_warning_4, cov_name)
475472
}
476473
}
477474
}

R/bcf.R

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -573,10 +573,7 @@ bcf <- function(
573573

574574
# Check for binary variables
575575
if (num_unique_values == 2) {
576-
already_flagged <- (num_values > 100) && (num_unique_values < 20)
577-
if (!already_flagged) {
578-
covs_warning_4 <- c(covs_warning_4, cov_name)
579-
}
576+
covs_warning_4 <- c(covs_warning_4, cov_name)
580577
}
581578
}
582579
}

stochtree/bart.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -456,32 +456,45 @@ def sample(
456456
if (num_gfr > 0) and (num_burnin == 0) and (num_mcmc == 0):
457457
num_values, num_cov_orig = X_train.shape
458458
max_grid_size = floor(num_values / cutpoint_grid_size)
459+
x_is_df = isinstance(X_train, pd.DataFrame)
459460
covs_warning_1 = []
460461
covs_warning_2 = []
461462
covs_warning_3 = []
463+
covs_warning_4 = []
462464
for i in range(num_cov_orig):
463-
# Determine the number of unique covariate values and a name for the covariate
464-
if isinstance(X_train, np.ndarray):
465-
x_j_hist = np.unique_counts(X_train[:, i]).counts
466-
cov_name = f"X{i + 1}"
467-
else:
468-
x_j_hist = (X_train.iloc[:, i]).value_counts()
469-
cov_name = X_train.columns[i]
465+
# Skip check for variables that are treated as categorical
466+
x_numeric = True
467+
if x_is_df:
468+
if isinstance(X_train.iloc[:,i].dtype, pd.CategoricalDtype):
469+
x_numeric = False
470+
471+
if x_numeric:
472+
# Determine the number of unique covariate values and a name for the covariate
473+
if isinstance(X_train, np.ndarray):
474+
x_j_hist = np.unique_counts(X_train[:, i]).counts
475+
cov_name = f"X{i + 1}"
476+
else:
477+
x_j_hist = (X_train.iloc[:, i]).value_counts()
478+
cov_name = X_train.columns[i]
470479

471-
# Check for a small relative number of unique values
472-
num_unique_values = len(x_j_hist)
473-
unique_full_ratio = num_unique_values / num_values
474-
if unique_full_ratio < 0.2:
475-
covs_warning_1.append(cov_name)
480+
# Check for a small relative number of unique values
481+
num_unique_values = len(x_j_hist)
482+
unique_full_ratio = num_unique_values / num_values
483+
if unique_full_ratio < 0.2:
484+
covs_warning_1.append(cov_name)
476485

477-
# Check for a small absolute number of unique values
478-
if num_values > 100:
479-
if num_unique_values < 20:
480-
covs_warning_2.append(cov_name)
486+
# Check for a small absolute number of unique values
487+
if num_values > 100:
488+
if num_unique_values < 20:
489+
covs_warning_2.append(cov_name)
481490

482-
# Check for a large number of duplicates of any individual value
483-
if np.any(x_j_hist > 2 * max_grid_size):
484-
covs_warning_3.append(cov_name)
491+
# Check for a large number of duplicates of any individual value
492+
if np.any(x_j_hist > 2 * max_grid_size):
493+
covs_warning_3.append(cov_name)
494+
495+
# Check for binary variables
496+
if num_unique_values == 2:
497+
covs_warning_4.append(cov_name)
485498

486499
if covs_warning_1:
487500
warnings.warn(
@@ -505,6 +518,13 @@ def sample(
505518
"Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance."
506519
)
507520

521+
if covs_warning_4:
522+
warnings.warn(
523+
f"Covariates {', '.join(covs_warning_4)} appear to be binary but are currently treated by stochtree as continuous. "
524+
"This might present some issues with the grow-from-root (GFR) algorithm. "
525+
"Consider converting binary variables to ordered categorical (i.e. `pd.Categorical(..., ordered = True)`."
526+
)
527+
508528
# Variable weight preprocessing (and initialization if necessary)
509529
p = X_train.shape[1]
510530
if variable_weights is None:

0 commit comments

Comments
 (0)