3333from ..utils import Substitution
3434from ..utils ._docstring import _n_jobs_docstring
3535from ..utils ._docstring import _random_state_docstring
36+ from ..utils ._validation import check_sampling_strategy
3637
3738MAX_INT = np .iinfo (np .int32 ).max
3839
@@ -364,7 +365,7 @@ def _validate_estimator(self, default=DecisionTreeClassifier()):
364365 self .base_estimator_ = clone (default )
365366
366367 self .base_sampler_ = RandomUnderSampler (
367- sampling_strategy = self .sampling_strategy ,
368+ sampling_strategy = self ._sampling_strategy ,
368369 replacement = self .replacement ,
369370 )
370371
@@ -447,10 +448,20 @@ def fit(self, X, y, sample_weight=None):
447448
448449 self .n_outputs_ = y .shape [1 ]
449450
450- y , expanded_class_weight = self ._validate_y_class_weight (y )
451+ y_encoded , expanded_class_weight = self ._validate_y_class_weight (y )
451452
452453 if getattr (y , "dtype" , None ) != DOUBLE or not y .flags .contiguous :
453- y = np .ascontiguousarray (y , dtype = DOUBLE )
454+ y_encoded = np .ascontiguousarray (y_encoded , dtype = DOUBLE )
455+
456+ if isinstance (self .sampling_strategy , dict ):
457+ self ._sampling_strategy = {
458+ np .where (self .classes_ [0 ] == key )[0 ][0 ]: value
459+ for key , value in check_sampling_strategy (
460+ self .sampling_strategy , y , 'under-sampling' ,
461+ ).items ()
462+ }
463+ else :
464+ self ._sampling_strategy = self .sampling_strategy
454465
455466 if expanded_class_weight is not None :
456467 if sample_weight is not None :
@@ -523,7 +534,7 @@ def fit(self, X, y, sample_weight=None):
523534 t ,
524535 self ,
525536 X ,
526- y ,
537+ y_encoded ,
527538 sample_weight ,
528539 i ,
529540 len (trees ),
@@ -548,7 +559,7 @@ def fit(self, X, y, sample_weight=None):
548559 )
549560
550561 if self .oob_score :
551- self ._set_oob_score (X , y )
562+ self ._set_oob_score (X , y_encoded )
552563
553564 # Decapsulate classes_ attributes
554565 if hasattr (self , "classes_" ) and self .n_outputs_ == 1 :
0 commit comments