Skip to content

Commit 805e979

Browse files
rdst bug and interval docs
1 parent c760153 commit 805e979

File tree

4 files changed

+121
-116
lines changed

4 files changed

+121
-116
lines changed

tsml/interval_based/_base.py

Lines changed: 87 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,7 @@
1212
import numpy as np
1313
from joblib import Parallel
1414
from sklearn.base import BaseEstimator, is_classifier, is_regressor
15-
from sklearn.tree import (
16-
BaseDecisionTree,
17-
DecisionTreeClassifier,
18-
DecisionTreeRegressor,
19-
ExtraTreeClassifier,
20-
)
15+
from sklearn.tree import BaseDecisionTree, DecisionTreeClassifier, DecisionTreeRegressor
2116
from sklearn.utils import check_random_state
2217
from sklearn.utils.fixes import delayed
2318
from sklearn.utils.multiclass import check_classification_targets
@@ -39,12 +34,10 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
3934
Allows the implementation of classifiers and regressors along the lines of [1][2][3]
4035
which extract intervals and create an ensemble from the subsequent features.
4136
42-
#skipping predict todo
43-
4437
Parameters
4538
----------
4639
base_estimator : BaseEstimator or None, default=None
47-
scikit-learn BaseEstimator used to build the interval ensemble. If None, uses a
40+
scikit-learn BaseEstimator used to build the interval ensemble. If None, use a
4841
simple decision tree.
4942
n_estimators : int, default=200
5043
Number of estimators to build for the ensemble.
@@ -65,8 +58,8 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
6558
input will return a function of the series length (may differ per
6659
series_transformers output) to extract that number of intervals.
6760
Valid str inputs are:
68-
- "sqrt" : square root of the series length.
69-
- "sqrt-div" : sqrt of series length divided by the number
61+
- "sqrt": square root of the series length.
62+
- "sqrt-div": sqrt of series length divided by the number
7063
of series_transformers.
7164
7265
A list or tuple of ints and/or strs will extract the number of intervals using
@@ -78,7 +71,9 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
7871
another list or tuple must be the same length as the number of
7972
series_transformers.
8073
81-
%todo random vs supervised
74+
While random interval extraction will extract the n_intervals intervals total
75+
(removing duplicates), supervised intervals will run the supervised extraction
76+
process n_intervals times, returning more intervals than specified.
8277
min_interval_length : int, float, list, or tuple, default=3
8378
Minimum length of intervals to extract from series. float inputs take a
8479
proportion of the series length to use as the minimum interval length.
@@ -97,10 +92,10 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
9792
Ignored for supervised interval_selection_method inputs.
9893
interval_features : TransformerMixin, callable, list, tuple, or None, default=None
9994
The features to extract from the intervals using transformers or callable
100-
functions. If None, uses the mean, standard deviation, and slope of the series.
95+
functions. If None, use the mean, standard deviation, and slope of the series.
10196
102-
Both transformers and functions should be able to take a 2d np.ndarray input.
103-
Functions should output a 1d array (the feature for each series) and
97+
Both transformers and functions should be able to take a 2D np.ndarray input.
98+
Functions should output a 1d array (the feature for each series), and
10499
transformers should output a 2d array where rows are the features for each
105100
series. A list or tuple of transformers and/or functions will extract all
106101
features and concatenate the output.
@@ -109,14 +104,29 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
109104
nested list or tuple. Any list or tuple input containing another list or tuple
110105
must be the same length as the number of series_transformers.
111106
series_transformers : TransformerMixin, list, tuple, or None, default=None
112-
113-
att_subsample_size : int or None, default=None
114-
Number of catch22 or summary statistic attributes to subsample per tree.
115-
replace_nan :
116-
107+
The transformers to apply to the series before extracting intervals. If None,
108+
use the series as is.
109+
110+
Both transformers and functions should be able to take a 3D np.ndarray input.
111+
A list or tuple of transformers and/or functions will extract intervals from
112+
all transformations concatenate the output. Including None in the list or tuple
113+
will use the series as is for interval extraction.
114+
att_subsample_size : int, float, list, tuple or None, default=None
115+
The number of attributes to subsample for each estimator. If None, use all
116+
117+
If int, use that number of attributes for all estimators. If float, use that
118+
proportion of attributes for all estimators.
119+
120+
Different subsample sizes for each series_transformers series can be specified
121+
using a list or tuple. Any list or tuple input must be the same length as the
122+
number of series_transformers.
123+
replace_nan : "nan", int, float or None, default=None
124+
The value to replace NaNs and infinite values with before fitting the base
125+
estimator. int or float input will replace with the specified value, while
126+
"nan" will replace infinite values with NaNs. If None, do not replace NaNs.
117127
time_limit_in_minutes : int, default=0
118128
Time contract to limit build time in minutes, overriding n_estimators.
119-
Default of 0 means n_estimators is used.
129+
Default of 0 means n_estimators are used.
120130
contract_max_n_estimators : int, default=500
121131
Max number of estimators when time_limit_in_minutes is set.
122132
save_transformed_data : bool, default=False
@@ -139,7 +149,7 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
139149
----------
140150
n_instances_ : int
141151
The number of train cases.
142-
n_dims_ : int
152+
n_channels_ : int
143153
The number of channels per case.
144154
series_length_ : int
145155
The length of each series.
@@ -149,9 +159,6 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
149159
The collections of estimators trained in fit.
150160
intervals_ : list of shape (n_estimators) of ndarray with shape (total_intervals,2)
151161
Stores indexes of each intervals start and end points for all classifiers.
152-
atts_ : list of shape (n_estimators) of array with shape (att_subsample_size)
153-
Attribute indexes of the subsampled catch22 or summary statistic for all
154-
classifiers.
155162
transformed_data_ : list of shape (n_estimators) of ndarray with shape
156163
(n_instances,total_intervals * att_subsample_size)
157164
The transformed dataset for all classifiers. Only saved when
@@ -221,7 +228,7 @@ def fit(self, X, y):
221228

