## Automatically adapted for scipy Oct 18, 2005 by

## Automatically adapted for scipy Oct 18, 2005 by

#
# Author: Pearu Peterson, March 2002
#
# w/ additions by Travis Oliphant, March 2002

__all__ = ['solve','inv','det','lstsq','norm','pinv','pinv2',
           'tri','tril','triu','toeplitz','hankel','lu_solve',
           'cho_solve','solve_banded','LinAlgError','kron',
           'all_mat', 'cholesky_banded', 'solveh_banded']

#from blas import get_blas_funcs
from flinalg import get_flinalg_funcs
from lapack import get_lapack_funcs
from numpy import asarray,zeros,sum,newaxis,greater_equal,subtract,arange,\
     conjugate,ravel,r_,mgrid,take,ones,dot,transpose,sqrt,add,real
import numpy
from numpy import asarray_chkfinite, outer, concatenate, reshape, single
from numpy import matrix as Matrix
from numpy.linalg import LinAlgError
from scipy.linalg import calc_lwork


def lu_solve((lu, piv), b, trans=0, overwrite_b=0):
    """ lu_solve((lu, piv), b, trans=0, overwrite_b=0) -> x

    Solve a system of equations given a previously factored matrix

    Inputs:

      (lu,piv) -- The factored matrix, a (the output of lu_factor)
      b        -- a set of right-hand sides
      trans    -- type of system to solve:
                  0 : a   * x = b   (no transpose)
                  1 : a^T * x = b   (transpose)
                  2   a^H * x = b   (conjugate transpose)

    Outputs:

       x -- the solution to the system
    """
    b1 = asarray_chkfinite(b)
    overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
    if lu.shape[0] != b1.shape[0]:
        raise ValueError, "incompatible dimensions."
    getrs, = get_lapack_funcs(('getrs',),(lu,b1))
    x,info = getrs(lu,piv,b1,trans=trans,overwrite_b=overwrite_b)
    if info==0:
        return x
    raise ValueError,\
          'illegal value in %-th argument of internal gesv|posv'%(-info)

def cho_solve((c, lower), b, overwrite_b=0):
    """ cho_solve((c, lower), b, overwrite_b=0) -> x

    Solve a system of equations given a previously cholesky factored matrix

    Inputs:

      (c,lower) -- The factored matrix, a (the output of cho_factor)
      b        -- a set of right-hand sides

    Outputs:

       x -- the solution to the system a*x = b
    """
    b1 = asarray_chkfinite(b)
    overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
    if c.shape[0] != b1.shape[0]:
        raise ValueError, "incompatible dimensions."
    potrs, = get_lapack_funcs(('potrs',),(c,b1))
    x,info = potrs(c,b1,lower=lower,overwrite_b=overwrite_b)
    if info==0:
        return x
    raise ValueError,\
          'illegal value in %-th argument of internal gesv|posv'%(-info)

# Linear equations
def solve(a, b, sym_pos=0, lower=0, overwrite_a=0, overwrite_b=0,
          debug = 0):
    """ solve(a, b, sym_pos=0, lower=0, overwrite_a=0, overwrite_b=0) -> x

    Solve a linear system of equations a * x = b for x.

    Inputs:

      a -- An N x N matrix.
      b -- An N x nrhs matrix or N vector.
      sym_pos -- Assume a is symmetric and positive definite.
      lower -- Assume a is lower triangular, otherwise upper one.
               Only used if sym_pos is true.
      overwrite_y - Discard data in y, where y is a or b.

    Outputs:

      x -- The solution to the system a * x = b
    """
    a1, b1 = map(asarray_chkfinite,(a,b))
    if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
        raise ValueError, 'expected square matrix'
    if a1.shape[0] != b1.shape[0]:
        raise ValueError, 'incompatible dimensions'
    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
    overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
    if debug:
        print 'solve:overwrite_a=',overwrite_a
        print 'solve:overwrite_b=',overwrite_b
    if sym_pos:
        posv, = get_lapack_funcs(('posv',),(a1,b1))
        c,x,info = posv(a1,b1,
                        lower = lower,
                        overwrite_a=overwrite_a,
                        overwrite_b=overwrite_b)
    else:
        gesv, = get_lapack_funcs(('gesv',),(a1,b1))
        lu,piv,x,info = gesv(a1,b1,
                             overwrite_a=overwrite_a,
                             overwrite_b=overwrite_b)

    if info==0:
        return x
    if info>0:
        raise LinAlgError, "singular matrix"
    raise ValueError,\
          'illegal value in %-th argument of internal gesv|posv'%(-info)

