Skip to content

Commit c76cf21

Browse files
add acoustic wave problem
1 parent 26ac12f commit c76cf21

File tree

8 files changed

+207
-3
lines changed

8 files changed

+207
-3
lines changed

docs/source/_rst/_code.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ Problems Zoo
215215
.. toctree::
216216
:titlesonly:
217217

218+
AcousticWaveProblem <problem/zoo/acoustic_wave.rst>
218219
AdvectionProblem <problem/zoo/advection.rst>
219220
AllenCahnProblem <problem/zoo/allen_cahn.rst>
220221
DiffusionReactionProblem <problem/zoo/diffusion_reaction.rst>
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
AcousticWaveProblem
2+
=====================
3+
.. currentmodule:: pina.problem.zoo.acoustic_wave
4+
5+
.. automodule:: pina.problem.zoo.acoustic_wave
6+
7+
.. autoclass:: AcousticWaveProblem
8+
:members:
9+
:show-inheritance:

pina/equation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"DiffusionReaction",
1414
"Helmholtz",
1515
"Poisson",
16+
"AcousticWave",
1617
]
1718

1819
from .equation import Equation
@@ -27,5 +28,6 @@
2728
DiffusionReaction,
2829
Helmholtz,
2930
Poisson,
31+
AcousticWave,
3032
)
3133
from .system_equation import SystemEquation

pina/equation/equation_factory.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
from ..operator import grad, div, laplacian
77
from ..utils import check_consistency
88

9-
# Pylint warning disabled because the classes defined in this module
10-
# inherit from Equation and are meant to be simple containers for equations.
11-
129

1310
class FixedValue(Equation): # pylint: disable=R0903
1411
"""
@@ -452,3 +449,60 @@ def equation(input_, output_):
452449
return lap - self.forcing_term(input_)
453450

454451
super().__init__(equation)
452+
453+
454+
class AcousticWave(Equation): # pylint: disable=R0903
455+
r"""
456+
Implementation of the N-dimensional isotropic acoustic wave equation.
457+
The equation is defined as follows:
458+
459+
.. math::
460+
461+
\frac{\partial^2 u}{\partial t^2} - c^2 \Delta u = 0
462+
463+
or alternatively:
464+
465+
.. math::
466+
467+
\Box u = 0
468+
469+
Here, :math:`c` is the wave propagation speed, and :math:`\Box` is the
470+
d'Alembert operator.
471+
"""
472+
473+
def __init__(self, c):
474+
"""
475+
Initialization of the :class:`AcousticWaveEquation` class.
476+
477+
:param c: The wave propagation speed.
478+
:type c: float | int
479+
"""
480+
check_consistency(c, (float, int))
481+
self.c = c
482+
483+
def equation(input_, output_):
484+
"""
485+
Implementation of the acoustic wave equation.
486+
487+
:param LabelTensor input_: The input data of the problem.
488+
:param LabelTensor output_: The output data of the problem.
489+
:return: The residual of the acoustic wave equation.
490+
:rtype: LabelTensor
491+
:raises ValueError: If the ``input_`` labels do not contain the time
492+
variable 't'.
493+
"""
494+
# Ensure time is passed as input
495+
if "t" not in input_.labels:
496+
raise ValueError(
497+
"The ``input_`` labels must contain the time 't' variable."
498+
)
499+
500+
# Compute the time second derivative and the spatial laplacian
501+
u_tt = laplacian(output_, input_, d=["t"])
502+
u_xx = laplacian(
503+
output_, input_, d=[di for di in input_.labels if di != "t"]
504+
)
505+
506+
return u_tt - self.c**2 * u_xx
507+
508+
super().__init__(equation)

pina/problem/zoo/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"Poisson2DSquareProblem",
99
"DiffusionReactionProblem",
1010
"InversePoisson2DSquareProblem",
11+
"AcousticWaveProblem",
1112
]
1213

1314
from .supervised_problem import SupervisedProblem
@@ -17,3 +18,4 @@
1718
from .poisson_2d_square import Poisson2DSquareProblem
1819
from .diffusion_reaction import DiffusionReactionProblem
1920
from .inverse_poisson_2d_square import InversePoisson2DSquareProblem
21+
from .acoustic_wave import AcousticWaveProblem

pina/problem/zoo/acoustic_wave.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""Formulation of the acoustic wave problem."""
2+
3+
import torch
4+
from ... import Condition
5+
from ...problem import SpatialProblem, TimeDependentProblem
6+
from ...utils import check_consistency
7+
from ...domain import CartesianDomain
8+
from ...equation import (
9+
Equation,
10+
SystemEquation,
11+
FixedValue,
12+
FixedGradient,
13+
AcousticWave,
14+
)
15+
16+
17+
def initial_condition(input_, output_):
18+
"""
19+
Definition of the initial condition of the acoustic wave problem.
20+
21+
:param LabelTensor input_: The input data of the problem.
22+
:param LabelTensor output_: The output data of the problem.
23+
:return: The residual of the initial condition.
24+
:rtype: LabelTensor
25+
"""
26+
arg = torch.pi * input_["x"]
27+
return output_ - torch.sin(arg) - 0.5 * torch.sin(4 * arg)
28+
29+
30+
class AcousticWaveProblem(TimeDependentProblem, SpatialProblem):
31+
r"""
32+
Implementation of the acoustic wave problem in the spatial interval
33+
:math:`[0, 1]` and temporal interval :math:`[0, 1]`.
34+
35+
.. seealso::
36+
37+
**Original reference**: Wang, Sifan, Xinling Yu, and
38+
Paris Perdikaris. *When and why PINNs fail to train:
39+
A neural tangent kernel perspective*. Journal of
40+
Computational Physics 449 (2022): 110768.
41+
DOI: `10.1016 <https://doi.org/10.1016/j.jcp.2021.110768>`_.
42+
43+
:Example:
44+
45+
>>> problem = AcousticWaveProblem(c=2.0)
46+
"""
47+
48+
output_variables = ["u"]
49+
spatial_domain = CartesianDomain({"x": [0, 1]})
50+
temporal_domain = CartesianDomain({"t": [0, 1]})
51+
52+
domains = {
53+
"D": CartesianDomain({"x": [0, 1], "t": [0, 1]}),
54+
"t0": CartesianDomain({"x": [0, 1], "t": 0.0}),
55+
"g1": CartesianDomain({"x": 0.0, "t": [0, 1]}),
56+
"g2": CartesianDomain({"x": 1.0, "t": [0, 1]}),
57+
}
58+
59+
conditions = {
60+
"g1": Condition(domain="g1", equation=FixedValue(value=0.0)),
61+
"g2": Condition(domain="g2", equation=FixedValue(value=0.0)),
62+
"t0": Condition(
63+
domain="t0",
64+
equation=SystemEquation(
65+
[Equation(initial_condition), FixedGradient(value=0.0, d="t")]
66+
),
67+
),
68+
}
69+
70+
def __init__(self, c=2.0):
71+
"""
72+
Initialization of the :class:`AcousticWaveProblem` class.
73+
74+
:param c: The wave propagation speed. Default is 2.0.
75+
:type c: float | int
76+
"""
77+
super().__init__()
78+
check_consistency(c, (float, int))
79+
self.c = c
80+
81+
self.conditions["D"] = Condition(
82+
domain="D", equation=AcousticWave(self.c)
83+
)
84+
85+
def solution(self, pts):
86+
"""
87+
Implementation of the analytical solution of the acoustic wave problem.
88+
89+
:param LabelTensor pts: Points where the solution is evaluated.
90+
:return: The analytical solution of the acoustic wave problem.
91+
:rtype: LabelTensor
92+
"""
93+
arg_x = torch.pi * pts["x"]
94+
arg_t = self.c * torch.pi * pts["t"]
95+
term1 = torch.sin(arg_x) * torch.cos(arg_t)
96+
term2 = 0.5 * torch.sin(4 * arg_x) * torch.cos(4 * arg_t)
97+
return term1 + term2

