Skip to content

Commit 7d0e1f5

Browse files
authored
feat(reindex_missing_columns): add ReindexMissingColumns and ReportNonFinite (#27)
* feat(reindex_missing_columns): add `ReindexMissingColumns` * feat: add `ReportNonFinite` * feat: export modules * docs(readme): update README.md * fix: import annotations from __future__
1 parent a34d784 commit 7d0e1f5

File tree

9 files changed

+756
-463
lines changed

9 files changed

+756
-463
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,4 @@ dmypy.json
140140
cython_debug/
141141
.VSCodeCounter/
142142
catboost_info/
143+
sklearn_utilities_info/

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ See [Docs](https://sklearn-utilities.readthedocs.io/en/latest/sklearn_utilities.
5454
- `DropMissingColumns`: drops columns with missing values above a threshold.
5555
- `DropMissingRowsY`: drops rows with missing values in y. Use `feature_engine.DropMissingData` for X.
5656
- `IntersectXY`: drops rows where the index of X and y do not intersect. Use with `feature_engine.DropMissingData`.
57+
- `ReindexMissingColumns`: reindexes columns of X in `transform()` to match the columns of X in `fit()`.
58+
- `ReportNonFinite`: reports non-finite values in X and/or y.
5759
- `IdTransformer`: a transformer that does nothing.
5860
- `RecursiveFitSubtractRegressor`: a regressor that recursively fits a regressor and subtracts the prediction from the target.
5961
- `SmartMultioutputEstimator`: a `MultiOutputEstimator` that supports tuple of arrays in `predict()` and supports pandas `Series` and `DataFrame`.

poetry.lock

Lines changed: 419 additions & 463 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ torch = "==2.0.0||>2.1.0"
5252
[tool.poetry.group.catboost.dependencies]
5353
catboost = "^1.2.2"
5454

55+
56+
[tool.poetry.group.seaborn.dependencies]
57+
seaborn = "^0.13.0"
58+
5559
[tool.semantic_release]
5660
branch = "main"
5761
version_toml = ["pyproject.toml:tool.poetry.version"]

src/sklearn_utilities/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
TransformedTargetEstimatorVar,
2525
)
2626
from .recursive_fit_subtract_regressor import RecursiveFitSubtractRegressor
27+
from .reindex_missing_columns import ReindexMissingColumns
28+
from .report_non_finite import ReportNonFinite
2729

2830
__all__ = [
2931
"DataFrameWrapper",
@@ -49,4 +51,6 @@
4951
"PipelineVar",
5052
"StandardScalerVar",
5153
"EvalSetWrapper",
54+
"ReportNonFinite",
55+
"ReindexMissingColumns",
5256
]
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from __future__ import annotations
2+
3+
import warnings
4+
from typing import Any, Callable, Literal
5+
6+
from pandas import DataFrame, Index
7+
from sklearn.base import BaseEstimator, TransformerMixin
8+
from typing_extensions import Self
9+
10+
from .types import TXPandas
11+
12+
13+
class ReindexMissingColumns(BaseEstimator, TransformerMixin):
14+
"""Reindex X to match the columns of the training data to avoid errors."""
15+
16+
def __init__(
17+
self,
18+
*,
19+
if_missing: Literal["warn", "raise"]
20+
| Callable[[Index[Any], Index[Any]], None] = "warn",
21+
reindex_kwargs: dict[
22+
Literal["method", "copy", "level", "fill_value", "limit", "tolerance"], Any
23+
] = {},
24+
) -> None:
25+
"""Reindex X to match the columns of the training data to avoid errors.
26+
27+
Parameters
28+
----------
29+
if_missing : Literal['warn', 'raise'] | Callable[[Index[Any], Index[Any]], None], optional
30+
If callable, the first argument is the expected columns and the
31+
second argument is the actual columns, by default 'warn'
32+
reindex_kwargs : dict[Literal['method', 'copy', 'level', 'fill_value',
33+
'limit', 'tolerance'], Any], optional
34+
Keyword arguments to pass to reindex, by default {}
35+
"""
36+
self.if_missing = if_missing
37+
self.reindex_kwargs = reindex_kwargs
38+
39+
def fit(self, X: DataFrame, y: Any = None, **fit_params: Any) -> Self:
40+
self.feature_names_in_ = X.columns
41+
return self
42+
43+
def transform(self, X: TXPandas, y: Any = None, **fit_params: Any) -> TXPandas:
44+
expected_columns = self.feature_names_in_
45+
actual_columns = X.columns
46+
if not expected_columns.equals(actual_columns):
47+
missing_columns = expected_columns.difference(actual_columns)
48+
if self.if_missing == "warn":
49+
warnings.warn(f"Missing columns: {missing_columns}")
50+
elif self.if_missing == "raise":
51+
raise ValueError(f"Missing columns: {missing_columns}")
52+
elif isinstance(self.if_missing, Callable): # type: ignore
53+
self.if_missing(expected_columns, actual_columns)
54+
else:
55+
raise ValueError(f"Invalid value for if_missing: {self.if_missing}")
56+
57+
return X.reindex(columns=self.feature_names_in_, **self.reindex_kwargs)
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
from __future__ import annotations
2+
3+
from logging import getLogger
4+
from pathlib import Path
5+
from typing import Any, Callable
6+
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
from matplotlib.figure import Figure
10+
from pandas import DataFrame, Series, Timestamp
11+
from sklearn.base import BaseEstimator, TransformerMixin
12+
from typing_extensions import Self
13+
14+
from .types import TXPandas
15+
16+
LOG = getLogger(__name__)
17+
18+
19+
class ReportNonFinite(BaseEstimator, TransformerMixin):
20+
"""Report non-finite values in X or y."""
21+
22+
def __init__(
23+
self,
24+
*,
25+
on_fit: bool = False,
26+
on_fit_y: bool = False,
27+
on_transform: bool = True,
28+
plot: bool = True,
29+
calc_corr: bool = False,
30+
callback: Callable[[dict[str, DataFrame | Series]], None] | None = None,
31+
callback_figure: Callable[[Figure], None]
32+
| None = lambda fig: Path("sklearn_utilities_info/ReportNonFinite").mkdir( # type: ignore
33+
parents=True, exist_ok=True
34+
)
35+
or fig.savefig(
36+
Path("sklearn_utilities_info/ReportNonFinite")
37+
/ f"{Timestamp.now().isoformat().replace(':', '-')}.png"
38+
),
39+
) -> None:
40+
"""Report non-finite values in X or y.
41+
42+
Parameters
43+
----------
44+
on_fit : bool, optional
45+
Whether to report non-finite values in X during fit, by default False
46+
on_fit_y : bool, optional
47+
Whether to report non-finite values in y during fit, by default False
48+
on_transform : bool, optional
49+
Whether to report non-finite values in X during transform, by default True
50+
plot : bool, optional
51+
Whether to plot the report result, by default True
52+
calc_corr : bool, optional
53+
Whether to calculate the correlation of non-finite values, by default False
54+
callback : Callable[[dict[str, DataFrame | Series]], None] | None, optional
55+
The callback function, by default None
56+
callback_figure : _type_, optional
57+
The callback function for figure, by default
58+
`lambda fig:
59+
Path("sklearn-utilities/ReportNonFinite").mkdir(parents=True, exist_ok=True)
60+
or fig.savefig(Path("sklearn-utilities/ReportNonFinite") /
61+
f"{Timestamp.now().isoformat().replace(':', '-')}.png")`
62+
"""
63+
self.on_fit = on_fit
64+
self.on_fit_y = on_fit_y
65+
self.on_transform = on_transform
66+
self.plot = plot
67+
self.calc_corr = calc_corr
68+
self.callback = callback
69+
self.callback_figure = callback_figure
70+
71+
def fit(self, X: DataFrame, y: Any = None, **fit_params: Any) -> Self:
72+
if self.on_fit:
73+
try:
74+
self._report(X, "fit")
75+
except Exception as e:
76+
LOG.warning(f"Failed to report non-finite values in X during fit: {e}")
77+
LOG.exception(e)
78+
79+
if self.on_fit_y:
80+
try:
81+
DataFrame(y)
82+
except Exception as e:
83+
LOG.warning(f"Failed to convert y to DataFrame during fit: {e}")
84+
LOG.exception(e)
85+
86+
try:
87+
self._report(DataFrame(y), "fit_y")
88+
except Exception as e:
89+
LOG.warning(f"Failed to report non-finite values in y during fit: {e}")
90+
LOG.exception(e)
91+
return self
92+
93+
def transform(self, X: TXPandas, y: Any = None, **fit_params: Any) -> TXPandas:
94+
if self.on_transform:
95+
try:
96+
self._report(X, "transform")
97+
except Exception as e:
98+
LOG.warning(
99+
f"Failed to report non-finite values in X during transform: {e}"
100+
)
101+
LOG.exception(e)
102+
return X
103+
104+
def _report(self, X: TXPandas, caller: str = "") -> TXPandas:
105+
"""Report non-finite values in X.
106+
107+
Parameters
108+
----------
109+
X : TXPandas
110+
Input data.
111+
caller : str, optional
112+
The caller name used in the log message, by default "".
113+
114+
Returns
115+
-------
116+
TXPandas
117+
Input data.
118+
"""
119+
is_na = X.isna()
120+
is_inf = X.isin([np.inf, -np.inf])
121+
is_non_finite = is_na | is_inf
122+
123+
d: dict[str, DataFrame | Series] = {
124+
"nan_rate_by_column": is_na.mean(),
125+
"inf_rate_by_column": is_inf.mean(),
126+
"nan_rate_by_row": is_na.mean(axis=1),
127+
"inf_rate_by_row": is_inf.mean(axis=1),
128+
}
129+
d = d | {
130+
"non_finite_rate_by_column": d["nan_rate_by_column"]
131+
+ d["inf_rate_by_column"],
132+
"non_finite_rate_by_row": d["nan_rate_by_row"] + d["inf_rate_by_row"],
133+
}
134+
135+
if self.calc_corr:
136+
d["nan_rate_corr_by_column"] = is_na.corr()
137+
d["inf_rate_corr_by_column"] = is_inf.corr()
138+
d["non_finite_rate_corr_by_column"] = is_non_finite.corr()
139+
140+
LOG.info(f"Non-finite values in X during {caller}: {d}")
141+
142+
if self.plot:
143+
import seaborn as sns
144+
145+
fig, axes = plt.subplots(3, 3 if self.calc_corr else 2, figsize=(20, 10))
146+
fig.suptitle(f"Non-finite values in X during {caller}")
147+
d["nan_rate_by_column"].plot(
148+
ax=axes[0, 0],
149+
kind="bar",
150+
title="NaN rate By column",
151+
xlabel="column name",
152+
ylabel="NaN rate",
153+
)
154+
d["inf_rate_by_column"].plot(
155+
ax=axes[1, 0],
156+
kind="bar",
157+
title="Inf rate By column",
158+
xlabel="column name",
159+
ylabel="Inf rate",
160+
)
161+
d["non_finite_rate_by_column"].plot(
162+
ax=axes[2, 0],
163+
kind="bar",
164+
title="Non-finite rate By column",
165+
xlabel="column name",
166+
ylabel="Non-finite rate",
167+
)
168+
d["nan_rate_by_row"].plot(
169+
ax=axes[0, 1],
170+
kind="line",
171+
title="NaN rate By row",
172+
xlabel="row index",
173+
ylabel="NaN rate",
174+
)
175+
d["inf_rate_by_row"].plot(
176+
ax=axes[1, 1],
177+
kind="line",
178+
title="Inf rate By row",
179+
xlabel="row index",
180+
ylabel="Inf rate",
181+
)
182+
d["non_finite_rate_by_row"].plot(
183+
ax=axes[2, 1],
184+
kind="line",
185+
title="Non-finite rate By row",
186+
xlabel="row index",
187+
ylabel="Non-finite rate",
188+
)
189+
if self.calc_corr:
190+
sns.heatmap(
191+
d["nan_rate_corr_by_column"], ax=axes[0, 2], vmin=-1, vmax=1
192+
)
193+
axes[0, 2].set_title("NaN rate Corr By column")
194+
sns.heatmap(
195+
d["inf_rate_corr_by_column"], ax=axes[1, 2], vmin=-1, vmax=1
196+
)
197+
axes[1, 2].set_title("Inf rate Corr By column")
198+
sns.heatmap(
199+
d["non_finite_rate_corr_by_column"], ax=axes[2, 2], vmin=-1, vmax=1
200+
)
201+
axes[2, 2].set_title("Non-finite rate Corr By column")
202+
203+
# tight layout
204+
plt.tight_layout()
205+
206+
# callback
207+
if self.callback_figure is not None:
208+
self.callback_figure(fig)
209+
210+
if self.callback is not None:
211+
self.callback(d)
212+
return X
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
from pandas.testing import assert_frame_equal
5+
6+
from sklearn_utilities.reindex_missing_columns import ReindexMissingColumns
7+
8+
9+
def test_reindex_missing_columns() -> None:
10+
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
11+
df_missing = df.drop(columns=["a"], inplace=False)
12+
df_expected = pd.DataFrame({"a": [np.nan] * 3, "b": [4, 5, 6]})
13+
estimator = ReindexMissingColumns().fit(df)
14+
15+
with pytest.warns(UserWarning):
16+
df_out = estimator.transform(df_missing)
17+
18+
assert_frame_equal(df_out, df_expected)

tests/test_report_non_finite.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Any
2+
3+
from pandas import DataFrame
4+
from pandas.testing import assert_frame_equal
5+
6+
from sklearn_utilities.report_non_finite import ReportNonFinite
7+
8+
9+
def test_fit(caplog: Any) -> None:
10+
caplog.set_level("DEBUG")
11+
X = DataFrame(
12+
{
13+
"clean": [1, 2, 3],
14+
"inf": [4, float("inf"), 6],
15+
"nan": [7, 8, float("nan")],
16+
"both": [float("inf"), float("nan"), 9],
17+
}
18+
)
19+
transformer = ReportNonFinite(on_fit=True, calc_corr=True)
20+
transformer.fit(X)
21+
# Add assertions to check the expected behavior of the fit method
22+
assert "Non-finite values in X during fit" in caplog.text
23+
24+
25+
def test_transform(caplog: Any) -> None:
26+
caplog.set_level("DEBUG")
27+
X = DataFrame(
28+
{
29+
"clean": [1, 2, 3],
30+
"inf": [4, float("inf"), 6],
31+
"nan": [7, 8, float("nan")],
32+
"both": [float("inf"), float("nan"), 9],
33+
}
34+
)
35+
transformer = ReportNonFinite(on_transform=True, calc_corr=True)
36+
X_transformed = transformer.transform(X)
37+
# Add assertions to check the expected behavior of the transform method
38+
assert "Non-finite values in X during transform" in caplog.text
39+
assert_frame_equal(X_transformed, X)

0 commit comments

Comments
 (0)