File: test_split.py

package info (click to toggle)
fenics-ufl 2025.2.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,176 kB
  • sloc: python: 25,267; makefile: 170
file content (89 lines) | stat: -rwxr-xr-x 3,644 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
__authors__ = "Martin Sandve Alnæs"
__date__ = "2009-03-14 -- 2009-03-14"

from utils import FiniteElement, LagrangeElement, MixedElement, SymmetricElement

from ufl import Coefficient, FunctionSpace, Mesh, TestFunction, as_vector, product, split, triangle
from ufl.pullback import identity_pullback
from ufl.sobolevspace import H1


def test_split(self):
    cell = triangle
    d = cell.topological_dimension()
    domain = Mesh(LagrangeElement(cell, 1, (d,)))
    f = LagrangeElement(cell, 1)
    v = FiniteElement(
        "Lagrange", cell, 1, (d,), identity_pullback, H1, sub_elements=[f for _ in range(d)]
    )
    w = FiniteElement(
        "Lagrange", cell, 1, (d + 1,), identity_pullback, H1, sub_elements=[f for _ in range(d + 1)]
    )
    t = FiniteElement(
        "Lagrange", cell, 1, (d, d), identity_pullback, H1, sub_elements=[f for _ in range(d**2)]
    )
    s = SymmetricElement({(0, 0): 0, (0, 1): 1, (1, 0): 1, (1, 1): 2}, [f for _ in range(3)])
    m = MixedElement([f, v, w, t, s, s])

    f_space = FunctionSpace(domain, f)
    v_space = FunctionSpace(domain, v)
    w_space = FunctionSpace(domain, w)
    t_space = FunctionSpace(domain, t)
    s_space = FunctionSpace(domain, s)
    m_space = FunctionSpace(domain, m)

    # Check that shapes of all these functions are correct:
    assert () == Coefficient(f_space).ufl_shape
    assert (d,) == Coefficient(v_space).ufl_shape
    assert (d + 1,) == Coefficient(w_space).ufl_shape
    assert (d, d) == Coefficient(t_space).ufl_shape
    assert (d, d) == Coefficient(s_space).ufl_shape
    # sum of value sizes, not accounting for symmetries:
    assert (3 * d * d + 2 * d + 2,) == Coefficient(m_space).ufl_shape

    # Shapes of subelements are reproduced:
    g = Coefficient(m_space)
    (size,) = g.ufl_shape
    for g2 in split(g):
        size -= product(g2.ufl_shape)
    assert size == 0

    # Mixed elements of non-scalar subelements are flattened
    v2 = MixedElement([v, v])
    m2 = MixedElement([t, t])
    v2_space = FunctionSpace(domain, v2)
    m2_space = FunctionSpace(domain, m2)
    # assert d == 2
    # assert (2,2) == Coefficient(v2_space).ufl_shape
    assert (d + d,) == Coefficient(v2_space).ufl_shape
    assert (2 * d * d,) == Coefficient(m2_space).ufl_shape

    # Split simple element
    t = TestFunction(f_space)
    assert split(t) == (t,)

    # Split twice on nested mixed elements gets
    # the innermost scalar subcomponents
    t = TestFunction(FunctionSpace(domain, MixedElement([f, v])))
    assert split(t) == (t[0], as_vector((t[1], t[2])))
    assert split(split(t)[1]) == (t[1], t[2])
    t = TestFunction(FunctionSpace(domain, MixedElement([f, [f, v]])))
    assert split(t) == (t[0], as_vector((t[1], t[2], t[3])))
    assert split(split(t)[1]) == (t[1], as_vector((t[2], t[3])))
    t = TestFunction(FunctionSpace(domain, MixedElement([[v, f], [f, v]])))
    assert split(t) == (as_vector((t[0], t[1], t[2])), as_vector((t[3], t[4], t[5])))
    assert split(split(t)[0]) == (as_vector((t[0], t[1])), t[2])
    assert split(split(t)[1]) == (t[3], as_vector((t[4], t[5])))
    assert split(split(split(t)[0])[0]) == (t[0], t[1])
    assert split(split(split(t)[0])[1]) == (t[2],)
    assert split(split(split(t)[1])[0]) == (t[3],)
    assert split(split(split(t)[1])[1]) == (t[4], t[5])

    # Split twice on nested mixed elements with symmetry
    vs = MixedElement([v, s])
    vs_space = FunctionSpace(domain, vs)
    vs_test = TestFunction(vs_space)

    v_test, s_test = split(vs_test)
    assert split(v_test) == (vs_test[0], vs_test[1])
    assert split(s_test) == (vs_test[2], vs_test[3], vs_test[4], vs_test[5])