Skip to content

Commit bb1b9a6

Browse files
committed
other updates
1 parent 1fcea5f commit bb1b9a6

File tree

3 files changed

+147
-27
lines changed

3 files changed

+147
-27
lines changed

imblearn/pipeline.py

Lines changed: 90 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ class Pipeline(pipeline.Pipeline):
5858
inspect estimators within the pipeline. Caching the
5959
transformers is advantageous when fitting is time consuming.
6060
61+
verbose : boolean, optional (default=False)
62+
If True, the time elapsed while fitting each step will be printed as it
63+
is completed.
6164
6265
Attributes
6366
----------
@@ -188,7 +191,14 @@ def _fit(self, X, y=None, **fit_params):
188191
"=sample_weight)`.".format(pname))
189192
step, param = pname.split("__", 1)
190193
fit_params_steps[step][param] = pval
191-
for step_idx, name, transformer in self._iter(with_final=False):
194+
for (step_idx,
195+
name,
196+
transformer) in self._iter(with_final=False,
197+
filter_passthrough=False):
198+
if (transformer is None or transformer == 'passthrough'):
199+
with _print_elapsed_time('Pipeline',
200+
self._log_message(step_idx)):
201+
continue
192202
if hasattr(memory, "location"):
193203
# joblib >= 0.12
194204
if memory.location is None:
@@ -210,11 +220,17 @@ def _fit(self, X, y=None, **fit_params):
210220
cloned_transformer, "fit_transform"
211221
):
212222
X, fitted_transformer = fit_transform_one_cached(
213-
cloned_transformer, None, X, y, **fit_params_steps[name]
223+
cloned_transformer, X, y, None,
224+
message_clsname='Pipeline',
225+
message=self._log_message(step_idx),
226+
**fit_params_steps[name]
214227
)
215228
elif hasattr(cloned_transformer, "fit_resample"):
216229
X, y, fitted_transformer = fit_resample_one_cached(
217-
cloned_transformer, X, y, **fit_params_steps[name]
230+
cloned_transformer, X, y,
231+
message_clsname='Pipeline',
232+
message=self._log_message(step_idx),
233+
**fit_params_steps[name]
218234
)
219235
# Replace the transformer of the step with the fitted
220236
# transformer. This is necessary when loading the transformer
@@ -253,8 +269,10 @@ def fit(self, X, y=None, **fit_params):
253269
254270
"""
255271
Xt, yt, fit_params = self._fit(X, y, **fit_params)
256-
if self._final_estimator != "passthrough":
257-
self._final_estimator.fit(Xt, yt, **fit_params)
272+
with _print_elapsed_time('Pipeline',
273+
self._log_message(len(self.steps) - 1)):
274+
if self._final_estimator != "passthrough":
275+
self._final_estimator.fit(Xt, yt, **fit_params)
258276
return self
259277

260278
def fit_transform(self, X, y=None, **fit_params):
@@ -287,12 +305,14 @@ def fit_transform(self, X, y=None, **fit_params):
287305
"""
288306
last_step = self._final_estimator
289307
Xt, yt, fit_params = self._fit(X, y, **fit_params)
290-
if last_step == "passthrough":
291-
return Xt
292-
elif hasattr(last_step, "fit_transform"):
293-
return last_step.fit_transform(Xt, yt, **fit_params)
294-
else:
295-
return last_step.fit(Xt, yt, **fit_params).transform(Xt)
308+
with _print_elapsed_time('Pipeline',
309+
self._log_message(len(self.steps) - 1)):
310+
if last_step == "passthrough":
311+
return Xt
312+
elif hasattr(last_step, "fit_transform"):
313+
return last_step.fit_transform(Xt, yt, **fit_params)
314+
else:
315+
return last_step.fit(Xt, yt, **fit_params).transform(Xt)
296316

297317
def fit_resample(self, X, y=None, **fit_params):
298318
"""Fit the model and sample with the final estimator
@@ -327,10 +347,12 @@ def fit_resample(self, X, y=None, **fit_params):
327347
"""
328348
last_step = self._final_estimator
329349
Xt, yt, fit_params = self._fit(X, y, **fit_params)
330-
if last_step == "passthrough":
331-
return Xt
332-
elif hasattr(last_step, "fit_resample"):
333-
return last_step.fit_resample(Xt, yt, **fit_params)
350+
with _print_elapsed_time('Pipeline',
351+
self._log_message(len(self.steps) - 1)):
352+
if last_step == "passthrough":
353+
return Xt
354+
elif hasattr(last_step, "fit_resample"):
355+
return last_step.fit_resample(Xt, yt, **fit_params)
334356

