1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
|
"""LU decomposition functions."""
from __future__ import division, print_function, absolute_import
from warnings import warn
from numpy import asarray, asarray_chkfinite
# Local imports
from .misc import _datacopied
from .lapack import get_lapack_funcs
from .flinalg import get_flinalg_funcs
__all__ = ['lu', 'lu_solve', 'lu_factor']
def lu_factor(a, overwrite_a=False, check_finite=True):
"""
Compute pivoted LU decomposition of a matrix.
The decomposition is::
A = P L U
where P is a permutation matrix, L lower triangular with unit
diagonal elements, and U upper triangular.
Parameters
----------
a : (M, M) array_like
Matrix to decompose
overwrite_a : bool, optional
Whether to overwrite data in A (may increase performance)
check_finite : bool, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns
-------
lu : (N, N) ndarray
Matrix containing U in its upper triangle, and L in its lower triangle.
The unit diagonal elements of L are not stored.
piv : (N,) ndarray
Pivot indices representing the permutation matrix P:
row i of matrix was interchanged with row piv[i].
See also
--------
lu_solve : solve an equation system using the LU factorization of a matrix
Notes
-----
This is a wrapper to the ``*GETRF`` routines from LAPACK.
"""
if check_finite:
a1 = asarray_chkfinite(a)
else:
a1 = asarray(a)
if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
raise ValueError('expected square matrix')
overwrite_a = overwrite_a or (_datacopied(a1, a))
getrf, = get_lapack_funcs(('getrf',), (a1,))
lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
if info < 0:
raise ValueError('illegal value in %d-th argument of '
'internal getrf (lu_factor)' % -info)
if info > 0:
warn("Diagonal number %d is exactly zero. Singular matrix." % info,
RuntimeWarning)
return lu, piv
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
"""Solve an equation system, a x = b, given the LU factorization of a
Parameters
----------
(lu, piv)
Factorization of the coefficient matrix a, as given by lu_factor
b : array
Right-hand side
trans : {0, 1, 2}, optional
Type of system to solve:
===== =========
trans system
===== =========
0 a x = b
1 a^T x = b
2 a^H x = b
===== =========
overwrite_b : bool, optional
Whether to overwrite data in b (may increase performance)
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns
-------
x : array
Solution to the system
See also
--------
lu_factor : LU factorize a matrix
"""
(lu, piv) = lu_and_piv
if check_finite:
b1 = asarray_chkfinite(b)
else:
b1 = asarray(b)
overwrite_b = overwrite_b or _datacopied(b1, b)
if lu.shape[0] != b1.shape[0]:
raise ValueError("incompatible dimensions.")
getrs, = get_lapack_funcs(('getrs',), (lu, b1))
x,info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
if info == 0:
return x
raise ValueError('illegal value in %d-th argument of internal gesv|posv'
% -info)
def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
"""
Compute pivoted LU decomposition of a matrix.
The decomposition is::
A = P L U
where P is a permutation matrix, L lower triangular with unit
diagonal elements, and U upper triangular.
Parameters
----------
a : (M, N) array_like
Array to decompose
permute_l : bool, optional
Perform the multiplication P*L (Default: do not permute)
overwrite_a : bool, optional
Whether to overwrite data in a (may improve performance)
check_finite : bool, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns
-------
**(If permute_l == False)**
p : (M, M) ndarray
Permutation matrix
l : (M, K) ndarray
Lower triangular or trapezoidal matrix with unit diagonal.
K = min(M, N)
u : (K, N) ndarray
Upper triangular or trapezoidal matrix
**(If permute_l == True)**
pl : (M, K) ndarray
Permuted L matrix.
K = min(M, N)
u : (K, N) ndarray
Upper triangular or trapezoidal matrix
Notes
-----
This is a LU factorization routine written for Scipy.
"""
if check_finite:
a1 = asarray_chkfinite(a)
else:
a1 = asarray(a)
if len(a1.shape) != 2:
raise ValueError('expected matrix')
overwrite_a = overwrite_a or (_datacopied(a1, a))
flu, = get_flinalg_funcs(('lu',), (a1,))
p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a)
if info < 0:
raise ValueError('illegal value in %d-th argument of '
'internal lu.getrf' % -info)
if permute_l:
return l, u
return p, l, u
|