# Copyright (C) 2018 EDF
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)
import numpy as np
import StOptReg as reg
import StOptGrids
import StOptGeners
from StOptGlobal import StateWithStocks



def SimulateRegressionControlUsingControl(p_grid, p_optimize, p_funcFinalValue, p_pointStock,
                                          p_initialRegime, p_fileToDump, key="Control"):
    """
    Simulate the optimal strategy.
    But using the control

    Parameters:
    -----------
    p_grid
        Grid used for  deterministic state (stocks for example).
    p_optimize
        Optimizer defining the optimization between two time steps.
    p_funcFinalValue
        Function defining the final value.
    p_pointStock
        Initial point stock.
    p_initialRegime
        Regime at initial date.
    p_fileToDump
        Name of the file used to dump continuation values in optimization.
    """
    simulator = p_optimize.getSimulator()
    nsteps = simulator.getNbStep()
    nsims = simulator.getNbSimul()
    dim = simulator.getDimension()
    particle0 =  simulator.getParticles()[:,0]
    states = [StateWithStocks(p_initialRegime, p_pointStock,particle0 )
              for _ in range(nsims)]
    # Retrieve the file containing the continuation values:
    ar = StOptGeners.BinaryFileArchive(p_fileToDump, "r")
    # Cost function
    costFunction = np.zeros((p_optimize.getSimuFuncSize(), nsims))
    # Iterate on time steps.
    for istep in range(nsteps) :
        control = ar.readGridAndRegressedValue(istep, key)
        grid = control[0].getGrid()
        for i in range(nsims):
            state = states[i]
            p_optimize.stepSimulateControl(grid, control, state,  costFunction[:, i])
        particles = simulator.stepForwardAndGetParticles()
        for i in range(nsims) :
            states[i].setStochasticRealization(particles[:, i])

    # Final step: accept to exercise if not already done entirely.
    for i in range(simulator.getNbSimul()):
        costFunction[0, i] += p_funcFinalValue.set(states[i].getRegime(),
                                                   states[i].getPtStock(),
                                                   states[i].getStochasticRealization())

    # Average gain/cost.
    return costFunction.mean()
