File: function.hh

package info (click to toggle)
dune-grid 2.11.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 8,472 kB
  • sloc: cpp: 60,883; python: 1,438; perl: 191; makefile: 12; sh: 3
file content (393 lines) | stat: -rw-r--r-- 18,254 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
// SPDX-FileCopyrightText: Copyright © DUNE Project contributors, see file LICENSE.md in module root
// SPDX-License-Identifier: LicenseRef-GPL-2.0-only-with-DUNE-exception
// -*- tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*-
// vi: set et ts=4 sw=2 sts=2:
#ifndef DUNE_PYTHON_GRID_FUNCTION_HH
#define DUNE_PYTHON_GRID_FUNCTION_HH

#include <functional>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>

#include <dune/common/ftraits.hh>
#include <dune/common/visibility.hh>

#include <dune/python/common/dimrange.hh>
#include <dune/python/common/typeregistry.hh>
#include <dune/python/common/vector.hh>
#include <dune/python/common/fvector.hh>
#include <dune/python/grid/simplegridfunction.hh>
#include <dune/python/grid/localview.hh>
#include <dune/python/grid/entity.hh>
#include <dune/python/grid/numpy.hh>
#include <dune/python/grid/object.hh>
#include <dune/python/grid/vtk.hh>

#if HAVE_DUNE_VTK
#include <dune/vtk/gridfunctions/gridfunction.hh>
#endif

#include <dune/python/pybind11/numpy.h>
#include <dune/python/pybind11/pybind11.h>

namespace Dune
{

  namespace Python
  {

    // GridFunctionTraits
    // ------------------

    template< class GridFunction >
    struct GridFunctionTraits
      : public GridObjectTraits< GridFunction >
    {
      typedef typename GridObjectTraits< GridFunction >::LocalCoordinate LocalCoordinate;

      typedef std::decay_t< decltype( localFunction( std::declval< const GridFunction & >() ) ) > LocalFunction;
      typedef std::decay_t< decltype( std::declval< LocalFunction & >()( std::declval< const LocalCoordinate & >() ) ) > Range;

      typedef typename GridFunction::GridView GridView;
    };



    namespace detail
    {

      template< class LocalCoordinate, class LocalFunction, class X >
      inline static auto callLocalFunction ( LocalFunction &&f, const X &x, PriorityTag< 2 > )
        -> decltype( f( x ) )
      {
        return f( x );
      }

      template< class LocalCoordinate, class LocalFunction >
      inline static pybind11::object callLocalFunction ( LocalFunction &&f, pybind11::array_t< typename FieldTraits< LocalCoordinate >::field_type > x, PriorityTag< 1 > )
      {
        return vectorize( [ &f ] ( const LocalCoordinate &x ) { return f( x ); }, x );
      }

      template< class LocalCoordinate, class LocalFunction, class X >
      inline static auto callLocalFunction ( LocalFunction &&f, const X &x, PriorityTag<0> )
        -> std::enable_if_t< !std::is_const< std::remove_reference_t< LocalFunction > >::value, pybind11::object >
      {
        return callLocalFunction< LocalCoordinate >( std::forward< LocalFunction >( f ), x, PriorityTag< 42 >() );
      }

      template< class GridFunction, class... options >
      void registerGridFunction ( pybind11::handle scope, pybind11::class_< GridFunction, options... > cls )
      {
        using pybind11::operator""_a;
        typedef typename GridFunctionTraits< GridFunction >::Range Range;
        cls.def_property_readonly( "grid", [] ( const GridFunction &self ) -> pybind11::handle
                    { return pybind11::cast(Dune::Python::gridView( self )); } );
        cls.def_property_readonly( "gridView", [] ( const GridFunction &self ) -> pybind11::handle
                    { return pybind11::cast(Dune::Python::gridView( self )); } );
        cls.def_property_readonly( "dimRange", [] ( pybind11::object self ) { return pybind11::int_( DimRange< Range >::value ); } );
        cls.def( "addToVTKWriter", &addToVTKWriter< GridFunction >, pybind11::keep_alive< 3, 1 >(), "name"_a, "writer"_a, "dataType"_a );

        cls.def( "cellData", [] ( const GridFunction &self, int level ) { return cellData( self, level ); }, "level"_a = 0 );
        cls.def( "pointData", [] ( const GridFunction &self, int level ) { return pointData( self, level ); }, "level"_a = 0 );
        auto dataWithPartition = [&cls](auto &part)
        {
          cls.def( "cellData", [] ( const GridFunction &self, int level, decltype(part) partition )
          { return cellData( self, level, partition ); }, "level"_a = 0, pybind11::kw_only(), "partition"_a );
          cls.def( "pointData", [] ( const GridFunction &self, int level, decltype(part) partition )
          { return pointData( self, level, partition ); }, "level"_a = 0, pybind11::kw_only(), "partition"_a );
        };
        dataWithPartition(Dune::Partitions::interior);
        dataWithPartition(Dune::Partitions::ghost);
        dataWithPartition(Dune::Partitions::all);
        dataWithPartition(Dune::Partitions::interiorBorder);
        dataWithPartition(Dune::Partitions::interiorBorderOverlap);
        dataWithPartition(Dune::Partitions::interiorBorderOverlapFront);

        cls.def( "polygonData", [] ( const GridFunction &self ) { return polygonData( self ); },
          R"doc(
            Store the grid with piecewise constant data in numpy arrays.

            Returns: pair with coordinate array storing the vertex coordinate of each polygon
                     in the grid and an array with a range type for each polygon.
          )doc" );
      }
    } // namespace detail

