diff --git a/pina/model/spline.py b/pina/model/spline.py index 77bf18759..d9141fe8c 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -18,16 +18,16 @@ class Spline(torch.nn.Module): where: - - :math:`C_i \in \mathbb{R}` are the control points. These fixed points - influence the shape of the curve but are not generally interpolated, - except at the boundaries under certain knot multiplicities. + - :math:`C \in \mathbb{R}^n` are the learnable control coefficients. Its + entries :math:`C_i` influence the shape of the curve but are not generally + interpolated, except under certain knot multiplicities. - :math:`B_{i,k}(x)` are the B-spline basis functions of order :math:`k`, i.e., piecewise polynomials of degree :math:`k-1` with support on the interval :math:`[x_i, x_{i+k}]`. - :math:`X = \{ x_1, x_2, \dots, x_m \}` is the non-decreasing knot vector. If the first and last knots are repeated :math:`k` times, then the curve - interpolates the first and last control points. + interpolates the first and last control coefficients. .. note:: diff --git a/pina/model/spline_surface.py b/pina/model/spline_surface.py index 61798fe7e..767e5b0dc 100644 --- a/pina/model/spline_surface.py +++ b/pina/model/spline_surface.py @@ -15,14 +15,15 @@ class SplineSurface(torch.nn.Module): .. math:: - S(x, y) = \sum_{i,j=1}^{n_x, n_y} B_{i,k}(x) B_{j,s}(y) C_{i,j}, - \quad x \in [x_1, x_m], y \in [y_1, y_l] + S(x, y) = \sum_{i=1}^{n_x} \sum_{j=1}^{n_y} B_{i,k}(x) B_{j,s}(y) + C_{i,j}, \quad x \in [x_1, x_m], y \in [y_1, y_l] where: - - :math:`C_{i,j} \in \mathbb{R}^2` are the control points. These fixed - points influence the shape of the surface but are not generally - interpolated, except at the boundaries under certain knot multiplicities. + - :math:`C \in \mathbb{R}^{n_x \times n_y}` is the matrix of learnable + control coefficients. Its entries :math:`C_{i,j}` influence the shape of + the surface but are not generally interpolated, except under certain knot + multiplicities. - :math:`B_{i,k}(x)` and :math:`B_{j,s}(y)` are the B-spline basis functions defined over two orthogonal directions, with orders :math:`k` and :math:`s`, respectively. @@ -268,8 +269,8 @@ def control_points(self, control_points): # Check control points if control_points.shape != __valid_shape: raise ValueError( - "control_points must be of the correct shape. ", - f"Expected {__valid_shape}, got {control_points.shape}.", + f"control_points must be of the correct shape. " + f"Expected {__valid_shape}, got {control_points.shape}." ) # Register control points as a learnable parameter