# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)
import StOptSDDP
import StOptGeners
import Simulators
import SDDPOptimizers
import imp

## @package backwardForwardSDDP
## backward and forward sweep by SDDP
## Permits to execute forward and backward sweep in python
## This is the python version of backwardForwardSDDP.cpp
## author Xavier Warin

## @brief Achieve forward and backward sweep by SDDP
## @param  p_optimizer           defines the optimiser necessary to optimize a step for one simulation solving a LP
## @param  p_nbSimulCheckForSimu defines the number of simulations to check convergence
## @param  p_initialState        initial state at the beginning of simulation
## @param  p_finalCut            object of final cuts
## @param  p_dates               vector of exercised dates, last dates correspond to the final cut object
## @param  p_meshForReg          number of mesh for regression in each direction
## @param  p_nameRegressor       name of the archive to store regressors
## @param  p_nameCut             name of the archive to store cuts
## @param  p_nameVisitedStates   name of the archive to store visited states
## @param  p_iter                maximum iteration of SDDP, on return the number of iterations achieved
## @param  p_accuracy            accuracy asked , on return estimation of accuracy achieved (expressed in %)
## @param  p_nStepConv           every p_nStepConv convergence is checked
## @return backward and forward valorization
def backwardForwardSDDP(p_optimizer, p_nbSimulCheckForSimu, p_initialState, p_finalCut,p_dates,
                        p_meshForReg, p_nameRegressor, p_nameCut, p_nameVisitedStates,
                        p_iter, p_accuracy,p_nStepConv):
    
    try:
        imp.find_module('mpi4py')
        Found = True
    except:
        print("Not parallel module found ")
        Found = False
    if Found:
        from mpi4py import MPI
        world = MPI.COMM_WORLD
    simulatorForOptim = p_optimizer.getSimulatorBackward()
    simulatorForSim = p_optimizer.getSimulatorForward()
    
    if Found:
        bTask = (world.rank==0)
    else:
        bTask = True


    if bTask:
        # archive for regressors
        archiveForRegressor =StOptGeners.BinaryFileArchive(p_nameRegressor,"w")
        # archive of first set of admissible states
        archiveForInitialState = StOptGeners.BinaryFileArchive(p_nameVisitedStates,"w")
        # list of states
        vecSetOfStates= [None] * (p_dates.size-1)
         # create regressors for each date except the last
        for  idate in range(p_dates.size-2,0,-1):
            simulatorForOptim.updateDateIndex(idate)
            # initial regressor
            particlesCurrent = simulatorForOptim.getParticles()
            # regressor
            regCurrent =StOptSDDP.LocalLinearRegressionForSDDP(False,particlesCurrent,p_meshForReg)
            # store the regressor
            archiveForRegressor.dump(regCurrent)
            setOfStates = StOptSDDP.SDDPVisitedStates(regCurrent.getNbMeshTotal())
            setOfStates.addVisitedStateForAll(p_optimizer.oneAdmissibleState(p_dates[idate + 1]),regCurrent)
            # ad admissible state
            vecSetOfStates[idate] =setOfStates
                                              
        simulatorForOptim.updateDateIndex(0)
        particlesInit = simulatorForOptim.getParticles()
        regInit= StOptSDDP.LocalLinearRegressionForSDDP(True,particlesInit,p_meshForReg)
        archiveForRegressor.dump(regInit)
        setOfStates = StOptSDDP.SDDPVisitedStates(1)
        setOfStates.addVisitedStateForAll(p_initialState,regInit)
        vecSetOfStates[0]= setOfStates

        for idate in range(0,p_dates.size-1):
            archiveForInitialState.dump(vecSetOfStates[idate])


    iterMax = p_iter;
    p_iter = 0;
    accuracy = p_accuracy;
    p_accuracy = 1e10;
    # archive to read regressor : currently all processor reads the regression
    archiveReadRegressor = StOptGeners.BinaryFileArchive(p_nameRegressor, "r");
    # only create for first task
    archiveForCuts = StOptGeners.BinaryFileArchive(p_nameCut, "w+");
    # to store  all backward values
    backwardValues =[]
    # forward value
    forwardValueForConv = 0.
    istep = 0;
    # store evolution of convergence
    backwardMinusForwardPrev = 0;
    while ((accuracy < p_accuracy) and (p_iter < iterMax)):
        # increase step
        istep += 1;
        # actualize time for simulators
        simulatorForOptim.resetTime()
        simulatorForSim.resetTime()
        
        if Found:
            world.Barrier()
        # backward sweep
        backwardValues.append(StOptSDDP.backwardSDDP(p_optimizer , simulatorForOptim, p_dates,  p_initialState,
                                                     p_finalCut, archiveReadRegressor,
                                                     p_nameVisitedStates , archiveForCuts))        
        if Found:
            world.Barrier()

        if ((istep == p_nStepConv) or (p_iter == 0)):
            istep = 0
            oldParticleNb = simulatorForSim.getNbSimul()
            simulatorForSim.updateSimulationNumberAndResetTime(p_nbSimulCheckForSimu)
            bIncreaseCut = False;
            
            forwardValueForConv =  StOptSDDP.forwardSDDP(p_optimizer , simulatorForSim, p_dates, p_initialState, p_finalCut, bIncreaseCut, archiveReadRegressor,
                                                         archiveForCuts, p_nameVisitedStates)

            p_accuracy  = abs((backwardValues[p_iter] - forwardValueForConv) / forwardValueForConv)
            simulatorForSim.updateSimulationNumberAndResetTime(oldParticleNb)

            if bTask:
                print(" ACCURACY ", p_accuracy, " Backward ", backwardValues[p_iter], " Forward ", forwardValueForConv , "p_iter "  ,p_iter , " accuracy " ,  p_accuracy )
            backwardMinusForward = backwardValues[p_iter] - forwardValueForConv;
            if (p_iter > 0):
                if (backwardMinusForward * backwardMinusForwardPrev < 0):
                    if bTask:
                        print(" Curve are crossing : increase sample and simulations to get more accurate solution, decrease step for checking convergence")
                    #exit
                    p_accuracy = 0.
            backwardMinusForwardPrev = backwardMinusForward;
        elif(p_iter>0):
            backwardMinusForward = backwardValues[p_iter] - forwardValueForConv;
            if (backwardMinusForward * backwardMinusForwardPrev < 0):
                if bTask:
                    print(" Curve are crossing : increase sample and simulations to get more accurate solution, decrease step for checking convergence")
                    #exit
                p_accuracy = 0.
        p_iter += 1;
    if Found:
        world.Barrier()
    return [backwardValues[p_iter - 1], forwardValueForConv]


