/***************************************************************************
                          fitting.h  -  description
                             -------------------
    begin                : Fri Apr 6 2001
    copyright            : (C) 2000-2021 by Thies Jochimsen & Michael von Mengershausen
    email                : thies@jochimsen.de  mengers@cns.mpg.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/

#ifndef FITTING_H
#define FITTING_H

#include <tjutils/tjnumeric.h> // for MinimizationFunction
#include <odindata/data.h>
#include <odindata/linalg.h>
#include <odindata/utils.h>


#define DEFAULT_MAX_ITER 1000
#define DEFAULT_TOLERANCE 1e-4

/**
  * @addtogroup odindata
  * @{
  */

/**
  * Structure representing a fitting paramater.
  */
struct fitpar {

  fitpar() : val(0.0), err(0.0) {}

/**
  * Value of the fitting parameter which is varied during the fit.
  */
  float val;

/**
  * The error interval of the final result.
  */
  float err;

/**
  * prints fp to the stream s
  */
  friend STD_ostream& operator << (STD_ostream& s, const fitpar& fp) {
    return s << fp.val << " +/- " << fp.err;
  }

};


////////////////////////////////////////////////////////////////////////

/**
  * Base class of all multi-dimensional function classes which
  * are used for fitting.
  * The function has an independent variable 'x' (the argument to evaluate_f),
  * a dependent variable 'y' (the result of evaluate_f) and a number of
  * function parameters.
  * To use this class, derive from it and overload the virtual
  * functions 'evaluate_f' (function value), 'evaluate_df' (first derivative),
  * 'numof_fitpars', and 'get_fitpar'.
  * Parameters which are modified during the fit should be
  * members of type fitpar.
  */
class ModelFunction {

 public:

/**
  * Returns the function value at position 'x'.
  */
  virtual float evaluate_f(float x) const = 0;

/**
  * Returns the first derivatives at position 'x'.
  */
  virtual fvector evaluate_df(float x) const = 0;

/**
  * Returns the number of independent fitting parameters.
  */
  virtual unsigned int numof_fitpars() const = 0;

/**
  * Returns reference to the i'th fitting parameter.
  */
  virtual fitpar& get_fitpar(unsigned int i) = 0;

/**
  * Returns the function values for x-values 'xvals'.
  */
  Array<float,1> get_function(const Array<float,1>& xvals) const;
  

  // dummy array used for default arguments
  static const Array<float,1> defaultArray;

 protected:
   ModelFunction() {}
   virtual ~ModelFunction() {}

   fitpar dummy_fitpar;
   
};

////////////////////////////////////////////////////////////////////////

/**
  * Interface class for all function fits
  */
class FunctionFitInterface {

 public:

  virtual ~FunctionFitInterface() {}

/**
  * Prepare a non-linear least-square fit of function 'model_func' for 'nvals' values
  */
  virtual bool init(ModelFunction& model_func, unsigned int nvals) = 0;

/**
  * The fitting routine that takes the starting values from the model function,
  * y-values 'yvals', and optionally the corresponding y-error bars 'ysigma'
  * and x-vals 'xvals'. If no error-bars are given, they are all set to 0.1 and if no
  * x-vals are given equidistant points with an increment of one are chosen,
  * i.e. xvals(i)=i;
  * A maximum of 'max_iterations' iterations and the given 'tolerance' is used during the fit.
  * Returns true on success.
  */
  virtual bool fit(const Array<float,1>& yvals,
           const Array<float,1>& ysigma=defaultArray,
           const Array<float,1>& xvals=defaultArray,
           unsigned int max_iterations=DEFAULT_MAX_ITER, double tolerance=DEFAULT_TOLERANCE) = 0;

  // dummy array used for default arguments
  static const Array<float,1> defaultArray;
};

////////////////////////////////////////////////////////////////////////

class ModelData; // forward declaration
class GslData4Fit; // forward declaration

/**
  * Class which is used for derivative-based fitting of functions.
  */
class FunctionFitDerivative : public virtual FunctionFitInterface {

