from numpy import sqrt, inner, finfo, zeros
from numpy.linalg import norm

from utils import make_system
from iterative import set_docstring

__all__ = ['minres']


header = \
"""Use MINimum RESidual iteration to solve Ax=b

MINRES minimizes norm(A*x - b) for a real symmetric matrix A.  Unlike
the Conjugate Gradient method, A can be indefinite or singular.

If shift != 0 then the method solves (A - shift*I)x = b
"""

Ainfo = "The real symmetric N-by-N matrix of the linear system"

footer = \
"""
Notes
-----
THIS FUNCTION IS EXPERIMENTAL AND SUBJECT TO CHANGE!

References
----------
Solution of sparse indefinite systems of linear equations,
    C. C. Paige and M. A. Saunders (1975),
    SIAM J. Numer. Anal. 12(4), pp. 617-629.
    http://www.stanford.edu/group/SOL/software/minres.html

This file is a translation of the following MATLAB implementation:
    http://www.stanford.edu/group/SOL/software/minres/matlab/
"""

@set_docstring(header,
               Ainfo,
               footer)
def minres(A, b, x0=None, shift=0.0, tol=1e-5, maxiter=None, xtype=None,
           M=None, callback=None, show=False, check=False):
    A,M,x,b,postprocess = make_system(A,M,x0,b,xtype)

    matvec = A.matvec
    psolve = M.matvec

    first = 'Enter minres.   '
    last  = 'Exit  minres.   '

    n = A.shape[0]

    if maxiter is None:
        maxiter = 5 * n


    msg   =[' beta2 = 0.  If M = I, b and x are eigenvectors    ',   # -1
            ' beta1 = 0.  The exact solution is  x = 0          ',   #  0
            ' A solution to Ax = b was found, given rtol        ',   #  1
            ' A least-squares solution was found, given rtol    ',   #  2
            ' Reasonable accuracy achieved, given eps           ',   #  3
            ' x has converged to an eigenvector                 ',   #  4
            ' acond has exceeded 0.1/eps                        ',   #  5
            ' The iteration limit was reached                   ',   #  6
            ' A  does not define a symmetric matrix             ',   #  7
            ' M  does not define a symmetric matrix             ',   #  8
            ' M  does not define a pos-def preconditioner       ']   #  9


    if show:
        print first + 'Solution of symmetric Ax = b'
        print first + 'n      =  %3g     shift  =  %23.14e'  % (n,shift)
        print first + 'itnlim =  %3g     rtol   =  %11.2e'   % (maxiter,tol)
        print

    istop = 0;   itn   = 0;   Anorm = 0;    Acond = 0;
    rnorm = 0;   ynorm = 0;

    xtype = x.dtype

    eps = finfo(xtype).eps

    x = zeros( n, dtype=xtype )

    # Set up y and v for the first Lanczos vector v1.
    # y  =  beta1 P' v1,  where  P = C**(-1).
    # v is really P' v1.

    y  = b
    r1 = b

    y = psolve(b)

    beta1 = inner(b,y)

    if beta1 < 0:
        raise ValueError('indefinite preconditioner')
    elif beta1 == 0:
        return (postprocess(x), 0)

    beta1 = sqrt( beta1 )

    if check:
        # are these too strict?

        # see if A is symmetric
        w    = matvec(y)
        r2   = matvec(w)
        s    = inner(w,w)
        t    = inner(y,r2)
        z    = abs( s - t )
        epsa = (s + eps) * eps**(1.0/3.0)
        if z > epsa:
            raise ValueError('non-symmetric matrix')

        # see if M is symmetric
        r2   = psolve(y)
        s    = inner(y,y)
        t    = inner(r1,r2)
        z    = abs( s - t )
        epsa = (s + eps) * eps**(1.0/3.0)
        if z > epsa:
            raise ValueError('non-symmetric preconditioner')


    # Initialize other quantities
    oldb   = 0;          beta   = beta1;   dbar   = 0;       epsln  = 0;
    qrnorm = beta1;      phibar = beta1;   rhs1   = beta1;
    rhs2   = 0;          tnorm2 = 0;       ynorm2 = 0;
    cs     = -1;         sn     = 0;
    w      = zeros(n, dtype=xtype)
    w2     = zeros(n, dtype=xtype)
    r2     = r1

    if show:
        print
        print
        print '   Itn     x(1)     Compatible    LS       norm(A)  cond(A) gbar/|A|'

    while itn < maxiter:
        itn += 1

        s = 1.0/beta
        v = s*y

        y  = matvec(v)
        y  = y - shift * v

        if itn >= 2:
            y  = y - (beta/oldb)*r1

        alfa   = inner(v,y)
        y      = y - (alfa/beta)*r2
        r1     = r2
        r2     = y
        y      = psolve(r2)
        oldb   = beta
        beta   = inner(r2,y)
        if beta < 0:
            raise ValueError('non-symmetric matrix')
        beta    = sqrt(beta)
        tnorm2 += alfa**2 + oldb**2 + beta**2

        if itn == 1:
            if beta/beta1 <= 10*eps:
                istop = -1  # Terminate later
            #tnorm2 = alfa**2 ??
            gmax = abs(alfa)
            gmin = gmax

        # Apply previous rotation Qk-1 to get
        #   [deltak epslnk+1] = [cs  sn][dbark    0   ]
        #   [gbar k dbar k+1]   [sn -cs][alfak betak+1].

        oldeps = epsln
        delta  = cs * dbar  +  sn * alfa   # delta1 = 0         deltak
        gbar   = sn * dbar  -  cs * alfa   # gbar 1 = alfa1     gbar k
        epsln  =               sn * beta   # epsln2 = 0         epslnk+1
        dbar   =            -  cs * beta   # dbar 2 = beta2     dbar k+1
        root   = norm([gbar, dbar])
        Arnorm = phibar * root

        # Compute the next plane rotation Qk

        gamma  = norm([gbar, beta])       # gammak
        gamma  = max(gamma, eps)
        cs     = gbar / gamma             # ck
        sn     = beta / gamma             # sk
        phi    = cs * phibar              # phik
        phibar = sn * phibar              # phibark+1

        # Update  x.

        denom = 1.0/gamma
        w1    = w2
        w2    = w
        w     = (v - oldeps*w1 - delta*w2) * denom
        x     = x + phi*w

        # Go round again.

        gmax   = max(gmax, gamma)
        gmin   = min(gmin, gamma)
        z      = rhs1 / gamma
        ynorm2 = z**2  +  ynorm2
        rhs1   = rhs2 -  delta*z
        rhs2   =      -  epsln*z

        # Estimate various norms and test for convergence.

        Anorm  = sqrt( tnorm2 )
        ynorm  = sqrt( ynorm2 )
        epsa   = Anorm * eps
        epsx   = Anorm * ynorm * eps
        epsr   = Anorm * ynorm * tol
        diag   = gbar

        if diag == 0: diag = epsa

        qrnorm = phibar
        rnorm  = qrnorm
        test1  = rnorm / (Anorm*ynorm)    #  ||r|| / (||A|| ||x||)
        test2  = root  /  Anorm           # ||Ar|| / (||A|| ||r||)

        # Estimate  cond(A).
        # In this version we look at the diagonals of  R  in the
        # factorization of the lower Hessenberg matrix,  Q * H = R,
        # where H is the tridiagonal matrix from Lanczos with one
        # extra row, beta(k+1) e_k^T.

        Acond  = gmax/gmin

        # See if any of the stopping criteria are satisfied.
        # In rare cases, istop is already -1 from above (Abar = const*I).

        if istop == 0:
            t1 = 1 + test1      # These tests work if tol < eps
            t2 = 1 + test2
            if t2    <= 1       : istop = 2
            if t1    <= 1       : istop = 1

            if itn   >= maxiter : istop = 6
            if Acond >= 0.1/eps : istop = 4
            if epsx  >= beta1   : istop = 3
            #if rnorm <= epsx   : istop = 2
            #if rnorm <= epsr   : istop = 1
            if test2 <= tol     : istop = 2
            if test1 <= tol     : istop = 1

        # See if it is time to print something.

        prnt = False
        if n        <= 40         : prnt = True
        if itn      <= 10         : prnt = True
        if itn      >= maxiter-10 : prnt = True
        if itn % 10 == 0          : prnt = True
        if qrnorm   <= 10*epsx    : prnt = True
        if qrnorm   <= 10*epsr    : prnt = True
        if Acond    <= 1e-2/eps   : prnt = True
        if istop  !=  0           : prnt = True

        if show and prnt:
            str1 = '%6g %12.5e %10.3e'  % (itn, x[0], test1)
            str2 = ' %10.3e'            % (test2,)
            str3 = ' %8.1e %8.1e %8.1e' % (Anorm, Acond, gbar/Anorm)

            print str1 + str2 + str3

            if itn % 10 == 0: print

        if callback is not None:
            callback(x)

        if istop != 0: break #TODO check this


    if show:
        print
        print last + ' istop   =  %3g               itn   =%5g' % (istop,itn)
        print last + ' Anorm   =  %12.4e      Acond =  %12.4e'  % (Anorm,Acond)
        print last + ' rnorm   =  %12.4e      ynorm =  %12.4e'  % (rnorm,ynorm)
        print last + ' Arnorm  =  %12.4e'                       %  (Arnorm,)
        print last + msg[istop+1]

    if istop == 6:
        info = maxiter
    else:
        info = 0

    return (postprocess(x),info)


if __name__ == '__main__':
    from scipy import ones, arange
    from scipy.linalg import norm
    from scipy.sparse import spdiags

    n = 10

    residuals = []

    def cb(x):
        residuals.append(norm(b - A*x))

    #A = poisson((10,),format='csr')
    A = spdiags( [arange(1,n+1,dtype=float)], [0], n, n, format='csr')
    M = spdiags( [1.0/arange(1,n+1,dtype=float)], [0], n, n, format='csr')
    A.psolve = M.matvec
    b = 0*ones( A.shape[0] )
    x = minres(A,b,tol=1e-12,maxiter=None,callback=cb)
    #x = cg(A,b,x0=b,tol=1e-12,maxiter=None,callback=cb)[0]