def solve_banded((l,u), ab, b, overwrite_ab=0, overwrite_b=0,
          debug = 0):
    """ solve_banded((l,u), ab, b, overwrite_ab=0, overwrite_b=0) -> x

    Solve a linear system of equations a * x = b for x where
    a is a banded matrix stored in diagonal orded form

     *   *     a1u

     *  a12 a23 ...
    a11 a22 a33 ...
    a21 a32 a43 ...
    .
    al1 ..         *

    Inputs:

      (l,u) -- number of non-zero lower and upper diagonals, respectively.
      a -- An N x (l+u+1) matrix.
      b -- An N x nrhs matrix or N vector.
      overwrite_y - Discard data in y, where y is ab or b.

    Outputs:

      x -- The solution to the system a * x = b
    """
    a1, b1 = map(asarray_chkfinite,(ab,b))
    overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))

    gbsv, = get_lapack_funcs(('gbsv',),(a1,b1))
    a2 = zeros((2*l+u+1,a1.shape[1]), dtype=gbsv.dtype)
    a2[l:,:] = a1
    lu,piv,x,info = gbsv(l,u,a2,b1,
                         overwrite_ab=1,
                         overwrite_b=overwrite_b)
    if info==0:
        return x
    if info>0:
        raise LinAlgError, "singular matrix"
    raise ValueError,\
          'illegal value in %-th argument of internal gbsv'%(-info)

def solveh_banded(ab, b, overwrite_ab=0, overwrite_b=0,
    	          lower=0):
    """ solveh_banded(ab, b, overwrite_ab=0, overwrite_b=0) -> c, x

    Solve a linear system of equations a * x = b for x where
    a is a banded symmetric or Hermitian positive definite
    matrix stored in lower diagonal ordered form (lower=1)

    a11 a22 a33 a44 a55 a66
    a21 a32 a43 a54 a65 *
    a31 a42 a53 a64 *   *

    or upper diagonal ordered form

    *   *   a31 a42 a53 a64
    *   a21 a32 a43 a54 a65
    a11 a22 a33 a44 a55 a66

    Inputs:

      ab -- An N x l
      b -- An N x nrhs matrix or N vector.
      overwrite_y - Discard data in y, where y is ab or b.
      lower - is ab in lower or upper form?

    Outputs: 

      c:  the Cholesky factorization of ab
      x:  the solution to ab * x = b

    """
    ab, b = map(asarray_chkfinite,(ab,b))

    pbsv, = get_lapack_funcs(('pbsv',),(ab,b))
    c,x,info = pbsv(ab,b,
                    lower=lower,
                    overwrite_ab=overwrite_ab,
                    overwrite_b=overwrite_b)
    if info==0:
        return c, x
    if info>0:
        raise LinAlgError, "%d-th leading minor not positive definite" % info 
    raise ValueError,\
          'illegal value in %d-th argument of internal pbsv'%(-info)

def cholesky_banded(ab, overwrite_ab=0, lower=0):
    """ cholesky_banded(ab, overwrite_ab=0, lower=0) -> c

    Compute the Cholesky decomposition of a 	
    banded symmetric or Hermitian positive definite
    matrix stored in lower diagonal ordered form (lower=1)

    a11 a22 a33 a44 a55 a66
    a21 a32 a43 a54 a65 *
    a31 a42 a53 a64 *   *

    or upper diagonal ordered form

    *   *   a31 a42 a53 a64
    *   a21 a32 a43 a54 a65
    a11 a22 a33 a44 a55 a66

    Inputs:

      ab -- An N x l
      overwrite_ab - Discard data in ab
      lower - is ab in lower or upper form?

    Outputs:  

      c:  the Cholesky factorization of ab

    """
    ab = asarray_chkfinite(ab)

    pbtrf, = get_lapack_funcs(('pbtrf',),(ab,))
    c,info = pbtrf(ab,
                   lower=lower,
                   overwrite_ab=overwrite_ab)

    if info==0:
        return c
    if info>0:
        raise LinAlgError, "%d-th leading minor not positive definite" % info 
    raise ValueError,\
          'illegal value in %d-th argument of internal pbtrf'%(-info)


# matrix inversion
def inv(a, overwrite_a=0):
    """ inv(a, overwrite_a=0) -> a_inv

    Return inverse of square matrix a.
    """
    a1 = asarray_chkfinite(a)
    if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
        raise ValueError, 'expected square matrix'
    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
    #XXX: I found no advantage or disadvantage of using finv.