 public:

/**
  * Constructs uninitialized function fit
  */
   FunctionFitDerivative() : gsldata(0), data4fit(0) {}

/**
  * Destructor
  */
  ~FunctionFitDerivative();

  // overloading virtual functions of FunctionFitInterface
  bool init(ModelFunction& model_func, unsigned int nvals);
  bool fit(const Array<float,1>& yvals,
           const Array<float,1>& ysigma=defaultArray,
           const Array<float,1>& xvals=defaultArray,
           unsigned int max_iterations=DEFAULT_MAX_ITER, double tolerance=DEFAULT_TOLERANCE);


 private:

  void print_state (size_t iter);

  GslData4Fit* gsldata;
  ModelData* data4fit;
};



///////////////////////////////////////////////////////////////////////


/**
  * A function to fit an exponential curve to an 1D data set.
  * It uses the function
  *
  * A * exp(lambda * x)
  */
struct ExponentialFunction : public ModelFunction {

  fitpar A;
  fitpar lambda;

  // implementing virtual functions of ModelFunction
  float evaluate_f(float x) const;
  fvector evaluate_df(float x) const;
  unsigned int numof_fitpars() const;
  fitpar& get_fitpar(unsigned int i);
};


///////////////////////////////////////////////////////////////////////

/**
  * A function to fit an exponential curve to an 1D data set.
  * It uses the function
  *
  * A * exp(lambda * x) + C
  */
struct ExponentialFunctionWithOffset : public ModelFunction {

  fitpar A;
  fitpar lambda;
  fitpar C;

  // implementing virtual functions of ModelFunction
  float evaluate_f(float x) const;
  fvector evaluate_df(float x) const;
  unsigned int numof_fitpars() const;
  fitpar& get_fitpar(unsigned int i);
};

///////////////////////////////////////////////////////////////////////


/**
  * A function to fit an Gaussian curve to an 1D data set.
  * It uses the function
  *
  * A * exp( - 2 * ( (x-x0) / fwhm )^2 )
  */
struct GaussianFunction : public ModelFunction {

  fitpar A;
  fitpar x0;
  fitpar fwhm;

  // implementing virtual functions of ModelFunction
  float evaluate_f(float x) const;
  fvector evaluate_df(float x) const;
  unsigned int numof_fitpars() const;
  fitpar& get_fitpar(unsigned int i);
};


///////////////////////////////////////////////////////////////////////


/**
  *
  * Class for fitting sinus function to a 1D curve
  *
  * y= A*sin(m*x + c)
  */
struct SinusFunction : public ModelFunction {

  fitpar A;
  fitpar m;
  fitpar c;

  // implementing virtual functions of ModelFunction
  float evaluate_f(float x) const;
  fvector evaluate_df(float x) const;
  unsigned int numof_fitpars() const;
  fitpar& get_fitpar(unsigned int i);
};

///////////////////////////////////////////////////////////////////////


/**
  *
  * Class for fitting gamma variate function to a 1D curve
  *
  * y= A*x^alpha*exp(-x/beta)
  */
struct GammaVariateFunction : public ModelFunction {

/**
  *
  * Set parameters from a simplified set of parameters: xmax and ymax are the x- and y-values of the maximum (see Madsen, Phys. Med. Biol. 37, 1992)
  */
  void set_pars(float alphaval, float xmax, float ymax);

  fitpar A;
  fitpar alpha;
  fitpar beta;

  // implementing virtual functions of ModelFunction
  float evaluate_f(float x) const;
  fvector evaluate_df(float x) const;
  unsigned int numof_fitpars() const;
  fitpar& get_fitpar(unsigned int i);
};

///////////////////////////////////////////////////////////////////////

/**
  *
  * Class for polynomial fitting of function
  *
  * y= Sum_i a[i] x^i, with i in [0,N_rank]
  *
  * N_rank is the degree of the polynome to be fitted
  */
template <int N_rank>
struct PolynomialFunction {

