#include "linalg.h"
#include "utils.h"

#include <tjutils/tjtest.h>
#include <tjutils/tjthread.h>


#ifdef HAVE_LIBGSL
#include <gsl/gsl_linalg.h>
#include <gsl/gsl_eigen.h>
#include <gsl/gsl_sort_vector.h>
#endif


#ifdef HAVE_LAPACK

// The lapack functions
extern "C" void cgelss_( int* M, int* N, int* NRHS, STD_complex* A, int* LDA, STD_complex* B, int* LDB, float* S, float* RCOND, int* RANK, STD_complex* WORK, int* LWORK, float* RWORK, int* INFO );
extern "C" void sgelss_( int* M, int* N, int* NRHS, float*       A, int* LDA, float*       B, int* LDB, float* S, float* RCOND, int* RANK, float*       WORK, int* LWORK, int* INFO );

//extern "C" void cgesvd_( char* JOBU, char* JOBVT, int* M, int* N, STD_complex* A, int* LDA, float* S, STD_complex* U, int* LDU, STD_complex* VT, int* LDVT, STD_complex* WORK, int* LWORK, float* RWORK, int* INFO );
//extern "C" void sgesvd_( char* JOBU, char* JOBVT, int* M, int* N, float*       A, int* LDA, float* S, float*       U, int* LDU, float* VT,       int* LDVT, float*       WORK, int* LWORK,               int* INFO );

extern "C" void ssyev_( char* JOBZ, char* UPLO, int* N, float* A, int* LDA, float* W, float* WORK, int* LWORK, int* INFO);

// Chose appropriate LAPACK function by function overloading
int gelss( int* M, int* N, int* NRHS, STD_complex* A, int* LDA, STD_complex* B, int* LDB, float* S, float* RCOND, int* RANK, STD_complex* WORK, int* LWORK, float* RWORK, int* INFO ) {
  cgelss_( M, N, NRHS, A, LDA, B, LDB, S, RCOND, RANK, WORK, LWORK, RWORK, INFO );
  return int(WORK[0].real());
}
int gelss( int* M, int* N, int* NRHS, float*       A, int* LDA, float*       B, int* LDB, float* S, float* RCOND, int* RANK, float*       WORK, int* LWORK, float* RWORK, int* INFO ) {
  sgelss_( M, N, NRHS, A, LDA, B, LDB, S, RCOND, RANK, WORK, LWORK, INFO );
  return int(WORK[0]);
}

/*
int gesvd( char* JOBU, char* JOBVT, int* M, int* N, STD_complex* A, int* LDA, float* S, STD_complex* U, int* LDU, STD_complex* VT, int* LDVT, STD_complex* WORK, int* LWORK, float* RWORK, int* INFO) {
  cgesvd_( JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, RWORK, INFO );
  return int(WORK[0].real());
}
int gesvd( char* JOBU, char* JOBVT, int* M, int* N, float*       A, int* LDA, float* S, float*       U, int* LDU, float*       VT, int* LDVT, float*       WORK, int* LWORK, float* RWORK, int* INFO) {
  sgesvd_( JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, INFO );
  return int(WORK[0]);
}
*/



////////////////////////////////////

bool report_error(int INFO, const char* caller) {
  Log<OdinData> odinlog("",caller);
  if(INFO<0) {
    ODINLOG(odinlog,errorLog) << "the " << -INFO << "-th argument had an illegal value." << STD_endl;
    return true;
  }
  if(INFO>0) {
    ODINLOG(odinlog,errorLog) << "the algorithm failed to converge." << STD_endl;
    return true;
  }
  return false;
}

////////////////////////////////////

static Mutex lapack_mutex;

////////////////////////////////////

