11# -*- coding: utf-8 -*-
2+ """Base classes for estimators."""
23
34__author__ = ["MatthewMiddlehurst" ]
45__all__ = [
56 "BaseTimeSeriesEstimator" ,
6- "clone_estimator " ,
7+ "_clone_estimator " ,
78]
89
9- from typing import Union
10+ from typing import Tuple , Union
1011
12+ import numpy as np
1113from numpy .random import RandomState
1214from sklearn .base import BaseEstimator , clone
1315from sklearn .ensemble ._base import _set_random_states
1820
1921
2022class BaseTimeSeriesEstimator (BaseEstimator ):
23+ """Base class for time series estimators in tsml."""
24+
2125 def _validate_data (
2226 self ,
23- X = "no_validation" ,
24- y = "no_validation" ,
25- reset = True ,
27+ X : object = "no_validation" ,
28+ y : object = "no_validation" ,
29+ reset : bool = True ,
2630 ** check_params ,
27- ):
31+ ) -> Union [
32+ Tuple [np .ndarray , object ],
33+ Tuple [list [np .ndarray ], object ],
34+ np .ndarray ,
35+ list [np .ndarray ],
36+ ]:
2837 """Validate input data and set or check the `n_features_in_` attribute.
2938
3039 Uses the `scikit-learn` 1.2.1 `_validate_data` function as a base.
3140
3241 Parameters
3342 ----------
34- X : {array-like, sparse matrix, dataframe} of shape \
35- (n_samples, n_features), default='no validation'
36- The input samples.
43+ X : ndarray or list of ndarrays of shape (n_samples, n_dimensions, \
44+ series_length), array-like, or 'no validation', default='no validation'
45+ The input samples. ideally a 3D numpy array or a list of 2D numpy
46+ arrays.
3747 If `'no_validation'`, no validation is performed on `X`. This is
3848 useful for meta-estimator which can delegate input validation to
3949 their underlying estimator(s). In that case `y` must be passed and
40- the only accepted `check_params` are `multi_output` and
41- `y_numeric`.
42-
43- y : array-like of shape (n_samples,), default='no_validation'
44- The targets.
50+ the only accepted `check_params` are `y_numeric`.
51+ y : array-like of shape (n_samples,), 'no_validation' or None, \
52+ default='no_validation'
53+ The target labels.
4554
46- - If `None`, `check_array ` is called on `X`. If the estimator's
55+ - If `None`, `check_X ` is called on `X`. If the estimator's
4756 requires_y tag is True, then an error will be raised.
48- - If `'no_validation'`, `check_array ` is called on `X` and the
57+ - If `'no_validation'`, `check_X ` is called on `X` and the
4958 estimator's requires_y tag is ignored. This is a default
5059 placeholder and is never meant to be explicitly set. In that case
5160 `X` must be passed.
5261 - Otherwise, only `y` with `_check_y` or both `X` and `y` are
53- checked with either `check_array` or `check_X_y` depending on
54- `validate_separately`.
55-
62+ checked with either `check_X_y`.
5663 reset : bool, default=True
5764 Whether to reset the `n_features_in_` attribute.
5865 If False, the input will be checked for consistency with data
5966 provided when reset was last True.
6067 .. note::
61- It is recommended to call reset=True in `fit` and in the first
62- call to `partial_fit`. All other methods that validate `X`
63- should set `reset=False`.
64-
65- validate_separately : False or tuple of dicts, default=False
66- Only used if y is not None.
67- If False, call validate_X_y(). Else, it must be a tuple of kwargs
68- to be used for calling check_array() on X and y respectively.
69-
70- `estimator=self` is automatically added to these dicts to generate
71- more informative error message in case of invalid input data.
72-
68+ It is recommended to call reset=True in `fit`. All other methods that
69+ validate `X` should set `reset=False`.
7370 **check_params : kwargs
74- Parameters passed to :func:`sklearn .utils.check_array` or
75- :func: `sklearn.utils.check_X_y`. Ignored if validate_separately
76- is not False .
71+ Parameters passed to :func:`tsml .utils.validation.check_X`,
72+ `sklearn.utils.validation._check_y` or
73+ :func:`tsml.utils.validation.check_X_y` .
7774
7875 `estimator=self` is automatically added to these params to generate
7976 more informative error message in case of invalid input data.
8077
8178 Returns
8279 -------
83- out : { ndarray, sparse matrix} or tuple of these
80+ out : np. ndarray, list of np.ndarray or tuple of these
8481 The validated input. A tuple is returned if both `X` and `y` are
8582 validated.
8683 """
@@ -90,7 +87,7 @@ def _validate_data(
9087 "requires y to be passed, but the target y is None."
9188 )
9289
93- no_val_X = X is None or ( isinstance (X , str ) and X == "no_validation" )
90+ no_val_X = isinstance (X , str ) and X == "no_validation"
9491 no_val_y = y is None or (isinstance (y , str ) and y == "no_validation" )
9592
9693 default_check_params = {"estimator" : self }
@@ -112,24 +109,26 @@ def _validate_data(
112109
113110 return out
114111
115- def _check_n_features (self , X , reset ):
112+ def _check_n_features (self , X : Union [ np . ndarray , list [ np . ndarray ]], reset : bool ):
116113 """Set the `n_features_in_` attribute, or check against it.
117114
118115 Uses the `scikit-learn` 1.2.1 `_check_n_features` function as a base.
119116
120117 Parameters
121118 ----------
122- X : {ndarray, sparse matrix} of shape (n_samples, n_features)
123- The input samples.
119+ X : ndarray or list of ndarrays of shape \
120+ (n_samples, n_dimensions, series_length)
121+ The input samples. Should be a 3D numpy array or a list of 2D numpy
122+ arrays.
124123 reset : bool
125- If True, the `n_features_in_` attribute is set to `X.shape[1]`.
124+ If True, the `n_features_in_` attribute is set to
125+ `(n_dimensions, min_series_length, max_series_length)`.
126126 If False and the attribute exists, then check that it is equal to
127- `X.shape[1]`. If False and the attribute does *not* exist, then
128- the check is skipped.
127+ `(n_dimensions, min_series_length, max_series_length)`.
128+ If False and the attribute does *not* exist, then the check is skipped.
129129 .. note::
130- It is recommended to call reset=True in `fit` and in the first
131- call to `partial_fit`. All other methods that validate `X`
132- should set `reset=False`.
130+ It is recommended to call reset=True in `fit`. All other methods that
131+ validate `X` should set `reset=False`.
133132 """
134133 try :
135134 n_features = _num_features (X )
@@ -167,11 +166,13 @@ def _check_n_features(self, X, reset):
167166 f"is expecting { self .n_features_in_ [1 ]} series length as input."
168167 )
169168
170- def _more_tags (self ):
169+ def _more_tags (self ) -> dict :
171170 return _DEFAULT_TAGS
172171
173172 @classmethod
174- def get_test_params (cls , parameter_set = None ):
173+ def get_test_params (
174+ cls , parameter_set : Union [str , None ] = None
175+ ) -> Union [dict , list [dict ]]:
175176 """Return unit test parameter settings for the estimator.
176177
177178 Parameters
@@ -182,11 +183,8 @@ def get_test_params(cls, parameter_set=None):
182183
183184 Returns
184185 -------
185- params : dict or list of dict, default = {}
186- Parameters to create testing instances of the class
187- Each dict are parameters to construct an "interesting" test instance, i.e.,
188- `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
189- `create_test_instance` uses the first (or only) dictionary in `params`
186+ params : dict or list of dict
187+ Parameters to create testing instances of the class.
190188 """
191189 if parameter_set is None :
192190 # default parameters = empty dict
@@ -197,9 +195,10 @@ def get_test_params(cls, parameter_set=None):
197195 )
198196
199197
200- def clone_estimator (
198+ def _clone_estimator (
201199 base_estimator : BaseEstimator , random_state : Union [None , int , RandomState ] = None
202200) -> BaseEstimator :
201+ """Clone an estimator and set the random state if available."""
203202 estimator = clone (base_estimator )
204203
205204 if random_state is not None :
0 commit comments