#!/usr/bin/python3

# Copyright (C) 2016 EDF
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)
import numpy as np
import unittest
import random
import math
import StOptGrids

# function used
def funcToInterpolate( x):
    return math.log(1. + x.sum())

# unit test for sparse grids
############################

class testGrids(unittest.TestCase):


    # test sparse grids  with boundaries
    def testSparseGridsBounds(self):
        # low values
        lowValues =np.array([1.,2.,3.])
        # size of the domain
        sizeDomValues = np.array([3.,4.,3.])
        # anisotropic weights
        weights =  np.array([1.,1.,1.])
        # level of the sparse grid
        level =3
        # create the sparse grid with linear interpolator
        sparseGridLin = StOptGrids.SparseSpaceGridBound(lowValues,sizeDomValues, level, weights,1)
        iterGrid = sparseGridLin.getGridIterator()
         # array to store
        data = np.empty(sparseGridLin.getNbPoints())
        # iterates on point
        while( iterGrid.isValid()):
          data[iterGrid.getCount()] = funcToInterpolate(iterGrid.getCoordinate())
          iterGrid.next()
        # Hierarchize the data
        hierarData = sparseGridLin.toHierarchize(data)
        # get back an interpolator
        ptInterp = np.array([2.3,3.2,5.9],dtype=float)
        interpol = sparseGridLin.createInterpolator(ptInterp)
        # calculate interpolated value
        interpValue = interpol.apply(hierarData)
        print(("Interpolated value sparse linear" , interpValue))
         # create the sparse grid with quadratic interpolator
        sparseGridQuad = StOptGrids.SparseSpaceGridBound(lowValues,sizeDomValues, level, weights,2)
        # Hierarchize the data
        hierarData = sparseGridQuad.toHierarchize(data)
        # get back an interpolator
        ptInterp = np.array([2.3,3.2,5.9],dtype=float)
        interpol = sparseGridQuad.createInterpolator(ptInterp)
        # calculate interpolated value
        interpValue = interpol.apply(hierarData)
        print(("Interpolated value sparse quadratic " , interpValue))
        # now refine
        precision = 1e-6
        print(("Size of hierarchical array " , len(hierarData)))
        valueAndHierar = sparseGridQuad.refine(precision,funcToInterpolate,data,hierarData)
        print(("Size of hierarchical array after refinement " , len(valueAndHierar[0])))
        # calculate interpolated value
        interpol1 = sparseGridQuad.createInterpolator(ptInterp)
        interpValue = interpol1.apply(valueAndHierar[1])
        print(("Interpolated value sparse quadratic after refinement " , interpValue))
        # coarsen the grid
        precision = 1e-4
        valueAndHierarCoarsen = sparseGridQuad.coarsen(precision,valueAndHierar[0],valueAndHierar[1])
        print(("Size of hierarchical array after coarsening " , len(valueAndHierarCoarsen[0])))
        # calculate interpolated value
        interpol2 = sparseGridQuad.createInterpolator(ptInterp)
        interpValue = interpol2.apply(valueAndHierarCoarsen[1])
        print(("Interpolated value sparse quadratic after refinement " , interpValue))
         
 
    # test sparse grids eliminating boundaries
    def testSparseGridsNoBounds(self):
        # low values
        lowValues =np.array([1.,2.,3.],dtype=float)
        # size of the domain
        sizeDomValues = np.array([3.,4.,3.],dtype=float)
        # anisotropic weights
        weights =  np.array([1.,1.,1.])
        # level of the sparse grid
        level =3
        # create the sparse grid with linear interpolator
        sparseGridLin = StOptGrids.SparseSpaceGridNoBound(lowValues,sizeDomValues, level, weights,1)
        iterGrid = sparseGridLin.getGridIterator()
         # array to store
        data = np.empty(sparseGridLin.getNbPoints())
        # iterates on point
        while( iterGrid.isValid()):
          data[iterGrid.getCount()] = funcToInterpolate(iterGrid.getCoordinate())
          iterGrid.next()
        # Hierarchize the data
        hierarData = sparseGridLin.toHierarchize(data)
        # get back an interpolator
        ptInterp = np.array([2.3,3.2,5.9],dtype=float)
        interpol = sparseGridLin.createInterpolator(ptInterp)
        # calculate interpolated value
        interpValue = interpol.apply(hierarData)
        print(("Interpolated value sparse linear" , interpValue))
         # create the sparse grid with quadratic interpolator
        sparseGridQuad = StOptGrids.SparseSpaceGridNoBound(lowValues,sizeDomValues, level, weights,2)
        # Hierarchize the data
        hierarData = sparseGridQuad.toHierarchize(data)
        # get back an interpolator
        ptInterp = np.array([2.3,3.2,5.9],dtype=float)
        interpol = sparseGridQuad.createInterpolator(ptInterp)
        # calculate interpolated value
        interpValue = interpol.apply(hierarData)
        print(("Interpolated value sparse quadratic " , interpValue))
        # test grids function
        iDim = sparseGridQuad.getDimension()
        pt = sparseGridQuad.getExtremeValues()
        # now refine
        precision = 1e-6
        print(("Size of hierarchical array " , len(hierarData)))
        valueAndHierar = sparseGridQuad.refine(precision,funcToInterpolate,data,hierarData)
        print(("Size of hierarchical array after refinement " , len(valueAndHierar[0])))
        # calculate interpolated value
        interpol1 = sparseGridQuad.createInterpolator(ptInterp)
        interpValue = interpol1.apply(valueAndHierar[1])
        print(("Interpolated value sparse quadratic after coarsening " , interpValue))
        # coarsen the grid
        precision = 1e-4
        valueAndHierarCoarsen = sparseGridQuad.coarsen(precision,valueAndHierar[0],valueAndHierar[1])
        print(("Size of hierarchical array after coarsening " , len(valueAndHierarCoarsen[0])))
        # calculate interpolated value
        interpol2 = sparseGridQuad.createInterpolator(ptInterp)
        interpValue = interpol2.apply(valueAndHierarCoarsen[1])
        print(("Interpolated value sparse quadratic after coarsening " , interpValue))

if __name__ == '__main__': 
    unittest.main()
