File: grid_generator.py

package info (click to toggle)
dune-grid 2.10.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 8,248 kB
  • sloc: cpp: 59,108; python: 1,437; perl: 191; makefile: 6; sh: 3
file content (383 lines) | stat: -rw-r--r-- 16,388 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
# SPDX-FileCopyrightText: Copyright © DUNE Project contributors, see file LICENSE.md in module root
# SPDX-License-Identifier: LicenseRef-GPL-2.0-only-with-DUNE-exception

import os, inspect
from ..generator.generator import SimpleGenerator
from dune.common.hashit import hashIt
from dune.common import FieldVector
from dune.common.utility import isString
from dune.deprecate import deprecated
from dune.grid import gridFunction, DataType, Partitions
from dune.grid import OutputType
from dune.generator.algorithm import cppType
from dune.generator import builder
from dune.deprecate import deprecated

def getDimgrid(constructor):
    dimgrid = None
    if not dimgrid:
        try:
            dimgrid = constructor.dimgrid
        except AttributeError:
            pass
    if not dimgrid:
        try:
            dimgrid = len(constructor["vertices"][0])
        except KeyError:
            pass
    if not dimgrid:
        raise ValueError("Couldn't extract dimension of grid from constructor arguments, added dimgrid parameter")
    return dimgrid

def triangulation(grid, level=0, *, partition=Partitions.all):
    if grid.dimGrid != 2:
        raise Exception("Grid must be 2-dimensional for use as matplotlib triangulation.")
    from matplotlib.tri import Triangulation
    x, triangles = grid.tessellate(level, partition=partition)
    if len(triangles) == 0: # no elements in this partition set
        return None
    return Triangulation(x[:,0], x[:,1], triangles)

_writeVTKDispatcher = []
def _writeVTK(vtk,grid,f,name,dataTag):
    done = False
    try:
        f.addToVTKWriter(name, vtk, dataTag)
        done = True
    except AttributeError:
        pass
    if not done:
        for dispatch in _writeVTKDispatcher:
            try:
                func = dispatch(grid,f)
            except:
                func = None
            if func is not None:
                func.addToVTKWriter(name,vtk,dataTag)
                done = True
                break
    if not done:
        gridFunction(grid)(f).addToVTKWriter(name, vtk, dataTag)

def writeVTK(grid, name,
             celldata=None, pointdata=None,
             cellvector=None, pointvector=None,
             number=None, subsampling=None,outputType=OutputType.appendedraw,
             write=True, nonconforming=False):
    vtk = grid.vtkWriter(nonconforming) if subsampling is None else grid.vtkWriter(subsampling)

    def addDataToVTKWriter(dataFunctions, dataName, dataTag):
        if dataFunctions is None: return
        if isinstance(dataFunctions, dict):
            for n, f in dataFunctions.items():
                if f is None: continue
                _writeVTK(vtk,grid,f,n,dataTag)
        elif isinstance(dataFunctions, list):
            for f in dataFunctions:
                if f is None: continue
                try:
                    _writeVTK(vtk,grid,f,f.name,dataTag)
                except AttributeError:
                    try:
                        _writeVTK(vtk,grid,f[0],f[1],dataTag)
                    except IndexError:
                        raise TypeError("""
Did you try to pass in a function without a name attribute?
Try using a dictionary with name:function instead.""")

        elif dataFunctions is not None:
            raise TypeError("Argument '" + dataName + "' must be a dict or list instance.")

    addDataToVTKWriter(celldata, 'celldata', DataType.CellData)
    addDataToVTKWriter(pointdata, 'pointdata', DataType.PointData)
    addDataToVTKWriter(cellvector, 'cellvector', DataType.CellVector)
    addDataToVTKWriter(pointvector, 'pointvector', DataType.PointVector)

    assert isinstance(outputType,OutputType)
    if write:
        if number is None:
            vtk.write(name, outputType)
        else:
            vtk.write(name, number, outputType)
    else:
        return vtk