222229
rng = check_random_state(self.random_state)
223230

224-
self.n_instances_, self.n_dims_, self.series_length_ = X.shape
231+
self.n_instances_, self.n_channels_, self.series_length_ = X.shape
225232
if is_classifier(self):
226233
check_classification_targets(y)
227234

@@ -236,16 +243,7 @@ def fit(self, X, y):
236243

237244
self._base_estimator = self.base_estimator
238245
if self.base_estimator is None:
239-
from tsml.interval_based import RSTSFClassifier
240-
241-
# default base_estimators for classification and regression
242-
if isinstance(self, RSTSFClassifier):
243-
self._base_estimator = ExtraTreeClassifier(
244-
criterion="entropy",
245-
class_weight="balanced",
246-
max_features="sqrt",
247-
)
248-
elif is_classifier(self):
246+
if is_classifier(self):
249247
self._base_estimator = DecisionTreeClassifier(criterion="entropy")
250248
elif is_regressor(self):
251249
self._base_estimator = DecisionTreeRegressor(criterion="absolute_error")
@@ -285,10 +283,14 @@ def fit(self, X, y):
285283
Xt.append(t.fit_transform(X, y))
286284
self._series_transformers.append(t)
287285
else:
288-
raise ValueError() # todo error for invalid self.series_transformers
286+
raise ValueError(
287+
f"Invalid series_transformers list input. Found {transformer}"
288+
)
289289
# other inputs are invalid
290290
else:
291-
raise ValueError() # todo error for invalid self.series_transformers
291+
raise ValueError(
292+
f"Invalid series_transformers input. Found {self.series_transformers}"
293+
)
292294

