File: test_iterative.py

package info (click to toggle)
python-scipy 0.7.2%2Bdfsg1-1
  • links: PTS, VCS
  • area: main
  • in suites: squeeze
  • size: 28,500 kB
  • ctags: 36,081
  • sloc: cpp: 216,880; fortran: 76,016; python: 71,576; ansic: 62,118; makefile: 243; sh: 17
file content (201 lines) | stat: -rw-r--r-- 6,148 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
#!/usr/bin/env python
""" Test functions for the sparse.linalg.isolve module
"""
import sys

from numpy.testing import *

from numpy import zeros, ones, arange, array, abs, max
from scipy.linalg import norm
from scipy.sparse import spdiags, csr_matrix

from scipy.sparse.linalg.interface import LinearOperator
from scipy.sparse.linalg.isolve import cg, cgs, bicg, bicgstab, gmres, qmr, minres

#def callback(x):
#    global A, b
#    res = b-dot(A,x)
#    #print "||A.x - b|| = " + str(norm(dot(A,x)-b))


#TODO check that method preserve shape and type
#TODO test complex matrices
#TODO test both preconditioner methods

N = 40
data = ones((3,N))
data[0,:] =  2
data[1,:] = -1
data[2,:] = -1
Poisson1D = spdiags( data, [0,-1,1], N, N, format='csr')

data = array([[6, -5, 2, 7, -1, 10, 4, -3, -8, 9]],dtype='d')
RandDiag = spdiags( data, [0], 10, 10, format='csr' )

class TestIterative(TestCase):
    def setUp(self):
        # list of tuples (solver, symmetric, positive_definite )
        self.solvers = []
        self.solvers.append( (cg,       True,  True) )
        self.solvers.append( (cgs,      False, False) )
        self.solvers.append( (bicg,     False, False) )
        self.solvers.append( (bicgstab, False, False) )
        self.solvers.append( (gmres,    False, False) )
        self.solvers.append( (qmr,      False, False) )
        self.solvers.append( (minres,   True,  False) )

        # list of tuples (A, symmetric, positive_definite )
        self.cases = []

        # Symmetric and Positive Definite
        self.cases.append( (Poisson1D,True,True) )

        # Symmetric and Negative Definite
        self.cases.append( (-Poisson1D,True,False) )

        # Symmetric and Indefinite
        self.cases.append( (RandDiag,True,False) )

        # Non-symmetric and Positive Definite
        # bicg and cgs fail to converge on this one
        #data = ones((2,10))
        #data[0,:] =  2
        #data[1,:] = -1
        #A = spdiags( data, [0,-1], 10, 10, format='csr')
        #self.cases.append( (A,False,True) )

    def test_maxiter(self):
        """test whether maxiter is respected"""

        A = Poisson1D
        tol = 1e-12

        for solver,req_sym,req_pos in self.solvers:
            b  = arange(A.shape[0], dtype=float)
            x0 = 0*b

            residuals = []
            def callback(x):
                residuals.append( norm(b - A*x) )

            x, info = solver(A, b, x0=x0, tol=tol, maxiter=3, callback=callback)

            assert_equal(len(residuals), 3)
            assert_equal(info, 3)

    def test_convergence(self):
        """test whether all methods converge"""

        tol = 1e-8

        for solver,req_sym,req_pos in self.solvers:
            for A,sym,pos in self.cases:
                if req_sym and not sym: continue
                if req_pos and not pos: continue

                b  = arange(A.shape[0], dtype=float)
                x0 = 0*b

                x, info = solver(A, b, x0=x0, tol=tol)

                assert_array_equal(x0, 0*b) #ensure that x0 is not overwritten
                assert_equal(info,0)

                assert( norm(b - A*x) < tol*norm(b) )

    def test_precond(self):
        """test whether all methods accept a trivial preconditioner"""

        tol = 1e-8

        def identity(b,which=None):
            """trivial preconditioner"""
            return b

        for solver,req_sym,req_pos in self.solvers:

            for A,sym,pos in self.cases:
                if req_sym and not sym: continue
                if req_pos and not pos: continue

                M,N = A.shape
                D = spdiags( [1.0/A.diagonal()], [0], M, N)

                b  = arange(A.shape[0], dtype=float)
                x0 = 0*b

                precond = LinearOperator(A.shape, identity, rmatvec=identity)

                if solver == qmr:
                    x, info = solver(A, b, M1=precond, M2=precond, x0=x0, tol=tol)
                else:
                    x, info = solver(A, b, M=precond, x0=x0, tol=tol)
                assert_equal(info,0)
                assert( norm(b - A*x) < tol*norm(b) )

                A = A.copy()
                A.psolve  = identity
                A.rpsolve = identity

                x, info = solver(A, b, x0=x0, tol=tol)
                assert_equal(info,0)
                assert( norm(b - A*x) < tol*norm(b) )


class TestQMR(TestCase):
    def test_leftright_precond(self):
        """Check that QMR works with left and right preconditioners"""

        from scipy.sparse.linalg.dsolve import splu
        from scipy.sparse.linalg.interface import LinearOperator

        n = 100

        dat = ones(n)
        A = spdiags([-2*dat, 4*dat, -dat], [-1,0,1] ,n,n)
        b = arange(n,dtype='d')

        L = spdiags([-dat/2, dat], [-1,0], n, n)
        U = spdiags([4*dat, -dat], [ 0,1], n, n)

        L_solver = splu(L)
        U_solver = splu(U)

        def L_solve(b):
            return L_solver.solve(b)
        def U_solve(b):
            return U_solver.solve(b)
        def LT_solve(b):
            return L_solver.solve(b,'T')
        def UT_solve(b):
            return U_solver.solve(b,'T')

        M1 = LinearOperator( (n,n), matvec=L_solve, rmatvec=LT_solve )
        M2 = LinearOperator( (n,n), matvec=U_solve, rmatvec=UT_solve )

        x,info = qmr(A, b, tol=1e-8, maxiter=15, M1=M1, M2=M2)

        assert_equal(info,0)
        assert( norm(b - A*x) < 1e-8*norm(b) )


class TestGMRES(TestCase):
    def test_callback(self):

        def store_residual(r, rvec):
            rvec[rvec.nonzero()[0].max()+1] = r

        #Define, A,b
        A = csr_matrix(array([[-2,1,0,0,0,0],[1,-2,1,0,0,0],[0,1,-2,1,0,0],[0,0,1,-2,1,0],[0,0,0,1,-2,1],[0,0,0,0,1,-2]]))
        b = ones((A.shape[0],))
        maxiter=1
        rvec = zeros(maxiter+1)
        rvec[0] = 1.0
        callback = lambda r:store_residual(r, rvec)
        x,flag = gmres(A, b, x0=zeros(A.shape[0]), tol=1e-16, maxiter=maxiter, callback=callback)
        diff = max(abs((rvec - array([1.0,   0.81649658092772603]))))
        assert(diff < 1e-5)


if __name__ == "__main__":
    nose.run(argv=['', __file__])