Skip to content

Commit 5966d03

Browse files
Add SINDy model (#660)
1 parent 2108c76 commit 5966d03

File tree

5 files changed

+167
-0
lines changed

5 files changed

+167
-0
lines changed

docs/source/_rst/_code.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ Models
106106
GraphNeuralKernel <model/graph_neural_operator_integral_kernel.rst>
107107
PirateNet <model/pirate_network.rst>
108108
EquivariantGraphNeuralOperator <model/equivariant_graph_neural_operator.rst>
109+
SINDy <model/sindy.rst>
109110

110111
Blocks
111112
-------------

docs/source/_rst/model/sindy.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
SINDy
2+
=======================
3+
.. currentmodule:: pina.model.sindy
4+
5+
.. autoclass:: SINDy
6+
:members:
7+
:show-inheritance:

pina/model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"GraphNeuralOperator",
1616
"PirateNet",
1717
"EquivariantGraphNeuralOperator",
18+
"SINDy",
1819
]
1920

2021
from .feed_forward import FeedForward, ResidualFeedForward
@@ -28,3 +29,4 @@
2829
from .graph_neural_operator import GraphNeuralOperator
2930
from .pirate_network import PirateNet
3031
from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator
32+
from .sindy import SINDy

pina/model/sindy.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""Module for the SINDy model class."""
2+
3+
from typing import Callable
4+
import torch
5+
from ..utils import check_consistency, check_positive_integer
6+
7+
8+
class SINDy(torch.nn.Module):
9+
r"""
10+
SINDy model class.
11+
12+
The Sparse Identification of Nonlinear Dynamics (SINDy) model identifies the
13+
governing equations of a dynamical system from data by learning a sparse
14+
linear combination of non-linear candidate functions.
15+
16+
The output of the model is expressed as product of a library matrix and a
17+
coefficient matrix:
18+
19+
.. math::
20+
21+
\dot{X} = \Theta(X) \Xi
22+
23+
where:
24+
- :math:`X \in \mathbb{R}^{B \times D}` is the input snapshots of the
25+
system state. Here, :math:`B` is the batch size and :math:`D` is the
26+
number of state variables.
27+
- :math:`\Theta(X) \in \mathbb{R}^{B \times L}` is the library matrix
28+
obtained by evaluating a set of candidate functions on the input data.
29+
Here, :math:`L` is the number of candidate functions in the library.
30+
- :math:`\Xi \in \mathbb{R}^{L \times D}` is the learned coefficient
31+
matrix that defines the sparse model.
32+
33+
.. seealso::
34+
35+
**Original reference**:
36+
Brunton, S.L., Proctor, J.L., and Kutz, J.N. (2016).
37+
*Discovering governing equations from data: Sparse identification of
38+
non-linear dynamical systems.*
39+
Proceedings of the National Academy of Sciences, 113(15), 3932-3937.
40+
DOI: `10.1073/pnas.1517384113
41+
<https://doi.org/10.1073/pnas.1517384113>`_
42+
"""
43+
44+
def __init__(self, library, output_dimension):
45+
"""
46+
Initialization of the :class:`SINDy` class.
47+
48+
:param list[Callable] library: The collection of candidate functions
49+
used to construct the library matrix. Each function must accept an
50+
input tensor of shape ``[..., D]`` and return a tensor of shape
51+
``[..., 1]``.
52+
:param int output_dimension: The number of output variables, typically
53+
the number of state derivatives. It determines the number of columns
54+
in the coefficient matrix.
55+
:raises ValueError: If ``library`` is not a list of callables.
56+
:raises AssertionError: If ``output_dimension`` is not a positive
57+
integer.
58+
"""
59+
super().__init__()
60+
61+
# Check consistency
62+
check_positive_integer(output_dimension, strict=True)
63+
check_consistency(library, Callable)
64+
if not isinstance(library, list):
65+
raise ValueError("`library` must be a list of callables.")
66+
67+
# Initialization
68+
self._library = library
69+
self._coefficients = torch.nn.Parameter(
70+
torch.zeros(len(library), output_dimension)
71+
)
72+
73+
def forward(self, x):
74+
"""
75+
Forward pass of the :class:`SINDy` model.
76+
77+
:param torch.Tensor x: The input batch of state variables.
78+
:return: The predicted time derivatives of the state variables.
79+
:rtype: torch.Tensor
80+
"""
81+
theta = torch.stack([f(x) for f in self.library], dim=-2)
82+
return torch.einsum("...li , lo -> ...o", theta, self.coefficients)
83+
84+
@property
85+
def library(self):
86+
"""
87+
The library of candidate functions.
88+
89+
:return: The library.
90+
:rtype: list[Callable]
91+
"""
92+
return self._library
93+
94+
@property
95+
def coefficients(self):
96+
"""
97+
The coefficients of the model.
98+
99+
:return: The coefficients.
100+
:rtype: torch.Tensor
101+
"""
102+
return self._coefficients

tests/test_model/test_sindy.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
import pytest
3+
from pina.model import SINDy
4+
5+
# Define a simple library of candidate functions and some test data
6+
library = [lambda x: torch.pow(x, 2), lambda x: torch.sin(x)]
7+
8+
9+
@pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))])
10+
def test_constructor(data):
11+
SINDy(library, data.shape[-1])
12+
13+
# Should fail if output_dimension is not a positive integer
14+
with pytest.raises(AssertionError):
15+
SINDy(library, "not_int")
16+
with pytest.raises(AssertionError):
17+
SINDy(library, -1)
18+
19+
# Should fail if library is not a list
20+
with pytest.raises(ValueError):
21+
SINDy(lambda x: torch.pow(x, 2), 3)
22+
23+
# Should fail if library is not a list of callables
24+
with pytest.raises(ValueError):
25+
SINDy([1, 2, 3], 3)
26+
27+
28+
@pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))])
29+
def test_forward(data):
30+
31+
# Define model
32+
model = SINDy(library, data.shape[-1])
33+
with torch.no_grad():
34+
model.coefficients.data.fill_(1.0)
35+
36+
# Evaluate model
37+
output_ = model(data)
38+
vals = data.pow(2) + torch.sin(data)
39+
40+
print(data.shape, output_.shape, vals.shape)
41+
42+
assert output_.shape == data.shape
43+
assert torch.allclose(output_, vals, atol=1e-6, rtol=1e-6)
44+
45+
46+
@pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))])
47+
def test_backward(data):
48+
49+
# Define and evaluate model
50+
model = SINDy(library, data.shape[-1])
51+
output_ = model(data.requires_grad_())
52+
53+
loss = output_.mean()
54+
loss.backward()
55+
assert data.grad.shape == data.shape

0 commit comments

Comments
 (0)