1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
|
# --------------------------------------------------------------------
from petsc4py import PETSc
import unittest
from test_snes import BaseTestSNES
# --------------------------------------------------------------------
class MySNES:
def __init__(self):
self.trace = False
self.call_log = {}
def _log(self, method, *args):
self.call_log.setdefault(method, 0)
self.call_log[method] += 1
if not self.trace:
return
clsname = self.__class__.__name__
pargs = []
for a in args:
pargs.append(a)
if isinstance(a, PETSc.Object):
pargs[-1] = type(a).__name__
pargs = tuple(pargs)
print(f'{clsname}.{method}{pargs}')
def create(self, *args):
self._log('create', *args)
def destroy(self, *args):
self._log('destroy', *args)
if not self.trace:
return
for k, v in self.call_log.items():
print(f'{k} {v}')
def view(self, snes, viewer):
self._log('view', snes, viewer)
def setFromOptions(self, snes):
OptDB = PETSc.Options(snes)
self.trace = OptDB.getBool('trace', self.trace)
self._log('setFromOptions', snes)
def setUp(self, snes):
self._log('setUp', snes)
def reset(self, snes):
self._log('reset', snes)
# def preSolve(self, snes):
# self._log('preSolve', snes)
#
# def postSolve(self, snes):
# self._log('postSolve', snes)
def preStep(self, snes):
self._log('preStep', snes)
def postStep(self, snes):
self._log('postStep', snes)
# def computeFunction(self, snes, x, F):
# self._log('computeFunction', snes, x, F)
# snes.computeFunction(x, F)
#
# def computeJacobian(self, snes, x, A, B):
# self._log('computeJacobian', snes, x, A, B)
# flag = snes.computeJacobian(x, A, B)
# return flag
#
# def linearSolve(self, snes, b, x):
# self._log('linearSolve', snes, b, x)
# snes.ksp.solve(b,x)
# ## return False # not succeed
# if snes.ksp.getConvergedReason() < 0:
# return False # not succeed
# return True # succeed
#
# def lineSearch(self, snes, x, y, F):
# self._log('lineSearch', snes, x, y, F)
# x.axpy(-1,y)
# snes.computeFunction(x, F)
# ## return False # not succeed
# return True # succeed
class TestSNESPython(BaseTestSNES, unittest.TestCase):
SNES_TYPE = PETSc.SNES.Type.PYTHON
def setUp(self):
super().setUp()
self.snes.setPythonContext(MySNES())
def testGetType(self):
ctx = self.snes.getPythonContext()
pytype = f'{ctx.__module__}.{type(ctx).__name__}'
self.assertTrue(self.snes.getPythonType() == pytype)
# --------------------------------------------------------------------
if __name__ == '__main__':
unittest.main()
# --------------------------------------------------------------------
|