diff --git a/notebooks/structural_components_dataclass.ipynb b/notebooks/structural_components_dataclass.ipynb new file mode 100644 index 000000000..611d76767 --- /dev/null +++ b/notebooks/structural_components_dataclass.ipynb @@ -0,0 +1,583 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ab70a522", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n" + ] + } + ], + "source": [ + "from pymc_extras.statespace.models.structural import (\n", + " RegressionComponent,\n", + " RegressionComponentDataClass,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "17021aa3", + "metadata": {}, + "outputs": [], + "source": [ + "# Current way\n", + "reg = RegressionComponent(\n", + " name=\"regression\",\n", + " state_names=[\"a\", \"b\"],\n", + " observed_state_names=[\"y\"],\n", + " innovations=True,\n", + " share_states=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "219eb5da", + "metadata": {}, + "outputs": [], + "source": [ + "# Proposed way\n", + "reg_dataclass = RegressionComponentDataClass(\n", + " name=\"regression\",\n", + " state_names=[\"a\", \"b\"],\n", + " observed_state_names=[\"y\"],\n", + " innovations=True,\n", + " share_states=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7ff76653", + "metadata": {}, + "source": [ + "# Reminder of current implementation" + ] + }, + { + "cell_type": "markdown", + "id": "c05f86f6", + "metadata": {}, + "source": [ + "Currently state names are a list of string that only contain the names of the states" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7e37e574", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['a[regression_shared]', 'b[regression_shared]']" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.state_names" + ] + }, + { + "cell_type": "markdown", + "id": "0d484b59", + "metadata": {}, + "source": [ + "In the proposed dataclass implementation each state is a `StateProperty` and all the states are `StateProporties` dataclasses." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "dee62a66", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "states: ['a[regression_shared]', 'b[regression_shared]']\n", + "observed: [True, True]\n" + ] + } + ], + "source": [ + "print(reg_dataclass.state_names)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "cebd72af", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: a[regression_shared]\n", + "observed: True\n", + "shared: True\n" + ] + } + ], + "source": [ + "print(reg_dataclass.state_names[\"a[regression_shared]\"]) # state name is the key" + ] + }, + { + "cell_type": "markdown", + "id": "1b8690a1", + "metadata": {}, + "source": [ + "Similarly with shock names we now have a shock_info that is a `ShockProperties` dataclass composed of `ShockProperty` dataclasses" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1320adac", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['a_shared', 'b_shared']" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.shock_names" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6c905946", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shocks: ['a_shared', 'b_shared']\n" + ] + } + ], + "source": [ + "print(reg_dataclass.shock_info)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ff60922", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: a_shared\n" + ] + } + ], + "source": [ + "print(reg_dataclass.shock_info[\"a_shared\"])" + ] + }, + { + "cell_type": "markdown", + "id": "bdbe8f7c", + "metadata": {}, + "source": [ + "This pattern continues for data, parameters, and coords as shown below" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ead54287", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['data_regression']" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.data_names" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ba784a4a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'data_regression': {'shape': (None, 2), 'dims': ('time', 'state_regression')}}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.data_info" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "521382b9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "data: ['data_regression']\n", + "needs exogenous data: True\n" + ] + } + ], + "source": [ + "print(reg_dataclass.data_info)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "85b7e774", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: data_regression\n", + "shape: (None, 2)\n", + "dims: ('time', 'state_regression')\n", + "is_exogenous: True\n" + ] + } + ], + "source": [ + "print(reg_dataclass.data_info[\"data_regression\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "e1ed9d7a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'beta_regression': {'shape': (2,),\n", + " 'constraints': None,\n", + " 'dims': ('state_regression',)},\n", + " 'sigma_beta_regression': {'shape': (2,),\n", + " 'constraints': 'Positive',\n", + " 'dims': ('state_regression',)}}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.param_info" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "8d194fe2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['beta_regression', 'sigma_beta_regression']" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.param_names" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "7fccad81", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'beta_regression': ('state_regression',),\n", + " 'sigma_beta_regression': ('state_regression',)}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.param_dims" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "9787c813", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "parameters: ['beta_regression', 'sigma_beta_regression']\n" + ] + } + ], + "source": [ + "print(reg_dataclass.param_info)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "914e97da", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: beta_regression\n", + "shape: (2,)\n", + "dims: ('state_regression',)\n" + ] + } + ], + "source": [ + "print(reg_dataclass.param_info[\"beta_regression\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "98875fd1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: sigma_beta_regression\n", + "shape: (2,)\n", + "dims: ('state_regression',)\n", + "constraints: Positive\n" + ] + } + ], + "source": [ + "print(reg_dataclass.param_info[\"sigma_beta_regression\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "a195cec5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'state_regression': ['a', 'b'], 'endog_regression': ['y']}" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg.coords" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "62622777", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "coordinates:\n", + " dimension: state_regression\n", + " labels: ['a', 'b']\n", + "\n", + " dimension: endog_regression\n", + " labels: ['y']\n", + "\n" + ] + } + ], + "source": [ + "print(reg_dataclass.coords)" + ] + }, + { + "cell_type": "markdown", + "id": "a79b845c", + "metadata": {}, + "source": [ + "# Mapping between items" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "9484c709", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "parameters: ['beta_regression', 'sigma_beta_regression']\n" + ] + } + ], + "source": [ + "# Important to be able to map between parameters -> dimensions -> dimension labels\n", + "print(reg_dataclass.param_info)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "85573fa2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: beta_regression\n", + "shape: (2,)\n", + "dims: ('state_regression',)\n" + ] + } + ], + "source": [ + "print(reg_dataclass.param_info[\"beta_regression\"]) # Key is parameter name" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "32f56fd4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dimension: state_regression\n", + "labels: ['a', 'b']\n" + ] + } + ], + "source": [ + "# dimension for parameter beta_regression is state_regression. Let's map to dimension labels\n", + "print(\n", + " reg_dataclass.coords[\n", + " reg_dataclass.param_info[\"beta_regression\"].dims[0] # Key is dimension name\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "35ae00a6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dimension: state_regression\n", + "labels: ['a', 'b']\n" + ] + } + ], + "source": [ + "# Equivalently\n", + "print(reg_dataclass.coords[\"state_regression\"])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc-extras", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pymc_extras/statespace/core/properties.py b/pymc_extras/statespace/core/properties.py new file mode 100644 index 000000000..56e875f94 --- /dev/null +++ b/pymc_extras/statespace/core/properties.py @@ -0,0 +1,261 @@ +from collections.abc import Iterator +from dataclasses import dataclass, fields +from typing import Generic, Self, TypeVar + +from pymc_extras.statespace.core import PyMCStateSpace +from pymc_extras.statespace.utils.constants import ( + ALL_STATE_AUX_DIM, + ALL_STATE_DIM, + OBS_STATE_AUX_DIM, + OBS_STATE_DIM, + SHOCK_AUX_DIM, + SHOCK_DIM, +) + + +@dataclass(frozen=True) +class Property: + def __str__(self) -> str: + return "\n".join(f"{f.name}: {getattr(self, f.name)}" for f in fields(self)) + + +T = TypeVar("T", bound=Property) + + +@dataclass(frozen=True) +class Info(Generic[T]): + items: tuple[T, ...] + key_field: str = "name" + _index: dict[str, T] | None = None + + def __post_init__(self): + index = {} + missing_attr = [] + for item in self.items: + if not hasattr(item, self.key_field): + missing_attr.append(item) + continue + key = getattr(item, self.key_field) + if key in index: + raise ValueError(f"Duplicate {self.key_field} '{key}' detected.") + index[key] = item + if missing_attr: + raise AttributeError(f"Items missing attribute '{self.key_field}': {missing_attr}") + object.__setattr__(self, "_index", index) + + def _key(self, item: T) -> str: + return getattr(item, self.key_field) + + def get(self, key: str, default=None) -> T | None: + return self._index.get(key, default) + + def __getitem__(self, key: str) -> T: + try: + return self._index[key] + except KeyError as e: + available = ", ".join(self._index.keys()) + raise KeyError(f"No {self.key_field} '{key}'. Available: [{available}]") from e + + def __contains__(self, key: object) -> bool: + return key in self._index + + def __iter__(self) -> Iterator[str]: + return iter(self.items) + + def __len__(self) -> int: + return len(self.items) + + def __str__(self) -> str: + return f"{self.key_field}s: {list(self._index.keys())}" + + @property + def names(self) -> tuple[str, ...]: + return tuple(self._index.keys()) + + +@dataclass(frozen=True) +class Parameter(Property): + name: str + shape: tuple[int, ...] + dims: tuple[str, ...] + constraints: str | None = None + + +@dataclass(frozen=True) +class ParameterInfo(Info[Parameter]): + def __init__(self, parameters: list[Parameter]): + super().__init__(items=tuple(parameters), key_field="name") + + def add(self, parameter: Parameter) -> "ParameterInfo": + # return a new ParameterInfo with parameter appended + return ParameterInfo(parameters=[*list(self.items), parameter]) + + def merge(self, other: "ParameterInfo") -> "ParameterInfo": + """Combine parameters from two ParameterInfo objects.""" + if not isinstance(other, ParameterInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with ParameterInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping: + raise ValueError(f"Duplicate parameter names found: {overlapping}") + + return ParameterInfo(parameters=list(self.items) + list(other.items)) + + +@dataclass(frozen=True) +class Data(Property): + name: str + shape: tuple[int, ...] + dims: tuple[str, ...] + is_exogenous: bool + + +@dataclass(frozen=True) +class DataInfo(Info[Data]): + def __init__(self, data: list[Data]): + super().__init__(items=tuple(data), key_field="name") + + @property + def needs_exogenous_data(self) -> bool: + return any(d.is_exogenous for d in self.items) + + def __str__(self) -> str: + return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}" + + def add(self, data: Data) -> "DataInfo": + # return a new DataInfo with data appended + return DataInfo(data=[*list(self.items), data]) + + def merge(self, other: "DataInfo") -> "DataInfo": + """Combine data from two DataInfo objects.""" + if not isinstance(other, DataInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with DataInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping: + raise ValueError(f"Duplicate data names found: {overlapping}") + + return DataInfo(data=list(self.items) + list(other.items)) + + +@dataclass(frozen=True) +class Coord(Property): + dimension: str + labels: tuple[str, ...] + + +@dataclass(frozen=True) +class CoordInfo(Info[Coord]): + def __init__(self, coords: list[Coord]): + super().__init__(items=tuple(coords), key_field="dimension") + + def __str__(self) -> str: + base = "coordinates:" + for coord in self.items: + coord_str = str(coord) + indented = "\n".join(" " + line for line in coord_str.splitlines()) + base += "\n" + indented + "\n" + return base + + @classmethod + def default_coords_from_model( + cls, model: PyMCStateSpace + ) -> ( + Self + ): # TODO: Need to figure out how to include Component type was causing circular import issues + states = tuple(model.state_names) + obs_states = tuple(model.observed_state_names) + shocks = tuple(model.shock_names) + + dim_to_labels = ( + (ALL_STATE_DIM, states), + (ALL_STATE_AUX_DIM, states), + (OBS_STATE_DIM, obs_states), + (OBS_STATE_AUX_DIM, obs_states), + (SHOCK_DIM, shocks), + (SHOCK_AUX_DIM, shocks), + ) + + coords = [Coord(dimension=dim, labels=labels) for dim, labels in dim_to_labels] + return cls(coords) + + def to_dict(self): + return {coord.dimension: coord.labels for coord in self.items if len(coord.labels) > 0} + + def add(self, coord: Coord) -> "CoordInfo": + # return a new CoordInfo with data appended + return CoordInfo(coords=[*list(self.items), coord]) + + def merge(self, other: "CoordInfo") -> "CoordInfo": + """Combine data from two CoordInfo objects.""" + if not isinstance(other, CoordInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with CoordInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping: + raise ValueError(f"Duplicate coord names found: {overlapping}") + + return CoordInfo(coords=list(self.items) + list(other.items)) + + +@dataclass(frozen=True) +class State(Property): + name: str + observed: bool + shared: bool + + +@dataclass(frozen=True) +class StateInfo(Info[State]): + def __init__(self, states: list[State]): + super().__init__(items=tuple(states), key_field="name") + + def __str__(self) -> str: + return ( + f"states: {[s.name for s in self.items]}\nobserved: {[s.observed for s in self.items]}" + ) + + @property + def observed_states(self) -> tuple[State, ...]: + return tuple(s for s in self.items if s.observed) + + def add(self, state: State) -> "StateInfo": + # return a new StateInfo with state appended + return StateInfo(states=[*list(self.items), state]) + + def merge(self, other: "StateInfo") -> "StateInfo": + """Combine states from two StateInfo objects.""" + if not isinstance(other, StateInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with StateInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping: + raise ValueError(f"Duplicate state names found: {overlapping}") + + return StateInfo(states=list(self.items) + list(other.items)) + + +@dataclass(frozen=True) +class Shock(Property): + name: str + + +@dataclass(frozen=True) +class ShockInfo(Info[Shock]): + def __init__(self, shocks: list[Shock]): + super().__init__(items=tuple(shocks), key_field="name") + + def add(self, shock: Shock) -> "ShockInfo": + # return a new ShockInfo with shock appended + return ShockInfo(shocks=[*list(self.items), shock]) + + def merge(self, other: "ShockInfo") -> "ShockInfo": + """Combine shocks from two ShockInfo objects.""" + if not isinstance(other, ShockInfo): + raise TypeError(f"Cannot merge {type(other).__name__} with ShockInfo") + + overlapping = set(self.names) & set(other.names) + if overlapping: + raise ValueError(f"Duplicate shock names found: {overlapping}") + + return ShockInfo(shocks=list(self.items) + list(other.items)) diff --git a/pymc_extras/statespace/models/structural/__init__.py b/pymc_extras/statespace/models/structural/__init__.py index f0bfb2f0a..8ef35c969 100644 --- a/pymc_extras/statespace/models/structural/__init__.py +++ b/pymc_extras/statespace/models/structural/__init__.py @@ -5,6 +5,9 @@ from pymc_extras.statespace.models.structural.components.level_trend import LevelTrendComponent from pymc_extras.statespace.models.structural.components.measurement_error import MeasurementError from pymc_extras.statespace.models.structural.components.regression import RegressionComponent +from pymc_extras.statespace.models.structural.components.regression_dataclass import ( + RegressionComponent as RegressionComponentDataClass, +) from pymc_extras.statespace.models.structural.components.seasonality import ( FrequencySeasonality, TimeSeasonality, @@ -17,5 +20,6 @@ "LevelTrendComponent", "MeasurementError", "RegressionComponent", + "RegressionComponentDataClass", "TimeSeasonality", ] diff --git a/pymc_extras/statespace/models/structural/components/regression.py b/pymc_extras/statespace/models/structural/components/regression.py index 5620b1ea7..1444b902b 100644 --- a/pymc_extras/statespace/models/structural/components/regression.py +++ b/pymc_extras/statespace/models/structural/components/regression.py @@ -2,6 +2,18 @@ from pytensor import tensor as pt +from pymc_extras.statespace.core.properties import ( + Coord, + CoordInfo, + Data, + DataInfo, + Parameter, + ParameterInfo, + Shock, + ShockInfo, + State, + StateInfo, +) from pymc_extras.statespace.models.structural.core import Component from pymc_extras.statespace.utils.constants import TIME_DIM @@ -194,64 +206,110 @@ def make_symbolic_graph(self) -> None: row_idx, col_idx = np.diag_indices(self.k_states) self.ssm["state_cov", row_idx, col_idx] = sigma_beta.ravel() ** 2 - def populate_component_properties(self) -> None: + def _set_parameters(self) -> None: k_endog = self.k_endog k_endog_effective = 1 if self.share_states else k_endog + k_states = self.k_states // k_endog_effective + + beta_param_name = f"beta_{self.name}" + beta_param_shape = (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,) + beta_param_dims = ( + (f"endog_{self.name}", f"state_{self.name}") + if k_endog_effective > 1 + else (f"state_{self.name}",) + ) + + beta_param_constraints = None + beta_parameter = Parameter( + name=beta_param_name, + shape=beta_param_shape, + dims=beta_param_dims, + constraints=beta_param_constraints, + ) + if self.innovations: + sigma_param_name = f"sigma_beta_{self.name}" + sigma_param_dims = (f"state_{self.name}",) + sigma_param_shape = (k_states,) + sigma_param_constraints = "Positive" + + sigma_parameter = Parameter( + name=sigma_param_name, + shape=sigma_param_shape, + dims=sigma_param_dims, + constraints=sigma_param_constraints, + ) + + self.param_info = ParameterInfo(parameters=[beta_parameter, sigma_parameter]) + self.param_names = self.param_info.names + else: + self.param_info = ParameterInfo(parameters=[beta_parameter]) + self.param_names = self.param_info.names + + def _set_data(self) -> None: + k_endog = self.k_endog + k_endog_effective = 1 if self.share_states else k_endog k_states = self.k_states // k_endog_effective + data_name = f"data_{self.name}" + data_shape = (None, k_states) + data_dims = (TIME_DIM, f"state_{self.name}") + + data_prop = Data(name=data_name, shape=data_shape, dims=data_dims, is_exogenous=True) + self.data_info = DataInfo(data=[data_prop]) + self.data_names = self.data_info.names + + def _set_shocks(self) -> None: if self.share_states: - self.shock_names = [f"{state_name}_shared" for state_name in self.state_names] + shock_names = [f"{state_name}_shared" for state_name in self.state_names] else: - self.shock_names = self.state_names + shock_names = self.state_names - self.param_names = [f"beta_{self.name}"] - self.data_names = [f"data_{self.name}"] - self.param_dims = { - f"beta_{self.name}": (f"endog_{self.name}", f"state_{self.name}") - if k_endog_effective > 1 - else (f"state_{self.name}",) - } + self.shock_info = ShockInfo(shocks=[Shock(name=name) for name in shock_names]) + self.shock_names = self.shock_info.names - base_names = self.state_names + def _set_states(self) -> None: + self.base_names = self.state_names if self.share_states: - self.state_names = [f"{name}[{self.name}_shared]" for name in base_names] + state_names = [f"{name}[{self.name}_shared]" for name in self.base_names] + self.state_info = StateInfo( + states=[State(name=name, observed=True, shared=True) for name in state_names] + ) + self.state_names = self.state_info.names else: - self.state_names = [ + state_names = [ f"{name}[{obs_name}]" for obs_name in self.observed_state_names - for name in base_names + for name in self.base_names ] + self.state_info = StateInfo( + states=[State(name=name, observed=True, shared=False) for name in state_names] + ) + self.state_names = self.state_info.names - self.param_info = { - f"beta_{self.name}": { - "shape": (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,), - "constraints": None, - "dims": (f"endog_{self.name}", f"state_{self.name}") - if k_endog_effective > 1 - else (f"state_{self.name}",), - }, - } - - self.data_info = { - f"data_{self.name}": { - "shape": (None, k_states), - "dims": (TIME_DIM, f"state_{self.name}"), - }, - } - self.coords = { - f"state_{self.name}": base_names, - f"endog_{self.name}": self.observed_state_names, - } + def _set_coords(self) -> None: + regression_state_coord = Coord( + dimension=f"state_{self.name}", labels=[state for state in self.base_names] + ) + endogenous_state_coord = Coord( + dimension=f"endog_{self.name}", labels=[state for state in self.observed_state_names] + ) - if self.innovations: - self.param_names += [f"sigma_beta_{self.name}"] - self.param_dims[f"sigma_beta_{self.name}"] = (f"state_{self.name}",) - self.param_info[f"sigma_beta_{self.name}"] = { - "shape": (k_states,), - "constraints": "Positive", - "dims": (f"state_{self.name}",) - if k_endog_effective == 1 - else (f"endog_{self.name}", f"state_{self.name}"), - } + self.coords = CoordInfo(coords=[regression_state_coord, endogenous_state_coord]) + + def populate_component_properties(self) -> None: + # Set parameter info + self._set_parameters() + + # Set data info + self._set_data() + + # Set shock info + self._set_shocks() + + # Set states info + self._set_states() + + # Set coordinates info + self._set_coords() diff --git a/pymc_extras/statespace/models/structural/components/regression_dataclass.py b/pymc_extras/statespace/models/structural/components/regression_dataclass.py new file mode 100644 index 000000000..607b2469b --- /dev/null +++ b/pymc_extras/statespace/models/structural/components/regression_dataclass.py @@ -0,0 +1,539 @@ +from dataclasses import dataclass, field + +import numpy as np + +from pytensor import tensor as pt + +from pymc_extras.statespace.models.structural.core import Component +from pymc_extras.statespace.utils.constants import TIME_DIM + + +@dataclass +class ParameterProperty: + name: str + shape: tuple[int, ...] + dims: tuple[str, ...] + constraints: str | None = None + + def __str__(self): + base = f"name: {self.name}\nshape: {self.shape}\ndims: {self.dims}" + if self.constraints: + return base + f"\nconstraints: {self.constraints}" + return base + + +@dataclass +class ParameterProperties: + parameters: list[ParameterProperty] + + def get_parameter(self, name: str) -> ParameterProperty | None: + return next((p for p in self.parameters if p.name == name), None) + + def __getitem__(self, name: str) -> ParameterProperty: + result = next((p for p in self.parameters if p.name == name), None) + if result is None: + raise KeyError(f"No parameter named '{name}'") + return result + + def __contains__(self, name: str) -> bool: + return any(p.name == name for p in self.parameters) + + def __str__(self): + base = f"parameters: {[parameter.name for parameter in self.parameters]}" + return base + + +@dataclass +class DataProperty: + name: str + shape: tuple[int, ...] + dims: tuple[str, ...] + is_exogenous: bool + + def __str__(self): + base = f"name: {self.name}\nshape: {self.shape}\ndims: {self.dims}\nis_exogenous: {self.is_exogenous}" + return base + + +@dataclass +class DataProperties: + data: list[DataProperty] + needs_exogenous_data: bool = field(default=False, init=False) + + def __post_init__(self): + for d in self.data: + if d.is_exogenous: + self.needs_exogenous_data = True + + def get_data(self, name: str) -> DataProperty | None: + return next((d for d in self.data if d.name == name), None) + + def __getitem__(self, name: str) -> DataProperty: + result = next((d for d in self.data if d.name == name), None) + if result is None: + raise KeyError(f"No data named '{name}'") + return result + + def __contains__(self, name: str) -> bool: + return any(d.name == name for d in self.data) + + def __str__(self): + base = f"data: {[d.name for d in self.data]}\nneeds exogenous data: {self.needs_exogenous_data}" + return base + + +@dataclass +class CoordProperty: + dimension: str + labels: list[str] + + def __str__(self): + base = f"dimension: {self.dimension}\nlabels: {self.labels}" + return base + + +@dataclass +class CoordProperties: + coords: list[CoordProperty] + + def get_coord(self, dimension: str) -> CoordProperty | None: + return next((c for c in self.coords if c.dimension == dimension), None) + + def __getitem__(self, dimension: str) -> CoordProperty: + result = next((c for c in self.coords if c.dimension == dimension), None) + if result is None: + raise KeyError(f"No coordinate named '{dimension}'") + return result + + def __contains__(self, dimension: str) -> bool: + return any(c.dimension == dimension for c in self.coords) + + def __str__(self): + base = "coordinates:" + for coord in self.coords: + coord_str = str(coord) + indented = "\n".join(" " + line for line in coord_str.splitlines()) + base += "\n" + indented + "\n" + return base + + +@dataclass +class StateProperty: + name: str + observed: bool + shared: bool + + def __str__(self): + base = f"name: {self.name}\nobserved: {self.observed}\nshared: {self.shared}" + return base + + +@dataclass +class StateProperties: + states: list[StateProperty] + + def get_state(self, name: str) -> StateProperty | None: + return next((s for s in self.states if s.name == name), None) + + def __getitem__(self, name: str) -> StateProperty: + result = next((s for s in self.states if s.name == name), None) + if result is None: + raise KeyError(f"No state named '{name}'") + return result + + def __contains__(self, name: str) -> bool: + return any(s.name == name for s in self.states) + + def __str__(self): + base = f"states: {[state.name for state in self.states]}\nobserved: {[state.observed for state in self.states]}" + return base + + +@dataclass +class ShockProperty: + name: str + + def __str__(self): + base = f"name: {self.name}" + return base + + +@dataclass +class ShockProperties: + shocks: list[ShockProperty] + + def get_state(self, name: str) -> ShockProperty | None: + return next((shock for shock in self.shocks if shock.name == name), None) + + def __getitem__(self, name: str) -> ShockProperty: + result = next((shock for shock in self.shocks if shock.name == name), None) + if result is None: + raise KeyError(f"No shock named '{name}'") + return result + + def __contains__(self, name: str) -> bool: + return any(shock.name == name for shock in self.shocks) + + def __str__(self): + base = f"shocks: {[shock.name for shock in self.shocks]}" + return base + + +class RegressionComponent(Component): + r""" + Regression component for exogenous variables in a structural time series model + + Parameters + ---------- + k_exog : int | None, default None + Number of exogenous variables to include in the regression. Must be specified if + state_names is not provided. + + name : str | None, default "regression" + A name for this regression component. Used to label dimensions and coordinates. + + state_names : list[str] | None, default None + List of strings for regression coefficient labels. If provided, must be of length + k_exog. If None and k_exog is provided, coefficients will be named + "{name}_1, {name}_2, ...". + + observed_state_names : list[str] | None, default None + List of strings for observed state labels. If None, defaults to ["data"]. + + innovations : bool, default False + Whether to include stochastic innovations in the regression coefficients, + allowing them to vary over time. If True, coefficients follow a random walk. + + share_states: bool, default False + Whether latent states are shared across the observed states. If True, there will be only one set of latent + states, which are observed by all observed states. If False, each observed state has its own set of + latent states. + + Notes + ----- + This component implements regression with exogenous variables in a structural time series + model. The regression component can be expressed as: + + .. math:: + y_t = \beta_t^T x_t + \epsilon_t + + Where :math:`y_t` is the dependent variable, :math:`x_t` is the vector of exogenous + variables, :math:`\beta_t` is the vector of regression coefficients, and :math:`\epsilon_t` + is the error term. + + When ``innovations=False`` (default), the coefficients are constant over time: + :math:`\beta_t = \beta_0` for all t. + + When ``innovations=True``, the coefficients follow a random walk: + :math:`\beta_{t+1} = \beta_t + \eta_t`, where :math:`\eta_t \sim N(0, \Sigma_\beta)`. + + The component supports both univariate and multivariate regression. In the multivariate + case, separate coefficients are estimated for each endogenous variable (i.e time series). + + Examples + -------- + Simple regression with constant coefficients: + + .. code:: python + + from pymc_extras.statespace import structural as st + import pymc as pm + import pytensor.tensor as pt + + trend = st.LevelTrendComponent(order=1, innovations_order=1) + regression = st.RegressionComponent(k_exog=2, state_names=['intercept', 'slope']) + ss_mod = (trend + regression).build() + + with pm.Model(coords=ss_mod.coords) as model: + # Prior for regression coefficients + betas = pm.Normal('betas', dims=ss_mod.param_dims['beta_regression']) + + # Prior for trend innovations + sigma_trend = pm.Exponential('sigma_trend', 1) + + ss_mod.build_statespace_graph(data) + idata = pm.sample() + + Multivariate regression with time-varying coefficients: + - There are 2 exogenous variables (price and income effects) + - There are 2 endogenous variables (sales and revenue) + - The regression coefficients are allowed to vary over time (`innovations=True`) + + .. code:: python + + regression = st.RegressionComponent( + k_exog=2, + state_names=['price_effect', 'income_effect'], + observed_state_names=['sales', 'revenue'], + innovations=True + ) + + with pm.Model(coords=ss_mod.coords) as model: + betas = pm.Normal('betas', dims=ss_mod.param_dims['beta_regression']) + + # Innovation variance for time-varying coefficients + sigma_beta = pm.Exponential('sigma_beta', 1, dims=ss_mod.param_dims['sigma_beta_regression']) + + ss_mod.build_statespace_graph(data) + idata = pm.sample() + """ + + def __init__( + self, + k_exog: int | None = None, + name: str | None = "regression", + state_names: list[str] | None = None, + observed_state_names: list[str] | None = None, + innovations=False, + share_states: bool = False, + ): + self.share_states = share_states + + if observed_state_names is None: + observed_state_names = ["data"] + + self.innovations = innovations + k_exog = self._handle_input_data(k_exog, state_names, name) + + k_states = k_exog + k_endog = len(observed_state_names) + k_posdef = k_exog + + super().__init__( + name=name, + k_endog=k_endog, + k_states=k_states * k_endog if not share_states else k_states, + k_posdef=k_posdef * k_endog if not share_states else k_posdef, + state_names=self.state_names, + share_states=share_states, + observed_state_names=observed_state_names, + measurement_error=False, + combine_hidden_states=False, + exog_names=[f"data_{name}"], + obs_state_idxs=np.ones(k_states), + ) + + @staticmethod + def _get_state_names(k_exog: int | None, state_names: list[str] | None, name: str): + if k_exog is None and state_names is None: + raise ValueError("Must specify at least one of k_exog or state_names") + if state_names is not None and k_exog is not None: + if len(state_names) != k_exog: + raise ValueError(f"Expected {k_exog} state names, found {len(state_names)}") + elif k_exog is None: + k_exog = len(state_names) + else: + state_names = [f"{name}_{i + 1}" for i in range(k_exog)] + + return k_exog, state_names + + def _handle_input_data(self, k_exog: int, state_names: list[str] | None, name) -> int: + k_exog, state_names = self._get_state_names(k_exog, state_names, name) + self.state_names = state_names + + return k_exog + + def make_symbolic_graph(self) -> None: + k_endog = self.k_endog + k_endog_effective = 1 if self.share_states else k_endog + + k_states = self.k_states // k_endog_effective + + betas = self.make_and_register_variable( + f"beta_{self.name}", shape=(k_endog, k_states) if k_endog_effective > 1 else (k_states,) + ) + regression_data = self.make_and_register_data(f"data_{self.name}", shape=(None, k_states)) + + self.ssm["initial_state", :] = betas.ravel() + self.ssm["transition", :, :] = pt.eye(self.k_states) + self.ssm["selection", :, :] = pt.eye(self.k_states) + + if self.share_states: + self.ssm["design"] = pt.specify_shape( + pt.join(1, *[pt.expand_dims(regression_data, 1) for _ in range(k_endog)]), + (None, k_endog, self.k_states), + ) + else: + Z = pt.linalg.block_diag(*[pt.expand_dims(regression_data, 1) for _ in range(k_endog)]) + self.ssm["design"] = pt.specify_shape( + Z, (None, k_endog, regression_data.type.shape[1] * k_endog) + ) + + if self.innovations: + sigma_beta = self.make_and_register_variable( + f"sigma_beta_{self.name}", + (k_states,) if k_endog_effective == 1 else (k_endog, k_states), + ) + row_idx, col_idx = np.diag_indices(self.k_states) + self.ssm["state_cov", row_idx, col_idx] = sigma_beta.ravel() ** 2 + + def _set_parameters(self) -> None: + k_endog = self.k_endog + k_endog_effective = 1 if self.share_states else k_endog + k_states = self.k_states // k_endog_effective + + beta_param_name = f"beta_{self.name}" + beta_param_shape = (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,) + beta_param_dims = ( + (f"endog_{self.name}", f"state_{self.name}") + if k_endog_effective > 1 + else (f"state_{self.name}",) + ) + + beta_param_constraints = None + + if self.innovations: + sigma_param_name = f"sigma_beta_{self.name}" + sigma_param_dims = (f"state_{self.name}",) + sigma_param_shape = (k_states,) + sigma_param_constraints = "Positive" + + beta_parameter = ParameterProperty( + name=beta_param_name, + shape=beta_param_shape, + dims=beta_param_dims, + constraints=beta_param_constraints, + ) + + sigma_parameter = ParameterProperty( + name=sigma_param_name, + shape=sigma_param_shape, + dims=sigma_param_dims, + constraints=sigma_param_constraints, + ) + + self.param_info = ParameterProperties(parameters=[beta_parameter, sigma_parameter]) + + def _set_data(self) -> None: + k_endog = self.k_endog + k_endog_effective = 1 if self.share_states else k_endog + k_states = self.k_states // k_endog_effective + + data_name = f"data_{self.name}" + data_shape = (None, k_states) + data_dims = (TIME_DIM, f"state_{self.name}") + + data_prop = DataProperty( + name=data_name, shape=data_shape, dims=data_dims, is_exogenous=True + ) + self.data_info = DataProperties(data=[data_prop]) + + def _set_shocks(self) -> None: + if self.share_states: + shock_names = [f"{state_name}_shared" for state_name in self.state_names] + else: + shock_names = self.state_names + + self.shock_info = ShockProperties(shocks=[ShockProperty(name=name) for name in shock_names]) + + def _set_states(self) -> None: + self.base_names = self.state_names + + if self.share_states: + state_names = [f"{name}[{self.name}_shared]" for name in self.base_names] + self.state_names = StateProperties( + states=[ + StateProperty(name=name, observed=True, shared=True) for name in state_names + ] + ) + else: + state_names = [ + f"{name}[{obs_name}]" + for obs_name in self.observed_state_names + for name in self.base_names + ] + self.state_names = StateProperties( + states=[ + StateProperty(name=name, observed=True, shared=False) for name in state_names + ] + ) + + def _set_coords(self) -> None: + regression_state_prop = CoordProperty( + dimension=f"state_{self.name}", labels=[state for state in self.base_names] + ) + endogenous_state_prop = CoordProperty( + dimension=f"endog_{self.name}", labels=[state for state in self.observed_state_names] + ) + + self.coords = CoordProperties(coords=[regression_state_prop, endogenous_state_prop]) + + def populate_component_properties(self) -> None: + # k_endog_eff, k_states = self._effective_shape_info() + + # 1. Set parameter info + self._set_parameters() + + # 2. Set data info + self._set_data() + + # 3. Set shock info + self._set_shocks() + + # 4. Set states info + self._set_states() + + # 5. Set coordinates info + self._set_coords() + + # def populate_component_properties(self) -> None: + # k_endog = self.k_endog + # k_endog_effective = 1 if self.share_states else k_endog + + # k_states = self.k_states // k_endog_effective + + # if self.share_states: + # self.shock_names = [f"{state_name}_shared" for state_name in self.state_names] + # else: + # self.shock_names = self.state_names + + # self.param_names = [f"beta_{self.name}"] + # self.data_names = [f"data_{self.name}"] + # self.param_dims = { + # f"beta_{self.name}": (f"endog_{self.name}", f"state_{self.name}") + # if k_endog_effective > 1 + # else (f"state_{self.name}",) + # } + + # base_names = self.state_names + + # if self.share_states: + # self.state_names = [f"{name}[{self.name}_shared]" for name in base_names] + # else: + # self.state_names = [ + # f"{name}[{obs_name}]" + # for obs_name in self.observed_state_names + # for name in base_names + # ] + + # self.param_info = { + # f"beta_{self.name}": { + # "shape": (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,), + # "constraints": None, + # "dims": (f"endog_{self.name}", f"state_{self.name}") + # if k_endog_effective > 1 + # else (f"state_{self.name}",), + # }, + # } + + # self.data_info = { + # f"data_{self.name}": { + # "shape": (None, k_states), + # "dims": (TIME_DIM, f"state_{self.name}"), + # }, + # } + # self.coords = { + # f"state_{self.name}": base_names, + # f"endog_{self.name}": self.observed_state_names, + # } + + # if self.innovations: + # self.param_names += [f"sigma_beta_{self.name}"] + # self.param_dims[f"sigma_beta_{self.name}"] = (f"state_{self.name}",) + # self.param_info[f"sigma_beta_{self.name}"] = { + # "shape": (k_states,), + # "constraints": "Positive", + # "dims": (f"state_{self.name}",) + # if k_endog_effective == 1 + # else (f"endog_{self.name}", f"state_{self.name}"), + # } diff --git a/pymc_extras/statespace/models/structural/core.py b/pymc_extras/statespace/models/structural/core.py index a2718251b..7b159e7f7 100644 --- a/pymc_extras/statespace/models/structural/core.py +++ b/pymc_extras/statespace/models/structural/core.py @@ -2,6 +2,7 @@ import logging from collections.abc import Sequence +from dataclasses import is_dataclass from itertools import pairwise from typing import Any @@ -12,6 +13,10 @@ from pytensor import tensor as pt from pymc_extras.statespace.core import PyMCStateSpace, PytensorRepresentation +from pymc_extras.statespace.core.properties import ( + Parameter, + ParameterInfo, +) from pymc_extras.statespace.models.utilities import ( add_tensors_by_dim_labels, conform_time_varying_and_time_invariant_matrices, @@ -136,6 +141,7 @@ class StructuralTimeSeries(PyMCStateSpace): methods (2nd ed.). Oxford University Press. """ + # TODO need to discuss cutting some of these args down. All the _names args are already inside of _info def __init__( self, ssm: PytensorRepresentation, @@ -150,6 +156,8 @@ def __init__( coords: dict[str, Sequence], param_info: dict[str, dict[str, Any]], data_info: dict[str, dict[str, Any]], + shock_info: dict[str, dict[str, Any]], + state_info: dict[str, dict[str, Any]], component_info: dict[str, dict[str, Any]], measurement_error: bool, name_to_variable: dict[str, Variable], @@ -165,7 +173,7 @@ def __init__( k_states, k_posdef, k_endog = ssm.k_states, ssm.k_posdef, ssm.k_endog param_names, param_dims, param_info = self._add_inital_state_cov_to_properties( - param_names, param_dims, param_info, k_states + param_info, k_states ) self._state_names = self._strip_data_names_if_unambiguous(state_names, k_endog) @@ -175,13 +183,13 @@ def __init__( self._param_dims = param_dims default_coords = make_default_coords(self) - coords.update(default_coords) + coords = coords.merge(default_coords) - self._coords = { - k: self._strip_data_names_if_unambiguous(v, k_endog) for k, v in coords.items() - } - self._param_info = param_info.copy() - self._data_info = data_info.copy() + self._coord_info = coords + self._param_info = param_info # .copy() #TODO add __copy__ to base class + self._data_info = data_info # .copy() + self._shock_info = shock_info + self._state_info = state_info self.measurement_error = measurement_error super().__init__( @@ -236,16 +244,25 @@ def _strip_data_names_if_unambiguous(self, names: list[str], k_endog: int): return names @staticmethod - def _add_inital_state_cov_to_properties(param_names, param_dims, param_info, k_states): - param_names += ["P0"] - param_dims["P0"] = (ALL_STATE_DIM, ALL_STATE_AUX_DIM) - param_info["P0"] = { - "shape": (k_states, k_states), - "constraints": "Positive semi-definite", - "dims": param_dims["P0"], - } + def _add_inital_state_cov_to_properties(param_info, k_states): + initial_state_cov_name = "P0" + initial_state_cov_shape = (k_states, k_states) + initial_state_cov_dims = (ALL_STATE_DIM, ALL_STATE_AUX_DIM) + initial_state_cov_constraints = "Positive semi-definite" + + initial_state_cov_param = Parameter( + name=initial_state_cov_name, + shape=initial_state_cov_shape, + dims=initial_state_cov_dims, + constraints=initial_state_cov_constraints, + ) + + if is_dataclass(param_info): + param_info = param_info.add(initial_state_cov_param) + else: + param_info = ParameterInfo(parameters=[initial_state_cov_param]) - return param_names, param_dims, param_info + return param_info.names, [p.dims for p in param_info], param_info @property def param_names(self): @@ -271,9 +288,9 @@ def shock_names(self): def param_dims(self): return self._param_dims - @property + @property # TODO discuss naming convention _info and need to clean up type hints def coords(self) -> dict[str, Sequence]: - return self._coords + return self._coord_info @property def param_info(self) -> dict[str, dict[str, Any]]: @@ -283,6 +300,14 @@ def param_info(self) -> dict[str, dict[str, Any]]: def data_info(self) -> dict[str, dict[str, Any]]: return self._data_info + @property + def state_info(self) -> dict[str, dict[str, Any]]: + return self._state_info + + @property + def shock_info(self) -> dict[str, dict[str, Any]]: + return self._shock_info + def make_symbolic_graph(self) -> None: """ Assign placeholder pytensor variables among statespace matrices in positions where PyMC variables will go. @@ -540,6 +565,8 @@ def __init__( self.param_info = {} self.data_info = {} + self.shock_info = {} + self.state_info = {} self.param_counts = {} @@ -595,7 +622,7 @@ def make_and_register_variable(self, name, shape, dtype=floatX) -> Variable: An error is raised if the provided name has already been registered, or if the name is not present in the ``param_names`` property. """ - if name not in self.param_names: + if name not in self.param_info: raise ValueError( f"{name} is not a model parameter. All placeholder variables should correspond to model " f"parameters." @@ -632,7 +659,7 @@ def make_and_register_data(self, name, shape, dtype=floatX) -> Variable: An error is raised if the provided name has already been registered, or if the name is not present in the ``data_names`` property. """ - if name not in self.data_names: + if name not in self.data_info: raise ValueError( f"{name} is not a model parameter. All placeholder variables should correspond to model " f"parameters." @@ -648,6 +675,21 @@ def make_and_register_data(self, name, shape, dtype=floatX) -> Variable: self._name_to_data[name] = placeholder return placeholder + def _set_parameters(self) -> None: + raise NotImplementedError + + def _set_data(self) -> None: + raise NotImplementedError + + def _set_shocks(self) -> None: + raise NotImplementedError + + def _set_states(self) -> None: + raise NotImplementedError + + def _set_coords(self) -> None: + raise NotImplementedError + def make_symbolic_graph(self) -> None: raise NotImplementedError @@ -764,15 +806,17 @@ def _combine_property(self, other, name, allow_duplicates=True): self_prop = getattr(self, name) other_prop = getattr(other, name) + # TODO discuss limiting the types we get here to only a dataclass type. By making the dataclasses immutable we now have to handle for tuples too. + if not isinstance(self_prop, type(other_prop)): raise TypeError( f"Property {name} of {self} and {other} are not the same and cannot be combined. Found " f"{type(self_prop)} for {self} and {type(other_prop)} for {other}'" ) - if not isinstance(self_prop, list | dict): + if not is_dataclass(self_prop) and not isinstance(self_prop, list | tuple | dict): raise TypeError( - f"All component properties are expected to be lists or dicts, but found {type(self_prop)}" + f"All component properties are expected to be dataclasses, but found {type(self_prop)}" f"for property {name} of {self} and {type(other_prop)} for {other}'" ) @@ -784,6 +828,12 @@ def _combine_property(self, other, name, allow_duplicates=True): new_prop = self_prop.copy() new_prop.update(other_prop) return new_prop + # TODO need to handle allow_duplicates but want to wait for above discussion first to see if we can cut down to just dataclass types + elif isinstance(self_prop, tuple): + new_prop = self_prop + other_prop + return new_prop + elif is_dataclass(self_prop): + return self_prop.merge(other_prop) def _combine_component_info(self, other): combined_info = {} @@ -817,6 +867,8 @@ def __add__(self, other): shock_names = self._combine_property(other, "shock_names") param_info = self._combine_property(other, "param_info") data_info = self._combine_property(other, "data_info") + shock_info = self._combine_property(other, "shock_info") + state_info = self._combine_property(other, "state_info") param_dims = self._combine_property(other, "param_dims") coords = self._combine_property(other, "coords") exog_names = self._combine_property(other, "exog_names") @@ -854,6 +906,8 @@ def __add__(self, other): ("param_dims", param_dims), ("param_info", param_info), ("data_info", data_info), + ("shock_info", shock_info), + ("state_info", state_info), ("exog_names", exog_names), ("_name_to_variable", _name_to_variable), ("_name_to_data", _name_to_data), @@ -908,6 +962,8 @@ def build( coords=self.coords, param_info=self.param_info, data_info=self.data_info, + shock_info=self.shock_info, + state_info=self.state_info, component_info=self._component_info, measurement_error=self.measurement_error, exog_names=self.exog_names, diff --git a/pymc_extras/statespace/models/utilities.py b/pymc_extras/statespace/models/utilities.py index 33be8d47d..cab8f3b3c 100644 --- a/pymc_extras/statespace/models/utilities.py +++ b/pymc_extras/statespace/models/utilities.py @@ -5,6 +5,7 @@ from pytensor.tensor import TensorVariable +from pymc_extras.statespace.core.properties import Coord, CoordInfo from pymc_extras.statespace.utils.constants import ( ALL_STATE_AUX_DIM, ALL_STATE_DIM, @@ -19,14 +20,23 @@ def make_default_coords(ss_mod): - coords = { - ALL_STATE_DIM: ss_mod.state_names, - ALL_STATE_AUX_DIM: ss_mod.state_names, - OBS_STATE_DIM: ss_mod.observed_states, - OBS_STATE_AUX_DIM: ss_mod.observed_states, - SHOCK_DIM: ss_mod.shock_names, - SHOCK_AUX_DIM: ss_mod.shock_names, - } + ALL_STATE_COORD = Coord(dimension=ALL_STATE_DIM, labels=ss_mod.state_names) + ALL_STATE_AUX_COORD = Coord(dimension=ALL_STATE_AUX_DIM, labels=ss_mod.state_names) + OBS_STATE_COORD = Coord(dimension=OBS_STATE_DIM, labels=ss_mod.observed_states) + OBS_STATE_AUX_COORD = Coord(dimension=OBS_STATE_AUX_DIM, labels=ss_mod.observed_states) + SHOCK_COORD = Coord(dimension=SHOCK_DIM, labels=ss_mod.shock_names) + SHOCK_AUX_COORD = Coord(dimension=SHOCK_AUX_DIM, labels=ss_mod.shock_names) + + coords = CoordInfo( + coords=[ + ALL_STATE_COORD, + ALL_STATE_AUX_COORD, + OBS_STATE_COORD, + OBS_STATE_AUX_COORD, + SHOCK_COORD, + SHOCK_AUX_COORD, + ] + ) return coords diff --git a/tests/statespace/core/test_properties.py b/tests/statespace/core/test_properties.py new file mode 100644 index 000000000..7f7cb8ae3 --- /dev/null +++ b/tests/statespace/core/test_properties.py @@ -0,0 +1,119 @@ +import pytest + +from pymc_extras.statespace.core.properties import ( + CoordInfo, + Data, + DataInfo, + Parameter, + ParameterInfo, + Shock, + ShockInfo, + State, + StateInfo, +) +from pymc_extras.statespace.utils.constants import ( + ALL_STATE_AUX_DIM, + ALL_STATE_DIM, + OBS_STATE_AUX_DIM, + OBS_STATE_DIM, + SHOCK_AUX_DIM, + SHOCK_DIM, +) + + +def test_property_str_formats_fields(): + p = Parameter(name="alpha", shape=(2,), dims=("param",)) + s = str(p).splitlines() + assert s == [ + "name: alpha", + "shape: (2,)", + "dims: ('param',)", + "constraints: None", + ] + + +def test_info_lookup_contains_and_missing_key(): + params = [ + Parameter("a", (1,), ("d",)), + Parameter("b", (2,), ("d",)), + Parameter("c", (3,), ("d",)), + ] + info = ParameterInfo(params) + + assert info.get("b").name == "b" + assert info["a"].shape == (1,) + assert "c" in info + + with pytest.raises(KeyError) as e: + _ = info["missing"] + assert "No name 'missing'" in str(e.value) + + +def test_data_info_needs_exogenous_and_str(): + data = [ + Data("price", (10,), ("time",), is_exogenous=False), + Data("x", (10,), ("time",), is_exogenous=True), + ] + info = DataInfo(data) + + assert info.needs_exogenous_data is True + s = str(info) + assert "data: ['price', 'x']" in s + assert "needs exogenous data: True" in s + + no_exog = DataInfo([Data("y", (10,), ("time",), is_exogenous=False)]) + assert no_exog.needs_exogenous_data is False + + +def test_coord_info_make_defaults_from_component_and_types(): + class DummyComponent: + state_names = ["x1", "x2"] + observed_state_names = ["x2"] + shock_names = ["eps1"] + + ci = CoordInfo.default_coords_from_model(DummyComponent()) + + expected = [ + (ALL_STATE_DIM, ("x1", "x2")), + (ALL_STATE_AUX_DIM, ("x1", "x2")), + (OBS_STATE_DIM, ("x2",)), + (OBS_STATE_AUX_DIM, ("x2",)), + (SHOCK_DIM, ("eps1",)), + (SHOCK_AUX_DIM, ("eps1",)), + ] + + assert len(ci.items) == 6 + for dim, labels in expected: + assert dim in ci + assert ci[dim].labels == labels + assert isinstance(ci[dim].labels, tuple) + + +def test_state_info_and_shockinfo_basic(): + states = [ + State("x1", observed=True, shared=False), + State("x2", observed=False, shared=True), + ] + state_info = StateInfo(states) + assert state_info["x1"].observed is True + s = str(state_info) + + assert "states: ['x1', 'x2']" in s + assert "observed: [True, False]" in s + + shocks = [Shock("s1"), Shock("s2")] + shock_info = ShockInfo(shocks) + + assert "s1" in shock_info + assert shock_info["s2"].name == "s2" + + +def test_info_is_iterable_and_unpackable(): + items = [Parameter("p1", (1,), ("d",)), Parameter("p2", (2,), ("d",))] + info = ParameterInfo(items) + + names = info.names + assert names == ("p1", "p2") + + a, b = info.items + assert a.name == "p1" and b.name == "p2" diff --git a/tests/statespace/models/structural/components/test_regression.py b/tests/statespace/models/structural/components/test_regression.py index c1732997d..ffde3348f 100644 --- a/tests/statespace/models/structural/components/test_regression.py +++ b/tests/statespace/models/structural/components/test_regression.py @@ -66,7 +66,7 @@ def test_exogenous_component(self, rng, regression_data, innovations): mod = mod.build(verbose=False) _assert_basic_coords_correct(mod) - assert mod.coords["state_exog"] == ["feature_1", "feature_2"] + assert mod.coords["state_exog"].labels == ["feature_1", "feature_2"] if innovations: # Check that sigma_beta parameter is included @@ -125,7 +125,7 @@ def test_regression_with_multiple_observed_states(self, rng, regression_data, in assert_allclose(x[0, 2:], params["beta_exog"][1], atol=ATOL, rtol=RTOL) mod = mod.build(verbose=False) - assert mod.coords["state_exog"] == ["feature_1", "feature_2"] + assert mod.coords["state_exog"].labels == ["feature_1", "feature_2"] Z = mod.ssm["design"].eval({"data_exog": regression_data}) vec_block_diag = np.vectorize(block_diag, signature="(n,m),(o,p)->(q,r)") @@ -164,8 +164,8 @@ def test_add_regression_components_with_multiple_observed_states( ) mod = (reg1 + reg2).build(verbose=False) - assert mod.coords["state_exog1"] == ["a", "b"] - assert mod.coords["state_exog2"] == ["c"] + assert mod.coords["state_exog1"].labels == ["a", "b"] + assert mod.coords["state_exog2"].labels == ["c"] Z = mod.ssm["design"].eval( { @@ -211,7 +211,7 @@ def test_filter_scans_time_varying_design_matrix(self, rng, time_series_data, in reg = st.RegressionComponent(state_names=["a", "b"], name="exog", innovations=innovations) mod = reg.build(verbose=False) - with pm.Model(coords=mod.coords) as m: + with pm.Model(coords=mod.coords.to_dict()) as m: data_exog = pm.Data("data_exog", data.values) x0 = pm.Normal("x0", dims=["state"]) @@ -249,14 +249,12 @@ def test_regression_multiple_shared_construction(): assert mod.k_states == 1 assert mod.k_posdef == 1 - assert mod.coords["state_regression"] == ["A"] - assert mod.coords["endog_regression"] == ["data_1", "data_2"] + assert mod.coords["state_regression"].labels == ["A"] + assert mod.coords["endog_regression"].labels == ["data_1", "data_2"] - assert mod.state_names == [ - "A[regression_shared]", - ] + assert mod.state_names == ("A[regression_shared]",) - assert mod.shock_names == ["A_shared"] + assert mod.shock_names == ("A_shared",) data = np.random.standard_normal(size=(10, 1)) Z = mod.ssm["design"].eval({"data_regression": data}) @@ -312,8 +310,8 @@ def test_regression_mixed_shared_and_not_shared(): assert mod.k_states == 4 assert mod.k_posdef == 4 - assert mod.state_names == ["A[data_1]", "A[data_2]", "B[joint_shared]", "C[joint_shared]"] - assert mod.shock_names == ["A", "B_shared", "C_shared"] + assert mod.state_names == ("A[data_1]", "A[data_2]", "B[joint_shared]", "C[joint_shared]") + assert mod.shock_names == ("A", "B_shared", "C_shared") data_joint = np.random.standard_normal(size=(10, 2)) data_individual = np.random.standard_normal(size=(10, 1)) diff --git a/tests/statespace/models/structural/conftest.py b/tests/statespace/models/structural/conftest.py index b9e58ca68..15dac710d 100644 --- a/tests/statespace/models/structural/conftest.py +++ b/tests/statespace/models/structural/conftest.py @@ -19,11 +19,11 @@ def rng(): def _assert_basic_coords_correct(mod): - assert mod.coords[ALL_STATE_DIM] == mod.state_names - assert mod.coords[ALL_STATE_AUX_DIM] == mod.state_names - assert mod.coords[SHOCK_DIM] == mod.shock_names - assert mod.coords[SHOCK_AUX_DIM] == mod.shock_names + assert mod.coords[ALL_STATE_DIM].labels == mod.state_names + assert mod.coords[ALL_STATE_AUX_DIM].labels == mod.state_names + assert mod.coords[SHOCK_DIM].labels == mod.shock_names + assert mod.coords[SHOCK_AUX_DIM].labels == mod.shock_names expected_obs = mod.observed_state_names if hasattr(mod, "observed_state_names") else ["data"] - assert mod.coords[OBS_STATE_DIM] == expected_obs - assert mod.coords[OBS_STATE_AUX_DIM] == expected_obs + assert mod.coords[OBS_STATE_DIM].labels == expected_obs + assert mod.coords[OBS_STATE_AUX_DIM].labels == expected_obs