diff --git a/cpp/dolfinx/fem/Form.h b/cpp/dolfinx/fem/Form.h index 6bbb426e5c..6c71141b8a 100644 --- a/cpp/dolfinx/fem/Form.h +++ b/cpp/dolfinx/fem/Form.h @@ -55,6 +55,8 @@ struct integral_data /// @param[in] entities Indices of entities to integrate over. /// @param[in] coeffs Indices of the coefficients that are present /// (active) in `kernel`. + /// @param[in] custom_data Optional custom user data pointer passed to + /// the kernel function. template requires std::is_convertible_v< std::remove_cvref_t, @@ -64,9 +66,10 @@ struct integral_data std::vector> and std::is_convertible_v, std::vector> - integral_data(K&& kernel, V&& entities, W&& coeffs) + integral_data(K&& kernel, V&& entities, W&& coeffs, + std::optional custom_data = std::nullopt) : kernel(std::forward(kernel)), entities(std::forward(entities)), - coeffs(std::forward(coeffs)) + coeffs(std::forward(coeffs)), custom_data(custom_data) { } @@ -82,6 +85,11 @@ struct integral_data /// @brief Indices of coefficients (from the form) that are in this /// integral. std::vector coeffs; + + /// @brief Custom user data pointer passed to the kernel function. + /// This can be used to pass runtime-computed data (e.g., per-cell + /// quadrature rules, material properties) to the kernel. + std::optional custom_data = std::nullopt; }; /// @brief A representation of finite element variational forms. @@ -391,6 +399,41 @@ class Form return it->second.kernel; } + /// @brief Get the custom data pointer for an integral. + /// + /// The custom data pointer is passed to the kernel function during + /// assembly. This can be used to pass runtime-computed data to + /// kernels (e.g., per-cell quadrature rules, material properties). + /// + /// @param[in] type Integral type. + /// @param[in] id Integral subdomain ID. + /// @param[in] kernel_idx Index of the kernel (we may have multiple + /// kernels for a given ID in mixed-topology meshes). + /// @return Custom data pointer for the integral, or std::nullopt if not set. + std::optional custom_data(IntegralType type, int id, + int kernel_idx) const + { + auto it = _integrals.find({type, id, kernel_idx}); + if (it == _integrals.end()) + throw std::runtime_error("Requested integral not found."); + return it->second.custom_data; + } + + /// @brief Set the custom data pointer for an integral. + /// + /// @param[in] type Integral type. + /// @param[in] id Integral subdomain ID. + /// @param[in] kernel_idx Index of the kernel. + /// @param[in] data Custom data pointer to set, or std::nullopt to clear. + void set_custom_data(IntegralType type, int id, int kernel_idx, + std::optional data) + { + auto it = _integrals.find({type, id, kernel_idx}); + if (it == _integrals.end()) + throw std::runtime_error("Requested integral not found."); + it->second.custom_data = data; + } + /// @brief Get types of integrals in the form. /// @return Integrals types. std::set integral_types() const diff --git a/cpp/dolfinx/fem/assemble_matrix_impl.h b/cpp/dolfinx/fem/assemble_matrix_impl.h index 3b6148f9e9..5b68e157a8 100644 --- a/cpp/dolfinx/fem/assemble_matrix_impl.h +++ b/cpp/dolfinx/fem/assemble_matrix_impl.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -60,6 +61,7 @@ using mdspan2_t = md::mdspan>; /// function mesh. /// @param cell_info1 Cell permutation information for the trial /// function mesh. +/// @param custom_data Custom user data pointer passed to the kernel. template void assemble_cells_matrix( la::MatSet auto mat_set, mdspan2_t x_dofmap, @@ -74,7 +76,8 @@ void assemble_cells_matrix( std::span bc1, FEkernel auto kernel, md::mdspan> coeffs, std::span constants, std::span cell_info0, - std::span cell_info1) + std::span cell_info1, + std::optional custom_data = std::nullopt) { if (cells.empty()) return; @@ -109,7 +112,7 @@ void assemble_cells_matrix( // Tabulate tensor std::ranges::fill(Ae, 0); kernel(Ae.data(), &coeffs(c, 0), constants.data(), cdofs.data(), nullptr, - nullptr, nullptr); + nullptr, custom_data.value_or(nullptr)); // Compute A = P_0 \tilde{A} P_1^T (dof transformation) P0(Ae, cell_info0, cell0, ndim1); // B = P0 \tilde{A} @@ -198,6 +201,7 @@ void assemble_cells_matrix( /// function mesh. /// @param[in] perms Entity permutation integer. Empty if entity /// permutations are not required. +/// @param custom_data Custom user data pointer passed to the kernel. template void assemble_entities( la::MatSet auto mat_set, mdspan2_t x_dofmap, @@ -221,7 +225,8 @@ void assemble_entities( md::mdspan> coeffs, std::span constants, std::span cell_info0, std::span cell_info1, - md::mdspan> perms) + md::mdspan> perms, + std::optional custom_data = std::nullopt) { if (entities.empty()) return; @@ -259,7 +264,7 @@ void assemble_entities( // Tabulate tensor std::ranges::fill(Ae, 0); kernel(Ae.data(), &coeffs(f, 0), constants.data(), cdofs.data(), - &local_entity, &perm, nullptr); + &local_entity, &perm, custom_data.value_or(nullptr)); P0(Ae, cell_info0, cell0, ndim1); P1T(Ae, cell_info1, cell1, ndim0); @@ -363,7 +368,8 @@ void assemble_interior_facets( coeffs, std::span constants, std::span cell_info0, std::span cell_info1, - md::mdspan> perms) + md::mdspan> perms, + std::optional custom_data = std::nullopt) { if (facets.empty()) return; @@ -440,7 +446,7 @@ void assemble_interior_facets( : std::array{perms(cells[0], local_facet[0]), perms(cells[1], local_facet[1])}; kernel(Ae.data(), &coeffs(f, 0, 0), constants.data(), cdofs.data(), - local_facet.data(), perm.data(), nullptr); + local_facet.data(), perm.data(), custom_data.value_or(nullptr)); // Local element layout is a 2x2 block matrix with structure // @@ -605,12 +611,14 @@ void assemble_matrix( std::span cells0 = a.domain_arg(IntegralType::cell, 0, i, cell_type_idx); std::span cells1 = a.domain_arg(IntegralType::cell, 1, i, cell_type_idx); auto& [coeffs, cstride] = coefficients.at({IntegralType::cell, i}); + std::optional custom_data + = a.custom_data(IntegralType::cell, i, cell_type_idx); assert(cells.size() * cstride == coeffs.size()); impl::assemble_cells_matrix( mat_set, x_dofmap, x, cells, {dofs0, bs0, cells0}, P0, {dofs1, bs1, cells1}, P1T, bc0, bc1, fn, md::mdspan(coeffs.data(), cells.size(), cstride), constants, - cell_info0, cell_info1); + cell_info0, cell_info1, custom_data); } md::mdspan> facet_perms; @@ -646,6 +654,8 @@ void assemble_matrix( assert(fn); auto& [coeffs, cstride] = coefficients.at({IntegralType::interior_facet, i}); + std::optional custom_data + = a.custom_data(IntegralType::interior_facet, i, 0); std::span facets = a.domain(IntegralType::interior_facet, i, 0); std::span facets0 = a.domain_arg(IntegralType::interior_facet, 0, i, 0); @@ -661,7 +671,7 @@ void assemble_matrix( mdspanx22_t(facets1.data(), facets1.size() / 4, 2, 2)}, P1T, bc0, bc1, fn, mdspanx2x_t(coeffs.data(), facets.size() / 4, 2, cstride), constants, - cell_info0, cell_info1, facet_perms); + cell_info0, cell_info1, facet_perms, custom_data); } for (auto itg_type : {fem::IntegralType::exterior_facet, @@ -688,6 +698,7 @@ void assemble_matrix( auto fn = a.kernel(itg_type, i, 0); assert(fn); auto& [coeffs, cstride] = coefficients.at({itg_type, i}); + std::optional custom_data = a.custom_data(itg_type, i, 0); std::span e = a.domain(itg_type, i, 0); mdspanx2_t entities(e.data(), e.size() / 2, 2); @@ -700,7 +711,7 @@ void assemble_matrix( mat_set, x_dofmap, x, entities, {dofs0, bs0, entities0}, P0, {dofs1, bs1, entities1}, P1T, bc0, bc1, fn, md::mdspan(coeffs.data(), entities.extent(0), cstride), constants, - cell_info0, cell_info1, perms); + cell_info0, cell_info1, perms, custom_data); } } } diff --git a/cpp/dolfinx/fem/assemble_scalar_impl.h b/cpp/dolfinx/fem/assemble_scalar_impl.h index 37166b46ea..016c3a9bd4 100644 --- a/cpp/dolfinx/fem/assemble_scalar_impl.h +++ b/cpp/dolfinx/fem/assemble_scalar_impl.h @@ -17,6 +17,7 @@ #include #include #include +#include #include namespace dolfinx::fem::impl @@ -30,7 +31,8 @@ T assemble_cells(mdspan2_t x_dofmap, std::span cells, FEkernel auto fn, std::span constants, md::mdspan> coeffs, - std::span> cdofs_b) + std::span> cdofs_b, + std::optional custom_data = std::nullopt) { T value(0); if (cells.empty()) @@ -49,7 +51,7 @@ T assemble_cells(mdspan2_t x_dofmap, std::copy_n(&x(x_dofs[i], 0), 3, std::next(cdofs_b.begin(), 3 * i)); fn(&value, &coeffs(index, 0), constants.data(), cdofs_b.data(), nullptr, - nullptr, nullptr); + nullptr, custom_data.value_or(nullptr)); } return value; @@ -77,7 +79,8 @@ T assemble_entities( FEkernel auto fn, std::span constants, md::mdspan> coeffs, md::mdspan> perms, - std::span> cdofs_b) + std::span> cdofs_b, + std::optional custom_data = std::nullopt) { T value(0); if (entities.empty()) @@ -99,7 +102,7 @@ T assemble_entities( // Permutations std::uint8_t perm = perms.empty() ? 0 : perms(cell, local_entity); fn(&value, &coeffs(f, 0), constants.data(), cdofs_b.data(), &local_entity, - &perm, nullptr); + &perm, custom_data.value_or(nullptr)); } return value; @@ -120,7 +123,8 @@ T assemble_interior_facets( md::dynamic_extent>> coeffs, md::mdspan> perms, - std::span> cdofs_b) + std::span> cdofs_b, + std::optional custom_data = std::nullopt) { T value(0); if (facets.empty()) @@ -150,7 +154,7 @@ T assemble_interior_facets( : std::array{perms(cells[0], local_facet[0]), perms(cells[1], local_facet[1])}; fn(&value, &coeffs(f, 0, 0), constants.data(), cdofs_b.data(), - local_facet.data(), perm.data(), nullptr); + local_facet.data(), perm.data(), custom_data.value_or(nullptr)); } return value; @@ -178,11 +182,12 @@ T assemble_scalar( auto fn = M.kernel(IntegralType::cell, i, 0); assert(fn); auto& [coeffs, cstride] = coefficients.at({IntegralType::cell, i}); + std::optional custom_data = M.custom_data(IntegralType::cell, i, 0); std::span cells = M.domain(IntegralType::cell, i, 0); assert(cells.size() * cstride == coeffs.size()); value += impl::assemble_cells( x_dofmap, x, cells, fn, constants, - md::mdspan(coeffs.data(), cells.size(), cstride), cdofs_b); + md::mdspan(coeffs.data(), cells.size(), cstride), cdofs_b, custom_data); } mesh::CellType cell_type = mesh->topology()->cell_type(); @@ -204,6 +209,8 @@ T assemble_scalar( assert(fn); auto& [coeffs, cstride] = coefficients.at({IntegralType::interior_facet, i}); + std::optional custom_data + = M.custom_data(IntegralType::interior_facet, i, 0); std::span facets = M.domain(IntegralType::interior_facet, i, 0); constexpr std::size_t num_adjacent_cells = 2; @@ -220,7 +227,7 @@ T assemble_scalar( md::mdspan>( coeffs.data(), facets.size() / shape1, 2, cstride), - facet_perms, cdofs_b); + facet_perms, cdofs_b, custom_data); } for (auto itg_type : {fem::IntegralType::exterior_facet, @@ -236,6 +243,7 @@ T assemble_scalar( auto fn = M.kernel(itg_type, i, 0); assert(fn); auto& [coeffs, cstride] = coefficients.at({itg_type, i}); + std::optional custom_data = M.custom_data(itg_type, i, 0); std::span entities = M.domain(itg_type, i, 0); @@ -248,7 +256,7 @@ T assemble_scalar( entities.data(), entities.size() / 2, 2), fn, constants, md::mdspan(coeffs.data(), entities.size() / 2, cstride), perms, - cdofs_b); + cdofs_b, custom_data); } } diff --git a/cpp/dolfinx/fem/assemble_vector_impl.h b/cpp/dolfinx/fem/assemble_vector_impl.h index f57e646dd9..fdaa37471b 100644 --- a/cpp/dolfinx/fem/assemble_vector_impl.h +++ b/cpp/dolfinx/fem/assemble_vector_impl.h @@ -92,7 +92,8 @@ void _lift_bc_cells( md::mdspan> coeffs, std::span cell_info0, std::span cell_info1, std::span bc_values1, - std::span bc_markers1, std::span x0, T alpha) + std::span bc_markers1, std::span x0, T alpha, + std::optional custom_data = std::nullopt) { if (cells.empty()) return; @@ -164,7 +165,7 @@ void _lift_bc_cells( std::ranges::fill(Ae, 0); kernel(Ae.data(), &coeffs(index, 0), constants.data(), cdofs.data(), - nullptr, nullptr, nullptr); + nullptr, nullptr, custom_data.value_or(nullptr)); P0(Ae, cell_info0, c0, num_cols); P1T(Ae, cell_info1, c1, num_rows); @@ -286,7 +287,8 @@ void _lift_bc_entities( std::span cell_info0, std::span cell_info1, std::span bc_values1, std::span bc_markers1, std::span x0, T alpha, - md::mdspan> perms) + md::mdspan> perms, + std::optional custom_data = std::nullopt) { if (entities.empty()) return; @@ -345,7 +347,7 @@ void _lift_bc_entities( std::uint8_t perm = perms.empty() ? 0 : perms(cell, local_entity); std::ranges::fill(Ae, 0); kernel(Ae.data(), &coeffs(index, 0), constants.data(), cdofs.data(), - &local_entity, &perm, nullptr); + &local_entity, &perm, custom_data.value_or(nullptr)); P0(Ae, cell_info0, cell0, num_cols); P1T(Ae, cell_info1, cell1, num_rows); @@ -443,7 +445,8 @@ void _lift_bc_interior_facets( std::span cell_info0, std::span cell_info1, std::span bc_values1, std::span bc_markers1, std::span x0, T alpha, - md::mdspan> perms) + md::mdspan> perms, + std::optional custom_data = std::nullopt) { if (facets.empty()) return; @@ -559,7 +562,7 @@ void _lift_bc_interior_facets( : std::array{perms(cells[0], local_facet[0]), perms(cells[1], local_facet[1])}; kernel(Ae.data(), &coeffs(f, 0, 0), constants.data(), cdofs.data(), - local_facet.data(), perm.data(), nullptr); + local_facet.data(), perm.data(), custom_data.value_or(nullptr)); if (cells0[0] >= 0) P0(Ae, cell_info0, cells0[0], num_cols); @@ -671,7 +674,8 @@ void assemble_cells( std::tuple> dofmap, FEkernel auto kernel, std::span constants, md::mdspan> coeffs, - std::span cell_info0) + std::span cell_info0, + std::optional custom_data = std::nullopt) { if (cells.empty()) return; @@ -698,7 +702,7 @@ void assemble_cells( // Tabulate vector for cell std::ranges::fill(be, 0); kernel(be.data(), &coeffs(index, 0), constants.data(), cdofs.data(), - nullptr, nullptr, nullptr); + nullptr, nullptr, custom_data.value_or(nullptr)); P0(be, cell_info0, c0, 1); // Scatter cell vector to 'global' vector array @@ -769,7 +773,8 @@ void assemble_entities( FEkernel auto kernel, std::span constants, md::mdspan> coeffs, std::span cell_info0, - md::mdspan> perms) + md::mdspan> perms, + std::optional custom_data = std::nullopt) { if (entities.empty()) return; @@ -801,7 +806,7 @@ void assemble_entities( // Tabulate element vector std::ranges::fill(be, 0); kernel(be.data(), &coeffs(f, 0), constants.data(), cdofs.data(), - &local_entity, &perm, nullptr); + &local_entity, &perm, custom_data.value_or(nullptr)); P0(be, cell_info0, cell0, 1); // Add element vector to global vector @@ -866,7 +871,8 @@ void assemble_interior_facets( md::dynamic_extent>> coeffs, std::span cell_info0, - md::mdspan> perms) + md::mdspan> perms, + std::optional custom_data = std::nullopt) { using X = scalar_value_t; @@ -918,7 +924,7 @@ void assemble_interior_facets( : std::array{perms(cells[0], local_facet[0]), perms(cells[1], local_facet[1])}; kernel(be.data(), &coeffs(f, 0, 0), constants.data(), cdofs.data(), - local_facet.data(), perm.data(), nullptr); + local_facet.data(), perm.data(), custom_data.value_or(nullptr)); if (cells0[0] >= 0) P0(be, cell_info0, cells0[0], 1); @@ -1026,6 +1032,7 @@ void lift_bc(V&& b, const Form& a, mdspan2_t x_dofmap, auto kernel = a.kernel(IntegralType::cell, i, 0); assert(kernel); auto& [_coeffs, cstride] = coefficients.at({IntegralType::cell, i}); + std::optional custom_data = a.custom_data(IntegralType::cell, i, 0); std::span cells = a.domain(IntegralType::cell, i, 0); std::span cells0 = a.domain_arg(IntegralType::cell, 0, i, 0); std::span cells1 = a.domain_arg(IntegralType::cell, 1, i, 0); @@ -1036,20 +1043,21 @@ void lift_bc(V&& b, const Form& a, mdspan2_t x_dofmap, _lift_bc_cells<1, 1>(b, x_dofmap, x, kernel, cells, {dofmap0, bs0, cells0}, P0, {dofmap1, bs1, cells1}, P1T, constants, coeffs, cell_info0, cell_info1, - bc_values1, bc_markers1, x0, alpha); + bc_values1, bc_markers1, x0, alpha, custom_data); } else if (bs0 == 3 and bs1 == 3) { _lift_bc_cells<3, 3>(b, x_dofmap, x, kernel, cells, {dofmap0, bs0, cells0}, P0, {dofmap1, bs1, cells1}, P1T, constants, coeffs, cell_info0, cell_info1, - bc_values1, bc_markers1, x0, alpha); + bc_values1, bc_markers1, x0, alpha, custom_data); } else { _lift_bc_cells(b, x_dofmap, x, kernel, cells, {dofmap0, bs0, cells0}, P0, {dofmap1, bs1, cells1}, P1T, constants, coeffs, cell_info0, - cell_info1, bc_values1, bc_markers1, x0, alpha); + cell_info1, bc_values1, bc_markers1, x0, alpha, + custom_data); } } @@ -1072,6 +1080,8 @@ void lift_bc(V&& b, const Form& a, mdspan2_t x_dofmap, assert(kernel); auto& [coeffs, cstride] = coefficients.at({IntegralType::interior_facet, i}); + std::optional custom_data + = a.custom_data(IntegralType::interior_facet, i, 0); using mdspanx22_t = md::mdspan& a, mdspan2_t x_dofmap, b, x_dofmap, x, kernel, facets, {dofmap0, bs0, facets0}, P0, {dofmap1, bs1, facets1}, P1T, constants, mdspanx2x_t(coeffs.data(), facets.extent(0), 2, cstride), cell_info0, - cell_info1, bc_values1, bc_markers1, x0, alpha, facet_perms); + cell_info1, bc_values1, bc_markers1, x0, alpha, facet_perms, + custom_data); } for (auto itg_type : {fem::IntegralType::exterior_facet, @@ -1105,6 +1116,7 @@ void lift_bc(V&& b, const Form& a, mdspan2_t x_dofmap, auto kernel = a.kernel(itg_type, i, 0); assert(kernel); auto& [coeffs, cstride] = coefficients.at({itg_type, i}); + std::optional custom_data = a.custom_data(itg_type, i, 0); using mdspanx2_t = md::mdspan& a, mdspan2_t x_dofmap, b, x_dofmap, x, kernel, entities, {dofmap0, bs0, entities0}, P0, {dofmap1, bs1, entities1}, P1T, constants, md::mdspan(coeffs.data(), entities.extent(0), cstride), cell_info0, - cell_info1, bc_values1, bc_markers1, x0, alpha, perms); + cell_info1, bc_values1, bc_markers1, x0, alpha, perms, custom_data); } } } @@ -1276,24 +1288,29 @@ void assemble_vector( std::span cells = L.domain(IntegralType::cell, i, cell_type_idx); std::span cells0 = L.domain_arg(IntegralType::cell, 0, i, cell_type_idx); auto& [coeffs, cstride] = coefficients.at({IntegralType::cell, i}); + void* custom_data = L.custom_data(IntegralType::cell, i, cell_type_idx) + .value_or(nullptr); assert(cells.size() * cstride == coeffs.size()); if (bs == 1) { impl::assemble_cells<1>( P0, b, x_dofmap, x, cells, {dofs, bs, cells0}, fn, constants, - md::mdspan(coeffs.data(), cells.size(), cstride), cell_info0); + md::mdspan(coeffs.data(), cells.size(), cstride), cell_info0, + custom_data); } else if (bs == 3) { impl::assemble_cells<3>( P0, b, x_dofmap, x, cells, {dofs, bs, cells0}, fn, constants, - md::mdspan(coeffs.data(), cells.size(), cstride), cell_info0); + md::mdspan(coeffs.data(), cells.size(), cstride), cell_info0, + custom_data); } else { - impl::assemble_cells( - P0, b, x_dofmap, x, cells, {dofs, bs, cells0}, fn, constants, - md::mdspan(coeffs.data(), cells.size(), cstride), cell_info0); + impl::assemble_cells(P0, b, x_dofmap, x, cells, {dofs, bs, cells0}, fn, + constants, + md::mdspan(coeffs.data(), cells.size(), cstride), + cell_info0, custom_data); } } @@ -1327,6 +1344,8 @@ void assemble_vector( assert(fn); auto& [coeffs, cstride] = coefficients.at({IntegralType::interior_facet, i}); + void* custom_data + = L.custom_data(IntegralType::interior_facet, i, 0).value_or(nullptr); std::span facets = L.domain(IntegralType::interior_facet, i, 0); std::span facets1 = L.domain_arg(IntegralType::interior_facet, 0, i, 0); assert((facets.size() / 4) * 2 * cstride == coeffs.size()); @@ -1339,7 +1358,7 @@ void assemble_vector( mdspanx22_t(facets1.data(), facets1.size() / 4, 2, 2)}, fn, constants, mdspanx2x_t(coeffs.data(), facets.size() / 4, 2, cstride), - cell_info0, facet_perms); + cell_info0, facet_perms, custom_data); } else if (bs == 3) { @@ -1350,7 +1369,7 @@ void assemble_vector( mdspanx22_t(facets1.data(), facets1.size() / 4, 2, 2)}, fn, constants, mdspanx2x_t(coeffs.data(), facets.size() / 4, 2, cstride), - cell_info0, facet_perms); + cell_info0, facet_perms, custom_data); } else { @@ -1361,7 +1380,7 @@ void assemble_vector( mdspanx22_t(facets1.data(), facets1.size() / 4, 2, 2)}, fn, constants, mdspanx2x_t(coeffs.data(), facets.size() / 4, 2, cstride), - cell_info0, facet_perms); + cell_info0, facet_perms, custom_data); } } @@ -1378,6 +1397,7 @@ void assemble_vector( auto fn = L.kernel(itg_type, i, 0); assert(fn); auto& [coeffs, cstride] = coefficients.at({itg_type, i}); + void* custom_data = L.custom_data(itg_type, i, 0).value_or(nullptr); std::span e = L.domain(itg_type, i, 0); mdspanx2_t entities(e.data(), e.size() / 2, 2); std::span e1 = L.domain_arg(itg_type, 0, i, 0); @@ -1388,7 +1408,7 @@ void assemble_vector( impl::assemble_entities<1>( P0, b, x_dofmap, x, entities, {dofs, bs, entities1}, fn, constants, md::mdspan(coeffs.data(), entities.extent(0), cstride), - cell_info0, perms); + cell_info0, perms, custom_data); } else if (bs == 3) { @@ -1396,7 +1416,7 @@ void assemble_vector( P0, b, x_dofmap, x, entities, {dofs, bs, entities1}, fn, constants, md::mdspan(coeffs.data(), entities.size() / 2, cstride), - cell_info0, perms); + cell_info0, perms, custom_data); } else { @@ -1404,7 +1424,7 @@ void assemble_vector( P0, b, x_dofmap, x, entities, {dofs, bs, entities1}, fn, constants, md::mdspan(coeffs.data(), entities.size() / 2, cstride), - cell_info0, perms); + cell_info0, perms, custom_data); } } } diff --git a/python/dolfinx/wrappers/fem.cpp b/python/dolfinx/wrappers/fem.cpp index 9e4b365722..738fab4540 100644 --- a/python/dolfinx/wrappers/fem.cpp +++ b/python/dolfinx/wrappers/fem.cpp @@ -761,6 +761,31 @@ void declare_form(nb::module_& m, std::string type) .def_prop_ro("integral_types", &dolfinx::fem::Form::integral_types) .def_prop_ro("needs_facet_permutations", &dolfinx::fem::Form::needs_facet_permutations) + .def( + "set_custom_data", + [](dolfinx::fem::Form& self, dolfinx::fem::IntegralType type, + int id, int kernel_idx, std::optional data) + { + self.set_custom_data(type, id, kernel_idx, + data ? std::optional((void*)*data) + : std::nullopt); + }, + nb::arg("type"), nb::arg("id"), nb::arg("kernel_idx"), + nb::arg("data").none(), + "Set custom data pointer for an integral. The data pointer is " + "passed to the kernel. Pass None to clear.") + .def( + "custom_data", + [](const dolfinx::fem::Form& self, + dolfinx::fem::IntegralType type, int id, + int kernel_idx) -> std::optional + { + auto cd = self.custom_data(type, id, kernel_idx); + return cd ? std::optional((std::uintptr_t)*cd) + : std::nullopt; + }, + nb::arg("type"), nb::arg("id"), nb::arg("kernel_idx"), + "Get custom data pointer for an integral, or None if not set.") .def( "domains", [](const dolfinx::fem::Form& self, diff --git a/python/test/unit/fem/test_custom_data.py b/python/test/unit/fem/test_custom_data.py new file mode 100644 index 0000000000..013b545846 --- /dev/null +++ b/python/test/unit/fem/test_custom_data.py @@ -0,0 +1,287 @@ +"""Unit tests for custom_data functionality in assembly.""" + +# Copyright (C) 2025 Susanne Claus +# +# This file is part of DOLFINx (https://www.fenicsproject.org) +# +# SPDX-License-Identifier: LGPL-3.0-or-later + +from mpi4py import MPI + +import numpy as np +import pytest + +import dolfinx +import ffcx.codegeneration.utils +from dolfinx import la +from dolfinx.fem import Form, IntegralType, form_cpp_class, functionspace +from dolfinx.mesh import create_unit_square + +numba = pytest.importorskip("numba") +ufcx_signature = ffcx.codegeneration.utils.numba_ufcx_kernel_signature + + +# Helper intrinsic to cast void* to a typed pointer for custom_data +@numba.extending.intrinsic +def voidptr_to_float64_ptr(typingctx, src): + """Cast a void pointer (CPointer(void)) to a float64 pointer. + + This function is used to access custom_data passed through the UFCx + tabulate_tensor interface. Since custom_data is passed as void*, this + intrinsic allows casting it to a typed float64 pointer for element access. + + Args: + typingctx: The typing context. + src: A void pointer (CPointer(void)) to cast. + + Returns: + sig: A Numba signature returning CPointer(float64). + codegen: A code generation function that performs the bitcast. + + Example: + Inside a Numba cfunc kernel:: + + typed_ptr = voidptr_to_float64_ptr(custom_data) + scale = typed_ptr[0] # Access first float64 value + """ + # Accept CPointer(void) which shows as 'none*' in numba type system + if isinstance(src, numba.types.CPointer) and src.dtype == numba.types.void: + sig = numba.types.CPointer(numba.types.float64)(src) + + def codegen(context, builder, signature, args): + [src] = args + # Cast void* to float64* + dst_type = context.get_value_type(numba.types.CPointer(numba.types.float64)) + return builder.bitcast(src, dst_type) + + return sig, codegen + + +def tabulate_rank1_with_custom_data(dtype, xdtype): + """Kernel that reads a scaling factor from custom_data. + + Note: custom_data must be set to a valid pointer before assembly. + """ + + @numba.cfunc(ufcx_signature(dtype, xdtype), nopython=True) + def tabulate(b_, w_, c_, coords_, local_index, orientation, custom_data): + b = numba.carray(b_, (3), dtype=dtype) + coordinate_dofs = numba.carray(coords_, (3, 3), dtype=xdtype) + + # Cast void* to float64* and read the scale value + typed_ptr = voidptr_to_float64_ptr(custom_data) + scale = typed_ptr[0] + + x0, y0 = coordinate_dofs[0, :2] + x1, y1 = coordinate_dofs[1, :2] + x2, y2 = coordinate_dofs[2, :2] + + # 2x Element area Ae + Ae = abs((x0 - x1) * (y2 - y1) - (y0 - y1) * (x2 - x1)) + b[:] = scale * Ae / 6.0 + + return tabulate + + +def tabulate_rank2_with_custom_data(dtype, xdtype): + """Kernel that reads a scaling factor from custom_data for matrix assembly. + + Note: custom_data must be set to a valid pointer before assembly. + """ + + @numba.cfunc(ufcx_signature(dtype, xdtype), nopython=True) + def tabulate(A_, w_, c_, coords_, entity_local_index, cell_orientation, custom_data): + A = numba.carray(A_, (3, 3), dtype=dtype) + coordinate_dofs = numba.carray(coords_, (3, 3), dtype=xdtype) + + # Cast void* to float64* and read the scale value + typed_ptr = voidptr_to_float64_ptr(custom_data) + scale = typed_ptr[0] + + x0, y0 = coordinate_dofs[0, :2] + x1, y1 = coordinate_dofs[1, :2] + x2, y2 = coordinate_dofs[2, :2] + + # 2x Element area Ae + Ae = abs((x0 - x1) * (y2 - y1) - (y0 - y1) * (x2 - x1)) + B = np.array([y1 - y2, y2 - y0, y0 - y1, x2 - x1, x0 - x2, x1 - x0], dtype=dtype).reshape( + 2, 3 + ) + A[:, :] = scale * np.dot(B.T, B) / (2 * Ae) + + return tabulate + + +@pytest.mark.parametrize("dtype", [np.float64]) +def test_custom_data_vector_assembly(dtype): + """Test that custom_data is correctly passed to kernels during vector assembly.""" + xdtype = np.real(dtype(0)).dtype + k1 = tabulate_rank1_with_custom_data(dtype, xdtype) + + mesh = create_unit_square(MPI.COMM_WORLD, 13, 13, dtype=xdtype) + V = functionspace(mesh, ("Lagrange", 1)) + + tdim = mesh.topology.dim + num_cells = mesh.topology.index_map(tdim).size_local + mesh.topology.index_map(tdim).num_ghosts + cells = np.arange(num_cells, dtype=np.int32) + active_coeffs = np.array([], dtype=np.int8) + + integrals = {IntegralType.cell: [(0, k1.address, cells, active_coeffs)]} + formtype = form_cpp_class(dtype) + L = Form(formtype([V._cpp_object], integrals, [], [], False, [], mesh=mesh._cpp_object)) + + # Create custom_data with scale=1.0 first + scale_value = np.array([1.0], dtype=dtype) + scale_ptr = scale_value.ctypes.data + L._cpp_object.set_custom_data(IntegralType.cell, 0, 0, scale_ptr) + + # Assemble with scale=1.0 + b1 = dolfinx.fem.assemble_vector(L) + b1.scatter_reverse(la.InsertMode.add) + norm1 = la.norm(b1) + + # Verify we can read back the custom_data pointer + assert L._cpp_object.custom_data(IntegralType.cell, 0, 0) == scale_ptr + + # Update custom_data to scale=2.0 + scale_value[0] = 2.0 + b2 = dolfinx.fem.assemble_vector(L) + b2.scatter_reverse(la.InsertMode.add) + norm2 = la.norm(b2) + + # The norm with scale=2 should be 2x the norm with scale=1 + assert np.isclose(norm2, 2.0 * norm1) + + # Test with scale=3.0 + scale_value[0] = 3.0 + b3 = dolfinx.fem.assemble_vector(L) + b3.scatter_reverse(la.InsertMode.add) + norm3 = la.norm(b3) + + assert np.isclose(norm3, 3.0 * norm1) + + +@pytest.mark.parametrize("dtype", [np.float64]) +def test_custom_data_matrix_assembly(dtype): + """Test that custom_data is correctly passed to kernels during matrix assembly.""" + xdtype = np.real(dtype(0)).dtype + k2 = tabulate_rank2_with_custom_data(dtype, xdtype) + + mesh = create_unit_square(MPI.COMM_WORLD, 13, 13, dtype=xdtype) + V = functionspace(mesh, ("Lagrange", 1)) + + cells = np.arange(mesh.topology.index_map(mesh.topology.dim).size_local, dtype=np.int32) + active_coeffs = np.array([], dtype=np.int8) + + integrals = {IntegralType.cell: [(0, k2.address, cells, active_coeffs)]} + formtype = form_cpp_class(dtype) + a = Form( + formtype( + [V._cpp_object, V._cpp_object], + integrals, + [], + [], + False, + [], + mesh=mesh._cpp_object, + ) + ) + + # Set custom_data with scale=1.0 first + scale_value = np.array([1.0], dtype=dtype) + a._cpp_object.set_custom_data(IntegralType.cell, 0, 0, scale_value.ctypes.data) + + # Assemble with scale=1.0 + A1 = dolfinx.fem.assemble_matrix(a) + A1.scatter_reverse() + norm1 = np.sqrt(A1.squared_norm()) + + # Update custom_data to scale=2.0 + scale_value[0] = 2.0 + A2 = dolfinx.fem.assemble_matrix(a) + A2.scatter_reverse() + norm2 = np.sqrt(A2.squared_norm()) + + # The norm with scale=2 should be 2x the norm with scale=1 + assert np.isclose(norm2, 2.0 * norm1) + + +@pytest.mark.parametrize("dtype", [np.float64]) +def test_custom_data_default_nullptr(dtype): + """Test that custom_data defaults to nullptr (0).""" + xdtype = np.real(dtype(0)).dtype + + # Define a simple kernel that doesn't use custom_data + @numba.cfunc(ufcx_signature(dtype, xdtype), nopython=True) + def tabulate_simple(b_, w_, c_, coords_, local_index, orientation, custom_data): + b = numba.carray(b_, (3), dtype=dtype) + coordinate_dofs = numba.carray(coords_, (3, 3), dtype=xdtype) + + x0, y0 = coordinate_dofs[0, :2] + x1, y1 = coordinate_dofs[1, :2] + x2, y2 = coordinate_dofs[2, :2] + + Ae = abs((x0 - x1) * (y2 - y1) - (y0 - y1) * (x2 - x1)) + b[:] = Ae / 6.0 + + mesh = create_unit_square(MPI.COMM_WORLD, 5, 5, dtype=xdtype) + V = functionspace(mesh, ("Lagrange", 1)) + + tdim = mesh.topology.dim + num_cells = mesh.topology.index_map(tdim).size_local + mesh.topology.index_map(tdim).num_ghosts + cells = np.arange(num_cells, dtype=np.int32) + active_coeffs = np.array([], dtype=np.int8) + + integrals = {IntegralType.cell: [(0, tabulate_simple.address, cells, active_coeffs)]} + formtype = form_cpp_class(dtype) + L = Form(formtype([V._cpp_object], integrals, [], [], False, [], mesh=mesh._cpp_object)) + + # custom_data should be None (std::nullopt) by default + assert L._cpp_object.custom_data(IntegralType.cell, 0, 0) is None + + +@pytest.mark.parametrize("dtype", [np.float64]) +def test_custom_data_struct(dtype): + """Test passing a struct with multiple values through custom_data.""" + xdtype = np.real(dtype(0)).dtype + + # Define a kernel that reads two values from custom_data + @numba.cfunc(ufcx_signature(dtype, xdtype), nopython=True) + def tabulate_with_struct(b_, w_, c_, coords_, local_index, orientation, custom_data): + b = numba.carray(b_, (3), dtype=dtype) + coordinate_dofs = numba.carray(coords_, (3, 3), dtype=xdtype) + + # Cast void* to float64* and read two values: [scale, offset] + typed_ptr = voidptr_to_float64_ptr(custom_data) + scale = typed_ptr[0] + offset = typed_ptr[1] + + x0, y0 = coordinate_dofs[0, :2] + x1, y1 = coordinate_dofs[1, :2] + x2, y2 = coordinate_dofs[2, :2] + + Ae = abs((x0 - x1) * (y2 - y1) - (y0 - y1) * (x2 - x1)) + b[:] = scale * Ae / 6.0 + offset + + mesh = create_unit_square(MPI.COMM_WORLD, 5, 5, dtype=xdtype) + V = functionspace(mesh, ("Lagrange", 1)) + + tdim = mesh.topology.dim + num_cells = mesh.topology.index_map(tdim).size_local + mesh.topology.index_map(tdim).num_ghosts + cells = np.arange(num_cells, dtype=np.int32) + active_coeffs = np.array([], dtype=np.int8) + + integrals = {IntegralType.cell: [(0, tabulate_with_struct.address, cells, active_coeffs)]} + formtype = form_cpp_class(dtype) + L = Form(formtype([V._cpp_object], integrals, [], [], False, [], mesh=mesh._cpp_object)) + + # Create struct data: [scale=2.0, offset=0.5] + struct_data = np.array([2.0, 0.5], dtype=dtype) + L._cpp_object.set_custom_data(IntegralType.cell, 0, 0, struct_data.ctypes.data) + + b = dolfinx.fem.assemble_vector(L) + b.scatter_reverse(la.InsertMode.add) + + # Verify the assembly used our custom values + # The offset should contribute to each DOF + assert la.norm(b) > 0