Skip to content

Commit 5e05a79

Browse files
list to List
1 parent ca71fc0 commit 5e05a79

File tree

4 files changed

+13
-13
lines changed

4 files changed

+13
-13
lines changed

tsml/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"_clone_estimator",
88
]
99

10-
from typing import Tuple, Union
10+
from typing import List, Tuple, Union
1111

1212
import numpy as np
1313
from numpy.random import RandomState
@@ -30,9 +30,9 @@ def _validate_data(
3030
**check_params,
3131
) -> Union[
3232
Tuple[np.ndarray, object],
33-
Tuple[list[np.ndarray], object],
33+
Tuple[List[np.ndarray], object],
3434
np.ndarray,
35-
list[np.ndarray],
35+
List[np.ndarray],
3636
]:
3737
"""Validate input data and set or check the `n_features_in_` attribute.
3838
@@ -109,7 +109,7 @@ def _validate_data(
109109

110110
return out
111111

112-
def _check_n_features(self, X: Union[np.ndarray, list[np.ndarray]], reset: bool):
112+
def _check_n_features(self, X: Union[np.ndarray, List[np.ndarray]], reset: bool):
113113
"""Set the `n_features_in_` attribute, or check against it.
114114
115115
Uses the `scikit-learn` 1.2.1 `_check_n_features` function as a base.
@@ -172,7 +172,7 @@ def _more_tags(self) -> dict:
172172
@classmethod
173173
def get_test_params(
174174
cls, parameter_set: Union[str, None] = None
175-
) -> Union[dict, list[dict]]:
175+
) -> Union[dict, List[dict]]:
176176
"""Return unit test parameter settings for the estimator.
177177
178178
Parameters

tsml/utils/discovery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from importlib import import_module
1010
from operator import itemgetter
1111
from pathlib import Path
12-
from typing import Union
12+
from typing import List, Union
1313

1414
from sklearn.base import (
1515
BaseEstimator,
@@ -25,7 +25,7 @@
2525
}
2626

2727

28-
def all_estimators(type_filter: Union[str, list[str]] = None):
28+
def all_estimators(type_filter: Union[str, List[str]] = None):
2929
"""Get a list of all estimators from `tsml`.
3030
3131
This function crawls the module and gets all classes that inherit

tsml/utils/testing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
]
1010

1111
from functools import partial
12-
from typing import Callable, Tuple, Union
12+
from typing import Callable, List, Tuple, Union
1313

1414
import numpy as np
1515
from sklearn.base import BaseEstimator
@@ -24,7 +24,7 @@
2424
from tsml.utils.discovery import all_estimators
2525

2626

27-
def generate_test_estimators() -> list[BaseEstimator]:
27+
def generate_test_estimators() -> List[BaseEstimator]:
2828
"""Generate a list of all estimators in tsml with test parameters.
2929
3030
Uses estimator parameters from `get_test_params` if available.
@@ -51,7 +51,7 @@ def generate_test_estimators() -> list[BaseEstimator]:
5151
return estimators
5252

5353

54-
def parametrize_with_checks(estimators: list[BaseEstimator]) -> Callable:
54+
def parametrize_with_checks(estimators: List[BaseEstimator]) -> Callable:
5555
"""Pytest specific decorator for parametrizing estimator checks.
5656
5757
If the estimator is a `BaseTimeSeriesEstimator` then the `tsml` checks are used,

tsml/utils/validation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import os
1414
import warnings
1515
from importlib import import_module
16-
from typing import Tuple, Union
16+
from typing import List, Tuple, Union
1717

1818
import numpy as np
1919
from packaging.requirements import InvalidRequirement, Requirement
@@ -97,7 +97,7 @@ def is_clusterer(estimator: BaseEstimator) -> bool:
9797
return getattr(estimator, "_estimator_type", None) == "clusterer"
9898

9999

100-
def _num_features(X: Union[np.ndarray, list[np.ndarray]]) -> tuple[int]:
100+
def _num_features(X: Union[np.ndarray, List[np.ndarray]]) -> tuple[int]:
101101
"""Return the number of features of a 3D numpy array or a list of 2D numpy arrays.
102102
103103
Returns
@@ -127,7 +127,7 @@ def check_X_y(
127127
ensure_min_series_length: int = 2,
128128
estimator: Union[str, BaseEstimator, None] = None,
129129
y_numeric: bool = False,
130-
) -> Union[Tuple[np.ndarray, np.ndarray], Tuple[list[np.ndarray], np.ndarray]]:
130+
) -> Union[Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], np.ndarray]]:
131131
"""Input validation for standard estimators.
132132
133133
Checks X and y for consistent length, enforces X to be 3D and y 1D. By default,

0 commit comments

Comments
 (0)