File: decomp_lu.py

package info (click to toggle)
python-scipy 0.10.1%2Bdfsg2-1
  • links: PTS, VCS
  • area: main
  • in suites: wheezy
  • size: 42,232 kB
  • sloc: cpp: 224,773; ansic: 103,496; python: 85,210; fortran: 79,130; makefile: 272; sh: 43
file content (161 lines) | stat: -rw-r--r-- 4,543 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""LU decomposition functions."""

from warnings import warn

from numpy import asarray, asarray_chkfinite

# Local imports
from misc import _datacopied
from lapack import get_lapack_funcs
from flinalg import get_flinalg_funcs

__all__ = ['lu', 'lu_solve', 'lu_factor']


def lu_factor(a, overwrite_a=False):
    """Compute pivoted LU decomposition of a matrix.

    The decomposition is::

        A = P L U

    where P is a permutation matrix, L lower triangular with unit
    diagonal elements, and U upper triangular.

    Parameters
    ----------
    a : array, shape (M, M)
        Matrix to decompose
    overwrite_a : boolean
        Whether to overwrite data in A (may increase performance)

    Returns
    -------
    lu : array, shape (N, N)
        Matrix containing U in its upper triangle, and L in its lower triangle.
        The unit diagonal elements of L are not stored.
    piv : array, shape (N,)
        Pivot indices representing the permutation matrix P:
        row i of matrix was interchanged with row piv[i].

    See also
    --------
    lu_solve : solve an equation system using the LU factorization of a matrix

    Notes
    -----
    This is a wrapper to the *GETRF routines from LAPACK.

    """
    a1 = asarray(a)
    if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
        raise ValueError('expected square matrix')
    overwrite_a = overwrite_a or (_datacopied(a1, a))
    getrf, = get_lapack_funcs(('getrf',), (a1,))
    lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
    if info < 0:
        raise ValueError('illegal value in %d-th argument of '
                                'internal getrf (lu_factor)' % -info)
    if info > 0:
        warn("Diagonal number %d is exactly zero. Singular matrix." % info,
                    RuntimeWarning)
    return lu, piv


def lu_solve((lu, piv), b, trans=0, overwrite_b=False):
    """Solve an equation system, a x = b, given the LU factorization of a

    Parameters
    ----------
    (lu, piv)
        Factorization of the coefficient matrix a, as given by lu_factor
    b : array
        Right-hand side
    trans : {0, 1, 2}
        Type of system to solve:

        =====  =========
        trans  system
        =====  =========
        0      a x   = b
        1      a^T x = b
        2      a^H x = b
        =====  =========

    Returns
    -------
    x : array
        Solution to the system

    See also
    --------
    lu_factor : LU factorize a matrix

    """
    b1 = asarray_chkfinite(b)
    overwrite_b = overwrite_b or _datacopied(b1, b)
    if lu.shape[0] != b1.shape[0]:
        raise ValueError("incompatible dimensions.")

    getrs, = get_lapack_funcs(('getrs',), (lu, b1))
    x,info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
    if info == 0:
        return x
    raise ValueError('illegal value in %d-th argument of internal gesv|posv'
                                                                    % -info)


def lu(a, permute_l=False, overwrite_a=False):
    """Compute pivoted LU decompostion of a matrix.

    The decomposition is::

        A = P L U

    where P is a permutation matrix, L lower triangular with unit
    diagonal elements, and U upper triangular.

    Parameters
    ----------
    a : array, shape (M, N)
        Array to decompose
    permute_l : boolean
        Perform the multiplication P*L  (Default: do not permute)
    overwrite_a : boolean
        Whether to overwrite data in a (may improve performance)

    Returns
    -------
    (If permute_l == False)
    p : array, shape (M, M)
        Permutation matrix
    l : array, shape (M, K)
        Lower triangular or trapezoidal matrix with unit diagonal.
        K = min(M, N)
    u : array, shape (K, N)
        Upper triangular or trapezoidal matrix

    (If permute_l == True)
    pl : array, shape (M, K)
        Permuted L matrix.
        K = min(M, N)
    u : array, shape (K, N)
        Upper triangular or trapezoidal matrix

    Notes
    -----
    This is a LU factorization routine written for Scipy.

    """
    a1 = asarray_chkfinite(a)
    if len(a1.shape) != 2:
        raise ValueError('expected matrix')
    overwrite_a = overwrite_a or (_datacopied(a1, a))
    flu, = get_flinalg_funcs(('lu',), (a1,))
    p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a)
    if info < 0:
        raise ValueError('illegal value in %d-th argument of '
                                            'internal lu.getrf' % -info)
    if permute_l:
        return l, u
    return p, l, u