##     finv, = get_flinalg_funcs(('inv',),(a1,))
##     if finv is not None:
##         a_inv,info = finv(a1,overwrite_a=overwrite_a)
##         if info==0:
##             return a_inv
##         if info>0: raise LinAlgError, "singular matrix"
##         if info<0: raise ValueError,\
##            'illegal value in %-th argument of internal inv.getrf|getri'%(-info)
    getrf,getri = get_lapack_funcs(('getrf','getri'),(a1,))
    #XXX: C ATLAS versions of getrf/i have rowmajor=1, this could be
    #     exploited for further optimization. But it will be probably
    #     a mess. So, a good testing site is required before trying
    #     to do that.
    if getrf.module_name[:7]=='clapack'!=getri.module_name[:7]:
        # ATLAS 3.2.1 has getrf but not getri.
        lu,piv,info = getrf(transpose(a1),
                            rowmajor=0,overwrite_a=overwrite_a)
        lu = transpose(lu)
    else:
        lu,piv,info = getrf(a1,overwrite_a=overwrite_a)
    if info==0:
        if getri.module_name[:7] == 'flapack':
            lwork = calc_lwork.getri(getri.prefix,a1.shape[0])
            lwork = lwork[1]
            # XXX: the following line fixes curious SEGFAULT when
            # benchmarking 500x500 matrix inverse. This seems to
            # be a bug in LAPACK ?getri routine because if lwork is
            # minimal (when using lwork[0] instead of lwork[1]) then
            # all tests pass. Further investigation is required if
            # more such SEGFAULTs occur.
            lwork = int(1.01*lwork)
            inv_a,info = getri(lu,piv,
                               lwork=lwork,overwrite_lu=1)
        else: # clapack
            inv_a,info = getri(lu,piv,overwrite_lu=1)
    if info>0: raise LinAlgError, "singular matrix"
    if info<0: raise ValueError,\
       'illegal value in %-th argument of internal getrf|getri'%(-info)
    return inv_a


## matrix and Vector norm
import decomp
def norm(x, ord=None):
    """ norm(x, ord=None) -> n

    Matrix or vector norm.

    Inputs:

      x -- a rank-1 (vector) or rank-2 (matrix) array
      ord -- the order of the norm.

     Comments:
       For arrays of any rank, if ord is None:
         calculate the square norm (Euclidean norm for vectors, Frobenius norm for matrices)

       For vectors ord can be any real number including Inf or -Inf.
         ord = Inf, computes the maximum of the magnitudes
         ord = -Inf, computes minimum of the magnitudes
         ord is finite, computes sum(abs(x)**ord,axis=0)**(1.0/ord)

       For matrices ord can only be one of the following values:
         ord = 2 computes the largest singular value
         ord = -2 computes the smallest singular value
         ord = 1 computes the largest column sum of absolute values
         ord = -1 computes the smallest column sum of absolute values
         ord = Inf computes the largest row sum of absolute values
         ord = -Inf computes the smallest row sum of absolute values
         ord = 'fro' computes the frobenius norm sqrt(sum(diag(X.H * X),axis=0))

       For values ord < 0, the result is, strictly speaking, not a
       mathematical 'norm', but it may still be useful for numerical purposes.
    """
    x = asarray_chkfinite(x)
    if ord is None: # check the default case first and handle it immediately
        return sqrt(add.reduce(real((conjugate(x)*x).ravel())))

    nd = len(x.shape)
    Inf = numpy.Inf
    if nd == 1:
        if ord == Inf:
            return numpy.amax(abs(x))
        elif ord == -Inf:
            return numpy.amin(abs(x))
        elif ord == 1:
            return numpy.sum(abs(x),axis=0) # special case for speedup
        elif ord == 2:
            return sqrt(numpy.sum(real((conjugate(x)*x)),axis=0)) # special case for speedup
        else:
            return numpy.sum(abs(x)**ord,axis=0)**(1.0/ord)
    elif nd == 2:
        if ord == 2:
            return numpy.amax(decomp.svd(x,compute_uv=0))
        elif ord == -2:
            return numpy.amin(decomp.svd(x,compute_uv=0))
        elif ord == 1:
            return numpy.amax(numpy.sum(abs(x),axis=0))
        elif ord == Inf:
            return numpy.amax(numpy.sum(abs(x),axis=1))
        elif ord == -1:
            return numpy.amin(numpy.sum(abs(x),axis=0))
        elif ord == -Inf:
            return numpy.amin(numpy.sum(abs(x),axis=1))
        elif ord in ['fro','f']:
            return sqrt(add.reduce(real((conjugate(x)*x).ravel())))
        else:
            raise ValueError, "Invalid norm order for matrices."
    else:
        raise ValueError, "Improper number of dimensions to norm."

