File: interpolation.cpp

package info (click to toggle)
fenics-basix 0.10.0.post0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,156 kB
  • sloc: cpp: 23,435; python: 10,829; makefile: 43; sh: 26
file content (104 lines) | stat: -rw-r--r-- 4,016 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// Copyright (c) 2021 Matthew Scroggs
// FEniCS Project
// SPDX-License-Identifier:    MIT

#include "interpolation.h"
#include "finite-element.h"
#include <concepts>
#include <exception>

using namespace basix;

template <typename T, std::size_t D>
using mdspan_t = md::mdspan<T, md::dextents<std::size_t, D>>;

//----------------------------------------------------------------------------
template <std::floating_point T>
std::pair<std::vector<T>, std::array<std::size_t, 2>>
basix::compute_interpolation_operator(const FiniteElement<T>& element_from,
                                      const FiniteElement<T>& element_to)
{
  if (element_from.cell_type() != element_to.cell_type())
  {
    throw std::runtime_error(
        "Cannot interpolate between elements defined on different cell types.");
  }

  const auto [points, shape] = element_to.points();
  const auto [tab_b, tab_shape]
      = element_from.tabulate(0, mdspan_t<const T, 2>(points.data(), shape));
  mdspan_t<const T, 4> tab(tab_b.data(), tab_shape);
  const auto [imb, imshape] = element_to.interpolation_matrix();
  mdspan_t<const T, 2> i_m(imb.data(), imshape);

  const std::size_t dim_to = element_to.dim();
  const std::size_t dim_from = element_from.dim();
  const std::size_t npts = tab.extent(1);

  const std::size_t vs_from
      = std::accumulate(element_from.value_shape().begin(),
                        element_from.value_shape().end(), 1, std::multiplies{});
  const std::size_t vs_to
      = std::reduce(element_to.value_shape().begin(),
                    element_to.value_shape().end(), 1, std::multiplies{});

  if (vs_from != vs_to)
  {
    if (vs_to == 1)
    {
      // Map element_from's components into element_to
      std::array<std::size_t, 2> shape = {dim_to * vs_from, dim_from};
      std::vector<T> outb(shape[0] * shape[1], 0.0);
      mdspan_t<T, 2> out(outb.data(), shape);
      for (std::size_t i = 0; i < vs_from; ++i)
        for (std::size_t j = 0; j < dim_to; ++j)
          for (std::size_t k = 0; k < dim_from; ++k)
            for (std::size_t l = 0; l < npts; ++l)
              out(i + j * vs_from, k) += i_m(j, l) * tab(0, l, k, i);

      return {std::move(outb), shape};
    }
    else if (vs_from == 1)
    {
      // Map duplicates of element_to to components of element_to
      std::array<std::size_t, 2> shape = {dim_to, dim_from * vs_to};
      std::vector<T> outb(shape[0] * shape[1], 0.0);
      mdspan_t<T, 2> out(outb.data(), shape);
      for (std::size_t i = 0; i < vs_to; ++i)
        for (std::size_t j = 0; j < dim_from; ++j)
          for (std::size_t k = 0; k < dim_to; ++k)
            for (std::size_t l = 0; l < npts; ++l)
              out(k, i + j * vs_to) += i_m(k, i * npts + l) * tab(0, l, j, 0);

      return {std::move(outb), shape};
    }
    else
    {
      throw std::runtime_error("Cannot interpolate between elements with this "
                               "combination of value sizes.");
    }
  }
  else
  {
    std::array<std::size_t, 2> shape = {dim_to, dim_from};
    std::vector<T> outb(shape[0] * shape[1], 0.0);
    mdspan_t<T, 2> out(outb.data(), shape);
    for (std::size_t i = 0; i < dim_to; ++i)
      for (std::size_t j = 0; j < dim_from; ++j)
        for (std::size_t k = 0; k < vs_from; ++k)
          for (std::size_t l = 0; l < npts; ++l)
            out(i, j) += i_m(i, k * npts + l) * tab(0, l, j, k);

    return {std::move(outb), shape};
  }
}
//----------------------------------------------------------------------------
/// @cond
template std::pair<std::vector<float>, std::array<std::size_t, 2>>
basix::compute_interpolation_operator(const FiniteElement<float>&,
                                      const FiniteElement<float>&);
template std::pair<std::vector<double>, std::array<std::size_t, 2>>
basix::compute_interpolation_operator(const FiniteElement<double>&,
                                      const FiniteElement<double>&);
/// @endcond
//-----------------------------------------------------------------------------