class SequencedVTK:
    def __init__(self, grid, name, number,
                 celldata, pointdata, cellvector, pointvector,
                 subsampling, outputType=OutputType.appendedraw):
        self.number = number
        self.name = name
        self.vtk = grid.writeVTK(name,celldata=celldata,pointdata=pointdata,cellvector=cellvector,pointvector=pointvector,subsampling=subsampling,write=False)
        self.outputType = outputType
    def __call__(self):
        self.vtk.write(self.name, self.number, self.outputType)
        self.number += 1

def sequencedVTK(grid, name, celldata=None, pointdata=None, cellvector=None, pointvector=None,
                 number=0, subsampling=None, outputType=OutputType.appendedraw):
    return SequencedVTK(grid,name,number,
                        celldata=celldata,pointdata=pointdata,
                        cellvector=cellvector,pointvector=pointvector,
                        subsampling=subsampling,outputType=outputType)

def plot(self, function=None, *args, **kwargs):
    import dune.plotting
    if not function:
        dune.plotting.plotGrid(self, *args, **kwargs)
    else:
        if not hasattr(function,"grid"):
            function = self.function(function)
        dune.plotting.plot(solution=function,*args,**kwargs)

isGenerator = SimpleGenerator("GridViewIndexSet", "Dune::Python")
def indexSet(gv):
    try:
        return gv._indexSet
    except TypeError:
        includes = gv.cppIncludes + ["dune/python/grid/indexset.hh"]
        typeName = gv.cppTypeName+"::IndexSet"
        moduleName = "indexset_" + hashIt(typeName)
        module = isGenerator.load(includes, typeName, moduleName)
        return gv._indexSet
mcmgGenerator = SimpleGenerator("MultipleCodimMultipleGeomTypeMapper", "Dune::Python")
def mapper(gv,layout):
    includes = gv.cppIncludes + ["dune/python/grid/mapper.hh"]
    typeName = "Dune::MultipleCodimMultipleGeomTypeMapper< "+gv.cppTypeName+" >"
    moduleName = "mcmgmapper_" + hashIt(typeName)
    module = mcmgGenerator.load(includes, typeName, moduleName)
    return gv._mapper(layout)

import functools
def gfPlot(gf, *args, **kwargs):
    gf.gridView.plot(gf,*args,**kwargs)
def callbackFunction(callback_,e,x):
    return callback_(e.geometry.toGlobal(x))

