diff --git a/pytorch_forecasting/layers/_decomposition/__init__.py b/pytorch_forecasting/layers/_decomposition/__init__.py new file mode 100644 index 000000000..92b93fe3e --- /dev/null +++ b/pytorch_forecasting/layers/_decomposition/__init__.py @@ -0,0 +1,9 @@ +""" +Decomposition layers for PyTorch Forecasting. +""" + +from pytorch_forecasting.layers._decomposition._series_decomp import SeriesDecomposition + +__all__ = [ + "SeriesDecomposition", +] diff --git a/pytorch_forecasting/layers/_decomposition/_series_decomp.py b/pytorch_forecasting/layers/_decomposition/_series_decomp.py new file mode 100644 index 000000000..30c6b8b38 --- /dev/null +++ b/pytorch_forecasting/layers/_decomposition/_series_decomp.py @@ -0,0 +1,43 @@ +""" +Series Decomposition Block for time series forecasting models. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pytorch_forecasting.layers._filter._moving_avg_filter import MovingAvg + + +class SeriesDecomposition(nn.Module): + """ + Series decomposition block from Autoformer. + + Decomposes time series into trend and seasonal components using + moving average filtering. + + Args: + kernel_size (int): + Size of the moving average kernel for trend extraction. + """ + + def __init__(self, kernel_size): + super().__init__() + self.moving_avg = MovingAvg(kernel_size, stride=1) + + def forward(self, x): + """ + Forward pass for series decomposition. + + Args: + x (torch.Tensor): + Input time series tensor of shape (batch_size, seq_len, features). + + Returns: + tuple: + - trend (torch.Tensor): Trend component of the time series. + - seasonal (torch.Tensor): Seasonal component of the time series. + """ + trend = self.moving_avg(x) + seasonal = x - trend + return seasonal, trend diff --git a/pytorch_forecasting/layers/_filter/__init__.py b/pytorch_forecasting/layers/_filter/__init__.py new file mode 100644 index 000000000..d82df7845 --- /dev/null +++ b/pytorch_forecasting/layers/_filter/__init__.py @@ -0,0 +1,9 @@ +""" +Filtering layers for time series forecasting models. +""" + +from pytorch_forecasting.layers._filter._moving_avg_filter import MovingAvg + +__all__ = [ + "MovingAvg", +] diff --git a/pytorch_forecasting/layers/_filter/_moving_avg_filter.py b/pytorch_forecasting/layers/_filter/_moving_avg_filter.py new file mode 100644 index 000000000..85f26c5f4 --- /dev/null +++ b/pytorch_forecasting/layers/_filter/_moving_avg_filter.py @@ -0,0 +1,48 @@ +""" +Moving Average Filter Block +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MovingAvg(nn.Module): + """ + Moving Average block for smoothing and trend extraction from time series data. + + A moving average is a smoothing technique that creates a series of average from + different subsets of a time series. + + For example: Given a time series ``x = [x_1, x_2, ..., x_n]``, the moving average + with a kernel size of `k` calculates the average of each subset of `k` consecutive + elements, resulting in a new series of averages. + + Args: + kernel_size (int): + Size of the moving average kernel. + stride (int): + Stride for the moving average operation, typically set to 1. + """ + + def __init__(self, kernel_size, stride): + super().__init__() + self.kernel_size = kernel_size + self.avg = nn.AvgPool1d(kernel_size, stride=stride, padding=0) + + def forward(self, x): + if self.kernel_size % 2 == 0: + self.padding_left = self.kernel_size // 2 - 1 + self.padding_right = self.kernel_size // 2 + else: + self.padding_left = self.kernel_size // 2 + self.padding_right = self.kernel_size // 2 + + front = x[:, 0:1, :].repeat(1, self.padding_left, 1) + end = x[:, -1:, :].repeat(1, self.padding_right, 1) + + x_padded = torch.cat([front, x, end], dim=1) + x_transposed = x_padded.permute(0, 2, 1) + x_smoothed = self.avg(x_transposed) + x_out = x_smoothed.permute(0, 2, 1) + return x_out diff --git a/pytorch_forecasting/models/dlinear/__init__.py b/pytorch_forecasting/models/dlinear/__init__.py new file mode 100644 index 000000000..be4ff8e26 --- /dev/null +++ b/pytorch_forecasting/models/dlinear/__init__.py @@ -0,0 +1,10 @@ +""" +Decomposition-Linear model for time series forecasting. +""" + +from pytorch_forecasting.models.dlinear._dlinear_pkg_v2 import DLinear_pkg_v2 +from pytorch_forecasting.models.dlinear._dlinear_v2 import DLinear + +__all__ = [ + "DLinear" "DLinear_pkg_v2", +] diff --git a/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py b/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py new file mode 100644 index 000000000..a2490d4c3 --- /dev/null +++ b/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py @@ -0,0 +1,127 @@ +""" +Packages container for DLinear model. +""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 + + +class DLinear_pkg_v2(_BasePtForecasterV2): + """DLinear package container.""" + + _tags = { + "info:name": "DLinear", + "info:compute": 2, + "authors": ["PranavBhatP"], + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": True, + "capability:flexible_history_length": True, + "capability:cold_start": False, + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + from pytorch_forecasting.models.dlinear._dlinear_v2 import DLinear + + return DLinear + + @classmethod + def _get_test_datamodule_from(cls, trainer_kwargs): + """Create test dataloaders from trainer_kwargs - following v1 pattern.""" + from pytorch_forecasting.data._tslib_data_module import TslibDataModule + from pytorch_forecasting.tests._data_scenarios import ( + data_with_covariates_v2, + make_datasets_v2, + ) + + data_with_covariates = data_with_covariates_v2() + data_loader_default_kwargs = dict( + target="target", + group_ids=["agency_encoded", "sku_encoded"], + add_relative_time_idx=True, + ) + + data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) + data_loader_default_kwargs.update(data_loader_kwargs) + + datasets_info = make_datasets_v2( + data_with_covariates, **data_loader_default_kwargs + ) + + training_dataset = datasets_info["training_dataset"] + validation_dataset = datasets_info["validation_dataset"] + + context_length = data_loader_kwargs.get("context_length", 8) + prediction_length = data_loader_kwargs.get("prediction_length", 2) + + batch_size = data_loader_kwargs.get("batch_size", 2) + + train_datamodule = TslibDataModule( + time_series_dataset=training_dataset, + context_length=context_length, + prediction_length=prediction_length, + add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), + batch_size=batch_size, + train_val_test_split=(0.8, 0.2, 0.0), + ) + + val_datamodule = TslibDataModule( + time_series_dataset=validation_dataset, + context_length=context_length, + prediction_length=prediction_length, + add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), + batch_size=batch_size, + train_val_test_split=(0.0, 1.0, 0.0), + ) + + test_datamodule = TslibDataModule( + time_series_dataset=validation_dataset, + context_length=context_length, + prediction_length=prediction_length, + add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), + batch_size=batch_size, + train_val_test_split=(0.0, 0.0, 1.0), + ) + + train_datamodule.setup("fit") + val_datamodule.setup("fit") + test_datamodule.setup("test") + + train_dataloader = train_datamodule.train_dataloader() + val_dataloader = val_datamodule.val_dataloader() + test_dataloader = test_datamodule.test_dataloader() + + return { + "train": train_dataloader, + "val": val_dataloader, + "test": test_dataloader, + "data_module": train_datamodule, + } + + @classmethod + def get_test_train_params(cls): + """ + Return testing parameter settings for the trainer. + + Parameters + ---------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + """ + + from pytorch_forecasting.metrics import MAE, MAPE, SMAPE, QuantileLoss + + return [ + {}, + dict(moving_avg=25, individual=False, logging_metrics=[SMAPE()]), + dict( + moving_avg=4, + individual=True, + ), + dict( + moving_avg=5, + individual=False, + logging_metrics=[SMAPE()], + ), + ] diff --git a/pytorch_forecasting/models/dlinear/_dlinear_v2.py b/pytorch_forecasting/models/dlinear/_dlinear_v2.py new file mode 100644 index 000000000..d8a0ca4e0 --- /dev/null +++ b/pytorch_forecasting/models/dlinear/_dlinear_v2.py @@ -0,0 +1,305 @@ +""" +LTSF-DLinear model for Pytorch Forecasting. +------------------------------------------- +""" + +################################################# +# NOTE: This is an experimental implementation # +# of LTSF-DLinear model for PTF v2. # +# It is an unstable API and subject to change. # +################################################# + +from typing import Any, Optional, Union +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import Optimizer + +from pytorch_forecasting.layers._decomposition import SeriesDecomposition +from pytorch_forecasting.metrics import QuantileLoss +from pytorch_forecasting.models.base._tslib_base_model_v2 import TslibBaseModel + + +class DLinear(TslibBaseModel): + """ + DLinear: Decomposition Linear Model for Long-Term Time Series Forecasting. + + DLinear decomposes time series into trend and seasonal components and applies + separate linear layers to each component. The final prediction is the sum of + both components. + + Parameters + ---------- + loss: nn.Module + Loss function for training step. + moving_avg: int , default=25 + Kernel size for moving average decomposition. + individual: bool, default=False + Whether to use individual linear layers for each variate (True) or + shared layers across all variates (False). + logging_metrics: Optional[list[nn.Module]], default=None + List of metrics to log during training, validation, and testing. + optimizer: Optional[Union[Optimizer, str]], default='adam' + Optimizer to use for training. + optimizer_params: Optional[dict], default=None + Parameters for the optimizer. + lr_scheduler: Optional[str], default=None + Learning rate scheduler to use. + lr_scheduler_params: Optional[dict], default=None + Parameters for the learning rate scheduler. + metadata: Optional[dict], default=None + Metadata for the model from TslibDataModule. + + References + ---------- + [1] https://arxiv.org/pdf/2205.13504 + [2] https://github.com/thuml/Time-Series-Library/blob/main/models/DLinear.py + + Notes + ----- + [1] This implementation supports only continuous features. Categorical variables + will be accomodated in future versions. + """ + + def __init__( + self, + loss: nn.Module, + moving_avg: int = 25, + individual: bool = False, + logging_metrics: Optional[list[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[dict] = None, + metadata: Optional[dict] = None, + **kwargs: Any, + ): + super().__init__( + loss=loss, + logging_metrics=logging_metrics, + optimizer=optimizer, + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + lr_scheduler_params=lr_scheduler_params, + metadata=metadata, + ) + + warnings.warn( + "DLinear is an experimental model implemented on TslibBaseModelV2. " + "It is an unstable version and may be subject to unannounced changes. " + "Please use with caution." + ) + self.moving_avg = moving_avg + self.individual = individual + + self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"]) + + self._init_network() + + self.apply(self._weight_init) + + def _weight_init(self, m: nn.Module): + if isinstance(m, nn.Linear): + nn.init.constant(m.weight.data, 1.0 / self.context_length) + if m.bias is not None: + nn.init.constant(m.bias.data, 0.0) + + def _init_network(self): + """ + Initialise the DLinear model network layer components. + """ + + self.enc_in = self.cont_dim + self.target_dim + + self.decomposition = SeriesDecomposition(self.moving_avg) + + self.n_quantiles = None + + if isinstance(self.loss, QuantileLoss): + self.n_quantiles = len(self.loss.quantiles) + + output_dim = self.prediction_length + + if self.n_quantiles is not None: + output_dim = self.prediction_length * self.n_quantiles + + if self.individual: + self.linear_seasonal = nn.ModuleList() + self.linear_trend = nn.ModuleList() + + for i in range(self.enc_in): + seasonal_layer = nn.Linear(self.context_length, output_dim) + trend_layer = nn.Linear(self.context_length, output_dim) + + self.linear_seasonal.append(seasonal_layer) + self.linear_trend.append(trend_layer) + else: + self.linear_seasonal = nn.Linear(self.context_length, output_dim) + self.linear_trend = nn.Linear(self.context_length, output_dim) + + def _encoder(self, x: torch.Tensor, target_indices: torch.Tensor) -> torch.Tensor: + """ + Encoder the input time series through decompoosition and linear layers. + + Parameters + ---------- + x: torch.Tensor + Input data fed into the encoder. + target_indices: torch.Tensor + Indices of target features to be extracted from the output. If None, all features are returned. + + Returns + ------- + output: torch.Tensor + Encoded output tensor of shape (batch_size, prediction_length, n_features) + """ # noqa: E501 + + seasonal_init, trend_init = self.decomposition(x) + seasonal_init = seasonal_init.permute(0, 2, 1) + trend_init = trend_init.permute(0, 2, 1) + + if self.individual: + seasonal_output, trend_output = self._process_individual_features( + seasonal_init, trend_init + ) # noqa: E501 + else: + seasonal_output = self.linear_seasonal(seasonal_init) + trend_output = self.linear_trend(trend_init) + + output = seasonal_output + trend_output + + if target_indices is not None: + output = output[:, target_indices, :] + + output = self._reshape_output(output) + + return output + + def _process_individual_features( + self, seasonal_init: torch.Tensor, trend_init: torch.Tensor + ): # noqa: E501 + """ + Process features individually when self.individual=True. + + Parameters + ---------- + seasonal_init: Seasonal component tensor + trend_init: Trend component tensor + + Returns + ------- + tuple: (seasonal_output, trend_output) + """ + # Determine output dimension + if self.n_quantiles is not None: + output_dim = self.prediction_length * self.n_quantiles + else: + output_dim = self.prediction_length + + # Initialize output tensors + # same batch_size and n_features for both seasonal and trend + batch_size, n_features, _ = seasonal_init.shape + seasonal_output = torch.zeros( + (batch_size, n_features, output_dim), + dtype=seasonal_init.dtype, + device=seasonal_init.device, + ) + trend_output = torch.zeros( + (batch_size, n_features, output_dim), + dtype=trend_init.dtype, + device=trend_init.device, + ) + + # Apply individual linear layers + for i in range(self.enc_in): + seasonal_output[:, i, :] = self.linear_seasonal[i](seasonal_init[:, i, :]) + trend_output[:, i, :] = self.linear_trend[i](trend_init[:, i, :]) + + return seasonal_output, trend_output + + def _reshape_output(self, output: torch.Tensor) -> torch.Tensor: + """ + Reshape output tensor for quantile predictions. + + Parameters + ---------- + output: torch.Tensor + Output tensor from the encoder, expected to be of shape + (batch_size, n_features, prediction_length) or + (batch_size, n_features, prediction_length, n_quantiles). + Returns + ------- + output: torch.Tensor + Reshaped tensor (batch_size, prediction_length, n_features, n_quantiles) + or (batch_size, prediction_length, n_features) if n_quantiles is None. + """ + if self.n_quantiles is not None: + batch_size, n_features = output.shape[0], output.shape[1] + output = output.reshape( + batch_size, n_features, self.prediction_length, self.n_quantiles + ) + output = output.permute(0, 2, 1, 3) # (batch, time, features, quantiles) + else: + output = output.permute(0, 2, 1) # (batch, time, features) + + # univariate forecasting + if self.target_dim == 1 and output.shape[-1] == 1: + output = output.squeeze(-1) + + return output + + def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Forward pass of the DLinear model. + + Parameters + ---------- + x: dict[str, torch.Tensor] + Dictionary containing input tensors. + + Returns + ------- + dict[str, torch.Tensor] + Dictionary containing output tensors. These can include + - predictions: Prediction_output of shape (batch_size, prediction_length, target_dim) + - attention_weights: Optionally, output attention weights + """ # noqa: E501 + input_data, target_indices = self._prepare_input_data(x) + + prediction = self._encoder(input_data, target_indices) + + if "target_scale" in x and hasattr(self, "transform_output"): + prediction = self.transform_output(prediction, x["target_scale"]) + + return {"prediction": prediction} + + def _prepare_input_data(self, x: dict[str, torch.Tensor]): + """Prepare input data and target indices for model input.""" + + available_features = [] + target_indices = [] + current_idx = 0 + + if "history_cont" in x and x["history_cont"].size(-1) > 0: + available_features.append(x["history_cont"]) + current_idx += x["history_cont"].size(-1) + + if "history_target" in x and x["history_target"].size(-1) > 0: + n_targets = x["history_target"].size(-1) + target_indices = list(range(current_idx, current_idx + n_targets)) + available_features.append(x["history_target"]) + + if not available_features: + raise ValueError("No valid input features found in the input dictionary.") + + input_data = torch.cat(available_features, dim=-1) + + target_indices = ( + torch.tensor(target_indices, dtype=torch.long, device=input_data.device) + if target_indices + else None + ) + + return input_data, target_indices diff --git a/tests/test_models/test_dlinear_v2.py b/tests/test_models/test_dlinear_v2.py new file mode 100644 index 000000000..efe2a0552 --- /dev/null +++ b/tests/test_models/test_dlinear_v2.py @@ -0,0 +1,168 @@ +import numpy as np +import pandas as pd +import pytest +import torch +from torch import nn + +from pytorch_forecasting.data import TimeSeries +from pytorch_forecasting.data._tslib_data_module import TslibDataModule +from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss +from pytorch_forecasting.models.dlinear._dlinear_v2 import DLinear + + +@pytest.fixture +def sample_dataset(): + """Create a sample dataset for testing using v2.""" + n_samples = 100 + n_series = 3 + + time_idx = np.arange(n_samples) + + series_data = [] + for i in range(n_series): + trend = 0.1 * time_idx + seasonality = 10 * np.sin(2 * np.pi * time_idx / 20) + noise = np.random.normal(0, 1, n_samples) + values = trend + seasonality + noise + + series = pd.DataFrame( + { + "time_idx": time_idx, + "series_id": i, + "value": values, + "feat1": np.random.normal(0, 1, n_samples), + "feat2": np.random.normal(0, 1, n_samples), + } + ) + series_data.append(series) + + data = pd.concat(series_data).reset_index(drop=True) + + ts = TimeSeries( + data, + time="time_idx", + group=["series_id"], + target=["value"], + num=["feat1", "feat2"], + cat=[], + known=["time_idx"], + unknown=["value", "feat1", "feat2"], + ) + + dm = TslibDataModule(ts, context_length=16, prediction_length=4, batch_size=4) + + dm.setup() + + return {"data_module": dm, "time_series": ts} + + +@pytest.mark.parametrize( + "moving_average, individual", + [ + (5, False), + (25, True), + ], +) +def test_dlinear_init(moving_average, individual, sample_dataset): + """Test DLinear initialization.""" + + dm = sample_dataset["data_module"] + + metadata = dm.metadata + loss = MAE() + model = DLinear( + loss=loss, moving_avg=moving_average, individual=individual, metadata=metadata + ) + + assert model.moving_avg == moving_average + assert model.individual == individual + assert model.n_quantiles is None + + +def test_dlinear_forward(sample_dataset): + """Test forward pass of DLinear.""" + + dm = sample_dataset["data_module"] + + train_dataloader = dm.train_dataloader() + batch = next(iter(train_dataloader))[0] + + metadata = dm.metadata + + model = DLinear(loss=MAE(), moving_avg=5, individual=True, metadata=metadata) + + with torch.no_grad(): + output = model(batch) + + assert "prediction" in output + assert output["prediction"].shape[0] == dm.batch_size + assert output["prediction"].shape[1] == metadata["prediction_length"] + + +def test_quantile_loss_output(sample_dataset): + """Test DLinear output shape with quantile loss.""" + + dm = sample_dataset["data_module"] + + train_dataloader = dm.train_dataloader() + batch = next(iter(train_dataloader))[0] + + metadata = dm.metadata + + quantiles = [0.1, 0.5, 0.9] + + model = DLinear( + loss=QuantileLoss(quantiles=quantiles), + moving_avg=5, + individual=True, + logging_metrics=[SMAPE(), MAE()], + metadata=metadata, + ) + + with torch.no_grad(): + output = model(batch) + + assert "prediction" in output + pred = output["prediction"] + assert pred.ndim == 4 + assert pred.shape[-1] == len(quantiles) + assert pred.shape[1] == metadata["prediction_length"] + + +def test_univariate_forecast(): + """Test univariate forecasting with DLinear.""" + + n_samples = 100 + time_idx = np.arange(n_samples) + values = np.sin(2 * np.pi * time_idx / 20) + np.random.normal(0, 0.1, n_samples) + + series = pd.DataFrame({"time_idx": time_idx, "series_id": 0, "value": values}) + + ts = TimeSeries( + series, + time="time_idx", + group=["series_id"], + target=["value"], + num=[], + cat=[], + known=["time_idx"], + unknown=["value"], + ) + + dm = TslibDataModule(ts, context_length=16, prediction_length=4, batch_size=4) + + dm.setup() + + metadata = dm.metadata + + model = DLinear(loss=MAE(), moving_avg=5, individual=False, metadata=metadata) + + train_dataloader = dm.train_dataloader() + batch = next(iter(train_dataloader))[0] + + with torch.no_grad(): + output = model(batch) + + assert "prediction" in output + assert output["prediction"].shape[0] == dm.batch_size + assert output["prediction"].shape[1] == metadata["prediction_length"]