# Copyright (C) 2016 EDF
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)
import numpy as np
import StOptGeners
import StOptGlobal
import imp


# Simulate the optimal strategy , threaded version
# p_grid                   grid used for  deterministic state (stocks for example)
# p_optimize               optimizer defining the optimization between two time steps
# p_funcFinalValue         function defining the final value
# p_pointStock             initial point stock
# p_initialRegime          regime at initial date
# p_fileToDump             name of the file used to dump continuation values in optimization
def SimulateRegressionControlDist(p_grid, p_optimize, p_funcFinalValue, p_pointStock, p_initialRegime, p_fileToDump, p_bOneFile) :
    
    try:
        imp.find_module('mpi4py')
        found = True
    except:
        print("Not parallel module found ")
        found = False
        
    if found :
        from mpi4py import MPI
        # from the optimizer get back the simulation
        world = MPI.COMM_WORLD
        simulator = p_optimize.getSimulator()
        nbStep = simulator.getNbStep()
        states = []
        particle0 =  simulator.getParticles()[:,0]
       
        for i in range(simulator.getNbSimul()) :
            states.append(StOptGlobal.StateWithStocks(p_initialRegime, p_pointStock, particle0))
            
        toDump = p_fileToDump
        
        if p_bOneFile is False:
            toDump += "_" + str(world.rank)
            
        ar = StOptGeners.BinaryFileArchive(toDump, "r")
        # name for continuation object in archive
        nameAr = "Continuation"
        # cost function
        costFunction = np.zeros((p_optimize.getSimuFuncSize(), simulator.getNbSimul()))
        
        # iterate on time steps
        for istep in range(nbStep) :
            NewState = StOptGlobal.SimulateStepRegressionControlDist(ar, istep, nameAr, p_grid, p_grid, p_optimize, p_bOneFile).oneStep(states, costFunction)
            # different from C++
            states = NewState[0]
            costFunction = NewState[1]
            # new stochastic state
            particles = simulator.stepForwardAndGetParticles()
            
            for i in range(simulator.getNbSimul()) :
                states[i].setStochasticRealization(particles[:,i])
                            
        # final : accept to exercise if not already done entirely
        for i in range(simulator.getNbSimul()) :
            costFunction[0,i] += p_funcFinalValue.set(states[i].getRegime(), states[i].getPtStock(), states[i].getStochasticRealization()) * simulator.getActu()
            
        # average gain/cost
        return costFunction.mean()