293295
# if only a single n_intervals value is passed it must be an int or str
294296
if isinstance(self.n_intervals, (int, str)):
@@ -533,14 +535,18 @@ def fit(self, X, y):
533535
# att_subsample_size must be at least one if it is an int
534536
if isinstance(self.att_subsample_size, int):
535537
if self.att_subsample_size < 1:
536-
raise ValueError() # todo error for invalid invalid self.att_subsample_size
538+
raise ValueError(
539+
"att_subsample_size must be at least one if it is an int."
540+
)
537541

538542
self._att_subsample_size = [self.att_subsample_size] * len(Xt)
539543
# att_subsample_size must be at less than one if it is a float (proportion of
540544
# total attributed to subsample)
541545
elif isinstance(self.att_subsample_size, float):
542-
if self.att_subsample_size > 1:
543-
raise ValueError() # todo error for invalid invalid self.att_subsample_size
546+
if self.att_subsample_size > 1 or self.att_subsample_size <= 0:
547+
raise ValueError(
548+
"att_subsample_size must be between 0 and 1 if it is a float."
549+
)
544550

545551
self._att_subsample_size = [self.att_subsample_size] * len(Xt)
546552
# default is no attribute subsampling with None
@@ -552,27 +558,42 @@ def fit(self, X, y):
552558
# performed
553559
elif isinstance(self.att_subsample_size, (list, tuple)):
554560
if len(self.att_subsample_size) != len(Xt):
555-
raise ValueError() # todo error for invalid self.att_subsample_size
561+
raise ValueError(
562+
"att_subsample_size as a list or tuple must be the same length as "
563+
"series_transformers."
564+
)
556565

557566
self._att_subsample_size = []
558567
for ssize in self.att_subsample_size:
559568
if isinstance(ssize, int):
560569
if ssize < 1:
561-
raise ValueError() # todo error for invalid invalid self.att_subsample_size
570+
raise ValueError(
571+
"att_subsample_size in list must be at least one if it is "
572+
"an int."
573+
)
562574

563575
self._att_subsample_size.append(ssize)
564576
elif isinstance(ssize, float):
565577
if ssize > 1:
566-
raise ValueError() # todo error for invalid invalid self.att_subsample_size
578+
raise ValueError(
579+
"att_subsample_size in list must be between 0 and 1 if it "
580+
"is a "
581+
"float."
582+
)
567583

568584
self._att_subsample_size.append(ssize)
569585
elif ssize is None:
570586
self._att_subsample_size.append(ssize)
571587
else:
572-
raise ValueError() # todo error for invalid self.att_subsample_size
588+
raise ValueError(
589+
"Invalid interval_features input in list. Found "
590+
f"{self.att_subsample_size}"
591+
)
573592
# other inputs are invalid
574593
else:
575-
raise ValueError() # todo error for invalid invalid self.att_subsample_size
594+
raise ValueError(
595+
f"Invalid interval_features input. Found {self.att_subsample_size}"
596+
)
576597

577598
# if we are subsampling attributes for a series_transformer and it uses a
578599
# BaseTransformer, we must ensure it has the required parameters and
@@ -596,7 +617,11 @@ def fit(self, X, y):
596617
break
597618

598619
if not has_params:
599-
raise ValueError() # todo error for invalid invalid self.att_subsample_size
620+
raise ValueError(
621+
"All transformers in interval_features must have a "
622+
"parameter named in transformer_feature_selection to "
623+
"be used in attribute subsampling."
624+
)
600625

601626
# the transformer must have an attribute with one of the
602627
# names listed in transformer_feature_names as a list or tuple
@@ -611,7 +636,12 @@ def fit(self, X, y):
611636
break
612637

613638
if not has_feature_names:
614-
raise ValueError() # todo error for invalid invalid self.att_subsample_size
639+
raise ValueError(
640+
"All transformers in interval_features must have an "
641+
"attribute or propertynamed in "
642+
"transformer_feature_names to be used in attribute "
643+
"subsampling."
644+
)
615645

