Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/fenicsx-refs.env
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
basix_ref=main
ufl_ref=main
ffcx_ref=main
ufl_ref=dokken/multiple_coordinate_elements
ffcx_ref=mscroggs/ufl_coordinate_elements

59 changes: 22 additions & 37 deletions python/demo/demo_mixed-topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@
import basix
import dolfinx.cpp as _cpp
import ufl
from dolfinx.cpp.mesh import GhostMode, create_cell_partitioner, create_mesh
from dolfinx.cpp.mesh import GhostMode, create_cell_partitioner
from dolfinx.mesh import create_mesh
from dolfinx.fem import (
FunctionSpace,
assemble_matrix,
assemble_vector,
coordinate_element,
mixed_topology_form,
)
from dolfinx.io.utils import cell_perm_vtk
from dolfinx.mesh import CellType, Mesh
from dolfinx import fem

if MPI.COMM_WORLD.size > 1:
print("Not yet running in parallel")
Expand Down Expand Up @@ -89,48 +90,32 @@

cells_np = [np.array(c) for c in cells]
geomx = np.array(geom, dtype=np.float64)
hexahedron = coordinate_element(CellType.hexahedron, 1)
prism = coordinate_element(CellType.prism, 1)

part = create_cell_partitioner(GhostMode.none)
mesh = create_mesh(
MPI.COMM_WORLD, cells_np, [hexahedron._cpp_object, prism._cpp_object], geomx, part
MPI.COMM_WORLD,
cells_np,
geomx,
[
basix.ufl.element("Lagrange", "hexahedron", 1, shape=(3,)),
basix.ufl.element("Lagrange", "prism", 1, shape=(3,)),
],
part,
)

# Create elements and dofmaps for each cell type
elements = [
basix.create_element(basix.ElementFamily.P, basix.CellType.hexahedron, 1),
basix.create_element(basix.ElementFamily.P, basix.CellType.prism, 1),
]
elements_cpp = [_cpp.fem.FiniteElement_float64(e._e, None, True) for e in elements]
# NOTE: Both dofmaps have the same IndexMap, but different cell_dofs
dofmaps = _cpp.fem.create_dofmaps(mesh.comm, mesh.topology, elements_cpp)

# Create C++ function space
V_cpp = _cpp.fem.FunctionSpace_float64(mesh, elements_cpp, dofmaps)

# Create forms for each cell type.
# FIXME This hack is required at the moment because UFL does not yet know
# about mixed topology meshes.
a = []
L = []
for i, cell_name in enumerate(["hexahedron", "prism"]):
print(f"Creating form for {cell_name}")
element = basix.ufl.wrap_element(elements[i])
domain = ufl.Mesh(basix.ufl.element("Lagrange", cell_name, 1, shape=(3,)))
V = FunctionSpace(Mesh(mesh, domain), element, V_cpp)
u, v = ufl.TrialFunction(V), ufl.TestFunction(V)
k = 12.0
x = ufl.SpatialCoordinate(domain)
a += [(ufl.inner(ufl.grad(u), ufl.grad(v)) - k**2 * u * v) * ufl.dx]
f = ufl.sin(ufl.pi * x[0]) * ufl.sin(ufl.pi * x[1])
L += [f * v * ufl.dx]
V = fem.functionspace(mesh, ("Lagrange", 1))


u, v = ufl.TrialFunction(V), ufl.TestFunction(V)
k = 12.0
x = ufl.SpatialCoordinate(domain)
a = (ufl.inner(ufl.grad(u), ufl.grad(v)) - k**2 * u * v) * ufl.dx
f = ufl.sin(ufl.pi * x[0]) * ufl.sin(ufl.pi * x[1])
L = f * v * ufl.dx

# Compile the form
# FIXME: For the time being, since UFL doesn't understand mixed topology
# meshes, we have to call mixed_topology_form instead of form.
a_form = mixed_topology_form(a, dtype=np.float64)
L_form = mixed_topology_form(L, dtype=np.float64)
a_form = form(a, dtype=np.float64)
L_form = form(L, dtype=np.float64)

