Skip to content

Commit ebc5e29

Browse files
chore(pre-commit.ci): auto fixes
1 parent a523199 commit ebc5e29

File tree

8 files changed

+67
-56
lines changed

8 files changed

+67
-56
lines changed

src/sklearn_utilities/eval_set.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -131,23 +131,25 @@ def __init__(
131131
self,
132132
estimator: TEstimator,
133133
*,
134-
tqdm_cls: Literal[
135-
"auto",
136-
"autonotebook",
137-
"std",
138-
"notebook",
139-
"asyncio",
140-
"keras",
141-
"dask",
142-
"tk",
143-
"gui",
144-
"rich",
145-
"contrib.slack",
146-
"contrib.discord",
147-
"contrib.telegram",
148-
"contrib.bells",
149-
]
150-
| type[tqdm.std.tqdm] = "auto",
134+
tqdm_cls: (
135+
Literal[
136+
"auto",
137+
"autonotebook",
138+
"std",
139+
"notebook",
140+
"asyncio",
141+
"keras",
142+
"dask",
143+
"tk",
144+
"gui",
145+
"rich",
146+
"contrib.slack",
147+
"contrib.discord",
148+
"contrib.telegram",
149+
"contrib.bells",
150+
]
151+
| type[tqdm.std.tqdm]
152+
) = "auto",
151153
tqdm_kwargs: dict[str, Any] | None = None,
152154
verbose: bool = True,
153155
) -> None:

src/sklearn_utilities/pandas/dataframe_wrapper.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,24 @@ def to_frame_or_series(
2525
return Series(
2626
array,
2727
index=base_index if array.shape[0] == len(base_index) else None,
28-
name=base_columns_or_name
29-
if not isinstance(base_columns_or_name, Index)
30-
else None,
28+
name=(
29+
base_columns_or_name
30+
if not isinstance(base_columns_or_name, Index)
31+
else None
32+
),
3133
)
3234
if array.ndim == 2:
3335
return DataFrame(
3436
array,
3537
index=base_index if array.shape[0] == len(base_index) else None,
36-
columns=base_columns_or_name
37-
if (
38-
isinstance(base_columns_or_name, Index)
39-
and array.shape[1] == len(base_columns_or_name)
40-
)
41-
else None,
38+
columns=(
39+
base_columns_or_name
40+
if (
41+
isinstance(base_columns_or_name, Index)
42+
and array.shape[1] == len(base_columns_or_name)
43+
)
44+
else None
45+
),
4246
)
4347
except Exception as e:
4448
warnings.warn(f"Could not convert {array} to DataFrame or Series: {e}")

src/sklearn_utilities/pandas/multioutput.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,12 @@ def predict(
115115
check_is_fitted(self)
116116
X = X[self.feature_names_in_]
117117
preds = [est.predict(X, **predict_params) for est in self.estimators_]
118-
preds_: DataFrame | Series | NDArray[Any] | tuple[
119-
DataFrame | Series | NDArray[Any], ...
120-
]
118+
preds_: (
119+
DataFrame
120+
| Series
121+
| NDArray[Any]
122+
| tuple[DataFrame | Series | NDArray[Any], ...]
123+
)
121124
if any(isinstance(pred, tuple) for pred in preds):
122125
# list of tuples of arrays to tuples of arrays
123126
preds_ = tuple(np.array(pred).T for pred in zip(*preds))

src/sklearn_utilities/proba/compose_var.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ def fit(self, X: TX, y: TY, **fit_params: Any) -> Self:
3737
@overload
3838
def predict(
3939
self, X: TX, return_std: Literal[False] = ..., **predict_params: Any
40-
) -> TY:
41-
...
40+
) -> TY: ...
4241

4342
@overload
4443
def predict(
4544
self, X: TX, return_std: Literal[True], **predict_params: Any
46-
) -> tuple[TY, TY]:
47-
...
45+
) -> tuple[TY, TY]: ...
4846

4947
def predict(
5048
self, X: TX, return_std: bool = False, **predict_params: Any

src/sklearn_utilities/reindex_missing_columns.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ class ReindexMissingColumns(BaseEstimator, TransformerMixin):
1616
def __init__(
1717
self,
1818
*,
19-
if_missing: Literal["warn", "raise"]
20-
| Callable[[Index[Any], Index[Any]], None] = "warn",
19+
if_missing: (
20+
Literal["warn", "raise"] | Callable[[Index[Any], Index[Any]], None]
21+
) = "warn",
2122
reindex_kwargs: dict[
2223
Literal["method", "copy", "level", "fill_value", "limit", "tolerance"], Any
2324
] = {},

src/sklearn_utilities/report_non_finite.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ def __init__(
2828
plot: bool = True,
2929
calc_corr: bool = False,
3030
callback: Callable[[dict[str, DataFrame | Series]], None] | None = None,
31-
callback_figure: Callable[[Figure], None]
32-
| None = lambda fig: Path("sklearn_utilities_info/ReportNonFinite").mkdir( # type: ignore
31+
callback_figure: Callable[[Figure], None] | None = lambda fig: Path(
32+
"sklearn_utilities_info/ReportNonFinite"
33+
).mkdir( # type: ignore
3334
parents=True, exist_ok=True
3435
)
3536
or fig.savefig(

src/sklearn_utilities/torch/pca.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ def __init__(
121121
*,
122122
qr: bool = False,
123123
svd_flip: bool | None = None,
124-
device: torch.device | int | str = "cuda"
125-
if torch.cuda.is_available()
126-
else "cpu",
124+
device: torch.device | int | str = (
125+
"cuda" if torch.cuda.is_available() else "cpu"
126+
),
127127
dtype: torch.dtype = torch.float32,
128128
**kwargs: Any,
129129
) -> None:

src/sklearn_utilities/torch/skorch/proba.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -297,22 +297,24 @@ def predict(
297297
X: TX,
298298
*,
299299
return_std: bool = False,
300-
type_: Literal[
301-
"mean",
302-
"median",
303-
"nanmean",
304-
"nanmedian",
305-
"var",
306-
"std",
307-
"ptp",
308-
"nanvar",
309-
"nanstd",
310-
]
311-
| tuple[
312-
Literal["mean", "median", "nanmean", "nanmedian"],
313-
Literal["var", "std", "ptp", "nanvar", "nanstd"],
314-
]
315-
| None = None,
300+
type_: (
301+
Literal[
302+
"mean",
303+
"median",
304+
"nanmean",
305+
"nanmedian",
306+
"var",
307+
"std",
308+
"ptp",
309+
"nanvar",
310+
"nanstd",
311+
]
312+
| tuple[
313+
Literal["mean", "median", "nanmean", "nanmedian"],
314+
Literal["var", "std", "ptp", "nanvar", "nanstd"],
315+
]
316+
| None
317+
) = None,
316318
**predict_params: Any,
317319
) -> TY | tuple[TY, TY]:
318320
ts_axis_ = self.estimator.criterion.ts_axis_

0 commit comments

Comments
 (0)