Skip to content

Commit 716a17c

Browse files
Include estimator class in results files and better testing for get_estimator methods (#331)
* better testing for get_estimator bugs * xgboost * mypy * mypy * mypy
1 parent a4d7b77 commit 716a17c

File tree

14 files changed

+125
-53
lines changed

14 files changed

+125
-53
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,9 @@ per-file-ignores = [
152152
[tool.mypy]
153153
mypy_path = "tsml_eval/"
154154
ignore_missing_imports = true
155+
follow_imports = "silent"
155156
exclude = [
156-
"_wip",
157+
"_wip/",
157158
# Ignore the publications symlinks and its contents
158159
"tsml_eval/publications/2023",
159160
]

tsml_eval/experiments/_get_clusterer.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
distance_based_clusterers = [
3131
"kmeans-euclidean",
3232
"kmeans-squared",
33-
"kmeans-dtw",
33+
["kmeans-dtw", "timeserieskmeans"],
3434
"kmeans-ddtw",
3535
"kmeans-wdtw",
3636
"kmeans-wddtw",
@@ -43,7 +43,7 @@
4343
"kmeans-shape_dtw",
4444
"kmedoids-euclidean",
4545
"kmedoids-squared",
46-
"kmedoids-dtw",
46+
["kmedoids-dtw", "timeserieskmedoids"],
4747
"kmedoids-ddtw",
4848
"kmedoids-wdtw",
4949
"kmedoids-wddtw",
@@ -56,7 +56,7 @@
5656
"kmedoids-shape_dtw",
5757
"clarans-euclidean",
5858
"clarans-squared",
59-
"clarans-dtw",
59+
["clarans-dtw", "timeseriesclarans"],
6060
"clarans-ddtw",
6161
"clarans-wdtw",
6262
"clarans-wddtw",
@@ -69,7 +69,7 @@
6969
"clarans-shape_dtw",
7070
"clara-euclidean",
7171
"clara-squared",
72-
"clara-dtw",
72+
["clara-dtw", "timeseriesclara"],
7373
"clara-ddtw",
7474
"clara-wdtw",
7575
"clara-wddtw",
@@ -114,7 +114,7 @@
114114
"kmeans-ssg-ba-msm",
115115
"kmeans-ssg-ba-adtw",
116116
"kmeans-ssg-ba-shape_dtw",
117-
"som-dtw",
117+
["som-dtw", "elasticsom"],
118118
"som-ddtw",
119119
"som-wdtw",
120120
"som-wddtw",
@@ -126,16 +126,8 @@
126126
"som-adtw",
127127
"som-shape_dtw",
128128
"som-soft_dtw",
129-
"ksc",
130-
"kshape",
131-
"timeserieskmeans",
132-
"timeserieskmedoids",
133-
"timeseriesclarans",
134-
"timeseriesclara",
135-
"elasticsom",
136-
"kspectralcentroid",
137-
"timeserieskshape",
138-
"timeserieskernelkmeans",
129+
["kspectralcentroid", "ksc"],
130+
["timeserieskshape", "kshape"],
139131
]
140132
feature_based_clusterers = [
141133
["catch22", "catch22clusterer"],
@@ -423,7 +415,7 @@ def _set_clusterer_distance_based(
423415
random_state=random_state,
424416
**kwargs,
425417
)
426-
elif "kshape" in c or "timeserieskshape" in c:
418+
elif "kshape" in c:
427419
return TimeSeriesKShape(
428420
init=init_algorithm,
429421
max_iter=50,
@@ -432,7 +424,7 @@ def _set_clusterer_distance_based(
432424
random_state=random_state,
433425
**kwargs,
434426
)
435-
elif c == "timeserieskernelkmeans" or c == "kernelkmeans":
427+
elif "timeserieskernelkmeans" in c:
436428
return TimeSeriesKernelKMeans(
437429
max_iter=50,
438430
n_init=10,

tsml_eval/experiments/experiments.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,9 @@ def run_classification_experiment(
230230
dataset_name,
231231
results_path,
232232
full_path=False,
233+
first_line_classifier_name=(
234+
f"{classifier_name} ({type(classifier).__name__})"
235+
),
233236
split="TRAIN",
234237
resample_id=resample_id,
235238
time_unit="MILLISECONDS",
@@ -280,6 +283,9 @@ def run_classification_experiment(
280283
dataset_name,
281284
results_path,
282285
full_path=False,
286+
first_line_classifier_name=(
287+
f"{classifier_name} ({type(classifier).__name__})"
288+
),
283289
split="TEST",
284290
resample_id=resample_id,
285291
time_unit="MILLISECONDS",
@@ -552,6 +558,7 @@ def run_regression_experiment(
552558
dataset_name,
553559
results_path,
554560
full_path=False,
561+
first_line_regressor_name=f"{regressor_name} ({type(regressor).__name__})",
555562
split="TRAIN",
556563
resample_id=resample_id,
557564
time_unit="MILLISECONDS",
@@ -597,6 +604,7 @@ def run_regression_experiment(
597604
dataset_name,
598605
results_path,
599606
full_path=False,
607+
first_line_regressor_name=f"{regressor_name} ({type(regressor).__name__})",
600608
split="TEST",
601609
resample_id=resample_id,
602610
time_unit="MILLISECONDS",
@@ -916,6 +924,7 @@ def run_clustering_experiment(
916924
dataset_name,
917925
results_path,
918926
full_path=False,
927+
first_line_clusterer_name=f"{clusterer_name} ({type(clusterer).__name__})",
919928
split="TRAIN",
920929
resample_id=resample_id,
921930
time_unit="MILLISECONDS",
@@ -960,6 +969,7 @@ def run_clustering_experiment(
960969
dataset_name,
961970
results_path,
962971
full_path=False,
972+
first_line_clusterer_name=f"{clusterer_name} ({type(clusterer).__name__})",
963973
split="TEST",
964974
resample_id=resample_id,
965975
time_unit="MILLISECONDS",
@@ -1197,6 +1207,7 @@ def run_forecasting_experiment(
11971207
dataset_name,
11981208
results_path,
11991209
full_path=False,
1210+
first_line_forecaster_name=f"{forecaster_name} ({type(forecaster).__name__})",
12001211
split="TEST",
12011212
random_seed=random_seed,
12021213
time_unit="MILLISECONDS",

tsml_eval/experiments/tests/test_classification.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def test_run_classification_experiment_invalid_estimator():
171171

172172
def test_get_classifier_by_name():
173173
"""Test get_classifier_by_name method."""
174-
classifier_lists = [
174+
classifier_name_lists = [
175175
_get_classifier.convolution_based_classifiers,
176176
_get_classifier.deep_learning_classifiers,
177177
_get_classifier.dictionary_based_classifiers,
@@ -184,12 +184,14 @@ def test_get_classifier_by_name():
184184
_get_classifier.vector_classifiers,
185185
]
186186

187+
# filled by _check_set_method
188+
classifier_list = []
187189
classifier_dict = {}
188190
all_classifier_names = []
189-
190-
for classifier_list in classifier_lists:
191+
for classifier_name_list in classifier_name_lists:
191192
_check_set_method(
192193
get_classifier_by_name,
194+
classifier_name_list,
193195
classifier_list,
194196
classifier_dict,
195197
all_classifier_names,

tsml_eval/experiments/tests/test_clustering.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def test_run_clustering_experiment_invalid_estimator():
173173

174174
def test_get_clusterer_by_name():
175175
"""Test get_clusterer_by_name method."""
176-
clusterer_lists = [
176+
clusterer_name_lists = [
177177
_get_clusterer.deep_learning_clusterers,
178178
_get_clusterer.distance_based_clusterers,
179179
_get_clusterer.feature_based_clusterers,
@@ -187,29 +187,30 @@ def test_get_clusterer_by_name():
187187
"base_estimator",
188188
]
189189

190+
clusterer_list = []
190191
clusterer_dict = {}
191192
all_clusterer_names = []
192-
193-
for clusterer_list in clusterer_lists:
194-
estimatorrs = _check_set_method(
193+
for clusterer_name_list in clusterer_name_lists:
194+
_check_set_method(
195195
get_clusterer_by_name,
196+
clusterer_name_list,
196197
clusterer_list,
197198
clusterer_dict,
198199
all_clusterer_names,
199-
return_estimator=True,
200200
)
201201

202-
# Check that clusterers with estimator parameters which are likely to be
203-
# a sub-estimator are not None so n_clusters can be set
204-
for clusterer in estimatorrs:
202+
# Check that clusterers with parameters which are likely to be
203+
# a sub-estimator are not None so n_clusters can be set
204+
for clusterers in clusterer_list:
205+
for c in clusterers:
205206
for param_name in clusterer_non_default_params:
206-
params = clusterer.get_params()
207+
params = c.get_params()
207208
if param_name in params:
208209
assert params[param_name] is not None, (
209210
f"Clusterers which have an estimator parameter i.e. "
210211
f"pipelines and deep learners must not have None as the "
211212
f"estimator. Found None for {param_name} in "
212-
f"{clusterer.__class__.__name__}"
213+
f"{c.__class__.__name__}"
213214
)
214215

215216
_check_set_method_results(
@@ -230,6 +231,7 @@ def test_aeon_clusterers_available():
230231
"ClustererPipeline",
231232
"SklearnClustererWrapper",
232233
# just missing
234+
"TimeSeriesKernelKMeans",
233235
]
234236

235237
est = [e for e, _ in all_estimators(type_filter="clusterer")]

tsml_eval/experiments/tests/test_data_transform.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99

1010
def test_get_data_transform_by_name():
1111
"""Test get_data_transform_by_name method."""
12-
transform_lists = [_get_data_transform.transformers]
12+
transform_name_lists = [_get_data_transform.transformers]
1313

14+
transform_list = []
1415
transform_dict = {}
1516
all_transform_names = []
16-
17-
for transform_list in transform_lists:
17+
for transform_name_list in transform_name_lists:
1818
_check_set_method(
1919
get_data_transform_by_name,
20+
transform_name_list,
2021
transform_list,
2122
transform_dict,
2223
all_transform_names,

tsml_eval/experiments/tests/test_regression.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def test_run_regression_experiment_invalid_estimator():
170170

171171
def test_get_regressor_by_name():
172172
"""Test get_regressor_by_name method."""
173-
regressor_lists = [
173+
regressor_name_lists = [
174174
_get_regressor.convolution_based_regressors,
175175
_get_regressor.deep_learning_regressors,
176176
_get_regressor.distance_based_regressors,
@@ -182,12 +182,13 @@ def test_get_regressor_by_name():
182182
_get_regressor.vector_regressors,
183183
]
184184

185+
regressor_list = []
185186
regressor_dict = {}
186187
all_regressor_names = []
187-
188-
for regressor_list in regressor_lists:
188+
for regressor_name_list in regressor_name_lists:
189189
_check_set_method(
190190
get_regressor_by_name,
191+
regressor_name_list,
191192
regressor_list,
192193
regressor_dict,
193194
all_regressor_names,

tsml_eval/publications/y2023/distance_based_clustering/tests/test_set_distance_clusterer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111

1212
def test_set_distance_clusterer():
1313
"""Test set_distance_clusterer method."""
14+
clusterer_list = []
1415
clusterer_dict = {}
1516
all_clusterer_names = []
16-
1717
_check_set_method(
1818
_set_distance_clusterer,
1919
distance_based_clusterers,
20+
clusterer_list,
2021
clusterer_dict,
2122
all_clusterer_names,
2223
)

tsml_eval/publications/y2023/rist_pipeline/tests/test_set_estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313

1414
def test_set_rist_classifier():
1515
"""Test set_rist_classifier method."""
16+
classifier_list = []
1617
classifier_dict = {}
1718
all_classifier_names = []
18-
1919
_check_set_method(
2020
_set_rist_classifier,
2121
rist_classifiers,
22+
classifier_list,
2223
classifier_dict,
2324
all_classifier_names,
2425
)
@@ -38,12 +39,13 @@ def test_set_rist_classifier_invalid():
3839

3940
def test_set_rist_regressor():
4041
"""Test set_rist_regressors method."""
42+
regressor_list = []
4143
regressor_dict = {}
4244
all_regressor_names = []
43-
4445
_check_set_method(
4546
_set_rist_regressor,
4647
rist_regressors,
48+
regressor_list,
4749
regressor_dict,
4850
all_regressor_names,
4951
)

tsml_eval/publications/y2023/tsc_bakeoff/tests/test_set_classifier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111

1212
def test_set_bakeoff_classifier():
1313
"""Test set_bakeoff_classifier method."""
14+
classifier_list = []
1415
classifier_dict = {}
1516
all_classifier_names = []
16-
1717
_check_set_method(
1818
_set_bakeoff_classifier,
1919
bakeoff_classifiers,
20+
classifier_list,
2021
classifier_dict,
2122
all_classifier_names,
2223
)

0 commit comments

Comments
 (0)