Skip to content

Commit 65c8079

Browse files
authored
MNT synchronize imblearn.pipeline with sklearn.pipeline (#620)
2 parents cb45ad0 + 1d2d73c commit 65c8079

File tree

4 files changed

+188
-32
lines changed

4 files changed

+188
-32
lines changed

doc/whats_new/v0.6.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ Maintenance
3131
:class:`sklearn.utils._testing.SkipTest`.
3232
:pr:`617` by :user:`Guillaume Lemaitre <glemaitre>`.
3333

34+
- Synchronize :mod:`imblearn.pipeline` with :mod:`sklearn.pipeline`.
35+
:pr:`617` by :user:`Guillaume Lemaitre <glemaitre>`.
36+
3437
Deprecation
3538
...........
3639

imblearn/pipeline.py

Lines changed: 99 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from sklearn import pipeline
1717
from sklearn.base import clone
18+
from sklearn.utils import Bunch, _print_elapsed_time
1819
from sklearn.utils.metaestimators import if_delegate_has_method
1920
from sklearn.utils.validation import check_memory
2021

@@ -57,10 +58,13 @@ class Pipeline(pipeline.Pipeline):
5758
inspect estimators within the pipeline. Caching the
5859
transformers is advantageous when fitting is time consuming.
5960
61+
verbose : boolean, optional (default=False)
62+
If True, the time elapsed while fitting each step will be printed as it
63+
is completed.
6064
6165
Attributes
6266
----------
63-
named_steps : dict
67+
named_steps : bunch object, a dictionary with attribute access
6468
Read-only attribute to access any step parameter by user given name.
6569
Keys are step names and values are steps parameters.
6670
@@ -178,9 +182,23 @@ def _fit(self, X, y=None, **fit_params):
178182
name: {} for name, step in self.steps if step is not None
179183
}
180184
for pname, pval in fit_params.items():
185+
if '__' not in pname:
186+
raise ValueError(
187+
"Pipeline.fit does not accept the {} parameter. "
188+
"You can pass parameters to specific steps of your "
189+
"pipeline using the stepname__parameter format, e.g. "
190+
"`Pipeline.fit(X, y, logisticregression__sample_weight"
191+
"=sample_weight)`.".format(pname))
181192
step, param = pname.split("__", 1)
182193
fit_params_steps[step][param] = pval
183-
for step_idx, name, transformer in self._iter(with_final=False):
194+
for (step_idx,
195+
name,
196+
transformer) in self._iter(with_final=False,
197+
filter_passthrough=False):
198+
if (transformer is None or transformer == 'passthrough'):
199+
with _print_elapsed_time('Pipeline',
200+
self._log_message(step_idx)):
201+
continue
184202
if hasattr(memory, "location"):
185203
# joblib >= 0.12
186204
if memory.location is None:
@@ -202,11 +220,17 @@ def _fit(self, X, y=None, **fit_params):
202220
cloned_transformer, "fit_transform"
203221
):
204222
X, fitted_transformer = fit_transform_one_cached(
205-
cloned_transformer, None, X, y, **fit_params_steps[name]
223+
cloned_transformer, X, y, None,
224+
message_clsname='Pipeline',
225+
message=self._log_message(step_idx),
226+
**fit_params_steps[name]
206227
)
207228
elif hasattr(cloned_transformer, "fit_resample"):
208229
X, y, fitted_transformer = fit_resample_one_cached(
209-
cloned_transformer, X, y, **fit_params_steps[name]
230+
cloned_transformer, X, y,
231+
message_clsname='Pipeline',
232+
message=self._log_message(step_idx),
233+
**fit_params_steps[name]
210234
)
211235
# Replace the transformer of the step with the fitted
212236
# transformer. This is necessary when loading the transformer
@@ -245,8 +269,10 @@ def fit(self, X, y=None, **fit_params):
245269
246270
"""
247271
Xt, yt, fit_params = self._fit(X, y, **fit_params)
248-
if self._final_estimator != "passthrough":
249-
self._final_estimator.fit(Xt, yt, **fit_params)
272+
with _print_elapsed_time('Pipeline',
273+
self._log_message(len(self.steps) - 1)):
274+
if self._final_estimator != "passthrough":
275+
self._final_estimator.fit(Xt, yt, **fit_params)
250276
return self
251277

252278
def fit_transform(self, X, y=None, **fit_params):
@@ -279,12 +305,14 @@ def fit_transform(self, X, y=None, **fit_params):
279305
"""
280306
last_step = self._final_estimator
281307
Xt, yt, fit_params = self._fit(X, y, **fit_params)
282-
if last_step == "passthrough":
283-
return Xt
284-
elif hasattr(last_step, "fit_transform"):
285-
return last_step.fit_transform(Xt, yt, **fit_params)
286-
else:
287-
return last_step.fit(Xt, yt, **fit_params).transform(Xt)
308+
with _print_elapsed_time('Pipeline',
309+
self._log_message(len(self.steps) - 1)):
310+
if last_step == "passthrough":
311+
return Xt
312+
elif hasattr(last_step, "fit_transform"):
313+
return last_step.fit_transform(Xt, yt, **fit_params)
314+
else:
315+
return last_step.fit(Xt, yt, **fit_params).transform(Xt)
288316

