Skip to content

Commit 91885a3

Browse files
FIX: Fix error when calling forest classifiers on single-class data (#2723)
* fix random forest on single-class data * fix tests * fix wrong method * missing import * rename test * remove unreachable branch * further simplifications * use 'get_namespace' * missing line * remove unused import * use more specialized function when possible * Revert "use more specialized function when possible" This reverts commit 81cd037.
1 parent 4e56401 commit 91885a3

File tree

5 files changed

+61
-10
lines changed

5 files changed

+61
-10
lines changed

doc/sources/algorithms.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ Classification
5151
- ``warm_start`` = `True`
5252
- ``ccp_alpha`` != `0`
5353
- ``criterion`` != `'gini'`
54-
- Multi-output and sparse data are not supported
54+
- Multi-output and sparse data are not supported. Number of classes must be at least 2.
5555
* - :obj:`sklearn.ensemble.ExtraTreesClassifier`
5656
- All parameters are supported except:
5757

5858
- ``warm_start`` = `True`
5959
- ``ccp_alpha`` != `0`
6060
- ``criterion`` != `'gini'`
61-
- Multi-output and sparse data are not supported
61+
- Multi-output and sparse data are not supported. Number of classes must be at least 2.
6262
* - :obj:`sklearn.neighbors.KNeighborsClassifier`
6363
-
6464
- For ``algorithm`` == `'kd_tree'`:
@@ -293,7 +293,7 @@ Classification
293293
- ``criterion`` != `'gini'`
294294
- ``oob_score`` = `True`
295295
- ``sample_weight`` != `None`
296-
- Multi-output and sparse data are not supported
296+
- Multi-output and sparse data are not supported. Number of classes must be at least 2.
297297
* - :obj:`sklearn.ensemble.ExtraTreesClassifier`
298298
- All parameters are supported except:
299299

@@ -302,7 +302,7 @@ Classification
302302
- ``criterion`` != `'gini'`
303303
- ``oob_score`` = `True`
304304
- ``sample_weight`` != `None`
305-
- Multi-output and sparse data are not supported
305+
- Multi-output and sparse data are not supported. Number of classes must be at least 2.
306306
* - :obj:`sklearn.neighbors.KNeighborsClassifier`
307307
- All parameters are supported except:
308308

@@ -488,7 +488,7 @@ Classification
488488
- ``criterion`` != `'gini'`
489489
- ``oob_score`` = `True`
490490
- ``sample_weight`` != `None`
491-
- Multi-output and sparse data are not supported
491+
- Multi-output and sparse data are not supported. Number of classes must be at least 2.
492492
* - :obj:`sklearn.ensemble.ExtraTreesClassifier`
493493
- All parameters are supported except:
494494

@@ -497,7 +497,7 @@ Classification
497497
- ``criterion`` != `'gini'`
498498
- ``oob_score`` = `True`
499499
- ``sample_weight`` != `None`
500-
- Multi-output and sparse data are not supported
500+
- Multi-output and sparse data are not supported. Number of classes must be at least 2.
501501
* - :obj:`sklearn.neighbors.KNeighborsClassifier`
502502
- All parameters are supported except:
503503

onedal/tests/utils/_dataframes_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _as_numpy(obj, *args, **kwargs):
134134
if dpctl_available and isinstance(obj, dpt.usm_ndarray):
135135
return dpt.to_numpy(obj, *args, **kwargs)
136136
if isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series):
137-
return obj.to_array(*args, **kwargs)
137+
return obj.to_numpy(*args, **kwargs)
138138
if sp.issparse(obj):
139139
return obj.toarray(*args, **kwargs)
140140
return np.asarray(obj, *args, **kwargs)

sklearnex/ensemble/_forest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,19 @@ def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
592592
)
593593
# TODO: Fix to support integers as input
594594

595+
if self.n_outputs_ == 1:
596+
xp, is_array_api_compliant = get_namespace(y)
597+
sety = xp.unique_values(y) if is_array_api_compliant else np.unique(y)
598+
num_classes = sety.shape[0]
599+
patching_status.and_conditions(
600+
[
601+
(
602+
num_classes >= 2,
603+
"Number of classes must be at least 2.",
604+
),
605+
]
606+
)
607+
595608
_get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
596609

597610
if not self.bootstrap and self.max_samples is not None:

sklearnex/ensemble/tests/test_forest.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515
# ===============================================================================
1616

17+
import numpy as np
18+
import pandas as pd
1719
import pytest
1820
from numpy.testing import assert_allclose
1921
from sklearn.datasets import make_classification, make_regression
@@ -153,3 +155,42 @@ def test_sklearnex_import_et_regression(dataframe, queue):
153155
# Check that the trees aren't just empty nodes predicting the mean
154156
for estimator in rf.estimators_:
155157
assert estimator.tree_.children_left.shape[0] > 1
158+
159+
160+
@pytest.mark.allow_sklearn_fallback
161+
@pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
162+
def test_classifiers_work_on_single_class(dataframe, queue):
163+
from sklearnex.ensemble import ExtraTreesClassifier, RandomForestClassifier
164+
165+
rng = np.random.default_rng(seed=123)
166+
X = rng.standard_normal(size=(20, 10))
167+
y = np.zeros(X.shape[0])
168+
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
169+
y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
170+
171+
np.testing.assert_array_equal(
172+
_as_numpy(RandomForestClassifier(n_estimators=1).fit(X, y).predict(X)),
173+
_as_numpy(y),
174+
)
175+
np.testing.assert_array_equal(
176+
_as_numpy(ExtraTreesClassifier(n_estimators=1).fit(X, y).predict(X)),
177+
_as_numpy(y),
178+
)
179+
180+
181+
@pytest.mark.allow_sklearn_fallback
182+
def test_classifiers_work_on_single_class_non_numeric():
183+
from sklearnex.ensemble import ExtraTreesClassifier, RandomForestClassifier
184+
185+
rng = np.random.default_rng(seed=123)
186+
X = rng.standard_normal(size=(20, 10))
187+
y = pd.Series(np.repeat("qwerty", X.shape[0]))
188+
189+
np.testing.assert_array_equal(
190+
RandomForestClassifier(n_estimators=1).fit(X, y).predict(X),
191+
y,
192+
)
193+
np.testing.assert_array_equal(
194+
ExtraTreesClassifier(n_estimators=1).fit(X, y).predict(X),
195+
y,
196+
)

sklearnex/utils/class_weight.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,6 @@ def _compute_class_weight(class_weight, *, classes, y, sample_weight=None):
4444
return compute_class_weight(class_weight, classes, y, sample_weight=sample_weight)
4545

4646
sety = xp.unique_values(y)
47-
setclasses = xp.unique_values(classes)
48-
if sety.shape[0] != xp.unique_values(xp.concat((sety, setclasses))).shape[0]:
49-
raise ValueError("classes should include all valid labels that can be in y")
5047
if class_weight is None or len(class_weight) == 0:
5148
# uniform class weights
5249
weight = xp.ones((classes.shape[0],), dtype=xp.float64, device=classes.device)

0 commit comments

Comments
 (0)