|
9 | 9 | import numpy as np |
10 | 10 |
|
11 | 11 | from sklearn.datasets import load_iris |
| 12 | +from sklearn.datasets import fetch_openml |
12 | 13 |
|
13 | 14 | from imblearn.datasets import make_imbalance |
14 | 15 |
|
@@ -52,3 +53,22 @@ def test_make_imbalance_dict(iris, sampling_strategy, expected_counts): |
52 | 53 | X, y = iris |
53 | 54 | _, y_ = make_imbalance(X, y, sampling_strategy=sampling_strategy) |
54 | 55 | assert Counter(y_) == expected_counts |
| 56 | + |
| 57 | + |
| 58 | +@pytest.mark.parametrize("as_frame", [True, False], ids=['dataframe', 'array']) |
| 59 | +@pytest.mark.parametrize( |
| 60 | + "sampling_strategy, expected_counts", |
| 61 | + [ |
| 62 | + ({'Iris-setosa': 10, 'Iris-versicolor': 20, 'Iris-virginica': 30}, |
| 63 | + {'Iris-setosa': 10, 'Iris-versicolor': 20, 'Iris-virginica': 30}), |
| 64 | + ({'Iris-setosa': 10, 'Iris-versicolor': 20}, |
| 65 | + {'Iris-setosa': 10, 'Iris-versicolor': 20, 'Iris-virginica': 50}), |
| 66 | + ], |
| 67 | +) |
| 68 | +def test_make_imbalanced_iris(as_frame, sampling_strategy, expected_counts): |
| 69 | + pytest.importorskip("pandas") |
| 70 | + X, y = fetch_openml('iris', version=1, return_X_y=True, as_frame=as_frame) |
| 71 | + X_res, y_res = make_imbalance(X, y, sampling_strategy=sampling_strategy) |
| 72 | + if as_frame: |
| 73 | + assert hasattr(X_res, "loc") |
| 74 | + assert Counter(y_res) == expected_counts |
0 commit comments