289317
def fit_resample(self, X, y=None, **fit_params):
290318
"""Fit the model and sample with the final estimator
@@ -319,10 +347,12 @@ def fit_resample(self, X, y=None, **fit_params):
319347
"""
320348
last_step = self._final_estimator
321349
Xt, yt, fit_params = self._fit(X, y, **fit_params)
322-
if last_step == "passthrough":
323-
return Xt
324-
elif hasattr(last_step, "fit_resample"):
325-
return last_step.fit_resample(Xt, yt, **fit_params)
350+
with _print_elapsed_time('Pipeline',
351+
self._log_message(len(self.steps) - 1)):
352+
if last_step == "passthrough":
353+
return Xt
354+
elif hasattr(last_step, "fit_resample"):
355+
return last_step.fit_resample(Xt, yt, **fit_params)
326356

327357
@if_delegate_has_method(delegate="_final_estimator")
328358
def predict(self, X, **predict_params):
@@ -384,7 +414,10 @@ def fit_predict(self, X, y=None, **fit_params):
384414
y_pred : array-like
385415
"""
386416
Xt, yt, fit_params = self._fit(X, y, **fit_params)
387-
return self.steps[-1][-1].fit_predict(Xt, yt, **fit_params)
417+
with _print_elapsed_time('Pipeline',
418+
self._log_message(len(self.steps) - 1)):
419+
y_pred = self.steps[-1][-1].fit_predict(Xt, yt, **fit_params)
420+
return y_pred
388421

389422
@if_delegate_has_method(delegate="_final_estimator")
390423
def predict_proba(self, X):
@@ -575,22 +608,55 @@ def score(self, X, y=None, sample_weight=None):
575608
score_params["sample_weight"] = sample_weight
576609
return self.steps[-1][-1].score(Xt, y, **score_params)
577610

611+
@if_delegate_has_method(delegate='_final_estimator')
612+
def score_samples(self, X):
613+
"""Apply transforms, and score_samples of the final estimator.
614+
Parameters
615+
----------
616+
X : iterable
617+
Data to predict on. Must fulfill input requirements of first step
618+
of the pipeline.
619+
Returns
620+
-------
621+
y_score : ndarray, shape (n_samples,)
622+
"""
623+
Xt = X
624+
for _, _, transformer in self._iter(with_final=False):
625+
if hasattr(transformer, "fit_resample"):
626+
pass
627+
else:
628+
Xt = transformer.transform(Xt)
629+
return self.steps[-1][-1].score_samples(Xt)
630+
578631

579-
def _fit_transform_one(transformer, weight, X, y, **fit_params):
580-
if hasattr(transformer, "fit_transform"):
581-
res = transformer.fit_transform(X, y, **fit_params)
582-
else:
583-
res = transformer.fit(X, y, **fit_params).transform(X)
632+
def _fit_transform_one(transformer,
633+
X,
634+
y,
635+
weight,
636+
message_clsname='',
637+
message=None,
638+
**fit_params):
639+
with _print_elapsed_time(message_clsname, message):
640+
if hasattr(transformer, "fit_transform"):
641+
res = transformer.fit_transform(X, y, **fit_params)
642+
else:
643+
res = transformer.fit(X, y, **fit_params).transform(X)
584644
# if we have a weight for this transformer, multiply output
585645
if weight is None:
586646
return res, transformer
587647
return res * weight, transformer
588648

589649

590-
def _fit_resample_one(sampler, X, y, **fit_params):
591-
X_res, y_res = sampler.fit_resample(X, y, **fit_params)
650+
def _fit_resample_one(sampler,
651+
X,
652+
y,
653+
message_clsname='',
654+
message=None,
655+
**fit_params):
656+
with _print_elapsed_time(message_clsname, message):
657+
X_res, y_res = sampler.fit_resample(X, y, **fit_params)
592658

593-
return X_res, y_res, sampler
659+
return X_res, y_res, sampler
594660

595661

596662
def make_pipeline(*steps, **kwargs):
@@ -614,6 +680,10 @@ def make_pipeline(*steps, **kwargs):
614680
inspect estimators within the pipeline. Caching the
615681
transformers is advantageous when fitting is time consuming.
616682
683+
verbose : boolean, optional (default=False)
684+
If True, the time elapsed while fitting each step will be printed as it
685+
is completed.
686+
617687
Returns
618688
-------
619689
p : Pipeline
@@ -637,8 +707,11 @@ def make_pipeline(*steps, **kwargs):
637707
verbose=False)
638708
"""
639709
memory = kwargs.pop("memory", None)
710+
verbose = kwargs.pop('verbose', False)
640711
if kwargs:
641712
raise TypeError(
642713
'Unknown keyword arguments: "{}"'.format(list(kwargs.keys())[0])
643714
)
644-
return Pipeline(pipeline._name_estimators(steps), memory=memory)
715+
return Pipeline(
716+
pipeline._name_estimators(steps), memory=memory, verbose=verbose
717+
)

