Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f2002fc
butchering
MatthewMiddlehurst Jun 9, 2025
54ee287
merge
MatthewMiddlehurst Jun 30, 2025
df062f4
fixes and mixins
MatthewMiddlehurst Jun 30, 2025
d4cf31b
naive and testing
MatthewMiddlehurst Jun 30, 2025
f1a3da1
test errors
MatthewMiddlehurst Jun 30, 2025
de68f61
tests
MatthewMiddlehurst Jul 3, 2025
61c6d5f
tests
MatthewMiddlehurst Jul 3, 2025
d4340e0
notebook
MatthewMiddlehurst Jul 3, 2025
4a1b0c7
merge
MatthewMiddlehurst Jul 12, 2025
350b2f7
Merge remote-tracking branch 'origin/main' into mm/ets
MatthewMiddlehurst Aug 6, 2025
a0de418
extra tests and notebook revert
MatthewMiddlehurst Aug 6, 2025
b1c0084
series to series
MatthewMiddlehurst Aug 6, 2025
ac81c4b
Merge remote-tracking branch 'origin/main' into mm/ets
MatthewMiddlehurst Aug 6, 2025
a2680ec
ets fit empty
MatthewMiddlehurst Aug 6, 2025
6659e0f
Merge remote-tracking branch 'origin/main' into mm/ets
MatthewMiddlehurst Aug 6, 2025
f566f0d
fixes
MatthewMiddlehurst Aug 6, 2025
8d2908d
Merge remote-tracking branch 'origin/main' into mm/ets
MatthewMiddlehurst Aug 10, 2025
d0d7a27
update
MatthewMiddlehurst Aug 10, 2025
5ca532b
this notebook thing is the big downside of this change
MatthewMiddlehurst Aug 10, 2025
cf13eb5
Merge remote-tracking branch 'origin/main' into mm/ets
MatthewMiddlehurst Aug 11, 2025
c3646c1
ets
MatthewMiddlehurst Aug 11, 2025
8d56ec7
Merge branch 'main' into mm/ets
TonyBagnall Aug 12, 2025
ef9cd6e
Merge remote-tracking branch 'origin/main' into mm/ets
MatthewMiddlehurst Oct 26, 2025
9e25bbb
Merge remote-tracking branch 'origin/mm/ets' into mm/ets
MatthewMiddlehurst Oct 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 14 additions & 57 deletions aeon/forecasting/stats/_ets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,10 @@
from aeon.forecasting.utils._extract_paras import _extract_ets_params
from aeon.forecasting.utils._loss_functions import (
_ets_fit,
_ets_initialise,
_ets_predict_value,
)
from aeon.forecasting.utils._nelder_mead import nelder_mead

ADDITIVE = "additive"
MULTIPLICATIVE = "multiplicative"


class ETS(BaseForecaster, IterativeForecastingMixin):
"""Exponential Smoothing (ETS) forecaster.
Expand All @@ -43,14 +39,8 @@ class ETS(BaseForecaster, IterativeForecastingMixin):
Type of seasonal component: None (0), `additive' (1) or 'multiplicative' (2)
seasonal_period : int, default=1
Number of time points in a seasonal cycle.
alpha : float, default=0.1
Level smoothing parameter.
beta : float, default=0.01
Trend smoothing parameter.
gamma : float, default=0.01
Seasonal smoothing parameter.
phi : float, default=0.99
Trend damping parameter (used only for damped trend models).
iterations : int, default=200
Number of iterations for the Nelder-Mead optimisation algorithm used to fit.

Attributes
----------
Expand Down Expand Up @@ -96,6 +86,8 @@ class ETS(BaseForecaster, IterativeForecastingMixin):

_tags = {
"capability:horizon": False,
"fit_is_empty": True,
"predict_updates_state": True,
}

def __init__(
Expand Down Expand Up @@ -131,7 +123,7 @@ def __init__(
self.forecast_ = 0
super().__init__(horizon=1, axis=1)

def _fit(self, y, exog=None):
def _predict(self, y, exog=None):
"""Fit Exponential Smoothing forecaster to series y.

Fit a forecaster to predict self.horizon steps ahead using y.
Expand All @@ -145,8 +137,8 @@ def _fit(self, y, exog=None):