tests/test_equation/test_equation_factory.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
DiffusionReaction,
99
Helmholtz,
1010
Poisson,
11+
AcousticWave,
1112
)
1213
from pina import LabelTensor
1314
import torch
@@ -195,3 +196,22 @@ def test_poisson_equation(forcing_term):
195196
# Residual
196197
residual = equation.residual(pts, u)
197198
assert residual.shape == u.shape
199+
200+
201+
@pytest.mark.parametrize("c", [1.0, 10, -7.5])
202+
def test_acoustic_wave_equation(c):
203+
204+
# Constructor
205+
equation = AcousticWave(c=c)
206+
207+
# Should fail if c is not a float or int
208+
with pytest.raises(ValueError):
209+
AcousticWave(c="invalid")
210+
211+
# Residual
212+
residual = equation.residual(pts, u)
213+
assert residual.shape == u.shape
214+
215+
# Should fail if the input has no 't' label
216+
with pytest.raises(ValueError):
217+
residual = equation.residual(pts["x", "y"], u)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
from pina.problem.zoo import AcousticWaveProblem
3+
from pina.problem import SpatialProblem, TimeDependentProblem
4+
5+
6+
@pytest.mark.parametrize("c", [0.1, 1])
7+
def test_constructor(c):
8+
9+
problem = AcousticWaveProblem(c=c)
10+
problem.discretise_domain(n=10, mode="random", domains="all")
11+
assert problem.are_all_domains_discretised
12+
assert isinstance(problem, SpatialProblem)
13+
assert isinstance(problem, TimeDependentProblem)
14+
assert hasattr(problem, "conditions")
15+
assert isinstance(problem.conditions, dict)
16+
17+
# Should fail if c is not a float or int
18+
with pytest.raises(ValueError):
19+
AcousticWaveProblem(c="invalid")

0 commit comments

Comments
 (0)