diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ccbb25b33..244cdbac2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -112,7 +112,7 @@ jobs: - name: Run pytest shell: bash - run: python -m pytest tests + run: python -m pytest pytest: name: Run pytest @@ -152,7 +152,7 @@ jobs: - name: Run pytest shell: bash - run: python -m pytest tests + run: python -m pytest - name: Statistics run: | diff --git a/pyproject.toml b/pyproject.toml index d7256e782..7ebac1a32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,7 @@ dev = [ "pytest-dotenv>=0.5.2,<1.0.0", "tensorboard>=2.12.1,<3.0.0", "pandoc>=2.3,<3.0.0", + "scikit-base", ] # docs - dependencies for building the documentation diff --git a/pytest.ini b/pytest.ini index 457863f87..52f4fa1c1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,7 +10,9 @@ addopts = --no-cov-on-fail markers = -testpaths = tests/ +testpaths = + tests/ + pytorch_forecasting/tests/ log_cli_level = ERROR log_format = %(asctime)s %(levelname)s %(message)s log_date_format = %Y-%m-%d %H:%M:%S diff --git a/pytorch_forecasting/_registry/__init__.py b/pytorch_forecasting/_registry/__init__.py new file mode 100644 index 000000000..f71836bfe --- /dev/null +++ b/pytorch_forecasting/_registry/__init__.py @@ -0,0 +1,5 @@ +"""PyTorch Forecasting registry.""" + +from pytorch_forecasting._registry._lookup import all_objects + +__all__ = ["all_objects"] diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py new file mode 100644 index 000000000..0fb4c0c9d --- /dev/null +++ b/pytorch_forecasting/_registry/_lookup.py @@ -0,0 +1,242 @@ +"""Registry lookup methods. + +This module exports the following methods for registry lookup: + +all_objects(object_types, filter_tags) + lookup and filtering of objects +""" + +# based on the sktime module of same name + +__author__ = ["fkiraly"] +# all_objects is based on the sklearn utility all_estimators + +from inspect import isclass +from pathlib import Path + +from skbase.lookup import all_objects as _all_objects + +from pytorch_forecasting.models.base import _BaseObject + + +def all_objects( + object_types=None, + filter_tags=None, + exclude_objects=None, + return_names=True, + as_dataframe=False, + return_tags=None, + suppress_import_stdout=True, +): + """Get a list of all objects from pytorch_forecasting. + + This function crawls the module and gets all classes that inherit + from skbase compatible base classes. + + Not included are: the base classes themselves, classes defined in test + modules. + + Parameters + ---------- + object_types: str, list of str, optional (default=None) + Which kind of objects should be returned. + + * if None, no filter is applied and all objects are returned. + * if str or list of str, strings define scitypes specified in search + only objects that are of (at least) one of the scitypes are returned + + return_names: bool, optional (default=True) + + * if True, estimator class name is included in the ``all_objects`` + return in the order: name, estimator class, optional tags, either as + a tuple or as pandas.DataFrame columns + * if False, estimator class name is removed from the ``all_objects`` return. + + filter_tags: dict of (str or list of str or re.Pattern), optional (default=None) + For a list of valid tag strings, use the registry.all_tags utility. + + ``filter_tags`` subsets the returned objects as follows: + + * each key/value pair is statement in "and"/conjunction + * key is tag name to sub-set on + * value str or list of string are tag values + * condition is "key must be equal to value, or in set(value)" + + In detail, he return will be filtered to keep exactly the classes + where tags satisfy all the filter conditions specified by ``filter_tags``. + Filter conditions are as follows, for ``tag_name: search_value`` pairs in + the ``filter_tags`` dict, applied to a class ``klass``: + + - If ``klass`` does not have a tag with name ``tag_name``, it is excluded. + Otherwise, let ``tag_value`` be the value of the tag with name ``tag_name``. + - If ``search_value`` is a string, and ``tag_value`` is a string, + the filter condition is that ``search_value`` must match the tag value. + - If ``search_value`` is a string, and ``tag_value`` is a list, + the filter condition is that ``search_value`` is contained in ``tag_value``. + - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a string, + the filter condition is that ``search_value.fullmatch(tag_value)`` + is true, i.e., the regex matches the tag value. + - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a list, + the filter condition is that at least one element of ``tag_value`` + matches the regex. + - If ``search_value`` is iterable, then the filter condition is that + at least one element of ``search_value`` satisfies the above conditions, + applied to ``tag_value``. + + Note: ``re.Pattern`` is supported only from ``scikit-base`` version 0.8.0. + + exclude_objects: str, list of str, optional (default=None) + Names of objects to exclude. + + as_dataframe: bool, optional (default=False) + + * True: ``all_objects`` will return a ``pandas.DataFrame`` with named + columns for all of the attributes being returned. + * False: ``all_objects`` will return a list (either a list of + objects or a list of tuples, see Returns) + + return_tags: str or list of str, optional (default=None) + Names of tags to fetch and return each estimator's value of. + For a list of valid tag strings, use the ``registry.all_tags`` utility. + if str or list of str, + the tag values named in return_tags will be fetched for each + estimator and will be appended as either columns or tuple entries. + + suppress_import_stdout : bool, optional. Default=True + whether to suppress stdout printout upon import. + + Returns + ------- + all_objects will return one of the following: + + 1. list of objects, if ``return_names=False``, and ``return_tags`` is None + + 2. list of tuples (optional estimator name, class, optional estimator + tags), if ``return_names=True`` or ``return_tags`` is not ``None``. + + 3. ``pandas.DataFrame`` if ``as_dataframe = True`` + + if list of objects: + entries are objects matching the query, + in alphabetical order of estimator name + + if list of tuples: + list of (optional estimator name, estimator, optional estimator + tags) matching the query, in alphabetical order of estimator name, + where + ``name`` is the estimator name as string, and is an + optional return + ``estimator`` is the actual estimator + ``tags`` are the estimator's values for each tag in return_tags + and is an optional return. + + if ``DataFrame``: + column names represent the attributes contained in each column. + "objects" will be the name of the column of objects, "names" + will be the name of the column of estimator class names and the string(s) + passed in return_tags will serve as column names for all columns of + tags that were optionally requested. + + Examples + -------- + >>> from pytorch_forecasting._registry import all_objects + >>> # return a complete list of objects as pd.Dataframe + >>> all_objects(as_dataframe=True) # doctest: +SKIP + + References + ---------- + Adapted version of sktime's ``all_estimators``, + which is an evolution of scikit-learn's ``all_estimators`` + """ + MODULES_TO_IGNORE = ( + "tests", + "setup", + "contrib", + "utils", + "all", + ) + + result = [] + ROOT = str(Path(__file__).parent.parent) # package root directory + + def _coerce_to_str(obj): + if isinstance(obj, (list, tuple)): + return [_coerce_to_str(o) for o in obj] + if isclass(obj): + obj = obj.get_tag("object_type") + return obj + + def _coerce_to_list_of_str(obj): + obj = _coerce_to_str(obj) + if isinstance(obj, str): + return [obj] + return obj + + if object_types is not None: + object_types = _coerce_to_list_of_str(object_types) + object_types = list(set(object_types)) + + if object_types is not None: + if filter_tags is None: + filter_tags = {} + elif isinstance(filter_tags, str): + filter_tags = {filter_tags: True} + else: + filter_tags = filter_tags.copy() + + if "object_type" in filter_tags: + obj_field = filter_tags["object_type"] + obj_field = _coerce_to_list_of_str(obj_field) + obj_field = obj_field + object_types + else: + obj_field = object_types + + filter_tags["object_type"] = obj_field + + result = _all_objects( + object_types=[_BaseObject], + filter_tags=filter_tags, + exclude_objects=exclude_objects, + return_names=return_names, + as_dataframe=as_dataframe, + return_tags=return_tags, + suppress_import_stdout=suppress_import_stdout, + package_name="pytorch_forecasting", + path=ROOT, + modules_to_ignore=MODULES_TO_IGNORE, + ) + + return result + + +def _check_list_of_str_or_error(arg_to_check, arg_name): + """Check that certain arguments are str or list of str. + + Parameters + ---------- + arg_to_check: argument we are testing the type of + arg_name: str, + name of the argument we are testing, will be added to the error if + ``arg_to_check`` is not a str or a list of str + + Returns + ------- + arg_to_check: list of str, + if arg_to_check was originally a str it converts it into a list of str + so that it can be iterated over. + + Raises + ------ + TypeError if arg_to_check is not a str or list of str + """ + # check that return_tags has the right type: + if isinstance(arg_to_check, str): + arg_to_check = [arg_to_check] + if not isinstance(arg_to_check, list) or not all( + isinstance(value, str) for value in arg_to_check + ): + raise TypeError( + f"Error in all_objects! Argument {arg_name} must be either\ + a str or list of str" + ) + return arg_to_check diff --git a/pytorch_forecasting/models/base/__init__.py b/pytorch_forecasting/models/base/__init__.py index 4860e4838..7b69ec246 100644 --- a/pytorch_forecasting/models/base/__init__.py +++ b/pytorch_forecasting/models/base/__init__.py @@ -7,8 +7,14 @@ BaseModelWithCovariates, Prediction, ) +from pytorch_forecasting.models.base._base_object import ( + _BaseObject, + _BasePtForecaster, +) __all__ = [ + "_BaseObject", + "_BasePtForecaster", "AutoRegressiveBaseModel", "AutoRegressiveBaseModelWithCovariates", "BaseModel", diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py new file mode 100644 index 000000000..7fd59d6a4 --- /dev/null +++ b/pytorch_forecasting/models/base/_base_object.py @@ -0,0 +1,115 @@ +"""Base Classes for pytorch-forecasting models, skbase compatible for indexing.""" + +import inspect + +from pytorch_forecasting.utils._dependencies import _safe_import + +_SkbaseBaseObject = _safe_import("skbase.base.BaseObject", pkg_name="scikit-base") + + +class _BaseObject(_SkbaseBaseObject): + + pass + + +class _BasePtForecaster(_BaseObject): + """Base class for all PyTorch Forecasting forecaster metadata. + + This class points to model objects and contains metadata as tags. + """ + + _tags = { + "object_type": "forecaster_pytorch", + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + raise NotImplementedError + + @classmethod + def name(cls): + """Get model name.""" + name = cls.get_class_tags().get("info:name", None) + if name is None: + name = cls.get_model_cls().__name__ + return name + + @classmethod + def create_test_instance(cls, parameter_set="default"): + """Construct an instance of the class, using first test parameter set. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + instance : instance of the class with default parameters + + """ + if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args: + params = cls.get_test_params(parameter_set=parameter_set) + else: + params = cls.get_test_params() + + if isinstance(params, list) and isinstance(params[0], dict): + params = params[0] + elif isinstance(params, dict): + pass + else: + raise TypeError( + "get_test_params should either return a dict or list of dict." + ) + + return cls.get_model_cls()(**params) + + @classmethod + def create_test_instances_and_names(cls, parameter_set="default"): + """Create list of all test instances and a list of names for them. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + objs : list of instances of cls + i-th instance is ``cls(**cls.get_test_params()[i])`` + names : list of str, same length as objs + i-th element is name of i-th instance of obj in tests. + The naming convention is ``{cls.__name__}-{i}`` if more than one instance, + otherwise ``{cls.__name__}`` + """ + if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args: + param_list = cls.get_test_params(parameter_set=parameter_set) + else: + param_list = cls.get_test_params() + + objs = [] + if not isinstance(param_list, (dict, list)): + raise RuntimeError( + f"Error in {cls.__name__}.get_test_params, " + "return must be param dict for class, or list thereof" + ) + if isinstance(param_list, dict): + param_list = [param_list] + for params in param_list: + if not isinstance(params, dict): + raise RuntimeError( + f"Error in {cls.__name__}.get_test_params, " + "return must be param dict for class, or list thereof" + ) + objs += [cls.get_model_cls()(**params)] + + num_instances = len(param_list) + if num_instances > 1: + names = [cls.__name__ + "-" + str(i) for i in range(num_instances)] + else: + names = [cls.__name__] + + return objs, names diff --git a/pytorch_forecasting/models/deepar/__init__.py b/pytorch_forecasting/models/deepar/__init__.py index 679f296f6..149e19e0d 100644 --- a/pytorch_forecasting/models/deepar/__init__.py +++ b/pytorch_forecasting/models/deepar/__init__.py @@ -1,5 +1,6 @@ """DeepAR: Probabilistic forecasting with autoregressive recurrent networks.""" from pytorch_forecasting.models.deepar._deepar import DeepAR +from pytorch_forecasting.models.deepar._deepar_metadata import DeepARMetadata -__all__ = ["DeepAR"] +__all__ = ["DeepAR", "DeepARMetadata"] diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py new file mode 100644 index 000000000..a9eb46a04 --- /dev/null +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -0,0 +1,125 @@ +"""DeepAR metadata container.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecaster + + +class DeepARMetadata(_BasePtForecaster): + """DeepAR metadata container.""" + + _tags = { + "info:name": "DeepAR", + "info:compute": 3, + "authors": ["jdb78"], + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": True, + "capability:flexible_history_length": True, + "capability:cold_start": False, + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + from pytorch_forecasting.models import DeepAR + + return DeepAR + + @classmethod + def get_test_train_params(cls): + """Return testing parameter settings for the trainer. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + from pytorch_forecasting.data.encoders import GroupNormalizer + from pytorch_forecasting.metrics import ( + BetaDistributionLoss, + ImplicitQuantileNetworkDistributionLoss, + LogNormalDistributionLoss, + MultivariateNormalDistributionLoss, + NegativeBinomialDistributionLoss, + ) + + return [ + {}, + {"cell_type": "GRU", "n_plotting_samples": 100}, + dict( + loss=LogNormalDistributionLoss(), + clip_target=True, + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="log" + ) + ), + cell_type="LSTM", + n_plotting_samples=100, + ), + dict( + loss=NegativeBinomialDistributionLoss(), + clip_target=False, + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], center=False + ) + ), + cell_type="LSTM", + n_plotting_samples=100, + ), + dict( + loss=BetaDistributionLoss(), + clip_target=True, + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="logit" + ) + ), + cell_type="LSTM", + n_plotting_samples=100, + ), + dict( + data_loader_kwargs=dict( + lags={"volume": [2, 5]}, + target="volume", + time_varying_unknown_reals=["volume"], + min_encoder_length=2, + ), + cell_type="LSTM", + n_plotting_samples=100, + ), + dict( + data_loader_kwargs=dict( + time_varying_unknown_reals=["volume", "discount"], + target=["volume", "discount"], + lags={"volume": [2], "discount": [2]}, + ), + cell_type="LSTM", + n_plotting_samples=100, + ), + dict( + loss=ImplicitQuantileNetworkDistributionLoss(hidden_size=8), + cell_type="LSTM", + n_plotting_samples=100, + ), + dict( + loss=MultivariateNormalDistributionLoss(), + cell_type="LSTM", + n_plotting_samples=100, + trainer_kwargs=dict(accelerator="cpu"), + ), + dict( + loss=MultivariateNormalDistributionLoss(), + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="log1p" + ) + ), + cell_type="LSTM", + n_plotting_samples=100, + trainer_kwargs=dict(accelerator="cpu"), + ), + ] diff --git a/pytorch_forecasting/tests/__init__.py b/pytorch_forecasting/tests/__init__.py new file mode 100644 index 000000000..6c2d26856 --- /dev/null +++ b/pytorch_forecasting/tests/__init__.py @@ -0,0 +1 @@ +"""PyTorch Forecasting test suite.""" diff --git a/pytorch_forecasting/tests/_config.py b/pytorch_forecasting/tests/_config.py new file mode 100644 index 000000000..dd9c2e889 --- /dev/null +++ b/pytorch_forecasting/tests/_config.py @@ -0,0 +1,13 @@ +"""Test configs.""" + +# list of str, names of estimators to exclude from testing +# WARNING: tests for these estimators will be skipped +EXCLUDE_ESTIMATORS = [ + "DummySkipped", + "ClassName", # exclude classes from extension templates +] + +# dictionary of lists of str, names of tests to exclude from testing +# keys are class names of estimators, values are lists of test names to exclude +# WARNING: tests with these names will be skipped +EXCLUDED_TESTS = {} diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py new file mode 100644 index 000000000..e276446a6 --- /dev/null +++ b/pytorch_forecasting/tests/_conftest.py @@ -0,0 +1,262 @@ +import numpy as np +import pytest +import torch + +from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data + +torch.manual_seed(23) + + +@pytest.fixture(scope="session") +def gpus(): + if torch.cuda.is_available(): + return [0] + else: + return 0 + + +@pytest.fixture(scope="session") +def data_with_covariates(): + data = get_stallion_data() + data["month"] = data.date.dt.month.astype(str) + data["log_volume"] = np.log1p(data.volume) + data["weight"] = 1 + np.sqrt(data.volume) + + data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month + data["time_idx"] -= data["time_idx"].min() + + # convert special days into strings + special_days = [ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + data[special_days] = ( + data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category") + ) + data = data.astype(dict(industry_volume=float)) + + # select data subset + data = data[lambda x: x.sku.isin(data.sku.unique()[:2])][ + lambda x: x.agency.isin(data.agency.unique()[:2]) + ] + + # default target + data["target"] = data["volume"].clip(1e-3, 1.0) + + return data + + +def make_dataloaders(data_with_covariates, **kwargs): + training_cutoff = "2016-09-01" + max_encoder_length = 4 + max_prediction_length = 3 + + kwargs.setdefault("target", "volume") + kwargs.setdefault("group_ids", ["agency", "sku"]) + kwargs.setdefault("add_relative_time_idx", True) + kwargs.setdefault("time_varying_unknown_reals", ["volume"]) + + training = TimeSeriesDataSet( + data_with_covariates[lambda x: x.date < training_cutoff].copy(), + time_idx="time_idx", + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + **kwargs, # fixture parametrization + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data_with_covariates.copy(), + min_prediction_idx=training.index.time.max() + 1, + ) + train_dataloader = training.to_dataloader(train=True, batch_size=2, num_workers=0) + val_dataloader = validation.to_dataloader(train=False, batch_size=2, num_workers=0) + test_dataloader = validation.to_dataloader(train=False, batch_size=1, num_workers=0) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) + + +@pytest.fixture( + params=[ + dict(), + dict( + static_categoricals=["agency", "sku"], + static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + constant_fill_strategy={"volume": 0}, + categorical_encoders={"sku": NaNLabelEncoder(add_nan=True)}, + ), + dict(static_categoricals=["agency", "sku"]), + dict(randomize_length=True, min_encoder_length=2), + dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2), + dict(target_normalizer=GroupNormalizer(transformation="log1p")), + dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="softplus", center=False + ) + ), + dict(target="agency"), + # test multiple targets + dict(target=["industry_volume", "volume"]), + dict(target=["agency", "volume"]), + dict( + target=["agency", "volume"], min_encoder_length=1, min_prediction_length=1 + ), + dict(target=["agency", "volume"], weight="volume"), + # test weights + dict(target="volume", weight="volume"), + ], + scope="session", +) +def multiple_dataloaders_with_covariates(data_with_covariates, request): + return make_dataloaders(data_with_covariates, **request.param) + + +@pytest.fixture(scope="session") +def dataloaders_with_different_encoder_decoder_length(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "target", + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_with_covariates(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_reals=["discount"], + time_varying_unknown_reals=["target"], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_multi_target(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + time_varying_unknown_reals=["target", "discount"], + target=["target", "discount"], + add_relative_time_idx=False, + ) + + +@pytest.fixture(scope="session") +def dataloaders_fixed_window_without_covariates(): + data = generate_ar_data(seasonality=10.0, timesteps=50, n_series=2) + validation = data.series.iloc[:2] + + max_encoder_length = 30 + max_prediction_length = 10 + + training = TimeSeriesDataSet( + data[lambda x: ~x.series.isin(validation)], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + static_categoricals=[], + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + time_varying_unknown_reals=["value"], + target_normalizer=EncoderNormalizer(), + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data[lambda x: x.series.isin(validation)], + stop_randomization=True, + ) + batch_size = 2 + train_dataloader = training.to_dataloader( + train=True, batch_size=batch_size, num_workers=0 + ) + val_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + test_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) diff --git a/pytorch_forecasting/tests/_data_scenarios.py b/pytorch_forecasting/tests/_data_scenarios.py new file mode 100644 index 000000000..062db97dd --- /dev/null +++ b/pytorch_forecasting/tests/_data_scenarios.py @@ -0,0 +1,261 @@ +import numpy as np +import pytest +import torch + +from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data + +torch.manual_seed(23) + + +@pytest.fixture(scope="session") +def gpus(): + if torch.cuda.is_available(): + return [0] + else: + return 0 + + +def data_with_covariates(): + data = get_stallion_data() + data["month"] = data.date.dt.month.astype(str) + data["log_volume"] = np.log1p(data.volume) + data["weight"] = 1 + np.sqrt(data.volume) + + data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month + data["time_idx"] -= data["time_idx"].min() + + # convert special days into strings + special_days = [ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + data[special_days] = ( + data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category") + ) + data = data.astype(dict(industry_volume=float)) + + # select data subset + data = data[lambda x: x.sku.isin(data.sku.unique()[:2])][ + lambda x: x.agency.isin(data.agency.unique()[:2]) + ] + + # default target + data["target"] = data["volume"].clip(1e-3, 1.0) + + return data + + +def make_dataloaders(data_with_covariates, **kwargs): + training_cutoff = "2016-09-01" + max_encoder_length = 4 + max_prediction_length = 3 + + kwargs.setdefault("target", "volume") + kwargs.setdefault("group_ids", ["agency", "sku"]) + kwargs.setdefault("add_relative_time_idx", True) + kwargs.setdefault("time_varying_unknown_reals", ["volume"]) + + training = TimeSeriesDataSet( + data_with_covariates[lambda x: x.date < training_cutoff].copy(), + time_idx="time_idx", + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + **kwargs, # fixture parametrization + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data_with_covariates.copy(), + min_prediction_idx=training.index.time.max() + 1, + ) + train_dataloader = training.to_dataloader(train=True, batch_size=2, num_workers=0) + val_dataloader = validation.to_dataloader(train=False, batch_size=2, num_workers=0) + test_dataloader = validation.to_dataloader(train=False, batch_size=1, num_workers=0) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) + + +@pytest.fixture( + params=[ + dict(), + dict( + static_categoricals=["agency", "sku"], + static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + constant_fill_strategy={"volume": 0}, + categorical_encoders={"sku": NaNLabelEncoder(add_nan=True)}, + ), + dict(static_categoricals=["agency", "sku"]), + dict(randomize_length=True, min_encoder_length=2), + dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2), + dict(target_normalizer=GroupNormalizer(transformation="log1p")), + dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="softplus", center=False + ) + ), + dict(target="agency"), + # test multiple targets + dict(target=["industry_volume", "volume"]), + dict(target=["agency", "volume"]), + dict( + target=["agency", "volume"], min_encoder_length=1, min_prediction_length=1 + ), + dict(target=["agency", "volume"], weight="volume"), + # test weights + dict(target="volume", weight="volume"), + ], + scope="session", +) +def multiple_dataloaders_with_covariates(data_with_covariates, request): + return make_dataloaders(data_with_covariates, **request.param) + + +@pytest.fixture(scope="session") +def dataloaders_with_different_encoder_decoder_length(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "target", + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_with_covariates(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_reals=["discount"], + time_varying_unknown_reals=["target"], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_multi_target(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + time_varying_unknown_reals=["target", "discount"], + target=["target", "discount"], + add_relative_time_idx=False, + ) + + +@pytest.fixture(scope="session") +def dataloaders_fixed_window_without_covariates(): + data = generate_ar_data(seasonality=10.0, timesteps=50, n_series=2) + validation = data.series.iloc[:2] + + max_encoder_length = 30 + max_prediction_length = 10 + + training = TimeSeriesDataSet( + data[lambda x: ~x.series.isin(validation)], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + static_categoricals=[], + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + time_varying_unknown_reals=["value"], + target_normalizer=EncoderNormalizer(), + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data[lambda x: x.series.isin(validation)], + stop_randomization=True, + ) + batch_size = 2 + train_dataloader = training.to_dataloader( + train=True, batch_size=batch_size, num_workers=0 + ) + val_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + test_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py new file mode 100644 index 000000000..3e046cda1 --- /dev/null +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -0,0 +1,291 @@ +"""Automated tests based on the skbase test suite template.""" + +from inspect import isclass +import shutil + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.loggers import TensorBoardLogger +from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator + +from pytorch_forecasting._registry import all_objects +from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS +from pytorch_forecasting.tests._conftest import make_dataloaders + +# whether to test only estimators from modules that are changed w.r.t. main +# default is False, can be set to True by pytest --only_changed_modules True flag +ONLY_CHANGED_MODULES = False + + +class PackageConfig: + """Contains package config variables for test classes.""" + + # class variables which can be overridden by descendants + # ------------------------------------------------------ + + # package to search for objects + # expected type: str, package/module name, relative to python environment root + package_name = "pytorch_forecasting" + + # list of object types (class names) to exclude + # expected type: list of str, str are class names + exclude_objects = EXCLUDE_ESTIMATORS + + # list of tests to exclude + # expected type: dict of lists, key:str, value: List[str] + # keys are class names of estimators, values are lists of test names to exclude + excluded_tests = EXCLUDED_TESTS + + +class BaseFixtureGenerator(_BaseFixtureGenerator): + """Fixture generator for base testing functionality in sktime. + + Test classes inheriting from this and not overriding pytest_generate_tests + will have estimator and scenario fixtures parametrized out of the box. + + Descendants can override: + estimator_type_filter: str, class variable; None or scitype string + e.g., "forecaster", "transformer", "classifier", see BASE_CLASS_SCITYPE_LIST + which estimators are being retrieved and tested + fixture_sequence: list of str + sequence of fixture variable names in conditional fixture generation + _generate_[variable]: object methods, all (test_name: str, **kwargs) -> list + generating list of fixtures for fixture variable with name [variable] + to be used in test with name test_name + can optionally use values for fixtures earlier in fixture_sequence, + these must be input as kwargs in a call + is_excluded: static method (test_name: str, est: class) -> bool + whether test with name test_name should be excluded for estimator est + should be used only for encoding general rules, not individual skips + individual skips should go on the EXCLUDED_TESTS list in _config + requires _generate_object_class and _generate_object_instance as is + _excluded_scenario: static method (test_name: str, scenario) -> bool + whether scenario should be skipped in test with test_name test_name + requires _generate_estimator_scenario as is + + Fixtures parametrized + --------------------- + object_class: estimator inheriting from BaseObject + ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS + object_instance: instance of estimator inheriting from BaseObject + ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS + instances are generated by create_test_instance class method of object_class + trainer_kwargs: list of dict + ranges over dictionaries of kwargs for the trainer + """ + + # overrides object retrieval in scikit-base + def _all_objects(self): + """Retrieve list of all object classes of type self.object_type_filter. + + If self.object_type_filter is None, retrieve all objects. + If class, retrieve all classes inheriting from self.object_type_filter. + Otherwise (assumed str or list of str), retrieve all classes with tags + object_type in self.object_type_filter. + """ + filter = getattr(self, "object_type_filter", None) + + if isclass(filter): + object_types = filter.get_class_tag("object_type", None) + else: + object_types = filter + + obj_list = all_objects( + object_types=object_types, + return_names=False, + exclude_objects=self.exclude_objects, + ) + + if isclass(filter): + obj_list = [obj for obj in obj_list if issubclass(obj, filter)] + + # run_test_for_class selects the estimators to run + # based on whether they have changed, and whether they have all dependencies + # internally, uses the ONLY_CHANGED_MODULES flag, + # and checks the python env against python_dependencies tag + # obj_list = [obj for obj in obj_list if run_test_for_class(obj)] + + return obj_list + + # which sequence the conditional fixtures are generated in + fixture_sequence = [ + "object_metadata", + "object_class", + "object_instance", + "trainer_kwargs", + ] + + def _generate_object_metadata(self, test_name, **kwargs): + """Return object class fixtures. + + Fixtures parametrized + --------------------- + object_class: object inheriting from BaseObject + ranges over all object classes not excluded by self.excluded_tests + """ + object_classes_to_test = [ + est for est in self._all_objects() if not self.is_excluded(test_name, est) + ] + object_names = [est.name() for est in object_classes_to_test] + + return object_classes_to_test, object_names + + def _generate_object_class(self, test_name, **kwargs): + """Return object class fixtures. + + Fixtures parametrized + --------------------- + object_class: object inheriting from BaseObject + ranges over all object classes not excluded by self.excluded_tests + """ + all_metadata = self._all_objects() + all_cls = [est.get_model_cls() for est in all_metadata] + object_classes_to_test = [ + est for est in all_cls if not self.is_excluded(test_name, est) + ] + object_names = [est.__name__ for est in object_classes_to_test] + + return object_classes_to_test, object_names + + def _generate_trainer_kwargs(self, test_name, **kwargs): + """Return kwargs for the trainer. + + Fixtures parametrized + --------------------- + trainer_kwargs: dict + ranges over all kwargs for the trainer + """ + if "object_metadata" in kwargs.keys(): + obj_meta = kwargs["object_metadata"] + else: + return [] + + all_train_kwargs = obj_meta.get_test_train_params() + rg = range(len(all_train_kwargs)) + train_kwargs_names = [str(i) for i in rg] + + return all_train_kwargs, train_kwargs_names + + +def _integration( + estimator_cls, + data_with_covariates, + tmp_path, + data_loader_kwargs={}, + clip_target: bool = False, + trainer_kwargs=None, + **kwargs, +): + data_with_covariates = data_with_covariates.copy() + if clip_target: + data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0) + else: + data_with_covariates["target"] = data_with_covariates["volume"] + data_loader_default_kwargs = dict( + target="target", + time_varying_known_reals=["price_actual"], + time_varying_unknown_reals=["target"], + static_categoricals=["agency"], + add_relative_time_idx=True, + ) + data_loader_default_kwargs.update(data_loader_kwargs) + dataloaders_with_covariates = make_dataloaders( + data_with_covariates, **data_loader_default_kwargs + ) + + train_dataloader = dataloaders_with_covariates["train"] + val_dataloader = dataloaders_with_covariates["val"] + test_dataloader = dataloaders_with_covariates["test"] + + early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" + ) + + logger = TensorBoardLogger(tmp_path) + if trainer_kwargs is None: + trainer_kwargs = {} + trainer = pl.Trainer( + max_epochs=3, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + enable_checkpointing=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + logger=logger, + **trainer_kwargs, + ) + + net = estimator_cls.from_dataset( + train_dataloader.dataset, + hidden_size=5, + learning_rate=0.01, + log_gradient_flow=True, + log_interval=1000, + **kwargs, + ) + net.size() + try: + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + # check loading + net = estimator_cls.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path + ) + + # check prediction + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + trainer_kwargs=trainer_kwargs, + ) + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + trainer_kwargs=trainer_kwargs, + ) + + +class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator): + """Generic tests for all objects in the mini package.""" + + def test_doctest_examples(self, object_class): + """Runs doctests for estimator class.""" + from skbase.utils.doctest_run import run_doctest + + run_doctest(object_class, name=f"class {object_class.__name__}") + + def test_integration( + self, + object_metadata, + trainer_kwargs, + tmp_path, + ): + """Fails for certain, for testing.""" + from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss + from pytorch_forecasting.tests._data_scenarios import data_with_covariates + + data_with_covariates = data_with_covariates() + + object_class = object_metadata.get_model_cls() + + if "loss" in trainer_kwargs and isinstance( + trainer_kwargs["loss"], NegativeBinomialDistributionLoss + ): + data_with_covariates = data_with_covariates.assign( + volume=lambda x: x.volume.round() + ) + _integration(object_class, data_with_covariates, tmp_path, **trainer_kwargs)