616646
# verify the interval_selection_method is a valid string
617647
if isinstance(self.interval_selection_method, str):
@@ -652,13 +682,14 @@ def fit(self, X, y):
652682
and not isinstance(self.replace_nan, (int, float))
653683
and self.replace_nan is not None
654684
):
655-
raise ValueError() # todo error for invalid self.replace_nan
685+
raise ValueError(f"Invalid replace_nan input. Found {self.replace_nan}")
656686

657687
self._n_jobs = check_n_jobs(self.n_jobs)
658688

659-
self._efficient_predictions = True # todo
689+
# flags for testing. not used in the actual algorithm
690+
self._efficient_predictions = True
660691
if not hasattr(self, "_test_flag"):
661-
self._test_flag = False # todo
692+
self._test_flag = False
662693

663694
if self.time_limit_in_minutes is not None and self.time_limit_in_minutes > 0:
664695
time_limit = self.time_limit_in_minutes * 60
@@ -869,7 +900,9 @@ def _fit_estimator(self, Xt, y, seed):
869900
features.append(all_function_features[atts[count + i] - length])
870901
else:
871902
warnings.warn(
872-
f"Attribute subsample size {att_subsample_size} is larger than or equal to the number of attributes {num_features} for series {self._series_transformers[r]}"
903+
f"Attribute subsample size {att_subsample_size} is larger than "
904+
f"or equal to the number of attributes {num_features} for "
905+
f"series {self._series_transformers[r]}"
873906
)
874907
for feature in self._interval_features[r]:
875908
if is_transformer(feature):
@@ -910,8 +943,6 @@ def _fit_estimator(self, Xt, y, seed):
910943
randomised_split_point=True,
911944
random_state=seed,
912945
)
913-
else:
914-
raise ValueError() # todo error for invalid self.interval_selection_method, should not get here
915946

916947
# fit the interval selector, transform the current series using it and save
917948
# the transformer
@@ -988,9 +1019,9 @@ def _predict_setup(self, X):
9881019
X = self._validate_data(X=X, reset=False)
9891020
X = self._convert_X(X)
9901021

991-
n_instances, n_dims, series_length = X.shape
1022+
n_instances, n_channels, series_length = X.shape
9921023

993-
if n_dims != self.n_dims_:
1024+
if n_channels != self.n_channels_:
9941025
raise ValueError(
9951026
"The number of channels in the train data does not match the number "
9961027
"of channels in the test data"

tsml/interval_based/tests/test_interval_forest.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def test_interval_forest_n_intervals(n_intervals, n_intervals_len):
9494
n_intervals=n_intervals,
9595
series_transformers=[None, FunctionTransformer(np.log1p)],
9696
save_transformed_data=True,
97+
random_state=0,
9798
)
9899
est.fit(X, y)
99100
est.predict_proba(X)
@@ -140,6 +141,7 @@ def test_interval_forest_attribute_subsample(features, output_len):
140141
interval_features=features,
141142
replace_nan=0,
142143
save_transformed_data=True,
144+
random_state=0,
143145
)
144146
est.fit(X, y)
145147
est.predict_proba(X)
@@ -166,6 +168,7 @@ def test_interval_forest_invalid_attribute_subsample():
166168
@pytest.mark.parametrize(
167169
"series_transformer",
168170
[
171+
FunctionTransformer(np.log1p),
169172
[None, FunctionTransformer(np.log1p)],
170173
[FunctionTransformer(np.log1p), ARCoefficientTransformer()],
171174
],
@@ -179,9 +182,13 @@ def test_interval_forest_series_transformer(series_transformer):
179182
n_intervals=2,
180183
series_transformers=series_transformer,
181184
save_transformed_data=True,
185+
random_state=0,
182186
)
183187
est.fit(X, y)
184188
est.predict_proba(X)
185189

186190
data = est.transformed_data_
187-
assert data[0].shape[1] == 12
191+
expected = (
192+
len(series_transformer) * 6 if isinstance(series_transformer, list) else 6
193+
)
194+
assert data[0].shape[1] == expected

0 commit comments

Comments
 (0)