    // registerGridFunction
    // --------------------

    template< class GridFunction, class... options >
    void registerGridFunction ( pybind11::handle scope, pybind11::class_< GridFunction, options... > cls )
    {
      using pybind11::operator""_a;

      typedef typename GridFunctionTraits< GridFunction >::Element Element;
      typedef typename GridFunctionTraits< GridFunction >::LocalCoordinate LocalCoordinate;
      typedef typename GridFunctionTraits< GridFunction >::LocalFunction LocalFunction;
      typedef typename GridFunctionTraits< GridFunction >::Range Range;

      typedef pybind11::array_t< typename FieldTraits< LocalCoordinate >::field_type > Array;

      // TODO subclassing from a non registered traits class not covered by TypeRegistry
      pybind11::class_< LocalFunction > clsLocalFunction( cls, "LocalFunction", pybind11::dynamic_attr() );
      registerLocalView< Element >( clsLocalFunction );
      clsLocalFunction.def( "__call__", [] ( LocalFunction &self, const LocalCoordinate &x ) {
          return detail::callLocalFunction< LocalCoordinate >( self, x, PriorityTag<2>() );
        }, "x"_a );
      clsLocalFunction.def( "__call__", [] ( LocalFunction &self, Array x ) {
          return detail::callLocalFunction< LocalCoordinate >( self, x, PriorityTag<2>() );
        }, "x"_a );
      clsLocalFunction.def_property_readonly( "dimRange", [] ( pybind11::object self ) { return pybind11::int_( DimRange< Range >::value ); } );

      cls.def( "localFunction", [] ( const GridFunction &self ) { return localFunction( self ); }, pybind11::keep_alive< 0, 1 >() );
      cls.def( "localFunction", [] ( const GridFunction &self, const Element &element )
          { auto lf = localFunction(self); lf.bind(element); return lf; },
          pybind11::keep_alive< 0, 1 >(),
          pybind11::keep_alive< 0, 2 >() );

      cls.def( "__call__", [] ( const GridFunction &self, const Element &element, LocalCoordinate &x ) {
          auto lf = localFunction(self);
          lf.bind(element);
          auto y = detail::callLocalFunction< LocalCoordinate >( lf, x, PriorityTag<2>() );
          lf.unbind();
          return y;
        }, "element"_a, "x"_a );
      cls.def( "__call__", [] ( const GridFunction &self, const Element &element, Array x ) {
          auto lf = localFunction(self);
          lf.bind(element);
          auto y = detail::callLocalFunction< LocalCoordinate >( lf, x, PriorityTag<2>() );
          lf.unbind();
          return y;
        }, "element"_a, "x"_a );
      detail::registerGridFunction(scope,cls);
#if HAVE_DUNE_VTK
      typedef typename GridFunctionTraits< GridFunction >::GridView GridView;
      using VirtualizedGF = Dune::Vtk::GridFunction<GridView>;
      // register the Function class if not already available
      auto vgfClass = Python::insertClass<VirtualizedGF>(scope,"VtkFunction",
          Python::GenerateTypeName("Dune::Vtk::GridFunction",MetaType<GridView>()),
          Python::IncludeFiles{"dune/vtk/gridfunctions/gridfunction.hh"});
      vgfClass.first.def( pybind11::init( [] ( GridFunction &gf ) {
          // TODO: perhaps grid functions should just have a name attribute in general
          return new VirtualizedGF( gf, Dune::Vtk::FieldInfo{"tmp"} );
        } ) );
      pybind11::implicitly_convertible<GridFunction,VirtualizedGF>();
#endif
    }

