Skip to content

Commit 95247e6

Browse files
committed
ENH make RandomUnderSampler accept dask array
1 parent edd7522 commit 95247e6

File tree

9 files changed

+252
-22
lines changed

9 files changed

+252
-22
lines changed

imblearn/dask/__init__.py

Whitespace-only changes.

imblearn/dask/_support.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
_REGISTERED_DASK_CONTAINER = []
2+
3+
try:
4+
from dask import array, dataframe
5+
_REGISTERED_DASK_CONTAINER += [
6+
array.Array, dataframe.Series, dataframe.DataFrame,
7+
]
8+
except ImportError:
9+
pass
10+
11+
12+
def is_dask_container(container):
13+
return isinstance(container, tuple(_REGISTERED_DASK_CONTAINER))

imblearn/dask/tests/__init__.py

Whitespace-only changes.

imblearn/dask/tests/test_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import numpy as np
2+
import pytest
3+
from dask import array
4+
from dask_ml.datasets import make_classification
5+
6+
from imblearn.dask.utils import is_multilabel
7+
from imblearn.dask.utils import type_of_target
8+
9+
10+
def test_type_of_target_error():
11+
y = np.arange(10)
12+
13+
err_msg = "Expected a Dask array, series or dataframe."
14+
with pytest.raises(ValueError, match=err_msg):
15+
type_of_target(y)
16+
17+
18+
@pytest.mark.parametrize(
19+
"y, expected_result",
20+
[
21+
(array.from_array(np.array([0, 1, 0, 1])), False),
22+
(array.from_array(np.array([[1, 0], [0, 0]])), True),
23+
(array.from_array(np.array([[1], [0], [0]])), False),
24+
(array.from_array(np.array([[1, 0, 0]])), True),
25+
]
26+
)
27+
def test_is_multilabel(y, expected_result):
28+
assert is_multilabel(y) is expected_result
29+
30+
31+
@pytest.mark.parametrize(
32+
"y, expected_type_of_target",
33+
[
34+
(array.from_array(np.array([[1, 0], [0, 0]])), "multilabel-indicator"),
35+
(array.from_array(np.array([[1, 0, 0]])), "multilabel-indicator"),
36+
(array.from_array(np.array([[[1, 2]]])), "unknown"),
37+
(array.from_array(np.array([[]])), "unknown"),
38+
(array.from_array(np.array([.1, .2, 3])), "continuous"),
39+
(array.from_array(np.array([[.1, .2, 3]])), "continuous-multioutput"),
40+
(array.from_array(np.array([[1., .2]])), "continuous-multioutput"),
41+
(array.from_array(np.array([1, 2])), "binary"),
42+
(array.from_array(np.array(["a", "b"])), "binary"),
43+
]
44+
)
45+
def test_type_of_target(y, expected_type_of_target):
46+
target_type = type_of_target(y)
47+
assert target_type == expected_type_of_target

imblearn/dask/utils.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import warnings
2+
3+
from dask import dataframe
4+
from dask import array
5+
from sklearn.exceptions import DataConversionWarning
6+
from sklearn.utils.multiclass import _is_integral_float
7+
8+
9+
def is_multilabel(y):
10+
if not (y.ndim == 2 and y.shape[1] > 1):
11+
return False
12+
13+
labels = array.unique(y).compute()
14+
15+
return len(labels) < 3 and (
16+
y.dtype.kind in 'biu' or _is_integral_float(labels)
17+
)
18+
19+
20+
def type_of_target(y):
21+
if is_multilabel(y):
22+
return 'multilabel-indicator'
23+
24+
if y.ndim > 2:
25+
return 'unknown'
26+
27+
if y.ndim == 2 and y.shape[1] == 0:
28+
return 'unknown' # [[]]
29+
30+
if y.ndim == 2 and y.shape[1] > 1:
31+
# [[1, 2], [1, 2]]
32+
suffix = "-multioutput"
33+
else:
34+
# [1, 2, 3] or [[1], [2], [3]]
35+
suffix = ""
36+
37+
# check float and contains non-integer float values
38+
if y.dtype.kind == 'f' and array.any(y != y.astype(int)):
39+
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.]
40+
# NOTE: we don't check for infinite values
41+
return 'continuous' + suffix
42+
43+
labels = array.unique(y).compute()
44+
if (len((labels)) > 2) or (y.ndim >= 2 and len(y[0]) > 1):
45+
# [1, 2, 3] or [[1., 2., 3]] or [[1, 2]]
46+
return 'multiclass' + suffix
47+
# [1, 2] or [["a"], ["b"]]
48+
return 'binary'
49+
50+
51+
def column_or_1d(y, *, warn=False):
52+
shape = y.shape
53+
if len(shape) == 1:
54+
return y.ravel()
55+
if len(shape) == 2 and shape[1] == 1:
56+
if warn:
57+
warnings.warn(
58+
"A column-vector y was passed when a 1d array was expected. "
59+
"Please change the shape of y to (n_samples, ), for example "
60+
"using ravel().", DataConversionWarning, stacklevel=2
61+
)
62+
return y.ravel()
63+
64+
raise ValueError(
65+
f"y should be a 1d array. Got an array of shape {shape} instead."
66+
)