/*
Array<STD_complex,2> conjugate_transpose_matrix(Array<STD_complex,2>& M) {return Array<STD_complex,2>(conj(M.transpose(1,0)));}
Array<float,2>       conjugate_transpose_matrix(Array<float,2>& M)       {return M.transpose(1,0);}


template <typename T>
bool svd_lapack(const Data<T,2>& A, Data<T,2>& U, Data<float,1>& sigma, Data<T,2>& V) {
  Log<OdinData> odinlog("","svd_lapack");

  int M=A.extent(0); // rows
  int N=A.extent(1); // cols
  ODINLOG(odinlog,normalDebug) << "M/N=" << M << "/" << N << STD_endl;


  Array<T,2> A_fortran(A.shape(), ColumnMajorArray<2>()); // Array with Fortran storage order
  A_fortran=A; // creates unique copy

  sigma.resize(N);

  Array<T,2> U_fortran(A.shape(), ColumnMajorArray<2>()); // Array with Fortran storage order
  Array<T,2> VT_fortran(TinyVector<int,2>(N,N), ColumnMajorArray<2>()); // Array with Fortran storage order


  Array<T,1> WORK(1);
  int LWORK=-1; // get optimal workspace size
  Array<float,1> RWORK(5*STD_min(M,N));

  int NRHS=1;
  int RANK;

  int INFO;

  char JOBU='A';
  char JOBVT='A';

#ifndef HAVE_THREADSAFE_LAPACK
  MutexLock lock(lapack_mutex); //LAPACK is generally not thread safe, so we will serialize use of LAPACK
#endif

  // call once to get optimal size of the WORK array
  LWORK=gesvd( &JOBU, &JOBVT, &M, &N, A_fortran.data(), &M, sigma.data(), U_fortran.data(), &M, VT_fortran.data(), &N, WORK.data(), &LWORK, RWORK.data(), &INFO);
  ODINLOG(odinlog,normalDebug) << "INFO/LWORK=" << INFO << "/" << LWORK << STD_endl;
  if(report_error(INFO,"svd_lapack(worksize)")) return false;

  WORK.resize(LWORK);

  // perform SVD
  gesvd( &JOBU, &JOBVT, &M, &N, A_fortran.data(), &M, sigma.data(), U_fortran.data(), &M, VT_fortran.data(), &N, WORK.data(), &LWORK, RWORK.data(), &INFO);
  ODINLOG(odinlog,normalDebug) << "INFO=" << INFO << STD_endl;
  if(report_error(INFO,"svd_lapack(svd)")) return false;

  U.resize(M,N);
  U=U_fortran;

  V.resize(N,N);
  V=conjugate_transpose_matrix(VT_fortran);

  return true;
}
*/

////////////////////////////////////



template <typename T>
bool solve_linear_lapack(Data<T,1>& result, const Data<T,2>& A, const Data<T,1>& b, float sv_truncation) {
  Log<OdinData> odinlog("","solve_linear_lapack");


  int M=A.extent(0); // rows
  int N=A.extent(1); // cols
  ODINLOG(odinlog,normalDebug) << "M/N=" << M << "/" << N << STD_endl;


  // Array with Fortran storage order
  Array<T,2> A_fortran(N,M);
  for(int in=0; in<N; in++) {
    for(int im=0; im<M; im++) {
      A_fortran(in,im)=A(im,in);
    }
  }

  Array<T,1> B_X(M); // Use Copy because it is used for in- and output
  B_X=b;

  Array<float,1> S(N);

  Array<T,1> WORK(1);
  int LWORK=-1; // get optimal workspace size
  Array<float,1> RWORK(5*STD_min(M,N));

  int NRHS=1;
  int RANK;

  int INFO;

#ifndef HAVE_THREADSAFE_LAPACK
  MutexLock lock(lapack_mutex); //LAPACK is generally not thread safe, so we will serialize use of LAPACK
#endif

  // call once to get optimal size of the WORK array
  LWORK=gelss( &M, &N, &NRHS, A_fortran.data(), &M, B_X.data(), &M, S.data(), &sv_truncation, &RANK, WORK.data(), &LWORK, RWORK.data(), &INFO );
  ODINLOG(odinlog,normalDebug) << "INFO/LWORK=" << INFO << "/" << LWORK << STD_endl;
  if(report_error(INFO,"solve_linear_lapack(worksize)")) return false;

  WORK.resize(LWORK);

  // perform SVD
  gelss( &M, &N, &NRHS, A_fortran.data(), &M, B_X.data(), &M, S.data(), &sv_truncation, &RANK, WORK.data(), &LWORK, RWORK.data(), &INFO );
  ODINLOG(odinlog,normalDebug) << "INFO=" << INFO << STD_endl;
  if(report_error(INFO,"solve_linear_lapack(svd)")) return false;

  result.resize(N);
  result=B_X(Range(0,N-1));

  return true;
}

#endif

///////////////////////////////////////////////////////////////////////////////////////////////////////