    template <class GridView, int dimR>
    struct stdFunction
    {
      static const unsigned int dimRange = (dimR ==0 ? 1 : dimR);
      typedef typename GridView::template Codim< 0 >::Entity Entity;
      typedef typename Entity::Geometry::LocalCoordinate Coordinate;
      typedef typename std::conditional< dimR == 0, double, Dune::FieldVector< double, dimRange > >::type Value;
      typedef std::function<Value(const Entity&,const Coordinate&)> type;
    };
    template <class GridView,int dimR,class Evaluate>
    struct EvaluateType
    {
      typedef typename GridView::template Codim< 0 >::Entity Entity;
      typedef typename Entity::Geometry::LocalCoordinate Coordinate;
      typedef typename std::conditional< dimR == 0, double, Dune::FieldVector< double, dimR > >::type Value;
      static std::string name()
      { std::string entity = findInTypeRegistry<Entity>().first->second.name;
        std::string coord = findInTypeRegistry<Coordinate>().first->second.name;
        std::string value;
        if (dimR==0) value = "double";
        else
        {
          auto found = findInTypeRegistry<Value>();
          assert(!found.second);
          value = found.first->second.name;
        }
        return "std::function<"+value+"(const "+entity+"&,const "+coord+"&)>";
      }
    };
    template <class GridView,int dimR>
    struct EvaluateType<GridView,dimR,pybind11::function>
    {
      static std::string name() { return "pybind11::function"; }
    };

    namespace detail
    {

      // PyGridFunctionEvaluator
      // -----------------------
      template <class GridView, int dimR, class Evaluate>
      struct DUNE_PRIVATE PyGridFunctionEvaluator
      {};

      template <class GridView, int dimR>
      struct DUNE_PRIVATE PyGridFunctionEvaluator<GridView,dimR,pybind11::function>
      {
        static const unsigned int dimRange = (dimR ==0 ? 1 : dimR);

        typedef typename GridView::template Codim< 0 >::Entity Entity;
        typedef typename Entity::Geometry::LocalCoordinate Coordinate;

        typedef typename std::conditional< dimR == 0, double, Dune::FieldVector< double, dimRange > >::type Value;

        explicit PyGridFunctionEvaluator ( pybind11::function evaluate ) : evaluate_( evaluate ) {}

        Value operator() ( const Entity &entity, const Coordinate &x ) const
        {
          pybind11::gil_scoped_acquire acq;
          return pybind11::cast< Value >( evaluate_( entity, x ) );
        }

        pybind11::array_t< double > operator() ( const Entity &entity, const pybind11::array_t<double> x ) const
        {
          pybind11::gil_scoped_acquire acq;
          return pybind11::cast< pybind11::array_t< double > >( evaluate_( entity, x ) );
        }

        pybind11::function evaluate() const { return evaluate_; }
      private:
        pybind11::function evaluate_;
      };
      template <class GridView, int dimR>
      struct DUNE_PRIVATE PyGridFunctionEvaluator<GridView,dimR,
                          typename stdFunction<GridView,dimR>::type >
      {
        static const unsigned int dimRange = (dimR ==0 ? 1 : dimR);

        typedef typename GridView::template Codim< 0 >::Entity Entity;
        typedef typename Entity::Geometry::LocalCoordinate Coordinate;

        typedef typename std::conditional< dimR == 0, double, Dune::FieldVector< double, dimRange > >::type Value;

        typedef typename stdFunction<GridView,dimR>::type Evaluate;
        explicit PyGridFunctionEvaluator ( Evaluate evaluate ) : evaluate_( evaluate ) {}

        Value operator() ( const Entity &entity, const Coordinate &x ) const
        {
          return evaluate_( entity, x );
        }

        Evaluate evaluate() const { return evaluate_; }
      private:
        Evaluate evaluate_;
      };



      // registerPyGridFunction
      // ----------------------

