"""Iterative methods for solving linear systems"""

__all__ = ['bicg','bicgstab','cg','cgs','gmres','qmr']

import _iterative
import numpy as np

from scipy.sparse.linalg.interface import LinearOperator
from utils import make_system

_type_conv = {'f':'s', 'd':'d', 'F':'c', 'D':'z'}


# Part of the docstring common to all iterative solvers
common_doc = \
"""
Parameters
----------
A : {sparse matrix, dense matrix, LinearOperator}
    The N-by-N matrix of the linear system.
b : {array, matrix}
    Right hand side of the linear system. Has shape (N,) or (N,1).

Optional Parameters
-------------------
x0  : {array, matrix}
    Starting guess for the solution.
tol : float
    Relative tolerance to achieve before terminating.
maxiter : integer
    Maximum number of iterations.  Iteration will stop after maxiter
    steps even if the specified tolerance has not been achieved.
M : {sparse matrix, dense matrix, LinearOperator}
    Preconditioner for A.  The preconditioner should approximate the
    inverse of A.  Effective preconditioning dramatically improves the
    rate of convergence, which implies that fewer iterations are needed
    to reach a given error tolerance.
callback : function
    User-supplied function to call after each iteration.  It is called
    as callback(xk), where xk is the current solution vector.

Outputs
-------
x : {array, matrix}
    The converged solution.
info : integer
    Provides convergence information:
        0  : successful exit
        >0 : convergence to tolerance not achieved, number of iterations
        <0 : illegal input or breakdown

Deprecated Parameters
----------------------
xtype : {'f','d','F','D'}
    The type of the result.  If None, then it will be determined from
    A.dtype.char and b.  If A does not have a typecode method then it
    will compute A.matvec(x0) to get a typecode.   To save the extra
    computation when A does not have a typecode attribute use xtype=0
    for the same type as b or use xtype='f','d','F',or 'D'.
    This parameter has been superceeded by LinearOperator.
"""


def set_docstring(header, footer):
    def combine(fn):
        fn.__doc__ = header + '\n' + common_doc + '\n' + footer
        return fn
    return combine



@set_docstring('Use BIConjugate Gradient iteration to solve A x = b','')
def bicg(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=None):
    A,M,x,b,postprocess = make_system(A,M,x0,b,xtype)

    n = len(b)
    if maxiter is None:
        maxiter = n*10

    matvec, rmatvec = A.matvec, A.rmatvec
    psolve, rpsolve = M.matvec, M.rmatvec
    ltr = _type_conv[x.dtype.char]
    revcom   = getattr(_iterative, ltr + 'bicgrevcom')
    stoptest = getattr(_iterative, ltr + 'stoptest2')

    resid = tol
    ndx1 = 1
    ndx2 = -1
    work = np.zeros(6*n,dtype=x.dtype)
    ijob = 1
    info = 0
    ftflag = True
    bnrm2 = -1.0
    iter_ = maxiter
    while True:
        olditer = iter_
        x, iter_, resid, info, ndx1, ndx2, sclr1, sclr2, ijob = \
           revcom(b, x, work, iter_, resid, info, ndx1, ndx2, ijob)
        if callback is not None and iter_ > olditer:
            callback(x)
        slice1 = slice(ndx1-1, ndx1-1+n)
        slice2 = slice(ndx2-1, ndx2-1+n)
        if (ijob == -1):
            if callback is not None:
                callback(x)
            break
        elif (ijob == 1):
            work[slice2] *= sclr2
            work[slice2] += sclr1*matvec(work[slice1])
        elif (ijob == 2):
            work[slice2] *= sclr2
            work[slice2] += sclr1*rmatvec(work[slice1])
        elif (ijob == 3):
            work[slice1] = psolve(work[slice2])
        elif (ijob == 4):
            work[slice1] = rpsolve(work[slice2])
        elif (ijob == 5):
            work[slice2] *= sclr2
            work[slice2] += sclr1*matvec(x)
        elif (ijob == 6):
            if ftflag:
                info = -1
                ftflag = False
            bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
        ijob = 2

    if info > 0 and iter_ == maxiter and resid > tol:
        #info isn't set appropriately otherwise
        info = iter_

    return postprocess(x), info

