1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
|
# Copyright (C) 2020 Jørgen S. Dokken
#
# This file is part of DOLFINX_MPC
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
from mpi4py import MPI
from petsc4py import PETSc
import dolfinx.fem as fem
import numpy as np
import pytest
import ufl
from dolfinx.common import Timer, TimingType, list_timings
from dolfinx.mesh import CellType, create_unit_square
import dolfinx_mpc
import dolfinx_mpc.utils
from dolfinx_mpc.utils import get_assemblers # noqa: F401
@pytest.mark.parametrize("get_assemblers", ["C++"], indirect=True)
@pytest.mark.parametrize("master_point", [[1, 1], [0, 1]])
@pytest.mark.parametrize("degree", range(1, 4))
@pytest.mark.parametrize("celltype", [CellType.quadrilateral, CellType.triangle])
def test_mpc_assembly(master_point, degree, celltype, get_assemblers): # noqa: F811
_, assemble_vector = get_assemblers
# Create mesh and function space
mesh = create_unit_square(MPI.COMM_WORLD, 3, 5, celltype)
V = fem.functionspace(mesh, ("Lagrange", degree))
# Generate reference vector
v = ufl.TestFunction(V)
x = ufl.SpatialCoordinate(mesh)
f = ufl.sin(2 * ufl.pi * x[0]) * ufl.sin(ufl.pi * x[1])
rhs = ufl.inner(f, v) * ufl.dx
linear_form = fem.form(rhs)
def l2b(li):
return np.array(li, dtype=mesh.geometry.x.dtype).tobytes()
s_m_c = {
l2b([1, 0]): {l2b([0, 1]): 0.43, l2b([1, 1]): 0.11},
l2b([0, 0]): {l2b(master_point): 0.69},
}
mpc = dolfinx_mpc.MultiPointConstraint(V)
mpc.create_general_constraint(s_m_c)
mpc.finalize()
b = assemble_vector(linear_form, mpc)
b.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
# Reduce system with global matrix K after assembly
L_org = fem.petsc.assemble_vector(linear_form)
L_org.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
root = 0
comm = mesh.comm
with Timer("~TEST: Compare"):
dolfinx_mpc.utils.compare_mpc_rhs(L_org, b, mpc, root=root)
list_timings(comm, [TimingType.wall])
b.destroy()
L_org.destroy()
|