# Copyright (C) 2019 EDF
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)
import numpy as np
import StOptGrids 
import StOptTree
import StOptGlobal
import StOptGeners


def DynamicProgrammingByTreeHighLevel(p_grid, p_optimize,  p_funcFinalValue, p_pointStock, p_initialRegime, p_fileToDump) :
    
    # from the optimizer get back the simulation
    simulator = p_optimize.getSimulator()
    # final values
    fin = StOptGlobal.FinalStepDP(p_grid, p_optimize.getNbRegime())
    valuesNext = fin.set(p_funcFinalValue, simulator.getNodes())
    ar = StOptGeners.BinaryFileArchive(p_fileToDump, "w")
    nameAr = "Continuation"
    nsteps =simulator.getNbStep()
    # iterate on time steps
    for iStep in range(nsteps) :
        simulator.stepBackward()
        # get back probability
        proba = simulator.getProba()
        # and connection matrix
        connected = simulator.getConnected()
        # creta tree for conditional expectation
        tree=StOptTree.Tree(proba,connected)      
        # transition object
        transStep = StOptGlobal.TransitionStepTreeDP(p_grid, p_grid, p_optimize)
        valuesAndControl = transStep.oneStep(valuesNext, tree)
        transStep.dumpContinuationValues(ar, nameAr, iStep, valuesNext, valuesAndControl[1], tree)
        valuesNext = valuesAndControl[0]
        
    # interpolate at the initial stock point and initial regime
    return (p_grid.createInterpolator(p_pointStock).applyVec(valuesNext[p_initialRegime])).mean()
