diff --git a/boruta/boruta_py.py b/boruta/boruta_py.py index 4b463cc..d2a2d0a 100644 --- a/boruta/boruta_py.py +++ b/boruta/boruta_py.py @@ -12,11 +12,13 @@ import numpy as np import scipy as sp from sklearn.utils import check_random_state, check_X_y -from sklearn.base import TransformerMixin, BaseEstimator +from sklearn.base import BaseEstimator +from sklearn.feature_selection import SelectorMixin +from sklearn.utils.validation import check_is_fitted import warnings -class BorutaPy(BaseEstimator, TransformerMixin): +class BorutaPy(BaseEstimator, SelectorMixin): """ Improved Python implementation of the Boruta R package. @@ -287,11 +289,19 @@ def _fit(self, X, y): # check input params self._check_params(X, y) + feature_names = getattr(X, "columns", None) + if feature_names is not None: + self.feature_names_in_ = np.asarray(feature_names, dtype=object) + else: + self.feature_names_in_ = None + if not isinstance(X, np.ndarray): X = self._validate_pandas_input(X) if not isinstance(y, np.ndarray): y = self._validate_pandas_input(y) + self.n_features_in_ = X.shape[1] + self.random_state = check_random_state(self.random_state) early_stopping = False @@ -465,6 +475,10 @@ def _set_n_estimators(self, n_estimators): ) return self + def _get_support_mask(self): + check_is_fitted(self, 'support_') + return self.support_ + def _get_tree_num(self, n_feat): depth = None try: diff --git a/boruta/test/test_boruta.py b/boruta/test/test_boruta.py index 3ad05ec..64a3691 100644 --- a/boruta/test/test_boruta.py +++ b/boruta/test/test_boruta.py @@ -2,6 +2,7 @@ import pandas as pd import pytest from sklearn.ensemble import RandomForestClassifier +from sklearn.exceptions import NotFittedError from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier from boruta import BorutaPy @@ -68,6 +69,62 @@ def test_dataframe_is_returned(Xy): assert isinstance(bt.transform(X_df, return_df=True), pd.DataFrame) +def test_selector_mixin_get_support_requires_fit(): + bt = BorutaPy(RandomForestClassifier()) + with pytest.raises(NotFittedError): + bt.get_support() + + +def test_selector_mixin_get_support_matches_mask(Xy): + X, y = Xy + bt = BorutaPy(RandomForestClassifier()) + bt.fit(X, y) + + assert np.array_equal(bt.get_support(), bt.support_) + assert np.array_equal(bt.get_support(indices=True), + np.where(bt.support_)[0]) + + +def test_selector_mixin_inverse_transform_restores_selected_features(Xy): + X, y = Xy + bt = BorutaPy(RandomForestClassifier()) + bt.fit(X, y) + + X_selected = bt.transform(X) + X_reconstructed = bt.inverse_transform(X_selected) + + assert X_reconstructed.shape == X.shape + assert np.allclose(X_reconstructed[:, bt.support_], X[:, bt.support_]) + + if (~bt.support_).any(): + assert np.allclose(X_reconstructed[:, ~bt.support_], 0) + + +def test_selector_mixin_get_feature_names_out_requires_fit(): + bt = BorutaPy(RandomForestClassifier()) + with pytest.raises(NotFittedError): + bt.get_feature_names_out() + + +def test_selector_mixin_get_feature_names_out_returns_selected_names(Xy): + X, y = Xy + bt = BorutaPy(RandomForestClassifier()) + bt.fit(X, y) + + expected_default = np.array([f"x{i}" for i in np.where(bt.support_)[0]]) + assert np.array_equal(bt.get_feature_names_out(), expected_default) + + custom_names = np.array([f"feature_{i}" for i in range(X.shape[1])]) + selected_names = bt.get_feature_names_out(custom_names) + assert np.array_equal(selected_names, custom_names[bt.support_]) + + columns = [f"col_{i}" for i in range(X.shape[1])] + X_df = pd.DataFrame(X, columns=columns) + bt_df = BorutaPy(RandomForestClassifier()) + bt_df.fit(X_df, y) + assert np.array_equal(bt_df.get_feature_names_out(), np.array(columns)[bt_df.support_]) + + @pytest.mark.parametrize("tree", [ExtraTreeClassifier(), DecisionTreeClassifier()]) def test_boruta_with_decision_trees(tree, Xy): msg = ( @@ -80,4 +137,4 @@ def test_boruta_with_decision_trees(tree, Xy): with pytest.raises(ValueError) as record: bt.fit(X, y) - assert str(record.value) == msg \ No newline at end of file + assert str(record.value) == msg