@set_docstring('Use BIConjugate Gradient STABilized iteration to solve A x = b','')
def bicgstab(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=None):
    A,M,x,b,postprocess = make_system(A,M,x0,b,xtype)

    n = len(b)
    if maxiter is None:
        maxiter = n*10

    matvec = A.matvec
    psolve = M.matvec
    ltr = _type_conv[x.dtype.char]
    revcom   = getattr(_iterative, ltr + 'bicgstabrevcom')
    stoptest = getattr(_iterative, ltr + 'stoptest2')

    resid = tol
    ndx1 = 1
    ndx2 = -1
    work = np.zeros(7*n,dtype=x.dtype)
    ijob = 1
    info = 0
    ftflag = True
    bnrm2 = -1.0
    iter_ = maxiter
    while True:
        olditer = iter_
        x, iter_, resid, info, ndx1, ndx2, sclr1, sclr2, ijob = \
           revcom(b, x, work, iter_, resid, info, ndx1, ndx2, ijob)
        if callback is not None and iter_ > olditer:
            callback(x)
        slice1 = slice(ndx1-1, ndx1-1+n)
        slice2 = slice(ndx2-1, ndx2-1+n)
        if (ijob == -1):
            if callback is not None:
                callback(x)
            break
        elif (ijob == 1):
            if matvec is None:
                matvec = get_matvec(A)
            work[slice2] *= sclr2
            work[slice2] += sclr1*matvec(work[slice1])
        elif (ijob == 2):
            if psolve is None:
                psolve = get_psolve(A)
            work[slice1] = psolve(work[slice2])
        elif (ijob == 3):
            if matvec is None:
                matvec = get_matvec(A)
            work[slice2] *= sclr2
            work[slice2] += sclr1*matvec(x)
        elif (ijob == 4):
            if ftflag:
                info = -1
                ftflag = False
            bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
        ijob = 2

    if info > 0 and iter_ == maxiter and resid > tol:
        #info isn't set appropriately otherwise
        info = iter_

    return postprocess(x), info

@set_docstring('Use Conjugate Gradient iteration to solve A x = b','')
def cg(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=None):
    A,M,x,b,postprocess = make_system(A,M,x0,b,xtype)

    n = len(b)
    if maxiter is None:
        maxiter = n*10

    matvec = A.matvec
    psolve = M.matvec
    ltr = _type_conv[x.dtype.char]
    revcom   = getattr(_iterative, ltr + 'cgrevcom')
    stoptest = getattr(_iterative, ltr + 'stoptest2')

    resid = tol
    ndx1 = 1
    ndx2 = -1
    work = np.zeros(4*n,dtype=x.dtype)
    ijob = 1
    info = 0
    ftflag = True
    bnrm2 = -1.0
    iter_ = maxiter
    while True:
        olditer = iter_
        x, iter_, resid, info, ndx1, ndx2, sclr1, sclr2, ijob = \
           revcom(b, x, work, iter_, resid, info, ndx1, ndx2, ijob)
        if callback is not None and iter_ > olditer:
            callback(x)
        slice1 = slice(ndx1-1, ndx1-1+n)
        slice2 = slice(ndx2-1, ndx2-1+n)
        if (ijob == -1):
            if callback is not None:
                callback(x)
            break
        elif (ijob == 1):
            work[slice2] *= sclr2
            work[slice2] += sclr1*matvec(work[slice1])
        elif (ijob == 2):
            work[slice1] = psolve(work[slice2])
        elif (ijob == 3):
            work[slice2] *= sclr2
            work[slice2] += sclr1*matvec(x)
        elif (ijob == 4):
            if ftflag:
                info = -1
                ftflag = False
            bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
        ijob = 2


    if info > 0 and iter_ == maxiter and resid > tol:
        #info isn't set appropriately otherwise
        info = iter_

    return postprocess(x), info


@set_docstring('Use Conjugate Gradient Squared iteration to solve A x = b','')
def cgs(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=None):
    A,M,x,b,postprocess = make_system(A,M,x0,b,xtype)

    n = len(b)
    if maxiter is None:
        maxiter = n*10

    matvec = A.matvec
    psolve = M.matvec
    ltr = _type_conv[x.dtype.char]
    revcom   = getattr(_iterative, ltr + 'cgsrevcom')
    stoptest = getattr(_iterative, ltr + 'stoptest2')

    resid = tol
    ndx1 = 1
    ndx2 = -1
    work = np.zeros(7*n,dtype=x.dtype)
    ijob = 1
    info = 0
    ftflag = True
    bnrm2 = -1.0
    iter_ = maxiter
    while True:
        olditer = iter_
        x, iter_, resid, info, ndx1, ndx2, sclr1, sclr2, ijob = \
           revcom(b, x, work, iter_, resid, info, ndx1, ndx2, ijob)
        if callback is not None and iter_ > olditer:
            callback(x)
        slice1 = slice(ndx1-1, ndx1-1+n)
        slice2 = slice(ndx2-1, ndx2-1+n)
        if (ijob == -1):
            if callback is not None:
                callback(x)
            break
        elif (ijob == 1):
            work[slice2] *= sclr2
            work[slice2] += sclr1*matvec(work[slice1])
        elif (ijob == 2):
            work[slice1] = psolve(work[slice2])
        elif (ijob == 3):
            work[slice2] *= sclr2
            work[slice2] += sclr1*matvec(x)
        elif (ijob == 4):
            if ftflag:
                info = -1
                ftflag = False
            bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
        ijob = 2

    if info > 0 and iter_ == maxiter and resid > tol:
        #info isn't set appropriately otherwise
        info = iter_

    return postprocess(x), info