  fitpar a[N_rank+1];

/**
  *
  * polynomial fitting routine.
  * Fits the function to the  y-values 'yvals', and optionally
  * the corresponding error bars 'ysigma' and x-values 'xvals'.
  * If no error-bars are given they are all set to 1.0 and if no
  * x-vals are given equidistant points with an increment of one
  * are chosen, i.e. xvals(i)=i;
  * Returns true on success.
  */
  bool fit(const Array<float,1>& yvals,
           const Array<float,1>& ysigma,
           const Array<float,1>& xvals);

/*
  bool fit(const Array<float,1>& yvals,
           const Array<float,1>& ysigma){
  	firstIndex fi;
	Array<float,1> xvals(yvals.size());
	xvals=fi;
	return fit(yvals,ysigma,xvals);
  };
*/

  bool fit(const Array<float,1>& yvals){
  	Array<float,1> ysigma(yvals.size());
	ysigma=1.;
	return fit(yvals,ysigma);
  };
  
/**
  * Returns the polynomial function values for x-values 'xvals'
  * using the current polynomial coefficients.
  */
  Array<float,1> get_function(const Array<float,1>& xvals) const;

};


template <int N_rank>
bool PolynomialFunction<N_rank>::fit(const Array<float,1>& yvals, const Array<float,1>& ysigma, const Array<float,1>& xvals) {

  int npol=N_rank+1;
  for(int i=0; i<npol; i++) a[i]=fitpar(); // reset

  int npts=yvals.size();

  Array<float,1> sigma(npts);
  if(int(ysigma.size())==npts) sigma=ysigma;
  else sigma=1.0;

  Array<float,1> x(npts);
  if(int(xvals.size())==npts) x=xvals;
  else for(int ipt=0; ipt<npts; ipt++) x(ipt)=ipt;


  Array<float,2> A(npts,npol);
  Array<float,1> b(npts);


  for(int ipt=0; ipt<npts; ipt++) {
    float weight=secureInv( sigma(ipt));

    b(ipt)=weight*yvals(ipt);

    for(int ipol=0; ipol<npol; ipol++) {
      A(ipt,ipol)=weight*pow(x(ipt),ipol);
    }
  }

  Array<float,1> coeff(solve_linear(A,b));

  for(int ipol=0; ipol<npol; ipol++) a[ipol].val=coeff(ipol);

  return true;
}


template <int N_rank>
Array<float,1> PolynomialFunction<N_rank>::get_function(const Array<float,1>& xvals) const {
  int npts=xvals.size();
  Array<float,1> result(npts); result=0.0;

  for(int ipt=0; ipt<npts; ipt++) {
    for(int ipol=0; ipol<(N_rank+1); ipol++) {
      result(ipt)+=a[ipol].val*pow(xvals(ipt),ipol);
    }
  }

  return result;
}


///////////////////////////////////////////////////////////////////////


/**
  *
  * Class for linear regression of the function
  *
  * y= m*x + c
  *
  * For details see Numerical Recepies in C (2nd edition), section 15.2.
  */
struct LinearFunction {

  fitpar m;
  fitpar c;

/**
  *
  * Linear fitting routine.
  * Fits the function to the  y-values 'yvals', and optionally
  * the corresponding error bars 'ysigma' and x-values 'xvals'.
  * If no error-bars are given they are all set to 1.0 and if no
  * x-vals are given equidistant points with an increment of one
  * are chosen, i.e. xvals(i)=i;
  * Returns true on success.
  */
  bool fit(const Array<float,1>& yvals,
           const Array<float,1>& ysigma=defaultArray,
           const Array<float,1>& xvals=defaultArray);

/**
  * Returns the linear function values for x-values 'xvals'
  * using the current fit parameters.
  */
  Array<float,1> get_function(const Array<float,1>& xvals) const;


  // dummy array used for default arguments
  static const Array<float,1> defaultArray;
};



///////////////////////////////////////////////////////////////////////

class GslData4DownhillSimplex; // forward declaration

/**
  * downhill simplex optimizer
  */
class DownhillSimplex {

