From b3644a65b041a790b94756fb1d9bbf2797236d0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 22 Feb 2025 23:18:51 +0100 Subject: [PATCH 01/33] test suite --- pyproject.toml | 1 + pytorch_forecasting/tests/__init__.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 pytorch_forecasting/tests/__init__.py diff --git a/pyproject.toml b/pyproject.toml index f3d1e339c..8c661db1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,7 @@ dev = [ "pytest-dotenv>=0.5.2,<1.0.0", "tensorboard>=2.12.1,<3.0.0", "pandoc>=2.3,<3.0.0", + "scikit-base", ] # docs - dependencies for building the documentation diff --git a/pytorch_forecasting/tests/__init__.py b/pytorch_forecasting/tests/__init__.py new file mode 100644 index 000000000..6c2d26856 --- /dev/null +++ b/pytorch_forecasting/tests/__init__.py @@ -0,0 +1 @@ +"""PyTorch Forecasting test suite.""" From 4b2486e083ca93d8f4c1a29a6a25d882027815f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 22 Feb 2025 23:33:29 +0100 Subject: [PATCH 02/33] skeleton --- pytorch_forecasting/_registry/__init__.py | 16 ++ pytorch_forecasting/_registry/_lookup.py | 214 ++++++++++++++++++ pytorch_forecasting/models/base/__init__.py | 2 + .../models/base/_base_object.py | 8 + pytorch_forecasting/tests/_config.py | 13 ++ .../tests/test_all_estimators.py | 118 ++++++++++ 6 files changed, 371 insertions(+) create mode 100644 pytorch_forecasting/_registry/__init__.py create mode 100644 pytorch_forecasting/_registry/_lookup.py create mode 100644 pytorch_forecasting/models/base/_base_object.py create mode 100644 pytorch_forecasting/tests/_config.py create mode 100644 pytorch_forecasting/tests/test_all_estimators.py diff --git a/pytorch_forecasting/_registry/__init__.py b/pytorch_forecasting/_registry/__init__.py new file mode 100644 index 000000000..bb0b88e61 --- /dev/null +++ b/pytorch_forecasting/_registry/__init__.py @@ -0,0 +1,16 @@ +"""PyTorch Forecasting registry.""" + +from pytorch_forecasting._registry._lookup import all_objects, all_tags +from pytorch_forecasting._registry._tags import ( + OBJECT_TAG_LIST, + OBJECT_TAG_REGISTER, + check_tag_is_valid, +) + +__all__ = [ + "OBJECT_TAG_LIST", + "OBJECT_TAG_REGISTER", + "all_objects", + "all_tags", + "check_tag_is_valid", +] diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py new file mode 100644 index 000000000..ea3210cb0 --- /dev/null +++ b/pytorch_forecasting/_registry/_lookup.py @@ -0,0 +1,214 @@ +"""Registry lookup methods. + +This module exports the following methods for registry lookup: + +all_objects(object_types, filter_tags) + lookup and filtering of objects +""" +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) +# based on the sktime module of same name + +__author__ = ["fkiraly"] +# all_objects is based on the sklearn utility all_estimators + + +from copy import deepcopy +from operator import itemgetter +from pathlib import Path + +import pandas as pd +from skbase.lookup import all_objects as _all_objects + +from pytorch_forecasting.models.base import _BaseObject + + +def all_objects( + object_types=None, + filter_tags=None, + exclude_objects=None, + return_names=True, + as_dataframe=False, + return_tags=None, + suppress_import_stdout=True, +): + """Get a list of all objects from pytorch_forecasting. + + This function crawls the module and gets all classes that inherit + from skbase compatible base classes. + + Not included are: the base classes themselves, classes defined in test + modules. + + Parameters + ---------- + object_types: str, list of str, optional (default=None) + Which kind of objects should be returned. + if None, no filter is applied and all objects are returned. + if str or list of str, strings define scitypes specified in search + only objects that are of (at least) one of the scitypes are returned + possible str values are entries of registry.BASE_CLASS_REGISTER (first col) + for instance 'regrssor_proba', 'distribution, 'metric' + + return_names: bool, optional (default=True) + + if True, estimator class name is included in the ``all_objects`` + return in the order: name, estimator class, optional tags, either as + a tuple or as pandas.DataFrame columns + + if False, estimator class name is removed from the ``all_objects`` return. + + filter_tags: dict of (str or list of str), optional (default=None) + For a list of valid tag strings, use the registry.all_tags utility. + + ``filter_tags`` subsets the returned estimators as follows: + + * each key/value pair is statement in "and"/conjunction + * key is tag name to sub-set on + * value str or list of string are tag values + * condition is "key must be equal to value, or in set(value)" + + exclude_estimators: str, list of str, optional (default=None) + Names of estimators to exclude. + + as_dataframe: bool, optional (default=False) + + True: ``all_objects`` will return a pandas.DataFrame with named + columns for all of the attributes being returned. + + False: ``all_objects`` will return a list (either a list of + estimators or a list of tuples, see Returns) + + return_tags: str or list of str, optional (default=None) + Names of tags to fetch and return each estimator's value of. + For a list of valid tag strings, use the registry.all_tags utility. + if str or list of str, + the tag values named in return_tags will be fetched for each + estimator and will be appended as either columns or tuple entries. + + suppress_import_stdout : bool, optional. Default=True + whether to suppress stdout printout upon import. + + Returns + ------- + all_objects will return one of the following: + 1. list of objects, if return_names=False, and return_tags is None + 2. list of tuples (optional object name, class, ~optional object + tags), if return_names=True or return_tags is not None. + 3. pandas.DataFrame if as_dataframe = True + if list of objects: + entries are objects matching the query, + in alphabetical order of object name + if list of tuples: + list of (optional object name, object, optional object + tags) matching the query, in alphabetical order of object name, + where + ``name`` is the object name as string, and is an + optional return + ``object`` is the actual object + ``tags`` are the object's values for each tag in return_tags + and is an optional return. + if dataframe: + all_objects will return a pandas.DataFrame. + column names represent the attributes contained in each column. + "objects" will be the name of the column of objects, "names" + will be the name of the column of object class names and the string(s) + passed in return_tags will serve as column names for all columns of + tags that were optionally requested. + + Examples + -------- + >>> from skpro.registry import all_objects + >>> # return a complete list of objects as pd.Dataframe + >>> all_objects(as_dataframe=True) # doctest: +SKIP + >>> # return all probabilistic regressors by filtering for object type + >>> all_objects("regressor_proba", as_dataframe=True) # doctest: +SKIP + >>> # return all regressors which handle missing data in the input by tag filtering + >>> all_objects( + ... "regressor_proba", + ... filter_tags={"capability:missing": True}, + ... as_dataframe=True + ... ) # doctest: +SKIP + + References + ---------- + Adapted version of sktime's ``all_estimators``, + which is an evolution of scikit-learn's ``all_estimators`` + """ + MODULES_TO_IGNORE = ( + "tests", + "setup", + "contrib", + "utils", + "all", + ) + + result = [] + ROOT = str(Path(__file__).parent.parent) # skpro package root directory + + if isinstance(filter_tags, str): + filter_tags = {filter_tags: True} + filter_tags = filter_tags.copy() if filter_tags else None + + if object_types: + if filter_tags and "object_type" not in filter_tags.keys(): + object_tag_filter = {"object_type": object_types} + elif filter_tags: + filter_tags_filter = filter_tags.get("object_type", []) + if isinstance(object_types, str): + object_types = [object_types] + object_tag_update = {"object_type": object_types + filter_tags_filter} + filter_tags.update(object_tag_update) + else: + object_tag_filter = {"object_type": object_types} + if filter_tags: + filter_tags.update(object_tag_filter) + else: + filter_tags = object_tag_filter + + result = _all_objects( + object_types=[_BaseObject], + filter_tags=filter_tags, + exclude_objects=exclude_objects, + return_names=return_names, + as_dataframe=as_dataframe, + return_tags=return_tags, + suppress_import_stdout=suppress_import_stdout, + package_name="skpro", + path=ROOT, + modules_to_ignore=MODULES_TO_IGNORE, + ) + + return result + + +def _check_list_of_str_or_error(arg_to_check, arg_name): + """Check that certain arguments are str or list of str. + + Parameters + ---------- + arg_to_check: argument we are testing the type of + arg_name: str, + name of the argument we are testing, will be added to the error if + ``arg_to_check`` is not a str or a list of str + + Returns + ------- + arg_to_check: list of str, + if arg_to_check was originally a str it converts it into a list of str + so that it can be iterated over. + + Raises + ------ + TypeError if arg_to_check is not a str or list of str + """ + # check that return_tags has the right type: + if isinstance(arg_to_check, str): + arg_to_check = [arg_to_check] + if not isinstance(arg_to_check, list) or not all( + isinstance(value, str) for value in arg_to_check + ): + raise TypeError( + f"Error in all_objects! Argument {arg_name} must be either\ + a str or list of str" + ) + return arg_to_check diff --git a/pytorch_forecasting/models/base/__init__.py b/pytorch_forecasting/models/base/__init__.py index 4860e4838..474e7d564 100644 --- a/pytorch_forecasting/models/base/__init__.py +++ b/pytorch_forecasting/models/base/__init__.py @@ -7,8 +7,10 @@ BaseModelWithCovariates, Prediction, ) +from pytorch_forecasting.models.base._base_object import _BaseObject __all__ = [ + "_BaseObject", "AutoRegressiveBaseModel", "AutoRegressiveBaseModelWithCovariates", "BaseModel", diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py new file mode 100644 index 000000000..7330867b1 --- /dev/null +++ b/pytorch_forecasting/models/base/_base_object.py @@ -0,0 +1,8 @@ +"""Base Classes for pytorch-forecasting models, skbase compatible for indexing.""" + +from skbase.base import BaseObject as _SkbaseBaseObject + + +class _BaseObject(_SkbaseBaseObject): + + pass diff --git a/pytorch_forecasting/tests/_config.py b/pytorch_forecasting/tests/_config.py new file mode 100644 index 000000000..dd9c2e889 --- /dev/null +++ b/pytorch_forecasting/tests/_config.py @@ -0,0 +1,13 @@ +"""Test configs.""" + +# list of str, names of estimators to exclude from testing +# WARNING: tests for these estimators will be skipped +EXCLUDE_ESTIMATORS = [ + "DummySkipped", + "ClassName", # exclude classes from extension templates +] + +# dictionary of lists of str, names of tests to exclude from testing +# keys are class names of estimators, values are lists of test names to exclude +# WARNING: tests with these names will be skipped +EXCLUDED_TESTS = {} diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py new file mode 100644 index 000000000..c806691fa --- /dev/null +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -0,0 +1,118 @@ +"""Automated tests based on the skbase test suite template.""" +from inspect import isclass + +from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator +from skbase.testing import TestAllObjects as _TestAllObjects + +from pytorch_forecasting._registry import all_objects +from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS + + +# whether to test only estimators from modules that are changed w.r.t. main +# default is False, can be set to True by pytest --only_changed_modules True flag +ONLY_CHANGED_MODULES = False + + +class PackageConfig: + """Contains package config variables for test classes.""" + + # class variables which can be overridden by descendants + # ------------------------------------------------------ + + # package to search for objects + # expected type: str, package/module name, relative to python environment root + package_name = "pytroch_forecasting" + + # list of object types (class names) to exclude + # expected type: list of str, str are class names + exclude_objects = EXCLUDE_ESTIMATORS + + # list of tests to exclude + # expected type: dict of lists, key:str, value: List[str] + # keys are class names of estimators, values are lists of test names to exclude + excluded_tests = EXCLUDED_TESTS + + +class BaseFixtureGenerator(_BaseFixtureGenerator): + """Fixture generator for base testing functionality in sktime. + + Test classes inheriting from this and not overriding pytest_generate_tests + will have estimator and scenario fixtures parametrized out of the box. + + Descendants can override: + estimator_type_filter: str, class variable; None or scitype string + e.g., "forecaster", "transformer", "classifier", see BASE_CLASS_SCITYPE_LIST + which estimators are being retrieved and tested + fixture_sequence: list of str + sequence of fixture variable names in conditional fixture generation + _generate_[variable]: object methods, all (test_name: str, **kwargs) -> list + generating list of fixtures for fixture variable with name [variable] + to be used in test with name test_name + can optionally use values for fixtures earlier in fixture_sequence, + these must be input as kwargs in a call + is_excluded: static method (test_name: str, est: class) -> bool + whether test with name test_name should be excluded for estimator est + should be used only for encoding general rules, not individual skips + individual skips should go on the EXCLUDED_TESTS list in _config + requires _generate_object_class and _generate_object_instance as is + _excluded_scenario: static method (test_name: str, scenario) -> bool + whether scenario should be skipped in test with test_name test_name + requires _generate_estimator_scenario as is + + Fixtures parametrized + --------------------- + object_class: estimator inheriting from BaseObject + ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS + object_instance: instance of estimator inheriting from BaseObject + ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS + instances are generated by create_test_instance class method of object_class + """ + + # overrides object retrieval in scikit-base + def _all_objects(self): + """Retrieve list of all object classes of type self.object_type_filter. + + If self.object_type_filter is None, retrieve all objects. + If class, retrieve all classes inheriting from self.object_type_filter. + Otherwise (assumed str or list of str), retrieve all classes with tags + object_type in self.object_type_filter. + """ + filter = getattr(self, "object_type_filter", None) + + if isclass(filter): + object_types = filter.get_class_tag("object_type", None) + else: + object_types = filter + + obj_list = all_objects( + object_types=object_types, + return_names=False, + exclude_objects=self.exclude_objects, + ) + + if isclass(filter): + obj_list = [obj for obj in obj_list if issubclass(obj, filter)] + + # run_test_for_class selects the estimators to run + # based on whether they have changed, and whether they have all dependencies + # internally, uses the ONLY_CHANGED_MODULES flag, + # and checks the python env against python_dependencies tag + # obj_list = [obj for obj in obj_list if run_test_for_class(obj)] + + return obj_list + + # which sequence the conditional fixtures are generated in + fixture_sequence = [ + "object_class", + "object_instance", + ] + + +class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator, _TestAllObjects): + """Generic tests for all objects in the mini package.""" + + def test_doctest_examples(self, object_class): + """Runs doctests for estimator class.""" + import doctest + + doctest.run_docstring_examples(object_class, globals()) From 02b0ce6fa53443044fffce8cbbce54a0c6d6b947 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 22 Feb 2025 23:33:58 +0100 Subject: [PATCH 03/33] skeleton --- pytorch_forecasting/_registry/__init__.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/pytorch_forecasting/_registry/__init__.py b/pytorch_forecasting/_registry/__init__.py index bb0b88e61..f71836bfe 100644 --- a/pytorch_forecasting/_registry/__init__.py +++ b/pytorch_forecasting/_registry/__init__.py @@ -1,16 +1,5 @@ """PyTorch Forecasting registry.""" -from pytorch_forecasting._registry._lookup import all_objects, all_tags -from pytorch_forecasting._registry._tags import ( - OBJECT_TAG_LIST, - OBJECT_TAG_REGISTER, - check_tag_is_valid, -) +from pytorch_forecasting._registry._lookup import all_objects -__all__ = [ - "OBJECT_TAG_LIST", - "OBJECT_TAG_REGISTER", - "all_objects", - "all_tags", - "check_tag_is_valid", -] +__all__ = ["all_objects"] From 41cbf667f9aea5848c3390778a53612338319504 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 06:42:02 +0100 Subject: [PATCH 04/33] Update test_all_estimators.py --- pytorch_forecasting/tests/test_all_estimators.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index c806691fa..704ddfc20 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -1,13 +1,15 @@ """Automated tests based on the skbase test suite template.""" + from inspect import isclass -from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator -from skbase.testing import TestAllObjects as _TestAllObjects +from skbase.testing import ( + BaseFixtureGenerator as _BaseFixtureGenerator, + TestAllObjects as _TestAllObjects, +) from pytorch_forecasting._registry import all_objects from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS - # whether to test only estimators from modules that are changed w.r.t. main # default is False, can be set to True by pytest --only_changed_modules True flag ONLY_CHANGED_MODULES = False From cef62d36df5eceff5238a2d6c7fd829319028446 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 06:47:24 +0100 Subject: [PATCH 05/33] Update _base_object.py --- pytorch_forecasting/models/base/_base_object.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 7330867b1..a91b10525 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -1,6 +1,8 @@ """Base Classes for pytorch-forecasting models, skbase compatible for indexing.""" -from skbase.base import BaseObject as _SkbaseBaseObject +from pytorch_forecasting.utils._dependencies import _safe_import + +_SkbaseBaseObject = _safe_import("skbase._base_object._BaseObject") class _BaseObject(_SkbaseBaseObject): From bc2e93b606095440772f7236eeebb070109c649f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 09:45:10 +0100 Subject: [PATCH 06/33] Update _lookup.py --- pytorch_forecasting/_registry/_lookup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index ea3210cb0..517bfa9af 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -5,7 +5,6 @@ all_objects(object_types, filter_tags) lookup and filtering of objects """ -# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) # based on the sktime module of same name __author__ = ["fkiraly"] From eee1c86859dc1d66d46eb85c7b39938639f8231e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 10:37:22 +0100 Subject: [PATCH 07/33] Update _lookup.py --- pytorch_forecasting/_registry/_lookup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index 517bfa9af..1ab4cfdb3 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -5,6 +5,7 @@ all_objects(object_types, filter_tags) lookup and filtering of objects """ + # based on the sktime module of same name __author__ = ["fkiraly"] From 164fe0d238ebe6b9f888c416f998a73948b365f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:05:54 +0100 Subject: [PATCH 08/33] base metadatda --- .../models/base/_base_object.py | 98 ++++++++++++++ .../models/deepar/_deepar_metadata.py | 128 ++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 pytorch_forecasting/models/deepar/_deepar_metadata.py diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index a91b10525..62fad456b 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -1,5 +1,7 @@ """Base Classes for pytorch-forecasting models, skbase compatible for indexing.""" +import inspect + from pytorch_forecasting.utils._dependencies import _safe_import _SkbaseBaseObject = _safe_import("skbase._base_object._BaseObject") @@ -8,3 +10,99 @@ class _BaseObject(_SkbaseBaseObject): pass + + +class _BasePtForecaster(_BaseObject): + """Base class for all PyTorch Forecasting forecaster metadata. + + This class points to model objects and contains metadata as tags. + """ + + _tags = { + "object_type": "forecaster_pytorch", + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + raise NotImplementedError + + + @classmethod + def create_test_instance(cls, parameter_set="default"): + """Construct an instance of the class, using first test parameter set. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + instance : instance of the class with default parameters + + """ + if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args: + params = cls.get_test_params(parameter_set=parameter_set) + else: + params = cls.get_test_params() + + if isinstance(params, list) and isinstance(params[0], dict): + params = params[0] + elif isinstance(params, dict): + pass + else: + raise TypeError( + "get_test_params should either return a dict or list of dict." + ) + + return cls.get_model_cls()(**params) + + @classmethod + def create_test_instances_and_names(cls, parameter_set="default"): + """Create list of all test instances and a list of names for them. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + objs : list of instances of cls + i-th instance is ``cls(**cls.get_test_params()[i])`` + names : list of str, same length as objs + i-th element is name of i-th instance of obj in tests. + The naming convention is ``{cls.__name__}-{i}`` if more than one instance, + otherwise ``{cls.__name__}`` + """ + if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args: + param_list = cls.get_test_params(parameter_set=parameter_set) + else: + param_list = cls.get_test_params() + + objs = [] + if not isinstance(param_list, (dict, list)): + raise RuntimeError( + f"Error in {cls.__name__}.get_test_params, " + "return must be param dict for class, or list thereof" + ) + if isinstance(param_list, dict): + param_list = [param_list] + for params in param_list: + if not isinstance(params, dict): + raise RuntimeError( + f"Error in {cls.__name__}.get_test_params, " + "return must be param dict for class, or list thereof" + ) + objs += [cls.get_model_cls()(**params)] + + num_instances = len(param_list) + if num_instances > 1: + names = [cls.__name__ + "-" + str(i) for i in range(num_instances)] + else: + names = [cls.__name__] + + return objs, names diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py new file mode 100644 index 000000000..330b86e80 --- /dev/null +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -0,0 +1,128 @@ +"""DeepAR metadata container.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecaster + + +class _DeepARMetadata(_BasePtForecaster): + """DeepAR metadata container.""" + + _tags = { + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": True, + "capability:flexible_history_length": True, + "capability:cold_start": False, + "info:compute": 3, + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + from pytorch_forecasting.models import DeepAR + + return DeepAR + + @classmethod + def get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the skbase object. + + ``get_test_params`` is a unified interface point to store + parameter settings for testing purposes. This function is also + used in ``create_test_instance`` and ``create_test_instances_and_names`` + to construct test instances. + + ``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``. + + Each ``dict`` is a parameter configuration for testing, + and can be used to construct an "interesting" test instance. + A call to ``cls(**params)`` should + be valid for all dictionaries ``params`` in the return of ``get_test_params``. + + The ``get_test_params`` need not return fixed lists of dictionaries, + it can also return dynamic or stochastic parameter settings. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + from pytorch_forecasting.data.encoders import GroupNormalizer + from pytorch_forecasting.metrics import ( + BetaDistributionLoss, + ImplicitQuantileNetworkDistributionLoss, + LogNormalDistributionLoss, + MultivariateNormalDistributionLoss, + NegativeBinomialDistributionLoss, + ) + + return [ + {}, + {"cell_type": "GRU"}, + dict( + loss=LogNormalDistributionLoss(), + clip_target=True, + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="log" + ) + ), + ), + dict( + loss=NegativeBinomialDistributionLoss(), + clip_target=False, + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], center=False + ) + ), + ), + dict( + loss=BetaDistributionLoss(), + clip_target=True, + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="logit" + ) + ), + ), + dict( + data_loader_kwargs=dict( + lags={"volume": [2, 5]}, + target="volume", + time_varying_unknown_reals=["volume"], + min_encoder_length=2, + ) + ), + dict( + data_loader_kwargs=dict( + time_varying_unknown_reals=["volume", "discount"], + target=["volume", "discount"], + lags={"volume": [2], "discount": [2]}, + ) + ), + dict( + loss=ImplicitQuantileNetworkDistributionLoss(hidden_size=8), + ), + dict( + loss=MultivariateNormalDistributionLoss(), + trainer_kwargs=dict(accelerator="cpu"), + ), + dict( + loss=MultivariateNormalDistributionLoss(), + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="log1p" + ) + ), + trainer_kwargs=dict(accelerator="cpu"), + ), + ] From 20e88d09993f3fed62ab52f93d0a4678f1a0c068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:21:34 +0100 Subject: [PATCH 09/33] registry --- pytorch_forecasting/_registry/_lookup.py | 18 +++--------------- pytorch_forecasting/models/base/__init__.py | 6 +++++- .../models/base/_base_object.py | 2 +- .../utils/_dependencies/_safe_import.py | 3 +++ 4 files changed, 12 insertions(+), 17 deletions(-) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index 1ab4cfdb3..828c448b1 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -11,12 +11,8 @@ __author__ = ["fkiraly"] # all_objects is based on the sklearn utility all_estimators - -from copy import deepcopy -from operator import itemgetter from pathlib import Path -import pandas as pd from skbase.lookup import all_objects as _all_objects from pytorch_forecasting.models.base import _BaseObject @@ -117,17 +113,9 @@ def all_objects( Examples -------- - >>> from skpro.registry import all_objects + >>> from pytorch_forecasting._registry import all_objects >>> # return a complete list of objects as pd.Dataframe >>> all_objects(as_dataframe=True) # doctest: +SKIP - >>> # return all probabilistic regressors by filtering for object type - >>> all_objects("regressor_proba", as_dataframe=True) # doctest: +SKIP - >>> # return all regressors which handle missing data in the input by tag filtering - >>> all_objects( - ... "regressor_proba", - ... filter_tags={"capability:missing": True}, - ... as_dataframe=True - ... ) # doctest: +SKIP References ---------- @@ -143,7 +131,7 @@ def all_objects( ) result = [] - ROOT = str(Path(__file__).parent.parent) # skpro package root directory + ROOT = str(Path(__file__).parent.parent) # package root directory if isinstance(filter_tags, str): filter_tags = {filter_tags: True} @@ -173,7 +161,7 @@ def all_objects( as_dataframe=as_dataframe, return_tags=return_tags, suppress_import_stdout=suppress_import_stdout, - package_name="skpro", + package_name="pytorch_forecasting", path=ROOT, modules_to_ignore=MODULES_TO_IGNORE, ) diff --git a/pytorch_forecasting/models/base/__init__.py b/pytorch_forecasting/models/base/__init__.py index 474e7d564..7b69ec246 100644 --- a/pytorch_forecasting/models/base/__init__.py +++ b/pytorch_forecasting/models/base/__init__.py @@ -7,10 +7,14 @@ BaseModelWithCovariates, Prediction, ) -from pytorch_forecasting.models.base._base_object import _BaseObject +from pytorch_forecasting.models.base._base_object import ( + _BaseObject, + _BasePtForecaster, +) __all__ = [ "_BaseObject", + "_BasePtForecaster", "AutoRegressiveBaseModel", "AutoRegressiveBaseModelWithCovariates", "BaseModel", diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 62fad456b..8895b4c2c 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -4,7 +4,7 @@ from pytorch_forecasting.utils._dependencies import _safe_import -_SkbaseBaseObject = _safe_import("skbase._base_object._BaseObject") +_SkbaseBaseObject = _safe_import("skbase.base.BaseObject", pkg_name="scikit-base") class _BaseObject(_SkbaseBaseObject): diff --git a/pytorch_forecasting/utils/_dependencies/_safe_import.py b/pytorch_forecasting/utils/_dependencies/_safe_import.py index f4805f9c1..ffbde8b5d 100644 --- a/pytorch_forecasting/utils/_dependencies/_safe_import.py +++ b/pytorch_forecasting/utils/_dependencies/_safe_import.py @@ -70,6 +70,9 @@ def _safe_import(import_path, pkg_name=None): if pkg_name is None: path_list = import_path.split(".") pkg_name = path_list[0] + else: + path_list = import_path.split(".") + path_list = [pkg_name] + path_list[1:] if pkg_name in _get_installed_packages(): try: From 318c1fbdbfface24fdc67568cf9a01a6dde1650c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:33:42 +0100 Subject: [PATCH 10/33] fix private name --- pytorch_forecasting/models/deepar/__init__.py | 3 ++- pytorch_forecasting/models/deepar/_deepar_metadata.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/models/deepar/__init__.py b/pytorch_forecasting/models/deepar/__init__.py index 679f296f6..149e19e0d 100644 --- a/pytorch_forecasting/models/deepar/__init__.py +++ b/pytorch_forecasting/models/deepar/__init__.py @@ -1,5 +1,6 @@ """DeepAR: Probabilistic forecasting with autoregressive recurrent networks.""" from pytorch_forecasting.models.deepar._deepar import DeepAR +from pytorch_forecasting.models.deepar._deepar_metadata import DeepARMetadata -__all__ = ["DeepAR"] +__all__ = ["DeepAR", "DeepARMetadata"] diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py index 330b86e80..89aefc1b0 100644 --- a/pytorch_forecasting/models/deepar/_deepar_metadata.py +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -3,7 +3,7 @@ from pytorch_forecasting.models.base._base_object import _BasePtForecaster -class _DeepARMetadata(_BasePtForecaster): +class DeepARMetadata(_BasePtForecaster): """DeepAR metadata container.""" _tags = { From 012ab3d78ed8a99e6920f3df704188834bbe1c14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:48:02 +0100 Subject: [PATCH 11/33] Update _base_object.py --- pytorch_forecasting/models/base/_base_object.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 8895b4c2c..4fcb1bd22 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -27,7 +27,6 @@ def get_model_cls(cls): """Get model class.""" raise NotImplementedError - @classmethod def create_test_instance(cls, parameter_set="default"): """Construct an instance of the class, using first test parameter set. From 86365a00d88cda407674dbcac5c4d53bd26f3fce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:57:16 +0100 Subject: [PATCH 12/33] test failure --- pytorch_forecasting/tests/_conftest.py | 262 ++++++++++++++++++ .../tests/test_all_estimators.py | 101 +++++++ 2 files changed, 363 insertions(+) create mode 100644 pytorch_forecasting/tests/_conftest.py diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py new file mode 100644 index 000000000..e276446a6 --- /dev/null +++ b/pytorch_forecasting/tests/_conftest.py @@ -0,0 +1,262 @@ +import numpy as np +import pytest +import torch + +from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data + +torch.manual_seed(23) + + +@pytest.fixture(scope="session") +def gpus(): + if torch.cuda.is_available(): + return [0] + else: + return 0 + + +@pytest.fixture(scope="session") +def data_with_covariates(): + data = get_stallion_data() + data["month"] = data.date.dt.month.astype(str) + data["log_volume"] = np.log1p(data.volume) + data["weight"] = 1 + np.sqrt(data.volume) + + data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month + data["time_idx"] -= data["time_idx"].min() + + # convert special days into strings + special_days = [ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + data[special_days] = ( + data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category") + ) + data = data.astype(dict(industry_volume=float)) + + # select data subset + data = data[lambda x: x.sku.isin(data.sku.unique()[:2])][ + lambda x: x.agency.isin(data.agency.unique()[:2]) + ] + + # default target + data["target"] = data["volume"].clip(1e-3, 1.0) + + return data + + +def make_dataloaders(data_with_covariates, **kwargs): + training_cutoff = "2016-09-01" + max_encoder_length = 4 + max_prediction_length = 3 + + kwargs.setdefault("target", "volume") + kwargs.setdefault("group_ids", ["agency", "sku"]) + kwargs.setdefault("add_relative_time_idx", True) + kwargs.setdefault("time_varying_unknown_reals", ["volume"]) + + training = TimeSeriesDataSet( + data_with_covariates[lambda x: x.date < training_cutoff].copy(), + time_idx="time_idx", + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + **kwargs, # fixture parametrization + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data_with_covariates.copy(), + min_prediction_idx=training.index.time.max() + 1, + ) + train_dataloader = training.to_dataloader(train=True, batch_size=2, num_workers=0) + val_dataloader = validation.to_dataloader(train=False, batch_size=2, num_workers=0) + test_dataloader = validation.to_dataloader(train=False, batch_size=1, num_workers=0) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) + + +@pytest.fixture( + params=[ + dict(), + dict( + static_categoricals=["agency", "sku"], + static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + constant_fill_strategy={"volume": 0}, + categorical_encoders={"sku": NaNLabelEncoder(add_nan=True)}, + ), + dict(static_categoricals=["agency", "sku"]), + dict(randomize_length=True, min_encoder_length=2), + dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2), + dict(target_normalizer=GroupNormalizer(transformation="log1p")), + dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="softplus", center=False + ) + ), + dict(target="agency"), + # test multiple targets + dict(target=["industry_volume", "volume"]), + dict(target=["agency", "volume"]), + dict( + target=["agency", "volume"], min_encoder_length=1, min_prediction_length=1 + ), + dict(target=["agency", "volume"], weight="volume"), + # test weights + dict(target="volume", weight="volume"), + ], + scope="session", +) +def multiple_dataloaders_with_covariates(data_with_covariates, request): + return make_dataloaders(data_with_covariates, **request.param) + + +@pytest.fixture(scope="session") +def dataloaders_with_different_encoder_decoder_length(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "target", + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_with_covariates(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_reals=["discount"], + time_varying_unknown_reals=["target"], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_multi_target(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + time_varying_unknown_reals=["target", "discount"], + target=["target", "discount"], + add_relative_time_idx=False, + ) + + +@pytest.fixture(scope="session") +def dataloaders_fixed_window_without_covariates(): + data = generate_ar_data(seasonality=10.0, timesteps=50, n_series=2) + validation = data.series.iloc[:2] + + max_encoder_length = 30 + max_prediction_length = 10 + + training = TimeSeriesDataSet( + data[lambda x: ~x.series.isin(validation)], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + static_categoricals=[], + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + time_varying_unknown_reals=["value"], + target_normalizer=EncoderNormalizer(), + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data[lambda x: x.series.isin(validation)], + stop_randomization=True, + ) + batch_size = 2 + train_dataloader = training.to_dataloader( + train=True, batch_size=batch_size, num_workers=0 + ) + val_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + test_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 704ddfc20..609761e21 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -1,7 +1,11 @@ """Automated tests based on the skbase test suite template.""" from inspect import isclass +import shutil +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.loggers import TensorBoardLogger from skbase.testing import ( BaseFixtureGenerator as _BaseFixtureGenerator, TestAllObjects as _TestAllObjects, @@ -9,6 +13,7 @@ from pytorch_forecasting._registry import all_objects from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS +from pytorch_forecasting.tests._conftest import make_dataloaders # whether to test only estimators from modules that are changed w.r.t. main # default is False, can be set to True by pytest --only_changed_modules True flag @@ -110,6 +115,98 @@ def _all_objects(self): ] +def _integration( + data_with_covariates, + tmp_path, + cell_type="LSTM", + data_loader_kwargs={}, + clip_target: bool = False, + trainer_kwargs=None, + **kwargs, +): + data_with_covariates = data_with_covariates.copy() + if clip_target: + data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0) + else: + data_with_covariates["target"] = data_with_covariates["volume"] + data_loader_default_kwargs = dict( + target="target", + time_varying_known_reals=["price_actual"], + time_varying_unknown_reals=["target"], + static_categoricals=["agency"], + add_relative_time_idx=True, + ) + data_loader_default_kwargs.update(data_loader_kwargs) + dataloaders_with_covariates = make_dataloaders( + data_with_covariates, **data_loader_default_kwargs + ) + + train_dataloader = dataloaders_with_covariates["train"] + val_dataloader = dataloaders_with_covariates["val"] + test_dataloader = dataloaders_with_covariates["test"] + + early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" + ) + + logger = TensorBoardLogger(tmp_path) + if trainer_kwargs is None: + trainer_kwargs = {} + trainer = pl.Trainer( + max_epochs=3, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + enable_checkpointing=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + logger=logger, + **trainer_kwargs, + ) + + net = DeepAR.from_dataset( + train_dataloader.dataset, + hidden_size=5, + cell_type=cell_type, + learning_rate=0.01, + log_gradient_flow=True, + log_interval=1000, + n_plotting_samples=100, + **kwargs, + ) + net.size() + try: + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + # check loading + net = DeepAR.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + + # check prediction + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + trainer_kwargs=trainer_kwargs, + ) + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + trainer_kwargs=trainer_kwargs, + ) + + class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator, _TestAllObjects): """Generic tests for all objects in the mini package.""" @@ -118,3 +215,7 @@ def test_doctest_examples(self, object_class): import doctest doctest.run_docstring_examples(object_class, globals()) + + def certain_failure(self, object_class): + """Fails for certain, for testing.""" + assert False From f6dee46efaa6853afa299d5edca80d11a80367ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:59:52 +0100 Subject: [PATCH 13/33] Update test_all_estimators.py --- pytorch_forecasting/tests/test_all_estimators.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 609761e21..a5f2c3783 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -116,6 +116,7 @@ def _all_objects(self): def _integration( + estimator_cls, data_with_covariates, tmp_path, cell_type="LSTM", @@ -165,7 +166,7 @@ def _integration( **trainer_kwargs, ) - net = DeepAR.from_dataset( + net = estimator_cls.from_dataset( train_dataloader.dataset, hidden_size=5, cell_type=cell_type, @@ -185,7 +186,7 @@ def _integration( test_outputs = trainer.test(net, dataloaders=test_dataloader) assert len(test_outputs) > 0 # check loading - net = DeepAR.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + net = estimator_cls.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # check prediction net.predict( From 9b0e4ec4c7d47dc0115a87ea4297a22a2f0fe5eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 15:51:07 +0100 Subject: [PATCH 14/33] Update test_all_estimators.py --- pytorch_forecasting/tests/test_all_estimators.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index a5f2c3783..2934d42db 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -116,7 +116,7 @@ def _all_objects(self): def _integration( - estimator_cls, + estimator_cls, data_with_covariates, tmp_path, cell_type="LSTM", @@ -186,7 +186,9 @@ def _integration( test_outputs = trainer.test(net, dataloaders=test_dataloader) assert len(test_outputs) > 0 # check loading - net = estimator_cls.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + net = estimator_cls.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path + ) # check prediction net.predict( From 7de528537d6fe36dd554f3cad5550d6f66c512e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 16:02:28 +0100 Subject: [PATCH 15/33] Update test_all_estimators.py --- pytorch_forecasting/tests/test_all_estimators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 2934d42db..37b597712 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -28,7 +28,7 @@ class PackageConfig: # package to search for objects # expected type: str, package/module name, relative to python environment root - package_name = "pytroch_forecasting" + package_name = "pytorch_forecasting" # list of object types (class names) to exclude # expected type: list of str, str are class names From 57dfe3a4e47ac3a34199d787cc6282f43b18a9f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 16:42:01 +0100 Subject: [PATCH 16/33] test folders --- pytest.ini | 4 +- .../tests/test_all_estimators.py | 43 ++++++++++++++++--- 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/pytest.ini b/pytest.ini index 457863f87..52f4fa1c1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,7 +10,9 @@ addopts = --no-cov-on-fail markers = -testpaths = tests/ +testpaths = + tests/ + pytorch_forecasting/tests/ log_cli_level = ERROR log_format = %(asctime)s %(levelname)s %(message)s log_date_format = %Y-%m-%d %H:%M:%S diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 37b597712..8fbcc6ffe 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -6,10 +6,8 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger -from skbase.testing import ( - BaseFixtureGenerator as _BaseFixtureGenerator, - TestAllObjects as _TestAllObjects, -) +import pytest +from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator from pytorch_forecasting._registry import all_objects from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS @@ -110,10 +108,43 @@ def _all_objects(self): # which sequence the conditional fixtures are generated in fixture_sequence = [ + "object_metadata", "object_class", "object_instance", ] + def _generate_object_metadata(self, test_name, **kwargs): + """Return object class fixtures. + + Fixtures parametrized + --------------------- + object_class: object inheriting from BaseObject + ranges over all object classes not excluded by self.excluded_tests + """ + object_classes_to_test = [ + est for est in self._all_objects() if not self.is_excluded(test_name, est) + ] + object_names = [est.__name__ for est in object_classes_to_test] + + return object_classes_to_test, object_names + + def _generate_object_class(self, test_name, **kwargs): + """Return object class fixtures. + + Fixtures parametrized + --------------------- + object_class: object inheriting from BaseObject + ranges over all object classes not excluded by self.excluded_tests + """ + all_metadata = self._all_objects() + all_cls = [est.get_model_cls() for est in all_metadata] + object_classes_to_test = [ + est for est in all_cls if not self.is_excluded(test_name, est) + ] + object_names = [est.__name__ for est in object_classes_to_test] + + return object_classes_to_test, object_names + def _integration( estimator_cls, @@ -210,7 +241,7 @@ def _integration( ) -class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator, _TestAllObjects): +class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator): """Generic tests for all objects in the mini package.""" def test_doctest_examples(self, object_class): @@ -219,6 +250,6 @@ def test_doctest_examples(self, object_class): doctest.run_docstring_examples(object_class, globals()) - def certain_failure(self, object_class): + def test_certain_failure(self, object_class): """Fails for certain, for testing.""" assert False From c9f12dbdeea4aa431b52620b226cd57193a9a249 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 16:59:30 +0100 Subject: [PATCH 17/33] Update test.yml --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5fcd9c1ff..0083302dd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -71,7 +71,7 @@ jobs: - name: Run pytest shell: bash - run: python -m pytest tests + run: python -m pytest pytest: name: Run pytest @@ -110,7 +110,7 @@ jobs: - name: Run pytest shell: bash - run: python -m pytest tests + run: python -m pytest - name: Statistics run: | From fa8144ebae6312d34458a2464da2d1bc3f186754 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 18:56:39 +0100 Subject: [PATCH 18/33] test integration --- .../models/deepar/_deepar_metadata.py | 25 +---------- .../tests/test_all_estimators.py | 41 ++++++++++++++++++- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py index 89aefc1b0..206f113f0 100644 --- a/pytorch_forecasting/models/deepar/_deepar_metadata.py +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -23,29 +23,8 @@ def get_model_cls(cls): return DeepAR @classmethod - def get_test_params(cls, parameter_set="default"): - """Return testing parameter settings for the skbase object. - - ``get_test_params`` is a unified interface point to store - parameter settings for testing purposes. This function is also - used in ``create_test_instance`` and ``create_test_instances_and_names`` - to construct test instances. - - ``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``. - - Each ``dict`` is a parameter configuration for testing, - and can be used to construct an "interesting" test instance. - A call to ``cls(**params)`` should - be valid for all dictionaries ``params`` in the return of ``get_test_params``. - - The ``get_test_params`` need not return fixed lists of dictionaries, - it can also return dynamic or stochastic parameter settings. - - Parameters - ---------- - parameter_set : str, default="default" - Name of the set of test parameters to return, for use in tests. If no - special parameters are defined for a value, will return `"default"` set. + def get_test_train_params(cls): + """Return testing parameter settings for the trainer. Returns ------- diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 8fbcc6ffe..378316d7a 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -71,6 +71,8 @@ class BaseFixtureGenerator(_BaseFixtureGenerator): object_instance: instance of estimator inheriting from BaseObject ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS instances are generated by create_test_instance class method of object_class + trainer_kwargs: list of dict + ranges over dictionaries of kwargs for the trainer """ # overrides object retrieval in scikit-base @@ -111,6 +113,7 @@ def _all_objects(self): "object_metadata", "object_class", "object_instance", + "trainer_kwargs", ] def _generate_object_metadata(self, test_name, **kwargs): @@ -145,6 +148,30 @@ def _generate_object_class(self, test_name, **kwargs): return object_classes_to_test, object_names + def _generate_trainer_kwargs(self, test_name, **kwargs): + """Return kwargs for the trainer. + + Fixtures parametrized + --------------------- + trainer_kwargs: dict + ranges over all kwargs for the trainer + """ + # call _generate_object_class to get all the classes + object_meta_to_test, _ = self._generate_object_metadata(test_name=test_name) + + # create instances from the classes + train_kwargs_to_test = [] + train_kwargs_names = [] + # retrieve all object parameters if multiple, construct instances + for est in object_meta_to_test: + est_name = est.__name__ + all_train_kwargs = est.get_test_train_params() + train_kwargs_to_test += all_train_kwargs + rg = range(len(all_train_kwargs)) + train_kwargs_names += [f"{est_name}_{i}" for i in rg] + + return train_kwargs_to_test, train_kwargs_names + def _integration( estimator_cls, @@ -250,6 +277,16 @@ def test_doctest_examples(self, object_class): doctest.run_docstring_examples(object_class, globals()) - def test_certain_failure(self, object_class): + def test_integration( + self, object_class, trainer_kwargs, data_with_covariates, tmp_path + ): """Fails for certain, for testing.""" - assert False + from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss + + if "loss" in trainer_kwargs and isinstance( + trainer_kwargs["loss"], NegativeBinomialDistributionLoss + ): + data_with_covariates = data_with_covariates.assign( + volume=lambda x: x.volume.round() + ) + _integration(object_class, data_with_covariates, tmp_path, **trainer_kwargs) From 232a510bc6820786ef1ce46ab3115a04126054ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 19:07:35 +0100 Subject: [PATCH 19/33] fixes --- .../models/base/_base_object.py | 8 +++++ .../models/deepar/_deepar_metadata.py | 4 ++- .../tests/test_all_estimators.py | 31 ++++++++++--------- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 4fcb1bd22..7fd59d6a4 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -27,6 +27,14 @@ def get_model_cls(cls): """Get model class.""" raise NotImplementedError + @classmethod + def name(cls): + """Get model name.""" + name = cls.get_class_tags().get("info:name", None) + if name is None: + name = cls.get_model_cls().__name__ + return name + @classmethod def create_test_instance(cls, parameter_set="default"): """Construct an instance of the class, using first test parameter set. diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py index 206f113f0..e477d63b0 100644 --- a/pytorch_forecasting/models/deepar/_deepar_metadata.py +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -7,12 +7,14 @@ class DeepARMetadata(_BasePtForecaster): """DeepAR metadata container.""" _tags = { + "info:name": "DeepAR", + "info:compute": 3, + "authors": ["jdb78"], "capability:exogenous": True, "capability:multivariate": True, "capability:pred_int": True, "capability:flexible_history_length": True, "capability:cold_start": False, - "info:compute": 3, } @classmethod diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 378316d7a..c2d86e40e 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -127,7 +127,7 @@ def _generate_object_metadata(self, test_name, **kwargs): object_classes_to_test = [ est for est in self._all_objects() if not self.is_excluded(test_name, est) ] - object_names = [est.__name__ for est in object_classes_to_test] + object_names = [est.name() for est in object_classes_to_test] return object_classes_to_test, object_names @@ -156,21 +156,16 @@ def _generate_trainer_kwargs(self, test_name, **kwargs): trainer_kwargs: dict ranges over all kwargs for the trainer """ - # call _generate_object_class to get all the classes - object_meta_to_test, _ = self._generate_object_metadata(test_name=test_name) + if "object_metadata" in kwargs.keys(): + obj_meta = kwargs["object_metadata"] + else: + return [] - # create instances from the classes - train_kwargs_to_test = [] - train_kwargs_names = [] - # retrieve all object parameters if multiple, construct instances - for est in object_meta_to_test: - est_name = est.__name__ - all_train_kwargs = est.get_test_train_params() - train_kwargs_to_test += all_train_kwargs - rg = range(len(all_train_kwargs)) - train_kwargs_names += [f"{est_name}_{i}" for i in rg] + all_train_kwargs = obj_meta.get_test_train_params() + rg = range(len(all_train_kwargs)) + train_kwargs_names = [str(i) for i in rg] - return train_kwargs_to_test, train_kwargs_names + return all_train_kwargs, train_kwargs_names def _integration( @@ -278,11 +273,17 @@ def test_doctest_examples(self, object_class): doctest.run_docstring_examples(object_class, globals()) def test_integration( - self, object_class, trainer_kwargs, data_with_covariates, tmp_path + self, + object_metadata, + trainer_kwargs, + data_with_covariates, + tmp_path, ): """Fails for certain, for testing.""" from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss + object_class = object_metadata.get_model_cls() + if "loss" in trainer_kwargs and isinstance( trainer_kwargs["loss"], NegativeBinomialDistributionLoss ): From 1c8d4b5c4fbf8dca91ec28e36e8781bb08a291bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 19:30:39 +0100 Subject: [PATCH 20/33] Update _conftest.py --- pytorch_forecasting/tests/_conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py index e276446a6..20cad22c5 100644 --- a/pytorch_forecasting/tests/_conftest.py +++ b/pytorch_forecasting/tests/_conftest.py @@ -17,7 +17,7 @@ def gpus(): return 0 -@pytest.fixture(scope="session") +@pytest.fixture(scope="package") def data_with_covariates(): data = get_stallion_data() data["month"] = data.date.dt.month.astype(str) From f632e32325a657fe975f9c76344eaba0585e17e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 19:37:01 +0100 Subject: [PATCH 21/33] try scenarios --- pytorch_forecasting/tests/_conftest.py | 2 +- pytorch_forecasting/tests/_data_scenarios.py | 261 ++++++++++++++++++ .../tests/test_all_estimators.py | 5 +- 3 files changed, 265 insertions(+), 3 deletions(-) create mode 100644 pytorch_forecasting/tests/_data_scenarios.py diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py index 20cad22c5..e276446a6 100644 --- a/pytorch_forecasting/tests/_conftest.py +++ b/pytorch_forecasting/tests/_conftest.py @@ -17,7 +17,7 @@ def gpus(): return 0 -@pytest.fixture(scope="package") +@pytest.fixture(scope="session") def data_with_covariates(): data = get_stallion_data() data["month"] = data.date.dt.month.astype(str) diff --git a/pytorch_forecasting/tests/_data_scenarios.py b/pytorch_forecasting/tests/_data_scenarios.py new file mode 100644 index 000000000..062db97dd --- /dev/null +++ b/pytorch_forecasting/tests/_data_scenarios.py @@ -0,0 +1,261 @@ +import numpy as np +import pytest +import torch + +from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data + +torch.manual_seed(23) + + +@pytest.fixture(scope="session") +def gpus(): + if torch.cuda.is_available(): + return [0] + else: + return 0 + + +def data_with_covariates(): + data = get_stallion_data() + data["month"] = data.date.dt.month.astype(str) + data["log_volume"] = np.log1p(data.volume) + data["weight"] = 1 + np.sqrt(data.volume) + + data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month + data["time_idx"] -= data["time_idx"].min() + + # convert special days into strings + special_days = [ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + data[special_days] = ( + data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category") + ) + data = data.astype(dict(industry_volume=float)) + + # select data subset + data = data[lambda x: x.sku.isin(data.sku.unique()[:2])][ + lambda x: x.agency.isin(data.agency.unique()[:2]) + ] + + # default target + data["target"] = data["volume"].clip(1e-3, 1.0) + + return data + + +def make_dataloaders(data_with_covariates, **kwargs): + training_cutoff = "2016-09-01" + max_encoder_length = 4 + max_prediction_length = 3 + + kwargs.setdefault("target", "volume") + kwargs.setdefault("group_ids", ["agency", "sku"]) + kwargs.setdefault("add_relative_time_idx", True) + kwargs.setdefault("time_varying_unknown_reals", ["volume"]) + + training = TimeSeriesDataSet( + data_with_covariates[lambda x: x.date < training_cutoff].copy(), + time_idx="time_idx", + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + **kwargs, # fixture parametrization + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data_with_covariates.copy(), + min_prediction_idx=training.index.time.max() + 1, + ) + train_dataloader = training.to_dataloader(train=True, batch_size=2, num_workers=0) + val_dataloader = validation.to_dataloader(train=False, batch_size=2, num_workers=0) + test_dataloader = validation.to_dataloader(train=False, batch_size=1, num_workers=0) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) + + +@pytest.fixture( + params=[ + dict(), + dict( + static_categoricals=["agency", "sku"], + static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + constant_fill_strategy={"volume": 0}, + categorical_encoders={"sku": NaNLabelEncoder(add_nan=True)}, + ), + dict(static_categoricals=["agency", "sku"]), + dict(randomize_length=True, min_encoder_length=2), + dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2), + dict(target_normalizer=GroupNormalizer(transformation="log1p")), + dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="softplus", center=False + ) + ), + dict(target="agency"), + # test multiple targets + dict(target=["industry_volume", "volume"]), + dict(target=["agency", "volume"]), + dict( + target=["agency", "volume"], min_encoder_length=1, min_prediction_length=1 + ), + dict(target=["agency", "volume"], weight="volume"), + # test weights + dict(target="volume", weight="volume"), + ], + scope="session", +) +def multiple_dataloaders_with_covariates(data_with_covariates, request): + return make_dataloaders(data_with_covariates, **request.param) + + +@pytest.fixture(scope="session") +def dataloaders_with_different_encoder_decoder_length(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "target", + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_with_covariates(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_reals=["discount"], + time_varying_unknown_reals=["target"], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_multi_target(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + time_varying_unknown_reals=["target", "discount"], + target=["target", "discount"], + add_relative_time_idx=False, + ) + + +@pytest.fixture(scope="session") +def dataloaders_fixed_window_without_covariates(): + data = generate_ar_data(seasonality=10.0, timesteps=50, n_series=2) + validation = data.series.iloc[:2] + + max_encoder_length = 30 + max_prediction_length = 10 + + training = TimeSeriesDataSet( + data[lambda x: ~x.series.isin(validation)], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + static_categoricals=[], + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + time_varying_unknown_reals=["value"], + target_normalizer=EncoderNormalizer(), + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data[lambda x: x.series.isin(validation)], + stop_randomization=True, + ) + batch_size = 2 + train_dataloader = training.to_dataloader( + train=True, batch_size=batch_size, num_workers=0 + ) + val_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + test_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index c2d86e40e..b8a21cc6a 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -6,7 +6,6 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger -import pytest from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator from pytorch_forecasting._registry import all_objects @@ -276,11 +275,13 @@ def test_integration( self, object_metadata, trainer_kwargs, - data_with_covariates, tmp_path, ): """Fails for certain, for testing.""" from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss + from pytorch_forecasting.tests._data_scenarios import data_with_covariates + + data_with_covariates = data_with_covariates() object_class = object_metadata.get_model_cls() From a6691341001f813ac4c5d12aafb173645087d679 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 4 May 2025 18:28:04 +0200 Subject: [PATCH 22/33] Update _lookup.py --- pytorch_forecasting/_registry/_lookup.py | 49 ++++++++++++++++-------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index 828c448b1..b4238f980 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -11,6 +11,7 @@ __author__ = ["fkiraly"] # all_objects is based on the sklearn utility all_estimators +from inspect import isclass from pathlib import Path from skbase.lookup import all_objects as _all_objects @@ -133,25 +134,39 @@ def all_objects( result = [] ROOT = str(Path(__file__).parent.parent) # package root directory - if isinstance(filter_tags, str): - filter_tags = {filter_tags: True} - filter_tags = filter_tags.copy() if filter_tags else None - - if object_types: - if filter_tags and "object_type" not in filter_tags.keys(): - object_tag_filter = {"object_type": object_types} - elif filter_tags: - filter_tags_filter = filter_tags.get("object_type", []) - if isinstance(object_types, str): - object_types = [object_types] - object_tag_update = {"object_type": object_types + filter_tags_filter} - filter_tags.update(object_tag_update) + def _coerce_to_str(obj): + if isinstance(obj, (list, tuple)): + return [_coerce_to_str(o) for o in obj] + if isclass(obj): + obj = obj.get_tag("object_type") + return obj + + def _coerce_to_list_of_str(obj): + obj = _coerce_to_str(obj) + if isinstance(obj, str): + return [obj] + return obj + + if object_types is not None: + object_types = _coerce_to_list_of_str(object_types) + object_types = list(set(object_types)) + + if object_types is not None: + if filter_tags is None: + filter_tags = {} + elif isinstance(filter_tags, str): + filter_tags = {filter_tags: True} else: - object_tag_filter = {"object_type": object_types} - if filter_tags: - filter_tags.update(object_tag_filter) + filter_tags = filter_tags.copy() + + if "object_type" in filter_tags: + obj_field = filter_tags["object_type"] + obj_field = _coerce_to_list_of_str(obj_field) + obj_field = obj_field + object_types else: - filter_tags = object_tag_filter + obj_field = object_types + + filter_tags["object_type"] = obj_field result = _all_objects( object_types=[_BaseObject], From d78bf5dc19cef1e659e8258552691c1713b2dd4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 4 May 2025 18:32:38 +0200 Subject: [PATCH 23/33] Update _lookup.py --- pytorch_forecasting/_registry/_lookup.py | 95 +++++++++++++++--------- 1 file changed, 60 insertions(+), 35 deletions(-) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index b4238f980..0fb4c0c9d 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -40,44 +40,64 @@ def all_objects( ---------- object_types: str, list of str, optional (default=None) Which kind of objects should be returned. - if None, no filter is applied and all objects are returned. - if str or list of str, strings define scitypes specified in search - only objects that are of (at least) one of the scitypes are returned - possible str values are entries of registry.BASE_CLASS_REGISTER (first col) - for instance 'regrssor_proba', 'distribution, 'metric' - return_names: bool, optional (default=True) + * if None, no filter is applied and all objects are returned. + * if str or list of str, strings define scitypes specified in search + only objects that are of (at least) one of the scitypes are returned - if True, estimator class name is included in the ``all_objects`` - return in the order: name, estimator class, optional tags, either as - a tuple or as pandas.DataFrame columns + return_names: bool, optional (default=True) - if False, estimator class name is removed from the ``all_objects`` return. + * if True, estimator class name is included in the ``all_objects`` + return in the order: name, estimator class, optional tags, either as + a tuple or as pandas.DataFrame columns + * if False, estimator class name is removed from the ``all_objects`` return. - filter_tags: dict of (str or list of str), optional (default=None) + filter_tags: dict of (str or list of str or re.Pattern), optional (default=None) For a list of valid tag strings, use the registry.all_tags utility. - ``filter_tags`` subsets the returned estimators as follows: + ``filter_tags`` subsets the returned objects as follows: * each key/value pair is statement in "and"/conjunction * key is tag name to sub-set on * value str or list of string are tag values * condition is "key must be equal to value, or in set(value)" - exclude_estimators: str, list of str, optional (default=None) - Names of estimators to exclude. + In detail, he return will be filtered to keep exactly the classes + where tags satisfy all the filter conditions specified by ``filter_tags``. + Filter conditions are as follows, for ``tag_name: search_value`` pairs in + the ``filter_tags`` dict, applied to a class ``klass``: + + - If ``klass`` does not have a tag with name ``tag_name``, it is excluded. + Otherwise, let ``tag_value`` be the value of the tag with name ``tag_name``. + - If ``search_value`` is a string, and ``tag_value`` is a string, + the filter condition is that ``search_value`` must match the tag value. + - If ``search_value`` is a string, and ``tag_value`` is a list, + the filter condition is that ``search_value`` is contained in ``tag_value``. + - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a string, + the filter condition is that ``search_value.fullmatch(tag_value)`` + is true, i.e., the regex matches the tag value. + - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a list, + the filter condition is that at least one element of ``tag_value`` + matches the regex. + - If ``search_value`` is iterable, then the filter condition is that + at least one element of ``search_value`` satisfies the above conditions, + applied to ``tag_value``. + + Note: ``re.Pattern`` is supported only from ``scikit-base`` version 0.8.0. + + exclude_objects: str, list of str, optional (default=None) + Names of objects to exclude. as_dataframe: bool, optional (default=False) - True: ``all_objects`` will return a pandas.DataFrame with named - columns for all of the attributes being returned. - - False: ``all_objects`` will return a list (either a list of - estimators or a list of tuples, see Returns) + * True: ``all_objects`` will return a ``pandas.DataFrame`` with named + columns for all of the attributes being returned. + * False: ``all_objects`` will return a list (either a list of + objects or a list of tuples, see Returns) return_tags: str or list of str, optional (default=None) Names of tags to fetch and return each estimator's value of. - For a list of valid tag strings, use the registry.all_tags utility. + For a list of valid tag strings, use the ``registry.all_tags`` utility. if str or list of str, the tag values named in return_tags will be fetched for each estimator and will be appended as either columns or tuple entries. @@ -88,27 +108,32 @@ def all_objects( Returns ------- all_objects will return one of the following: - 1. list of objects, if return_names=False, and return_tags is None - 2. list of tuples (optional object name, class, ~optional object - tags), if return_names=True or return_tags is not None. - 3. pandas.DataFrame if as_dataframe = True + + 1. list of objects, if ``return_names=False``, and ``return_tags`` is None + + 2. list of tuples (optional estimator name, class, optional estimator + tags), if ``return_names=True`` or ``return_tags`` is not ``None``. + + 3. ``pandas.DataFrame`` if ``as_dataframe = True`` + if list of objects: entries are objects matching the query, - in alphabetical order of object name + in alphabetical order of estimator name + if list of tuples: - list of (optional object name, object, optional object - tags) matching the query, in alphabetical order of object name, + list of (optional estimator name, estimator, optional estimator + tags) matching the query, in alphabetical order of estimator name, where - ``name`` is the object name as string, and is an - optional return - ``object`` is the actual object - ``tags`` are the object's values for each tag in return_tags - and is an optional return. - if dataframe: - all_objects will return a pandas.DataFrame. + ``name`` is the estimator name as string, and is an + optional return + ``estimator`` is the actual estimator + ``tags`` are the estimator's values for each tag in return_tags + and is an optional return. + + if ``DataFrame``: column names represent the attributes contained in each column. "objects" will be the name of the column of objects, "names" - will be the name of the column of object class names and the string(s) + will be the name of the column of estimator class names and the string(s) passed in return_tags will serve as column names for all columns of tags that were optionally requested. From 3d2dafc73249fc3f6b5f6d461f0420d64ef111a1 Mon Sep 17 00:00:00 2001 From: Aryan Saini <116151399+phoeenniixx@users.noreply.github.com> Date: Tue, 13 May 2025 12:57:35 +0530 Subject: [PATCH 24/33] [ENH] EXPERIMENTAL PR: D1 and D2 layer for v2 refactor (#1811) ### Description This PR implements the basic skeleton of D1 and D2 layer for v2. See https://github.com/sktime/pytorch-forecasting/issues/1736 for discussion and design. --- pytorch_forecasting/data/__init__.py | 3 +- pytorch_forecasting/data/data_module.py | 709 ++++++++++++++++++ .../data/timeseries/__init__.py | 15 + .../_timeseries.py} | 21 +- .../data/timeseries/_timeseries_v2.py | 323 ++++++++ pytorch_forecasting/utils/_coerce.py | 25 + tests/test_data/test_d1.py | 379 ++++++++++ tests/test_data/test_data_module.py | 464 ++++++++++++ 8 files changed, 1918 insertions(+), 21 deletions(-) create mode 100644 pytorch_forecasting/data/data_module.py create mode 100644 pytorch_forecasting/data/timeseries/__init__.py rename pytorch_forecasting/data/{timeseries.py => timeseries/_timeseries.py} (99%) create mode 100644 pytorch_forecasting/data/timeseries/_timeseries_v2.py create mode 100644 pytorch_forecasting/utils/_coerce.py create mode 100644 tests/test_data/test_d1.py create mode 100644 tests/test_data/test_data_module.py diff --git a/pytorch_forecasting/data/__init__.py b/pytorch_forecasting/data/__init__.py index 301c8394d..17be285a0 100644 --- a/pytorch_forecasting/data/__init__.py +++ b/pytorch_forecasting/data/__init__.py @@ -13,10 +13,11 @@ TorchNormalizer, ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler -from pytorch_forecasting.data.timeseries import TimeSeriesDataSet +from pytorch_forecasting.data.timeseries import TimeSeries, TimeSeriesDataSet __all__ = [ "TimeSeriesDataSet", + "TimeSeries", "NaNLabelEncoder", "GroupNormalizer", "TorchNormalizer", diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py new file mode 100644 index 000000000..c8252014d --- /dev/null +++ b/pytorch_forecasting/data/data_module.py @@ -0,0 +1,709 @@ +####################################################################################### +# Disclaimer: This data-module is still work in progress and experimental, please +# use with care. This data-module is a basic skeleton of how the data-handling pipeline +# may look like in the future. +# This is D2 layer that will handle the preprocessing and data loaders. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + +from typing import Any, Dict, List, Optional, Tuple, Union +from warnings import warn + +from lightning.pytorch import LightningDataModule +from sklearn.preprocessing import RobustScaler, StandardScaler +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_forecasting.data.encoders import ( + EncoderNormalizer, + NaNLabelEncoder, + TorchNormalizer, +) +from pytorch_forecasting.data.timeseries import TimeSeries +from pytorch_forecasting.utils._coerce import _coerce_to_dict + +NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] + + +class EncoderDecoderTimeSeriesDataModule(LightningDataModule): + """ + Lightning DataModule for processing time series data in an encoder-decoder format. + + This module handles preprocessing, splitting, and batching of time series data + for use in deep learning models. It supports categorical and continuous features, + various scalers, and automatic target normalization. + + Parameters + ---------- + time_series_dataset : TimeSeries + The dataset containing time series data. + max_encoder_length : int, default=30 + Maximum length of the encoder input sequence. + min_encoder_length : Optional[int], default=None + Minimum length of the encoder input sequence. + Defaults to `max_encoder_length` if not specified. + max_prediction_length : int, default=1 + Maximum length of the decoder output sequence. + min_prediction_length : Optional[int], default=None + Minimum length of the decoder output sequence. + Defaults to `max_prediction_length` if not specified. + min_prediction_idx : Optional[int], default=None + Minimum index from which predictions start. + allow_missing_timesteps : bool, default=False + Whether to allow missing timesteps in the dataset. + add_relative_time_idx : bool, default=False + Whether to add a relative time index feature. + add_target_scales : bool, default=False + Whether to add target scaling information. + add_encoder_length : Union[bool, str], default="auto" + Whether to include encoder length information. + target_normalizer : + Union[NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None], + default="auto" + Normalizer for the target variable. If "auto", uses `RobustScaler`. + + categorical_encoders : Optional[Dict[str, NaNLabelEncoder]], default=None + Dictionary of categorical encoders. + + scalers : + Optional[Dict[str, Union[StandardScaler, RobustScaler, + TorchNormalizer, EncoderNormalizer]]], default=None + Dictionary of feature scalers. + + randomize_length : Union[None, Tuple[float, float], bool], default=False + Whether to randomize input sequence length. + batch_size : int, default=32 + Batch size for DataLoader. + num_workers : int, default=0 + Number of workers for DataLoader. + train_val_test_split : tuple, default=(0.7, 0.15, 0.15) + Proportions for train, validation, and test dataset splits. + """ + + def __init__( + self, + time_series_dataset: TimeSeries, + max_encoder_length: int = 30, + min_encoder_length: Optional[int] = None, + max_prediction_length: int = 1, + min_prediction_length: Optional[int] = None, + min_prediction_idx: Optional[int] = None, + allow_missing_timesteps: bool = False, + add_relative_time_idx: bool = False, + add_target_scales: bool = False, + add_encoder_length: Union[bool, str] = "auto", + target_normalizer: Union[ + NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None + ] = "auto", + categorical_encoders: Optional[Dict[str, NaNLabelEncoder]] = None, + scalers: Optional[ + Dict[ + str, + Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer], + ] + ] = None, + randomize_length: Union[None, Tuple[float, float], bool] = False, + batch_size: int = 32, + num_workers: int = 0, + train_val_test_split: tuple = (0.7, 0.15, 0.15), + ): + + self.time_series_dataset = time_series_dataset + self.max_encoder_length = max_encoder_length + self.min_encoder_length = min_encoder_length + self.max_prediction_length = max_prediction_length + self.min_prediction_length = min_prediction_length + self.min_prediction_idx = min_prediction_idx + self.allow_missing_timesteps = allow_missing_timesteps + self.add_relative_time_idx = add_relative_time_idx + self.add_target_scales = add_target_scales + self.add_encoder_length = add_encoder_length + self.randomize_length = randomize_length + self.target_normalizer = target_normalizer + self.categorical_encoders = categorical_encoders + self.scalers = scalers + self.batch_size = batch_size + self.num_workers = num_workers + self.train_val_test_split = train_val_test_split + + warn( + "TimeSeries is part of an experimental rework of the " + "pytorch-forecasting data layer, " + "scheduled for release with v2.0.0. " + "The API is not stable and may change without prior warning. " + "For beta testing, but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) + + super().__init__() + + # handle defaults and derived attributes + if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto": + self._target_normalizer = RobustScaler() + else: + self._target_normalizer = target_normalizer + + self.time_series_metadata = time_series_dataset.get_metadata() + self._min_prediction_length = min_prediction_length or max_prediction_length + self._min_encoder_length = min_encoder_length or max_encoder_length + self._categorical_encoders = _coerce_to_dict(categorical_encoders) + self._scalers = _coerce_to_dict(scalers) + + self.categorical_indices = [] + self.continuous_indices = [] + self._metadata = None + + for idx, col in enumerate(self.time_series_metadata["cols"]["x"]): + if self.time_series_metadata["col_type"].get(col) == "C": + self.categorical_indices.append(idx) + else: + self.continuous_indices.append(idx) + + # overwrite __init__ params for upwards compatibility with AS PRs + # todo: should we avoid this and ensure classes are dataclass-like? + self.min_prediction_length = self._min_prediction_length + self.min_encoder_length = self._min_encoder_length + self.categorical_encoders = self._categorical_encoders + self.scalers = self._scalers + self.target_normalizer = self._target_normalizer + + def _prepare_metadata(self): + """Prepare metadata for model initialisation. + + Returns + ------- + dict + dictionary containing the following keys: + + * ``encoder_cat``: Number of categorical variables in the encoder. + Computed as ``len(self.categorical_indices)``, which counts the + categorical feature indices. + * ``encoder_cont``: Number of continuous variables in the encoder. + Computed as ``len(self.continuous_indices)``, which counts the + continuous feature indices. + * ``decoder_cat``: Number of categorical variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "C"(categorical) and col_known == "K" (known) + * ``decoder_cont``: Number of continuous variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "F"(continuous) and col_known == "K"(known) + * ``target``: Number of target variables. + Computed as ``len(self.time_series_metadata["cols"]["y"])``, which + gives the number of output target columns.. + * ``static_categorical_features``: Number of static categorical features + Computed by filtering ``self.time_series_metadata["cols"]["st"]`` + (static features) where col_type == "C" (categorical). + * ``static_continuous_features``: Number of static continuous features + Computed as difference of + ``len(self.time_series_metadata["cols"]["st"])`` (static features) + and static_categorical_features that gives static continuous feature + * ``max_encoder_length``: maximum encoder length + Taken directly from `self.max_encoder_length`. + * ``max_prediction_length``: maximum prediction length + Taken directly from `self.max_prediction_length`. + * ``min_encoder_length``: minimum encoder length + Taken directly from `self.min_encoder_length`. + * ``min_prediction_length``: minimum prediction length + Taken directly from `self.min_prediction_length`. + """ + encoder_cat_count = len(self.categorical_indices) + encoder_cont_count = len(self.continuous_indices) + + decoder_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "C" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + decoder_cont_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "F" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + + target_count = len(self.time_series_metadata["cols"]["y"]) + metadata = { + "encoder_cat": encoder_cat_count, + "encoder_cont": encoder_cont_count, + "decoder_cat": decoder_cat_count, + "decoder_cont": decoder_cont_count, + "target": target_count, + } + if self.time_series_metadata["cols"]["st"]: + static_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["st"] + if self.time_series_metadata["col_type"].get(col) == "C" + ] + ) + static_cont_count = ( + len(self.time_series_metadata["cols"]["st"]) - static_cat_count + ) + + metadata["static_categorical_features"] = static_cat_count + metadata["static_continuous_features"] = static_cont_count + else: + metadata["static_categorical_features"] = 0 + metadata["static_continuous_features"] = 0 + + metadata.update( + { + "max_encoder_length": self.max_encoder_length, + "max_prediction_length": self.max_prediction_length, + "min_encoder_length": self._min_encoder_length, + "min_prediction_length": self._min_prediction_length, + } + ) + + return metadata + + @property + def metadata(self): + """Compute metadata for model initialization. + + This property returns a dictionary containing the shapes and key information + related to the time series model. The metadata includes: + + * ``encoder_cat``: Number of categorical variables in the encoder. + * ``encoder_cont``: Number of continuous variables in the encoder. + * ``decoder_cat``: Number of categorical variables in the decoder that are + known in advance. + * ``decoder_cont``: Number of continuous variables in the decoder that are + known in advance. + * ``target``: Number of target variables. + + If static features are present, the following keys are added: + + * ``static_categorical_features``: Number of static categorical features + * ``static_continuous_features``: Number of static continuous features + + It also contains the following information: + + * ``max_encoder_length``: maximum encoder length + * ``max_prediction_length``: maximum prediction length + * ``min_encoder_length``: minimum encoder length + * ``min_prediction_length``: minimum prediction length + """ + if self._metadata is None: + self._metadata = self._prepare_metadata() + return self._metadata + + def _preprocess_data(self, series_idx: torch.Tensor) -> List[Dict[str, Any]]: + """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. + + Preprocessing steps + -------------------- + + * Converts target (`y`) and features (`x`) to `torch.float32`. + * Masks time points that are at or before the cutoff time. + * Splits features into categorical and continuous subsets based on + predefined indices. + + + TODO: add scalers, target normalizers etc. + """ + sample = self.time_series_dataset[series_idx] + + target = sample["y"] + features = sample["x"] + times = sample["t"] + cutoff_time = sample["cutoff_time"] + + time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) + + if isinstance(target, torch.Tensor): + target = target.float() + else: + target = torch.tensor(target, dtype=torch.float32) + + if isinstance(features, torch.Tensor): + features = features.float() + else: + features = torch.tensor(features, dtype=torch.float32) + + # TODO: add scalers, target normalizers etc. + + categorical = ( + features[:, self.categorical_indices] + if self.categorical_indices + else torch.zeros((features.shape[0], 0)) + ) + continuous = ( + features[:, self.continuous_indices] + if self.continuous_indices + else torch.zeros((features.shape[0], 0)) + ) + + return { + "features": {"categorical": categorical, "continuous": continuous}, + "target": target, + "static": sample.get("st", None), + "group": sample.get("group", torch.tensor([0])), + "length": len(target), + "time_mask": time_mask, + "times": times, + "cutoff_time": cutoff_time, + } + + class _ProcessedEncoderDecoderDataset(Dataset): + """PyTorch Dataset for processed encoder-decoder time series data. + + Parameters + ---------- + dataset : TimeSeries + The base time series dataset that provides access to raw data and metadata. + data_module : EncoderDecoderTimeSeriesDataModule + The data module handling preprocessing and metadata configuration. + windows : List[Tuple[int, int, int, int]] + List of window tuples containing + (series_idx, start_idx, enc_length, pred_length). + add_relative_time_idx : bool, default=False + Whether to include relative time indices. + """ + + def __init__( + self, + dataset: TimeSeries, + data_module: "EncoderDecoderTimeSeriesDataModule", + windows: List[Tuple[int, int, int, int]], + add_relative_time_idx: bool = False, + ): + self.dataset = dataset + self.data_module = data_module + self.windows = windows + self.add_relative_time_idx = add_relative_time_idx + + def __len__(self): + return len(self.windows) + + def __getitem__(self, idx): + """Retrieve a processed time series window for dataloader input. + + x : dict + Dictionary containing model inputs: + + * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features) + Categorical features for the encoder. + * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features) + Continuous features for the encoder. + * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features) + Categorical features for the decoder. + * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features) + Continuous features for the decoder. + * ``encoder_lengths`` : tensor of shape (1,) + Length of the encoder sequence. + * ``decoder_lengths`` : tensor of shape (1,) + Length of the decoder sequence. + * ``decoder_target_lengths`` : tensor of shape (1,) + Length of the decoder target sequence. + * ``groups`` : tensor of shape (1,) + Group identifier for the time series instance. + * ``encoder_time_idx`` : tensor of shape (enc_length,) + Time indices for the encoder sequence. + * ``decoder_time_idx`` : tensor of shape (pred_length,) + Time indices for the decoder sequence. + * ``target_scale`` : tensor of shape (1,) + Scaling factor for the target values. + * ``encoder_mask`` : tensor of shape (enc_length,) + Boolean mask indicating valid encoder time points. + * ``decoder_mask`` : tensor of shape (pred_length,) + Boolean mask indicating valid decoder time points. + + If static features are present, the following keys are added: + + * ``static_categorical_features`` : tensor of shape + (1, n_static_cat_features), optional + Static categorical features, if available. + * ``static_continuous_features`` : tensor of shape (1, 0), optional + Placeholder for static continuous features (currently empty). + + y : tensor of shape ``(pred_length, n_targets)`` + Target values for the decoder sequence. + """ + series_idx, start_idx, enc_length, pred_length = self.windows[idx] + data = self.data_module._preprocess_data(series_idx) + + end_idx = start_idx + enc_length + pred_length + encoder_indices = slice(start_idx, start_idx + enc_length) + decoder_indices = slice(start_idx + enc_length, end_idx) + + target_scale = data["target"][encoder_indices] + target_scale = target_scale[~torch.isnan(target_scale)].abs().mean() + if torch.isnan(target_scale) or target_scale == 0: + target_scale = torch.tensor(1.0) + + encoder_mask = ( + data["time_mask"][encoder_indices] + if "time_mask" in data + else torch.ones(enc_length, dtype=torch.bool) + ) + decoder_mask = ( + data["time_mask"][decoder_indices] + if "time_mask" in data + else torch.zeros(pred_length, dtype=torch.bool) + ) + + encoder_cat = data["features"]["categorical"][encoder_indices] + encoder_cont = data["features"]["continuous"][encoder_indices] + + features = data["features"] + metadata = self.data_module.time_series_metadata + + known_cat_indices = [ + i + for i, col in enumerate(metadata["cols"]["x"]) + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + + known_cont_indices = [ + i + for i, col in enumerate(metadata["cols"]["x"]) + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + + cat_map = { + orig_idx: i + for i, orig_idx in enumerate(self.data_module.categorical_indices) + } + cont_map = { + orig_idx: i + for i, orig_idx in enumerate(self.data_module.continuous_indices) + } + + mapped_known_cat_indices = [ + cat_map[idx] for idx in known_cat_indices if idx in cat_map + ] + mapped_known_cont_indices = [ + cont_map[idx] for idx in known_cont_indices if idx in cont_map + ] + + decoder_cat = ( + features["categorical"][decoder_indices][:, mapped_known_cat_indices] + if mapped_known_cat_indices + else torch.zeros((pred_length, 0)) + ) + + decoder_cont = ( + features["continuous"][decoder_indices][:, mapped_known_cont_indices] + if mapped_known_cont_indices + else torch.zeros((pred_length, 0)) + ) + + x = { + "encoder_cat": encoder_cat, + "encoder_cont": encoder_cont, + "decoder_cat": decoder_cat, + "decoder_cont": decoder_cont, + "encoder_lengths": torch.tensor(enc_length), + "decoder_lengths": torch.tensor(pred_length), + "decoder_target_lengths": torch.tensor(pred_length), + "groups": data["group"], + "encoder_time_idx": torch.arange(enc_length), + "decoder_time_idx": torch.arange(enc_length, enc_length + pred_length), + "target_scale": target_scale, + "encoder_mask": encoder_mask, + "decoder_mask": decoder_mask, + } + if data["static"] is not None: + x["static_categorical_features"] = data["static"].unsqueeze(0) + x["static_continuous_features"] = torch.zeros((1, 0)) + + y = data["target"][decoder_indices] + if y.ndim == 1: + y = y.unsqueeze(-1) + + return x, y + + def _create_windows(self, indices: torch.Tensor) -> List[Tuple[int, int, int, int]]: + """Generate sliding windows for training, validation, and testing. + + Returns + ------- + List[Tuple[int, int, int, int]] + A list of tuples, where each tuple consists of: + - ``series_idx`` : int + Index of the time series in `time_series_dataset`. + - ``start_idx`` : int + Start index of the encoder window. + - ``enc_length`` : int + Length of the encoder input sequence. + - ``pred_length`` : int + Length of the decoder output sequence. + """ + windows = [] + + for idx in indices: + series_idx = idx.item() + sample = self.time_series_dataset[series_idx] + sequence_length = len(sample["y"]) + + if sequence_length < self.max_encoder_length + self.max_prediction_length: + continue + + effective_min_prediction_idx = ( + self.min_prediction_idx + if self.min_prediction_idx is not None + else self.max_encoder_length + ) + + max_prediction_idx = sequence_length - self.max_prediction_length + 1 + + if max_prediction_idx <= effective_min_prediction_idx: + continue + + for start_idx in range( + 0, max_prediction_idx - effective_min_prediction_idx + ): + if ( + start_idx + self.max_encoder_length + self.max_prediction_length + <= sequence_length + ): + windows.append( + ( + series_idx, + start_idx, + self.max_encoder_length, + self.max_prediction_length, + ) + ) + + return windows + + def setup(self, stage: Optional[str] = None): + """Prepare the datasets for training, validation, testing, or prediction. + + Parameters + ---------- + stage : Optional[str], default=None + Specifies the stage of setup. Can be one of: + - ``"fit"`` : Prepares training and validation datasets. + - ``"test"`` : Prepares the test dataset. + - ``"predict"`` : Prepares the dataset for inference. + - ``None`` : Prepares ``fit`` datasets. + """ + total_series = len(self.time_series_dataset) + self._split_indices = torch.randperm(total_series) + + self._train_size = int(self.train_val_test_split[0] * total_series) + self._val_size = int(self.train_val_test_split[1] * total_series) + + self._train_indices = self._split_indices[: self._train_size] + self._val_indices = self._split_indices[ + self._train_size : self._train_size + self._val_size + ] + self._test_indices = self._split_indices[self._train_size + self._val_size :] + + if stage is None or stage == "fit": + if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): + self.train_windows = self._create_windows(self._train_indices) + self.val_windows = self._create_windows(self._val_indices) + + self.train_dataset = self._ProcessedEncoderDecoderDataset( + self.time_series_dataset, + self, + self.train_windows, + self.add_relative_time_idx, + ) + self.val_dataset = self._ProcessedEncoderDecoderDataset( + self.time_series_dataset, + self, + self.val_windows, + self.add_relative_time_idx, + ) + + elif stage == "test": + if not hasattr(self, "test_dataset"): + self.test_windows = self._create_windows(self._test_indices) + self.test_dataset = self._ProcessedEncoderDecoderDataset( + self.time_series_dataset, + self, + self.test_windows, + self.add_relative_time_idx, + ) + elif stage == "predict": + predict_indices = torch.arange(len(self.time_series_dataset)) + self.predict_windows = self._create_windows(predict_indices) + self.predict_dataset = self._ProcessedEncoderDecoderDataset( + self.time_series_dataset, + self, + self.predict_windows, + self.add_relative_time_idx, + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def predict_dataloader(self): + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + @staticmethod + def collate_fn(batch): + x_batch = { + "encoder_cat": torch.stack([x["encoder_cat"] for x, _ in batch]), + "encoder_cont": torch.stack([x["encoder_cont"] for x, _ in batch]), + "decoder_cat": torch.stack([x["decoder_cat"] for x, _ in batch]), + "decoder_cont": torch.stack([x["decoder_cont"] for x, _ in batch]), + "encoder_lengths": torch.stack([x["encoder_lengths"] for x, _ in batch]), + "decoder_lengths": torch.stack([x["decoder_lengths"] for x, _ in batch]), + "decoder_target_lengths": torch.stack( + [x["decoder_target_lengths"] for x, _ in batch] + ), + "groups": torch.stack([x["groups"] for x, _ in batch]), + "encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]), + "decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]), + "target_scale": torch.stack([x["target_scale"] for x, _ in batch]), + "encoder_mask": torch.stack([x["encoder_mask"] for x, _ in batch]), + "decoder_mask": torch.stack([x["decoder_mask"] for x, _ in batch]), + } + + if "static_categorical_features" in batch[0][0]: + x_batch["static_categorical_features"] = torch.stack( + [x["static_categorical_features"] for x, _ in batch] + ) + x_batch["static_continuous_features"] = torch.stack( + [x["static_continuous_features"] for x, _ in batch] + ) + + y_batch = torch.stack([y for _, y in batch]) + return x_batch, y_batch diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py new file mode 100644 index 000000000..788c08201 --- /dev/null +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -0,0 +1,15 @@ +"""Data loaders for time series data.""" + +from pytorch_forecasting.data.timeseries._timeseries import ( + TimeSeriesDataSet, + _find_end_indices, + check_for_nonfinite, +) +from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries + +__all__ = [ + "_find_end_indices", + "check_for_nonfinite", + "TimeSeriesDataSet", + "TimeSeries", +] diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py similarity index 99% rename from pytorch_forecasting/data/timeseries.py rename to pytorch_forecasting/data/timeseries/_timeseries.py index 942a49721..30fe9e0bb 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -32,6 +32,7 @@ ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler from pytorch_forecasting.utils import repr_class +from pytorch_forecasting.utils._coerce import _coerce_to_dict, _coerce_to_list from pytorch_forecasting.utils._dependencies import _check_matplotlib @@ -2663,23 +2664,3 @@ def __repr__(self) -> str: attributes=self.get_parameters(), extra_attributes=dict(length=len(self)), ) - - -def _coerce_to_list(obj): - """Coerce object to list. - - None is coerced to empty list, otherwise list constructor is used. - """ - if obj is None: - return [] - return list(obj) - - -def _coerce_to_dict(obj): - """Coerce object to dict. - - None is coerce to empty dict, otherwise deepcopy is used. - """ - if obj is None: - return {} - return deepcopy(obj) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py new file mode 100644 index 000000000..d5ecbcabb --- /dev/null +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -0,0 +1,323 @@ +""" +Timeseries dataset - v2 prototype. + +Beta version, experimental - use for testing but not in production. +""" + +from typing import Dict, List, Optional, Union +from warnings import warn + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset + +from pytorch_forecasting.utils._coerce import _coerce_to_list + +####################################################################################### +# Disclaimer: This dataset class is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the data-handling pipeline may +# look like in the future. +# This is the D1 layer that is a "Raw Dataset Layer" mainly for raw data ingestion +# and turning the data to tensors. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + + +class TimeSeries(Dataset): + """PyTorch Dataset for time series data stored in pandas DataFrame. + + Parameters + ---------- + data : pd.DataFrame + data frame with sequence data. + Column names must all be str, and contain str as referred to below. + data_future : pd.DataFrame, optional, default=None + data frame with future data. + Column names must all be str, and contain str as referred to below. + May contain only columns that are in time, group, weight, known, or static. + time : str, optional, default = first col not in group_ids, weight, target, static. + integer typed column denoting the time index within ``data``. + This column is used to determine the sequence of samples. + If there are no missing observations, + the time index should increase by ``+1`` for each subsequent sample. + The first time_idx for each series does not necessarily + have to be ``0`` but any value is allowed. + target : str or List[str], optional, default = last column (at iloc -1) + column(s) in ``data`` denoting the forecasting target. + Can be categorical or numerical dtype. + group : List[str], optional, default = None + list of column names identifying a time series instance within ``data``. + This means that the ``group`` together uniquely identify an instance, + and ``group`` together with ``time`` uniquely identify a single observation + within a time series instance. + If ``None``, the dataset is assumed to be a single time series. + weight : str, optional, default=None + column name for weights. + If ``None``, it is assumed that there is no weight column. + num : list of str, optional, default = all columns with dtype in "fi" + list of numerical variables in ``data``, + list may also contain list of str, which are then grouped together. + cat : list of str, optional, default = all columns with dtype in "Obc" + list of categorical variables in ``data``, + list may also contain list of str, which are then grouped together + (e.g. useful for product categories). + known : list of str, optional, default = all variables + list of variables that change over time and are known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for special days or promotion categories). + unknown : list of str, optional, default = no variables + list of variables that are not known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for weather categories). + static : list of str, optional, default = all variables not in known, unknown + list of variables that do not change over time, + list may also contain list of str, which are then grouped together. + """ + + def __init__( + self, + data: pd.DataFrame, + data_future: Optional[pd.DataFrame] = None, + time: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group: Optional[List[str]] = None, + weight: Optional[str] = None, + num: Optional[List[Union[str, List[str]]]] = None, + cat: Optional[List[Union[str, List[str]]]] = None, + known: Optional[List[Union[str, List[str]]]] = None, + unknown: Optional[List[Union[str, List[str]]]] = None, + static: Optional[List[Union[str, List[str]]]] = None, + ): + + self.data = data + self.data_future = data_future + self.time = time + self.target = target + self.group = group + self.weight = weight + self.num = num + self.cat = cat + self.known = known + self.unknown = unknown + self.static = static + + warn( + "TimeSeries is part of an experimental rework of the " + "pytorch-forecasting data layer, " + "scheduled for release with v2.0.0. " + "The API is not stable and may change without prior warning. " + "For beta testing, but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) + + super().__init__() + + # handle defaults, coercion, and derived attributes + self._target = _coerce_to_list(target) + self._group = _coerce_to_list(group) + self._num = _coerce_to_list(num) + self._cat = _coerce_to_list(cat) + self._known = _coerce_to_list(known) + self._unknown = _coerce_to_list(unknown) + self._static = _coerce_to_list(static) + + self.feature_cols = [ + col + for col in data.columns + if col not in [self.time] + self._group + [self.weight] + self._target + ] + if self._group: + self._groups = self.data.groupby(self._group).groups + self._group_ids = list(self._groups.keys()) + else: + self._groups = {"_single_group": self.data.index} + self._group_ids = ["_single_group"] + + self._prepare_metadata() + + # overwrite __init__ params for upwards compatibility with AS PRs + # todo: should we avoid this and ensure classes are dataclass-like? + self.group = self._group + self.target = self._target + self.num = self._num + self.cat = self._cat + self.known = self._known + self.unknown = self._unknown + self.static = self._static + + def _prepare_metadata(self): + """Prepare metadata for the dataset. + + The funcion returns metadata that contains: + + * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } + Names of columns for y, x, and static features. + List elements are in same order as column dimensions. + Columns not appearing are assumed to be named (x0, x1, etc.), + (y0, y1, etc.), (st0, st1, etc.). + * ``col_type``: dict[str, str] + maps column names to data types "F" (numerical) and "C" (categorical). + Column names not occurring are assumed "F". + * ``col_known``: dict[str, str] + maps column names to "K" (future known) or "U" (future unknown). + Column names not occurring are assumed "K". + """ + self.metadata = { + "cols": { + "y": self._target, + "x": self.feature_cols, + "st": self._static, + }, + "col_type": {}, + "col_known": {}, + } + + all_cols = self._target + self.feature_cols + self._static + for col in all_cols: + self.metadata["col_type"][col] = "C" if col in self._cat else "F" + + self.metadata["col_known"][col] = "K" if col in self._known else "U" + + def __len__(self) -> int: + """Return number of time series in the dataset.""" + return len(self._group_ids) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get time series data for given index. + + Returns + ------- + t : numpy.ndarray of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with `y`, + and `x` not ending in `f`. + + y : torch.Tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with `t`. + + x : torch.Tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with `t`. + + group : torch.Tensor of shape (n_groups,) + Group identifiers for time series instances. + + st : torch.Tensor of shape (n_static_features,) + Static features. + + cutoff_time : float or numpy.float64 + Cutoff time for the time series instance. + + Other Returns + ------------- + weights : torch.Tensor of shape (n_timepoints,), optional + Only included if weights are not `None`. + """ + time = self.time + feature_cols = self.feature_cols + _target = self._target + _known = self._known + _static = self._static + _group = self._group + _groups = self._groups + _group_ids = self._group_ids + weight = self.weight + data_future = self.data_future + + group_id = _group_ids[index] + + if _group: + mask = _groups[group_id] + data = self.data.loc[mask] + else: + data = self.data + + cutoff_time = data[time].max() + + data_vals = data[time].values + data_tgt_vals = data[_target].values + data_feat_vals = data[feature_cols].values + + result = { + "t": data_vals, + "y": torch.tensor(data_tgt_vals), + "x": torch.tensor(data_feat_vals), + "group": torch.tensor([hash(str(group_id))]), + "st": torch.tensor(data[_static].iloc[0].values if _static else []), + "cutoff_time": cutoff_time, + } + + if data_future is not None: + if _group: + future_mask = self.data_future.groupby(_group).groups[group_id] + future_data = self.data_future.loc[future_mask] + else: + future_data = self.data_future + + data_fut_vals = future_data[time].values + + combined_times = np.concatenate([data_vals, data_fut_vals]) + combined_times = np.unique(combined_times) + combined_times.sort() + + num_timepoints = len(combined_times) + x_merged = np.full((num_timepoints, len(feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(_target)), np.nan) + + current_time_indices = {t: i for i, t in enumerate(combined_times)} + for i, t in enumerate(data_vals): + idx = current_time_indices[t] + x_merged[idx] = data_feat_vals[i] + y_merged[idx] = data_tgt_vals[i] + + for i, t in enumerate(data_fut_vals): + if t in current_time_indices: + idx = current_time_indices[t] + for j, col in enumerate(_known): + if col in feature_cols: + feature_idx = feature_cols.index(col) + x_merged[idx, feature_idx] = future_data[col].values[i] + + result.update( + { + "t": combined_times, + "x": torch.tensor(x_merged, dtype=torch.float32), + "y": torch.tensor(y_merged, dtype=torch.float32), + } + ) + + if weight: + if self.data_future is not None and self.weight in self.data_future.columns: + weights_merged = np.full(num_timepoints, np.nan) + for i, t in enumerate(data_vals): + idx = current_time_indices[t] + weights_merged[idx] = data[weight].values[i] + + for i, t in enumerate(data_fut_vals): + if t in current_time_indices and self.weight in future_data.columns: + idx = current_time_indices[t] + weights_merged[idx] = future_data[weight].values[i] + + result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) + else: + result["weights"] = torch.tensor( + data[self.weight].values, dtype=torch.float32 + ) + + return result + + def get_metadata(self) -> Dict: + """Return metadata about the dataset. + + Returns + ------- + Dict + Dictionary containing: + - cols: column names for y, x, and static features + - col_type: mapping of columns to their types (F/C) + - col_known: mapping of columns to their future known status (K/U) + """ + return self.metadata diff --git a/pytorch_forecasting/utils/_coerce.py b/pytorch_forecasting/utils/_coerce.py new file mode 100644 index 000000000..328431aa8 --- /dev/null +++ b/pytorch_forecasting/utils/_coerce.py @@ -0,0 +1,25 @@ +"""Coercion functions for various data types.""" + +from copy import deepcopy + + +def _coerce_to_list(obj): + """Coerce object to list. + + None is coerced to empty list, otherwise list constructor is used. + """ + if obj is None: + return [] + if isinstance(obj, str): + return [obj] + return list(obj) + + +def _coerce_to_dict(obj): + """Coerce object to dict. + + None is coerce to empty dict, otherwise deepcopy is used. + """ + if obj is None: + return {} + return deepcopy(obj) diff --git a/tests/test_data/test_d1.py b/tests/test_data/test_d1.py new file mode 100644 index 000000000..b32c13213 --- /dev/null +++ b/tests/test_data/test_d1.py @@ -0,0 +1,379 @@ +import numpy as np +import pandas as pd +import pytest +import torch + +from pytorch_forecasting.data.timeseries import TimeSeries + + +@pytest.fixture +def sample_data(): + """Create time series data for testing.""" + dates = pd.date_range(start="2023-01-01", periods=10, freq="D") + data = pd.DataFrame( + { + "timestamp": dates, + "target_value": np.sin(np.arange(10)) + 10, + "feature1": np.random.randn(10), + "feature2": np.random.randn(10), + "feature3": np.random.randn(10), + "group_id": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], + "weight": np.abs(np.random.randn(10)) + 0.1, + "static_feat": [10, 10, 10, 10, 10, 20, 20, 20, 20, 20], + } + ) + return data + + +@pytest.fixture +def future_data(): + """Create future time series data.""" + dates = pd.date_range(start="2023-01-11", periods=5, freq="D") + data = pd.DataFrame( + { + "timestamp": dates, + "feature1": np.random.randn(5), + "feature2": np.random.randn(5), + "feature3": np.random.randn(5), + "group_id": [1, 1, 1, 2, 2], + "weight": np.abs(np.random.randn(5)) + 0.1, + "static_feat": [10, 10, 10, 20, 20], + } + ) + return data + + +def test_init_basic(sample_data): + """Test basic initialization of TimeSeries class. + + Ensures that the class stores time, target, and correctly detects feature columns + when no group, known/unknown features, or static/weight features are specified.""" + ts = TimeSeries(data=sample_data, time="timestamp", target="target_value") + + assert ts.time == "timestamp" + assert ts.target == ["target_value"] + assert len(ts.feature_cols) == 6 # All columns except timestamp, target_value + assert len(ts) == 1 # Single group by default + + +def test_init_with_groups(sample_data): + """Test initialization with group parameter. + + Verifies that data is grouped correctly and each group is handled as a + separate time series. + """ + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", group=["group_id"] + ) + + assert ts.group == ["group_id"] + assert len(ts) == 2 # Two groups (1 and 2) + assert set(ts._group_ids) == {1, 2} + + +def test_init_with_features_categorization(sample_data): + """Test feature categorization. + + Ensures that numeric, categorical, and static features are categorized and + stored correctly in metadata.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + num=["feature1", "feature2", "feature3"], + cat=[], + static=["static_feat"], + ) + + assert ts.num == ["feature1", "feature2", "feature3"] + assert ts.cat == [] + assert ts.static == ["static_feat"] + assert ts.metadata["col_type"]["feature1"] == "F" + assert ts.metadata["col_type"]["feature2"] == "F" + + +def test_init_with_known_unknown(sample_data): + """Test known and unknown features classification. + + Checks if the known and unknown feature categorization is correctly set + and stored in metadata.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + known=["feature1"], + unknown=["feature2", "feature3"], + ) + + assert ts.known == ["feature1"] + assert ts.unknown == ["feature2", "feature3"] + assert ts.metadata["col_known"]["feature1"] == "K" + assert ts.metadata["col_known"]["feature2"] == "U" + + +def test_init_with_weight(sample_data): + """Test initialization with weight parameter. + + Verifies that the weight column is stored correctly and excluded + from the feature columns.""" + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", weight="weight" + ) + + assert ts.weight == "weight" + assert "weight" not in ts.feature_cols + + +def test_getitem_basic(sample_data): + """Test __getitem__ with basic configuration. + + Checks the output structure of a single time series without grouping, + ensuring x, y are tensors of correct shapes.""" + ts = TimeSeries(data=sample_data, time="timestamp", target="target_value") + + result = ts[0] + assert torch.is_tensor(result["y"]) + assert torch.is_tensor(result["x"]) + assert "t" in result + assert "cutoff_time" in result + assert len(result["y"]) == 10 # 10 data points + assert result["y"].shape == (10, 1) # One target variable + assert result["x"].shape[1] == 6 # Six feature columns + + +def test_getitem_with_groups(sample_data): + """Test __getitem__ with groups parameter. + + Verifies the per-group access using index and checks that each group + has the correct number of time steps.""" + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", group=["group_id"] + ) + + # group (1) + result_g1 = ts[0] + assert len(result_g1["t"]) == 5 # 5 data points in group 1 + + # group (2) + result_g2 = ts[1] + assert len(result_g2["t"]) == 5 # 5 data points in group 2 + + +def test_getitem_with_static(sample_data): + """Test __getitem__ with static features. + + Ensures static features are included in the output and correctly + mapped per group.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + group=["group_id"], + static=["static_feat"], + ) + + result_g1 = ts[0] + result_g2 = ts[1] + + assert torch.is_tensor(result_g1["st"]) + assert result_g1["st"].item() == 10 # Static feature for group 1 + assert result_g2["st"].item() == 20 # Static feature for group 2 + + +def test_getitem_with_weight(sample_data): + """Test __getitem__ with weight parameter. + + Validates that weights are correctly returned in the output and have the + expected length and type.""" + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", weight="weight" + ) + + result = ts[0] + assert "weights" in result + assert torch.is_tensor(result["weights"]) + assert len(result["weights"]) == 10 + + +def test_with_future_data(sample_data, future_data): + """Test with future data provided. + + Verifies that future time steps are appended to the end of each group, + especially for known features.""" + ts = TimeSeries( + data=sample_data, + data_future=future_data, + time="timestamp", + target="target_value", + group=["group_id"], + known=["feature1"], + ) + + result_g1 = ts[0] # Group 1 + + assert len(result_g1["t"]) == 8 # 5 original + 3 future for group 1 + + feature1_idx = ts.feature_cols.index("feature1") + assert not torch.isnan( + result_g1["x"][-1, feature1_idx] + ) # feature1 is not NaN in last row + + +def test_future_data_with_weights(sample_data, future_data): + """Test handling of weights with future data. + + Ensures that weights from future data are combined properly and match the + time indices.""" + ts = TimeSeries( + data=sample_data, + data_future=future_data, + time="timestamp", + target="target_value", + group=["group_id"], + weight="weight", + ) + + result = ts[0] # Group 1 + assert "weights" in result + assert torch.is_tensor(result["weights"]) + assert len(result["weights"]) == len(result["t"]) + + +def test_future_data_missing_columns(sample_data): + """Test handling when future data is missing some columns. + + Verifies the handling of missing feature columns in future data by + checking NaN padding.""" + dates = pd.date_range(start="2023-01-11", periods=5, freq="D") + incomplete_future = pd.DataFrame( + { + "timestamp": dates, + "feature1": np.random.randn(5), + # Missing feature2, feature3 + "group_id": [1, 1, 1, 2, 2], + "weight": np.abs(np.random.randn(5)) + 0.1, + } + ) + + ts = TimeSeries( + data=sample_data, + data_future=incomplete_future, + time="timestamp", + target="target_value", + group=["group_id"], + known=["feature1"], + ) + + result = ts[0] + # Check that missing features are NaN in future timepoints + future_indices = np.where(result["t"] >= np.datetime64("2023-01-11"))[0] + feature2_idx = ts.feature_cols.index("feature2") + feature3_idx = ts.feature_cols.index("feature3") + assert torch.isnan(result["x"][future_indices[0], feature2_idx]) + assert torch.isnan(result["x"][future_indices[0], feature3_idx]) + + +def test_different_future_groups(sample_data): + """Test with future data that has different groups than original data. + + Ensures that groups present only in future data are ignored if not + in the original dataset.""" + dates = pd.date_range(start="2023-01-11", periods=5, freq="D") + future_with_new_group = pd.DataFrame( + { + "timestamp": dates, + "feature1": np.random.randn(5), + "feature2": np.random.randn(5), + "feature3": np.random.randn(5), + "group_id": [1, 1, 3, 3, 3], # Group 3 is new + "weight": np.abs(np.random.randn(5)) + 0.1, + "static_feat": [10, 10, 30, 30, 30], + } + ) + + ts = TimeSeries( + data=sample_data, + data_future=future_with_new_group, + time="timestamp", + target="target_value", + group=["group_id"], + ) + + # Original data has groups 1 and 2, but not 3 + assert len(ts) == 2 + assert 3 not in ts._group_ids + + +def test_multiple_targets(sample_data): + """Test handling of multiple target variables. + + Verifies that multiple target columns are handled and returned + as the correct shape in the output.""" + sample_data["target_value2"] = np.cos(np.arange(10)) + 5 + + ts = TimeSeries( + data=sample_data, time="timestamp", target=["target_value", "target_value2"] + ) + + result = ts[0] + assert result["y"].shape == (10, 2) # Two target variables + + +def test_empty_groups(): + """Test handling of empty groups. + + Confirms that the class handles datasets with a single group and + no empty group errors occur.""" + data = pd.DataFrame( + { + "timestamp": pd.date_range(start="2023-01-01", periods=5, freq="D"), + "target_value": np.random.randn(5), + "group_id": [1, 1, 1, 1, 1], # Only one group + } + ) + + ts = TimeSeries( + data=data, time="timestamp", target="target_value", group=["group_id"] + ) + + assert len(ts) == 1 # Only one group + + +def test_metadata_structure(sample_data): + """Test the structure of metadata. + + Ensures the metadata dictionary includes the expected keys and + correct mappings of feature roles.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + num=["feature1", "feature2", "feature3"], + cat=[], # No categorical features + static=["static_feat"], + known=["feature1"], + unknown=["feature2", "feature3"], + ) + + metadata = ts.get_metadata() + + assert "cols" in metadata + assert "col_type" in metadata + assert "col_known" in metadata + + assert metadata["cols"]["y"] == ["target_value"] + assert set(metadata["cols"]["x"]) == { + "feature1", + "feature2", + "feature3", + "group_id", + "weight", + "static_feat", + } + assert metadata["cols"]["st"] == ["static_feat"] + + assert metadata["col_type"]["feature1"] == "F" + assert metadata["col_type"]["feature2"] == "F" + + assert metadata["col_known"]["feature1"] == "K" + assert metadata["col_known"]["feature2"] == "U" diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py new file mode 100644 index 000000000..4051b852c --- /dev/null +++ b/tests/test_data/test_data_module.py @@ -0,0 +1,464 @@ +import numpy as np +import pandas as pd +import pytest + +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule +from pytorch_forecasting.data.timeseries import TimeSeries + + +@pytest.fixture +def sample_timeseries_data(): + """Create a sample time series dataset with only numerical values.""" + num_groups = 10 + seq_length = 100 + + groups = [] + times = [] + values = [] + categorical_feature = [] + continuous_feature1 = [] + continuous_feature2 = [] + known_future = [] + + for g in range(num_groups): + for t in range(seq_length): + groups.append(g) + times.append(pd.Timestamp("2020-01-01") + pd.Timedelta(days=t)) + + value = 10 + 0.1 * t + 5 * np.sin(t / 10) + g * 2 + np.random.normal(0, 1) + values.append(value) + + categorical_feature.append(np.random.choice([0, 1, 2])) + + continuous_feature1.append(np.random.normal(g, 1)) + continuous_feature2.append(value * 0.5 + np.random.normal(0, 0.5)) + + known_future.append(t % 7) + + df = pd.DataFrame( + { + "group": groups, + "time": times, + "target": values, + "cat_feat": categorical_feature, + "cont_feat1": continuous_feature1, + "cont_feat2": continuous_feature2, + "known_future": known_future, + } + ) + + time_series = TimeSeries( + data=df, + time="time", + target="target", + group=["group"], + num=["cont_feat1", "cont_feat2", "known_future"], + cat=["cat_feat"], + known=["known_future"], + ) + + return time_series + + +@pytest.fixture +def data_module(sample_timeseries_data): + """Create a data module instance.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=4, + train_val_test_split=(0.7, 0.15, 0.15), + ) + return dm + + +def test_init(sample_timeseries_data): + """Test the initialization of the data module. + + Verifies hyperparameter assignment and basic time_series_metadata creation.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=8, + ) + + assert dm.max_encoder_length == 24 + assert dm.max_prediction_length == 12 + assert dm.min_encoder_length == 24 + assert dm.min_prediction_length == 12 + assert dm.batch_size == 8 + assert dm.train_val_test_split == (0.7, 0.15, 0.15) + + assert isinstance(dm.time_series_metadata, dict) + assert "cols" in dm.time_series_metadata + + +def test_prepare_metadata(data_module): + """Test the metadata preparation method. + + Ensures that internal metadata keys are created correctly.""" + metadata = data_module._prepare_metadata() + + assert "encoder_cat" in metadata + assert "encoder_cont" in metadata + assert "decoder_cat" in metadata + assert "decoder_cont" in metadata + assert "target" in metadata + assert "max_encoder_length" in metadata + assert "max_prediction_length" in metadata + + assert metadata["max_encoder_length"] == 24 + assert metadata["max_prediction_length"] == 12 + + +def test_metadata_property(data_module): + """Test the metadata property. + + Confirms caching behavior and correct feature counts.""" + metadata = data_module.metadata + + # Should return the same object when called multiple times (caching) + assert data_module.metadata is metadata + + assert metadata["encoder_cat"] == 1 # cat_feat + assert metadata["encoder_cont"] == 3 # cont_feat1, cont_feat2, known_future + assert metadata["decoder_cat"] == 0 # No categorical features marked as known + assert metadata["decoder_cont"] == 1 # Only known_future marked as known + + +def test_setup(data_module): + """Test the setup method that prepares the datasets.""" + data_module.setup(stage="fit") + print(data_module._val_indices) + assert hasattr(data_module, "train_dataset") + assert hasattr(data_module, "val_dataset") + assert len(data_module.train_windows) > 0 + assert len(data_module.val_windows) > 0 + + data_module.setup(stage="test") + assert hasattr(data_module, "test_dataset") + assert len(data_module.test_windows) > 0 + + data_module.setup(stage="predict") + assert hasattr(data_module, "predict_dataset") + assert len(data_module.predict_windows) > 0 + + +def test_create_windows(data_module): + """Test the window creation logic. + + Validates window structure and length settings.""" + data_module.setup() + + windows = data_module._create_windows(data_module._train_indices) + + assert len(windows) > 0 + + for window in windows: + assert len(window) == 4 + assert window[2] == data_module.max_encoder_length + assert window[3] == data_module.max_prediction_length + + +def test_dataloader_creation(data_module): + """Test that dataloaders are created correctly. + + Checks batch sizes and dataloader instantiation across all stages.""" + data_module.setup() + + train_loader = data_module.train_dataloader() + assert train_loader.batch_size == data_module.batch_size + assert train_loader.num_workers == data_module.num_workers + + val_loader = data_module.val_dataloader() + assert val_loader.batch_size == data_module.batch_size + + data_module.setup(stage="test") + test_loader = data_module.test_dataloader() + assert test_loader.batch_size == data_module.batch_size + + data_module.setup(stage="predict") + predict_loader = data_module.predict_dataloader() + assert predict_loader.batch_size == data_module.batch_size + + +def test_processed_dataset(data_module): + """Test the internal ProcessedEncoderDecoderDataset class. + + Verifies sample structure and tensor dimensions for encoder/decoder inputs.""" + data_module.setup() + + assert len(data_module.train_dataset) == len(data_module.train_windows) + assert len(data_module.val_dataset) == len(data_module.val_windows) + + x, y = data_module.train_dataset[0] + + required_keys = [ + "encoder_cat", + "encoder_cont", + "decoder_cat", + "decoder_cont", + "encoder_lengths", + "decoder_lengths", + "decoder_target_lengths", + "groups", + "encoder_time_idx", + "decoder_time_idx", + "target_scale", + "encoder_mask", + "decoder_mask", + ] + + for key in required_keys: + assert key in x + + assert x["encoder_cat"].shape[0] == data_module.max_encoder_length + assert x["decoder_cat"].shape[0] == data_module.max_prediction_length + + metadata = data_module.time_series_metadata + known_cat_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + ) + + known_cont_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + ) + + assert x["decoder_cat"].shape[1] == known_cat_count + assert x["decoder_cont"].shape[1] == known_cont_count + + assert y.shape[0] == data_module.max_prediction_length + + +def test_collate_fn(data_module): + """Test the collate function that combines batch samples. + + Ensures proper stacking of dictionary keys and batch outputs.""" + data_module.setup() + + batch_size = 3 + batch = [data_module.train_dataset[i] for i in range(batch_size)] + + x_batch, y_batch = data_module.collate_fn(batch) + + for key in x_batch: + assert x_batch[key].shape[0] == batch_size + + metadata = data_module.time_series_metadata + known_cat_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + ) + + known_cont_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + ) + + assert x_batch["decoder_cat"].shape[2] == known_cat_count + assert x_batch["decoder_cont"].shape[2] == known_cont_count + assert y_batch.shape[0] == batch_size + assert y_batch.shape[1] == data_module.max_prediction_length + + +def test_full_dataloader_iteration(data_module): + """Test a full iteration through the train dataloader. + + Confirms batch retrieval and tensor dimensions match configuration.""" + data_module.setup() + train_loader = data_module.train_dataloader() + + batch = next(iter(train_loader)) + x_batch, y_batch = batch + + assert x_batch["encoder_cat"].shape[0] == data_module.batch_size + assert x_batch["encoder_cat"].shape[1] == data_module.max_encoder_length + + metadata = data_module.time_series_metadata + known_cat_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + ) + + known_cont_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + ) + + assert x_batch["decoder_cat"].shape[0] == data_module.batch_size + assert x_batch["decoder_cat"].shape[2] == known_cat_count + assert x_batch["decoder_cont"].shape[0] == data_module.batch_size + assert x_batch["decoder_cont"].shape[2] == known_cont_count + assert y_batch.shape[0] == data_module.batch_size + assert y_batch.shape[1] == data_module.max_prediction_length + + +def test_variable_encoder_lengths(sample_timeseries_data): + """Test with variable encoder lengths. + + Ensures random length behavior is respected and functional.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + min_encoder_length=12, + max_prediction_length=12, + batch_size=4, + randomize_length=True, + ) + + dm.setup() + assert dm.min_encoder_length == 12 + assert dm.max_encoder_length == 24 + + +def test_preprocess_data(data_module, sample_timeseries_data): + """Test the _preprocess_data method. + + Checks preprocessing output structure and alignment with raw data.""" + if not hasattr(data_module, "_split_indices"): + data_module.setup() + + series_idx = data_module._train_indices[0] + + processed = data_module._preprocess_data(series_idx) + + assert "features" in processed + assert "categorical" in processed["features"] + assert "continuous" in processed["features"] + assert "target" in processed + assert "time_mask" in processed + + original_sample = sample_timeseries_data[series_idx.item()] + expected_length = len(original_sample["y"]) + + assert processed["features"]["categorical"].shape[0] == expected_length + assert processed["features"]["continuous"].shape[0] == expected_length + assert processed["target"].shape[0] == expected_length + + +def test_with_static_features(): + """Test with static features included. + + Validates static feature support in both metadata and sample input.""" + df = pd.DataFrame( + { + "group": [0, 0, 0, 1, 1, 1], + "time": pd.date_range("2020-01-01", periods=6), + "target": [1, 2, 3, 4, 5, 6], + "static_cat": [0, 0, 0, 1, 1, 1], + "static_num": [10, 10, 10, 20, 20, 20], + "feature1": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + } + ) + + ts = TimeSeries( + data=df, + time="time", + target="target", + group=["group"], + num=["feature1", "static_num"], + static=["static_cat", "static_num"], + cat=["static_cat"], + ) + + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=ts, + max_encoder_length=2, + max_prediction_length=1, + batch_size=2, + ) + + dm.setup() + + metadata = dm.metadata + assert metadata["static_categorical_features"] == 1 + assert metadata["static_continuous_features"] == 1 + + x, y = dm.train_dataset[0] + assert "static_categorical_features" in x + assert "static_continuous_features" in x + + +def test_different_train_val_test_split(sample_timeseries_data): + """Test with different train/val/test split ratios.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=4, + train_val_test_split=(0.8, 0.1, 0.1), + ) + + dm.setup() + + total_series = len(sample_timeseries_data) + expected_train = int(0.8 * total_series) + expected_val = int(0.1 * total_series) + + assert len(dm._train_indices) == expected_train + assert len(dm._val_indices) == expected_val + assert len(dm._test_indices) == total_series - expected_train - expected_val + + +def test_multivariate_target(): + """Test with multivariate target (multiple target columns). + + Verifies correct handling of multivariate targets in data pipeline.""" + df = pd.DataFrame( + { + "group": np.repeat([0, 1], 50), + "time": np.tile(pd.date_range("2020-01-01", periods=50), 2), + "target1": np.random.normal(0, 1, 100), + "target2": np.random.normal(5, 2, 100), + "feature1": np.random.normal(0, 1, 100), + "feature2": np.random.normal(0, 1, 100), + } + ) + + ts = TimeSeries( + data=df, + time="time", + target=["target1", "target2"], + group=["group"], + num=["feature1", "feature2"], + ) + + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=ts, + max_encoder_length=10, + max_prediction_length=5, + batch_size=4, + ) + + dm.setup() + + x, y = dm.train_dataset[0] + assert y.shape[-1] == 2 From 15ea3c3d5584edca5f4332bc74490a7394e2460d Mon Sep 17 00:00:00 2001 From: Aryan Saini <116151399+phoeenniixx@users.noreply.github.com> Date: Sat, 17 May 2025 00:35:12 +0530 Subject: [PATCH 25/33] [BUG] EXPERIMENTAL PR: Solve the bug in `data_module` (#1834) This PR solves the bug in `data_module` where the `static_categorical_features` and `static_continuous_features` were not correctly calculated in `__getitem__` of nested class --- pytorch_forecasting/data/data_module.py | 36 +++++++++++++++++++++++-- tests/test_data/test_data_module.py | 8 ++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index c8252014d..f6706275f 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -519,8 +519,40 @@ def __getitem__(self, idx): "decoder_mask": decoder_mask, } if data["static"] is not None: - x["static_categorical_features"] = data["static"].unsqueeze(0) - x["static_continuous_features"] = torch.zeros((1, 0)) + raw_st_tensor = data.get("static") + static_col_names = self.data_module.time_series_metadata["cols"]["st"] + + is_categorical_mask = torch.tensor( + [ + self.data_module.time_series_metadata["col_type"].get(col_name) + == "C" + for col_name in static_col_names + ], + dtype=torch.bool, + ) + + is_continuous_mask = ~is_categorical_mask + + st_cat_values_for_item = raw_st_tensor[is_categorical_mask] + st_cont_values_for_item = raw_st_tensor[is_continuous_mask] + + if st_cat_values_for_item.shape[0] > 0: + x["static_categorical_features"] = st_cat_values_for_item.unsqueeze( + 0 + ) + else: + x["static_categorical_features"] = torch.zeros( + (1, 0), dtype=torch.float32 + ) + + if st_cont_values_for_item.shape[0] > 0: + x["static_continuous_features"] = st_cont_values_for_item.unsqueeze( + 0 + ) + else: + x["static_continuous_features"] = torch.zeros( + (1, 0), dtype=torch.float32 + ) y = data["target"][decoder_indices] if y.ndim == 1: diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index 4051b852c..cad78aecd 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -405,6 +405,14 @@ def test_with_static_features(): x, y = dm.train_dataset[0] assert "static_categorical_features" in x assert "static_continuous_features" in x + assert ( + x["static_categorical_features"].shape[1] + == metadata["static_categorical_features"] + ) + assert ( + x["static_continuous_features"].shape[1] + == metadata["static_continuous_features"] + ) def test_different_train_val_test_split(sample_timeseries_data): From c04ebf378fd62bab949a04f98b4cf77bad8e10f9 Mon Sep 17 00:00:00 2001 From: ne0n Date: Fri, 16 May 2025 23:19:55 +0200 Subject: [PATCH 26/33] [BUG] fix incorrect concatenation dimension in `concat_sequences` (#1827) ### Description This PR fixes [1823](https://github.com/sktime/pytorch-forecasting/issues/1823) --- .../data/timeseries/_timeseries.py | 2 +- .../models/base/_base_model.py | 2 +- pytorch_forecasting/utils/_utils.py | 6 +-- .../test_temporal_fusion_transformer.py | 48 ++++++++++++++++++- 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py index 30fe9e0bb..f384367aa 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -2348,7 +2348,7 @@ def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]: @staticmethod def _collate_fn( - batches: list[tuple[dict[str, torch.Tensor], torch.Tensor]] + batches: list[tuple[dict[str, torch.Tensor], torch.Tensor]], ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: """ Collate function to combine items into mini-batch for dataloader. diff --git a/pytorch_forecasting/models/base/_base_model.py b/pytorch_forecasting/models/base/_base_model.py index f7b14488f..1aa865dff 100644 --- a/pytorch_forecasting/models/base/_base_model.py +++ b/pytorch_forecasting/models/base/_base_model.py @@ -133,7 +133,7 @@ def _concatenate_output( str, List[Union[List[torch.Tensor], torch.Tensor, bool, int, str, np.ndarray]], ] - ] + ], ) -> Dict[ str, Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, int, bool, str]]] ]: diff --git a/pytorch_forecasting/utils/_utils.py b/pytorch_forecasting/utils/_utils.py index af93006cf..eb850b1e7 100644 --- a/pytorch_forecasting/utils/_utils.py +++ b/pytorch_forecasting/utils/_utils.py @@ -233,7 +233,7 @@ def autocorrelation(input, dim=0): def unpack_sequence( - sequence: Union[torch.Tensor, rnn.PackedSequence] + sequence: Union[torch.Tensor, rnn.PackedSequence], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Unpack RNN sequence. @@ -257,7 +257,7 @@ def unpack_sequence( def concat_sequences( - sequences: Union[List[torch.Tensor], List[rnn.PackedSequence]] + sequences: Union[List[torch.Tensor], List[rnn.PackedSequence]], ) -> Union[torch.Tensor, rnn.PackedSequence]: """ Concatenate RNN sequences. @@ -272,7 +272,7 @@ def concat_sequences( if isinstance(sequences[0], rnn.PackedSequence): return rnn.pack_sequence(sequences, enforce_sorted=False) elif isinstance(sequences[0], torch.Tensor): - return torch.cat(sequences, dim=1) + return torch.cat(sequences, dim=0) elif isinstance(sequences[0], (tuple, list)): return tuple( concat_sequences([sequences[ii][i] for ii in range(len(sequences))]) diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index 24c249bd5..f0eab8671 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -10,9 +10,10 @@ import pytest import torch -from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting import Baseline, TimeSeriesDataSet from pytorch_forecasting.data import NaNLabelEncoder from pytorch_forecasting.data.encoders import GroupNormalizer, MultiNormalizer +from pytorch_forecasting.data.examples import generate_ar_data from pytorch_forecasting.metrics import ( CrossEntropy, MQF2DistributionLoss, @@ -521,3 +522,48 @@ def test_no_exogenous_variable(): return_y=True, return_index=True, ) + + +def test_correct_prediction_concatenation(): + data = generate_ar_data(seasonality=10.0, timesteps=100, n_series=2, seed=42) + data["static"] = 2 + data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") + data.head() + + # create dataset and dataloaders + max_encoder_length = 20 + max_prediction_length = 5 + + training_cutoff = data["time_idx"].max() - max_prediction_length + + context_length = max_encoder_length + prediction_length = max_prediction_length + + training = TimeSeriesDataSet( + data[lambda x: x.time_idx <= training_cutoff], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + # only unknown variable is "value" + # and N-Beats can also not take any additional variables + time_varying_unknown_reals=["value"], + max_encoder_length=context_length, + max_prediction_length=prediction_length, + ) + + batch_size = 71 + train_dataloader = training.to_dataloader( + train=True, batch_size=batch_size, num_workers=0 + ) + + baseline_model = Baseline() + predictions = baseline_model.predict( + train_dataloader, + return_x=True, + return_y=True, + trainer_kwargs=dict(logger=None, accelerator="cpu"), + ) + + # The predicted output and the target should have the same size. + assert predictions.output.size() == predictions.y[0].size() From 524d05b05eb38836bb3118f85574a73f32ced832 Mon Sep 17 00:00:00 2001 From: Aryan Saini <116151399+phoeenniixx@users.noreply.github.com> Date: Sun, 18 May 2025 13:20:52 +0530 Subject: [PATCH 27/33] [ENH] EXPERIMENTAL PR: make the `data_module` dataclass-like (#1832) This PR makes the `data_modulel` dataclass-like See discussion in #1829 --- pytorch_forecasting/data/data_module.py | 8 -------- tests/test_data/test_data_module.py | 4 ++-- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index f6706275f..ec3c11de9 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -163,14 +163,6 @@ def __init__( else: self.continuous_indices.append(idx) - # overwrite __init__ params for upwards compatibility with AS PRs - # todo: should we avoid this and ensure classes are dataclass-like? - self.min_prediction_length = self._min_prediction_length - self.min_encoder_length = self._min_encoder_length - self.categorical_encoders = self._categorical_encoders - self.scalers = self._scalers - self.target_normalizer = self._target_normalizer - def _prepare_metadata(self): """Prepare metadata for model initialisation. diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index cad78aecd..b5f2067b5 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -86,8 +86,8 @@ def test_init(sample_timeseries_data): assert dm.max_encoder_length == 24 assert dm.max_prediction_length == 12 - assert dm.min_encoder_length == 24 - assert dm.min_prediction_length == 12 + assert dm._min_encoder_length == 24 + assert dm._min_prediction_length == 12 assert dm.batch_size == 8 assert dm.train_val_test_split == (0.7, 0.15, 0.15) From b82b42a25181c740f161b25c2c058f7a8a6ee08f Mon Sep 17 00:00:00 2001 From: Pranav Bhat Date: Thu, 22 May 2025 10:56:01 +0530 Subject: [PATCH 28/33] add initial version of tests for tide --- tests/test_models/test_tide.py | 202 +++++++++++++++++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 tests/test_models/test_tide.py diff --git a/tests/test_models/test_tide.py b/tests/test_models/test_tide.py new file mode 100644 index 000000000..7931a15e0 --- /dev/null +++ b/tests/test_models/test_tide.py @@ -0,0 +1,202 @@ +import pickle +import shutil + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.loggers import TensorBoardLogger +import numpy as np +import pandas as pd +import pytest + +from pytorch_forecasting.data.timeseries import TimeSeriesDataSet +from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss +from pytorch_forecasting.models import TiDEModel +from pytorch_forecasting.utils._dependencies import _get_installed_packages + + +def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs): + train_dataloader = dataloader["train"] + val_dataloader = dataloader["val"] + test_dataloader = dataloader["test"] + + early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" + ) + + logger = TensorBoardLogger(tmp_path) + if trainer_kwargs is None: + trainer_kwargs = {} + trainer = pl.Trainer( + max_epochs=2, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + enable_checkpointing=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + logger=logger, + **trainer_kwargs, + ) + + kwargs.setdefault("hidden_size", 16) + kwargs.setdefault("temporal_decoder_hidden", 8) + kwargs.setdefault("temporal_width_future", 4) + kwargs.setdefault("dropout", 0.1) + kwargs.setdefault("learning_rate", 0.01) + + net = TiDEModel.from_dataset( + train_dataloader.dataset, + **kwargs, + ) + net.size() + try: + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + # check loading + net = TiDEModel.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path + ) + + # check prediction + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + trainer_kwargs=trainer_kwargs, + ) + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + predictions = net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + ) + return predictions + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"loss": SMAPE()}, + {"hidden_size": 32, "temporal_decoder_hidden": 16}, + {"dropout": 0.2, "use_layer_norm": True}, + ], +) +def test_integration(dataloaders_with_covariates, tmp_path, kwargs): + _integration(dataloaders_with_covariates, tmp_path, **kwargs) + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + ], +) +def test_multi_target_integration(dataloaders_multi_target, tmp_path, kwargs): + _integration(dataloaders_multi_target, tmp_path, **kwargs) + + +@pytest.fixture +def model(dataloaders_with_covariates): + dataset = dataloaders_with_covariates["train"].dataset + net = TiDEModel.from_dataset( + dataset, + hidden_size=16, + dropout=0.1, + temporal_width_future=4, + ) + return net + + +def test_pickle(model): + pkl = pickle.dumps(model) + pickle.loads(pkl) # noqa: S301 + + +@pytest.mark.skipif( + "matplotlib" not in _get_installed_packages(), + reason="skip test if required package matplotlib not installed", +) +def test_prediction_visualization(model, dataloaders_with_covariates): + raw_predictions = model.predict( + dataloaders_with_covariates["val"], + mode="raw", + return_x=True, + fast_dev_run=True, + ) + model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=0) + + +def test_prediction_with_kwargs(model, dataloaders_with_covariates): + # Tests prediction works with different keyword arguments + model.predict( + dataloaders_with_covariates["val"], return_index=True, fast_dev_run=True + ) + model.predict( + dataloaders_with_covariates["val"], + return_x=True, + return_y=True, + fast_dev_run=True, + ) + + +def test_no_exogenous_variable(): + data = pd.DataFrame( + { + "target": np.ones(1600), + "group_id": np.repeat(np.arange(16), 100), + "time_idx": np.tile(np.arange(100), 16), + } + ) + training_dataset = TimeSeriesDataSet( + data=data, + time_idx="time_idx", + target="target", + group_ids=["group_id"], + max_encoder_length=10, + max_prediction_length=5, + time_varying_unknown_reals=["target"], + time_varying_known_reals=[], + ) + validation_dataset = TimeSeriesDataSet.from_dataset( + training_dataset, data, stop_randomization=True, predict=True + ) + training_data_loader = training_dataset.to_dataloader( + train=True, batch_size=8, num_workers=0 + ) + validation_data_loader = validation_dataset.to_dataloader( + train=False, batch_size=8, num_workers=0 + ) + forecaster = TiDEModel.from_dataset( + training_dataset, + ) + from lightning.pytorch import Trainer + + trainer = Trainer( + max_epochs=2, + limit_train_batches=8, + limit_val_batches=8, + ) + trainer.fit( + forecaster, + train_dataloaders=training_data_loader, + val_dataloaders=validation_data_loader, + ) + best_model_path = trainer.checkpoint_callback.best_model_path + best_model = TiDEModel.load_from_checkpoint(best_model_path) + best_model.predict( + validation_data_loader, + return_x=True, + return_y=True, + return_index=True, + ) From e46f9f6b078911ed90dd8e1b45dcd0aacfd0ded0 Mon Sep 17 00:00:00 2001 From: Pranav Bhat Date: Thu, 22 May 2025 11:45:33 +0530 Subject: [PATCH 29/33] refactor _integration to TiDE specific _integration function --- tests/test_models/test_tide.py | 107 +++++++++++++-------------------- 1 file changed, 43 insertions(+), 64 deletions(-) diff --git a/tests/test_models/test_tide.py b/tests/test_models/test_tide.py index 7931a15e0..c461d1158 100644 --- a/tests/test_models/test_tide.py +++ b/tests/test_models/test_tide.py @@ -11,76 +11,54 @@ from pytorch_forecasting.data.timeseries import TimeSeriesDataSet from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss from pytorch_forecasting.models import TiDEModel +from pytorch_forecasting.tests.test_all_estimators import _integration from pytorch_forecasting.utils._dependencies import _get_installed_packages -def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs): - train_dataloader = dataloader["train"] - val_dataloader = dataloader["val"] - test_dataloader = dataloader["test"] +def _tide_integration(dataloaders, tmp_path, trainer_kwargs=None, **kwargs): + """TiDE specific wrapper around the common integration test function. - early_stop_callback = EarlyStopping( - monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" - ) + Args: + dataloaders: Dictionary of dataloaders for train, val, and test. + tmp_path: Temporary path for saving the model. + trainer_kwargs: Additional arguments for the Trainer. + **kwargs: Additional arguments for the TiDEModel. - logger = TensorBoardLogger(tmp_path) - if trainer_kwargs is None: - trainer_kwargs = {} - trainer = pl.Trainer( - max_epochs=2, - gradient_clip_val=0.1, - callbacks=[early_stop_callback], - enable_checkpointing=True, - default_root_dir=tmp_path, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - logger=logger, - **trainer_kwargs, - ) + Returns: + Predictions from the trained model. + """ - kwargs.setdefault("hidden_size", 16) - kwargs.setdefault("temporal_decoder_hidden", 8) - kwargs.setdefault("temporal_width_future", 4) - kwargs.setdefault("dropout", 0.1) - kwargs.setdefault("learning_rate", 0.01) + from pytorch_forecasting.tests._data_scenarios import data_with_covariates - net = TiDEModel.from_dataset( - train_dataloader.dataset, - **kwargs, - ) - net.size() - try: - trainer.fit( - net, - train_dataloaders=train_dataloader, - val_dataloaders=val_dataloader, - ) - test_outputs = trainer.test(net, dataloaders=test_dataloader) - assert len(test_outputs) > 0 - # check loading - net = TiDEModel.load_from_checkpoint( - trainer.checkpoint_callback.best_model_path - ) - - # check prediction - net.predict( - val_dataloader, - fast_dev_run=True, - return_index=True, - return_decoder_lengths=True, - trainer_kwargs=trainer_kwargs, - ) - finally: - shutil.rmtree(tmp_path, ignore_errors=True) - - predictions = net.predict( - val_dataloader, - fast_dev_run=True, - return_index=True, - return_decoder_lengths=True, + df = data_with_covariates() + + tide_kwargs = { + "temporal_decoder_hidden": 8, + "temporal_width_future": 4, + "droupout": 0.1, + } + + tide_kwargs.update(kwargs) + train_dataset = dataloaders["train"].dataset + + data_loader_kwargs = { + "target": train_dataset.target, + "group_ids": train_dataset.group_ids, + "time_varying_known_reals": train_dataset.time_varying_known_reals, + "time_varying_unknown_reals": train_dataset.time_varying_unknown_reals, + "static_categoricals": train_dataset.static_categoricals, + "static_reals": train_dataset.static_reals, + "add_relative_time_idx": train_dataset.add_relative_time_idx, + } + + return _integration( + TiDEModel, + df, + tmp_path, + data_loader_kwargs=data_loader_kwargs, + trainer_kwargs=trainer_kwargs, + **tide_kwargs, ) - return predictions @pytest.mark.parametrize( @@ -93,7 +71,7 @@ def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs): ], ) def test_integration(dataloaders_with_covariates, tmp_path, kwargs): - _integration(dataloaders_with_covariates, tmp_path, **kwargs) + _tide_integration(dataloaders_with_covariates, tmp_path, **kwargs) @pytest.mark.parametrize( @@ -103,7 +81,7 @@ def test_integration(dataloaders_with_covariates, tmp_path, kwargs): ], ) def test_multi_target_integration(dataloaders_multi_target, tmp_path, kwargs): - _integration(dataloaders_multi_target, tmp_path, **kwargs) + _tide_integration(dataloaders_multi_target, tmp_path, **kwargs) @pytest.fixture @@ -196,6 +174,7 @@ def test_no_exogenous_variable(): best_model = TiDEModel.load_from_checkpoint(best_model_path) best_model.predict( validation_data_loader, + fast_dev_run=True, return_x=True, return_y=True, return_index=True, From 6dfe1a8dd2f3e72b22e3c59487c395b1f7e8275f Mon Sep 17 00:00:00 2001 From: Pranav Bhat Date: Thu, 22 May 2025 14:58:29 +0530 Subject: [PATCH 30/33] remove model-specific params from _integration in test_all_estimators --- pytorch_forecasting/tests/test_all_estimators.py | 3 --- tests/test_models/test_tide.py | 6 ++---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index b8a21cc6a..529cd7bb4 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -171,7 +171,6 @@ def _integration( estimator_cls, data_with_covariates, tmp_path, - cell_type="LSTM", data_loader_kwargs={}, clip_target: bool = False, trainer_kwargs=None, @@ -221,11 +220,9 @@ def _integration( net = estimator_cls.from_dataset( train_dataloader.dataset, hidden_size=5, - cell_type=cell_type, learning_rate=0.01, log_gradient_flow=True, log_interval=1000, - n_plotting_samples=100, **kwargs, ) net.size() diff --git a/tests/test_models/test_tide.py b/tests/test_models/test_tide.py index c461d1158..3a5073c7a 100644 --- a/tests/test_models/test_tide.py +++ b/tests/test_models/test_tide.py @@ -27,7 +27,6 @@ def _tide_integration(dataloaders, tmp_path, trainer_kwargs=None, **kwargs): Returns: Predictions from the trained model. """ - from pytorch_forecasting.tests._data_scenarios import data_with_covariates df = data_with_covariates() @@ -35,7 +34,7 @@ def _tide_integration(dataloaders, tmp_path, trainer_kwargs=None, **kwargs): tide_kwargs = { "temporal_decoder_hidden": 8, "temporal_width_future": 4, - "droupout": 0.1, + "dropout": 0.1, } tide_kwargs.update(kwargs) @@ -50,7 +49,6 @@ def _tide_integration(dataloaders, tmp_path, trainer_kwargs=None, **kwargs): "static_reals": train_dataset.static_reals, "add_relative_time_idx": train_dataset.add_relative_time_idx, } - return _integration( TiDEModel, df, @@ -66,7 +64,7 @@ def _tide_integration(dataloaders, tmp_path, trainer_kwargs=None, **kwargs): [ {}, {"loss": SMAPE()}, - {"hidden_size": 32, "temporal_decoder_hidden": 16}, + {"temporal_decoder_hidden": 16}, {"dropout": 0.2, "use_layer_norm": True}, ], ) From f23d4d14e6c64da050ec8452f46c1b347fc5fcd4 Mon Sep 17 00:00:00 2001 From: Pranav Bhat Date: Tue, 27 May 2025 16:14:03 +0530 Subject: [PATCH 31/33] add metadata class for tide --- .../models/tide/_tide_metadata.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 pytorch_forecasting/models/tide/_tide_metadata.py diff --git a/pytorch_forecasting/models/tide/_tide_metadata.py b/pytorch_forecasting/models/tide/_tide_metadata.py new file mode 100644 index 000000000..35caacbfe --- /dev/null +++ b/pytorch_forecasting/models/tide/_tide_metadata.py @@ -0,0 +1,53 @@ +"""TiDE metadata container.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecaster + + +class TiDEModelMetadata(_BasePtForecaster): + """Metadata container for TiDE Model.""" + + _tags = { + "info:name": "TiDEModel", + "info:compute": 3, + "authors": ["Sohaib-Ahmed21"], + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": True, + "capability:flexible_history_length": True, + "capability:cold_start": False, + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + from pytorch_forecasting.models.tide import TiDEModel + + return TiDEModel + + @classmethod + def get_test_train_params(cls): + """Return testing parameter settings for the trainer. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class. + """ + + from pytorch_forecasting.data.encoders import GroupNormalizer + from pytorch_forecasting.metrics import SMAPE + + return [ + {}, + {"temporal_decoder_hidden": 16}, + { + "dropout": 0.2, + "use_layer_norm": True, + "loss": SMAPE(), + "data_loader_kwargs": dict( + target_normalizer=GroupNormalizer( + groups=["group"], transformation="softplus" + ) + ), + }, + ] From 228c2f143d2c2462cfb731b59e6f842278850e86 Mon Sep 17 00:00:00 2001 From: Pranav Bhat Date: Tue, 27 May 2025 16:18:19 +0530 Subject: [PATCH 32/33] add TiDEModelMetadata to __init__.py --- pytorch_forecasting/models/tide/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_forecasting/models/tide/__init__.py b/pytorch_forecasting/models/tide/__init__.py index 0f265a153..35bf89611 100644 --- a/pytorch_forecasting/models/tide/__init__.py +++ b/pytorch_forecasting/models/tide/__init__.py @@ -1,9 +1,11 @@ """Tide model.""" from pytorch_forecasting.models.tide._tide import TiDEModel +from pytorch_forecasting.models.tide._tide_metadata import TiDEModelMetadata from pytorch_forecasting.models.tide.sub_modules import _TideModule __all__ = [ "_TideModule", "TiDEModel", + "TiDEModelMetadata", ] From 9be3f11ae96302b0e3bef90e83290af36bb40683 Mon Sep 17 00:00:00 2001 From: Pranav Bhat Date: Tue, 27 May 2025 20:14:43 +0530 Subject: [PATCH 33/33] fixed model-specific changes to provide test compatibility to TiDE --- .../models/tide/_tide_metadata.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/pytorch_forecasting/models/tide/_tide_metadata.py b/pytorch_forecasting/models/tide/_tide_metadata.py index 35caacbfe..30821c20f 100644 --- a/pytorch_forecasting/models/tide/_tide_metadata.py +++ b/pytorch_forecasting/models/tide/_tide_metadata.py @@ -38,16 +38,26 @@ def get_test_train_params(cls): from pytorch_forecasting.metrics import SMAPE return [ - {}, - {"temporal_decoder_hidden": 16}, + { + "data_loader_kwargs": dict( + add_relative_time_idx=False, + # must include this everytime since the data_loader_default_kwargs + # include this to be True. + ) + }, + { + "temporal_decoder_hidden": 16, + "data_loader_kwargs": dict(add_relative_time_idx=False), + }, { "dropout": 0.2, "use_layer_norm": True, "loss": SMAPE(), "data_loader_kwargs": dict( target_normalizer=GroupNormalizer( - groups=["group"], transformation="softplus" - ) + groups=["agency", "sku"], transformation="softplus" + ), + add_relative_time_idx=False, ), }, ]