def gmres(A, b, x0=None, tol=1e-5, restrt=20, maxiter=None, xtype=None, M=None, callback=None):
    """Use Generalized Minimal RESidual iteration to solve A x = b

    Parameters
    ----------
    A : {sparse matrix, dense matrix, LinearOperator}
        The N-by-N matrix of the linear system.
    b : {array, matrix}
        Right hand side of the linear system. Has shape (N,) or (N,1).

    Optional Parameters
    -------------------
    x0  : {array, matrix}
        Starting guess for the solution.
    tol : float
        Relative tolerance to achieve before terminating.
    restrt : integer
        Number of iterations between restarts. Larger values increase
        iteration cost, but may be necessary for convergence.
    maxiter : integer
        Maximum number of iterations.  Iteration will stop after maxiter
        steps even if the specified tolerance has not been achieved.
    M : {sparse matrix, dense matrix, LinearOperator}
        Preconditioner for A.  The preconditioner should approximate the
        inverse of A.  Effective preconditioning dramatically improves the
        rate of convergence, which implies that fewer iterations are needed
        to reach a given error tolerance.
    callback : function
        User-supplied function to call after each iteration.  It is called
        as callback(rk), where rk is the current residual vector.

    Outputs
    -------
    x : {array, matrix}
        The converged solution.
    info : integer
        Provides convergence information:
            0  : successful exit
            >0 : convergence to tolerance not achieved, number of iterations
            <0 : illegal input or breakdown

    See Also
    --------
    LinearOperator

    Deprecated Parameters
    ---------------------
    xtype : {'f','d','F','D'}
        The type of the result.  If None, then it will be determined from
        A.dtype.char and b.  If A does not have a typecode method then it
        will compute A.matvec(x0) to get a typecode.   To save the extra
        computation when A does not have a typecode attribute use xtype=0
        for the same type as b or use xtype='f','d','F',or 'D'.
        This parameter has been superceeded by LinearOperator.

    """
    A,M,x,b,postprocess = make_system(A,M,x0,b,xtype)

    n = len(b)
    if maxiter is None:
        maxiter = n*10

    restrt = min(restrt, n)

    matvec = A.matvec
    psolve = M.matvec
    ltr = _type_conv[x.dtype.char]
    revcom   = getattr(_iterative, ltr + 'gmresrevcom')
    stoptest = getattr(_iterative, ltr + 'stoptest2')

    resid = tol
    ndx1 = 1
    ndx2 = -1
    work  = np.zeros((6+restrt)*n,dtype=x.dtype)
    work2 = np.zeros((restrt+1)*(2*restrt+2),dtype=x.dtype)
    ijob = 1
    info = 0
    ftflag = True
    bnrm2 = -1.0
    iter_ = maxiter
    old_ijob = ijob
    first_pass = True
    resid_ready = False
    iter_num = 1
    while True:
        olditer = iter_
        x, iter_, resid, info, ndx1, ndx2, sclr1, sclr2, ijob = \
           revcom(b, x, restrt, work, work2, iter_, resid, info, ndx1, ndx2, ijob)
        #if callback is not None and iter_ > olditer:
        #    callback(x)
        slice1 = slice(ndx1-1, ndx1-1+n)
        slice2 = slice(ndx2-1, ndx2-1+n)
        if (ijob == -1): # gmres success, update last residual
            if resid_ready and callback is not None:
                callback(resid)
                resid_ready = False

            break
        elif (ijob == 1):
            work[slice2] *= sclr2
            work[slice2] += sclr1*matvec(x)
        elif (ijob == 2):
            work[slice1] = psolve(work[slice2])
            if not first_pass and old_ijob==3:
                resid_ready = True

            first_pass = False
        elif (ijob == 3):
            work[slice2] *= sclr2
            work[slice2] += sclr1*matvec(work[slice1])
            if resid_ready and callback is not None:
                callback(resid)
                resid_ready = False
                iter_num = iter_num+1

        elif (ijob == 4):
            if ftflag:
                info = -1
                ftflag = False
            bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)

        old_ijob = ijob
        ijob = 2

        if iter_num > maxiter:
            break

    if info >= 0 and resid > tol:
        #info isn't set appropriately otherwise
        info = maxiter

    return postprocess(x), info


