File: test_ksp_py.py

package info (click to toggle)
petsc4py 3.24.3-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,612 kB
  • sloc: python: 13,569; ansic: 1,768; makefile: 345; f90: 313; sh: 14
file content (123 lines) | stat: -rw-r--r-- 3,007 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# --------------------------------------------------------------------

from petsc4py import PETSc
import unittest
from test_ksp import BaseTestKSP

# --------------------------------------------------------------------


class MyKSP:
    def __init__(self):
        pass

    def create(self, ksp):
        self.work = []

    def destroy(self, ksp):
        for v in self.work:
            v.destroy()

    def setUp(self, ksp):
        self.work[:] = ksp.getWorkVecs(right=2, left=None)

    def reset(self, ksp):
        for v in self.work:
            v.destroy()
        del self.work[:]

    def loop(self, ksp, r):
        its = ksp.getIterationNumber()
        rnorm = r.norm()
        ksp.setResidualNorm(rnorm)
        ksp.logConvergenceHistory(rnorm)
        ksp.monitor(its, rnorm)
        reason = ksp.callConvergenceTest(its, rnorm)
        if not reason:
            ksp.setIterationNumber(its + 1)
        else:
            ksp.setConvergedReason(reason)
        return reason


class MyRichardson(MyKSP):
    def solve(self, ksp, b, x):
        A, B = ksp.getOperators()
        P = ksp.getPC()
        r, z = self.work
        #
        A.mult(x, r)
        r.aypx(-1, b)
        P.apply(r, z)
        x.axpy(1, z)
        while not self.loop(ksp, z):
            A.mult(x, r)
            r.aypx(-1, b)
            P.apply(r, z)
            x.axpy(1, z)


class MyCG(MyKSP):
    def setUp(self, ksp):
        super().setUp(ksp)
        d = self.work[0].duplicate()
        q = d.duplicate()
        self.work += [d, q]

    def solve(self, ksp, b, x):
        A, B = ksp.getOperators()
        # P = ksp.getPC()
        r, z, d, q = self.work
        #
        A.mult(x, r)
        r.aypx(-1, b)
        r.copy(d)
        delta_0 = r.dot(r)
        delta = delta_0
        while not self.loop(ksp, r):
            A.mult(d, q)
            alpha = delta / d.dot(q)
            x.axpy(+alpha, d)
            r.axpy(-alpha, q)
            delta_old = delta
            delta = r.dot(r)
            beta = delta / delta_old
            d.aypx(beta, r)


# --------------------------------------------------------------------


class BaseTestKSPPYTHON(BaseTestKSP):
    KSP_TYPE = PETSc.KSP.Type.PYTHON
    ContextClass = None

    def setUp(self):
        super().setUp()
        ctx = self.ContextClass()
        self.ksp.setPythonContext(ctx)

    def testGetType(self):
        ctx = self.ksp.getPythonContext()
        pytype = f'{ctx.__module__}.{type(ctx).__name__}'
        self.assertTrue(self.ksp.getPythonType() == pytype)

    def tearDown(self):
        self.ksp.destroy()
        PETSc.garbage_cleanup()


class TestKSPPYTHON_RICH(BaseTestKSPPYTHON, unittest.TestCase):
    PC_TYPE = PETSc.PC.Type.JACOBI
    ContextClass = MyRichardson


class TestKSPPYTHON_CG(BaseTestKSPPYTHON, unittest.TestCase):
    PC_TYPE = PETSc.PC.Type.NONE
    ContextClass = MyCG


# --------------------------------------------------------------------

if __name__ == '__main__':
    unittest.main()