#!/usr/bin/python3

# Copyright (C) 2016 EDF
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)
import numpy as np
import unittest
import random
import math


# unit test for continuation values
##################################

class testContValues(unittest.TestCase):

    # test a regular grid  for stocks and a local function basis for regression
    def testSimpleGridsAndRegressor(self):
        import StOptGrids
        import StOptReg
        # low value for the meshes
        lowValues =np.array([1.,2.,3.],dtype=float)
        # size of the meshes
        step = np.array([0.7,2.3,1.9],dtype=float)
        # number of steps
        nbStep = np.array([3,2,4], dtype=np.int32)
        # create the regular grid
        #########################
        grid = StOptGrids.RegularSpaceGrid(lowValues,step,nbStep)
        # simulation
        nbSimul =10000
        np.random.seed(1000)
        x = np.random.uniform(-1.,1.,size=(1,nbSimul));
        # mesh
        nbMesh = np.array([16],dtype=np.int32)
        # Create the regressor
        #####################
        regressor = StOptReg.LocalLinearRegression(False,x,nbMesh)
        # regressed values
        toReal = (2+x[0,:]+(1+x[0,:])*(1+x[0,:]))
        # function to regress
        toRegress = toReal + 4*np.random.normal(0.,1,nbSimul)
        # create a matrix (number of stock points by number of simulations)
        toRegressMult = np.zeros(shape=(len(toRegress),grid.getNbPoints()))
        for i in range(toRegressMult.shape[1]):
           toRegressMult[:,i] = toRegress
        # Now create the continuation object
        ####################################
        contOb = StOptReg.ContinuationValue(grid,regressor,toRegressMult)
        # get back the regressed values at the point stock
        ptStock=  np.array([1.2,3.1,5.9],dtype=float)
        regressValues = contOb.getAllSimulations(ptStock)
        # do the same with an interpolator
        interp = grid.createInterpolator(ptStock)
        regressValuesInterp = contOb.getAllSimulations(interp)
        # test create of an interpoaltion object mixing grids for stocks and regression for uncertainties
        #################################################################################################
        gridAndRegressed = StOptReg.GridAndRegressedValue(grid,regressor,toRegressMult)
        # get back the regressed value for a point stock and an uncertainty
        valRegressed = gridAndRegressed.getValue(ptStock, x[:,0])
        # Now test simulations one by one
        #################################
        for i in range(nbSimul):
            regressed = gridAndRegressed.getValue(ptStock,x[:,i] )
            diff = regressed -regressValues[i]
            self.assertAlmostEqual(diff,0., 7, "test regression simulation by simulation")

    # test some mapping of GneralSpaceGrid
    def testGeneralGridInheritance(self):
        from StOptGrids import GeneralSpaceGrid, RegularSpaceGrid
        from StOptReg import LocalLinearRegression, ContinuationValue

        x = np.random.randn(5)
        regressor = LocalLinearRegression([1])
        
        regular = RegularSpaceGrid(np.array([0.]), np.array([0.5]), np.array([3]))
        ContinuationValue(regular, regressor, x) 
        
        general = GeneralSpaceGrid([[0., 1., 1.2, 1.5]])
        ContinuationValue(general, regressor, x) 
   
 
if __name__ == '__main__': 
    unittest.main()