### Determinant

def det(a, overwrite_a=0):
    """ det(a, overwrite_a=0) -> d

    Return determinant of a square matrix.
    """
    a1 = asarray_chkfinite(a)
    if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
        raise ValueError, 'expected square matrix'
    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
    fdet, = get_flinalg_funcs(('det',),(a1,))
    a_det,info = fdet(a1,overwrite_a=overwrite_a)
    if info<0: raise ValueError,\
       'illegal value in %-th argument of internal det.getrf'%(-info)
    return a_det

### Linear Least Squares

def lstsq(a, b, cond=None, overwrite_a=0, overwrite_b=0):
    """ lstsq(a, b, cond=None, overwrite_a=0, overwrite_b=0) -> x,resids,rank,s

    Return least-squares solution of a * x = b.

    Inputs:

      a -- An M x N matrix.
      b -- An M x nrhs matrix or M vector.
      cond -- Used to determine effective rank of a.

    Outputs:

      x -- The solution (N x nrhs matrix) to the minimization problem:
                  2-norm(| b - a * x |) -> min
      resids -- The residual sum-of-squares for the solution matrix x
                (only if M>N and rank==N).
      rank -- The effective rank of a.
      s -- Singular values of a in decreasing order. The condition number
           of a is abs(s[0]/s[-1]).
    """
    a1, b1 = map(asarray_chkfinite,(a,b))
    if len(a1.shape) != 2:
        raise ValueError, 'expected matrix'
    m,n = a1.shape
    if len(b1.shape)==2: nrhs = b1.shape[1]
    else: nrhs = 1
    if m != b1.shape[0]:
        raise ValueError, 'incompatible dimensions'
    gelss, = get_lapack_funcs(('gelss',),(a1,b1))
    if n>m:
        # need to extend b matrix as it will be filled with
        # a larger solution matrix
        b2 = zeros((n,nrhs), dtype=gelss.dtype)
        if len(b1.shape)==2: b2[:m,:] = b1
        else: b2[:m,0] = b1
        b1 = b2
    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
    overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
    if gelss.module_name[:7] == 'flapack':
        lwork = calc_lwork.gelss(gelss.prefix,m,n,nrhs)[1]
        v,x,s,rank,info = gelss(a1,b1,cond = cond,
                                lwork = lwork,
                                overwrite_a = overwrite_a,
                                overwrite_b = overwrite_b)
    else:
        raise NotImplementedError,'calling gelss from %s' % (gelss.module_name)
    if info>0: raise LinAlgError, "SVD did not converge in Linear Least Squares"
    if info<0: raise ValueError,\
       'illegal value in %-th argument of internal gelss'%(-info)
    resids = asarray([], dtype=x.dtype)
    if n<m:
        x1 = x[:n]
        if rank==n: resids = sum(x[n:]**2,axis=0)
        x = x1
    return x,resids,rank,s


def pinv(a, cond=None, rcond=None):
    """ pinv(a, rcond=None) -> a_pinv

    Compute generalized inverse of A using least-squares solver.
    """
    a = asarray_chkfinite(a)
    b = numpy.identity(a.shape[0], dtype=a.dtype)
    if rcond is not None:
        cond = rcond
    return lstsq(a, b, cond=cond)[0]


eps = numpy.finfo(float).eps
feps = numpy.finfo(single).eps

_array_precision = {'f': 0, 'd': 1, 'F': 0, 'D': 1}
def pinv2(a, cond=None, rcond=None):
    """ pinv2(a, rcond=None) -> a_pinv

    Compute the generalized inverse of A using svd.
    """
    a = asarray_chkfinite(a)
    u, s, vh = decomp.svd(a)
    t = u.dtype.char
    if rcond is not None:
        cond = rcond
    if cond in [None,-1]:
        cond = {0: feps*1e3, 1: eps*1e6}[_array_precision[t]]
    m,n = a.shape
    cutoff = cond*numpy.maximum.reduce(s)
    psigma = zeros((m,n),t)
    for i in range(len(s)):
        if s[i] > cutoff:
            psigma[i,i] = 1.0/conjugate(s[i])
    #XXX: use lapack/blas routines for dot
    return transpose(conjugate(dot(dot(u,psigma),vh)))

