1212import numpy as np
1313from joblib import Parallel
1414from sklearn .base import BaseEstimator , is_classifier , is_regressor
15- from sklearn .tree import (
16- BaseDecisionTree ,
17- DecisionTreeClassifier ,
18- DecisionTreeRegressor ,
19- ExtraTreeClassifier ,
20- )
15+ from sklearn .tree import BaseDecisionTree , DecisionTreeClassifier , DecisionTreeRegressor
2116from sklearn .utils import check_random_state
2217from sklearn .utils .fixes import delayed
2318from sklearn .utils .multiclass import check_classification_targets
@@ -39,12 +34,10 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
3934 Allows the implementation of classifiers and regressors along the lines of [1][2][3]
4035 which extract intervals and create an ensemble from the subsequent features.
4136
42- #skipping predict todo
43-
4437 Parameters
4538 ----------
4639 base_estimator : BaseEstimator or None, default=None
47- scikit-learn BaseEstimator used to build the interval ensemble. If None, uses a
40+ scikit-learn BaseEstimator used to build the interval ensemble. If None, use a
4841 simple decision tree.
4942 n_estimators : int, default=200
5043 Number of estimators to build for the ensemble.
@@ -65,8 +58,8 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
6558 input will return a function of the series length (may differ per
6659 series_transformers output) to extract that number of intervals.
6760 Valid str inputs are:
68- - "sqrt" : square root of the series length.
69- - "sqrt-div" : sqrt of series length divided by the number
61+ - "sqrt": square root of the series length.
62+ - "sqrt-div": sqrt of series length divided by the number
7063 of series_transformers.
7164
7265 A list or tuple of ints and/or strs will extract the number of intervals using
@@ -78,7 +71,9 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
7871 another list or tuple must be the same length as the number of
7972 series_transformers.
8073
81- %todo random vs supervised
74+ While random interval extraction will extract the n_intervals intervals total
75+ (removing duplicates), supervised intervals will run the supervised extraction
76+ process n_intervals times, returning more intervals than specified.
8277 min_interval_length : int, float, list, or tuple, default=3
8378 Minimum length of intervals to extract from series. float inputs take a
8479 proportion of the series length to use as the minimum interval length.
@@ -97,10 +92,10 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
9792 Ignored for supervised interval_selection_method inputs.
9893 interval_features : TransformerMixin, callable, list, tuple, or None, default=None
9994 The features to extract from the intervals using transformers or callable
100- functions. If None, uses the mean, standard deviation, and slope of the series.
95+ functions. If None, use the mean, standard deviation, and slope of the series.
10196
102- Both transformers and functions should be able to take a 2d np.ndarray input.
103- Functions should output a 1d array (the feature for each series) and
97+ Both transformers and functions should be able to take a 2D np.ndarray input.
98+ Functions should output a 1d array (the feature for each series), and
10499 transformers should output a 2d array where rows are the features for each
105100 series. A list or tuple of transformers and/or functions will extract all
106101 features and concatenate the output.
@@ -109,14 +104,29 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
109104 nested list or tuple. Any list or tuple input containing another list or tuple
110105 must be the same length as the number of series_transformers.
111106 series_transformers : TransformerMixin, list, tuple, or None, default=None
112-
113- att_subsample_size : int or None, default=None
114- Number of catch22 or summary statistic attributes to subsample per tree.
115- replace_nan :
116-
107+ The transformers to apply to the series before extracting intervals. If None,
108+ use the series as is.
109+
110+ Both transformers and functions should be able to take a 3D np.ndarray input.
111+ A list or tuple of transformers and/or functions will extract intervals from
112+ all transformations concatenate the output. Including None in the list or tuple
113+ will use the series as is for interval extraction.
114+ att_subsample_size : int, float, list, tuple or None, default=None
115+ The number of attributes to subsample for each estimator. If None, use all
116+
117+ If int, use that number of attributes for all estimators. If float, use that
118+ proportion of attributes for all estimators.
119+
120+ Different subsample sizes for each series_transformers series can be specified
121+ using a list or tuple. Any list or tuple input must be the same length as the
122+ number of series_transformers.
123+ replace_nan : "nan", int, float or None, default=None
124+ The value to replace NaNs and infinite values with before fitting the base
125+ estimator. int or float input will replace with the specified value, while
126+ "nan" will replace infinite values with NaNs. If None, do not replace NaNs.
117127 time_limit_in_minutes : int, default=0
118128 Time contract to limit build time in minutes, overriding n_estimators.
119- Default of 0 means n_estimators is used.
129+ Default of 0 means n_estimators are used.
120130 contract_max_n_estimators : int, default=500
121131 Max number of estimators when time_limit_in_minutes is set.
122132 save_transformed_data : bool, default=False
@@ -139,7 +149,7 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
139149 ----------
140150 n_instances_ : int
141151 The number of train cases.
142- n_dims_ : int
152+ n_channels_ : int
143153 The number of channels per case.
144154 series_length_ : int
145155 The length of each series.
@@ -149,9 +159,6 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
149159 The collections of estimators trained in fit.
150160 intervals_ : list of shape (n_estimators) of ndarray with shape (total_intervals,2)
151161 Stores indexes of each intervals start and end points for all classifiers.
152- atts_ : list of shape (n_estimators) of array with shape (att_subsample_size)
153- Attribute indexes of the subsampled catch22 or summary statistic for all
154- classifiers.
155162 transformed_data_ : list of shape (n_estimators) of ndarray with shape
156163 (n_instances,total_intervals * att_subsample_size)
157164 The transformed dataset for all classifiers. Only saved when
@@ -221,7 +228,7 @@ def fit(self, X, y):
221228
222229 rng = check_random_state (self .random_state )
223230
224- self .n_instances_ , self .n_dims_ , self .series_length_ = X .shape
231+ self .n_instances_ , self .n_channels_ , self .series_length_ = X .shape
225232 if is_classifier (self ):
226233 check_classification_targets (y )
227234
@@ -236,16 +243,7 @@ def fit(self, X, y):
236243
237244 self ._base_estimator = self .base_estimator
238245 if self .base_estimator is None :
239- from tsml .interval_based import RSTSFClassifier
240-
241- # default base_estimators for classification and regression
242- if isinstance (self , RSTSFClassifier ):
243- self ._base_estimator = ExtraTreeClassifier (
244- criterion = "entropy" ,
245- class_weight = "balanced" ,
246- max_features = "sqrt" ,
247- )
248- elif is_classifier (self ):
246+ if is_classifier (self ):
249247 self ._base_estimator = DecisionTreeClassifier (criterion = "entropy" )
250248 elif is_regressor (self ):
251249 self ._base_estimator = DecisionTreeRegressor (criterion = "absolute_error" )
@@ -285,10 +283,14 @@ def fit(self, X, y):
285283 Xt .append (t .fit_transform (X , y ))
286284 self ._series_transformers .append (t )
287285 else :
288- raise ValueError () # todo error for invalid self.series_transformers
286+ raise ValueError (
287+ f"Invalid series_transformers list input. Found { transformer } "
288+ )
289289 # other inputs are invalid
290290 else :
291- raise ValueError () # todo error for invalid self.series_transformers
291+ raise ValueError (
292+ f"Invalid series_transformers input. Found { self .series_transformers } "
293+ )
292294
293295 # if only a single n_intervals value is passed it must be an int or str
294296 if isinstance (self .n_intervals , (int , str )):
@@ -533,14 +535,18 @@ def fit(self, X, y):
533535 # att_subsample_size must be at least one if it is an int
534536 if isinstance (self .att_subsample_size , int ):
535537 if self .att_subsample_size < 1 :
536- raise ValueError () # todo error for invalid invalid self.att_subsample_size
538+ raise ValueError (
539+ "att_subsample_size must be at least one if it is an int."
540+ )
537541
538542 self ._att_subsample_size = [self .att_subsample_size ] * len (Xt )
539543 # att_subsample_size must be at less than one if it is a float (proportion of
540544 # total attributed to subsample)
541545 elif isinstance (self .att_subsample_size , float ):
542- if self .att_subsample_size > 1 :
543- raise ValueError () # todo error for invalid invalid self.att_subsample_size
546+ if self .att_subsample_size > 1 or self .att_subsample_size <= 0 :
547+ raise ValueError (
548+ "att_subsample_size must be between 0 and 1 if it is a float."
549+ )
544550
545551 self ._att_subsample_size = [self .att_subsample_size ] * len (Xt )
546552 # default is no attribute subsampling with None
@@ -552,27 +558,42 @@ def fit(self, X, y):
552558 # performed
553559 elif isinstance (self .att_subsample_size , (list , tuple )):
554560 if len (self .att_subsample_size ) != len (Xt ):
555- raise ValueError () # todo error for invalid self.att_subsample_size
561+ raise ValueError (
562+ "att_subsample_size as a list or tuple must be the same length as "
563+ "series_transformers."
564+ )
556565
557566 self ._att_subsample_size = []
558567 for ssize in self .att_subsample_size :
559568 if isinstance (ssize , int ):
560569 if ssize < 1 :
561- raise ValueError () # todo error for invalid invalid self.att_subsample_size
570+ raise ValueError (
571+ "att_subsample_size in list must be at least one if it is "
572+ "an int."
573+ )
562574
563575 self ._att_subsample_size .append (ssize )
564576 elif isinstance (ssize , float ):
565577 if ssize > 1 :
566- raise ValueError () # todo error for invalid invalid self.att_subsample_size
578+ raise ValueError (
579+ "att_subsample_size in list must be between 0 and 1 if it "
580+ "is a "
581+ "float."
582+ )
567583
568584 self ._att_subsample_size .append (ssize )
569585 elif ssize is None :
570586 self ._att_subsample_size .append (ssize )
571587 else :
572- raise ValueError () # todo error for invalid self.att_subsample_size
588+ raise ValueError (
589+ "Invalid interval_features input in list. Found "
590+ f"{ self .att_subsample_size } "
591+ )
573592 # other inputs are invalid
574593 else :
575- raise ValueError () # todo error for invalid invalid self.att_subsample_size
594+ raise ValueError (
595+ f"Invalid interval_features input. Found { self .att_subsample_size } "
596+ )
576597
577598 # if we are subsampling attributes for a series_transformer and it uses a
578599 # BaseTransformer, we must ensure it has the required parameters and
@@ -596,7 +617,11 @@ def fit(self, X, y):
596617 break
597618
598619 if not has_params :
599- raise ValueError () # todo error for invalid invalid self.att_subsample_size
620+ raise ValueError (
621+ "All transformers in interval_features must have a "
622+ "parameter named in transformer_feature_selection to "
623+ "be used in attribute subsampling."
624+ )
600625
601626 # the transformer must have an attribute with one of the
602627 # names listed in transformer_feature_names as a list or tuple
@@ -611,7 +636,12 @@ def fit(self, X, y):
611636 break
612637
613638 if not has_feature_names :
614- raise ValueError () # todo error for invalid invalid self.att_subsample_size
639+ raise ValueError (
640+ "All transformers in interval_features must have an "
641+ "attribute or propertynamed in "
642+ "transformer_feature_names to be used in attribute "
643+ "subsampling."
644+ )
615645
616646 # verify the interval_selection_method is a valid string
617647 if isinstance (self .interval_selection_method , str ):
@@ -652,13 +682,14 @@ def fit(self, X, y):
652682 and not isinstance (self .replace_nan , (int , float ))
653683 and self .replace_nan is not None
654684 ):
655- raise ValueError () # todo error for invalid self.replace_nan
685+ raise ValueError (f"Invalid replace_nan input. Found { self .replace_nan } " )
656686
657687 self ._n_jobs = check_n_jobs (self .n_jobs )
658688
659- self ._efficient_predictions = True # todo
689+ # flags for testing. not used in the actual algorithm
690+ self ._efficient_predictions = True
660691 if not hasattr (self , "_test_flag" ):
661- self ._test_flag = False # todo
692+ self ._test_flag = False
662693
663694 if self .time_limit_in_minutes is not None and self .time_limit_in_minutes > 0 :
664695 time_limit = self .time_limit_in_minutes * 60
@@ -869,7 +900,9 @@ def _fit_estimator(self, Xt, y, seed):
869900 features .append (all_function_features [atts [count + i ] - length ])
870901 else :
871902 warnings .warn (
872- f"Attribute subsample size { att_subsample_size } is larger than or equal to the number of attributes { num_features } for series { self ._series_transformers [r ]} "
903+ f"Attribute subsample size { att_subsample_size } is larger than "
904+ f"or equal to the number of attributes { num_features } for "
905+ f"series { self ._series_transformers [r ]} "
873906 )
874907 for feature in self ._interval_features [r ]:
875908 if is_transformer (feature ):
@@ -910,8 +943,6 @@ def _fit_estimator(self, Xt, y, seed):
910943 randomised_split_point = True ,
911944 random_state = seed ,
912945 )
913- else :
914- raise ValueError () # todo error for invalid self.interval_selection_method, should not get here
915946
916947 # fit the interval selector, transform the current series using it and save
917948 # the transformer
@@ -988,9 +1019,9 @@ def _predict_setup(self, X):
9881019 X = self ._validate_data (X = X , reset = False )
9891020 X = self ._convert_X (X )
9901021
991- n_instances , n_dims , series_length = X .shape
1022+ n_instances , n_channels , series_length = X .shape
9921023
993- if n_dims != self .n_dims_ :
1024+ if n_channels != self .n_channels_ :
9941025 raise ValueError (
9951026 "The number of channels in the train data does not match the number "
9961027 "of channels in the test data"
0 commit comments