Returns
-------
self
Fitted ETS.
float
single prediction self.horizon steps ahead of y.
"""
_validate_parameter(self.error_type, False)
_validate_parameter(self.seasonality_type, True)
Expand All @@ -156,9 +148,9 @@ def _fit(self, y, exog=None):
def _get_int(x):
if x is None:
return 0
if x == ADDITIVE:
if x == "additive":
return 1
if x == MULTIPLICATIVE:
if x == "multiplicative":
return 2
return x

Expand Down Expand Up @@ -200,7 +192,8 @@ def _get_int(x):
self.liklihood_,
self.k_,
) = _ets_fit(self.parameters_, data, self._model)
self.forecast_ = _numba_predict(

return _numba_predict(
self._trend_type,
self._seasonality_type,
self.level_,
Expand All @@ -212,50 +205,14 @@ def _get_int(x):
self._seasonal_period,
)

return self

def _predict(self, y, exog=None):
"""
Predict the next horizon steps ahead.

Parameters
----------
y : np.ndarray, default = None
A time series to predict the next horizon value for. If None,
predict the next horizon value after series seen in fit.
exog : np.ndarray, default =None
Optional exogenous time series data assumed to be aligned with y

Returns
-------
float
single prediction self.horizon steps ahead of y.
"""
return self.forecast_

def _initialise(self, data):
"""
Initialize level, trend, and seasonality values for the ETS model.

Parameters
----------
data : array-like
The time series data
(should contain at least two full seasons if seasonality is specified)
"""
self.level_, self.trend_, self.seasonality_ = _ets_initialise(
self._trend_type, self._seasonality_type, self._seasonal_period, data
)

def iterative_forecast(self, y, prediction_horizon):
"""Forecast with ETS specific iterative method.

Overrides the base class iterative_forecast to avoid refitting on each step.
This simply rolls the ETS model forward
"""
self.fit(y)
preds = np.zeros(prediction_horizon)
preds[0] = self.forecast_
preds[0] = self.predict(y)
for i in range(1, prediction_horizon):
preds[i] = _numba_predict(
self._trend_type,
Expand Down Expand Up @@ -301,10 +258,10 @@ def _numba_predict(


def _validate_parameter(var, can_be_none):
valid_str = (ADDITIVE, MULTIPLICATIVE)
valid_str = ("additive", "multiplicative")
valid_int = (1, 2)
if can_be_none:
valid_str = (None, ADDITIVE, MULTIPLICATIVE)
valid_str = (None, "additive", "multiplicative")
valid_int = (0, 1, 2)
valid = True
if isinstance(var, str) or var is None:
Expand Down
13 changes: 4 additions & 9 deletions aeon/forecasting/stats/tests/test_ets.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_ets_raises_on_horizon_greater_than_one():
forecaster.horizon = 2
data = np.array([3, 10, 12, 13, 12, 10, 12, 3, 10, 12, 13, 12, 10, 12])
with pytest.raises(ValueError, match="Horizon is set >1, but"):
forecaster.fit(data)
forecaster.predict(data)


def test_ets_iterative_forecast():
Expand All @@ -97,11 +97,6 @@ def test_ets_iterative_forecast():
assert preds.shape == (h,), f"Expected output shape {(h,)}, got {preds.shape}"
assert np.all(np.isfinite(preds)), "All forecast values should be finite"

# Optional: check that the first prediction equals forecast_ from .fit()
forecaster.fit(y)
assert np.isclose(
preds[0], forecaster.forecast_, atol=1e-6
), "First forecast should match forecast_"
forecaster = ETS(trend_type=None)
forecaster._fit(y)
assert forecaster._trend_type == 0
# Optional: check that the first prediction equals .predict()
p = forecaster.predict(y)
assert np.isclose(preds[0], p, atol=1e-6), "First forecast should match predict"
1 change: 0 additions & 1 deletion aeon/forecasting/utils/_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def _ets_update_states(
@njit(fastmath=True, cache=True)
def _ets_predict_value(trend_type, seasonality_type, level, trend, seasonality, phi):
"""

Generate various useful values, including the next fitted value.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion examples/forecasting/iterative.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@
"ets_forecasts = ets.iterative_forecast(y_train, 20)\n",
"sm_model = statsmodels_ets.fit()\n",
"statsmodels_forecasts = sm_model.forecast(steps=20)\n",
"print(f\"Alpha: {ets.alpha_}, Beta: {ets.beta_}, Gamma: {ets.gamma_}, Phi: {ets.phi_}\")\n",
"# print(f\"Alpha: {ets.alpha_}, Beta: {ets.beta_}, Gamma: {ets.gamma_}, Phi: {ets.phi_}\")\n",
"print(\n",
" f\"Alpha: {sm_model.alpha}, Beta: {sm_model.beta if sm_model.has_trend else 'N/A'}, \\\n",
" Gamma: {sm_model.gamma if sm_model.has_seasonal else 'N/A'}, \\\n",
Expand Down
Loading