#!/usr/bin/python3

# Copyright (C) 2016 EDF
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)
import math
import numpy as np
import unittest
import StOptGrids
import StOptReg
import StOptGlobal
import StOptGeners
import Utils
import Simulators as sim
import Optimizers as opt
import importlib

# unit test for global shape
############################

class OptimizerConstructionMpi(unittest.TestCase):

    def test(self):
        # test MPI
        moduleMpi4Py=importlib.util.find_spec('mpi4py')
        if (moduleMpi4Py is not None):
            from mpi4py import MPI
            comm = MPI.COMM_WORLD
            initialValues = np.zeros(1,dtype=float) + 1.
            sigma = np.zeros(1) + 0.2
            mu = np.zeros(1) + 0.05
            corr = np.ones((1,1),dtype=float)
             # number of step
            nStep = 30
            # exercice date
            dates = np.linspace(0., 1., nStep + 1)
            T= dates[len(dates) - 1]
            nbSimul = 10 # simulation number (optimization and simulation)
            # simulator
            ##########
            bsSim = sim.BlackScholesSimulator(initialValues, sigma, mu, corr, T, len(dates) - 1, nbSimul, False)
            strike = 1.
            # Pay off
            payOff= Utils.BasketCall(strike)
            # optimizer
            ##########
            N = 3   # number of exercize dates
            swiOpt = opt.OptimizerSwingBlackScholes(payOff,N)
            # link simulator to optimizer
            swiOpt.setSimulator(bsSim)
            # archive
            ########
            nameArchive = "Archive" + str(comm.rank)
            ar = StOptGeners.BinaryFileArchive(nameArchive,"w+")
            # regressor
            ##########
            nMesh = np.array([1])
            regressor = StOptReg.LocalLinearRegression(nMesh)
            # Grid
            ######
            # low value for the meshes
            lowValues =np.array([0.],dtype=float)
            # size of the meshes
            step = np.array([1.],dtype=float)
            # number of steps
            nbStep = np.array([N], dtype=np.int32)
            gridArrival =  StOptGrids.RegularSpaceGrid(lowValues,step,nbStep)
            gridStart   =  StOptGrids.RegularSpaceGrid(lowValues,step,nbStep-1)
            # pay off function for swing
            ############################
            payOffBasket = Utils.BasketCall(strike);
            payoff = Utils.PayOffSwing(payOffBasket,N)
            # final step
            ############
            asset =bsSim.getParticles()
            bSplit = np.array([1],dtype=bool) #define which direction to split in parallel computing
            fin = StOptGlobal.FinalStepDPDist(gridArrival,1,bSplit)
            values = fin.set( payoff,asset)
            # transition time step
            #####################
            # on step backward and get  asset
            asset = bsSim.stepBackwardAndGetParticles()
            # update regressor
            regressor.updateSimulations(0,asset)
            transStep = StOptGlobal.TransitionStepRegressionDPDist(gridStart,gridArrival,swiOpt)
            valuesNext=transStep.oneStep(values,regressor)
            bOneFile=0 # multiple files
            transStep.dumpContinuationValues(ar,"Continuation",1,valuesNext[0],valuesNext[1],regressor,bOneFile)
            # simulate time step
            ####################
            nbSimul= 10
            vecOfStates =[] # state of each simulation
            for i in np.arange(nbSimul):
                # one regime, all with same stock level (dimension 2), same realization of simulation (dimension 3)
                vecOfStates.append(StOptGlobal.StateWithStocks(1, np.array([0.]) , np.zeros(1)))
            simStep = StOptGlobal.SimulateStepRegressionDist(ar,1,"Continuation",gridArrival,swiOpt,bOneFile)
            phi = np.zeros((1,nbSimul))
            NewState =VecOfStateNext = simStep.oneStep(vecOfStates, phi)


if __name__ == '__main__': 
    unittest.main() 