# Assemble the matrix
A = assemble_matrix(a_form)
Expand Down
55 changes: 40 additions & 15 deletions python/dolfinx/fem/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,7 @@ def functionspace(
mesh: Mesh,
element: typing.Union[
ufl.finiteelement.AbstractFiniteElement,
Sequence[ufl.finiteelement.AbstractFiniteElement],
ElementMetaData,
tuple[str, int],
tuple[str, int, tuple],
Expand All @@ -600,22 +601,46 @@ def functionspace(
"""
# Create UFL element
dtype = mesh.geometry.x.dtype
try:
e = ElementMetaData(*element) # type: ignore
ufl_e = basix.ufl.element(
e.family,
mesh.basix_cell(), # type: ignore
e.degree,
shape=e.shape,
symmetry=e.symmetry,
dtype=dtype,
)
except TypeError:
ufl_e = element # type: ignore
if len(mesh.topology._cpp_object.cell_types) > 1:
try:
e = ElementMetaData(*element) # type: ignore
ufl_e = [
basix.ufl.element(
e.family,
getattr(basix.CellType, _cpp.mesh.to_string(cell)),
e.degree,
shape=e.shape,
symmetry=e.symmetry,
dtype=dtype,
)
for cell in mesh.topology._cpp_object.cell_types
]
except TypeError:
ufl_e = element # type: ignore

# Check that element and mesh cell types match
for domain, element in zip(mesh.ufl_domain(), ufl_e):
if domain is None or element.cell != domain.ufl_cell():
raise ValueError("Non-matching UFL cell and mesh cell shapes.")
else:
try:
ufl_e = basix.ufl.element(
e.family,
mesh.basix_cell(), # type: ignore
e.degree,
shape=e.shape,
symmetry=e.symmetry,
dtype=dtype,
)
except TypeError:
ufl_e = element # type: ignore

# Check that element and mesh cell types match
if ((domain := mesh.ufl_domain()) is None) or ufl_e.cell != domain.ufl_cell():
raise ValueError("Non-matching UFL cell and mesh cell shapes.")

# Check that element and mesh cell types match
if ((domain := mesh.ufl_domain()) is None) or ufl_e.cell != domain.ufl_cell():
raise ValueError("Non-matching UFL cell and mesh cell shapes.")
# TODO
# TODO

# Create DOLFINx objects
element = finiteelement(mesh.topology.cell_type, ufl_e, dtype) # type: ignore
Expand Down
115 changes: 81 additions & 34 deletions python/dolfinx/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class Mesh:
def __init__(
self,
msh: typing.Union[_cpp.mesh.Mesh_float32, _cpp.mesh.Mesh_float64],
domain: typing.Optional[ufl.Mesh],
domain: typing.Optional[typing.Union[ufl.Mesh, Sequence[ufl.Mesh]]],
):
"""Initialize mesh from a C++ mesh.

Expand All @@ -307,8 +307,13 @@ def __init__(
self._topology = Topology(self._cpp_object.topology)
self._geometry = Geometry(self._cpp_object.geometry)
self._ufl_domain = domain
if self._ufl_domain is not None:
self._ufl_domain._ufl_cargo = self._cpp_object # type: ignore
try:
for d in self._ufl_domain:
if d is not None:
d._ufl_cargo = self._cpp_object # type: ignore
except TypeError:
if self._ufl_domain is not None:
self._ufl_domain._ufl_cargo = self._cpp_object # type: ignore

@property
def comm(self):
Expand Down Expand Up @@ -680,23 +685,68 @@ def refine(
return Mesh(mesh1, ufl_domain), parent_cell, parent_facet


def extract_cmap_and_domain(
e: typing.Union[
ufl.Mesh,
basix.finite_element.FiniteElement,
basix.ufl._BasixElement,
_CoordinateElement,
],
gdim: int,
) -> tuple[typing.Any, typing.Any]:
try:
# e is a UFL domain
e_ufl = e.ufl_coordinate_element()
return _coordinate_element(e_ufl.basix_element), e
except AttributeError:
pass

try:
# e is a Basix 'UFL' element
domain = ufl.Mesh(e)
assert domain.geometric_dimension == gdim
return _coordinate_element(e.basix_element), domain
except (AttributeError, TypeError):
pass
try:
# e is a Basix element
# TODO: Resolve geometric dimension vs shape for manifolds
e_ufl = basix.ufl._BasixElement(e) # type: ignore
e_ufl = basix.ufl.blocked_element(e_ufl, shape=(gdim,))
domain = ufl.Mesh(e_ufl)
assert domain.geometric_dimension == gdim
return _coordinate_element(e), domain # type: ignore
except (AttributeError, TypeError):
pass
# e is a CoordinateElement
return e, None


def create_mesh(
comm: _MPI.Comm,
cells: npt.NDArray[np.int64],
cells: typing.Union[
npt.NDArray[np.int64],
typing.Sequence[npt.NDArray[np.int64]],
],
x: npt.NDArray[np.floating],
e: typing.Union[
ufl.Mesh,
basix.finite_element.FiniteElement,
basix.ufl._BasixElement,
_CoordinateElement,
Sequence[
typing.Union[
basix.finite_element.FiniteElement, basix.ufl._BasixElement, _CoordinateElement
]
],
],
partitioner: typing.Optional[Callable] = None,
) -> Mesh:
"""Create a mesh from topology and geometry arrays.

Args:
comm: MPI communicator to define the mesh on.
cells: Cells of the mesh. ``cells[i]`` are the 'nodes' of cell
cells: Cells of the mesh. ``cells[i]`` are the 'nodes' of cell
``i``.
x: Mesh geometry ('node' coordinates), with shape
``(num_nodes, gdim)``.
Expand All @@ -721,37 +771,34 @@ def create_mesh(
else:
gdim = x.shape[1]

dtype = None
try:
# e is a UFL domain
e_ufl = e.ufl_coordinate_element() # type: ignore
cmap = _coordinate_element(e_ufl.basix_element) # type: ignore
domain = e
dtype = cmap.dtype
# TODO: Resolve UFL vs Basix geometric dimension issue
# assert domain.geometric_dimension() == gdim
except AttributeError:
try:
# e is a Basix 'UFL' element
cmap = _coordinate_element(e.basix_element) # type: ignore
domain = ufl.Mesh(e)
dtype = cmap.dtype
assert domain.geometric_dimension() == gdim
except AttributeError:
try:
# e is a Basix element
# TODO: Resolve geometric dimension vs shape for manifolds
cmap = _coordinate_element(e) # type: ignore
e_ufl = basix.ufl._BasixElement(e) # type: ignore
e_ufl = basix.ufl.blocked_element(e_ufl, shape=(gdim,))
domain = ufl.Mesh(e_ufl)
# e and cells are Sequences
dtype = None
formatted_cells = []
cmaps = []
domains = []
for i, c in zip(e, cells):
cmap, domain = extract_cmap_and_domain(i, gdim)
if dtype is None:
dtype = cmap.dtype
assert domain.geometric_dimension() == gdim
except (AttributeError, TypeError):
# e is a CoordinateElement
cmap = e
domain = None
dtype = cmap.dtype # type: ignore
else:
assert dtype == cmap.dtype
c = np.asarray(c, dtype=np.int64, order="C")
formatted_cells.append(c)
cmaps.append(cmap._cpp_object)
domains.append(domain)

x = np.asarray(x, dtype=dtype, order="C")
msh: typing.Union[_cpp.mesh.Mesh_float32, _cpp.mesh.Mesh_float64] = _cpp.mesh.create_mesh(
comm, formatted_cells, cmaps, x, partitioner
)
return Mesh(msh, domains)

except KeyboardInterrupt:
pass

cmap, domain = extract_cmap_and_domain(e, gdim)
dtype = cmap.dtype

x = np.asarray(x, dtype=dtype, order="C")
cells = np.asarray(cells, dtype=np.int64, order="C")
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/mesh/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_UFLCell(interval, square, rectangle, cube, box):
def test_UFLDomain(interval, square, rectangle, cube, box):
def _check_ufl_domain(mesh):
domain = mesh.ufl_domain()
assert mesh.geometry.dim == domain.geometric_dimension()
assert mesh.geometry.dim == domain.geometric_dimension
assert mesh.topology.dim == domain.topological_dimension()
assert mesh.ufl_cell() == domain.ufl_cell()

Expand Down
Loading