diff --git a/whitebox/analytics/models/pipelines.py b/whitebox/analytics/models/pipelines.py index 02d63ba..ceea0ab 100644 --- a/whitebox/analytics/models/pipelines.py +++ b/whitebox/analytics/models/pipelines.py @@ -95,7 +95,7 @@ def create_multiclass_classification_training_model_pipeline( """ X_train, X_test, y_train, y_test = train_test_split( - X, Y, test_size=0.3, random_state=0 + X, Y, test_size=0.3, random_state=0, stratify = Y ) d_train = lgb.Dataset(X_train, label=y_train) """