bool shape_error(const TinyVector<int,2>& A_shape, int b_extent) {
  Log<OdinData> odinlog("solve_linear","shape_error");
  int A_nrows=A_shape(0);
  int A_ncols=A_shape(1);

  if(A_nrows==0 || A_ncols==0) {
    ODINLOG(odinlog,errorLog) << "Zero-size matrix" << STD_endl;
    return true;
  }

  if(A_ncols>A_nrows) {
    ODINLOG(odinlog,errorLog) << "cols>rows matrices not supported" << STD_endl;
    return true;
  }

  if(b_extent!=A_nrows) {
    ODINLOG(odinlog,errorLog) << "size mismatch (b_extent=" << b_extent << ") != (A_nrows=" << A_nrows << ")" << STD_endl;
    return true;
  }
  return false;
}

//////////////////////////////////////////////////

/*
ComplexData<2> matrix_inverse(const ComplexData<2>& A, float sv_truncation) {
  Log<OdinData> odinlog("","matrix_inverse(complex)");

  Range all=Range::all();

  ComplexData<2> result;

  bool svdresult=false;

  ComplexData<2> U;
  Data<float,1> sigma;
  ComplexData<2> V;

#ifdef HAVE_LAPACK
  svdresult=svd_lapack(A, U, sigma, V);
#endif
  if(!svdresult) {
    ODINLOG(odinlog,errorLog) << "SVD failed" << STD_endl;
  }

  // Regularization
  int ncol=sigma.size();
  float threshold=sv_truncation*sigma(0);
  ComplexData<1> Sinv(ncol);
  for(int icol=0; icol<ncol; icol++) {
    if(sigma(icol)<threshold) Sinv(icol)=STD_complex(0.0);
    else Sinv(icol)=STD_complex(secureInv(sigma(icol)));
  }

  // Calculating the inverse
  for(int icol=0; icol<ncol; icol++) {
    V(all,icol)*=Sinv(icol); // apply inverse singular values to V
  }
  result.reference(matrix_product(V,conjugate_transpose_matrix(U)));

  return result;
}
*/

Data<float,1> solve_linear(const Data<float,2>& A, const Data<float,1>& b, float sv_truncation) {
  Log<OdinData> odinlog("","solve_linear(float)");

  Data<float,1> result;

  if(shape_error(A.shape(),b.extent(0))) return result;

#ifdef HAVE_LAPACK

  solve_linear_lapack(result,A, b, sv_truncation);

#else
#ifdef HAVE_LIBGSL

  Range all=Range::all();

  int nrows=A.extent(0);
  int ncols=A.extent(1);
  ODINLOG(odinlog,normalDebug) << "nrows/ncols=" << nrows << "/" << ncols << STD_endl;


  gsl_matrix *A_gsl=gsl_matrix_alloc (nrows, ncols);
  gsl_matrix *V_gsl=gsl_matrix_alloc (ncols, ncols);

  gsl_vector *S_gsl=gsl_vector_alloc (ncols);
  gsl_vector *work =gsl_vector_alloc (ncols);

  gsl_matrix *X_gsl=0; // additional workspace
  if( nrows > (2*ncols) ) X_gsl=gsl_matrix_alloc (ncols, ncols); // Use modified SVD for massively over-determined matrices


  for(int irow=0; irow<nrows; irow++) {
    for(int icol=0; icol<ncols; icol++) {
      gsl_matrix_set (A_gsl, irow, icol, A(irow,icol));
    }
  }

  if(X_gsl) gsl_linalg_SV_decomp_mod(A_gsl, X_gsl, V_gsl, S_gsl, work);
  else      gsl_linalg_SV_decomp(A_gsl,V_gsl,S_gsl,work);
  gsl_vector_free(work);
  if(X_gsl) gsl_matrix_free(X_gsl);


  Data<float,2> U(nrows,ncols);
  for(int irow=0; irow<nrows; irow++) {
    for(int icol=0; icol<ncols; icol++) {
      U(irow,icol)=gsl_matrix_get (A_gsl, irow, icol);
    }
  }
  gsl_matrix_free(A_gsl);


  Data<float,1> S(ncols);
  for(int icol=0; icol<ncols; icol++) S(icol)=gsl_vector_get(S_gsl, icol);
  gsl_vector_free(S_gsl);
  ODINLOG(odinlog,normalDebug) << "S=" << S << STD_endl;

  Data<float,2> V(ncols, ncols);
  for(int jcol=0; jcol<ncols; jcol++) {
    for(int icol=0; icol<ncols; icol++) {
      V(jcol,icol)=gsl_matrix_get (V_gsl, jcol, icol);
    }
  }
  gsl_matrix_free(V_gsl);


  // Regularization
  float threshold=sv_truncation*S(0);
  Data<float,1> Sinv(ncols);
  for(int icol=0; icol<ncols; icol++) {
    if(S(icol)<threshold) Sinv(icol)=0.0;
    else Sinv(icol)=secureInv(S(icol));
  }

  // Calculating the inverse
  for(int icol=0; icol<ncols; icol++) {
    V(all,icol)*=Sinv(icol); // apply inverse singular values to V
  }
  Data<float,2> Ainv(matrix_product(V,U.transpose(1,0)));

  result.reference(matrix_product(Ainv,b));


#else
#error "Neither LAPACK nor GSL not available"
#endif
#endif

  return result;
}



