diff --git a/aeon/forecasting/stats/__init__.py b/aeon/forecasting/stats/__init__.py index a86d5346a0..bb4b54cde3 100644 --- a/aeon/forecasting/stats/__init__.py +++ b/aeon/forecasting/stats/__init__.py @@ -1,6 +1,7 @@ """Stats based forecasters.""" __all__ = [ + "AutoETS", "ARIMA", "AutoARIMA", "AutoTAR", @@ -11,7 +12,7 @@ ] from aeon.forecasting.stats._arima import ARIMA, AutoARIMA -from aeon.forecasting.stats._ets import ETS +from aeon.forecasting.stats._ets import ETS, AutoETS from aeon.forecasting.stats._tar import TAR, AutoTAR from aeon.forecasting.stats._theta import Theta from aeon.forecasting.stats._tvp import TVP diff --git a/aeon/forecasting/stats/_ets.py b/aeon/forecasting/stats/_ets.py index 542ecd40bb..ff9d2fa723 100644 --- a/aeon/forecasting/stats/_ets.py +++ b/aeon/forecasting/stats/_ets.py @@ -1,4 +1,4 @@ -"""ETS class. +"""ETS and AutoETS class. An implementation of the exponential smoothing statistics forecasting algorithm. Implements additive and multiplicative error models. We recommend using the AutoETS @@ -6,7 +6,7 @@ """ __maintainer__ = [] -__all__ = ["ETS"] +__all__ = ["ETS", "AutoETS"] import numpy as np @@ -20,6 +20,7 @@ _ets_predict_value, ) from aeon.forecasting.utils._nelder_mead import nelder_mead +from aeon.forecasting.utils._seasonality import calc_seasonal_period ADDITIVE = "additive" MULTIPLICATIVE = "multiplicative" @@ -271,6 +272,139 @@ def iterative_forecast(self, y, prediction_horizon): return preds +class AutoETS(BaseForecaster): + """Automatic Exponential Smoothing forecaster. + + An implementation of the exponential smoothing statistics forecasting algorithm. + Chooses betweek additive and multiplicative error models, + None, additive and multiplicative (including damped) trend and + None, additive and multiplicative seasonality[1]_. + + Attempts to make this forecaster stable: + - Issues with Zero division Errors: + - If any data points are non-positive, multiplicative options are excluded. + - With numba fastmath=true, the compiler moves the operations around. This means + zero division guards are ineffective. As such I have tried putting the guards + in a separate fastmath=false function, with inline=true. This slows it down a + lot though. + - Need to make sure initialisation function never assigns slices of the data + array + - fixed bug where the first few values were being changed in the seasonality + calculation array + - Issues with the nelder mead not finding good parameters, + usually ending up with alpha approx 1: + - Tested updating the initial conditions (initial level, trend, + seasonality array) to be heuristically calculated across the whole + array at the start. + This seemed to fix some issues, but cause others. + - Tested optimising over the initial conditions in the nelder-mead array. + This didn't really help, although is how statsmodels does it. + - Added guards to the nelder-mead algorithm to reflect points back in when they + go above or below (0,1) to ensure parameters stay valid. + - Added sigmoid function to output of nelder-mead to ensure parameters stay + in (0,1). + - Initialised the simplex array with reasonable starting values + (a=0.4, b=0.25, phi=0.95, g=0.35) + - The algorithms sometimes produce really extreme forecasts of the order of 1e100 + larger than the data. I assume this is due to the guards on the zero division, + but haven't been able to work out how to fix it. + + Parameters + ---------- + horizon : int, default = 1 + The horizon to forecast to. + + References + ---------- + .. [1] R. J. Hyndman and G. Athanasopoulos, + Forecasting: Principles and Practice. Melbourne, Australia: OTexts, 2014. + + Examples + -------- + >>> from aeon.forecasting.stats import AutoETS + >>> from aeon.datasets import load_airline + >>> y = load_airline() + >>> forecaster = AutoETS() + >>> forecaster.forecast(y) + 435.9312382780535 + """ + + _tags = { + "capability:horizon": False, + } + + def __init__(self): + self.error_type_ = 0 + self.trend_type_ = 0 + self.seasonality_type_ = 0 + self.seasonal_period_ = 0 + self.wrapped_model_ = None + super().__init__(horizon=1, axis=1) + + def _fit(self, y, exog=None): + """Fit Auto Exponential Smoothing forecaster to series y. + + Fit a forecaster to predict self.horizon steps ahead using y. + + Parameters + ---------- + y : np.ndarray + A time series on which to learn a forecaster to predict horizon ahead + exog : np.ndarray, default =None + Optional exogenous time series data assumed to be aligned with y + + Returns + ------- + self + Fitted AutoETS. + """ + data = y.squeeze() + best_model = auto_ets(data) + self.error_type_ = int(best_model[0]) + self.trend_type_ = int(best_model[1]) + self.seasonality_type_ = int(best_model[2]) + self.seasonal_period_ = int(best_model[3]) + self.wrapped_model_ = ETS( + self.error_type_, + self.trend_type_, + self.seasonality_type_, + self.seasonal_period_, + ) + self.wrapped_model_.fit(y, exog) + return self + + def _predict(self, y=None, exog=None): + """ + Predict the next horizon steps ahead. + + Parameters + ---------- + y : np.ndarray, default = None + A time series to predict the next horizon value for. If None, + predict the next horizon value after series seen in fit. + exog : np.ndarray, default =None + Optional exogenous time series data assumed to be aligned with y + + Returns + ------- + float + single prediction self.horizon steps ahead of y. + """ + return self.wrapped_model_.predict(y, exog) + + def _forecast(self, y, exog=None, axis=1): + self.fit(y, exog=exog) + return float(self.wrapped_model_.forecast_) + + def iterative_forecast(self, y, prediction_horizon): + """Forecast with ETS specific iterative method. + + Overrides the base class iterative_forecast to avoid refitting on each step. + This simply rolls the ETS model forward + """ + return self.wrapped_model_.iterative_forecast(y, prediction_horizon) + + @njit(fastmath=True, cache=True) def _numba_predict( trend_type, @@ -320,3 +454,56 @@ def _validate_parameter(var, can_be_none): f"variable must be either string or integer with values" f" {valid_str} or {valid_int} but saw {var}" ) + + +@njit(fastmath=True, cache=True) +def auto_ets(data): + """Calculate model parameters based on the internal nelder-mead implementation.""" + seasonal_period = calc_seasonal_period(data) + seasonal_enabled = seasonal_period > 1 + s_max = 3 if seasonal_enabled else 1 + all_pos = True + for i in range(data.size): + if data[i] <= 0.0: + all_pos = False + break + x0_k1 = np.array((0.2, 0), dtype=np.float64) + x0_k2 = np.array((0.2, 0.5), dtype=np.float64) + x0_k3 = np.array((0.2, 0.05, 0.99), dtype=np.float64) + x0_k4 = np.array((0.2, 0.05, 0.05, 0.99), dtype=np.float64) + model = np.empty(4, dtype=np.int32) + best_model = np.empty(4, dtype=np.int32) + best_aic = np.inf + for error_type in range(1, 3): + if error_type == 2 and not all_pos: + continue + for trend_type in range(0, 3): + if trend_type == 2 and not all_pos: + continue + k_base = 1 + (2 if (trend_type != 0) else 0) + for seasonality_type in range(0, s_max): + if seasonality_type == 2 and not all_pos: + continue + model[0] = error_type + model[1] = trend_type + model[2] = seasonality_type + model[3] = seasonal_period if (seasonality_type != 0) else 1 + k = k_base + (1 if seasonality_type != 0 else 0) + x0 = ( + x0_k1 + if k == 1 + else (x0_k2 if k == 2 else (x0_k3 if k == 3 else x0_k4)) + ) + best_params, aic = nelder_mead(1, k, data, model, x0=x0) + if aic < best_aic: + best_aic = aic + best_model[:] = model + if k == 1: + x0_k1[:-1] = best_params + if k == 2: + x0_k2[:] = best_params + elif k == 3: + x0_k3[:] = best_params + else: + x0_k4[:] = best_params + return best_model diff --git a/aeon/forecasting/stats/tests/test_ets.py b/aeon/forecasting/stats/tests/test_ets.py index ab0f078e17..184a12b28e 100644 --- a/aeon/forecasting/stats/tests/test_ets.py +++ b/aeon/forecasting/stats/tests/test_ets.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from aeon.forecasting.stats._ets import ETS, _validate_parameter +from aeon.forecasting.stats._ets import ETS, AutoETS, _validate_parameter @pytest.mark.parametrize( @@ -105,3 +105,115 @@ def test_ets_iterative_forecast(): forecaster = ETS(trend_type=None) forecaster._fit(y) assert forecaster._trend_type == 0 + + +# small seasonal-ish series (same as in ETS tests) +Y_SEASONAL = np.array( + [3, 10, 12, 13, 12, 10, 12, 3, 10, 12, 13, 12, 10, 12], dtype=float +) +# another shortish series for basic sanity checks +Y_SHORT = np.array([10, 12, 14, 13, 15, 16, 18, 19, 20, 21, 22, 23], dtype=float) + + +def test_autoets_fit_sets_attributes_and_wraps(): + """Fit should set type/period attributes and wrap an ETS instance.""" + forecaster = AutoETS() + forecaster.fit(Y_SEASONAL) + + # wrapped model exists and is ETS + assert forecaster.wrapped_model_ is not None + assert isinstance(forecaster.wrapped_model_, ETS) + + # discovered structure attributes should exist and be integers >= 0 + for attr in ("error_type_", "trend_type_", "seasonality_type_", "seasonal_period_"): + val = getattr(forecaster, attr) + assert isinstance(val, (int, np.integer)) + assert val >= 0 + + # wrapped model should have been fitted and expose a finite forecast_ + assert hasattr(forecaster.wrapped_model_, "forecast_") + assert np.isfinite(forecaster.wrapped_model_.forecast_) + + +def test_autoets_predict_returns_finite_float(): + """_predict should return a finite float once fitted.""" + forecaster = AutoETS() + forecaster.fit(Y_SHORT) + pred = forecaster._predict(Y_SHORT) + assert isinstance(pred, float) + assert np.isfinite(pred) + + +def test_autoets_forecast_sets_wrapped_and_returns_forecast_float(): + """_forecast should fit internally, set wrapped forecast_, and return that value.""" + forecaster = AutoETS() + f = forecaster._forecast(Y_SEASONAL) + assert isinstance(f, float) + assert np.isfinite(f) + assert forecaster.wrapped_model_ is not None + assert hasattr(forecaster.wrapped_model_, "forecast_") + assert np.isclose(f, float(forecaster.wrapped_model_.forecast_)) + + +def test_autoets_iterative_forecast_shape_and_validity(): + """iterative_forecast should delegate to wrapped ETS and return valid outputs.""" + h = 5 + forecaster = AutoETS() + forecaster.fit(Y_SHORT) + preds = forecaster.iterative_forecast(Y_SHORT, prediction_horizon=h) + + assert isinstance(preds, np.ndarray) + assert preds.shape == (h,) + assert np.all(np.isfinite(preds)) + + # Optional: first iterative step should match one-step-ahead forecast after fit + assert np.isclose(preds[0], forecaster.wrapped_model_.forecast_, atol=1e-6) + + +def test_autoets_horizon_greater_than_one_raises(): + """ + AutoETS.fit should raise ValueError. + + when horizon > 1 (ETS only supports 1-step fit). + """ + forecaster = AutoETS() + forecaster.horizon = 2 + with pytest.raises(ValueError, match="Horizon is set >1"): + forecaster.fit(Y_SEASONAL) + + +def test_autoets_predict_matches_wrapped_predict(): + """_predict should match the wrapped ETS model's predict.""" + forecaster = AutoETS() + forecaster.fit(Y_SEASONAL) + a = forecaster._predict(Y_SEASONAL) + b = forecaster.wrapped_model_.predict(Y_SEASONAL) + assert isinstance(a, float) and isinstance(b, float) + assert np.isfinite(a) and np.isfinite(b) + assert np.isclose(a, b) + + +def test_autoets_forecast_is_consistent_with_wrapped(): + """_forecast should equal the wrapped model's forecast after internal fit.""" + forecaster = AutoETS() + val = forecaster._forecast(Y_SHORT) + assert np.isclose(val, float(forecaster.wrapped_model_.forecast_)) + + +def test_autoets_exog_raises(): + """AutoETS.fit should raise ValueError when exog passed.""" + forecaster = AutoETS() + exog = np.arange(len(Y_SEASONAL), dtype=float) # simple aligned exogenous regressor + with pytest.raises( + ValueError, + match="AutoETS cannot handle exogenous variables", + ): + forecaster.fit(Y_SEASONAL, exog=exog) + + +def test_autoets_repeatability_on_same_input(): + """Forecasting twice on the same series should be deterministic.""" + forecaster = AutoETS() + f1 = forecaster._forecast(Y_SEASONAL) + f2 = forecaster._forecast(Y_SEASONAL) + assert np.isclose(f1, f2) diff --git a/aeon/forecasting/utils/_loss_functions.py b/aeon/forecasting/utils/_loss_functions.py index 768e8a36c0..972eae6842 100644 --- a/aeon/forecasting/utils/_loss_functions.py +++ b/aeon/forecasting/utils/_loss_functions.py @@ -1,5 +1,7 @@ """Loss functions for optimiser.""" +import math + import numpy as np from numba import njit @@ -9,6 +11,7 @@ ) LOG_2PI = 1.8378770664093453 +EPS = np.float64(1e-8) @njit(cache=True, fastmath=True) @@ -44,23 +47,56 @@ def _arima_fit(params, data, model): @njit(fastmath=True, cache=True) -def _ets_fit(params, data, model, return_all_states=False): +def _ets_fit(params, data, model): alpha, beta, gamma, phi = _extract_ets_params(params, model) error_type = model[0] trend_type = model[1] seasonality_type = model[2] seasonal_period = model[3] n_timepoints = len(data) - seasonal_period - level, trend, seasonality = _ets_initialise( - trend_type, seasonality_type, seasonal_period, data - ) + sum1 = 0.0 + sum2 = 0.0 + for i in range(seasonal_period): + sum1 += data[i] + sum2 += data[i + seasonal_period] + level = sum1 / seasonal_period + level2 = sum2 / seasonal_period + # Initial Trend + if trend_type == 1: + # Average difference between corresponding points in the first two seasons + trend = level2 - level + elif trend_type == 2: + # Average ratio between corresponding points in the first two seasons + trend = level2 / level + else: + # No trend + trend = 0 + # Initial Seasonality + seasonality = np.empty(seasonal_period, dtype=np.float64) + if seasonality_type == 1: + # Seasonal component is the difference + # from the initial level for each point in the first season + for i in range(seasonal_period): + seasonality[i] = data[i] - level + elif seasonality_type == 2: + # Seasonal component is the ratio of each point in the first season + # to the initial level + if level == 0: + for i in range(seasonal_period): + seasonality[i] = data[i] + else: + for i in range(seasonal_period): + seasonality[i] = data[i] / level + else: + # No seasonality + seasonality = np.zeros(1, dtype=np.float64) avg_mean_sq_err_ = 0 liklihood_ = 0 residuals_ = np.zeros(n_timepoints) # 1 Less residual than data points fitted_values_ = np.zeros(n_timepoints) + s_index = 0 for t in range(n_timepoints): index = t + seasonal_period - s_index = t % seasonal_period time_point = data[index] @@ -78,22 +114,28 @@ def _ets_fit(params, data, model, return_all_states=False): gamma, phi, ) + s_index += 1 + if s_index == seasonal_period: + s_index = 0 residuals_[t] = error fitted_values_[t] = fitted_value avg_mean_sq_err_ += (time_point - fitted_value) ** 2 liklihood_error = error if error_type == 2: # Multiplicative liklihood_error *= fitted_value + if liklihood_error > 1e4 or liklihood_ > 1e8: + liklihood_ = 1e8 + break liklihood_ += liklihood_error**2 avg_mean_sq_err_ /= n_timepoints - liklihood_ = n_timepoints * np.log(liklihood_) + liklihood_ = n_timepoints * math.log(liklihood_) k_ = ( seasonal_period * (seasonality_type != 0) + 2 * (trend_type != 0) + 2 + 1 * (phi != 1) ) - aic_ = liklihood_ + 2 * k_ - n_timepoints * np.log(n_timepoints) + aic_ = liklihood_ + 2 * k_ - n_timepoints * math.log(n_timepoints) return ( aic_, level, @@ -108,6 +150,106 @@ def _ets_fit(params, data, model, return_all_states=False): ) +@njit(inline="always", cache=True) +def safe_div(num, den): + if den < EPS: + return num / EPS + else: + return num / den + + +@njit(fastmath=True, cache=True) +def _ets_aic(params, data, model): + alpha, beta, gamma, phi = _extract_ets_params(params, model) + error_type = model[0] + trend_type = model[1] + seasonality_type = model[2] + seasonal_period = model[3] + n_timepoints = len(data) - seasonal_period + level, trend, seasonality = _ets_initialise( + trend_type, seasonality_type, seasonal_period, data + ) + liklihood_ = 0 + s_index = 0 + for t in range(n_timepoints): + index = t + seasonal_period + # Calculate level, trend, and seasonal components + # Retrieve the current state values + curr_level = level + curr_seasonality = seasonality[s_index] + if trend_type == 2: # Multiplicative + if trend < 0: + damped_trend = -((-trend) ** phi) + else: + damped_trend = trend**phi + trend_level_combination = level * damped_trend + else: # Additive trend, if no trend, then trend = 0 + damped_trend = trend * phi + trend_level_combination = level + damped_trend + + # Calculate forecast (fitted value) based on the current components + if seasonality_type == 2: # Multiplicative + fitted_value = trend_level_combination * seasonality[s_index] + else: # Additive seasonality, if no seasonality, then seasonality = 0 + fitted_value = trend_level_combination + seasonality[s_index] + # Calculate the error term (observed value - fitted value) + if error_type == 2: + error = safe_div(data[index], fitted_value) - 1 # Multiplicative error + else: + error = data[index] - fitted_value # Additive error + # Update level + if error_type == 2: + level = trend_level_combination * (1 + alpha * error) + trend = damped_trend * (1 + beta * error) + seasonality[s_index] = curr_seasonality * (1 + gamma * error) + if seasonality_type == 1: + level += alpha * error * curr_seasonality # Add seasonality correction + seasonality[s_index] += gamma * error * trend_level_combination + if trend_type == 1: + trend += (curr_level + curr_seasonality) * beta * error + else: + trend += safe_div(curr_seasonality, curr_level) * beta * error + elif trend_type == 1: + trend += curr_level * beta * error + else: + level_correction = 1 + trend_correction = 1 + seasonality_correction = 1 + if seasonality_type == 2: + # Add seasonality correction + level_correction *= curr_seasonality + trend_correction *= curr_seasonality + seasonality_correction *= trend_level_combination + if trend_type == 2: + trend_correction *= curr_level + level = trend_level_combination + alpha * safe_div(error, level_correction) + trend = damped_trend + beta * safe_div(error, trend_correction) + seasonality[s_index] = curr_seasonality + gamma * safe_div( + error, seasonality_correction + ) + s_index += 1 + if s_index == seasonal_period: + s_index = 0 + if error_type == 2: # Multiplicative + error *= fitted_value + if error > 1e4 or liklihood_ > 1e8: + liklihood_ = 1e8 + break + liklihood_ += error * error + k_ = ( + seasonal_period + if (seasonality_type != 0) + else 0 + 2 if (trend_type != 0) else 0 + 2 + 1 if (phi != 1) else 0 + ) + aic_ = ( + n_timepoints * math.log(liklihood_) + + 2 * k_ + - n_timepoints * math.log(n_timepoints) + ) + return aic_ + + +# Fastmath deliberately set to False to avoid issues with numerical stability (ZDEs) @njit(fastmath=True, cache=True) def _ets_initialise(trend_type, seasonality_type, seasonal_period, data): """ @@ -120,36 +262,46 @@ def _ets_initialise(trend_type, seasonality_type, seasonal_period, data): (should contain at least two full seasons if seasonality is specified) """ # Initial Level: Mean of the first season - level = np.mean(data[:seasonal_period]) + sum1 = 0.0 + sum2 = 0.0 + for i in range(seasonal_period): + sum1 += data[i] + sum2 += data[i + seasonal_period] + level = sum1 / seasonal_period + level2 = sum2 / seasonal_period # Initial Trend if trend_type == 1: # Average difference between corresponding points in the first two seasons - trend = np.mean( - data[seasonal_period : 2 * seasonal_period] - data[:seasonal_period] - ) + trend = level2 - level elif trend_type == 2: # Average ratio between corresponding points in the first two seasons - trend = np.mean( - data[seasonal_period : 2 * seasonal_period] / data[:seasonal_period] - ) + trend = level2 / level else: # No trend trend = 0 # Initial Seasonality + seasonality = np.empty(seasonal_period, dtype=np.float64) if seasonality_type == 1: # Seasonal component is the difference # from the initial level for each point in the first season - seasonality = data[:seasonal_period] - level + for i in range(seasonal_period): + seasonality[i] = data[i] - level elif seasonality_type == 2: # Seasonal component is the ratio of each point in the first season # to the initial level - seasonality = data[:seasonal_period] / level + if level == 0: + for i in range(seasonal_period): + seasonality[i] = data[i] + else: + for i in range(seasonal_period): + seasonality[i] = data[i] / level else: # No seasonality seasonality = np.zeros(1, dtype=np.float64) return level, trend, seasonality +# Fastmath deliberately set to False to avoid issues with numerical stability (ZDEs) @njit(fastmath=True, cache=True) def _ets_update_states( error_type, @@ -179,11 +331,25 @@ def _ets_update_states( # Retrieve the current state values curr_level = level curr_seasonality = seasonality - fitted_value, damped_trend, trend_level_combination = _ets_predict_value( - trend_type, seasonality_type, level, trend, seasonality, phi - ) + if trend_type == 2: # Multiplicative + if trend < 0: + damped_trend = -((-trend) ** phi) + else: + damped_trend = trend**phi + trend_level_combination = level * damped_trend + else: # Additive trend, if no trend, then trend = 0 + damped_trend = trend * phi + trend_level_combination = level + damped_trend + + # Calculate forecast (fitted value) based on the current components + if seasonality_type == 2: # Multiplicative + fitted_value = trend_level_combination * seasonality + else: # Additive seasonality, if no seasonality, then seasonality = 0 + fitted_value = trend_level_combination + seasonality # Calculate the error term (observed value - fitted value) if error_type == 2: + if fitted_value < EPS: + fitted_value = EPS error = data_item / fitted_value - 1 # Multiplicative error else: error = data_item - fitted_value # Additive error @@ -198,7 +364,7 @@ def _ets_update_states( if trend_type == 1: trend += (curr_level + curr_seasonality) * beta * error else: - trend += curr_seasonality / curr_level * beta * error + trend += curr_seasonality / max(curr_level, EPS) * beta * error elif trend_type == 1: trend += curr_level * beta * error else: @@ -212,13 +378,16 @@ def _ets_update_states( seasonality_correction *= trend_level_combination if trend_type == 2: trend_correction *= curr_level - level = trend_level_combination + alpha * error / level_correction - trend = damped_trend + beta * error / trend_correction - seasonality = curr_seasonality + gamma * error / seasonality_correction + level = trend_level_combination + alpha * error / max(level_correction, EPS) + trend = damped_trend + beta * error / max(trend_correction, EPS) + seasonality = curr_seasonality + gamma * error / max( + seasonality_correction, EPS + ) return (fitted_value, error, level, trend, seasonality) -@njit(fastmath=True, cache=True) +# Fastmath deliberately set to False to avoid issues with numerical stability (ZDEs) +@njit(inline="always", fastmath=True, cache=True) def _ets_predict_value(trend_type, seasonality_type, level, trend, seasonality, phi): """ @@ -247,7 +416,10 @@ def _ets_predict_value(trend_type, seasonality_type, level, trend, seasonality, # Apply damping parameter and # calculate commonly used combination of trend and level components if trend_type == 2: # Multiplicative - damped_trend = trend**phi + if trend < 0: + damped_trend = -((-trend) ** phi) + else: + damped_trend = trend**phi trend_level_combination = level * damped_trend else: # Additive trend, if no trend, then trend = 0 damped_trend = trend * phi diff --git a/aeon/forecasting/utils/_nelder_mead.py b/aeon/forecasting/utils/_nelder_mead.py index 7ddbb1240a..29cddfb796 100644 --- a/aeon/forecasting/utils/_nelder_mead.py +++ b/aeon/forecasting/utils/_nelder_mead.py @@ -3,7 +3,7 @@ import numpy as np from numba import njit -from aeon.forecasting.utils._loss_functions import _arima_fit, _ets_fit +from aeon.forecasting.utils._loss_functions import _arima_fit, _ets_aic @njit(cache=True, fastmath=True) @@ -11,20 +11,14 @@ def dispatch_loss(fn_id, params, data, model): if fn_id == 0: return _arima_fit(params, data, model) if fn_id == 1: - return _ets_fit(params, data, model)[0] + return _ets_aic(params, data, model) else: raise ValueError("Unknown loss function ID") @njit(cache=True, fastmath=True) def nelder_mead( - loss_id, - num_params, - data, - model, - tol=1e-6, - max_iter=500, - simplex_init=0.5, + loss_id, num_params, data, model, tol=1e-6, max_iter=500, simplex_init=0.5, x0=None ): """ Perform optimisation using the Nelder–Mead simplex algorithm. @@ -79,12 +73,22 @@ def nelder_mead( The Computer Journal, 7(4), 308–313. https://doi.org/10.1093/comjnl/7.4.308 """ - points = np.full((num_params + 1, num_params), simplex_init) - for i in range(num_params): - points[i + 1][i] = simplex_init * 1.2 + points = np.empty((num_params + 1, num_params), dtype=np.float64) + if x0 is None: + for i in range(num_params + 1): + points[i][:] = simplex_init + for i in range(num_params): + points[i + 1][i] = simplex_init * 1.2 + else: + for i in range(num_params + 1): + points[i][:] = x0[:num_params] + for i in range(num_params): + p = x0[:num_params].copy() + p[i] = min(0.999, max(1e-6, x0[i] * 1.2)) + points[i + 1, :] = p values = np.empty(len(points), dtype=np.float64) for i in range(len(points)): - values[i] = dispatch_loss(loss_id, points[i].copy(), data, model) + values[i] = dispatch_loss(loss_id, points[i], data, model) for _ in range(max_iter): # Order simplex by function values order = np.argsort(values) diff --git a/aeon/forecasting/utils/_seasonality.py b/aeon/forecasting/utils/_seasonality.py index 356b1a40d2..6b1a3411b6 100644 --- a/aeon/forecasting/utils/_seasonality.py +++ b/aeon/forecasting/utils/_seasonality.py @@ -88,7 +88,7 @@ def calc_seasonal_period(data): The estimated seasonal period (lag) of the series. Returns 1 if no significant peak is detected in the autocorrelation. """ - lags = acf(data, 24) + lags = acf(data, min(24, len(data) - 1)) lags = np.concatenate((np.array([1.0]), lags)) peaks = [] mean_lags = np.mean(lags)