 public:

/**
  * Construct downhill simplex optimizer
  * - function: Function to evaluate/minimize
  */
  DownhillSimplex(MinimizationFunction& function);

/**
  * Destructor
  */
  ~DownhillSimplex();

/**
  * Returns parameter values in 'result' which minimize attached 'function'
  * - starting_point: Starting from this initial point
  * - step_size: The size of the initial trial steps
  * - ftol:  Tolerannce
  * - nmax: Max number of iterations
  * Returns true on success.
  */
  bool get_minimum_parameters(fvector& result, const fvector& starting_point, const fvector& step_size, unsigned int max_iterations=DEFAULT_MAX_ITER, double tolerance=DEFAULT_TOLERANCE);


 private:
  unsigned int ndim;
  GslData4DownhillSimplex* gsldata;

};


///////////////////////////////////////////////////////////////////////

/**
  * Class for downhill-simplex-based fitting of functions.
  */
class FunctionFitDownhillSimplex : public virtual FunctionFitInterface, public MinimizationFunction {

 public:

/**
  * Constructs uninitialized function fit
  */
   FunctionFitDownhillSimplex();

/**
  * Destructor
  */
  ~FunctionFitDownhillSimplex();

  // overloading virtual functions of FunctionFitInterface
  bool init(ModelFunction& model_func, unsigned int nvals);
  bool fit(const Array<float,1>& yvals,
           const Array<float,1>& ysigma=defaultArray,
           const Array<float,1>& xvals=defaultArray,
           unsigned int max_iterations=DEFAULT_MAX_ITER, double tolerance=DEFAULT_TOLERANCE);


  // overloading virtual functions of MinimizationFunction
  unsigned int numof_fitpars() const;
  float evaluate(const fvector& pars) const;


 private:
  ModelFunction* func;
  DownhillSimplex* ds;
  Array<float,1> yvals_cache;
  Array<float,1> ysigma_cache;
  Array<float,1> xvals_cache;
};

///////////////////////////////////////////////////////////////////////

/**
  * Fits an N_rank-dimensional polynomial of order 'polynom_order' to each point of the
  * array using the values of its neighbours regarding their reliability
  * (i.e. their relative weight for the fit). Parameters are:
  * - value_map: The array to be fitted
  * - reliability_map: The reliability of each point
  * - polynom_order: Order of the polynom
  * - kernel_size: Size of the neighbourhood of the pixel which is
  *   considered for the fit (using a Gaussian kernel with this FWHM)
  * - only_zero_reliability: Fit only pixel with zero reliabiliy
  *
  * This function returns the fitted array
  */