//////////////////////////////////////////////////




ComplexData<1> solve_linear(const ComplexData<2>& A, const ComplexData<1>& b, float sv_truncation) {
  Log<OdinData> odinlog("","solve_linear(complex)");

  ComplexData<1> result;

  if(shape_error(A.shape(),b.extent(0))) return result;

#ifdef HAVE_LAPACK
  solve_linear_lapack(result, A, b, sv_truncation);

#else

  // create real matrix according to NRC, section 2.3
  int A_nrows=A.extent(0);
  int A_ncols=A.extent(1);
  Data<float,2> Af(2*A_nrows,2*A_ncols);

  for(int irow=0; irow<A_nrows; irow++) {
    for(int icol=0; icol<A_ncols; icol++) {
      float re=A(irow,icol).real();
      float im=A(irow,icol).imag();
      Af(irow,icol)=re;
      Af(A_nrows+irow, A_ncols+icol)=re;
      Af(A_nrows+irow,icol)=im;
      Af(irow,A_ncols+icol)=-im;
    }
  }

  int nb=b.extent(0);
  Data<float,1> bf(2*nb);
  for(int i=0; i<nb; i++) {
    bf(i)=b(i).real();
    bf(nb+i)=b(i).imag();
  }

  Data<float,1> xf=solve_linear(Af,bf,sv_truncation);

  int nx=xf.extent(0)/2;
  result.resize(nx);

  for(int i=0; i<nx; i++) {
    result(i)=STD_complex(xf(i),xf(nx+i));
  }
#endif

  return result;
}


//////////////////////////////////////////////////

Data<float,1> eigenvalues(const Data<float,2>& A) {
  Log<OdinData> odinlog("","eigenvalues");
  Data<float,1> result;

  int N=A.extent(0);
  if(A.extent(1)!=N) {
    ODINLOG(odinlog,errorLog) << "Matrix not quadratic" << STD_endl;
    return result;
  }
  ODINLOG(odinlog,normalDebug) << "N=" << N << STD_endl;

  result.resize(N);
  result=0.0;

#ifdef HAVE_LAPACK

  // Array with Fortran storage order
  Array<float,2> A_fortran(N,N);
  for(int i=0; i<N; i++) {
    for(int j=0; j<N; j++) {
      A_fortran(i,j)=A(j,i);
    }
  }

  char JOBZ='N';
  char UPLO='U';

  Array<float,1> WORK(1);
  int LWORK=-1; // get optimal workspace size

  int INFO;

#ifndef HAVE_THREADSAFE_LAPACK
  MutexLock lock(lapack_mutex); //LAPACK is generally not thread safe, so we will serialize use of LAPACK
#endif

  // call once to get optimal size of the WORK array
  ssyev_(&JOBZ, &UPLO, &N, A_fortran.data(), &N, result.data(), WORK.data(), &LWORK, &INFO);
  LWORK=int(WORK(0));
  if(report_error(INFO,"eigenvalues(worksize)")) return result;

  WORK.resize(LWORK);

  // perform actual diagonalization
  ssyev_(&JOBZ, &UPLO, &N, A_fortran.data(), &N, result.data(), WORK.data(), &LWORK, &INFO);
  ODINLOG(odinlog,normalDebug) << "INFO=" << INFO << STD_endl;
  if(report_error(INFO,"eigenvalues(diagonalization)")) {
    ODINLOG(odinlog,normalDebug) << "A" << A << STD_endl;
    return result;
  }

#else
#ifdef HAVE_LIBGSL

  gsl_matrix* m=gsl_matrix_alloc(N, N);
  for(int j=0; j<N; j++) {
    for(int i=0; i<N; i++) {
      gsl_matrix_set(m, j, i, A(j,i));
    }
  }

  gsl_vector* eval=gsl_vector_alloc(N);

  gsl_eigen_symm_workspace* w=gsl_eigen_symm_alloc(N);

  if(!gsl_eigen_symm(m, eval, w)) {
    gsl_sort_vector(eval);
    for (int i=0; i<N; i++) {
      result(i)=gsl_vector_get(eval, i);
    }
  }

  gsl_vector_free(eval);
  gsl_eigen_symm_free(w);
  gsl_matrix_free(m);

#else
#error "GSL not available"
#endif
#endif
  return result;
}


