1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
#!/usr/bin/env python
#
# Created by: Pearu Peterson, September 2002
#
from __future__ import division, print_function, absolute_import
from numpy.testing import TestCase, run_module_suite, assert_equal, \
assert_array_almost_equal, assert_, assert_raises, assert_allclose
import numpy as np
from scipy.linalg import _flapack as flapack
from scipy.linalg import inv
try:
from scipy.linalg import _clapack as clapack
except ImportError:
clapack = None
from scipy.linalg.lapack import get_lapack_funcs
REAL_DTYPES = [np.float32, np.float64]
COMPLEX_DTYPES = [np.complex64, np.complex128]
DTYPES = REAL_DTYPES + COMPLEX_DTYPES
class TestFlapackSimple(TestCase):
def test_gebal(self):
a = [[1,2,3],[4,5,6],[7,8,9]]
a1 = [[1,0,0,3e-4],
[4,0,0,2e-3],
[7,1,0,0],
[0,1,0,0]]
for p in 'sdzc':
f = getattr(flapack,p+'gebal',None)
if f is None:
continue
ba,lo,hi,pivscale,info = f(a)
assert_(not info,repr(info))
assert_array_almost_equal(ba,a)
assert_equal((lo,hi),(0,len(a[0])-1))
assert_array_almost_equal(pivscale, np.ones(len(a)))
ba,lo,hi,pivscale,info = f(a1,permute=1,scale=1)
assert_(not info,repr(info))
# print a1
# print ba,lo,hi,pivscale
def test_gehrd(self):
a = [[-149, -50,-154],
[537, 180, 546],
[-27, -9, -25]]
for p in 'd':
f = getattr(flapack,p+'gehrd',None)
if f is None:
continue
ht,tau,info = f(a)
assert_(not info,repr(info))
def test_trsyl(self):
a = np.array([[1, 2], [0, 4]])
b = np.array([[5, 6], [0, 8]])
c = np.array([[9, 10], [11, 12]])
trans = 'T'
# Test single and double implementations, including most
# of the options
for dtype in 'fdFD':
a1, b1, c1 = a.astype(dtype), b.astype(dtype), c.astype(dtype)
trsyl, = get_lapack_funcs(('trsyl',), (a1,))
if dtype.isupper(): # is complex dtype
a1[0] += 1j
trans = 'C'
x, scale, info = trsyl(a1, b1, c1)
assert_array_almost_equal(np.dot(a1, x) + np.dot(x, b1), scale * c1)
x, scale, info = trsyl(a1, b1, c1, trana=trans, tranb=trans)
assert_array_almost_equal(np.dot(a1.conjugate().T, x) + np.dot(x, b1.conjugate().T),
scale * c1, decimal=4)
x, scale, info = trsyl(a1, b1, c1, isgn=-1)
assert_array_almost_equal(np.dot(a1, x) - np.dot(x, b1), scale * c1, decimal=4)
class TestLapack(TestCase):
def test_flapack(self):
if hasattr(flapack,'empty_module'):
# flapack module is empty
pass
def test_clapack(self):
if hasattr(clapack,'empty_module'):
# clapack module is empty
pass
class TestRegression(TestCase):
def test_ticket_1645(self):
# Check that RQ routines have correct lwork
for dtype in DTYPES:
a = np.zeros((300, 2), dtype=dtype)
gerqf, = get_lapack_funcs(['gerqf'], [a])
assert_raises(Exception, gerqf, a, lwork=2)
rq, tau, work, info = gerqf(a)
if dtype in REAL_DTYPES:
orgrq, = get_lapack_funcs(['orgrq'], [a])
assert_raises(Exception, orgrq, rq[-2:], tau, lwork=1)
orgrq(rq[-2:], tau, lwork=2)
elif dtype in COMPLEX_DTYPES:
ungrq, = get_lapack_funcs(['ungrq'], [a])
assert_raises(Exception, ungrq, rq[-2:], tau, lwork=1)
ungrq(rq[-2:], tau, lwork=2)
class TestDpotr(TestCase):
def test_gh_2691(self):
# 'lower' argument of dportf/dpotri
for lower in [True, False]:
for clean in [True, False]:
np.random.seed(42)
x = np.random.normal(size=(3, 3))
a = x.dot(x.T)
dpotrf, dpotri = get_lapack_funcs(("potrf", "potri"), (a, ))
c, info = dpotrf(a, lower, clean=clean)
dpt = dpotri(c, lower)[0]
if lower:
assert_allclose(np.tril(dpt), np.tril(inv(a)))
else:
assert_allclose(np.triu(dpt), np.triu(inv(a)))
if __name__ == "__main__":
run_module_suite()
|