def function(gv,callback,includeFiles=None,*args,name=None,order=None,dimRange=None):
    if name is None:
        name = "tmp"+str(gv._gfCounter)
        gv.__class__._gfCounter += 1
    if isString(callback):
        if includeFiles is None:
            raise ValueError("""if `callback` is the name of a C++ function
            then at least one include file containing that function must be
            provided""")

        # unique header guard is added further down
        source  = '#include <config.h>\n\n'
        source += '#define USING_DUNE_PYTHON 1\n\n'
        includes = []
        if isString(includeFiles):
            if not os.path.dirname(includeFiles):
                with open(includeFiles, "r") as include:
                    source += include.read()
                source += "\n"
            else:
                source += "#include <"+includeFiles+">\n"
                includes += [includeFiles]
        elif hasattr(includeFiles,"readable"): # for IOString
            with includeFiles as include:
                source += include.read()
            source += "\n"
        elif isinstance(includeFiles, list):
            for includefile in includeFiles:
                if not os.path.dirname(includefile):
                    with open(includefile, "r") as include:
                        source += include.read()
                    source += "\n"
            else:
                source += "#include <"+includefile+">\n"
                includes += [includefile]
        includes += gv.cppIncludes
        argTypes = []
        for arg in args:
            t,i = cppType(arg)
            argTypes.append(t)
            includes += i

        signature = callback + "( " + ", ".join(argTypes) + " )"
        moduleName = "gf_" + hashIt(signature) + "_" + hashIt(source)

        # add unique header guard with moduleName
        source = '#ifndef Guard_'+moduleName+'\n' + \
                 '#define Guard_'+moduleName+'\n\n' + \
                 source

        includes = sorted(set(includes))
        source += "".join(["#include <" + i + ">\n" for i in includes])
        source += "\n"
        source += '#include <dune/python/grid/function.hh>\n'
        source += '#include <dune/python/pybind11/pybind11.h>\n'
        source += '\n'

        source += "PYBIND11_MODULE( " + moduleName + ", module )\n"
        source += "{\n"
        source += "  module.def( \"gf\", [module] ( "+gv.cppTypeName + " &gv"+"".join([", "+argTypes[i] + " arg" + str(i) for i in range(len(argTypes))]) + " ) {\n"
        source += "      auto callback="+callback+"<"+gv.cppTypeName+">( "+",".join(["arg"+str(i) for i in range(len(argTypes))]) +"); \n"
        source += "      return Dune::Python::registerGridFunction<"+gv.cppTypeName+",decltype(callback)>(module,pybind11::cast(gv),\"tmp\",callback);\n"
        source += "    },"
        source += "    "+",".join(["pybind11::keep_alive<0,"+str(i+1)+">()" for i in range(len(argTypes)+1)])
        source += ");\n"
        source += "}\n"
        source += "#endif\n"
        gf = builder.load(moduleName, source, signature).gf(gv,*args)
    else:
        if len(inspect.signature(callback).parameters) == 1: # global function, turn into a local function
            callback_ = callback
            # callback = lambda e,x: callback_(e.geometry.toGlobal(x))
            callback = functools.partial(callbackFunction, callback_)
        else:
            callback_ = None
        if dimRange is None:
            # if no `dimRange` attribute is set on the callback,
            # try to evaluate the function to determine the dimension of
            # the return value. This can fail if the function is singular in
            # the computational domain in which case an exception is raised
            e = gv.elements.__iter__().__next__()
            try:
                y = callback(e,e.referenceElement.position(0,0))
            except ArithmeticError:
                try:
                    y = callback(e,e.referenceElement.position(0,2))
                except ArithmeticError:
                    raise TypeError("Cannot determine dimension of range of "+
                      "given grid function due to arithmetic exceptions being "+
                      "raised. Add a `dimRange` parameter to the grid function to "+
                      "solve this issue - set `dimRange`=0 for a scalar function.")
            try:
                dimRange = len(y)
            except TypeError:
                dimRange = 0
        if dimRange > 0:
            scalar = "false"
        else:
            scalar = "true"
        FieldVector(dimRange*[0]) # register FieldVector for the return value
        if not dimRange in gv.__class__._functions.keys():
            # unique header key is added further down
            source  = '#include <config.h>\n\n'
            source += '#define USING_DUNE_PYTHON 1\n\n'
            includes = gv.cppIncludes

            signature = gv.cppTypeName+"::gf<"+str(dimRange)+">"
            moduleName = "gf_" + hashIt(signature) + "_" + hashIt(source)

            # add unique header guard with moduleName
            source = '#ifndef Guard_'+moduleName+'\n' + \
                     '#define Guard_'+moduleName+'\n\n' + \
                     source

            includes = sorted(set(includes))
            source += "".join(["#include <" + i + ">\n" for i in includes])
            source += "\n"
            source += '#include <dune/python/grid/function.hh>\n'
            source += '#include <dune/python/pybind11/pybind11.h>\n'
            source += '\n'

            source += "PYBIND11_MODULE( " + moduleName + ", module )\n"
            source += "{\n"
            source += "  typedef pybind11::function Evaluate;\n";
            source += "  Dune::Python::registerGridFunction< "+gv.cppTypeName+", Evaluate, "+str(dimRange)+" >( module, \"gf\", "+scalar+" );\n"
            source += "}\n"
            source += "#endif\n"
            gfModule = builder.load(moduleName, source, signature)
            gfFunc = getattr(gfModule,"gf"+str(dimRange))
            """
            if callback_ is not None:
                gfFunc.localCall = gfFunc.__call__
                gfFunc.globalCall = lambda self,x: callback_(x)
                feval = lambda self,e,x=None: self.globalCall(e) if x is None else self.localCall(e,x)
                subclass = type(gfFunc.__name__, (gfFunc,), {"__call__": feval})
                gv.__class__._functions[dimRange] = subclass
            else:
                gv.__class__._functions[dimRange] = gfFunc
            """
            gfFunc._localCall = gfFunc.__call__
            def gfCall(self,e,x=None):
                if x is None:
                    if not hasattr(self,"_globalCall"):
                        raise AttributeError("this grid function can not be called with a global coordinate")
                    return self._globalCall(e)
                else:
                    return self._localCall(e,x)
            gfFunc.__call__ = gfCall
            gv.__class__._functions[dimRange] = gfFunc
        gf = gv.__class__._functions[dimRange](gv,callback)
        if callback_ is not None: # allow to still call with only global coordinate
            gf._globalCall = callback_
    gf.plot = functools.partial(gfPlot, gf)
    gf.name = name
    gf.order = order
    return gf

