Skip to content

Commit ff28266

Browse files
authored
[ENH] Add Feature load_model to Deep Regressor Ensembles (#3130)
* fix: docs, fix: Correct misleading load_model docstring & resolve NameError in return type * feat: implement load_model functionality to InceptionTimeRegressor and LiteTimeRegressor (Closes #2770) * test: Add unit tests for load_model functionality (Issue #2770) * style(docs): Resolve D205 ruff error by adding required blank line to all modified load_model docstrings. * style(base): Final cleanup of base regressor docstring after merge. * resolve return type in BaseDeepClassifier and BaseDeepRegressor * revert return type of BaseDeepRegressor and BaseDeepClassifier to None * Automatic `pre-commit` fixes --------- Co-authored-by: rwtarpit <180088298+rwtarpit@users.noreply.github.com>
1 parent a54d7e2 commit ff28266

File tree

8 files changed

+224
-26
lines changed

8 files changed

+224
-26
lines changed

aeon/classification/deep_learning/_inception_time.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""InceptionTime and Inception classifiers."""
22

3+
from __future__ import annotations
4+
35
__maintainer__ = ["hadifawaz1999"]
46
__all__ = ["InceptionTimeClassifier"]
57

@@ -251,7 +253,7 @@ def __init__(
251253

252254
super().__init__()
253255

254-
def _fit(self, X, y):
256+
def _fit(self, X: np.ndarray, y: np.ndarray) -> InceptionTimeClassifier:
255257
"""Fit the ensemble of IndividualInceptionClassifier models.
256258
257259
Parameters
@@ -307,7 +309,7 @@ def _fit(self, X, y):
307309

308310
return self
309311

310-
def _predict(self, X) -> np.ndarray:
312+
def _predict(self, X: np.ndarray) -> np.ndarray:
311313
"""Predict the labels of the test set using InceptionTime.
312314
313315
Parameters
@@ -328,7 +330,7 @@ def _predict(self, X) -> np.ndarray:
328330
]
329331
)
330332

331-
def _predict_proba(self, X) -> np.ndarray:
333+
def _predict_proba(self, X: np.ndarray) -> np.ndarray:
332334
"""Predict the proba of labels of the test set using InceptionTime.
333335
334336
Parameters
@@ -351,24 +353,30 @@ def _predict_proba(self, X) -> np.ndarray:
351353
return probs
352354

353355
@classmethod
354-
def load_model(self, model_path, classes):
355-
"""Load pre-trained classifiers instead of fitting.
356+
def load_model(
357+
self, model_path: list[str], classes: np.ndarray
358+
) -> InceptionTimeClassifier:
359+
"""Load pre-trained keras models from disk instead of fitting.
356360
361+
Pretrained models should be saved using "save_best_model"
362+
or "save_last_model" boolean parameter.
357363
When calling this function, all functionalities can be used
358-
such as predict, predict_proba, etc. with the loaded models.
364+
such as predict, predict_proba etc. with the loaded model.
359365
360366
Parameters
361367
----------
362368
model_path : list of str (list of paths including the model names and extension)
363-
The directory where the models will be saved including the model
364-
names with a ".keras" extension.
365-
classes : np.ndarray
369+
The complete path (including file name and '.keras' extension)
370+
from which the pre-trained model's weights and configuration
371+
are loaded.
372+
classes : np.ndarray
366373
The set of unique classes the pre-trained loaded model is trained
367374
to predict during the classification task.
375+
Example: model_path="path/to/file/best_model.keras"
368376
369377
Returns
370378
-------
371-
None
379+
InceptionTimeClassifier
372380
"""
373381
assert (
374382
type(model_path) is list

aeon/classification/deep_learning/_lite_time.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""LITETime and LITE classifiers."""
22

3+
from __future__ import annotations
4+
35
__maintainer__ = ["hadifawaz1999"]
46
__all__ = ["LITETimeClassifier"]
57

@@ -193,7 +195,7 @@ def __init__(
193195

194196
super().__init__()
195197

196-
def _fit(self, X, y):
198+
def _fit(self, X: np.ndarray, y: np.ndarray) -> LITETimeClassifier:
197199
"""Fit the ensemble of IndividualLITEClassifier models.
198200
199201
Parameters
@@ -239,7 +241,7 @@ def _fit(self, X, y):
239241

240242
return self
241243

242-
def _predict(self, X) -> np.ndarray:
244+
def _predict(self, X: np.ndarray) -> np.ndarray:
243245
"""Predict the labels of the test set using LITETime.
244246
245247
Parameters
@@ -260,7 +262,7 @@ def _predict(self, X) -> np.ndarray:
260262
]
261263
)
262264

263-
def _predict_proba(self, X) -> np.ndarray:
265+
def _predict_proba(self, X: np.ndarray) -> np.ndarray:
264266
"""Predict the proba of labels of the test set using LITETime.
265267
266268
Parameters
@@ -283,24 +285,30 @@ def _predict_proba(self, X) -> np.ndarray:
283285
return probs
284286

285287
@classmethod
286-
def load_model(self, model_path, classes):
287-
"""Load pre-trained classifiers instead of fitting.
288+
def load_model(
289+
self, model_path: list[str], classes: np.ndarray
290+
) -> LITETimeClassifier:
291+
"""Load pre-trained keras models from disk instead of fitting.
288292
293+
Pretrained models should be saved using "save_best_model"
294+
or "save_last_model" boolean parameter.
289295
When calling this function, all functionalities can be used
290-
such as predict, predict_proba, etc. with the loaded models.
296+
such as predict, predict_proba etc. with the loaded model.
291297
292298
Parameters
293299
----------
294300
model_path : list of str (list of paths including the model names and extension)
295-
The director where the models will be saved including the model
296-
names with a ".keras" extension.
297-
classes : np.ndarray
301+
The complete path (including file name and '.keras' extension)
302+
from which the pre-trained model's weights and configuration
303+
are loaded.
304+
classes : np.ndarray
298305
The set of unique classes the pre-trained loaded model is trained
299306
to predict during the classification task.
307+
Example: model_path="path/to/file/best_model.keras"
300308
301309
Returns
302310
-------
303-
None
311+
LITETimeClassifier
304312
"""
305313
assert (
306314
type(model_path) is list

aeon/classification/deep_learning/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class name: BaseDeepClassifier
2222
because we can generalise tags, _predict and _predict_proba
2323
"""
2424

25+
from __future__ import annotations
26+
2527
__maintainer__ = ["hadifawaz1999"]
2628
__all__ = ["BaseDeepClassifier"]
2729

@@ -178,17 +180,20 @@ def save_last_model_to_file(self, file_path="./"):
178180
"""
179181
self.model_.save(file_path + self.last_file_name + ".keras")
180182

181-
def load_model(self, model_path, classes):
183+
def load_model(self, model_path: str, classes: np.ndarray) -> None:
182184
"""Load a pre-trained keras model instead of fitting.
183185
186+
Pretrained model should be saved using "save_last_model" or
187+
"save_best_model" boolean parameter.
184188
When calling this function, all functionalities can be used
185189
such as predict, predict_proba etc. with the loaded model.
186190
187191
Parameters
188192
----------
189193
model_path : str (path including model name and extension)
190-
The directory where the model will be saved including the model
191-
name with a ".keras" extension.
194+
The complete path (including file name and '.keras' extension)
195+
from which the pre-trained model's weights and configuration
196+
are loaded.
192197
Example: model_path="path/to/file/best_model.keras"
193198
classes : np.ndarray
194199
The set of unique classes the pre-trained loaded model is trained

aeon/regression/deep_learning/_inception_time.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,43 @@ def _predict(self, X: np.ndarray) -> np.ndarray:
343343

344344
return ypreds
345345

346+
@classmethod
347+
def load_model(self, model_path: list[str]) -> InceptionTimeRegressor:
348+
"""Load pre-trained keras models from disk instead of fitting.
349+
350+
Pretrained models should be saved using "save_best_model"
351+
or "save_last_model" boolean parameter.
352+
When calling this function, all functionalities can be used
353+
such as predict, etc. with the loaded model.
354+
355+
Parameters
356+
----------
357+
model_path : list of str (list of paths including the model names and extension)
358+
The complete path (including file name and '.keras' extension)
359+
from which the pre-trained model's weights and configuration
360+
are loaded.
361+
Example: model_path="path/to/file/best_model.keras"
362+
363+
Returns
364+
-------
365+
InceptionTimeRegressor
366+
"""
367+
assert (
368+
type(model_path) is list
369+
), "model_path should be a list of paths to the models"
370+
371+
regressor = self()
372+
regressor.regressors_ = []
373+
374+
for i in range(len(model_path)):
375+
reg = IndividualInceptionRegressor()
376+
reg.load_model(model_path[i])
377+
regressor.regressors_.append(reg)
378+
379+
regressor.n_regressors = len(regressor.regressors_)
380+
regressor.is_fitted = True
381+
return regressor
382+
346383
@classmethod
347384
def _get_test_params(
348385
cls, parameter_set: str = "default"

aeon/regression/deep_learning/_lite_time.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,43 @@ def _predict(self, X: np.ndarray) -> np.ndarray:
269269

270270
return vals
271271

272+
@classmethod
273+
def load_model(self, model_path: list[str]) -> LITETimeRegressor:
274+
"""Load pre-trained keras models from disk instead of fitting.
275+
276+
Pretrained models should be saved using "save_best_model"
277+
or "save_last_model" boolean parameter.
278+
When calling this function, all functionalities can be used
279+
such as predict, etc. with the loaded model.
280+
281+
Parameters
282+
----------
283+
model_path : list of str (list of paths including the model names and extension)
284+
The complete path (including file name and '.keras' extension)
285+
from which the pre-trained model's weights and configuration
286+
are loaded.
287+
Example: model_path="path/to/file/best_model.keras"
288+
289+
Returns
290+
-------
291+
LITETimeRegressor
292+
"""
293+
assert (
294+
type(model_path) is list
295+
), "model_path should be a list of paths to the models"
296+
297+
regressor = self()
298+
regressor.regressors_ = []
299+
300+
for i in range(len(model_path)):
301+
reg = IndividualLITERegressor()
302+
reg.load_model(model_path[i])
303+
regressor.regressors_.append(reg)
304+
305+
regressor.n_regressors = len(regressor.regressors_)
306+
regressor.is_fitted = True
307+
return regressor
308+
272309
@classmethod
273310
def _get_test_params(cls, parameter_set: str = "default") -> dict | list[dict]:
274311
"""Return testing parameter settings for the estimator.

aeon/regression/deep_learning/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,17 @@ def save_last_model_to_file(self, file_path: str = "./") -> None:
120120
def load_model(self, model_path: str) -> None:
121121
"""Load a pre-trained keras model instead of fitting.
122122
123+
Pretrained model should be saved using "save_last_model"
124+
or "save_best_model" boolean parameter.
123125
When calling this function, all functionalities can be used
124-
such as predict etc. with the loaded model.
126+
such as predict, etc. with the loaded model.
125127
126128
Parameters
127129
----------
128130
model_path : str (path including model name and extension)
129-
The directory where the model will be saved including the model
130-
name with a ".keras" extension.
131+
The complete path (including file name and '.keras' extension)
132+
from which the pre-trained model's weights and configuration
133+
are loaded.
131134
Example: model_path="path/to/file/best_model.keras"
132135
133136
Returns
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Tests for save/load functionality of InceptionTimeRegressor."""
2+
3+
import glob
4+
import os
5+
import tempfile
6+
7+
import numpy as np
8+
import pytest
9+
10+
from aeon.regression.deep_learning import InceptionTimeRegressor
11+
from aeon.testing.data_generation import make_example_3d_numpy
12+
from aeon.utils.validation._dependencies import _check_soft_dependencies
13+
14+
15+
@pytest.mark.skipif(
16+
not _check_soft_dependencies("tensorflow", severity="none"),
17+
reason="skip test if required soft dependency not available",
18+
)
19+
def test_save_load_inceptiontime():
20+
"""Test saving and loading for InceptionTimeRegressor."""
21+
with tempfile.TemporaryDirectory() as temp:
22+
temp_dir = os.path.join(temp, "")
23+
24+
X, y = make_example_3d_numpy(
25+
n_cases=10,
26+
n_channels=1,
27+
n_timepoints=12,
28+
return_y=True,
29+
regression_target=True,
30+
)
31+
32+
model = InceptionTimeRegressor(
33+
n_epochs=1, random_state=42, save_best_model=True, file_path=temp_dir
34+
)
35+
model.fit(X, y)
36+
37+
y_pred_orig = model.predict(X)
38+
39+
model_file = glob.glob(os.path.join(temp_dir, f"{model.best_file_name}*.keras"))
40+
41+
loaded_model = InceptionTimeRegressor.load_model(model_path=model_file)
42+
43+
assert isinstance(loaded_model, InceptionTimeRegressor)
44+
45+
preds = loaded_model.predict(X)
46+
assert isinstance(preds, np.ndarray)
47+
48+
assert len(preds) == len(y)
49+
np.testing.assert_array_equal(preds, y_pred_orig)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Tests for save/load functionality of LiteTimeRegressor."""
2+
3+
import glob
4+
import os
5+
import tempfile
6+
7+
import numpy as np
8+
import pytest
9+
10+
from aeon.regression.deep_learning import LITETimeRegressor
11+
from aeon.testing.data_generation import make_example_3d_numpy
12+
from aeon.utils.validation._dependencies import _check_soft_dependencies
13+
14+
15+
@pytest.mark.skipif(
16+
not _check_soft_dependencies("tensorflow", severity="none"),
17+
reason="skip test if required soft dependency not available",
18+
)
19+
def test_save_load_litetim():
20+
"""Test saving and loading for LiteTimeRegressor."""
21+
with tempfile.TemporaryDirectory() as temp:
22+
temp_dir = os.path.join(temp, "")
23+
24+
X, y = make_example_3d_numpy(
25+
n_cases=10,
26+
n_channels=1,
27+
n_timepoints=12,
28+
return_y=True,
29+
regression_target=True,
30+
)
31+
32+
model = LITETimeRegressor(
33+
n_epochs=1, random_state=42, save_best_model=True, file_path=temp_dir
34+
)
35+
model.fit(X, y)
36+
37+
y_pred_orig = model.predict(X)
38+
39+
model_files = glob.glob(
40+
os.path.join(temp_dir, f"{model.best_file_name}*.keras")
41+
)
42+
43+
loaded_model = LITETimeRegressor.load_model(model_path=model_files)
44+
45+
assert isinstance(loaded_model, LITETimeRegressor)
46+
47+
preds = loaded_model.predict(X)
48+
assert isinstance(preds, np.ndarray)
49+
50+
assert len(preds) == len(y)
51+
np.testing.assert_array_equal(preds, y_pred_orig)

0 commit comments

Comments
 (0)