imblearn/under_sampling/_prototype_selection/_random_under_sampler.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sklearn.utils import _safe_indexing
1111

1212
from ..base import BaseUnderSampler
13+
from ...dask._support import is_dask_container
1314
from ...utils import check_target_type
1415
from ...utils import Substitution
1516
from ...utils._docstring import _random_state_docstring
@@ -80,44 +81,66 @@ def __init__(
8081
self.replacement = replacement
8182

8283
def _check_X_y(self, X, y):
83-
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
84-
X, y = self._validate_data(
85-
X, y, reset=True, accept_sparse=["csr", "csc"], dtype=None,
86-
force_all_finite=False,
84+
y, binarize_y, self._uniques = check_target_type(
85+
y,
86+
indicate_one_vs_all=True,
87+
return_unique=True,
8788
)
89+
if not any([is_dask_container(arr) for arr in (X, y)]):
90+
X, y = self._validate_data(
91+
X,
92+
y,
93+
reset=True,
94+
accept_sparse=["csr", "csc"],
95+
dtype=None,
96+
force_all_finite=False,
97+
)
8898
return X, y, binarize_y
8999

100+
@staticmethod
101+
def _find_target_class_indices(y, target_class):
102+
target_class_indices = np.flatnonzero(y == target_class)
103+
if is_dask_container(y):
104+
return target_class_indices.compute()
105+
return target_class_indices
106+
90107
def _fit_resample(self, X, y):
91108
random_state = check_random_state(self.random_state)
92109

93-
idx_under = np.empty((0,), dtype=int)
110+
idx_under = []
94111

95-
for target_class in np.unique(y):
112+
for target_class in self._uniques:
113+
target_class_indices = self._find_target_class_indices(
114+
y, target_class
115+
)
96116
if target_class in self.sampling_strategy_.keys():
97117
n_samples = self.sampling_strategy_[target_class]
98118
index_target_class = random_state.choice(
99-
range(np.count_nonzero(y == target_class)),
119+
target_class_indices.size,
100120
size=n_samples,
101121
replace=self.replacement,
102122
)
103123
else:
104124
index_target_class = slice(None)
105125

106-
idx_under = np.concatenate(
107-
(
108-
idx_under,
109-
np.flatnonzero(y == target_class)[index_target_class],
110-
),
111-
axis=0,
112-
)
126+
selected_indices = target_class_indices[index_target_class]
127+
idx_under.append(selected_indices)
113128

114-
self.sample_indices_ = idx_under
129+
self.sample_indices_ = np.hstack(idx_under)
130+
self.sample_indices_.sort()
115131

116-
return _safe_indexing(X, idx_under), _safe_indexing(y, idx_under)
132+
return (
133+
_safe_indexing(X, self.sample_indices_),
134+
_safe_indexing(y, self.sample_indices_)
135+
)
117136

118137
def _more_tags(self):
119138
return {
120-
"X_types": ["2darray", "string"],
139+
"X_types": [
140+
"2darray",
141+
"string",
142+
"dask-array",
143+
],
121144
"sample_indices": True,
122145
"allow_nan": True,
123146
}

imblearn/utils/_validation.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
from sklearn.base import clone
1515
from sklearn.neighbors._base import KNeighborsMixin
1616
from sklearn.neighbors import NearestNeighbors
17-
from sklearn.utils import column_or_1d
18-
from sklearn.utils.multiclass import type_of_target
1917

18+
from ..dask._support import is_dask_container
2019
from ..exceptions import raise_isinstance_error
20+
from .wrapper import _is_multiclass_encoded
21+
from .wrapper import column_or_1d
22+
from .wrapper import type_of_target
23+
from .wrapper import unique
2124

2225
SAMPLING_KIND = (
2326
"over-sampling",
@@ -99,10 +102,12 @@ def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
99102

100103
def _count_class_sample(y):
101104
unique, counts = np.unique(y, return_counts=True)
105+
if is_dask_container(unique):
106+
unique, counts = unique.compute(), counts.compute()
102107
return dict(zip(unique, counts))
103108

104109

105-
def check_target_type(y, indicate_one_vs_all=False):
110+
def check_target_type(y, indicate_one_vs_all=False, return_unique=False):
106111
"""Check the target types to be conform to the current samplers.
107112
108113
The current samplers should be compatible with ``'binary'``,
@@ -116,18 +121,24 @@ def check_target_type(y, indicate_one_vs_all=False):
116121
indicate_one_vs_all : bool, default=False
117122
Either to indicate if the targets are encoded in a one-vs-all fashion.
118123
124+
return_unique : bool, default=False
125+
Either to return or not the unique values in y.
126+
119127
Returns
120128
-------
121129
y : ndarray
122130
The returned target.
123131
132+
y_unique : ndarray
133+
The unique values in `y`.
134+
124135
is_one_vs_all : bool, optional
125136
Indicate if the target was originally encoded in a one-vs-all fashion.
126137
Only returned if ``indicate_multilabel=True``.
127138
"""
128139
type_y = type_of_target(y)
129140
if type_y == "multilabel-indicator":
130-
if np.any(y.sum(axis=1) > 1):
141+
if not _is_multiclass_encoded(y):
131142
raise ValueError(
132143
"Imbalanced-learn currently supports binary, multiclass and "
133144
"binarized encoded multiclasss targets. Multilabel and "
@@ -137,7 +148,13 @@ def check_target_type(y, indicate_one_vs_all=False):
137148
else:
138149
y = column_or_1d(y)
139150

140-
return (y, type_y == "multilabel-indicator") if indicate_one_vs_all else y
151+
output = [y]
152+
if indicate_one_vs_all:
153+
output += [type_y == "multilabel-indicator"]
154+
if return_unique:
155+
output += [unique(y)]
156+
157+
return output
141158

142159

143160
def _sampling_strategy_all(y, sampling_type):

imblearn/utils/estimator_checks.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,16 @@ def _set_checking_parameters(estimator):
5151

5252

5353
def _yield_sampler_checks(sampler):
54+
tags = sampler._get_tags()
5455
yield check_target_type
5556
yield check_samplers_one_label
5657
yield check_samplers_fit
5758
yield check_samplers_fit_resample
5859
yield check_samplers_sampling_strategy_fit_resample
5960
yield check_samplers_sparse
6061
yield check_samplers_pandas
62+
if "dask-array" in tags["X_types"]:
63+
yield check_samplers_dask_array
6164
yield check_samplers_list
6265
yield check_samplers_multiclass_ova
6366
yield check_samplers_preserve_dtype
@@ -290,6 +293,30 @@ def check_samplers_pandas(name, sampler):
290293
assert_allclose(y_res_s.to_numpy(), y_res)
291294

292295

296+
def check_samplers_dask_array(name, sampler):
297+
dask = pytest.importorskip("dask")
298+
# Check that the samplers handle pandas dataframe and pandas series
299+
X, y = make_classification(
300+
n_samples=1000,
301+
n_classes=3,
302+
n_informative=4,
303+
weights=[0.2, 0.3, 0.5],
304+
random_state=0,
305+
)
306+
X_dask = dask.array.from_array(X, chunks=100)
307+
y_dask = dask.array.from_array(y, chunks=100)
308+
309+
X_res_dask, y_res_dask = sampler.fit_resample(X_dask, y_dask)
310+
X_res, y_res = sampler.fit_resample(X, y)
311+
312+
# check that we return the same type for dataframes or series types
313+
assert isinstance(X_res_dask, dask.array.Array)
314+
assert isinstance(y_res_dask, dask.array.Array)
315+
316+
assert_allclose(X_res_dask, X_res)
317+
assert_allclose(y_res_dask, y_res)
318+
319+
293320
def check_samplers_list(name, sampler):
294321
# Check that the can samplers handle simple lists
295322
X, y = make_classification(

imblearn/utils/wrapper.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numpy as np
2+
3+
from sklearn.utils.multiclass import type_of_target as sklearn_type_of_target
4+
from sklearn.utils.validation import column_or_1d as sklearn_column_or_1d
5+
6+
from ..dask._support import is_dask_container
7+
8+
9+
def type_of_target(y):
10+
if is_dask_container(y):
11+
from ..dask.utils import type_of_target as dask_type_of_target
12+
13+
return dask_type_of_target(y)
14+
return sklearn_type_of_target(y)
15+
16+
17+
def _is_multiclass_encoded(y):
18+
if is_dask_container(y):
19+
from dask import array
20+
21+
return array.all(y.sum(axis=1) == 1).compute()
22+
return np.all(y.sum(axis=1) == 1)
23+
24+
25+
def column_or_1d(y, *, warn=False):
26+
if is_dask_container(y):
27+
from ..dask.utils import column_or_1d as dask_column_or_1d
28+
29+
return dask_column_or_1d(y, warn=warn)
30+
return sklearn_column_or_1d(y, warn=warn)
31+
32+
33+
def unique(*args, **kwargs):
34+
output = np.unique(args, kwargs)
35+
if is_dask_container(output):
36+
return (arr.compute() for arr in output)
37+
return output

0 commit comments

Comments
 (0)