def addAttr(module, cls):
    setattr(cls, "_module", module)
    setattr(cls, "writeVTK", writeVTK)
    setattr(cls, "sequencedVTK", sequencedVTK)
    setattr(cls, "_functions", {})
    @deprecated(name="dune.grid.GridView.tesselate", msg="Use 'tessellate' (note spelling)")
    def tesselate(gv, *args,**kwargs):
        return gv.tessellate(*args,**kwargs)
    setattr(cls,"tesselate", tesselate)
    [[deprecated("use 'tessellate' (note spelling)")]]

    if cls.dimension == 2:
        setattr(cls, "plot", plot)
        setattr(cls, "triangulation", triangulation)
    else:
        def throwFunc(msg):
            def throw(*args, **kwargs):
                raise AttributeError(msg)
            return throw
        setattr(cls, "plot", throwFunc("plot(...) only implemented on 2D grids"))
        setattr(cls, "triangulation", throwFunc("triangulation(...) only implemented on 2d grid"))

    cls.indexSet = property(indexSet)
    setattr(cls,"mapper",mapper)
    setattr(cls,"function",function)
    setattr(cls,"_gfCounter",0)
def addHAttr(module):
    # register reference element for this grid
    import dune.geometry
    for d in range(module.LeafGrid.dimension+1):
        dune.geometry.module(d)
    setattr(module.HierarchicalGrid,"levelView",levelView)
    setattr(module.HierarchicalGrid,"persistentContainer",persistentContainer)

gvGenerator = SimpleGenerator("GridView", "Dune::Python")
def viewModule(includes, typeName, *args, **kwargs):
    includes = includes + ["dune/python/grid/gridview.hh"]
    moduleName = "view_" + hashIt(typeName)
    module = gvGenerator.load(includes, typeName, moduleName, *args, **kwargs)
    return module

def levelView(hgrid,level):
    includes = hgrid.cppIncludes
    typeName = "typename "+hgrid.cppTypeName+"::LevelGridView"
    viewModule(includes, typeName)
    return hgrid._levelView(level)

pcGenerator = SimpleGenerator("PersistentContainer", "Dune::Python")
def persistentContainer(hgrid,codim,dimension):
    includes = hgrid.cppIncludes + ["dune/python/grid/persistentcontainer.hh"]
    typeName = "Dune::PersistentContainer<"+hgrid.cppTypeName+", Dune::FieldVector<double,"+str(dimension)+">>"
    moduleName = "persistentcontainer_" + hashIt(typeName)
    module = pcGenerator.load(includes, typeName, moduleName)
    return module.PersistentContainer(hgrid,codim)

def module(includes, typeName, *args, **kwargs):
    try:
        generator = kwargs.pop("generator")
    except KeyError:
        generator = SimpleGenerator("HierarchicalGrid", "Dune::Python")
    includes = includes + ["dune/python/grid/hierarchical.hh"]
    typeHash = "hierarchicalgrid_" + hashIt(typeName)
    kwargs["dynamicAttr"] = True
    kwargs["holder"] = "std::shared_ptr"
    module = generator.load(includes, typeName, typeHash, *args, **kwargs)
    return module

if __name__ == "__main__":
    import doctest
    doctest.testmod(optionflags=doctest.ELLIPSIS)