//////////////////////////////////////////////////
// Unit Test


#ifndef NO_UNIT_TEST
class LinAlgTest : public UnitTest {

 public:
  LinAlgTest() : UnitTest("linalg") {}

 private:
  bool check() const {
    Log<UnitTest> odinlog(this,"check");

    // Self-consistency check
    ComplexData<2> A(3,3);
    for(int i=0; i<9; i++) A(A.create_index(i))=STD_complex(float(i),sqrt(float(i)));
    ComplexData<1> b(3);
    b(0)=STD_complex(0.1,4.5);
    b(1)=STD_complex(4.1,0.2);
    b(2)=STD_complex(-3.4,-7.5);

    ComplexData<1> x=solve_linear(A,b);

    ComplexData<1> b_test(3);
    b_test=matrix_product(A,x);

    if(cabs(sum(b_test-b))>1.0e-3) {
      ODINLOG(odinlog,errorLog) << "A=" << A << STD_endl;
      ODINLOG(odinlog,errorLog) << "x=" << x << STD_endl;
      ODINLOG(odinlog,errorLog) << "b=" << b << STD_endl;
      ODINLOG(odinlog,errorLog) << "b_test=" << b_test << STD_endl;
      ODINLOG(odinlog,errorLog) << "test failed" << STD_endl;
      return false;
    }



    // Test overdetermined system and its errors
    int nrows=100;
    int ncols=4;
    int errcol=1;
    A.resize(nrows,ncols);
    b.resize(nrows);
    ComplexData<1> x_expected(ncols);
    for(int i=0; i<ncols; i++) x_expected(i)=STD_complex(sqrt(float(i)),float(i)*float(i)-10.23);

    for(int irow=0; irow<nrows; irow++) {
      b(irow)=STD_complex(0.0);
      for(int icol=0; icol<ncols; icol++) {

        A(irow,icol)=STD_complex(sqrt(float(irow+2*icol)+4.4),log(float(irow+icol+2)));
        b(irow)+=x_expected(icol)*A(irow,icol);

        if(icol==errcol) A(irow,icol)+=STD_complex(1.0); // Put error in error column
      }
    }

    ComplexData<1> x_solved=solve_linear(A,b,0.02);

    // Actual solution by GSL
    x_expected(0)=STD_complex(0.941546,-6.1919);
    x_expected(1)=STD_complex(1.4224,-6.87925);
    x_expected(2)=STD_complex(1.03419,-6.38178);
    x_expected(3)=STD_complex(1.07938,-6.4723);

    if(cabs(sum(x_solved-x_expected))>1.0e-3) {
      ODINLOG(odinlog,errorLog) << "A=" << A << STD_endl;
      ODINLOG(odinlog,errorLog) << "b=" << b << STD_endl;
      ODINLOG(odinlog,errorLog) << "x_expected=" << x_expected << STD_endl;
      ODINLOG(odinlog,errorLog) << "x_solved=" << x_solved << STD_endl;
      return false;
    }


    // Testing eigenvalues
    Data<float,2> Asimple(2,2);
    Asimple(0,0)=13;
    Asimple(1,1)=7;
    Asimple(1,0)=Asimple(0,1)=-4; // symmetric

    Data<float,1> eig_calculated(eigenvalues(Asimple));
    Data<float,1> eig_expected(2); eig_expected(0)=5; eig_expected(1)=15;
    if(sum(fabs(eig_calculated-eig_expected))) {
      ODINLOG(odinlog,errorLog) << "eig_calculated=" << eig_calculated << STD_endl;
      ODINLOG(odinlog,errorLog) << "eig_expected=" << eig_expected << STD_endl;
      return false;
    }

    return true;
  }

};

void alloc_LinAlgTest() {new LinAlgTest();} // create test instance
#endif