335357
@if_delegate_has_method(delegate="_final_estimator")
336358
def predict(self, X, **predict_params):
@@ -392,7 +414,10 @@ def fit_predict(self, X, y=None, **fit_params):
392414
y_pred : array-like
393415
"""
394416
Xt, yt, fit_params = self._fit(X, y, **fit_params)
395-
return self.steps[-1][-1].fit_predict(Xt, yt, **fit_params)
417+
with _print_elapsed_time('Pipeline',
418+
self._log_message(len(self.steps) - 1)):
419+
y_pred = self.steps[-1][-1].fit_predict(Xt, yt, **fit_params)
420+
return y_pred
396421

397422
@if_delegate_has_method(delegate="_final_estimator")
398423
def predict_proba(self, X):
@@ -583,22 +608,55 @@ def score(self, X, y=None, sample_weight=None):
583608
score_params["sample_weight"] = sample_weight
584609
return self.steps[-1][-1].score(Xt, y, **score_params)
585610

611+
@if_delegate_has_method(delegate='_final_estimator')
612+
def score_samples(self, X):
613+
"""Apply transforms, and score_samples of the final estimator.
614+
Parameters
615+
----------
616+
X : iterable
617+
Data to predict on. Must fulfill input requirements of first step
618+
of the pipeline.
619+
Returns
620+
-------
621+
y_score : ndarray, shape (n_samples,)
622+
"""
623+
Xt = X
624+
for _, _, transformer in self._iter(with_final=False):
625+
if hasattr(transform, "fit_resample"):
626+
pass
627+
else:
628+
Xt = transformer.transform(Xt)
629+
return self.steps[-1][-1].score_samples(Xt)
630+
586631

587-
def _fit_transform_one(transformer, weight, X, y, **fit_params):
588-
if hasattr(transformer, "fit_transform"):
589-
res = transformer.fit_transform(X, y, **fit_params)
590-
else:
591-
res = transformer.fit(X, y, **fit_params).transform(X)
632+
def _fit_transform_one(transformer,
633+
X,
634+
y,
635+
weight,
636+
message_clsname='',
637+
message=None,
638+
**fit_params):
639+
with _print_elapsed_time(message_clsname, message):
640+
if hasattr(transformer, "fit_transform"):
641+
res = transformer.fit_transform(X, y, **fit_params)
642+
else:
643+
res = transformer.fit(X, y, **fit_params).transform(X)
592644
# if we have a weight for this transformer, multiply output
593645
if weight is None:
594646
return res, transformer
595647
return res * weight, transformer
596648

597649

598-
def _fit_resample_one(sampler, X, y, **fit_params):
599-
X_res, y_res = sampler.fit_resample(X, y, **fit_params)
650+
def _fit_resample_one(sampler,
651+
X,
652+
y,
653+
message_clsname='',
654+
message=None,
655+
**fit_params):
656+
with _print_elapsed_time(message_clsname, message):
657+
X_res, y_res = sampler.fit_resample(X, y, **fit_params)
600658

601-
return X_res, y_res, sampler
659+
return X_res, y_res, sampler
602660

603661

604662
def make_pipeline(*steps, **kwargs):
@@ -622,6 +680,10 @@ def make_pipeline(*steps, **kwargs):
622680
inspect estimators within the pipeline. Caching the
623681
transformers is advantageous when fitting is time consuming.
624682
683+
verbose : boolean, optional (default=False)
684+
If True, the time elapsed while fitting each step will be printed as it
685+
is completed.
686+
625687
Returns
626688
-------
627689
p : Pipeline
@@ -645,8 +707,11 @@ def make_pipeline(*steps, **kwargs):
645707
verbose=False)
646708
"""
647709
memory = kwargs.pop("memory", None)
710+
verbose = kwargs.pop('verbose', False)
648711
if kwargs:
649712
raise TypeError(
650713
'Unknown keyword arguments: "{}"'.format(list(kwargs.keys())[0])
651714
)
652-
return Pipeline(pipeline._name_estimators(steps), memory=memory)
715+
return Pipeline(
716+
pipeline._name_estimators(steps), memory=memory, verbose=verbose
717+
)

imblearn/tests/test_pipeline.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
# Christos Aridas
66
# License: MIT
77

8-
from tempfile import mkdtemp
8+
import itertools
9+
import re
910
import shutil
1011
import time
12+
from tempfile import mkdtemp
1113

