#!/usr/bin/python3

# Copyright (C) 2016 EDF
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)
import StOptSDDP
import StOptSDDPUnitTest
import numpy as np
import unittest
import StOptGeners

# unitest for SDDP framework
############################


class FakeOptimizer(unittest.TestCase):

    def test(self):

        # Create a fake optimizer depending of a fake simulator
        simul =StOptSDDPUnitTest.SimulTest()
        optim= StOptSDDPUnitTest.OptimizeTest(simul)

        # final cuts
        finCut =StOptSDDP.SDDPFinalCut()

        # SDDP problem parameters
        nbSimulCheckForSimu = 1
        initialState = np.array([1.,2.])
        dates  = np.array([1.,2.])
        nbMesh= np.array([1,1],np.int32)
        nameRegressor =" "
        nameCut = " "
        nameVisitedStates = " "
        iter= 1
        accuracy=0.01
        nStepConv=10
        # test backward and forward mapping
        sol = StOptSDDPUnitTest.backwardForwardSDDPTestMapping(optim,nbSimulCheckForSimu,initialState,finCut,dates,nbMesh,nameRegressor,nameCut,nameVisitedStates,iter,accuracy,nStepConv)
        


class LocalLinearRegressionForSDDP(unittest.TestCase):

    def test(self):

        # Create a local linear regressor for SDDP
        particles= np.array([[1.,2.],[3.,3.]],np.float64)
        mesh = np.array([1,2],np.int32)
                           
        regressor =StOptSDDP.LocalLinearRegressionForSDDP(1,particles,mesh)
        
        archiveToWrite = StOptGeners.BinaryFileArchive("MyArchive","w")

        # dump SDDP regressor
        archiveToWrite.dump(regressor)
        # visited state
        aVisitedState = StOptSDDP.SDDPVisitedStates(1)
        # a state
        aState =  np.array([3.,4.],np.float64)
        # a particle
        aParticle = np.array([1.,2.],np.float64)
        aVisitedState.addVisitedState(aState, aParticle,regressor)
        # dump visited state
        archiveToWrite.dump(aVisitedState)

        # test backward SDDP mapping
        simul =StOptSDDPUnitTest.SimulTest()
        optim= StOptSDDPUnitTest.OptimizeTest(simul)
        # get backward base simulator
        simBase = optim.getSimulatorBackward()
        dates = np.array([1.],np.float64)
        initialState=  np.array([1.],np.float64)
        finalCut =StOptSDDP.SDDPFinalCut()
        archiveRegresssor = StOptGeners.BinaryFileArchive("MyArchiveReg","w")
        archiveCut  = StOptGeners.BinaryFileArchive("MyArchiveCut","r")
        a = StOptSDDPUnitTest.backwardSDDPTestMapping(optim,simBase,dates,initialState,finalCut,archiveRegresssor,"MyString",archiveCut)

        # test backward mapping
        b= StOptSDDPUnitTest.forwardSDDPTestMapping(optim,simBase,dates,initialState,finalCut,1,archiveRegresssor,archiveCut,"MyString")
        
        
if __name__ == '__main__': 
    unittest.main() 
