Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.12.1
rev: 24.1.0
hooks:
- id: black
- repo: https://github.com/codespell-project/codespell
Expand Down
36 changes: 19 additions & 17 deletions src/sklearn_utilities/eval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,23 +131,25 @@ def __init__(
self,
estimator: TEstimator,
*,
tqdm_cls: Literal[
"auto",
"autonotebook",
"std",
"notebook",
"asyncio",
"keras",
"dask",
"tk",
"gui",
"rich",
"contrib.slack",
"contrib.discord",
"contrib.telegram",
"contrib.bells",
]
| type[tqdm.std.tqdm] = "auto",
tqdm_cls: (
Literal[
"auto",
"autonotebook",
"std",
"notebook",
"asyncio",
"keras",
"dask",
"tk",
"gui",
"rich",
"contrib.slack",
"contrib.discord",
"contrib.telegram",
"contrib.bells",
]
| type[tqdm.std.tqdm]
) = "auto",
tqdm_kwargs: dict[str, Any] | None = None,
verbose: bool = True,
) -> None:
Expand Down
22 changes: 13 additions & 9 deletions src/sklearn_utilities/pandas/dataframe_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,24 @@ def to_frame_or_series(
return Series(
array,
index=base_index if array.shape[0] == len(base_index) else None,
name=base_columns_or_name
if not isinstance(base_columns_or_name, Index)
else None,
name=(
base_columns_or_name
if not isinstance(base_columns_or_name, Index)
else None
),
)
if array.ndim == 2:
return DataFrame(
array,
index=base_index if array.shape[0] == len(base_index) else None,
columns=base_columns_or_name
if (
isinstance(base_columns_or_name, Index)
and array.shape[1] == len(base_columns_or_name)
)
else None,
columns=(
base_columns_or_name
if (
isinstance(base_columns_or_name, Index)
and array.shape[1] == len(base_columns_or_name)
)
else None
),
)
except Exception as e:
warnings.warn(f"Could not convert {array} to DataFrame or Series: {e}")
Expand Down
9 changes: 6 additions & 3 deletions src/sklearn_utilities/pandas/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,12 @@ def predict(
check_is_fitted(self)
X = X[self.feature_names_in_]
preds = [est.predict(X, **predict_params) for est in self.estimators_]
preds_: DataFrame | Series | NDArray[Any] | tuple[
DataFrame | Series | NDArray[Any], ...
]
preds_: (
DataFrame
| Series
| NDArray[Any]
| tuple[DataFrame | Series | NDArray[Any], ...]
)
if any(isinstance(pred, tuple) for pred in preds):
# list of tuples of arrays to tuples of arrays
preds_ = tuple(np.array(pred).T for pred in zip(*preds))
Expand Down
6 changes: 2 additions & 4 deletions src/sklearn_utilities/proba/compose_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ def fit(self, X: TX, y: TY, **fit_params: Any) -> Self:
@overload
def predict(
self, X: TX, return_std: Literal[False] = ..., **predict_params: Any
) -> TY:
...
) -> TY: ...

@overload
def predict(
self, X: TX, return_std: Literal[True], **predict_params: Any
) -> tuple[TY, TY]:
...
) -> tuple[TY, TY]: ...

def predict(
self, X: TX, return_std: bool = False, **predict_params: Any
Expand Down
5 changes: 3 additions & 2 deletions src/sklearn_utilities/reindex_missing_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ class ReindexMissingColumns(BaseEstimator, TransformerMixin):
def __init__(
self,
*,
if_missing: Literal["warn", "raise"]
| Callable[[Index[Any], Index[Any]], None] = "warn",
if_missing: (
Literal["warn", "raise"] | Callable[[Index[Any], Index[Any]], None]
) = "warn",
reindex_kwargs: dict[
Literal["method", "copy", "level", "fill_value", "limit", "tolerance"], Any
] = {},
Expand Down
5 changes: 3 additions & 2 deletions src/sklearn_utilities/report_non_finite.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ def __init__(
plot: bool = True,
calc_corr: bool = False,
callback: Callable[[dict[str, DataFrame | Series]], None] | None = None,
callback_figure: Callable[[Figure], None]
| None = lambda fig: Path("sklearn_utilities_info/ReportNonFinite").mkdir( # type: ignore
callback_figure: Callable[[Figure], None] | None = lambda fig: Path(
"sklearn_utilities_info/ReportNonFinite"
).mkdir( # type: ignore
parents=True, exist_ok=True
)
or fig.savefig(
Expand Down
6 changes: 3 additions & 3 deletions src/sklearn_utilities/torch/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def __init__(
*,
qr: bool = False,
svd_flip: bool | None = None,
device: torch.device | int | str = "cuda"
if torch.cuda.is_available()
else "cpu",
device: torch.device | int | str = (
"cuda" if torch.cuda.is_available() else "cpu"
),
dtype: torch.dtype = torch.float32,
**kwargs: Any,
) -> None:
Expand Down
34 changes: 18 additions & 16 deletions src/sklearn_utilities/torch/skorch/proba.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,22 +297,24 @@ def predict(
X: TX,
*,
return_std: bool = False,
type_: Literal[
"mean",
"median",
"nanmean",
"nanmedian",
"var",
"std",
"ptp",
"nanvar",
"nanstd",
]
| tuple[
Literal["mean", "median", "nanmean", "nanmedian"],
Literal["var", "std", "ptp", "nanvar", "nanstd"],
]
| None = None,
type_: (
Literal[
"mean",
"median",
"nanmean",
"nanmedian",
"var",
"std",
"ptp",
"nanvar",
"nanstd",
]
| tuple[
Literal["mean", "median", "nanmean", "nanmedian"],
Literal["var", "std", "ptp", "nanvar", "nanstd"],
]
| None
) = None,
**predict_params: Any,
) -> TY | tuple[TY, TY]:
ts_axis_ = self.estimator.criterion.ts_axis_
Expand Down