From b3644a65b041a790b94756fb1d9bbf2797236d0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 22 Feb 2025 23:18:51 +0100 Subject: [PATCH 01/26] test suite --- pyproject.toml | 1 + pytorch_forecasting/tests/__init__.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 pytorch_forecasting/tests/__init__.py diff --git a/pyproject.toml b/pyproject.toml index f3d1e339c..8c661db1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,7 @@ dev = [ "pytest-dotenv>=0.5.2,<1.0.0", "tensorboard>=2.12.1,<3.0.0", "pandoc>=2.3,<3.0.0", + "scikit-base", ] # docs - dependencies for building the documentation diff --git a/pytorch_forecasting/tests/__init__.py b/pytorch_forecasting/tests/__init__.py new file mode 100644 index 000000000..6c2d26856 --- /dev/null +++ b/pytorch_forecasting/tests/__init__.py @@ -0,0 +1 @@ +"""PyTorch Forecasting test suite.""" From 4b2486e083ca93d8f4c1a29a6a25d882027815f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 22 Feb 2025 23:33:29 +0100 Subject: [PATCH 02/26] skeleton --- pytorch_forecasting/_registry/__init__.py | 16 ++ pytorch_forecasting/_registry/_lookup.py | 214 ++++++++++++++++++ pytorch_forecasting/models/base/__init__.py | 2 + .../models/base/_base_object.py | 8 + pytorch_forecasting/tests/_config.py | 13 ++ .../tests/test_all_estimators.py | 118 ++++++++++ 6 files changed, 371 insertions(+) create mode 100644 pytorch_forecasting/_registry/__init__.py create mode 100644 pytorch_forecasting/_registry/_lookup.py create mode 100644 pytorch_forecasting/models/base/_base_object.py create mode 100644 pytorch_forecasting/tests/_config.py create mode 100644 pytorch_forecasting/tests/test_all_estimators.py diff --git a/pytorch_forecasting/_registry/__init__.py b/pytorch_forecasting/_registry/__init__.py new file mode 100644 index 000000000..bb0b88e61 --- /dev/null +++ b/pytorch_forecasting/_registry/__init__.py @@ -0,0 +1,16 @@ +"""PyTorch Forecasting registry.""" + +from pytorch_forecasting._registry._lookup import all_objects, all_tags +from pytorch_forecasting._registry._tags import ( + OBJECT_TAG_LIST, + OBJECT_TAG_REGISTER, + check_tag_is_valid, +) + +__all__ = [ + "OBJECT_TAG_LIST", + "OBJECT_TAG_REGISTER", + "all_objects", + "all_tags", + "check_tag_is_valid", +] diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py new file mode 100644 index 000000000..ea3210cb0 --- /dev/null +++ b/pytorch_forecasting/_registry/_lookup.py @@ -0,0 +1,214 @@ +"""Registry lookup methods. + +This module exports the following methods for registry lookup: + +all_objects(object_types, filter_tags) + lookup and filtering of objects +""" +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) +# based on the sktime module of same name + +__author__ = ["fkiraly"] +# all_objects is based on the sklearn utility all_estimators + + +from copy import deepcopy +from operator import itemgetter +from pathlib import Path + +import pandas as pd +from skbase.lookup import all_objects as _all_objects + +from pytorch_forecasting.models.base import _BaseObject + + +def all_objects( + object_types=None, + filter_tags=None, + exclude_objects=None, + return_names=True, + as_dataframe=False, + return_tags=None, + suppress_import_stdout=True, +): + """Get a list of all objects from pytorch_forecasting. + + This function crawls the module and gets all classes that inherit + from skbase compatible base classes. + + Not included are: the base classes themselves, classes defined in test + modules. + + Parameters + ---------- + object_types: str, list of str, optional (default=None) + Which kind of objects should be returned. + if None, no filter is applied and all objects are returned. + if str or list of str, strings define scitypes specified in search + only objects that are of (at least) one of the scitypes are returned + possible str values are entries of registry.BASE_CLASS_REGISTER (first col) + for instance 'regrssor_proba', 'distribution, 'metric' + + return_names: bool, optional (default=True) + + if True, estimator class name is included in the ``all_objects`` + return in the order: name, estimator class, optional tags, either as + a tuple or as pandas.DataFrame columns + + if False, estimator class name is removed from the ``all_objects`` return. + + filter_tags: dict of (str or list of str), optional (default=None) + For a list of valid tag strings, use the registry.all_tags utility. + + ``filter_tags`` subsets the returned estimators as follows: + + * each key/value pair is statement in "and"/conjunction + * key is tag name to sub-set on + * value str or list of string are tag values + * condition is "key must be equal to value, or in set(value)" + + exclude_estimators: str, list of str, optional (default=None) + Names of estimators to exclude. + + as_dataframe: bool, optional (default=False) + + True: ``all_objects`` will return a pandas.DataFrame with named + columns for all of the attributes being returned. + + False: ``all_objects`` will return a list (either a list of + estimators or a list of tuples, see Returns) + + return_tags: str or list of str, optional (default=None) + Names of tags to fetch and return each estimator's value of. + For a list of valid tag strings, use the registry.all_tags utility. + if str or list of str, + the tag values named in return_tags will be fetched for each + estimator and will be appended as either columns or tuple entries. + + suppress_import_stdout : bool, optional. Default=True + whether to suppress stdout printout upon import. + + Returns + ------- + all_objects will return one of the following: + 1. list of objects, if return_names=False, and return_tags is None + 2. list of tuples (optional object name, class, ~optional object + tags), if return_names=True or return_tags is not None. + 3. pandas.DataFrame if as_dataframe = True + if list of objects: + entries are objects matching the query, + in alphabetical order of object name + if list of tuples: + list of (optional object name, object, optional object + tags) matching the query, in alphabetical order of object name, + where + ``name`` is the object name as string, and is an + optional return + ``object`` is the actual object + ``tags`` are the object's values for each tag in return_tags + and is an optional return. + if dataframe: + all_objects will return a pandas.DataFrame. + column names represent the attributes contained in each column. + "objects" will be the name of the column of objects, "names" + will be the name of the column of object class names and the string(s) + passed in return_tags will serve as column names for all columns of + tags that were optionally requested. + + Examples + -------- + >>> from skpro.registry import all_objects + >>> # return a complete list of objects as pd.Dataframe + >>> all_objects(as_dataframe=True) # doctest: +SKIP + >>> # return all probabilistic regressors by filtering for object type + >>> all_objects("regressor_proba", as_dataframe=True) # doctest: +SKIP + >>> # return all regressors which handle missing data in the input by tag filtering + >>> all_objects( + ... "regressor_proba", + ... filter_tags={"capability:missing": True}, + ... as_dataframe=True + ... ) # doctest: +SKIP + + References + ---------- + Adapted version of sktime's ``all_estimators``, + which is an evolution of scikit-learn's ``all_estimators`` + """ + MODULES_TO_IGNORE = ( + "tests", + "setup", + "contrib", + "utils", + "all", + ) + + result = [] + ROOT = str(Path(__file__).parent.parent) # skpro package root directory + + if isinstance(filter_tags, str): + filter_tags = {filter_tags: True} + filter_tags = filter_tags.copy() if filter_tags else None + + if object_types: + if filter_tags and "object_type" not in filter_tags.keys(): + object_tag_filter = {"object_type": object_types} + elif filter_tags: + filter_tags_filter = filter_tags.get("object_type", []) + if isinstance(object_types, str): + object_types = [object_types] + object_tag_update = {"object_type": object_types + filter_tags_filter} + filter_tags.update(object_tag_update) + else: + object_tag_filter = {"object_type": object_types} + if filter_tags: + filter_tags.update(object_tag_filter) + else: + filter_tags = object_tag_filter + + result = _all_objects( + object_types=[_BaseObject], + filter_tags=filter_tags, + exclude_objects=exclude_objects, + return_names=return_names, + as_dataframe=as_dataframe, + return_tags=return_tags, + suppress_import_stdout=suppress_import_stdout, + package_name="skpro", + path=ROOT, + modules_to_ignore=MODULES_TO_IGNORE, + ) + + return result + + +def _check_list_of_str_or_error(arg_to_check, arg_name): + """Check that certain arguments are str or list of str. + + Parameters + ---------- + arg_to_check: argument we are testing the type of + arg_name: str, + name of the argument we are testing, will be added to the error if + ``arg_to_check`` is not a str or a list of str + + Returns + ------- + arg_to_check: list of str, + if arg_to_check was originally a str it converts it into a list of str + so that it can be iterated over. + + Raises + ------ + TypeError if arg_to_check is not a str or list of str + """ + # check that return_tags has the right type: + if isinstance(arg_to_check, str): + arg_to_check = [arg_to_check] + if not isinstance(arg_to_check, list) or not all( + isinstance(value, str) for value in arg_to_check + ): + raise TypeError( + f"Error in all_objects! Argument {arg_name} must be either\ + a str or list of str" + ) + return arg_to_check diff --git a/pytorch_forecasting/models/base/__init__.py b/pytorch_forecasting/models/base/__init__.py index 4860e4838..474e7d564 100644 --- a/pytorch_forecasting/models/base/__init__.py +++ b/pytorch_forecasting/models/base/__init__.py @@ -7,8 +7,10 @@ BaseModelWithCovariates, Prediction, ) +from pytorch_forecasting.models.base._base_object import _BaseObject __all__ = [ + "_BaseObject", "AutoRegressiveBaseModel", "AutoRegressiveBaseModelWithCovariates", "BaseModel", diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py new file mode 100644 index 000000000..7330867b1 --- /dev/null +++ b/pytorch_forecasting/models/base/_base_object.py @@ -0,0 +1,8 @@ +"""Base Classes for pytorch-forecasting models, skbase compatible for indexing.""" + +from skbase.base import BaseObject as _SkbaseBaseObject + + +class _BaseObject(_SkbaseBaseObject): + + pass diff --git a/pytorch_forecasting/tests/_config.py b/pytorch_forecasting/tests/_config.py new file mode 100644 index 000000000..dd9c2e889 --- /dev/null +++ b/pytorch_forecasting/tests/_config.py @@ -0,0 +1,13 @@ +"""Test configs.""" + +# list of str, names of estimators to exclude from testing +# WARNING: tests for these estimators will be skipped +EXCLUDE_ESTIMATORS = [ + "DummySkipped", + "ClassName", # exclude classes from extension templates +] + +# dictionary of lists of str, names of tests to exclude from testing +# keys are class names of estimators, values are lists of test names to exclude +# WARNING: tests with these names will be skipped +EXCLUDED_TESTS = {} diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py new file mode 100644 index 000000000..c806691fa --- /dev/null +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -0,0 +1,118 @@ +"""Automated tests based on the skbase test suite template.""" +from inspect import isclass + +from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator +from skbase.testing import TestAllObjects as _TestAllObjects + +from pytorch_forecasting._registry import all_objects +from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS + + +# whether to test only estimators from modules that are changed w.r.t. main +# default is False, can be set to True by pytest --only_changed_modules True flag +ONLY_CHANGED_MODULES = False + + +class PackageConfig: + """Contains package config variables for test classes.""" + + # class variables which can be overridden by descendants + # ------------------------------------------------------ + + # package to search for objects + # expected type: str, package/module name, relative to python environment root + package_name = "pytroch_forecasting" + + # list of object types (class names) to exclude + # expected type: list of str, str are class names + exclude_objects = EXCLUDE_ESTIMATORS + + # list of tests to exclude + # expected type: dict of lists, key:str, value: List[str] + # keys are class names of estimators, values are lists of test names to exclude + excluded_tests = EXCLUDED_TESTS + + +class BaseFixtureGenerator(_BaseFixtureGenerator): + """Fixture generator for base testing functionality in sktime. + + Test classes inheriting from this and not overriding pytest_generate_tests + will have estimator and scenario fixtures parametrized out of the box. + + Descendants can override: + estimator_type_filter: str, class variable; None or scitype string + e.g., "forecaster", "transformer", "classifier", see BASE_CLASS_SCITYPE_LIST + which estimators are being retrieved and tested + fixture_sequence: list of str + sequence of fixture variable names in conditional fixture generation + _generate_[variable]: object methods, all (test_name: str, **kwargs) -> list + generating list of fixtures for fixture variable with name [variable] + to be used in test with name test_name + can optionally use values for fixtures earlier in fixture_sequence, + these must be input as kwargs in a call + is_excluded: static method (test_name: str, est: class) -> bool + whether test with name test_name should be excluded for estimator est + should be used only for encoding general rules, not individual skips + individual skips should go on the EXCLUDED_TESTS list in _config + requires _generate_object_class and _generate_object_instance as is + _excluded_scenario: static method (test_name: str, scenario) -> bool + whether scenario should be skipped in test with test_name test_name + requires _generate_estimator_scenario as is + + Fixtures parametrized + --------------------- + object_class: estimator inheriting from BaseObject + ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS + object_instance: instance of estimator inheriting from BaseObject + ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS + instances are generated by create_test_instance class method of object_class + """ + + # overrides object retrieval in scikit-base + def _all_objects(self): + """Retrieve list of all object classes of type self.object_type_filter. + + If self.object_type_filter is None, retrieve all objects. + If class, retrieve all classes inheriting from self.object_type_filter. + Otherwise (assumed str or list of str), retrieve all classes with tags + object_type in self.object_type_filter. + """ + filter = getattr(self, "object_type_filter", None) + + if isclass(filter): + object_types = filter.get_class_tag("object_type", None) + else: + object_types = filter + + obj_list = all_objects( + object_types=object_types, + return_names=False, + exclude_objects=self.exclude_objects, + ) + + if isclass(filter): + obj_list = [obj for obj in obj_list if issubclass(obj, filter)] + + # run_test_for_class selects the estimators to run + # based on whether they have changed, and whether they have all dependencies + # internally, uses the ONLY_CHANGED_MODULES flag, + # and checks the python env against python_dependencies tag + # obj_list = [obj for obj in obj_list if run_test_for_class(obj)] + + return obj_list + + # which sequence the conditional fixtures are generated in + fixture_sequence = [ + "object_class", + "object_instance", + ] + + +class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator, _TestAllObjects): + """Generic tests for all objects in the mini package.""" + + def test_doctest_examples(self, object_class): + """Runs doctests for estimator class.""" + import doctest + + doctest.run_docstring_examples(object_class, globals()) From 02b0ce6fa53443044fffce8cbbce54a0c6d6b947 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 22 Feb 2025 23:33:58 +0100 Subject: [PATCH 03/26] skeleton --- pytorch_forecasting/_registry/__init__.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/pytorch_forecasting/_registry/__init__.py b/pytorch_forecasting/_registry/__init__.py index bb0b88e61..f71836bfe 100644 --- a/pytorch_forecasting/_registry/__init__.py +++ b/pytorch_forecasting/_registry/__init__.py @@ -1,16 +1,5 @@ """PyTorch Forecasting registry.""" -from pytorch_forecasting._registry._lookup import all_objects, all_tags -from pytorch_forecasting._registry._tags import ( - OBJECT_TAG_LIST, - OBJECT_TAG_REGISTER, - check_tag_is_valid, -) +from pytorch_forecasting._registry._lookup import all_objects -__all__ = [ - "OBJECT_TAG_LIST", - "OBJECT_TAG_REGISTER", - "all_objects", - "all_tags", - "check_tag_is_valid", -] +__all__ = ["all_objects"] From 41cbf667f9aea5848c3390778a53612338319504 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 06:42:02 +0100 Subject: [PATCH 04/26] Update test_all_estimators.py --- pytorch_forecasting/tests/test_all_estimators.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index c806691fa..704ddfc20 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -1,13 +1,15 @@ """Automated tests based on the skbase test suite template.""" + from inspect import isclass -from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator -from skbase.testing import TestAllObjects as _TestAllObjects +from skbase.testing import ( + BaseFixtureGenerator as _BaseFixtureGenerator, + TestAllObjects as _TestAllObjects, +) from pytorch_forecasting._registry import all_objects from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS - # whether to test only estimators from modules that are changed w.r.t. main # default is False, can be set to True by pytest --only_changed_modules True flag ONLY_CHANGED_MODULES = False From cef62d36df5eceff5238a2d6c7fd829319028446 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 06:47:24 +0100 Subject: [PATCH 05/26] Update _base_object.py --- pytorch_forecasting/models/base/_base_object.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 7330867b1..a91b10525 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -1,6 +1,8 @@ """Base Classes for pytorch-forecasting models, skbase compatible for indexing.""" -from skbase.base import BaseObject as _SkbaseBaseObject +from pytorch_forecasting.utils._dependencies import _safe_import + +_SkbaseBaseObject = _safe_import("skbase._base_object._BaseObject") class _BaseObject(_SkbaseBaseObject): From bc2e93b606095440772f7236eeebb070109c649f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 09:45:10 +0100 Subject: [PATCH 06/26] Update _lookup.py --- pytorch_forecasting/_registry/_lookup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index ea3210cb0..517bfa9af 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -5,7 +5,6 @@ all_objects(object_types, filter_tags) lookup and filtering of objects """ -# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) # based on the sktime module of same name __author__ = ["fkiraly"] From eee1c86859dc1d66d46eb85c7b39938639f8231e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 10:37:22 +0100 Subject: [PATCH 07/26] Update _lookup.py --- pytorch_forecasting/_registry/_lookup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index 517bfa9af..1ab4cfdb3 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -5,6 +5,7 @@ all_objects(object_types, filter_tags) lookup and filtering of objects """ + # based on the sktime module of same name __author__ = ["fkiraly"] From 164fe0d238ebe6b9f888c416f998a73948b365f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:05:54 +0100 Subject: [PATCH 08/26] base metadatda --- .../models/base/_base_object.py | 98 ++++++++++++++ .../models/deepar/_deepar_metadata.py | 128 ++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 pytorch_forecasting/models/deepar/_deepar_metadata.py diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index a91b10525..62fad456b 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -1,5 +1,7 @@ """Base Classes for pytorch-forecasting models, skbase compatible for indexing.""" +import inspect + from pytorch_forecasting.utils._dependencies import _safe_import _SkbaseBaseObject = _safe_import("skbase._base_object._BaseObject") @@ -8,3 +10,99 @@ class _BaseObject(_SkbaseBaseObject): pass + + +class _BasePtForecaster(_BaseObject): + """Base class for all PyTorch Forecasting forecaster metadata. + + This class points to model objects and contains metadata as tags. + """ + + _tags = { + "object_type": "forecaster_pytorch", + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + raise NotImplementedError + + + @classmethod + def create_test_instance(cls, parameter_set="default"): + """Construct an instance of the class, using first test parameter set. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + instance : instance of the class with default parameters + + """ + if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args: + params = cls.get_test_params(parameter_set=parameter_set) + else: + params = cls.get_test_params() + + if isinstance(params, list) and isinstance(params[0], dict): + params = params[0] + elif isinstance(params, dict): + pass + else: + raise TypeError( + "get_test_params should either return a dict or list of dict." + ) + + return cls.get_model_cls()(**params) + + @classmethod + def create_test_instances_and_names(cls, parameter_set="default"): + """Create list of all test instances and a list of names for them. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + objs : list of instances of cls + i-th instance is ``cls(**cls.get_test_params()[i])`` + names : list of str, same length as objs + i-th element is name of i-th instance of obj in tests. + The naming convention is ``{cls.__name__}-{i}`` if more than one instance, + otherwise ``{cls.__name__}`` + """ + if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args: + param_list = cls.get_test_params(parameter_set=parameter_set) + else: + param_list = cls.get_test_params() + + objs = [] + if not isinstance(param_list, (dict, list)): + raise RuntimeError( + f"Error in {cls.__name__}.get_test_params, " + "return must be param dict for class, or list thereof" + ) + if isinstance(param_list, dict): + param_list = [param_list] + for params in param_list: + if not isinstance(params, dict): + raise RuntimeError( + f"Error in {cls.__name__}.get_test_params, " + "return must be param dict for class, or list thereof" + ) + objs += [cls.get_model_cls()(**params)] + + num_instances = len(param_list) + if num_instances > 1: + names = [cls.__name__ + "-" + str(i) for i in range(num_instances)] + else: + names = [cls.__name__] + + return objs, names diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py new file mode 100644 index 000000000..330b86e80 --- /dev/null +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -0,0 +1,128 @@ +"""DeepAR metadata container.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecaster + + +class _DeepARMetadata(_BasePtForecaster): + """DeepAR metadata container.""" + + _tags = { + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": True, + "capability:flexible_history_length": True, + "capability:cold_start": False, + "info:compute": 3, + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + from pytorch_forecasting.models import DeepAR + + return DeepAR + + @classmethod + def get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the skbase object. + + ``get_test_params`` is a unified interface point to store + parameter settings for testing purposes. This function is also + used in ``create_test_instance`` and ``create_test_instances_and_names`` + to construct test instances. + + ``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``. + + Each ``dict`` is a parameter configuration for testing, + and can be used to construct an "interesting" test instance. + A call to ``cls(**params)`` should + be valid for all dictionaries ``params`` in the return of ``get_test_params``. + + The ``get_test_params`` need not return fixed lists of dictionaries, + it can also return dynamic or stochastic parameter settings. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + from pytorch_forecasting.data.encoders import GroupNormalizer + from pytorch_forecasting.metrics import ( + BetaDistributionLoss, + ImplicitQuantileNetworkDistributionLoss, + LogNormalDistributionLoss, + MultivariateNormalDistributionLoss, + NegativeBinomialDistributionLoss, + ) + + return [ + {}, + {"cell_type": "GRU"}, + dict( + loss=LogNormalDistributionLoss(), + clip_target=True, + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="log" + ) + ), + ), + dict( + loss=NegativeBinomialDistributionLoss(), + clip_target=False, + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], center=False + ) + ), + ), + dict( + loss=BetaDistributionLoss(), + clip_target=True, + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="logit" + ) + ), + ), + dict( + data_loader_kwargs=dict( + lags={"volume": [2, 5]}, + target="volume", + time_varying_unknown_reals=["volume"], + min_encoder_length=2, + ) + ), + dict( + data_loader_kwargs=dict( + time_varying_unknown_reals=["volume", "discount"], + target=["volume", "discount"], + lags={"volume": [2], "discount": [2]}, + ) + ), + dict( + loss=ImplicitQuantileNetworkDistributionLoss(hidden_size=8), + ), + dict( + loss=MultivariateNormalDistributionLoss(), + trainer_kwargs=dict(accelerator="cpu"), + ), + dict( + loss=MultivariateNormalDistributionLoss(), + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="log1p" + ) + ), + trainer_kwargs=dict(accelerator="cpu"), + ), + ] From 20e88d09993f3fed62ab52f93d0a4678f1a0c068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:21:34 +0100 Subject: [PATCH 09/26] registry --- pytorch_forecasting/_registry/_lookup.py | 18 +++--------------- pytorch_forecasting/models/base/__init__.py | 6 +++++- .../models/base/_base_object.py | 2 +- .../utils/_dependencies/_safe_import.py | 3 +++ 4 files changed, 12 insertions(+), 17 deletions(-) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index 1ab4cfdb3..828c448b1 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -11,12 +11,8 @@ __author__ = ["fkiraly"] # all_objects is based on the sklearn utility all_estimators - -from copy import deepcopy -from operator import itemgetter from pathlib import Path -import pandas as pd from skbase.lookup import all_objects as _all_objects from pytorch_forecasting.models.base import _BaseObject @@ -117,17 +113,9 @@ def all_objects( Examples -------- - >>> from skpro.registry import all_objects + >>> from pytorch_forecasting._registry import all_objects >>> # return a complete list of objects as pd.Dataframe >>> all_objects(as_dataframe=True) # doctest: +SKIP - >>> # return all probabilistic regressors by filtering for object type - >>> all_objects("regressor_proba", as_dataframe=True) # doctest: +SKIP - >>> # return all regressors which handle missing data in the input by tag filtering - >>> all_objects( - ... "regressor_proba", - ... filter_tags={"capability:missing": True}, - ... as_dataframe=True - ... ) # doctest: +SKIP References ---------- @@ -143,7 +131,7 @@ def all_objects( ) result = [] - ROOT = str(Path(__file__).parent.parent) # skpro package root directory + ROOT = str(Path(__file__).parent.parent) # package root directory if isinstance(filter_tags, str): filter_tags = {filter_tags: True} @@ -173,7 +161,7 @@ def all_objects( as_dataframe=as_dataframe, return_tags=return_tags, suppress_import_stdout=suppress_import_stdout, - package_name="skpro", + package_name="pytorch_forecasting", path=ROOT, modules_to_ignore=MODULES_TO_IGNORE, ) diff --git a/pytorch_forecasting/models/base/__init__.py b/pytorch_forecasting/models/base/__init__.py index 474e7d564..7b69ec246 100644 --- a/pytorch_forecasting/models/base/__init__.py +++ b/pytorch_forecasting/models/base/__init__.py @@ -7,10 +7,14 @@ BaseModelWithCovariates, Prediction, ) -from pytorch_forecasting.models.base._base_object import _BaseObject +from pytorch_forecasting.models.base._base_object import ( + _BaseObject, + _BasePtForecaster, +) __all__ = [ "_BaseObject", + "_BasePtForecaster", "AutoRegressiveBaseModel", "AutoRegressiveBaseModelWithCovariates", "BaseModel", diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 62fad456b..8895b4c2c 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -4,7 +4,7 @@ from pytorch_forecasting.utils._dependencies import _safe_import -_SkbaseBaseObject = _safe_import("skbase._base_object._BaseObject") +_SkbaseBaseObject = _safe_import("skbase.base.BaseObject", pkg_name="scikit-base") class _BaseObject(_SkbaseBaseObject): diff --git a/pytorch_forecasting/utils/_dependencies/_safe_import.py b/pytorch_forecasting/utils/_dependencies/_safe_import.py index f4805f9c1..ffbde8b5d 100644 --- a/pytorch_forecasting/utils/_dependencies/_safe_import.py +++ b/pytorch_forecasting/utils/_dependencies/_safe_import.py @@ -70,6 +70,9 @@ def _safe_import(import_path, pkg_name=None): if pkg_name is None: path_list = import_path.split(".") pkg_name = path_list[0] + else: + path_list = import_path.split(".") + path_list = [pkg_name] + path_list[1:] if pkg_name in _get_installed_packages(): try: From 318c1fbdbfface24fdc67568cf9a01a6dde1650c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:33:42 +0100 Subject: [PATCH 10/26] fix private name --- pytorch_forecasting/models/deepar/__init__.py | 3 ++- pytorch_forecasting/models/deepar/_deepar_metadata.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/models/deepar/__init__.py b/pytorch_forecasting/models/deepar/__init__.py index 679f296f6..149e19e0d 100644 --- a/pytorch_forecasting/models/deepar/__init__.py +++ b/pytorch_forecasting/models/deepar/__init__.py @@ -1,5 +1,6 @@ """DeepAR: Probabilistic forecasting with autoregressive recurrent networks.""" from pytorch_forecasting.models.deepar._deepar import DeepAR +from pytorch_forecasting.models.deepar._deepar_metadata import DeepARMetadata -__all__ = ["DeepAR"] +__all__ = ["DeepAR", "DeepARMetadata"] diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py index 330b86e80..89aefc1b0 100644 --- a/pytorch_forecasting/models/deepar/_deepar_metadata.py +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -3,7 +3,7 @@ from pytorch_forecasting.models.base._base_object import _BasePtForecaster -class _DeepARMetadata(_BasePtForecaster): +class DeepARMetadata(_BasePtForecaster): """DeepAR metadata container.""" _tags = { From 012ab3d78ed8a99e6920f3df704188834bbe1c14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:48:02 +0100 Subject: [PATCH 11/26] Update _base_object.py --- pytorch_forecasting/models/base/_base_object.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 8895b4c2c..4fcb1bd22 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -27,7 +27,6 @@ def get_model_cls(cls): """Get model class.""" raise NotImplementedError - @classmethod def create_test_instance(cls, parameter_set="default"): """Construct an instance of the class, using first test parameter set. From 86365a00d88cda407674dbcac5c4d53bd26f3fce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:57:16 +0100 Subject: [PATCH 12/26] test failure --- pytorch_forecasting/tests/_conftest.py | 262 ++++++++++++++++++ .../tests/test_all_estimators.py | 101 +++++++ 2 files changed, 363 insertions(+) create mode 100644 pytorch_forecasting/tests/_conftest.py diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py new file mode 100644 index 000000000..e276446a6 --- /dev/null +++ b/pytorch_forecasting/tests/_conftest.py @@ -0,0 +1,262 @@ +import numpy as np +import pytest +import torch + +from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data + +torch.manual_seed(23) + + +@pytest.fixture(scope="session") +def gpus(): + if torch.cuda.is_available(): + return [0] + else: + return 0 + + +@pytest.fixture(scope="session") +def data_with_covariates(): + data = get_stallion_data() + data["month"] = data.date.dt.month.astype(str) + data["log_volume"] = np.log1p(data.volume) + data["weight"] = 1 + np.sqrt(data.volume) + + data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month + data["time_idx"] -= data["time_idx"].min() + + # convert special days into strings + special_days = [ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + data[special_days] = ( + data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category") + ) + data = data.astype(dict(industry_volume=float)) + + # select data subset + data = data[lambda x: x.sku.isin(data.sku.unique()[:2])][ + lambda x: x.agency.isin(data.agency.unique()[:2]) + ] + + # default target + data["target"] = data["volume"].clip(1e-3, 1.0) + + return data + + +def make_dataloaders(data_with_covariates, **kwargs): + training_cutoff = "2016-09-01" + max_encoder_length = 4 + max_prediction_length = 3 + + kwargs.setdefault("target", "volume") + kwargs.setdefault("group_ids", ["agency", "sku"]) + kwargs.setdefault("add_relative_time_idx", True) + kwargs.setdefault("time_varying_unknown_reals", ["volume"]) + + training = TimeSeriesDataSet( + data_with_covariates[lambda x: x.date < training_cutoff].copy(), + time_idx="time_idx", + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + **kwargs, # fixture parametrization + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data_with_covariates.copy(), + min_prediction_idx=training.index.time.max() + 1, + ) + train_dataloader = training.to_dataloader(train=True, batch_size=2, num_workers=0) + val_dataloader = validation.to_dataloader(train=False, batch_size=2, num_workers=0) + test_dataloader = validation.to_dataloader(train=False, batch_size=1, num_workers=0) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) + + +@pytest.fixture( + params=[ + dict(), + dict( + static_categoricals=["agency", "sku"], + static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + constant_fill_strategy={"volume": 0}, + categorical_encoders={"sku": NaNLabelEncoder(add_nan=True)}, + ), + dict(static_categoricals=["agency", "sku"]), + dict(randomize_length=True, min_encoder_length=2), + dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2), + dict(target_normalizer=GroupNormalizer(transformation="log1p")), + dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="softplus", center=False + ) + ), + dict(target="agency"), + # test multiple targets + dict(target=["industry_volume", "volume"]), + dict(target=["agency", "volume"]), + dict( + target=["agency", "volume"], min_encoder_length=1, min_prediction_length=1 + ), + dict(target=["agency", "volume"], weight="volume"), + # test weights + dict(target="volume", weight="volume"), + ], + scope="session", +) +def multiple_dataloaders_with_covariates(data_with_covariates, request): + return make_dataloaders(data_with_covariates, **request.param) + + +@pytest.fixture(scope="session") +def dataloaders_with_different_encoder_decoder_length(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "target", + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_with_covariates(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_reals=["discount"], + time_varying_unknown_reals=["target"], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_multi_target(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + time_varying_unknown_reals=["target", "discount"], + target=["target", "discount"], + add_relative_time_idx=False, + ) + + +@pytest.fixture(scope="session") +def dataloaders_fixed_window_without_covariates(): + data = generate_ar_data(seasonality=10.0, timesteps=50, n_series=2) + validation = data.series.iloc[:2] + + max_encoder_length = 30 + max_prediction_length = 10 + + training = TimeSeriesDataSet( + data[lambda x: ~x.series.isin(validation)], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + static_categoricals=[], + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + time_varying_unknown_reals=["value"], + target_normalizer=EncoderNormalizer(), + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data[lambda x: x.series.isin(validation)], + stop_randomization=True, + ) + batch_size = 2 + train_dataloader = training.to_dataloader( + train=True, batch_size=batch_size, num_workers=0 + ) + val_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + test_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 704ddfc20..609761e21 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -1,7 +1,11 @@ """Automated tests based on the skbase test suite template.""" from inspect import isclass +import shutil +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.loggers import TensorBoardLogger from skbase.testing import ( BaseFixtureGenerator as _BaseFixtureGenerator, TestAllObjects as _TestAllObjects, @@ -9,6 +13,7 @@ from pytorch_forecasting._registry import all_objects from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS +from pytorch_forecasting.tests._conftest import make_dataloaders # whether to test only estimators from modules that are changed w.r.t. main # default is False, can be set to True by pytest --only_changed_modules True flag @@ -110,6 +115,98 @@ def _all_objects(self): ] +def _integration( + data_with_covariates, + tmp_path, + cell_type="LSTM", + data_loader_kwargs={}, + clip_target: bool = False, + trainer_kwargs=None, + **kwargs, +): + data_with_covariates = data_with_covariates.copy() + if clip_target: + data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0) + else: + data_with_covariates["target"] = data_with_covariates["volume"] + data_loader_default_kwargs = dict( + target="target", + time_varying_known_reals=["price_actual"], + time_varying_unknown_reals=["target"], + static_categoricals=["agency"], + add_relative_time_idx=True, + ) + data_loader_default_kwargs.update(data_loader_kwargs) + dataloaders_with_covariates = make_dataloaders( + data_with_covariates, **data_loader_default_kwargs + ) + + train_dataloader = dataloaders_with_covariates["train"] + val_dataloader = dataloaders_with_covariates["val"] + test_dataloader = dataloaders_with_covariates["test"] + + early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" + ) + + logger = TensorBoardLogger(tmp_path) + if trainer_kwargs is None: + trainer_kwargs = {} + trainer = pl.Trainer( + max_epochs=3, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + enable_checkpointing=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + logger=logger, + **trainer_kwargs, + ) + + net = DeepAR.from_dataset( + train_dataloader.dataset, + hidden_size=5, + cell_type=cell_type, + learning_rate=0.01, + log_gradient_flow=True, + log_interval=1000, + n_plotting_samples=100, + **kwargs, + ) + net.size() + try: + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + # check loading + net = DeepAR.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + + # check prediction + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + trainer_kwargs=trainer_kwargs, + ) + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + trainer_kwargs=trainer_kwargs, + ) + + class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator, _TestAllObjects): """Generic tests for all objects in the mini package.""" @@ -118,3 +215,7 @@ def test_doctest_examples(self, object_class): import doctest doctest.run_docstring_examples(object_class, globals()) + + def certain_failure(self, object_class): + """Fails for certain, for testing.""" + assert False From f6dee46efaa6853afa299d5edca80d11a80367ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:59:52 +0100 Subject: [PATCH 13/26] Update test_all_estimators.py --- pytorch_forecasting/tests/test_all_estimators.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 609761e21..a5f2c3783 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -116,6 +116,7 @@ def _all_objects(self): def _integration( + estimator_cls, data_with_covariates, tmp_path, cell_type="LSTM", @@ -165,7 +166,7 @@ def _integration( **trainer_kwargs, ) - net = DeepAR.from_dataset( + net = estimator_cls.from_dataset( train_dataloader.dataset, hidden_size=5, cell_type=cell_type, @@ -185,7 +186,7 @@ def _integration( test_outputs = trainer.test(net, dataloaders=test_dataloader) assert len(test_outputs) > 0 # check loading - net = DeepAR.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + net = estimator_cls.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # check prediction net.predict( From 9b0e4ec4c7d47dc0115a87ea4297a22a2f0fe5eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 15:51:07 +0100 Subject: [PATCH 14/26] Update test_all_estimators.py --- pytorch_forecasting/tests/test_all_estimators.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index a5f2c3783..2934d42db 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -116,7 +116,7 @@ def _all_objects(self): def _integration( - estimator_cls, + estimator_cls, data_with_covariates, tmp_path, cell_type="LSTM", @@ -186,7 +186,9 @@ def _integration( test_outputs = trainer.test(net, dataloaders=test_dataloader) assert len(test_outputs) > 0 # check loading - net = estimator_cls.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + net = estimator_cls.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path + ) # check prediction net.predict( From 7de528537d6fe36dd554f3cad5550d6f66c512e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 16:02:28 +0100 Subject: [PATCH 15/26] Update test_all_estimators.py --- pytorch_forecasting/tests/test_all_estimators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 2934d42db..37b597712 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -28,7 +28,7 @@ class PackageConfig: # package to search for objects # expected type: str, package/module name, relative to python environment root - package_name = "pytroch_forecasting" + package_name = "pytorch_forecasting" # list of object types (class names) to exclude # expected type: list of str, str are class names From 57dfe3a4e47ac3a34199d787cc6282f43b18a9f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 16:42:01 +0100 Subject: [PATCH 16/26] test folders --- pytest.ini | 4 +- .../tests/test_all_estimators.py | 43 ++++++++++++++++--- 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/pytest.ini b/pytest.ini index 457863f87..52f4fa1c1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,7 +10,9 @@ addopts = --no-cov-on-fail markers = -testpaths = tests/ +testpaths = + tests/ + pytorch_forecasting/tests/ log_cli_level = ERROR log_format = %(asctime)s %(levelname)s %(message)s log_date_format = %Y-%m-%d %H:%M:%S diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 37b597712..8fbcc6ffe 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -6,10 +6,8 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger -from skbase.testing import ( - BaseFixtureGenerator as _BaseFixtureGenerator, - TestAllObjects as _TestAllObjects, -) +import pytest +from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator from pytorch_forecasting._registry import all_objects from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS @@ -110,10 +108,43 @@ def _all_objects(self): # which sequence the conditional fixtures are generated in fixture_sequence = [ + "object_metadata", "object_class", "object_instance", ] + def _generate_object_metadata(self, test_name, **kwargs): + """Return object class fixtures. + + Fixtures parametrized + --------------------- + object_class: object inheriting from BaseObject + ranges over all object classes not excluded by self.excluded_tests + """ + object_classes_to_test = [ + est for est in self._all_objects() if not self.is_excluded(test_name, est) + ] + object_names = [est.__name__ for est in object_classes_to_test] + + return object_classes_to_test, object_names + + def _generate_object_class(self, test_name, **kwargs): + """Return object class fixtures. + + Fixtures parametrized + --------------------- + object_class: object inheriting from BaseObject + ranges over all object classes not excluded by self.excluded_tests + """ + all_metadata = self._all_objects() + all_cls = [est.get_model_cls() for est in all_metadata] + object_classes_to_test = [ + est for est in all_cls if not self.is_excluded(test_name, est) + ] + object_names = [est.__name__ for est in object_classes_to_test] + + return object_classes_to_test, object_names + def _integration( estimator_cls, @@ -210,7 +241,7 @@ def _integration( ) -class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator, _TestAllObjects): +class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator): """Generic tests for all objects in the mini package.""" def test_doctest_examples(self, object_class): @@ -219,6 +250,6 @@ def test_doctest_examples(self, object_class): doctest.run_docstring_examples(object_class, globals()) - def certain_failure(self, object_class): + def test_certain_failure(self, object_class): """Fails for certain, for testing.""" assert False From c9f12dbdeea4aa431b52620b226cd57193a9a249 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 16:59:30 +0100 Subject: [PATCH 17/26] Update test.yml --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5fcd9c1ff..0083302dd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -71,7 +71,7 @@ jobs: - name: Run pytest shell: bash - run: python -m pytest tests + run: python -m pytest pytest: name: Run pytest @@ -110,7 +110,7 @@ jobs: - name: Run pytest shell: bash - run: python -m pytest tests + run: python -m pytest - name: Statistics run: | From fa8144ebae6312d34458a2464da2d1bc3f186754 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 18:56:39 +0100 Subject: [PATCH 18/26] test integration --- .../models/deepar/_deepar_metadata.py | 25 +---------- .../tests/test_all_estimators.py | 41 ++++++++++++++++++- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py index 89aefc1b0..206f113f0 100644 --- a/pytorch_forecasting/models/deepar/_deepar_metadata.py +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -23,29 +23,8 @@ def get_model_cls(cls): return DeepAR @classmethod - def get_test_params(cls, parameter_set="default"): - """Return testing parameter settings for the skbase object. - - ``get_test_params`` is a unified interface point to store - parameter settings for testing purposes. This function is also - used in ``create_test_instance`` and ``create_test_instances_and_names`` - to construct test instances. - - ``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``. - - Each ``dict`` is a parameter configuration for testing, - and can be used to construct an "interesting" test instance. - A call to ``cls(**params)`` should - be valid for all dictionaries ``params`` in the return of ``get_test_params``. - - The ``get_test_params`` need not return fixed lists of dictionaries, - it can also return dynamic or stochastic parameter settings. - - Parameters - ---------- - parameter_set : str, default="default" - Name of the set of test parameters to return, for use in tests. If no - special parameters are defined for a value, will return `"default"` set. + def get_test_train_params(cls): + """Return testing parameter settings for the trainer. Returns ------- diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 8fbcc6ffe..378316d7a 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -71,6 +71,8 @@ class BaseFixtureGenerator(_BaseFixtureGenerator): object_instance: instance of estimator inheriting from BaseObject ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS instances are generated by create_test_instance class method of object_class + trainer_kwargs: list of dict + ranges over dictionaries of kwargs for the trainer """ # overrides object retrieval in scikit-base @@ -111,6 +113,7 @@ def _all_objects(self): "object_metadata", "object_class", "object_instance", + "trainer_kwargs", ] def _generate_object_metadata(self, test_name, **kwargs): @@ -145,6 +148,30 @@ def _generate_object_class(self, test_name, **kwargs): return object_classes_to_test, object_names + def _generate_trainer_kwargs(self, test_name, **kwargs): + """Return kwargs for the trainer. + + Fixtures parametrized + --------------------- + trainer_kwargs: dict + ranges over all kwargs for the trainer + """ + # call _generate_object_class to get all the classes + object_meta_to_test, _ = self._generate_object_metadata(test_name=test_name) + + # create instances from the classes + train_kwargs_to_test = [] + train_kwargs_names = [] + # retrieve all object parameters if multiple, construct instances + for est in object_meta_to_test: + est_name = est.__name__ + all_train_kwargs = est.get_test_train_params() + train_kwargs_to_test += all_train_kwargs + rg = range(len(all_train_kwargs)) + train_kwargs_names += [f"{est_name}_{i}" for i in rg] + + return train_kwargs_to_test, train_kwargs_names + def _integration( estimator_cls, @@ -250,6 +277,16 @@ def test_doctest_examples(self, object_class): doctest.run_docstring_examples(object_class, globals()) - def test_certain_failure(self, object_class): + def test_integration( + self, object_class, trainer_kwargs, data_with_covariates, tmp_path + ): """Fails for certain, for testing.""" - assert False + from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss + + if "loss" in trainer_kwargs and isinstance( + trainer_kwargs["loss"], NegativeBinomialDistributionLoss + ): + data_with_covariates = data_with_covariates.assign( + volume=lambda x: x.volume.round() + ) + _integration(object_class, data_with_covariates, tmp_path, **trainer_kwargs) From 232a510bc6820786ef1ce46ab3115a04126054ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 19:07:35 +0100 Subject: [PATCH 19/26] fixes --- .../models/base/_base_object.py | 8 +++++ .../models/deepar/_deepar_metadata.py | 4 ++- .../tests/test_all_estimators.py | 31 ++++++++++--------- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 4fcb1bd22..7fd59d6a4 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -27,6 +27,14 @@ def get_model_cls(cls): """Get model class.""" raise NotImplementedError + @classmethod + def name(cls): + """Get model name.""" + name = cls.get_class_tags().get("info:name", None) + if name is None: + name = cls.get_model_cls().__name__ + return name + @classmethod def create_test_instance(cls, parameter_set="default"): """Construct an instance of the class, using first test parameter set. diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py index 206f113f0..e477d63b0 100644 --- a/pytorch_forecasting/models/deepar/_deepar_metadata.py +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -7,12 +7,14 @@ class DeepARMetadata(_BasePtForecaster): """DeepAR metadata container.""" _tags = { + "info:name": "DeepAR", + "info:compute": 3, + "authors": ["jdb78"], "capability:exogenous": True, "capability:multivariate": True, "capability:pred_int": True, "capability:flexible_history_length": True, "capability:cold_start": False, - "info:compute": 3, } @classmethod diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 378316d7a..c2d86e40e 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -127,7 +127,7 @@ def _generate_object_metadata(self, test_name, **kwargs): object_classes_to_test = [ est for est in self._all_objects() if not self.is_excluded(test_name, est) ] - object_names = [est.__name__ for est in object_classes_to_test] + object_names = [est.name() for est in object_classes_to_test] return object_classes_to_test, object_names @@ -156,21 +156,16 @@ def _generate_trainer_kwargs(self, test_name, **kwargs): trainer_kwargs: dict ranges over all kwargs for the trainer """ - # call _generate_object_class to get all the classes - object_meta_to_test, _ = self._generate_object_metadata(test_name=test_name) + if "object_metadata" in kwargs.keys(): + obj_meta = kwargs["object_metadata"] + else: + return [] - # create instances from the classes - train_kwargs_to_test = [] - train_kwargs_names = [] - # retrieve all object parameters if multiple, construct instances - for est in object_meta_to_test: - est_name = est.__name__ - all_train_kwargs = est.get_test_train_params() - train_kwargs_to_test += all_train_kwargs - rg = range(len(all_train_kwargs)) - train_kwargs_names += [f"{est_name}_{i}" for i in rg] + all_train_kwargs = obj_meta.get_test_train_params() + rg = range(len(all_train_kwargs)) + train_kwargs_names = [str(i) for i in rg] - return train_kwargs_to_test, train_kwargs_names + return all_train_kwargs, train_kwargs_names def _integration( @@ -278,11 +273,17 @@ def test_doctest_examples(self, object_class): doctest.run_docstring_examples(object_class, globals()) def test_integration( - self, object_class, trainer_kwargs, data_with_covariates, tmp_path + self, + object_metadata, + trainer_kwargs, + data_with_covariates, + tmp_path, ): """Fails for certain, for testing.""" from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss + object_class = object_metadata.get_model_cls() + if "loss" in trainer_kwargs and isinstance( trainer_kwargs["loss"], NegativeBinomialDistributionLoss ): From 1c8d4b5c4fbf8dca91ec28e36e8781bb08a291bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 19:30:39 +0100 Subject: [PATCH 20/26] Update _conftest.py --- pytorch_forecasting/tests/_conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py index e276446a6..20cad22c5 100644 --- a/pytorch_forecasting/tests/_conftest.py +++ b/pytorch_forecasting/tests/_conftest.py @@ -17,7 +17,7 @@ def gpus(): return 0 -@pytest.fixture(scope="session") +@pytest.fixture(scope="package") def data_with_covariates(): data = get_stallion_data() data["month"] = data.date.dt.month.astype(str) From f632e32325a657fe975f9c76344eaba0585e17e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 19:37:01 +0100 Subject: [PATCH 21/26] try scenarios --- pytorch_forecasting/tests/_conftest.py | 2 +- pytorch_forecasting/tests/_data_scenarios.py | 261 ++++++++++++++++++ .../tests/test_all_estimators.py | 5 +- 3 files changed, 265 insertions(+), 3 deletions(-) create mode 100644 pytorch_forecasting/tests/_data_scenarios.py diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py index 20cad22c5..e276446a6 100644 --- a/pytorch_forecasting/tests/_conftest.py +++ b/pytorch_forecasting/tests/_conftest.py @@ -17,7 +17,7 @@ def gpus(): return 0 -@pytest.fixture(scope="package") +@pytest.fixture(scope="session") def data_with_covariates(): data = get_stallion_data() data["month"] = data.date.dt.month.astype(str) diff --git a/pytorch_forecasting/tests/_data_scenarios.py b/pytorch_forecasting/tests/_data_scenarios.py new file mode 100644 index 000000000..062db97dd --- /dev/null +++ b/pytorch_forecasting/tests/_data_scenarios.py @@ -0,0 +1,261 @@ +import numpy as np +import pytest +import torch + +from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data + +torch.manual_seed(23) + + +@pytest.fixture(scope="session") +def gpus(): + if torch.cuda.is_available(): + return [0] + else: + return 0 + + +def data_with_covariates(): + data = get_stallion_data() + data["month"] = data.date.dt.month.astype(str) + data["log_volume"] = np.log1p(data.volume) + data["weight"] = 1 + np.sqrt(data.volume) + + data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month + data["time_idx"] -= data["time_idx"].min() + + # convert special days into strings + special_days = [ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + data[special_days] = ( + data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category") + ) + data = data.astype(dict(industry_volume=float)) + + # select data subset + data = data[lambda x: x.sku.isin(data.sku.unique()[:2])][ + lambda x: x.agency.isin(data.agency.unique()[:2]) + ] + + # default target + data["target"] = data["volume"].clip(1e-3, 1.0) + + return data + + +def make_dataloaders(data_with_covariates, **kwargs): + training_cutoff = "2016-09-01" + max_encoder_length = 4 + max_prediction_length = 3 + + kwargs.setdefault("target", "volume") + kwargs.setdefault("group_ids", ["agency", "sku"]) + kwargs.setdefault("add_relative_time_idx", True) + kwargs.setdefault("time_varying_unknown_reals", ["volume"]) + + training = TimeSeriesDataSet( + data_with_covariates[lambda x: x.date < training_cutoff].copy(), + time_idx="time_idx", + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + **kwargs, # fixture parametrization + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data_with_covariates.copy(), + min_prediction_idx=training.index.time.max() + 1, + ) + train_dataloader = training.to_dataloader(train=True, batch_size=2, num_workers=0) + val_dataloader = validation.to_dataloader(train=False, batch_size=2, num_workers=0) + test_dataloader = validation.to_dataloader(train=False, batch_size=1, num_workers=0) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) + + +@pytest.fixture( + params=[ + dict(), + dict( + static_categoricals=["agency", "sku"], + static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + constant_fill_strategy={"volume": 0}, + categorical_encoders={"sku": NaNLabelEncoder(add_nan=True)}, + ), + dict(static_categoricals=["agency", "sku"]), + dict(randomize_length=True, min_encoder_length=2), + dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2), + dict(target_normalizer=GroupNormalizer(transformation="log1p")), + dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="softplus", center=False + ) + ), + dict(target="agency"), + # test multiple targets + dict(target=["industry_volume", "volume"]), + dict(target=["agency", "volume"]), + dict( + target=["agency", "volume"], min_encoder_length=1, min_prediction_length=1 + ), + dict(target=["agency", "volume"], weight="volume"), + # test weights + dict(target="volume", weight="volume"), + ], + scope="session", +) +def multiple_dataloaders_with_covariates(data_with_covariates, request): + return make_dataloaders(data_with_covariates, **request.param) + + +@pytest.fixture(scope="session") +def dataloaders_with_different_encoder_decoder_length(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "target", + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_with_covariates(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_reals=["discount"], + time_varying_unknown_reals=["target"], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_multi_target(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + time_varying_unknown_reals=["target", "discount"], + target=["target", "discount"], + add_relative_time_idx=False, + ) + + +@pytest.fixture(scope="session") +def dataloaders_fixed_window_without_covariates(): + data = generate_ar_data(seasonality=10.0, timesteps=50, n_series=2) + validation = data.series.iloc[:2] + + max_encoder_length = 30 + max_prediction_length = 10 + + training = TimeSeriesDataSet( + data[lambda x: ~x.series.isin(validation)], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + static_categoricals=[], + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + time_varying_unknown_reals=["value"], + target_normalizer=EncoderNormalizer(), + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data[lambda x: x.series.isin(validation)], + stop_randomization=True, + ) + batch_size = 2 + train_dataloader = training.to_dataloader( + train=True, batch_size=batch_size, num_workers=0 + ) + val_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + test_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index c2d86e40e..b8a21cc6a 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -6,7 +6,6 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger -import pytest from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator from pytorch_forecasting._registry import all_objects @@ -276,11 +275,13 @@ def test_integration( self, object_metadata, trainer_kwargs, - data_with_covariates, tmp_path, ): """Fails for certain, for testing.""" from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss + from pytorch_forecasting.tests._data_scenarios import data_with_covariates + + data_with_covariates = data_with_covariates() object_class = object_metadata.get_model_cls() From a6691341001f813ac4c5d12aafb173645087d679 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 4 May 2025 18:28:04 +0200 Subject: [PATCH 22/26] Update _lookup.py --- pytorch_forecasting/_registry/_lookup.py | 49 ++++++++++++++++-------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index 828c448b1..b4238f980 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -11,6 +11,7 @@ __author__ = ["fkiraly"] # all_objects is based on the sklearn utility all_estimators +from inspect import isclass from pathlib import Path from skbase.lookup import all_objects as _all_objects @@ -133,25 +134,39 @@ def all_objects( result = [] ROOT = str(Path(__file__).parent.parent) # package root directory - if isinstance(filter_tags, str): - filter_tags = {filter_tags: True} - filter_tags = filter_tags.copy() if filter_tags else None - - if object_types: - if filter_tags and "object_type" not in filter_tags.keys(): - object_tag_filter = {"object_type": object_types} - elif filter_tags: - filter_tags_filter = filter_tags.get("object_type", []) - if isinstance(object_types, str): - object_types = [object_types] - object_tag_update = {"object_type": object_types + filter_tags_filter} - filter_tags.update(object_tag_update) + def _coerce_to_str(obj): + if isinstance(obj, (list, tuple)): + return [_coerce_to_str(o) for o in obj] + if isclass(obj): + obj = obj.get_tag("object_type") + return obj + + def _coerce_to_list_of_str(obj): + obj = _coerce_to_str(obj) + if isinstance(obj, str): + return [obj] + return obj + + if object_types is not None: + object_types = _coerce_to_list_of_str(object_types) + object_types = list(set(object_types)) + + if object_types is not None: + if filter_tags is None: + filter_tags = {} + elif isinstance(filter_tags, str): + filter_tags = {filter_tags: True} else: - object_tag_filter = {"object_type": object_types} - if filter_tags: - filter_tags.update(object_tag_filter) + filter_tags = filter_tags.copy() + + if "object_type" in filter_tags: + obj_field = filter_tags["object_type"] + obj_field = _coerce_to_list_of_str(obj_field) + obj_field = obj_field + object_types else: - filter_tags = object_tag_filter + obj_field = object_types + + filter_tags["object_type"] = obj_field result = _all_objects( object_types=[_BaseObject], From d78bf5dc19cef1e659e8258552691c1713b2dd4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 4 May 2025 18:32:38 +0200 Subject: [PATCH 23/26] Update _lookup.py --- pytorch_forecasting/_registry/_lookup.py | 95 +++++++++++++++--------- 1 file changed, 60 insertions(+), 35 deletions(-) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index b4238f980..0fb4c0c9d 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -40,44 +40,64 @@ def all_objects( ---------- object_types: str, list of str, optional (default=None) Which kind of objects should be returned. - if None, no filter is applied and all objects are returned. - if str or list of str, strings define scitypes specified in search - only objects that are of (at least) one of the scitypes are returned - possible str values are entries of registry.BASE_CLASS_REGISTER (first col) - for instance 'regrssor_proba', 'distribution, 'metric' - return_names: bool, optional (default=True) + * if None, no filter is applied and all objects are returned. + * if str or list of str, strings define scitypes specified in search + only objects that are of (at least) one of the scitypes are returned - if True, estimator class name is included in the ``all_objects`` - return in the order: name, estimator class, optional tags, either as - a tuple or as pandas.DataFrame columns + return_names: bool, optional (default=True) - if False, estimator class name is removed from the ``all_objects`` return. + * if True, estimator class name is included in the ``all_objects`` + return in the order: name, estimator class, optional tags, either as + a tuple or as pandas.DataFrame columns + * if False, estimator class name is removed from the ``all_objects`` return. - filter_tags: dict of (str or list of str), optional (default=None) + filter_tags: dict of (str or list of str or re.Pattern), optional (default=None) For a list of valid tag strings, use the registry.all_tags utility. - ``filter_tags`` subsets the returned estimators as follows: + ``filter_tags`` subsets the returned objects as follows: * each key/value pair is statement in "and"/conjunction * key is tag name to sub-set on * value str or list of string are tag values * condition is "key must be equal to value, or in set(value)" - exclude_estimators: str, list of str, optional (default=None) - Names of estimators to exclude. + In detail, he return will be filtered to keep exactly the classes + where tags satisfy all the filter conditions specified by ``filter_tags``. + Filter conditions are as follows, for ``tag_name: search_value`` pairs in + the ``filter_tags`` dict, applied to a class ``klass``: + + - If ``klass`` does not have a tag with name ``tag_name``, it is excluded. + Otherwise, let ``tag_value`` be the value of the tag with name ``tag_name``. + - If ``search_value`` is a string, and ``tag_value`` is a string, + the filter condition is that ``search_value`` must match the tag value. + - If ``search_value`` is a string, and ``tag_value`` is a list, + the filter condition is that ``search_value`` is contained in ``tag_value``. + - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a string, + the filter condition is that ``search_value.fullmatch(tag_value)`` + is true, i.e., the regex matches the tag value. + - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a list, + the filter condition is that at least one element of ``tag_value`` + matches the regex. + - If ``search_value`` is iterable, then the filter condition is that + at least one element of ``search_value`` satisfies the above conditions, + applied to ``tag_value``. + + Note: ``re.Pattern`` is supported only from ``scikit-base`` version 0.8.0. + + exclude_objects: str, list of str, optional (default=None) + Names of objects to exclude. as_dataframe: bool, optional (default=False) - True: ``all_objects`` will return a pandas.DataFrame with named - columns for all of the attributes being returned. - - False: ``all_objects`` will return a list (either a list of - estimators or a list of tuples, see Returns) + * True: ``all_objects`` will return a ``pandas.DataFrame`` with named + columns for all of the attributes being returned. + * False: ``all_objects`` will return a list (either a list of + objects or a list of tuples, see Returns) return_tags: str or list of str, optional (default=None) Names of tags to fetch and return each estimator's value of. - For a list of valid tag strings, use the registry.all_tags utility. + For a list of valid tag strings, use the ``registry.all_tags`` utility. if str or list of str, the tag values named in return_tags will be fetched for each estimator and will be appended as either columns or tuple entries. @@ -88,27 +108,32 @@ def all_objects( Returns ------- all_objects will return one of the following: - 1. list of objects, if return_names=False, and return_tags is None - 2. list of tuples (optional object name, class, ~optional object - tags), if return_names=True or return_tags is not None. - 3. pandas.DataFrame if as_dataframe = True + + 1. list of objects, if ``return_names=False``, and ``return_tags`` is None + + 2. list of tuples (optional estimator name, class, optional estimator + tags), if ``return_names=True`` or ``return_tags`` is not ``None``. + + 3. ``pandas.DataFrame`` if ``as_dataframe = True`` + if list of objects: entries are objects matching the query, - in alphabetical order of object name + in alphabetical order of estimator name + if list of tuples: - list of (optional object name, object, optional object - tags) matching the query, in alphabetical order of object name, + list of (optional estimator name, estimator, optional estimator + tags) matching the query, in alphabetical order of estimator name, where - ``name`` is the object name as string, and is an - optional return - ``object`` is the actual object - ``tags`` are the object's values for each tag in return_tags - and is an optional return. - if dataframe: - all_objects will return a pandas.DataFrame. + ``name`` is the estimator name as string, and is an + optional return + ``estimator`` is the actual estimator + ``tags`` are the estimator's values for each tag in return_tags + and is an optional return. + + if ``DataFrame``: column names represent the attributes contained in each column. "objects" will be the name of the column of objects, "names" - will be the name of the column of object class names and the string(s) + will be the name of the column of estimator class names and the string(s) passed in return_tags will serve as column names for all columns of tags that were optionally requested. From 94e8ce8bbe73ecd1c6769b39a0416ce5ed8de69f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 27 May 2025 22:38:27 +0200 Subject: [PATCH 24/26] move cell_type and n_plotting samples to kwargs --- .../models/deepar/_deepar_metadata.py | 22 ++++++++++++++++--- .../tests/test_all_estimators.py | 3 --- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py index e477d63b0..a9eb46a04 100644 --- a/pytorch_forecasting/models/deepar/_deepar_metadata.py +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -47,7 +47,7 @@ def get_test_train_params(cls): return [ {}, - {"cell_type": "GRU"}, + {"cell_type": "GRU", "n_plotting_samples": 100}, dict( loss=LogNormalDistributionLoss(), clip_target=True, @@ -56,6 +56,8 @@ def get_test_train_params(cls): groups=["agency", "sku"], transformation="log" ) ), + cell_type="LSTM", + n_plotting_samples=100, ), dict( loss=NegativeBinomialDistributionLoss(), @@ -65,6 +67,8 @@ def get_test_train_params(cls): groups=["agency", "sku"], center=False ) ), + cell_type="LSTM", + n_plotting_samples=100, ), dict( loss=BetaDistributionLoss(), @@ -74,6 +78,8 @@ def get_test_train_params(cls): groups=["agency", "sku"], transformation="logit" ) ), + cell_type="LSTM", + n_plotting_samples=100, ), dict( data_loader_kwargs=dict( @@ -81,20 +87,28 @@ def get_test_train_params(cls): target="volume", time_varying_unknown_reals=["volume"], min_encoder_length=2, - ) + ), + cell_type="LSTM", + n_plotting_samples=100, ), dict( data_loader_kwargs=dict( time_varying_unknown_reals=["volume", "discount"], target=["volume", "discount"], lags={"volume": [2], "discount": [2]}, - ) + ), + cell_type="LSTM", + n_plotting_samples=100, ), dict( loss=ImplicitQuantileNetworkDistributionLoss(hidden_size=8), + cell_type="LSTM", + n_plotting_samples=100, ), dict( loss=MultivariateNormalDistributionLoss(), + cell_type="LSTM", + n_plotting_samples=100, trainer_kwargs=dict(accelerator="cpu"), ), dict( @@ -104,6 +118,8 @@ def get_test_train_params(cls): groups=["agency", "sku"], transformation="log1p" ) ), + cell_type="LSTM", + n_plotting_samples=100, trainer_kwargs=dict(accelerator="cpu"), ), ] diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index b8a21cc6a..529cd7bb4 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -171,7 +171,6 @@ def _integration( estimator_cls, data_with_covariates, tmp_path, - cell_type="LSTM", data_loader_kwargs={}, clip_target: bool = False, trainer_kwargs=None, @@ -221,11 +220,9 @@ def _integration( net = estimator_cls.from_dataset( train_dataloader.dataset, hidden_size=5, - cell_type=cell_type, learning_rate=0.01, log_gradient_flow=True, log_interval=1000, - n_plotting_samples=100, **kwargs, ) net.size() From 5d2d00132b69a173d726d78b8698aeb77954d4f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 27 May 2025 22:40:27 +0200 Subject: [PATCH 25/26] doctest runner fixed --- pytorch_forecasting/tests/test_all_estimators.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 529cd7bb4..535ab763c 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -7,6 +7,7 @@ from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator +from skbase.utils.doctest_run import run_doctest from pytorch_forecasting._registry import all_objects from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS @@ -264,9 +265,7 @@ class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator): def test_doctest_examples(self, object_class): """Runs doctests for estimator class.""" - import doctest - - doctest.run_docstring_examples(object_class, globals()) + run_doctest(object_class, name=f"class {object_class.__name__}") def test_integration( self, From 487c98119c576a4060ef3bc14f2f565a6e60347f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 27 May 2025 23:21:45 +0200 Subject: [PATCH 26/26] Update test_all_estimators.py --- pytorch_forecasting/tests/test_all_estimators.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 535ab763c..3e046cda1 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -7,7 +7,6 @@ from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator -from skbase.utils.doctest_run import run_doctest from pytorch_forecasting._registry import all_objects from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS @@ -265,6 +264,8 @@ class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator): def test_doctest_examples(self, object_class): """Runs doctests for estimator class.""" + from skbase.utils.doctest_run import run_doctest + run_doctest(object_class, name=f"class {object_class.__name__}") def test_integration(