|
| 1 | +"""Module for the SINDy model class.""" |
| 2 | + |
| 3 | +from typing import Callable |
| 4 | +import torch |
| 5 | +from ..utils import check_consistency, check_positive_integer |
| 6 | + |
| 7 | + |
| 8 | +class SINDy(torch.nn.Module): |
| 9 | + r""" |
| 10 | + SINDy model class. |
| 11 | +
|
| 12 | + The Sparse Identification of Nonlinear Dynamics (SINDy) model identifies the |
| 13 | + governing equations of a dynamical system from data by learning a sparse |
| 14 | + linear combination of non-linear candidate functions. |
| 15 | +
|
| 16 | + The output of the model is expressed as product of a library matrix and a |
| 17 | + coefficient matrix: |
| 18 | +
|
| 19 | + .. math:: |
| 20 | +
|
| 21 | + \dot{X} = \Theta(X) \Xi |
| 22 | +
|
| 23 | + where: |
| 24 | + - :math:`X \in \mathbb{R}^{B \times D}` is the input snapshots of the |
| 25 | + system state. Here, :math:`B` is the batch size and :math:`D` is the |
| 26 | + number of state variables. |
| 27 | + - :math:`\Theta(X) \in \mathbb{R}^{B \times L}` is the library matrix |
| 28 | + obtained by evaluating a set of candidate functions on the input data. |
| 29 | + Here, :math:`L` is the number of candidate functions in the library. |
| 30 | + - :math:`\Xi \in \mathbb{R}^{L \times D}` is the learned coefficient |
| 31 | + matrix that defines the sparse model. |
| 32 | +
|
| 33 | + .. seealso:: |
| 34 | +
|
| 35 | + **Original reference**: |
| 36 | + Brunton, S.L., Proctor, J.L., and Kutz, J.N. (2016). |
| 37 | + *Discovering governing equations from data: Sparse identification of |
| 38 | + non-linear dynamical systems.* |
| 39 | + Proceedings of the National Academy of Sciences, 113(15), 3932-3937. |
| 40 | + DOI: `10.1073/pnas.1517384113 |
| 41 | + <https://doi.org/10.1073/pnas.1517384113>`_ |
| 42 | + """ |
| 43 | + |
| 44 | + def __init__(self, library, output_dimension): |
| 45 | + """ |
| 46 | + Initialization of the :class:`SINDy` class. |
| 47 | +
|
| 48 | + :param list[Callable] library: The collection of candidate functions |
| 49 | + used to construct the library matrix. Each function must accept an |
| 50 | + input tensor of shape ``[..., D]`` and return a tensor of shape |
| 51 | + ``[..., 1]``. |
| 52 | + :param int output_dimension: The number of output variables, typically |
| 53 | + the number of state derivatives. It determines the number of columns |
| 54 | + in the coefficient matrix. |
| 55 | + :raises ValueError: If ``library`` is not a list of callables. |
| 56 | + :raises AssertionError: If ``output_dimension`` is not a positive |
| 57 | + integer. |
| 58 | + """ |
| 59 | + super().__init__() |
| 60 | + |
| 61 | + # Check consistency |
| 62 | + check_positive_integer(output_dimension, strict=True) |
| 63 | + check_consistency(library, Callable) |
| 64 | + if not isinstance(library, list): |
| 65 | + raise ValueError("`library` must be a list of callables.") |
| 66 | + |
| 67 | + # Initialization |
| 68 | + self._library = library |
| 69 | + self._coefficients = torch.nn.Parameter( |
| 70 | + torch.zeros(len(library), output_dimension) |
| 71 | + ) |
| 72 | + |
| 73 | + def forward(self, x): |
| 74 | + """ |
| 75 | + Forward pass of the :class:`SINDy` model. |
| 76 | +
|
| 77 | + :param torch.Tensor x: The input batch of state variables. |
| 78 | + :return: The predicted time derivatives of the state variables. |
| 79 | + :rtype: torch.Tensor |
| 80 | + """ |
| 81 | + theta = torch.stack([f(x) for f in self.library], dim=-2) |
| 82 | + return torch.einsum("...li , lo -> ...o", theta, self.coefficients) |
| 83 | + |
| 84 | + @property |
| 85 | + def library(self): |
| 86 | + """ |
| 87 | + The library of candidate functions. |
| 88 | +
|
| 89 | + :return: The library. |
| 90 | + :rtype: list[Callable] |
| 91 | + """ |
| 92 | + return self._library |
| 93 | + |
| 94 | + @property |
| 95 | + def coefficients(self): |
| 96 | + """ |
| 97 | + The coefficients of the model. |
| 98 | +
|
| 99 | + :return: The coefficients. |
| 100 | + :rtype: torch.Tensor |
| 101 | + """ |
| 102 | + return self._coefficients |
0 commit comments