1214
import numpy as np
1315
import pytest
@@ -29,6 +31,7 @@
2931
from sklearn.feature_selection import SelectKBest, f_classif
3032
from sklearn.datasets import load_iris, make_classification
3133
from sklearn.preprocessing import StandardScaler
34+
from sklearn.pipeline import FeatureUnion
3235

3336
from imblearn.pipeline import Pipeline, make_pipeline
3437
from imblearn.under_sampling import (
@@ -1268,3 +1271,56 @@ def test_pipeline_param_error():
12681271
with pytest.raises(ValueError, match="Pipeline.fit does not accept "
12691272
"the sample_weight parameter"):
12701273
clf.fit([[0], [0]], [0, 1], sample_weight=[1, 1])
1274+
1275+
1276+
parameter_grid_test_verbose = ((est, pattern, method) for
1277+
(est, pattern), method in itertools.product(
1278+
[
1279+
(Pipeline([('transf', Transf()), ('clf', FitParamT())]),
1280+
r'\[Pipeline\].*\(step 1 of 2\) Processing transf.* total=.*\n'
1281+
r'\[Pipeline\].*\(step 2 of 2\) Processing clf.* total=.*\n$'),
1282+
(Pipeline([('transf', Transf()), ('noop', None),
1283+
('clf', FitParamT())]),
1284+
r'\[Pipeline\].*\(step 1 of 3\) Processing transf.* total=.*\n'
1285+
r'\[Pipeline\].*\(step 2 of 3\) Processing noop.* total=.*\n'
1286+
r'\[Pipeline\].*\(step 3 of 3\) Processing clf.* total=.*\n$'),
1287+
(Pipeline([('transf', Transf()), ('noop', 'passthrough'),
1288+
('clf', FitParamT())]),
1289+
r'\[Pipeline\].*\(step 1 of 3\) Processing transf.* total=.*\n'
1290+
r'\[Pipeline\].*\(step 2 of 3\) Processing noop.* total=.*\n'
1291+
r'\[Pipeline\].*\(step 3 of 3\) Processing clf.* total=.*\n$'),
1292+
(Pipeline([('transf', Transf()), ('clf', None)]),
1293+
r'\[Pipeline\].*\(step 1 of 2\) Processing transf.* total=.*\n'
1294+
r'\[Pipeline\].*\(step 2 of 2\) Processing clf.* total=.*\n$'),
1295+
(Pipeline([('transf', None), ('mult', Mult())]),
1296+
r'\[Pipeline\].*\(step 1 of 2\) Processing transf.* total=.*\n'
1297+
r'\[Pipeline\].*\(step 2 of 2\) Processing mult.* total=.*\n$'),
1298+
(Pipeline([('transf', 'passthrough'), ('mult', Mult())]),
1299+
r'\[Pipeline\].*\(step 1 of 2\) Processing transf.* total=.*\n'
1300+
r'\[Pipeline\].*\(step 2 of 2\) Processing mult.* total=.*\n$'),
1301+
(FeatureUnion([('mult1', Mult()), ('mult2', Mult())]),
1302+
r'\[FeatureUnion\].*\(step 1 of 2\) Processing mult1.* total=.*\n'
1303+
r'\[FeatureUnion\].*\(step 2 of 2\) Processing mult2.* total=.*\n$'),
1304+
(FeatureUnion([('mult1', 'drop'), ('mult2', Mult()), ('mult3', 'drop')]),
1305+
r'\[FeatureUnion\].*\(step 1 of 1\) Processing mult2.* total=.*\n$')
1306+
], ['fit', 'fit_transform', 'fit_predict'])
1307+
if hasattr(est, method) and not (
1308+
method == 'fit_transform' and hasattr(est, 'steps') and
1309+
isinstance(est.steps[-1][1], FitParamT))
1310+
)
1311+
1312+
1313+
@pytest.mark.parametrize('est, pattern, method', parameter_grid_test_verbose)
1314+
def test_verbose(est, method, pattern, capsys):
1315+
func = getattr(est, method)
1316+
1317+
X = [[1, 2, 3], [4, 5, 6]]
1318+
y = [[7], [8]]
1319+
1320+
est.set_params(verbose=False)
1321+
func(X, y)
1322+
assert not capsys.readouterr().out, 'Got output for verbose=False'
1323+
1324+
est.set_params(verbose=True)
1325+
func(X, y)
1326+
assert re.match(pattern, capsys.readouterr().out)

setup.cfg

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ addopts =
2828
--ignore examples
2929
--ignore maint_tools
3030
--doctest-modules
31-
--disable-pytest-warnings
3231
-rs
3332

3433
filterwarnings =

0 commit comments

Comments
 (0)