#-----------------------------------------------------------------------------
# matrix construction functions
#-----------------------------------------------------------------------------

def tri(N, M=None, k=0, dtype=None):
    """ returns a N-by-M matrix where all the diagonals starting from
        lower left corner up to the k-th are all ones.
    """
    if M is None: M = N
    if type(M) == type('d'):
        #pearu: any objections to remove this feature?
        #       As tri(N,'d') is equivalent to tri(N,dtype='d')
        dtype = M
        M = N
    m = greater_equal(subtract.outer(arange(N), arange(M)),-k)
    if dtype is None:
        return m
    else:
        return m.astype(dtype)

def tril(m, k=0):
    """ returns the elements on and below the k-th diagonal of m.  k=0 is the
        main diagonal, k > 0 is above and k < 0 is below the main diagonal.
    """
    svsp = getattr(m,'spacesaver',lambda:0)()
    m = asarray(m)
    out = tri(m.shape[0], m.shape[1], k=k, dtype=m.dtype.char)*m
    pass  ## pass  ## out.savespace(svsp)
    return out

def triu(m, k=0):
    """ returns the elements on and above the k-th diagonal of m.  k=0 is the
        main diagonal, k > 0 is above and k < 0 is below the main diagonal.
    """
    svsp = getattr(m,'spacesaver',lambda:0)()
    m = asarray(m)
    out = (1-tri(m.shape[0], m.shape[1], k-1, m.dtype.char))*m
    pass  ## pass  ## out.savespace(svsp)
    return out

def toeplitz(c,r=None):
    """ Construct a toeplitz matrix (i.e. a matrix with constant diagonals).

        Description:

           toeplitz(c,r) is a non-symmetric Toeplitz matrix with c as its first
           column and r as its first row.

           toeplitz(c) is a symmetric (Hermitian) Toeplitz matrix (r=c).

        See also: hankel
    """
    isscalar = numpy.isscalar
    if isscalar(c) or isscalar(r):
        return c
    if r is None:
        r = c
        r[0] = conjugate(r[0])
        c = conjugate(c)
    r,c = map(asarray_chkfinite,(r,c))
    r,c = map(ravel,(r,c))
    rN,cN = map(len,(r,c))
    if r[0] != c[0]:
        print "Warning: column and row values don't agree; column value used."
    vals = r_[r[rN-1:0:-1], c]
    cols = mgrid[0:cN]
    rows = mgrid[rN:0:-1]
    indx = cols[:,newaxis]*ones((1,rN),dtype=int) + \
           rows[newaxis,:]*ones((cN,1),dtype=int) - 1
    return take(vals, indx, 0)


def hankel(c,r=None):
    """ Construct a hankel matrix (i.e. matrix with constant anti-diagonals).

        Description:

          hankel(c,r) is a Hankel matrix whose first column is c and whose
          last row is r.

          hankel(c) is a square Hankel matrix whose first column is C.
          Elements below the first anti-diagonal are zero.

        See also:  toeplitz
    """
    isscalar = numpy.isscalar
    if isscalar(c) or isscalar(r):
        return c
    if r is None:
        r = zeros(len(c))
    elif r[0] != c[-1]:
        print "Warning: column and row values don't agree; column value used."
    r,c = map(asarray_chkfinite,(r,c))
    r,c = map(ravel,(r,c))
    rN,cN = map(len,(r,c))
    vals = r_[c, r[1:rN]]
    cols = mgrid[1:cN+1]
    rows = mgrid[0:rN]
    indx = cols[:,newaxis]*ones((1,rN),dtype=int) + \
           rows[newaxis,:]*ones((cN,1),dtype=int) - 1
    return take(vals, indx, 0)

def all_mat(*args):
    return map(Matrix,args)

def kron(a,b):
    """kronecker product of a and b

    Kronecker product of two matrices is block matrix
    [[ a[ 0 ,0]*b, a[ 0 ,1]*b, ... , a[ 0 ,n-1]*b  ],
     [ ...                                   ...   ],
     [ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b  ]]
    """
    if not a.flags['CONTIGUOUS']:
        a = reshape(a, a.shape)
    if not b.flags['CONTIGUOUS']:
        b = reshape(b, b.shape)
    o = outer(a,b)
    o=o.reshape(a.shape + b.shape)
    return concatenate(concatenate(o, axis=1), axis=1)