imblearn/tests/test_pipeline.py

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
# Christos Aridas
66
# License: MIT
77

8-
from tempfile import mkdtemp
8+
import itertools
9+
import re
910
import shutil
1011
import time
12+
from tempfile import mkdtemp
1113

1214
import numpy as np
1315
import pytest
@@ -29,12 +31,12 @@
2931
from sklearn.feature_selection import SelectKBest, f_classif
3032
from sklearn.datasets import load_iris, make_classification
3133
from sklearn.preprocessing import StandardScaler
34+
from sklearn.pipeline import FeatureUnion
3235

36+
from imblearn.datasets import make_imbalance
3337
from imblearn.pipeline import Pipeline, make_pipeline
34-
from imblearn.under_sampling import (
35-
RandomUnderSampler,
36-
EditedNearestNeighbours as ENN,
37-
)
38+
from imblearn.under_sampling import RandomUnderSampler
39+
from imblearn.under_sampling import EditedNearestNeighbours as ENN
3840

3941

4042
JUNK_FOOD_DOCS = (
@@ -1261,3 +1263,82 @@ def test_score_samples_on_pipeline_without_score_samples():
12611263
"'score_samples'",
12621264
):
12631265
pipe.score_samples(X)
1266+
1267+
1268+
def test_pipeline_param_error():
1269+
clf = make_pipeline(LogisticRegression())
1270+
with pytest.raises(ValueError, match="Pipeline.fit does not accept "
1271+
"the sample_weight parameter"):
1272+
clf.fit([[0], [0]], [0, 1], sample_weight=[1, 1])
1273+
1274+
1275+
parameter_grid_test_verbose = ((est, pattern, method) for
1276+
(est, pattern), method in itertools.product(
1277+
[
1278+
(Pipeline([('transf', Transf()), ('clf', FitParamT())]),
1279+
r'\[Pipeline\].*\(step 1 of 2\) Processing transf.* total=.*\n'
1280+
r'\[Pipeline\].*\(step 2 of 2\) Processing clf.* total=.*\n$'),
1281+
(Pipeline([('transf', Transf()), ('noop', None),
1282+
('clf', FitParamT())]),
1283+
r'\[Pipeline\].*\(step 1 of 3\) Processing transf.* total=.*\n'
1284+
r'\[Pipeline\].*\(step 2 of 3\) Processing noop.* total=.*\n'
1285+
r'\[Pipeline\].*\(step 3 of 3\) Processing clf.* total=.*\n$'),
1286+
(Pipeline([('transf', Transf()), ('noop', 'passthrough'),
1287+
('clf', FitParamT())]),
1288+
r'\[Pipeline\].*\(step 1 of 3\) Processing transf.* total=.*\n'
1289+
r'\[Pipeline\].*\(step 2 of 3\) Processing noop.* total=.*\n'
1290+
r'\[Pipeline\].*\(step 3 of 3\) Processing clf.* total=.*\n$'),
1291+
(Pipeline([('transf', Transf()), ('clf', None)]),
1292+
r'\[Pipeline\].*\(step 1 of 2\) Processing transf.* total=.*\n'
1293+
r'\[Pipeline\].*\(step 2 of 2\) Processing clf.* total=.*\n$'),
1294+
(Pipeline([('transf', None), ('mult', Mult())]),
1295+
r'\[Pipeline\].*\(step 1 of 2\) Processing transf.* total=.*\n'
1296+
r'\[Pipeline\].*\(step 2 of 2\) Processing mult.* total=.*\n$'),
1297+
(Pipeline([('transf', 'passthrough'), ('mult', Mult())]),
1298+
r'\[Pipeline\].*\(step 1 of 2\) Processing transf.* total=.*\n'
1299+
r'\[Pipeline\].*\(step 2 of 2\) Processing mult.* total=.*\n$'),
1300+
(FeatureUnion([('mult1', Mult()), ('mult2', Mult())]),
1301+
r'\[FeatureUnion\].*\(step 1 of 2\) Processing mult1.* total=.*\n'
1302+
r'\[FeatureUnion\].*\(step 2 of 2\) Processing mult2.* total=.*\n$'),
1303+
(FeatureUnion([('mult1', 'drop'), ('mult2', Mult()), ('mult3', 'drop')]),
1304+
r'\[FeatureUnion\].*\(step 1 of 1\) Processing mult2.* total=.*\n$')
1305+
], ['fit', 'fit_transform', 'fit_predict'])
1306+
if hasattr(est, method) and not (
1307+
method == 'fit_transform' and hasattr(est, 'steps') and
1308+
isinstance(est.steps[-1][1], FitParamT))
1309+
)
1310+
1311+
1312+
@pytest.mark.parametrize('est, pattern, method', parameter_grid_test_verbose)
1313+
def test_verbose(est, method, pattern, capsys):
1314+
func = getattr(est, method)
1315+
1316+
X = [[1, 2, 3], [4, 5, 6]]
1317+
y = [[7], [8]]
1318+
1319+
est.set_params(verbose=False)
1320+
func(X, y)
1321+
assert not capsys.readouterr().out, 'Got output for verbose=False'
1322+
1323+
est.set_params(verbose=True)
1324+
func(X, y)
1325+
assert re.match(pattern, capsys.readouterr().out)
1326+
1327+
1328+
def test_pipeline_score_samples_pca_lof():
1329+
X, y = load_iris(return_X_y=True)
1330+
sampling_strategy = {0: 50, 1: 30, 2: 20}
1331+
X, y = make_imbalance(X, y, sampling_strategy=sampling_strategy)
1332+
# Test that the score_samples method is implemented on a pipeline.
1333+
# Test that the score_samples method on pipeline yields same results as
1334+
# applying transform and score_samples steps separately.
1335+
rus = RandomUnderSampler()
1336+
pca = PCA(svd_solver='full', n_components='mle', whiten=True)
1337+
lof = LocalOutlierFactor(novelty=True)
1338+
pipe = Pipeline([('rus', rus), ('pca', pca), ('lof', lof)])
1339+
pipe.fit(X, y)
1340+
# Check the shapes
1341+
assert pipe.score_samples(X).shape == (X.shape[0],)
1342+
# Check the values
1343+
lof.fit(pca.fit_transform(X))
1344+
assert_allclose(pipe.score_samples(X), lof.score_samples(pca.transform(X)))

setup.cfg

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ addopts =
2828
--ignore examples
2929
--ignore maint_tools
3030
--doctest-modules
31-
--disable-pytest-warnings
3231
-rs
3332

3433
filterwarnings =

0 commit comments

Comments
 (0)