template <int N_rank>
Array<float,N_rank> polyniomial_fit(const Array<float,N_rank>& value_map, const Array<float,N_rank>& reliability_map,
                                    unsigned int polynom_order, float kernel_size, bool only_zero_reliability=false) {
  Log<OdinData> odinlog("","polyniomial_fit");

  Data<float,N_rank> result(value_map.shape());
  result=0.0;

  if(!same_shape(value_map,reliability_map)) {
    ODINLOG(odinlog,errorLog) << "size mismatch (value_map.shape()=" << value_map.shape() << ") != (reliability_map.shape()=" << reliability_map.shape() << ")" << STD_endl;
    return result;
  }

  if(min(reliability_map)<0.0) {
    ODINLOG(odinlog,errorLog) << "reliability_map must be non-negative" << STD_endl;
    return result;
  }

  int minsize=max(value_map.shape());
  for(int idim=0; idim<N_rank; idim++) {
    int dimsize=value_map.shape()(idim);
    if( (dimsize>1) && (dimsize<minsize) ) minsize=dimsize;
  }
  if(minsize<=0) {
    return result;
  }

  if((minsize-1)<int(polynom_order)) {
    polynom_order=minsize-1;
    ODINLOG(odinlog,warningLog) << "array size too small, restricting polynom_order to " << polynom_order << STD_endl;
  }

  TinyVector<int,N_rank> valshape(value_map.shape());
  int nvals=value_map.numElements();

  TinyVector<int,N_rank> polsize; polsize=polynom_order+1;
  Data<int,N_rank> polarr(polsize);
  int npol=polarr.numElements();

  if(pow(kernel_size,float(N_rank))<float(npol)) {
    kernel_size=pow(double(npol),double(1.0/float(N_rank)));
    ODINLOG(odinlog,warningLog) << "kernel_size too small for polynome, increasing to " << kernel_size << STD_endl;
  }



  int neighb_pixel=int(kernel_size);
  if(neighb_pixel<=0) neighb_pixel=1;
  TinyVector<int,N_rank> neighbsize; neighbsize=2*neighb_pixel+1;
  TinyVector<int,N_rank> neighboffset; neighboffset=-neighb_pixel;
  Data<int,N_rank> neighbarr(neighbsize); // neighbour grid around root pixel
  int nneighb=neighbarr.numElements();

  ODINLOG(odinlog,normalDebug) << "nvals/npol/nneighb=" << nvals << "/" << npol << "/" << nneighb << STD_endl;

  if(npol>nneighb) {
    ODINLOG(odinlog,warningLog) << "polynome order (" << npol << ") larger than number of neighbours (" << nneighb << ")" << STD_endl;
  }


  Array<float,2> A(npol,npol);
  Array<float,1> c(npol);
  Array<float,1> b(npol);

  TinyVector<int,N_rank> valindex;
  TinyVector<int,N_rank> neighbindex;
  TinyVector<int,N_rank> currindex;
  TinyVector<int,N_rank> diffindex;
  TinyVector<int,N_rank> polindex;
  TinyVector<int,N_rank> polindex_sum;

  float epsilon=0.01;
  float relevant_radius=0.5*kernel_size*sqrt(double(N_rank))+epsilon;

  // iterate through pixels of value_map
  for(int ival=0; ival<nvals; ival++) {
    valindex=result.create_index(ival);

    if( (!only_zero_reliability) || (reliability_map(valindex)<=0.0) ) { // fit only pixel with zero reliability

      A=0.0;
      b=0.0;

      int n_relevant_neighb_pixel=0;

      // iterate through neigbourhood of pixel and accumulate them in a single
      // set of equations, weighted by their reliability
      for(int ineighb=0; ineighb<nneighb; ineighb++) {
        neighbindex=neighbarr.create_index(ineighb);
        currindex=valindex+neighboffset+neighbindex;

        bool valid_pixel=true;

        // is the pixel within value_map ?
        for(int irank=0; irank<N_rank; irank++) {
          if(currindex(irank)<0 || currindex(irank)>=valshape(irank)) valid_pixel=false;
        }

        // does the pixel have non-vanishing reliability
        float reliability=0.0;
        if(valid_pixel) reliability=reliability_map(currindex);
        if(reliability<=0.0) valid_pixel=false;


        if(valid_pixel) {

          diffindex=currindex-valindex; // (xk-x0,yk-y0,...)

          float radiussqr=sum(diffindex*diffindex);
          float weight=reliability*exp(-2.0*radiussqr/(kernel_size*kernel_size));

          if(weight>0.0) {
            if(sqrt(radiussqr)<=relevant_radius) n_relevant_neighb_pixel++;

            // create b_i,j
            for(int ipol=0; ipol<npol; ipol++) {
              polindex=polarr.create_index(ipol); // (i,j,..)
              float polproduct=1.0;
              for(int irank=0; irank<N_rank; irank++) polproduct*=pow(float(diffindex(irank)),float(polindex(irank)));
              b(ipol)+=weight*value_map(currindex)*polproduct;
            }

            // create A_ii',jj'
            for(int ipol=0; ipol<npol; ipol++) {
              for(int ipol_prime=0; ipol_prime<npol; ipol_prime++) {
                polindex_sum=polarr.create_index(ipol)+polarr.create_index(ipol_prime); 
                float polproduct=1.0;
                for(int irank=0; irank<N_rank; irank++) polproduct*=pow(float(diffindex(irank)),float(polindex_sum(irank)));
                A(ipol,ipol_prime)+=weight*polproduct;
              }
            }

          }


        }
      }

      if(n_relevant_neighb_pixel>=npol) { // do we have enough pixel for the fit ?
        c=solve_linear(A,b);
        result(valindex)=c(0);
      }

    } else result(valindex)=value_map(valindex);

  }

  return result;
}


/** @}
  */


#endif