      template< class GridView, class Evaluate, unsigned int dimRange >
      auto registerPyGridFunction ( pybind11::handle scope, const std::string &name, bool scalar, std::integral_constant< unsigned int, dimRange > )
      {
        using pybind11::operator""_a;

        typedef typename GridView::template Codim<0>::Entity Entity;
        typedef typename Entity::Geometry::LocalCoordinate LocalCoordinate;
        typedef PyGridFunctionEvaluator<GridView,dimRange,Evaluate> Evaluator;
        typedef SimpleGridFunction< GridView, Evaluator > GridFunction;
        if (dimRange>0)
          Dune::Python::registerFieldVector<double,dimRange>(scope);
        addToTypeRegistry<Evaluator>(GenerateTypeName("Dune::Python::detail::PyGridFunctionEvaluator",
                                            MetaType<GridView>(),dimRange,
                                            EvaluateType<GridView,dimRange,Evaluate>::name()
                                            ),
                         IncludeFiles{"dune/python/grid/function.hh"});

        std::string clsName = name + std::to_string( dimRange );
        auto gf = insertClass< GridFunction >( scope, clsName,
                  pybind11::dynamic_attr(),
                  GenerateTypeName("Dune::Python::SimpleGridFunction",
                                      MetaType<GridView>(), Dune::MetaType<Evaluator>()),
            IncludeFiles{"dune/python/grid/function.hh"});
        gf.first.def(pybind11::init([](GridView &gridView, Evaluate callable) {
              return new GridFunction( gridView,
                         PyGridFunctionEvaluator<GridView,dimRange,Evaluate>(callable) );
              }), "gridView"_a, "callable"_a, pybind11::keep_alive<1,2>() );
        gf.first.def( pybind11::pickle( [](const pybind11::object &self) { // __getstate__
            GridFunction& gv = self.cast<GridFunction&>();
            /* Return a tuple that fully encodes the state of the object */
            pybind11::dict d;
            if (pybind11::hasattr(self, "__dict__")) {
              d = self.attr("__dict__");
            }
            return pybind11::make_tuple(gv.gridView(),gv.localEvaluator().evaluate(),d);
          },
        [](pybind11::tuple t) { // __setstate__
            if (t.size() != 3)
                throw std::runtime_error("Invalid state in GridFunction::setstate with "+std::to_string(t.size())+"arguments!");
            pybind11::handle pygv = t[0];
            GridView& gv = pygv.cast<GridView&>();
            pybind11::handle pyeval = t[1];
            Evaluate callable = pyeval.cast<Evaluate>();
            /* Create a new C++ instance */
            auto py_state = t[2].cast<pybind11::dict>();
            return std::make_pair(
                   new GridFunction(gv,PyGridFunctionEvaluator<GridView,dimRange,Evaluate>(callable)),
                   py_state);
          }
        ),pybind11::keep_alive<1,2>());

        if (gf.second)
        {
          Dune::Python::registerGridFunction( scope, gf.first );
          gf.first.def_property_readonly( "scalar", [scalar] ( pybind11::object self ) { return scalar; } );
          if constexpr (dimRange>0)
          {
            typedef typename Dune::Python::stdFunction<GridView,0>::type Evaluate0;
            detail::registerPyGridFunction< GridView, Evaluate0, 0 >
                    ( scope, name, true, std::integral_constant< unsigned int, 0 >() );
            gf.first.def( "__getitem__", [] ( const GridFunction &self, std::size_t c ) {
              Evaluate0 eval0 = [&self,c](const Entity &e, const LocalCoordinate &x) -> double
              {
                auto lf = localFunction(self);
                lf.bind(e);
                auto y = detail::callLocalFunction< LocalCoordinate >( lf, x, PriorityTag<2>() );
                lf.unbind();
                return y[0];
              };
              auto gridFunction = simpleGridFunction( gridView(self),
                  detail::PyGridFunctionEvaluator< GridView, 0, Evaluate0 >( std::move( eval0 ) )
                );
              return gridFunction;
              // return pybind11::cast( std::move( gridFunction ) );
            }, pybind11::keep_alive< 0, 1 >() );
          }
        }
        return gf;
      }

    } // namespace detail

    template< class GridView, class Evaluate, int dimRange >
    auto registerGridFunction ( pybind11::handle scope, std::string name, bool scalar )
    {
      detail::registerPyGridFunction< GridView, Evaluate, dimRange >( scope, name, scalar, std::integral_constant< unsigned int, dimRange >() );
    }
    template <class Value> struct FunctionRange
    {
      static constexpr int value()
      { if constexpr (std::is_convertible_v<Value,double>) return 0; else return Value::dimension; }
    };
    template< class GridView, class Eval >
    auto registerGridFunction ( pybind11::handle scope, pybind11::object gp, std::string name, Eval eval )
    {
      typedef typename GridView::template Codim<0>::Entity Entity;
      typedef typename Entity::Geometry::LocalCoordinate LocalCoordinate;
      typedef decltype(eval(std::declval<const Entity&>(),std::declval<const LocalCoordinate&>())) Value;
      static constexpr int dimRange = FunctionRange<Value>::value();
      typedef typename Dune::Python::stdFunction<GridView,dimRange>::type Evaluate;
      registerGridFunction< GridView, Evaluate, dimRange >( scope, name, dimRange==0 );

      Evaluate evaluate(eval);
      const GridView &gridView = gp.cast< const GridView & >();
      auto gridFunction = simpleGridFunction( gridView, detail::PyGridFunctionEvaluator< GridView, dimRange, Evaluate >( std::move( evaluate ) ) );
      return pybind11::cast( std::move( gridFunction ), pybind11::return_value_policy::move, gp );
    }
  } // namespace Python

} // namespace Dune

#endif // #ifndef DUNE_PYTHON_GRID_FUNCTION_HH