def qmr(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M1=None, M2=None, callback=None):
    """Use Quasi-Minimal Residual iteration to solve A x = b

    Parameters
    ----------
    A : {sparse matrix, dense matrix, LinearOperator}
        The N-by-N matrix of the linear system.
    b : {array, matrix}
        Right hand side of the linear system. Has shape (N,) or (N,1).

    Optional Parameters
    -------------------
    x0  : {array, matrix}
        Starting guess for the solution.
    tol : float
        Relative tolerance to achieve before terminating.
    maxiter : integer
        Maximum number of iterations.  Iteration will stop after maxiter
        steps even if the specified tolerance has not been achieved.
    M1 : {sparse matrix, dense matrix, LinearOperator}
        Left preconditioner for A.
    M2 : {sparse matrix, dense matrix, LinearOperator}
        Right preconditioner for A. Used together with the left
        preconditioner M1.  The matrix M1*A*M2 should have better
        conditioned than A alone.
    callback : function
        User-supplied function to call after each iteration.  It is called
        as callback(xk), where xk is the current solution vector.

    Outputs
    -------
    x : {array, matrix}
        The converged solution.
    info : integer
        Provides convergence information:
            0  : successful exit
            >0 : convergence to tolerance not achieved, number of iterations
            <0 : illegal input or breakdown

    See Also
    --------
    LinearOperator

    Deprecated Parameters
    ---------------------
    xtype : {'f','d','F','D'}
        The type of the result.  If None, then it will be determined from
        A.dtype.char and b.  If A does not have a typecode method then it
        will compute A.matvec(x0) to get a typecode.   To save the extra
        computation when A does not have a typecode attribute use xtype=0
        for the same type as b or use xtype='f','d','F',or 'D'.
        This parameter has been superceeded by LinearOperator.

    """
    A_ = A
    A,M,x,b,postprocess = make_system(A,None,x0,b,xtype)

    if M1 is None and M2 is None:
        if hasattr(A_,'psolve'):
            def left_psolve(b):
                return A_.psolve(b,'left')
            def right_psolve(b):
                return A_.psolve(b,'right')
            def left_rpsolve(b):
                return A_.rpsolve(b,'left')
            def right_rpsolve(b):
                return A_.rpsolve(b,'right')
            M1 = LinearOperator(A.shape, matvec=left_psolve, rmatvec=left_rpsolve)
            M2 = LinearOperator(A.shape, matvec=right_psolve, rmatvec=right_rpsolve)
        else:
            def id(b):
                return b
            M1 = LinearOperator(A.shape, matvec=id, rmatvec=id)
            M2 = LinearOperator(A.shape, matvec=id, rmatvec=id)

    n = len(b)
    if maxiter is None:
        maxiter = n*10

    ltr = _type_conv[x.dtype.char]
    revcom   = getattr(_iterative, ltr + 'qmrrevcom')
    stoptest = getattr(_iterative, ltr + 'stoptest2')

    resid = tol
    ndx1 = 1
    ndx2 = -1
    work = np.zeros(11*n,x.dtype)
    ijob = 1
    info = 0
    ftflag = True
    bnrm2 = -1.0
    iter_ = maxiter
    while True:
        olditer = iter_
        x, iter_, resid, info, ndx1, ndx2, sclr1, sclr2, ijob = \
           revcom(b, x, work, iter_, resid, info, ndx1, ndx2, ijob)
        if callback is not None and iter_ > olditer:
            callback(x)
        slice1 = slice(ndx1-1, ndx1-1+n)
        slice2 = slice(ndx2-1, ndx2-1+n)
        if (ijob == -1):
            if callback is not None:
                callback(x)
            break
        elif (ijob == 1):
            work[slice2] *= sclr2
            work[slice2] += sclr1*A.matvec(work[slice1])
        elif (ijob == 2):
            work[slice2] *= sclr2
            work[slice2] += sclr1*A.rmatvec(work[slice1])
        elif (ijob == 3):
            work[slice1] = M1.matvec(work[slice2])
        elif (ijob == 4):
            work[slice1] = M2.matvec(work[slice2])
        elif (ijob == 5):
            work[slice1] = M1.rmatvec(work[slice2])
        elif (ijob == 6):
            work[slice1] = M2.rmatvec(work[slice2])
        elif (ijob == 7):
            work[slice2] *= sclr2
            work[slice2] += sclr1*A.matvec(x)
        elif (ijob == 8):
            if ftflag:
                info = -1
                ftflag = False
            bnrm2, resid, info = stoptest(work[slice1], b, bnrm2, tol, info)
        ijob = 2

    if info > 0 and iter_ == maxiter and resid > tol:
        #info isn't set appropriately otherwise
        info = iter_

    return postprocess(x), info
