File: test_esv.py

package info (click to toggle)
python-scipy 0.10.1%2Bdfsg2-1
  • links: PTS, VCS
  • area: main
  • in suites: wheezy
  • size: 42,232 kB
  • sloc: cpp: 224,773; ansic: 103,496; python: 85,210; fortran: 79,130; makefile: 272; sh: 43
file content (138 lines) | stat: -rw-r--r-- 4,801 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import numpy as np
from numpy.testing import TestCase, assert_array_almost_equal, dec, \
                          assert_equal, assert_

from common import FUNCS_TP, FLAPACK_IS_EMPTY, CLAPACK_IS_EMPTY, FUNCS_FLAPACK, \
                   FUNCS_CLAPACK, PREC

SYEV_ARG = np.array([[1,2,3],[2,2,3],[3,3,6]])
SYEV_REF = np.array([-0.6699243371851365, 0.4876938861533345,
                     9.182230451031804])

class TestEsv(TestCase):
    def _test_base(self, func, lang):
        tp = FUNCS_TP[func]
        a = SYEV_ARG.astype(tp)
        if lang == 'C':
            f = FUNCS_CLAPACK[func]
        elif lang == 'F':
            f = FUNCS_FLAPACK[func]
        else:
            raise ValueError("Lang %s ??" % lang)

        w, v, info = f(a)

        assert_(not info, msg=repr(info))
        assert_array_almost_equal(w, SYEV_REF, decimal=PREC[tp])
        for i in range(3):
            assert_array_almost_equal(np.dot(a,v[:,i]), w[i]*v[:,i],
                                      decimal=PREC[tp])

    def _test_base_irange(self, func, irange, lang):
        tp = FUNCS_TP[func]
        a = SYEV_ARG.astype(tp)
        if lang == 'C':
            f = FUNCS_CLAPACK[func]
        elif lang == 'F':
            f = FUNCS_FLAPACK[func]
        else:
            raise ValueError("Lang %s ??" % lang)

        w, v, info = f(a, irange=irange)
        rslice = slice(irange[0], irange[1]+1)
        m = irange[1] - irange[0] + 1
        assert_(not info, msg=repr(info))

        assert_equal(len(w),m)
        assert_array_almost_equal(w, SYEV_REF[rslice], decimal=PREC[tp])

        for i in range(m):
            assert_array_almost_equal(np.dot(a,v[:,i]), w[i]*v[:,i],
                                      decimal=PREC[tp])

    def _test_base_vrange(self, func, vrange, lang):
        tp = FUNCS_TP[func]
        a = SYEV_ARG.astype(tp)
        ew = [value for value in SYEV_REF if vrange[0] < value <= vrange[1]]

        if lang == 'C':
            f = FUNCS_CLAPACK[func]
        elif lang == 'F':
            f = FUNCS_FLAPACK[func]
        else:
            raise ValueError("Lang %s ??" % lang)

        w, v, info = f(a, vrange=vrange)
        assert_(not info, msg=repr(info))

        assert_array_almost_equal(w, ew, decimal=PREC[tp])

        for i in range(len(w)):
            assert_array_almost_equal(np.dot(a,v[:,i]), w[i]*v[:,i],
                                      decimal=PREC[tp])

    def _test_syevr_ranges(self, func, lang):
        for irange in ([0, 2], [0, 1], [1, 1], [1, 2]):
            self._test_base_irange(func, irange, lang)

        for vrange in ([-1, 10], [-1, 1], [0, 1], [1, 10]):
            self._test_base_vrange(func, vrange, lang)

    # Flapack tests
    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
    def test_ssyev(self):
        self._test_base('ssyev', 'F')

    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
    def test_dsyev(self):
        self._test_base('dsyev', 'F')

    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
    def test_ssyevr(self):
        self._test_base('ssyevr', 'F')

    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
    def test_dsyevr(self):
        self._test_base('dsyevr', 'F')

    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
    def test_ssyevr_ranges(self):
        self._test_syevr_ranges('ssyevr', 'F')

    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
    def test_dsyevr_ranges(self):
        self._test_syevr_ranges('dsyevr', 'F')

    # Clapack tests
    @dec.skipif(CLAPACK_IS_EMPTY or not FUNCS_CLAPACK["ssyev"],
                "Clapack empty, skip clapack test")
    def test_clapack_ssyev(self):
        self._test_base('ssyev', 'C')

    @dec.skipif(CLAPACK_IS_EMPTY or not FUNCS_CLAPACK["dsyev"],
                "Clapack empty, skip clapack test")
    def test_clapack_dsyev(self):
        self._test_base('dsyev', 'C')

    @dec.skipif(CLAPACK_IS_EMPTY or not FUNCS_CLAPACK["ssyevr"],
                "Clapack empty, skip clapack test")
    def test_clapack_ssyevr(self):
        self._test_base('ssyevr', 'C')

    @dec.skipif(CLAPACK_IS_EMPTY or not FUNCS_CLAPACK["dsyevr"],
                "Clapack empty, skip clapack test")
    def test_clapack_dsyevr(self):
        self._test_base('dsyevr', 'C')

    @dec.skipif(CLAPACK_IS_EMPTY or not FUNCS_CLAPACK["ssyevr"],
                "Clapack empty, skip clapack test")
    def test_clapack_ssyevr_ranges(self):
        self._test_syevr_ranges('ssyevr', 'C')

    @dec.skipif(CLAPACK_IS_EMPTY or not FUNCS_CLAPACK["dsyevr"],
                "Clapack empty, skip clapack test")
    def test_clapack_dsyevr_ranges(self):
        self._test_syevr_ranges('dsyevr', 'C')

if __name__=="__main__":
    run_module_suite()