#!/usr/bin/python3

# Copyright (C) 2016 EDF
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)
import numpy as np
import math
import StOptReg as reg
import StOptGrids
import StOptGlobal
import Simulators as sim
import Optimizers as opt
import Utils
import dp.DynamicProgrammingByRegressionDist as dynmpi
import dp.SimulateRegressionControlDist as srtmpi
import unittest
import importlib

accuracyClose = 1e5

    
class testGasStorageSwitchCostMpiHighLevelTest(unittest.TestCase):
    
    def testGasStorageSwitchCostMpiHighLevel(self):
        
        moduleMpi4Py=importlib.util.find_spec('mpi4py')
        if (moduleMpi4Py is not None):
            from mpi4py import MPI

            world = MPI.COMM_WORLD
            # storage
            ###############
            maxLevelStorage = 360000.
            injectionRateStorage = 60000.
            withdrawalRateStorage = 45000.
            injectionCostStorage = 0.35
            withdrawalCostStorage = 0.35
            switchingCostStorage =  4.

            maturity = 1.
            nstep = 10

            # define a a time grid
            timeGrid = StOptGrids.OneDimRegularSpaceGrid(0., maturity / nstep, nstep)
            futValues = []
            # periodicity factor
            iPeriod = 52

            for i in range(nstep + 1):
                futValues.append(50. + 20 * math.sin((math.pi * i * iPeriod) / nstep))

            # define the future curve
            futureGrid = Utils.FutureCurve(timeGrid, futValues)
            # one dimensional factors
            nDim = 1
            sigma = np.zeros(nDim) + 0.94
            mr = np.zeros(nDim) + 0.29
            # number of simulations
            nbsimulOpt = 20000
            # grid
            #####
            nGrid = 40
            lowValues = np.zeros(1, dtype = float)
            step = np.zeros(1, dtype = float) + maxLevelStorage / nGrid
            nbStep = np.zeros(1, dtype = np.int32) + nGrid
            grid = StOptGrids.RegularSpaceGrid(lowValues, step, nbStep)

            # no actualization
            rate=0. 
            # a backward simulator
            ######################
            bForward = False
            backSimulator = sim.MeanRevertingSimulator(futureGrid, sigma, mr, rate, maturity, nstep, nbsimulOpt, bForward)
            # optimizer
            ############
            storage = opt.OptimizeGasStorageSwitchingCostMeanReverting(injectionRateStorage, withdrawalRateStorage, injectionCostStorage, withdrawalCostStorage, switchingCostStorage)
            # regressor
            ##########
            nMesh = 4
            nbMesh = np.zeros(1, dtype = np.int32) + nMesh
            regressor = reg.LocalLinearRegression(nbMesh)
            # final value
            vFunction = Utils.ZeroPayOff()

            # initial values
            initialStock = np.zeros(1) + maxLevelStorage
            initialRegime = 0 # here do nothing (no injection, no withdrawal)

            # Optimize
            ###########
            fileToDump = "CondExpGasSwiCostHLMpi"
            bOneFile = True
            # link the simulations to the optimizer
            storage.setSimulator(backSimulator)
            valueOptimMpi = dynmpi.DynamicProgrammingByRegressionDist(grid, storage, regressor, vFunction, initialStock, initialRegime, fileToDump, bOneFile)
            print("valOP", valueOptimMpi)

            world.barrier()

            nbsimulSim = 40000
            bForward = True
            forSimulator = sim.MeanRevertingSimulator(futureGrid, sigma, mr, rate,maturity, nstep, nbsimulSim, bForward)
            storage.setSimulator(forSimulator)
            valSimuMpi = srtmpi.SimulateRegressionControlDist(grid, storage, vFunction, initialStock, initialRegime, fileToDump, bOneFile)
            print("valSimuMpi", valSimuMpi)

            if world.rank == 0:
                self.assertAlmostEqual(valueOptimMpi, valSimuMpi, None, "Re-adjust tolerance edge please", accuracyClose)
                print("Optim", valueOptimMpi, "valSimuMpi", valSimuMpi)

            return valueOptimMpi

    def test_switchingVaryingRegimeStorageMpi(self):
        
        moduleMpi4Py=importlib.util.find_spec('mpi4py')
        if (moduleMpi4Py is not None):
            from mpi4py import MPI
            world = MPI.COMM_WORLD
            # storage
            ###############
            maxLevelStorage = 360000.
            injectionRateStorage = 60000.
            withdrawalRateStorage = 45000.
            injectionCostStorage = 0.35
            withdrawalCostStorage = 0.35
            switchingCostStorage =  4.
            
            maturity = 1.
            nstep = 10
        
            # define a a time grid
            timeGrid = StOptGrids.OneDimRegularSpaceGrid(0., maturity / nstep, nstep)
            futValues = []
            # periodicity factor
            iPeriod = 52
            
            for i in range(nstep + 1):
                futValues.append(50. + 20 * math.sin((math.pi * i * iPeriod) / nstep))
                
            # define the future curve
            futureGrid = Utils.FutureCurve(timeGrid, futValues)
            # regime values allowed
            #######################
            tvalues = np.zeros(6)
            tvalues[0] = 0.
            tvalues[1] = 1e-3
            tvalues[2] = 1. / 4 + 1e-3
            tvalues[3] = 1. / 2 + 1e-3
            tvalues[4] = 3. / 4. + 1e-3
            tvalues[5] = 1.
            timeRegimes = StOptGrids.OneDimSpaceGrid(tvalues)
            regValues = []
            regValues.append(3)
            regValues.append(3)
            regValues.append(1)
            regValues.append(3)
            regValues.append(2)
            regValues.append(3)
            regime = Utils.RegimeCurve(timeRegimes, regValues)
            # one dimensional factors
            nDim = 1
            sigma = np.zeros(nDim) + 0.94
            mr = np.zeros(nDim) + 0.29
            # number of simulations
            nbsimulOpt = 20000
            # grid
            #####
            nGrid = 40
            lowValues = np.zeros(1, dtype = float)
            step = np.zeros(1, dtype = float) + maxLevelStorage / nGrid
            nbStep = np.zeros(1, dtype = np.int32) + nGrid
            grid = StOptGrids.RegularSpaceGrid(lowValues, step, nbStep)
            
            # no actualization
            rate = 0.
            # a backward simulator
            ######################
            bForward = False
            backSimulator = sim.MeanRevertingSimulator(futureGrid, sigma, mr, rate, maturity, nstep, nbsimulOpt, bForward)
            # optimizer
            ############
            storage = opt.OptimizeGasStorageSwitchingCostMeanReverting(injectionRateStorage, withdrawalRateStorage, injectionCostStorage, withdrawalCostStorage, switchingCostStorage, regime)
            # regressor
            ##########
            nMesh = 4
            nbMesh = np.zeros(1, dtype = np.int32) + nMesh
            regressor = reg.LocalLinearRegression(nbMesh)
            # final value
            vFunction = Utils.ZeroPayOff()
            
            # initial values
            initialStock = np.zeros(1) + maxLevelStorage
            initialRegime = 0 # here do nothing (no injection, no withdrawal)
            
            # Optimize
            ###########
            fileToDump = "CondExpGas"
            bOneFile = True
            # link the simulations to the optimizer
            storage.setSimulator(backSimulator)
            valueOptimMpi = dynmpi.DynamicProgrammingByRegressionDist(grid, storage, regressor, vFunction, initialStock, initialRegime, fileToDump, bOneFile)
            print("valOP", valueOptimMpi)
            
            world.barrier()
            
            # a forward simulator
            #####################
            nbsimulSim = 40000
            bForward = True
            forSimulator = sim.MeanRevertingSimulator(futureGrid, sigma, mr, rate, maturity, nstep, nbsimulSim, bForward)
            # link the simulations to the optimizer
            storage.setSimulator(forSimulator)
            valSimuMpi = srtmpi.SimulateRegressionControlDist(grid, storage, vFunction, initialStock, initialRegime, fileToDump, bOneFile)
            print("valSimuMpi", valSimuMpi)
            
            if world.rank == 0:
                self.assertAlmostEqual(valueOptimMpi, valSimuMpi, None, "Re-adjust tolerance edge please", accuracyClose)
                print("valOP", valueOptimMpi, "valSimuMpi", valSimuMpi)

if __name__ == '__main__':
    
    unittest.main()
