1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
|
"""LU decomposition functions."""
from warnings import warn
from numpy import asarray, asarray_chkfinite
# Local imports
from misc import _datacopied
from lapack import get_lapack_funcs
from flinalg import get_flinalg_funcs
__all__ = ['lu', 'lu_solve', 'lu_factor']
def lu_factor(a, overwrite_a=False):
"""Compute pivoted LU decomposition of a matrix.
The decomposition is::
A = P L U
where P is a permutation matrix, L lower triangular with unit
diagonal elements, and U upper triangular.
Parameters
----------
a : array, shape (M, M)
Matrix to decompose
overwrite_a : boolean
Whether to overwrite data in A (may increase performance)
Returns
-------
lu : array, shape (N, N)
Matrix containing U in its upper triangle, and L in its lower triangle.
The unit diagonal elements of L are not stored.
piv : array, shape (N,)
Pivot indices representing the permutation matrix P:
row i of matrix was interchanged with row piv[i].
See also
--------
lu_solve : solve an equation system using the LU factorization of a matrix
Notes
-----
This is a wrapper to the *GETRF routines from LAPACK.
"""
a1 = asarray(a)
if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
raise ValueError('expected square matrix')
overwrite_a = overwrite_a or (_datacopied(a1, a))
getrf, = get_lapack_funcs(('getrf',), (a1,))
lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
if info < 0:
raise ValueError('illegal value in %d-th argument of '
'internal getrf (lu_factor)' % -info)
if info > 0:
warn("Diagonal number %d is exactly zero. Singular matrix." % info,
RuntimeWarning)
return lu, piv
def lu_solve((lu, piv), b, trans=0, overwrite_b=False):
"""Solve an equation system, a x = b, given the LU factorization of a
Parameters
----------
(lu, piv)
Factorization of the coefficient matrix a, as given by lu_factor
b : array
Right-hand side
trans : {0, 1, 2}
Type of system to solve:
===== =========
trans system
===== =========
0 a x = b
1 a^T x = b
2 a^H x = b
===== =========
Returns
-------
x : array
Solution to the system
See also
--------
lu_factor : LU factorize a matrix
"""
b1 = asarray_chkfinite(b)
overwrite_b = overwrite_b or _datacopied(b1, b)
if lu.shape[0] != b1.shape[0]:
raise ValueError("incompatible dimensions.")
getrs, = get_lapack_funcs(('getrs',), (lu, b1))
x,info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
if info == 0:
return x
raise ValueError('illegal value in %d-th argument of internal gesv|posv'
% -info)
def lu(a, permute_l=False, overwrite_a=False):
"""Compute pivoted LU decompostion of a matrix.
The decomposition is::
A = P L U
where P is a permutation matrix, L lower triangular with unit
diagonal elements, and U upper triangular.
Parameters
----------
a : array, shape (M, N)
Array to decompose
permute_l : boolean
Perform the multiplication P*L (Default: do not permute)
overwrite_a : boolean
Whether to overwrite data in a (may improve performance)
Returns
-------
(If permute_l == False)
p : array, shape (M, M)
Permutation matrix
l : array, shape (M, K)
Lower triangular or trapezoidal matrix with unit diagonal.
K = min(M, N)
u : array, shape (K, N)
Upper triangular or trapezoidal matrix
(If permute_l == True)
pl : array, shape (M, K)
Permuted L matrix.
K = min(M, N)
u : array, shape (K, N)
Upper triangular or trapezoidal matrix
Notes
-----
This is a LU factorization routine written for Scipy.
"""
a1 = asarray_chkfinite(a)
if len(a1.shape) != 2:
raise ValueError('expected matrix')
overwrite_a = overwrite_a or (_datacopied(a1, a))
flu, = get_flinalg_funcs(('lu',), (a1,))
p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a)
if info < 0:
raise ValueError('illegal value in %d-th argument of '
'internal lu.getrf' % -info)
if permute_l:
return l, u
return p, l, u
|