@@ -456,32 +456,45 @@ def sample(
456456 if (num_gfr > 0 ) and (num_burnin == 0 ) and (num_mcmc == 0 ):
457457 num_values , num_cov_orig = X_train .shape
458458 max_grid_size = floor (num_values / cutpoint_grid_size )
459+ x_is_df = isinstance (X_train , pd .DataFrame )
459460 covs_warning_1 = []
460461 covs_warning_2 = []
461462 covs_warning_3 = []
463+ covs_warning_4 = []
462464 for i in range (num_cov_orig ):
463- # Determine the number of unique covariate values and a name for the covariate
464- if isinstance (X_train , np .ndarray ):
465- x_j_hist = np .unique_counts (X_train [:, i ]).counts
466- cov_name = f"X{ i + 1 } "
467- else :
468- x_j_hist = (X_train .iloc [:, i ]).value_counts ()
469- cov_name = X_train .columns [i ]
465+ # Skip check for variables that are treated as categorical
466+ x_numeric = True
467+ if x_is_df :
468+ if isinstance (X_train .iloc [:,i ].dtype , pd .CategoricalDtype ):
469+ x_numeric = False
470+
471+ if x_numeric :
472+ # Determine the number of unique covariate values and a name for the covariate
473+ if isinstance (X_train , np .ndarray ):
474+ x_j_hist = np .unique_counts (X_train [:, i ]).counts
475+ cov_name = f"X{ i + 1 } "
476+ else :
477+ x_j_hist = (X_train .iloc [:, i ]).value_counts ()
478+ cov_name = X_train .columns [i ]
470479
471- # Check for a small relative number of unique values
472- num_unique_values = len (x_j_hist )
473- unique_full_ratio = num_unique_values / num_values
474- if unique_full_ratio < 0.2 :
475- covs_warning_1 .append (cov_name )
480+ # Check for a small relative number of unique values
481+ num_unique_values = len (x_j_hist )
482+ unique_full_ratio = num_unique_values / num_values
483+ if unique_full_ratio < 0.2 :
484+ covs_warning_1 .append (cov_name )
476485
477- # Check for a small absolute number of unique values
478- if num_values > 100 :
479- if num_unique_values < 20 :
480- covs_warning_2 .append (cov_name )
486+ # Check for a small absolute number of unique values
487+ if num_values > 100 :
488+ if num_unique_values < 20 :
489+ covs_warning_2 .append (cov_name )
481490
482- # Check for a large number of duplicates of any individual value
483- if np .any (x_j_hist > 2 * max_grid_size ):
484- covs_warning_3 .append (cov_name )
491+ # Check for a large number of duplicates of any individual value
492+ if np .any (x_j_hist > 2 * max_grid_size ):
493+ covs_warning_3 .append (cov_name )
494+
495+ # Check for binary variables
496+ if num_unique_values == 2 :
497+ covs_warning_4 .append (cov_name )
485498
486499 if covs_warning_1 :
487500 warnings .warn (
@@ -505,6 +518,13 @@ def sample(
505518 "Consider running with `num_mcmc > 0` and `num_burnin > 0` to improve your model's performance."
506519 )
507520
521+ if covs_warning_4 :
522+ warnings .warn (
523+ f"Covariates { ', ' .join (covs_warning_4 )} appear to be binary but are currently treated by stochtree as continuous. "
524+ "This might present some issues with the grow-from-root (GFR) algorithm. "
525+ "Consider converting binary variables to ordered categorical (i.e. `pd.Categorical(..., ordered = True)`."
526+ )
527+
508528 # Variable weight preprocessing (and initialization if necessary)
509529 p = X_train .shape [1 ]
510530 if variable_weights is None :
0 commit comments