1717from sklearn .base import clone
1818from sklearn .ensemble import RandomForestClassifier
1919from sklearn .ensemble ._base import _set_random_states
20+ from sklearn .ensemble ._forest import _get_n_samples_bootstrap
2021from sklearn .ensemble ._forest import _parallel_build_trees
2122from sklearn .exceptions import DataConversionWarning
2223from sklearn .tree import DecisionTreeClassifier
@@ -44,6 +45,7 @@ def _local_parallel_build_trees(
4445 n_trees ,
4546 verbose = 0 ,
4647 class_weight = None ,
48+ n_samples_bootstrap = None
4749):
4850 # resample before to fit the tree
4951 X_resampled , y_resampled = sampler .fit_resample (X , y )
@@ -59,7 +61,7 @@ def _local_parallel_build_trees(
5961 n_trees ,
6062 verbose = verbose ,
6163 class_weight = class_weight ,
62- n_samples_bootstrap = X_resampled . shape [ 0 ] ,
64+ n_samples_bootstrap = n_samples_bootstrap ,
6365 )
6466 return sampler , tree
6567
@@ -195,6 +197,27 @@ class BalancedRandomForestClassifier(RandomForestClassifier):
195197 Note that these weights will be multiplied with sample_weight (passed
196198 through the fit method) if sample_weight is specified.
197199
200+
201+ ccp_alpha : non-negative float, optional (default=0.0)
202+ Complexity parameter used for Minimal Cost-Complexity Pruning. The
203+ subtree with the largest cost complexity that is smaller than
204+ ``ccp_alpha`` will be chosen. By default, no pruning is performed. See
205+ :ref:`minimal_cost_complexity_pruning` for details.
206+
207+ .. versionadded:: 0.22
208+ Added in `scikit-learn` in 0.22
209+
210+ max_samples : int or float, default=None
211+ If bootstrap is True, the number of samples to draw from X
212+ to train each base estimator.
213+ - If None (default), then draw `X.shape[0]` samples.
214+ - If int, then draw `max_samples` samples.
215+ - If float, then draw `max_samples * X.shape[0]` samples. Thus,
216+ `max_samples` should be in the interval `(0, 1)`.
217+
218+ .. versionadded:: 0.22
219+ Added in `scikit-learn` in 0.22
220+
198221 Attributes
199222 ----------
200223 estimators_ : list of DecisionTreeClassifier
@@ -281,6 +304,8 @@ def __init__(
281304 verbose = 0 ,
282305 warm_start = False ,
283306 class_weight = None ,
307+ ccp_alpha = 0.0 ,
308+ max_samples = None ,
284309 ):
285310 super ().__init__ (
286311 criterion = criterion ,
@@ -299,6 +324,8 @@ def __init__(
299324 max_features = max_features ,
300325 max_leaf_nodes = max_leaf_nodes ,
301326 min_impurity_decrease = min_impurity_decrease ,
327+ ccp_alpha = ccp_alpha ,
328+ max_samples = max_samples ,
302329 )
303330
304331 self .sampling_strategy = sampling_strategy
@@ -414,6 +441,12 @@ def fit(self, X, y, sample_weight=None):
414441 else :
415442 sample_weight = expanded_class_weight
416443
444+ # Get bootstrap sample size
445+ n_samples_bootstrap = _get_n_samples_bootstrap (
446+ n_samples = X .shape [0 ],
447+ max_samples = self .max_samples
448+ )
449+
417450 # Check parameters
418451 self ._validate_estimator ()
419452
@@ -479,6 +512,7 @@ def fit(self, X, y, sample_weight=None):
479512 len (trees ),
480513 verbose = self .verbose ,
481514 class_weight = self .class_weight ,
515+ n_samples_bootstrap = n_samples_bootstrap ,
482516 )
483517 for i , (s , t ) in enumerate (zip (samplers , trees ))
484518 )
0 commit comments