File: test_lsqr.py

package info (click to toggle)
python-scipy 0.18.1-2
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 75,464 kB
  • ctags: 79,406
  • sloc: python: 143,495; cpp: 89,357; fortran: 81,650; ansic: 79,778; makefile: 364; sh: 265
file content (121 lines) | stat: -rw-r--r-- 3,591 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ import division, print_function, absolute_import

import numpy as np
from numpy.testing import (assert_, assert_equal, assert_almost_equal,
                           assert_array_almost_equal)
from scipy._lib.six import xrange

import scipy.sparse
import scipy.sparse.linalg
from scipy.sparse.linalg import lsqr
from time import time

# Set up a test problem
n = 35
G = np.eye(n)
normal = np.random.normal
norm = np.linalg.norm

for jj in xrange(5):
    gg = normal(size=n)
    hh = gg * gg.T
    G += (hh + hh.T) * 0.5
    G += normal(size=n) * normal(size=n)

b = normal(size=n)

tol = 1e-10
show = False
maxit = None


def test_basic():
    svx = np.linalg.solve(G, b)
    X = lsqr(G, b, show=show, atol=tol, btol=tol, iter_lim=maxit)
    xo = X[0]
    assert_(norm(svx - xo) < 1e-5)


def test_gh_2466():
    row = np.array([0, 0])
    col = np.array([0, 1])
    val = np.array([1, -1])
    A = scipy.sparse.coo_matrix((val, (row, col)), shape=(1, 2))
    b = np.asarray([4])
    lsqr(A, b)


def test_well_conditioned_problems():
    # Test that sparse the lsqr solver returns the right solution
    # on various problems with different random seeds.
    # This is a non-regression test for a potential ZeroDivisionError
    # raised when computing the `test2` & `test3` convergence conditions.
    n = 10
    A_sparse = scipy.sparse.eye(n, n)
    A_dense = A_sparse.toarray()

    with np.errstate(invalid='raise'):
        for seed in range(30):
            rng = np.random.RandomState(seed + 10)
            beta = rng.rand(n)
            beta[beta == 0] = 0.00001  # ensure that all the betas are not null
            b = A_sparse * beta[:, np.newaxis]
            output = lsqr(A_sparse, b, show=show)

            # Check that the termination condition corresponds to an approximate
            # solution to Ax = b
            assert_equal(output[1], 1)
            solution = output[0]

            # Check that we recover the ground truth solution
            assert_array_almost_equal(solution, beta)

            # Sanity check: compare to the dense array solver
            reference_solution = np.linalg.solve(A_dense, b).ravel()
            assert_array_almost_equal(solution, reference_solution)


def test_b_shapes():
    # Test b being a scalar.
    A = np.array([[1.0, 2.0]])
    b = 3.0
    x = lsqr(A, b)[0]
    assert_almost_equal(norm(A.dot(x) - b), 0)

    # Test b being a column vector.
    A = np.eye(10)
    b = np.ones((10, 1))
    x = lsqr(A, b)[0]
    assert_almost_equal(norm(A.dot(x) - b.ravel()), 0)


if __name__ == "__main__":
    svx = np.linalg.solve(G, b)

    tic = time()
    X = lsqr(G, b, show=show, atol=tol, btol=tol, iter_lim=maxit)
    xo = X[0]
    phio = X[3]
    psio = X[7]
    k = X[2]
    chio = X[8]
    mg = np.amax(G - G.T)
    if mg > 1e-14:
        sym = 'No'
    else:
        sym = 'Yes'

    print('LSQR')
    print("Is linear operator symmetric? " + sym)
    print("n: %3g  iterations:   %3g" % (n, k))
    print("Norms computed in %.2fs by LSQR" % (time() - tic))
    print(" ||x||  %9.4e  ||r|| %9.4e  ||Ar||  %9.4e " % (chio, phio, psio))
    print("Residual norms computed directly:")
    print(" ||x||  %9.4e  ||r|| %9.4e  ||Ar||  %9.4e" % (norm(xo),
                                                          norm(G*xo - b),
                                                          norm(G.T*(G*xo-b))))
    print("Direct solution norms:")
    print(" ||x||  %9.4e  ||r|| %9.4e " % (norm(svx), norm(G*svx - b)))
    print("")
    print(" || x_{direct} - x_{LSQR}|| %9.4e " % norm(svx-xo))
    print("")