Skip to content

Commit 9893fca

Browse files
authored
MNT synchronize forest with scikit-learn (#622)
* MNT add max_features parameter * MNT add ccp_alpha parameter * DOC whats new
1 parent 65c8079 commit 9893fca

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

doc/whats_new/v0.6.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ Maintenance
3232
:pr:`617` by :user:`Guillaume Lemaitre <glemaitre>`.
3333

3434
- Synchronize :mod:`imblearn.pipeline` with :mod:`sklearn.pipeline`.
35-
:pr:`617` by :user:`Guillaume Lemaitre <glemaitre>`.
35+
:pr:`620` by :user:`Guillaume Lemaitre <glemaitre>`.
36+
37+
- Synchronize :class:`imblearn.ensemble.BalancedRandomForestClassifier` and add
38+
parameters `max_samples` and `ccp_alpha`.
39+
:pr:`621` by :user:`Guillaume Lemaitre <glemaitre>`.
3640

3741
Deprecation
3842
...........

imblearn/ensemble/_forest.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sklearn.base import clone
1818
from sklearn.ensemble import RandomForestClassifier
1919
from sklearn.ensemble._base import _set_random_states
20+
from sklearn.ensemble._forest import _get_n_samples_bootstrap
2021
from sklearn.ensemble._forest import _parallel_build_trees
2122
from sklearn.exceptions import DataConversionWarning
2223
from sklearn.tree import DecisionTreeClassifier
@@ -44,6 +45,7 @@ def _local_parallel_build_trees(
4445
n_trees,
4546
verbose=0,
4647
class_weight=None,
48+
n_samples_bootstrap=None
4749
):
4850
# resample before to fit the tree
4951
X_resampled, y_resampled = sampler.fit_resample(X, y)
@@ -59,7 +61,7 @@ def _local_parallel_build_trees(
5961
n_trees,
6062
verbose=verbose,
6163
class_weight=class_weight,
62-
n_samples_bootstrap=X_resampled.shape[0],
64+
n_samples_bootstrap=n_samples_bootstrap,
6365
)
6466
return sampler, tree
6567

@@ -195,6 +197,27 @@ class BalancedRandomForestClassifier(RandomForestClassifier):
195197
Note that these weights will be multiplied with sample_weight (passed
196198
through the fit method) if sample_weight is specified.
197199
200+
201+
ccp_alpha : non-negative float, optional (default=0.0)
202+
Complexity parameter used for Minimal Cost-Complexity Pruning. The
203+
subtree with the largest cost complexity that is smaller than
204+
``ccp_alpha`` will be chosen. By default, no pruning is performed. See
205+
:ref:`minimal_cost_complexity_pruning` for details.
206+
207+
.. versionadded:: 0.22
208+
Added in `scikit-learn` in 0.22
209+
210+
max_samples : int or float, default=None
211+
If bootstrap is True, the number of samples to draw from X
212+
to train each base estimator.
213+
- If None (default), then draw `X.shape[0]` samples.
214+
- If int, then draw `max_samples` samples.
215+
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
216+
`max_samples` should be in the interval `(0, 1)`.
217+
218+
.. versionadded:: 0.22
219+
Added in `scikit-learn` in 0.22
220+
198221
Attributes
199222
----------
200223
estimators_ : list of DecisionTreeClassifier
@@ -281,6 +304,8 @@ def __init__(
281304
verbose=0,
282305
warm_start=False,
283306
class_weight=None,
307+
ccp_alpha=0.0,
308+
max_samples=None,
284309
):
285310
super().__init__(
286311
criterion=criterion,
@@ -299,6 +324,8 @@ def __init__(
299324
max_features=max_features,
300325
max_leaf_nodes=max_leaf_nodes,
301326
min_impurity_decrease=min_impurity_decrease,
327+
ccp_alpha=ccp_alpha,
328+
max_samples=max_samples,
302329
)
303330

304331
self.sampling_strategy = sampling_strategy
@@ -414,6 +441,12 @@ def fit(self, X, y, sample_weight=None):
414441
else:
415442
sample_weight = expanded_class_weight
416443

444+
# Get bootstrap sample size
445+
n_samples_bootstrap = _get_n_samples_bootstrap(
446+
n_samples=X.shape[0],
447+
max_samples=self.max_samples
448+
)
449+
417450
# Check parameters
418451
self._validate_estimator()
419452

@@ -479,6 +512,7 @@ def fit(self, X, y, sample_weight=None):
479512
len(trees),
480513
verbose=self.verbose,
481514
class_weight=self.class_weight,
515+
n_samples_bootstrap=n_samples_bootstrap,
482516
)
483517
for i, (s, t) in enumerate(zip(samplers, trees))
484518
)

imblearn/ensemble/tests/test_forest.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,45 @@ def test_balanced_random_forest_grid_search(imbalanced_dataset):
134134
brf, {"n_estimators": (1, 2), "max_depth": (1, 2)}, cv=3
135135
)
136136
grid.fit(*imbalanced_dataset)
137+
138+
139+
def test_little_tree_with_small_max_samples():
140+
rng = np.random.RandomState(1)
141+
142+
X = rng.randn(10000, 2)
143+
y = rng.randn(10000) > 0
144+
145+
# First fit with no restriction on max samples
146+
est1 = BalancedRandomForestClassifier(
147+
n_estimators=1,
148+
random_state=rng,
149+
max_samples=None,
150+
)
151+
152+
# Second fit with max samples restricted to just 2
153+
est2 = BalancedRandomForestClassifier(
154+
n_estimators=1,
155+
random_state=rng,
156+
max_samples=2,
157+
)
158+
159+
est1.fit(X, y)
160+
est2.fit(X, y)
161+
162+
tree1 = est1.estimators_[0].tree_
163+
tree2 = est2.estimators_[0].tree_
164+
165+
msg = "Tree without `max_samples` restriction should have more nodes"
166+
assert tree1.node_count > tree2.node_count, msg
167+
168+
169+
def test_balanced_random_forest_pruning(imbalanced_dataset):
170+
brf = BalancedRandomForestClassifier()
171+
brf.fit(*imbalanced_dataset)
172+
n_nodes_no_pruning = brf.estimators_[0].tree_.node_count
173+
174+
brf_pruned = BalancedRandomForestClassifier(ccp_alpha=0.015)
175+
brf_pruned.fit(*imbalanced_dataset)
176+
n_nodes_pruning = brf_pruned.estimators_[0].tree_.node_count
177+
178+
assert n_nodes_no_pruning > n_nodes_pruning

0 commit comments

Comments
 (0)