Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
b3644a6
test suite
fkiraly Feb 22, 2025
a1d64c6
Merge branch 'main' into test-suite
fkiraly Feb 22, 2025
4b2486e
skeleton
fkiraly Feb 22, 2025
02b0ce6
skeleton
fkiraly Feb 22, 2025
41cbf66
Update test_all_estimators.py
fkiraly Feb 23, 2025
cef62d3
Update _base_object.py
fkiraly Feb 23, 2025
bc2e93b
Update _lookup.py
fkiraly Feb 23, 2025
eee1c86
Update _lookup.py
fkiraly Feb 23, 2025
164fe0d
base metadatda
fkiraly Feb 23, 2025
20e88d0
registry
fkiraly Feb 23, 2025
318c1fb
fix private name
fkiraly Feb 23, 2025
012ab3d
Update _base_object.py
fkiraly Feb 23, 2025
86365a0
test failure
fkiraly Feb 23, 2025
f6dee46
Update test_all_estimators.py
fkiraly Feb 23, 2025
9b0e4ec
Update test_all_estimators.py
fkiraly Feb 23, 2025
7de5285
Update test_all_estimators.py
fkiraly Feb 23, 2025
57dfe3a
test folders
fkiraly Feb 23, 2025
c9f12db
Update test.yml
fkiraly Feb 23, 2025
fa8144e
test integration
fkiraly Feb 23, 2025
232a510
fixes
fkiraly Feb 23, 2025
1c8d4b5
Update _conftest.py
fkiraly Feb 23, 2025
f632e32
try scenarios
fkiraly Feb 23, 2025
ef37f55
Merge branch 'main' into test-suite
fkiraly May 1, 2025
a669134
Update _lookup.py
fkiraly May 4, 2025
d78bf5d
Update _lookup.py
fkiraly May 4, 2025
3d2dafc
[ENH] EXPERIMENTAL PR: D1 and D2 layer for v2 refactor (#1811)
phoeenniixx May 13, 2025
15ea3c3
[BUG] EXPERIMENTAL PR: Solve the bug in `data_module` (#1834)
phoeenniixx May 16, 2025
c04ebf3
[BUG] fix incorrect concatenation dimension in `concat_sequences` (#1…
cngmid May 16, 2025
524d05b
[ENH] EXPERIMENTAL PR: make the `data_module` dataclass-like (#1832)
phoeenniixx May 18, 2025
b82b42a
add initial version of tests for tide
PranavBhatP May 22, 2025
e46f9f6
refactor _integration to TiDE specific _integration function
PranavBhatP May 22, 2025
6dfe1a8
remove model-specific params from _integration in test_all_estimators
PranavBhatP May 22, 2025
4613d1b
Merge branch 'main' into test-tide-stack-1780
PranavBhatP May 23, 2025
f23d4d1
add metadata class for tide
PranavBhatP May 27, 2025
228c2f1
add TiDEModelMetadata to __init__.py
PranavBhatP May 27, 2025
9be3f11
fixed model-specific changes to provide test compatibility to TiDE
PranavBhatP May 27, 2025
52d763c
Merge branch 'main' into pr/1843
fkiraly May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytorch_forecasting/models/tide/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Tide model."""

from pytorch_forecasting.models.tide._tide import TiDEModel
from pytorch_forecasting.models.tide._tide_metadata import TiDEModelMetadata
from pytorch_forecasting.models.tide.sub_modules import _TideModule

__all__ = [
"_TideModule",
"TiDEModel",
"TiDEModelMetadata",
]
63 changes: 63 additions & 0 deletions pytorch_forecasting/models/tide/_tide_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""TiDE metadata container."""

from pytorch_forecasting.models.base._base_object import _BasePtForecaster


class TiDEModelMetadata(_BasePtForecaster):
"""Metadata container for TiDE Model."""

_tags = {
"info:name": "TiDEModel",
"info:compute": 3,
"authors": ["Sohaib-Ahmed21"],
"capability:exogenous": True,
"capability:multivariate": True,
"capability:pred_int": True,
"capability:flexible_history_length": True,
"capability:cold_start": False,
}

@classmethod
def get_model_cls(cls):
"""Get model class."""
from pytorch_forecasting.models.tide import TiDEModel

return TiDEModel

@classmethod
def get_test_train_params(cls):
"""Return testing parameter settings for the trainer.

Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class.
"""

from pytorch_forecasting.data.encoders import GroupNormalizer
from pytorch_forecasting.metrics import SMAPE

return [
{
"data_loader_kwargs": dict(
add_relative_time_idx=False,
# must include this everytime since the data_loader_default_kwargs
# include this to be True.
)
},
{
"temporal_decoder_hidden": 16,
"data_loader_kwargs": dict(add_relative_time_idx=False),
},
{
"dropout": 0.2,
"use_layer_norm": True,
"loss": SMAPE(),
"data_loader_kwargs": dict(
target_normalizer=GroupNormalizer(
groups=["agency", "sku"], transformation="softplus"
),
add_relative_time_idx=False,
),
},
]
179 changes: 179 additions & 0 deletions tests/test_models/test_tide.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import pickle
import shutil

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
import numpy as np
import pandas as pd
import pytest

from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss
from pytorch_forecasting.models import TiDEModel
from pytorch_forecasting.tests.test_all_estimators import _integration
from pytorch_forecasting.utils._dependencies import _get_installed_packages


def _tide_integration(dataloaders, tmp_path, trainer_kwargs=None, **kwargs):
"""TiDE specific wrapper around the common integration test function.

Args:
dataloaders: Dictionary of dataloaders for train, val, and test.
tmp_path: Temporary path for saving the model.
trainer_kwargs: Additional arguments for the Trainer.
**kwargs: Additional arguments for the TiDEModel.

Returns:
Predictions from the trained model.
"""
from pytorch_forecasting.tests._data_scenarios import data_with_covariates

df = data_with_covariates()

tide_kwargs = {
"temporal_decoder_hidden": 8,
"temporal_width_future": 4,
"dropout": 0.1,
}

tide_kwargs.update(kwargs)
train_dataset = dataloaders["train"].dataset

data_loader_kwargs = {
"target": train_dataset.target,
"group_ids": train_dataset.group_ids,
"time_varying_known_reals": train_dataset.time_varying_known_reals,
"time_varying_unknown_reals": train_dataset.time_varying_unknown_reals,
"static_categoricals": train_dataset.static_categoricals,
"static_reals": train_dataset.static_reals,
"add_relative_time_idx": train_dataset.add_relative_time_idx,
}
return _integration(
TiDEModel,
df,
tmp_path,
data_loader_kwargs=data_loader_kwargs,
trainer_kwargs=trainer_kwargs,
**tide_kwargs,
)


@pytest.mark.parametrize(
"kwargs",
[
{},
{"loss": SMAPE()},
{"temporal_decoder_hidden": 16},
{"dropout": 0.2, "use_layer_norm": True},
],
)
def test_integration(dataloaders_with_covariates, tmp_path, kwargs):
_tide_integration(dataloaders_with_covariates, tmp_path, **kwargs)


@pytest.mark.parametrize(
"kwargs",
[
{},
],
)
def test_multi_target_integration(dataloaders_multi_target, tmp_path, kwargs):
_tide_integration(dataloaders_multi_target, tmp_path, **kwargs)


@pytest.fixture
def model(dataloaders_with_covariates):
dataset = dataloaders_with_covariates["train"].dataset
net = TiDEModel.from_dataset(
dataset,
hidden_size=16,
dropout=0.1,
temporal_width_future=4,
)
return net


def test_pickle(model):
pkl = pickle.dumps(model)
pickle.loads(pkl) # noqa: S301


@pytest.mark.skipif(
"matplotlib" not in _get_installed_packages(),
reason="skip test if required package matplotlib not installed",
)
def test_prediction_visualization(model, dataloaders_with_covariates):
raw_predictions = model.predict(
dataloaders_with_covariates["val"],
mode="raw",
return_x=True,
fast_dev_run=True,
)
model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=0)


def test_prediction_with_kwargs(model, dataloaders_with_covariates):
# Tests prediction works with different keyword arguments
model.predict(
dataloaders_with_covariates["val"], return_index=True, fast_dev_run=True
)
model.predict(
dataloaders_with_covariates["val"],
return_x=True,
return_y=True,
fast_dev_run=True,
)


def test_no_exogenous_variable():
data = pd.DataFrame(
{
"target": np.ones(1600),
"group_id": np.repeat(np.arange(16), 100),
"time_idx": np.tile(np.arange(100), 16),
}
)
training_dataset = TimeSeriesDataSet(
data=data,
time_idx="time_idx",
target="target",
group_ids=["group_id"],
max_encoder_length=10,
max_prediction_length=5,
time_varying_unknown_reals=["target"],
time_varying_known_reals=[],
)
validation_dataset = TimeSeriesDataSet.from_dataset(
training_dataset, data, stop_randomization=True, predict=True
)
training_data_loader = training_dataset.to_dataloader(
train=True, batch_size=8, num_workers=0
)
validation_data_loader = validation_dataset.to_dataloader(
train=False, batch_size=8, num_workers=0
)
forecaster = TiDEModel.from_dataset(
training_dataset,
)
from lightning.pytorch import Trainer

trainer = Trainer(
max_epochs=2,
limit_train_batches=8,
limit_val_batches=8,
)
trainer.fit(
forecaster,
train_dataloaders=training_data_loader,
val_dataloaders=validation_data_loader,
)
best_model_path = trainer.checkpoint_callback.best_model_path
best_model = TiDEModel.load_from_checkpoint(best_model_path)
best_model.predict(
validation_data_loader,
fast_dev_run=True,
return_x=True,
return_y=True,
return_index=True,
)
Loading