# Copyright (C) 2016 EDF
# All Rights Reserved
# This code is published under the GNU Lesser General Public License (GNU LGPL)
import numpy as np


# Defines a simple gas storage for optimization and simulation
# No constraint on the storage at the end of optimization period (so the storage will be empty)
class OptimizeLake :

    # Constructor
    # p_withdrawalRate    withdrawal rate between two time steps
    def __init__(self, p_withdrawalRate) :
        self.m_withdrawalRate = p_withdrawalRate

    # define the diffusion cone for parallelism
    # p_regionByProcessor         region (min max) treated by the processor for the different regimes treated
    # returns in each dimension the min max values in the stock that can be reached from the grid p_gridByProcessor for each regime
    def getCone(self, p_regionByProcessor) :

        extrGrid = list(range(1))
        extrGrid[0][0] = p_regionByProcessor[0][0] - self.m_withdrawalRate
        extrGrid[0][1] = 1e30

        return extrGrid

    # permits to actualize the time (needed for simulation)
    def incrementActuTimeStep(self) :

        return None

    # defines a step in optimization
    # p_grid      grid at arrival step after command
    # p_stock     coordinate of the stock point to treat
    # p_condEsp   conditional expectation operator
    # p_phiIn     for each regime  gives the solution calculated at the previous step ( next time step by Dynamic Programming resolution)
    # for each regimes (column) gives the solution for each particle (row)
    def stepOptimize(self, p_grid, p_stock, p_condEsp, p_phiIn) :

        nbSimul = p_condEsp[0].getNbSimul()
        solutionAndControl = list(range(2))
        solutionAndControl[0] = np.zeros((nbSimul, 1))
        solutionAndControl[1] = np.zeros((nbSimul, 1))
        inflows = self.m_simulator.getParticles()[0,:].transpose()
        # level min of the stock
        minStorage = p_grid.getExtremeValues()[0][0]
        maxStorage = p_grid.getExtremeValues()[0][1]
        condExpSameStock = np.zeros(len(inflows))
        cashSameStock = np.zeros(len(inflows))

        for i in range(len(inflows)):
            stockWithInflows = np.zeros(1)
            stockWithInflows[0] = min(p_stock[0] + inflows[i], maxStorage)

            if p_grid.isInside(stockWithInflows):
                interpolatorCurrentStock = p_grid.createInterpolator(stockWithInflows)
                # cash flow at current stock and previous step
                cashSameStock[i] = interpolatorCurrentStock.apply(p_phiIn[0][i,:].transpose())
                # conditional expectation at current stock point
                condExpSameStock[i] = p_condEsp[0].getASimulation(i, interpolatorCurrentStock)

            else:
                cashSameStock[i] = - 1e30
                condExpSameStock[i] = - 1e30

        gainWithdrawal = np.zeros(len(inflows))
        cashWithdrawalStock = np.zeros(len(inflows))
        condExpWithdrawalStock  = np.zeros(len(inflows))
        withdrawalMax = np.zeros(nbSimul)

        for i in range(len(inflows)):
            stockWithIn = p_stock[0] + inflows[i]
            withdrawalMax[i] = min(stockWithIn - minStorage, self.m_withdrawalRate)

            if withdrawalMax[i] < (0. - 1e3 * 2.220446049250313e-16):
                cashWithdrawalStock[i] = - 1e30
                condExpWithdrawalStock[i] = - 1e30
                gainWithdrawal = - 1e30

            else:
                stockWithInflows = np.zeros(1)
                stockWithInflows[0] = min(stockWithIn - withdrawalMax[i], maxStorage)
                interpolatorWithdrawalStock = p_grid.createInterpolator(stockWithInflows)
                cashWithdrawalStock[i] = interpolatorWithdrawalStock.apply(p_phiIn[0][i,:].transpose())
                condExpWithdrawalStock[i] = p_condEsp[0].getASimulation(i, interpolatorWithdrawalStock)
                gainWithdrawal[i] = withdrawalMax[i]

        solutionAndControl[0][:,0] = cashSameStock
        solutionAndControl[1][:,0] = 0.
        espCondMax = condExpSameStock
        espCondWithdrawal = gainWithdrawal + condExpWithdrawalStock

        solutionAndControl[0][:,0] = np.where(espCondWithdrawal > espCondMax, gainWithdrawal + cashWithdrawalStock, solutionAndControl[0][:,0])
        solutionAndControl[1][:,0] = np.where(espCondWithdrawal > espCondMax, - withdrawalMax, solutionAndControl[1][:,0])

        return solutionAndControl

    # get number of regimes
    def getNbRegime(self) :

        return 1

    # number of controls
    def getNbControl(self):

        return 1

    # defines a step in simulation
    # Notice that this implementation is not optimal. In fact no interpolation is necessary for this asset.
    # This implementation is for test and example purpose
    # p_grid          grid at arrival step after command
    # p_continuation  defines the continuation operator for each regime
    # p_state         defines the state value (modified)
    # p_phiInOut      defines the value function (modified): size number of functions to follow
    def stepSimulate(self, p_grid, p_continuation, p_state, p_phiInOut) :

        ptStock = p_state.getPtStock()
        ptStockMax = ptStock
        # size of the stock
        maxStorage = p_grid.getExtremeValues()[0][1]
        # inflow
        inflow = p_state.getStochasticRealization()[0]
        # update storage
        ptStockCur = ptStock.copy()
        ptStockCur[0] += inflow
        ptStockCur[0] = min(ptStockCur[0], maxStorage) # deverse if necessary
        # if do nothing
        espCondMax = -1.e30

        if p_grid.isInside(ptStockCur):
            espCondMax = p_continuation[0].getValue(ptStockCur, p_state.getStochasticRealization())
            ptStockMax[0] += inflow

        # gain to add at current point
        phiAdd = 0.
        # if withdrawal
	# level min of the stock
        minStorage = p_grid.getExtremeValues()[0][0]
        withdrawalMax = min(p_state.getPtStock()[0] + inflow - minStorage, self.m_withdrawalRate)

        if (0. < withdrawalMax ):
            gainWithdrawal = withdrawalMax
            # level reached
            ptStockCur = ptStock + inflow - withdrawalMax
            ptStockCur[0] = min(ptStockCur[0], maxStorage)
            continuationWithdrawal = p_continuation[0].getValue(ptStockCur, p_state.getStochasticRealization())
            espCondWithdrawal = gainWithdrawal + continuationWithdrawal

            if espCondWithdrawal > espCondMax :
                espCondMax = espCondWithdrawal
                phiAdd = gainWithdrawal
                ptStockMax[0] = ptStockCur[0]

        # for return
        p_state.setPtStock(ptStockMax)
        p_phiInOut[0] += phiAdd

    # Defines a step in simulation using interpolation in controls
    # p_grid          grid at arrival step after command
    # p_control       defines the controls
    # p_state         defines the state value (modified)
    # p_phiInOut      defines the value function (modified): size number of functions to follow
    def stepSimulateControl(self, p_grid, p_control, p_state, p_phiInOut):

        ptStock = p_state.getPtStock()
        # size of the stock
        maxStorage = p_grid.getExtremeValues()[0][1]
        minStorage =  p_grid.getExtremeValues()[0][0]
        # inflow
        inflow = p_state.getStochasticRealization()[0]
        # get back control
        control = p_control[0].getValue(ptStock, p_state.getStochasticRealization())
        control = max(min(maxStorage - (ptStock[0] + inflow), control), minStorage - (ptStock[0] + inflow))
        ptStock[0] += inflow + control
        # for return
        p_state.setPtStock(ptStock)
        p_phiInOut[0] -= control

    # get size of the  function to follow in simulation
    def getSimuFuncSize(self):

        return 1

    # store the simulator
    def setSimulator(self, p_simulator):

        self.m_simulator = p_simulator

    # get the simulator back
    def getSimulator(self):

        return self.m_simulator
